"BOKU"のITな日常

還暦越えの文系システムエンジニアの”BOKU”は新しいことが大好きです。

過学習抑制の手段であるWeight DecayはSGDと相性が良くて、Adamとは良くない/使い方35

今回はとても地味・・なテーマです。

Weight Decayです。

 

Weight Decayとは

 

Weight Decay は直訳すると「荷重減衰」みたいな感じになります。 

Weight Decayは過学習の抑制手法のひとつです。

過学習というのは、学習(訓練)データに適合しすぎて、学習(訓練)データと異なるデータでの正解率が低くなってしまうことです。 

過学習は重み(Weight)が大きな値をもつことによって発生することが多いということから、学習過程で重み(Weight)が大きくならないようにペナルティを課す方法です。

ニューラルネットワークの学習は、損失関数(レイヤーでいうとCategoricalCrossEntropyとか)で、どの位正解と差があるかを求めて、その差異が小さくなる様、逆上って(逆伝播)重みやバイアスの値を更新していくわけです。 

その過程で、損失関数の値に、なんらかの値をペナルティとして加算してやれば、重み(Weight)が大きくならないようにできる。

そういう理屈です。

 

Weight Decayのペナルティの値

 

じゃあ、そのなんらかの値は何か。 

いくつか方法はあるみたいですが、よく使われているのが「重みの各要素の二乗を足し合わせたもの」に任意の係数をかけ合わせたものです。 

この「重みの各要素の二乗を足し合わせたもの」をペナルティに使う方法を「L2ノルム正規化=L2正規化」と呼びます。 

数式で理解したい方は、L2ノルム正規化などでググれば、たぶんいくらでも情報がでてくると思いますので、そちらでどうぞ(笑)

 

Neural Network Consoleのマニュアルの説明

 

Weight Decayという項目は、Neural Network ConsoleのConfigタブのOptimizerのところにあります。

リファレンスを見てみると、一応説明はあります。

support.dl.sony.com

書いてあるのはたったの2行。

4 Weight Decay(L2正則化)の強度を設定するにはWeight DecayにWeight Decayの係数を指定します。

上記の説明を読んだあとなら、Configタグで指定するWeight Decayの係数というのは上記でいう「任意の係数」にあたるのがわかると思います。

でも・・。

そこまで知らない人には、全くわからない。

そういう説明ですよね。

 

Weight Decayの係数の適切な値とは

 

f:id:arakan_no_boku:20180412235717j:plain

 

ここが悩ましいところです。 

この数値にしておけば、なんでもOK!みたいな数値はありません。

(自分が知らないだけかもしれませんけど・・) 

本格的にやるなら、ベイズ最適化などハイパーパラメータの自動最適化の手法はいろいろあるみたいですが、趣味でやってる人間にとっては大げさです。

qiita.com

とりあえず、地道にやるしかないかな。

 

どの値から始めるかを検討

 

正直、書籍とかによっても、この係数についてはバラバラです。 

0.1くらいを設定・・という人もあれば、0.0000001位が最適だったなんて人もいる。 

仕方ないので、ちょこちょこ試しながら、様子を見てます。 

その結果。 

自分は、Neural Network Consoleでやる場合は「0.0001」から始めることが、最近は多いです。 

経験的にベストでないにしても、割合ベターな結果が期待できる感じなので。 

ご参考になれば。

 

Weight Decayは、OptimizerのAdamと相性がよくない

 

Neural Network ConsoleのOptimizerのデフォルトは「Adam」です。 

基本、Adamは収束が速く、とりあえず選択して問題がないOptimizerです。

arakan-pgm-ai.hatenablog.com

 

でも、「学習データと異なるデータでの正解率をあげる=汎化」目的で、Weight Decayを使うときには、Adamはベストチョイスといえなくなります。 

こちらのTweetに貼ってあるリンク(論文のPDF)にその理由が説明されています。

簡単に言うなら、Weight Decayの係数を0以外に変更しても、Adamを使っている限り、過学習の傾向があまり改善しないのです。

 

論より証拠でやってみる。まず「Adam」で。

 

学習した結果が、こういうグラフになっているCNNのネットワークを題材にします。 

f:id:arakan_no_boku:20180414004750j:plain

 

Traiingエラー(赤の実線)がきれいに収束しているのに、Validationエラー(赤の点線)が途中から逆に増加して、どんどん離れていってます。 

もう、典型的な過学習の状態です。 

なので、評価した結果のAccuracyも「97.65%」と、かなり低めです。

f:id:arakan_no_boku:20180414005126j:plain

 

あまり改善しないことの確認

 

ここに、OptimizerはAdamのままで、Weight Decayの係数を「0.0001」にしてみます。

f:id:arakan_no_boku:20180413231926j:plain

 結果はどうかというと、改善はされます。 

でも、形的にあまり変わりませんし、過学習の傾向は残ったままです。

f:id:arakan_no_boku:20180413232737j:plain

f:id:arakan_no_boku:20180413232758j:plain

 

今度は「Sgd」に変更してやってみる

 

OptimizerのUpdaterを「sgd」に変更してみます。

f:id:arakan_no_boku:20180413233252j:plain

 それ以外はまったく同じです。 

 

SGDの場合は効果があった

 

そうすると学習結果のグラフの形が劇的に変わります。

f:id:arakan_no_boku:20180414005636j:plain

 

traiingエラーとValidationエラーの差がほぼありません。 

かなりいい感じで汎化できていると期待できます。 

ただ、残念ながら「SGDは収束が遅い」ので、50epochでは収束しきっていません。 

この時点で評価したら、Accuracyは97%前半しかいきません。 

ちょっと悔しいので、epoch数を増やして再チャレンジして、最終的にはここまでいきました。

f:id:arakan_no_boku:20180414002910j:plain

ま・・、元々のネットワーク構成がよろしくないのか、いまいちですけど。

 

特徴を理解して状況で使い分ける感じですかね

 

まとめると。 

Weight decayの値を0以外(例えば 0.0001等)にすると、L2正規化が働いて、過学習の抑制効果があります。 

ただ、Optimizerタブで「Adam」を選択していると、相性の問題で、あまり効果がありません。 

そこを「sgd」に変更すると、Weight decayの汎化効果が最大限に発揮されますが、今度は収束が遅くなるので、Max Epochの数を増やすなどの対応が必要になります。 

結局、状況で「今、どちらに重点をおくべきか」で、どちらを選択するかを判断しないといけないということですね。 

f:id:arakan_no_boku:20171115215731j:plain