アラカン"BOKU"のITな日常

文系システムエンジニアの”BOKU”が勉強したこと、経験したこと、日々思うこと。

pythonで学習したグラフをJAVAで再利用する方法(3-1):Tensorflow入門の入門7/文系向け

今回から、「python版Tensorflowで学習させた学習済グラフを、JAVAで再利用する。」をやってみます。

 

簡単な利用イメージはこんな感じです。

f:id:arakan_no_boku:20170503103908j:plain

 

SpringBoot+Javaで動いているWEBアプリケーションに、ディープラーニングを使って推論する機能を組み込むのが目的です。

 

ただ、手順は若干おおいです。

 

一気にそれを書くと、めちゃくちゃ記事が長くなるので、3回にわけて書きます。

 

3-1:pythonで学習済モデルをJAVAでの再利用を考慮して保存する。(この記事)

 

3-2:保存した学習済モデルをJAVAで利用できるようにマージする   arakan-pgm-ai.hatenablog.com

 

3-3: マージした学習済モデルをJAVA側で読み込んで推論する。

arakan-pgm-ai.hatenablog.com

 

 

今回は、1回目です。

 

Trmsorflowで保存できるものには2種類あった

Tensorflowでエクスポート(保存)できるモデル(core artifacts)には2種類あったんですね。

 

この2種類です。

 

で、前回の記事で紹介した、tensorflow.trani.Saver()で保存できるののはcheckpointsのみです。

 

GraphDefは、tf.train.write_graph()を使って、testmodel.pbなどの名前をつけた、.pb(プロトコルバッファ)ファイルにします。

 

Graphと変数の状態ですから、JAVA等で再利用しようと思ったら、この2つがセットで存在しないといけないわけです。

 

ここが仲々わからなくて、checkpointsだけでやろうとして、変数が未定義だと怒られ、pb(プロトコルバッファ)ファイルのみでやろうとして、変数が初期化してないと怒られ・・を繰り返して、1週間くらいはまりました。

 

ということで、前回から、いくつか保存方法を変更します。

 

Tensorflowで学習モデルを保存する方法

変更点を先にまとめときます。

  1. 全体を名前をつけたGraphで囲む。(
  2. すべてのTensorに参照用の名前をつける。(
  3. Graphを保存する処理を追加する。(

です。

 

先にソースコード全体をのせます。

上記の変更点にあたる部分には色をつけてます。

import tensorflow as tf
import csv
import math
import os

#csvデータを読み込む部分は割愛してます。前回以前の記事参照。
gr = tf.Graph()
with gr.as_default():
     bias = tf.Variable(tf.zeros([2],dtype=tf.float32),name="bias")
stddev_1 = 2.0 / math.sqrt(5 * 2)
    weight = tf.get_variable("weight",[5,2],initializer=tf.random_normal_initializer(stddev=stddev_1,dtype=tf.float32))
    bias_h = tf.Variable(tf.zeros([5],dtype=tf.float32),name="bias_h")
stddev_2 = 2.0 / math.sqrt(5 * 5)
    weight_h1 = tf.get_variable("weight_h1",[5,5],initializer=tf.random_normal_initializer(stddev=stddev_2,dtype=tf.float32))
    weight_h2 = tf.get_variable("weight_h2",[5,5],initializer=tf.random_normal_initializer(stddev=stddev_2,dtype=tf.float32))
    weight_h3 = tf.get_variable("weight_h3",[5,5],initializer=tf.random_normal_initializer(stddev=stddev_2,dtype=tf.float32))
    data = tf.placeholder(dtype=tf.float32,shape=[None,5],name="data")
    label = tf.placeholder(dtype=tf.float32,shape=[None,2],name="label")
hidden_1 = tf.nn.relu(tf.matmul(data,weight_h1) + bias_h,name="hidden_1")
    hidden_2 = tf.nn.relu(tf.matmul(hidden_1,weight_h2) + bias_h,name="hidden_2")
    hidden_3 = tf.nn.relu(tf.matmul(hidden_2,weight_h3) + bias_h,name="hidden_3")
    y = tf.nn.softmax(tf.matmul(hidden_3,weight) + bias,name="y")
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(label * tf.log(y), reduction_indices=[1]),name="cross_entropy")
    train_step = tf.train.GradientDescentOptimizer(0.1,name="train_step").minimize(cross_entropy)
    correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(label,1),name="correct_prediction")
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"),name="accuracy")
    cwd = os.getcwd()
    with tf.Session(graph=gr) as s:
         s.run(tf.global_variables_initializer())
         saver = tf.train.Saver()
         for k in range(10):
              s.run(train_step,feed_dict={data:data_body,label:label_body})
        acc = s.run(accuracy, feed_dict={data:data_body_test,label:label_body_test})
        print("結果:{:.2f}%".format(acc * 100))
        saver.save(s,cwd + "\\model.ckpt")
        gr_def = gr.as_graph_def()
         tf.train.write_graph(gr_def,cwd,"testmodel.pb",as_text=False)

 

上記は、Windows版のtensorflowでやってます。

 

なので、「 cwd = os.getcwd()」でカレントフォルダのパスを取得して、いちいち絶対パスにしてからファイルを読み書きしています。

 

そうしないと、Windows版ではうまくいかないので。

 

Linux版だと、そういうのは必要なく、ファイル名のみまたは 「./model.ckpt」みたいな書き方でいけるので、そこは必要に応じて書き換えてください。

 

ポイントを確認していきます。

全体を名前をつけたGraphで囲む。  

あとでGraphを保存しなければいけないので、扱いやすいように名前をつけてるだけなんですけどね。

 gr = tf.Graph()
with gr.as_default():

 

こんな感じでグラフオブジェクトを作って、そのWith句でデフォルトのグラフを使うことを宣言してます。

 

CSVデータから読み込む処理とか、再利用するときにも外部で行う必要のある部分は、Graphで囲む前に書いてはずしておくほうがいいみたいですね。

 

すべてのTensorに参照用の名前をつける。 

以下例みたいに、すべてのTensorテンソル)/オペレーションに、name="xxx"で名前を追加してます。

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"),name="accuracy")

 

pythonでもJAVAでも、学習済モデルを再利用する時には、こちらの名前でアクセスする必要があるからです。

 

この名前をつけておかないと、実行したときに、「Noneにアクセスしようとしてるぞ!」と怒られます。

 

これがまた、原因がわかりづらいエラーになります。

 

ですから、とりあえずエラーが回避する方を優先して、すべてに名前をつけておくようにしています。

 

Graphを保存する処理を追加する。  

前回で、checkpointsを保存する処理は書いてあるので、今回はGraphを保存する処理だけを追加で書きます。

gr_def = gr.as_graph_def()
tf.train.write_graph(gr_def,cwd,"testmodel.pb",as_text=False)

 

現在のグラフから、GraphDef・・ようするにGraphをシリアライズしたもの・・を生成して、それを、tf.train.write_graph()を使って、testmodel.pbという名前のファイルに書き出しているだけです。

 

エクスポートされたファイルの確認

うまく動いていれば、以下のようなファイルができているはずです。

  • checkpoint
  • model.ckpt.data-00000-of-00001
  • model.ckpt.index
  • model.ckpt.meta
  • testmodel.pb

 

上の4つがチェックポイントのエクスポートデータ、最後の一つがGraphのエクスポートデータです。

 

あとは、これをJAVA側で読み込んでやればいいな。

 

と、そう思ったら、いやいやとんでもない。

 

実は、チェックポイントのエクスポートデータをJAVA側でどうやっても読み込めなかったんです。

 

どうも、現時点では

このファイルを全部マージして、変数の状態を反映させたpb(プロトコルバッファ)ファイルを作ってやる。

 

または、全く別の形式でエクスポートして、SavedModelBundleを利用して、JAVA側でロードする方法をとる。

 

このどちらかを行う必要があるみたいなんです。

 

ただ、SavedModelBundleの関連機能は、正直まだ tf.contrib.XXXの機能と同様に、依然開発中っぽいという情報があります。

 

それに、あとあとの事を考えると、とりあえず、古い(ベーシックな)やり方も知っておかないと、応用がきかなったら嫌なので、今回は前者の方法をとることにしました。

 

このマージする方法については、次回の記事(3-2)で書きます。

arakan-pgm-ai.hatenablog.com

 


Tensorflow入門の入門カテゴリの記事一覧はこちらです。

arakan-pgm-ai.hatenablog.com

f:id:arakan_no_boku:20170404211107j:plain