アラカン"BOKU"のITな日常

文系システムエンジニアの”BOKU”が勉強したこと、経験したこと、日々思うことを書いてます。

回帰問題にリベンジする。日別売上の予測っぽいことを再びやってみる/使い方32/Neural Network Console

 使い方18でLSTMユニットで「予測っぽいことをやってみる」記事を書きました。

arakan-pgm-ai.hatenablog.com

でも、タイトルに書き足した様に「問題あり」で終わってます。 

今回は、リベンジ(笑)編です。

 

前回の反省点を整理する

 

RNN(リカレントニューラルネットワーク)と通常のニューラルネットワークの違いは、ざっくり言えば隠れ層の扱いの違いです。 

RNNは例えば時刻(t)でのx(t)に加え、過去の時刻(t-1)におけるx(t-1)も保持しておいて、時刻(t)における隠れ層に伝える点が違います。 

でも、過程は違えど、最終的に結果(y)を推論する点は同じなんですよね。 

なのに、データ(x)に対して結果ラベル(y)をすべて0にして、途中経過の出力であるx'だけを見るという使い方18のアプローチはさすがに乱暴でした。 

Googleグループでも同じ話題がコメントされてました。

groups.google.com

 

これによると、7日分の売上データがデータ(x)で、次の7日分の売上を結果(y)にするとかしないと駄目だということです。 

7日間の売上を7行1列で時系列で、0.csv、.1.csv・・のように作ったとしたら、イメージとしてはこんな感じ。

x:data,y:label

.\0.csv,./1.csv

これを学習させて、最終週の7日分(x)のみ与えて、次週を予測させる。

x:data

./999.csv

 

なるほどですね。 

実際にこれをRNNでやるとしたら、学習データは結果(y)の次の7日分をy__0~y__6に展開したこんな感じになるんですかね。

f:id:arakan_no_boku:20180404092817j:plain

 

このやり方はこちらのチュートリアルでも紹介されてますし。

support.dl.sony.com

 

実際にLSTMを使ってやってみました。 

きれいに学習は収束しましたし、それらしい数値は出力できたのですが、たぶん、用意できた数値データの量が十分ではなかったのか、なかなか思うような精度がでませんでした。 

なので、ちょっとそちらは置いといて、今回は少し違うアプローチに挑戦してみます。

 

直近7日間の数値で翌日を予測するアプローチ

 

アイディアはこうです。 

7日間の時系列データで、7日分の予測をするのではなく、直近7日分の数値を一塊のデータとして、8日目の数値を予測する。 

こうすれば、同じ期間の数値(例えば売上等)でも学習用データの量としては7倍にできますし、なにより、時系列で処理する必要がないので、普通のニューラルネットワークで対処できそうです。 

使い方18の時同様に、日別の数値(売上?)を用意して、1.0~-1.0の範囲になるように100000で割って少数以下の数字にします(C列)。

f:id:arakan_no_boku:20180403225849j:plain

 

これを、7日間をx__0~x__6のデータとして、8日目をyとする以下の1行8列のデータにして以下のようなCSVデータを作ります。

f:id:arakan_no_boku:20180403230913j:plain

 

このやり方はNNCのチュートリアルのこの記事を参考にしています。

support.dl.sony.com

 

今回は、これで約5ヶ月分(約150行くらい)を学習用、同じフォーマットで最後の1週間分(7行)を評価用として2つのファイルを作りました。 

手で作ると大変なので、こんな感じでEXCELVBACSVデータを作って、あとで学習用と評価用にわけました。 

ちなみに、ソースコード中の「forward01」は、シート名ですので、適宜書き換えてください。

Sub dataCreate3()
     Dim wkst
     Dim obj As Object
     Dim buf  As Variant
     wkst = "forward01"
     bottom = Range("C2").End(xlDown).Row
     ts = 7
     fn = 0
     Set obj = CreateObject("ADODB.Stream")
     obj.Charset = "UTF-8"
     obj.Open
     datFile = ActiveWorkbook.Path & "\" & wkst & ".csv"
     out = ""
     out = "x,x__1,x__2,x__3,x__4,x__5,x__6,y" & vbNewLine
     obj.WriteText (out)

     For r = 1 To bottom - 1
        If (r + 8) < (bottom - 1) Then
           out = ""
           out = out & Worksheets(wkst).Cells(r + 1, 3).Value & ","
           out = out & Worksheets(wkst).Cells(r + 2, 3).Value & ","
           out = out & Worksheets(wkst).Cells(r + 3, 3).Value & ","
           out = out & Worksheets(wkst).Cells(r + 4, 3).Value & ","
           out = out & Worksheets(wkst).Cells(r + 5, 3).Value & ","
           out = out & Worksheets(wkst).Cells(r + 6, 3).Value & ","
           out = out & Worksheets(wkst).Cells(r + 7, 3).Value & ","
           out = out & Worksheets(wkst).Cells(r + 8, 3).Value & vbNewLine
           obj.WriteText (out)
           fn = fn + 1
         End If
       Next
       obj.Position = 0
       obj.Type = adTypeBinary
       obj.Position = 3
       buf = obj.read
       obj.Position = 0
       obj.write buf
       obj.SetEOS
       obj.SaveToFile datFile, 2
       obj.Close
End Sub

 

ネットワークモデルはごくシンプルに

 

今回学習に使ったモデルです。

f:id:arakan_no_boku:20180404001822j:plain

 

inputは、x__0~x__6の7列+結果ラベル(y)の1行8列なのですが、チュートリアルにあるとおり、x__0~x__6の部分だけ・・つまり、7をサイズにします。

f:id:arakan_no_boku:20180404002129j:plain

 

AffineとTanhのあたりはoutshapeを7にあわせる以外は特別なことはありません。 

回帰問題なので、最後はsquarederrorにします。 

T.Dataset(アウトプット)は「y」です。

f:id:arakan_no_boku:20180404002603j:plain

 

学習をやってみる

 

これで、DATASETに、さきほど用意したCSVファイルを指定して、学習させます。 

学習のグラフはこんな感じ。

f:id:arakan_no_boku:20180404093600j:plain

 

わりと最初の方は、かなり「暴れる」感じのグラフになるんですが、50epochあたりから収束してきました。 

グラフだと100epochもあれば十分なようにも見えるのですが、今回、自分で試行錯誤した結果から言えば上記のように300epochくらいが最良でした。 

バッチサイズもいろいろ試しましたが、ちょっと時間がかかりましたが、一番結果がよかったのは「1」でした。

f:id:arakan_no_boku:20180404094628j:plain

 

評価をして、結果を確認する

 

今回、用意した数値(いちおう、売上想定で1週間の曜日で凸凹をつけて、全体的には上昇トレンドを意識して作ったもの)は使い方18で使ったものの流用です。 

ただし、期間をのばして約5ヶ月分の日別にしました。 

グラフに描くと、全体としてはこういう見てくれです。

f:id:arakan_no_boku:20180404095618j:plain

 

あ・・、データは小数点以下なんですが、グラフにするときには100000倍にして、元の整数部5桁の数字にもどしてますので、念のため。 

このうちの最後の方だけを拡大してみると、作成したデータはちょうど頂点あたりで終わってます。

f:id:arakan_no_boku:20180404095701j:plain

 

この続きを予測して継ぎ足してみて、不自然でない数字が予測されていれば、一旦はOKと判断できるんじゃないかと思ってます。 

さて、評価を実行した結果です。 

プロジェクト名.filesフォルダの実行日付_時刻フォルダ(20180404_130520等)の下「output_result.csv」を開きます。 

そうすると右端に「y'」の列があるので、その数値を結果としてコピーします。

f:id:arakan_no_boku:20180404172459j:plain

 

実際には、その行のx__1~x__6を、x__0~x__5にずらして、結果のy'をx__6にして、また評価して・・みたいに地道にやる必要があったりします。 

これはかなり面倒くさいので、今回のようにお試しでないなら、nnablaとかでプログラム組んだ方がよいとは思いますが・・とりあえず、今回はやってみました。 

結果のグラフです。

f:id:arakan_no_boku:20180412202353j:plain

 

赤枠で囲った部分が、予測によって生成された数字です。 

上昇トレンドも含めて、かなり、正確に予測できているような気はしますね。 

しかも、今回はちゃんと推論した結果(y')を使ってますし。 

一応、「予測っぽい」くらいのことはできてるんじゃないですかね。 

やれやれ・・。 

 

f:id:arakan_no_boku:20171115215731j:plain

ニューラルネットワークコンソールの使い方一覧はこちらです。

arakan-pgm-ai.hatenablog.com