"BOKU"のITな日常

興味のむくまま気の向くままに調べたり・まとめたりしてます。

pythonで学習したグラフをJAVAで再利用できる形式で保存する/Tensorflow-r1.4

今回から、「python版Tensorflowで学習させた学習済グラフを、JAVAで再利用する。」をやってみます。 

f:id:arakan_no_boku:20190313220058j:plain

 

利用イメージ

 

簡単な利用イメージはこんな感じです。

f:id:arakan_no_boku:20170503103908j:plain

 

SpringBoot+Javaで動いているWEBアプリケーションに、ディープラーニングを使って推論する機能を組み込むのが目的です。 

ただ、手順は若干おおいです。 

一気にそれを書くと、めちゃくちゃ記事が長くなるので、3回にわけて書きます。 

3-1:pythonで学習済モデルをJAVAで再利用できる形式で保存する。(この記事) 

3-2:保存した学習済モデルをJAVAで利用できるようにマージする   

3-3: マージした学習済モデルをJAVA側で読み込んで推論する。 

今回は、1回目です。

 

pythonで学習済モデルをJAVAで再利用できる形式で保存する

 

Tensorflowでエクスポート(保存)できるモデル(core artifacts)には2種類あったんですね。 

この2種類です。

でも、tensorflow.trani.Saver()で保存できるののはcheckpointsのみです。 

GraphDefは、tf.train.write_graph()を使って、testmodel.pbなどの名前をつけた、.pb(プロトコルバッファ)ファイルにします。 

Graphと変数の状態ですから、JAVA等で再利用しようと思ったら、この2つがセットで存在しないといけないわけです。 

ここが仲々わからなくて、checkpointsだけでやろうとして、変数が未定義だと怒られ、pb(プロトコルバッファ)ファイルのみでやろうとして、変数が初期化してないと怒られ・・を繰り返して、1週間くらいはまりました。 

ということで、前回から、いくつか保存方法を変更します。

 

Tensorflowで学習モデルを保存する方法の変更点

 

変更点を先にまとめときます。

  1. 全体を名前をつけたGraphで囲む。(
  2. すべてのTensorに参照用の名前をつける。(
  3. Graphを保存する処理を追加する。(

です。 

先にソースコード全体をのせます。

上記の変更点にあたる部分には色をつけてます。

import tensorflow as tf
import csv
import math
import os

#csvデータを読み込む部分は割愛してます。前回以前の記事参照。
gr = tf.Graph()
with gr.as_default():
     bias = tf.Variable(tf.zeros([2],dtype=tf.float32),name="bias")
stddev_1 = 2.0 / math.sqrt(5 * 2)
    weight = tf.get_variable("weight",[5,2],initializer=tf.random_normal_initializer(stddev=stddev_1,dtype=tf.float32))
    bias_h = tf.Variable(tf.zeros([5],dtype=tf.float32),name="bias_h")
stddev_2 = 2.0 / math.sqrt(5 * 5)
    weight_h1 = tf.get_variable("weight_h1",[5,5],initializer=tf.random_normal_initializer(stddev=stddev_2,dtype=tf.float32))
    weight_h2 = tf.get_variable("weight_h2",[5,5],initializer=tf.random_normal_initializer(stddev=stddev_2,dtype=tf.float32))
    weight_h3 = tf.get_variable("weight_h3",[5,5],initializer=tf.random_normal_initializer(stddev=stddev_2,dtype=tf.float32))
    data = tf.placeholder(dtype=tf.float32,shape=[None,5],name="data")
    label = tf.placeholder(dtype=tf.float32,shape=[None,2],name="label")
hidden_1 = tf.nn.relu(tf.matmul(data,weight_h1) + bias_h,name="hidden_1")
    hidden_2 = tf.nn.relu(tf.matmul(hidden_1,weight_h2) + bias_h,name="hidden_2")
    hidden_3 = tf.nn.relu(tf.matmul(hidden_2,weight_h3) + bias_h,name="hidden_3")
    y = tf.nn.softmax(tf.matmul(hidden_3,weight) + bias,name="y")
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(label * tf.log(y), reduction_indices=[1]),name="cross_entropy")
    train_step = tf.train.GradientDescentOptimizer(0.1,name="train_step").minimize(cross_entropy)
    correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(label,1),name="correct_prediction")
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"),name="accuracy")
    cwd = os.getcwd()
    with tf.Session(graph=gr) as s:
         s.run(tf.global_variables_initializer())
         saver = tf.train.Saver()
         for k in range(10):
              s.run(train_step,feed_dict={data:data_body,label:label_body})
        acc = s.run(accuracy, feed_dict={data:data_body_test,label:label_body_test})
        print("結果:{:.2f}%".format(acc * 100))
        saver.save(s,cwd + "\\model.ckpt")
        gr_def = gr.as_graph_def()
         tf.train.write_graph(gr_def,cwd,"testmodel.pb",as_text=False)

 

上記は、Windows版のtensorflowでやってます。 

なので、「 cwd = os.getcwd()」でカレントフォルダのパスを取得して、いちいち絶対パスにしてからファイルを読み書きしています。 

そうしないと、Windows版ではうまくいかないので。 

Linux版だと、そういうのは必要なく、ファイル名のみまたは 「./model.ckpt」みたいな書き方でいけるので、そこは必要に応じて書き換えてください。 

ポイントを確認していきます。

 

全体を名前をつけたGraphで囲む。  

 

あとでGraphを保存しなければいけないので、扱いやすいように名前をつけてるだけなんですけどね。

 gr = tf.Graph()
with gr.as_default():

 

こんな感じでグラフオブジェクトを作って、そのWith句でデフォルトのグラフを使うことを宣言してます。 

CSVデータから読み込む処理とか、再利用するときにも外部で行う必要のある部分は、Graphで囲む前に書いてはずしておくほうがいいみたいですね。

 

すべてのTensorに参照用の名前をつける。 

 

以下例みたいに、すべてのTensorテンソル)/オペレーションに、name="xxx"で名前を追加してます。

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"),name="accuracy")

 

pythonでもJAVAでも、学習済モデルを再利用する時には、こちらの名前でアクセスする必要があるからです。 

この名前をつけておかないと、実行したときに、「Noneにアクセスしようとしてるぞ!」と怒られます。 

これがまた、原因がわかりづらいエラーになります。 

ですから、とりあえずエラーが回避する方を優先して、すべてに名前をつけておくようにしています。

 

Graphを保存する処理を追加する。  

 

前回で、checkpointsを保存する処理は書いてあるので、今回はGraphを保存する処理だけを追加で書きます。

gr_def = gr.as_graph_def()
tf.train.write_graph(gr_def,cwd,"testmodel.pb",as_text=False)

現在のグラフから、GraphDef・・ようするにGraphをシリアライズしたもの・・を生成して、それを、tf.train.write_graph()を使って、testmodel.pbという名前のファイルに書き出しているだけです。

 

エクスポートされたファイルの確認

 

うまく動いていれば、以下のようなファイルができているはずです。

  • checkpoint
  • model.ckpt.data-00000-of-00001
  • model.ckpt.index
  • model.ckpt.meta
  • testmodel.pb

 上の4つがチェックポイントのエクスポートデータ、最後の一つがGraphのエクスポートデータです。 

あとは、これをJAVA側で読み込んでやればいいな。 

と、そう思ったら、いやいやとんでもない。 

実は、チェックポイントのエクスポートデータをJAVA側でどうやっても読み込めなかったんです。 

どうも、現時点では

このファイルを全部マージして、変数の状態を反映させたpb(プロトコルバッファ)ファイルを作ってやる。 

または、全く別の形式でエクスポートして、SavedModelBundleを利用して、JAVA側でロードする方法をとる。 

このどちらかを行う必要があるみたいなんです。 

ただ、SavedModelBundleの関連機能は、正直まだ tf.contrib.XXXの機能と同様に、依然開発中っぽいという情報があります。 

それに、あとあとの事を考えると、とりあえず、古い(ベーシックな)やり方も知っておかないと、応用がきかなったら嫌なので、今回は前者の方法をとることにしました。

 

マージする方法については、次回の記事(3-2)で書きます。

arakan-pgm-ai.hatenablog.com