アラカン"BOKU"のITな日常

あれこれ興味をもって考えたことを書いてます

tensorflow入門の入門6:学習結果を保存して、再利用する:(文系向け)

学習済のパラメータを保存して、再利用できるようにしてみます。

 

前回で、少しだけ深いニューラルネットワークを構築し、CSVデータからデータを読み込んで学習出来る様にしました。

arakan-pgm-ai.hatenablog.com

 

その学習結果を保存して、再利用するように変更していきます。

 

ここまでできれば、自分の希望する任意のデータをCSVファイルにして、学習・推論させる雛形として、一通りのものが揃いますから。

 

学習結果の保存と再利用(読み込み)する機能は、Tensorflowで、ちゃんと用意されてます。・・当たり前ですね。

 

tensorflow.train.Saver() です。

 

それで、save() で保存、restore() で読み込みです。

 

超簡単です。

 

ただ、それを書く位置と順序に注意事項があるみたいなので、今回はその辺の確認をする感じなるのかな。

 

学習結果を保存する

保存機能を付け加えたコードを抜粋してみます。

 

ベースは前回のコードで、付け加えた部分だけ、色を変えてます。

with tf.Session() as s:
     s.run(tf.global_variables_initializer())
     saver = tf.train.Saver()
     cwd = os.getcwd()
     for k in range(10):
          s.run(train_step,feed_dict={data:data_body,label:label_body})
     acc = s.run(accuracy, feed_dict={data:data_body_test,label:label_body_test})
     print("結果:{:.2f}%".format(acc * 100))
     saver.save(s,cwd + "\\model.ckpt")

 

1行ずつ確認していきます。

 

まず、Saver()のオペレーションを作ります。

 saver = tf.train.Saver()

 

注意点としては以下の2点です。 

  • Sessionの中でなければならない。
  • tf.global_variables_initializer()の実行後でなければならない。

 

2つ目は、変数を使っている時だけです。

 

もし、tf.global_variables_initializer()の実行前に書くと、実行時に変数がない(No variables to save)とエラーがでます。

 

次の行は、Windows版のTensorflowの時だけは必ず必要です。

cwd = os.getcwd()

 

ネットのサンプルプログラムは、Linux版ベースで書かれているものが多いので、この処理は書いていません。

 

でも、Windows版ではこうしないと、ディレクトリが見つからない(ValueError: Parent directory of model.ckpt doesn't exist, can't save.)と怒られます。

 

で、最後にsave()コマンドで保存します。

saver.save(s,cwd + "\\model.ckpt")

 

上記はWindows版のケースです。

 

Linux版だと以下の書き方でもいけるはずですが、Windows版はこれだとエラーになります。

saver.save(s, "model.ckpt")

 

Windows版の場合は、絶対パスで指定してやらないと、ディレクトリを見つけられないみたいですね。

 

細かいとこですけど、意外にはまりそうな箇所ではありますね。

 

これで実行して成功すると、カレントディレクトリに以下のように最低4つのファイルができるみたいです。

f:id:arakan_no_boku:20170514234621j:plain

 

読み込んで再利用する

じゃあ、先程保存した学習済パラメータを使って、推論をやってみて前回と同じ結果がでるかどうか試してみます。

 

学習済パラメータを読み込んだだけで、全く学習を行わないで、推論だけやるコードを抜粋します。

 

ベースは前回のコードで、付け加えた部分だけ、色を変えてます。

with tf.Session() as s:
     s.run(tf.global_variables_initializer())
     saver = tf.train.Saver()
     cwd = os.getcwd()
     saver.restore(s,cwd + "\\model.ckpt")
     acc = s.run(accuracy, feed_dict={data:data_body_test,label:label_body_test})
    print("結果:{:.2f}%".format(acc * 100))

 

最初の2行は保存のときと同じです。

 

読み込んで、パラメータをリストアしているのは以下の部分です。

saver.restore(s,cwd + "\\model.ckpt")

 

Windows版だけ絶対指定が必要なのも同じです。

 

まあ、saveと全く同じ形で、restoreという名前なので、別に説明しなくても何をしているかはわかりますね。

 

これで、学習済パラメータが再現されていれば、テストデータで分類した結果が前回と全く同じになるはずです。

 

さて、どうかな。

f:id:arakan_no_boku:20170513232641j:plain

 

おお!完璧じゃないですか。

 


Temsorflow入門の入門の前の記事

Tensorflow入門の入門5:少しだけ深いニューラルネットワークへ:(文系向け)

f:id:arakan_no_boku:20170404211107j:plain