読者です 読者をやめる 読者になる 読者になる

アラカン"BOKU"のITな日常

あれこれ興味をもって考えたことを書いてます

手作りディープラーニング。学習機能を実装してみたこと。

手作りディープラーニングで学習機能を実装した話を書きます。

 

勉強を兼ねて、フルスクラッチディープラーニングの実装に挑戦中で、前回は推論部だけ作りましたので、その続きです。学習機能がないと、まったくのおバカ状態ですからね。

 

ディープラーニングの学習って、ざっくり言えば、重みとバイアスというパラメータの値を適切な結果がでるように調整することです。

 

下の図の例だと、z1のノードを求めるのに「z1 = x1W1 + x2W2 + b」で計算します。

この式のW1,W2の部分が「重み」、bの部分が「バイアス」です。

f:id:arakan_no_boku:20161108055220j:plain

 

で、どうするかというと、損失関数というものを使って、どの位正解と離れているかを数値(例えば、正解に近いと 0.1213456・・、離れていると 0.892356など)の大小で表現して、それが最小の状態になるように重みとバイアスを更新していくわけです。

 

今回は、損失関数は交差エントロピー誤差を使いました。計算式とかは書きません。ブログが見づらくなるので。

 

もし、興味がある方はネットで探すといろいろ親切に解説いただいているサイトがあったりしますので、すいませんが、そちらを見てください。例えば、こちらです。

GITHUBですが、Wikiで計算式とコーディング例を丁寧に解説いただいてます。ありがたいですね。

 

早速学習させてみた

テストには、例によってMNISTデータ・セットを使いました。

 

学習した結果、前回 9%だった正解率が、97%強97.00%~97.31%程度)になりました。

 

わずか3層固定のネットワークで、かつ、特別、学習精度を向上させる対応をほどこしていない素の状態であることを考えると、びっくりするような結果じゃないかと自分では思います。

 

実際に動かして意外だったのは、学習に思ったより時間がかかったことです。

 

予想としては、60000枚の学習データがあるので、✖10ののべ60万枚程度学習すれば、97%~98%くらいの正解率には達するんじゃないかと思っていましたが、おおはずれでした。

 

グラフを見てください。

f:id:arakan_no_boku:20161111062519j:plain

細い点線でグラフを書いているのでわかりにくくてすいません。

 

このグラフのX軸は、実行回数です。Y軸は学習結果をもとに実行した10000枚のテスト画像に対する正解率です。

 

学習は、1回の実行で画像を100枚ずつ処理していて、500回ごとにグラフに出力しています。グラフは50000回実行した時のものです。

 

最初の1000回くらいで、かんたんに89%位まであがったのですが、そこからは、まさにカメの歩みでなかなか上がりません。結局、97%強に達するのに、のべ500万枚の学習が必要でした。

 

でも、時間はかかっても、このまま続けていくと、限りなく100%に近づいていくのか?という疑問が当然でてきます。

 

なので、一旦、適当なところで学習データを保存して、そこからさらに学習を続ける実験をしてみました。

f:id:arakan_no_boku:20161112091022j:plain

 上のグラフは、追加で20000回分の学習した結果をトレースしたものです。

 

見づらいので、グラフのY軸は、90%未満は切り落として表示しています。画像だけ見ると、若干上がっているように見えますが、実際は凸凹しながらほぼ横ばい・・つまり、頭打ちになりました。学習の限界ということなのでしょうね。

 

正解率が凸凹するのは、処理速度を稼ぐため、100枚単位くらいを1処理単位にして、損失率の平均をとって、それに基づいてパラメータを更新する方法をとっているからです。

 

サンプリングされた100枚の内容によって正解率が変わるのは当然なので、予想はしてました。

 

なんですが、ログを追いかけていると、資料やできあいのフレームワークのデモでは気づかなかったことも見えてきて、とても興味深かったです。