アラカン"BOKU"のITな日常

あれこれ興味をもって考えたことを書いてます

Tensorflow入門の入門9:学習済モデルをTensorflow for JAVAで再利用して推論してみる

今回はpythonで学習させた「学習済モデル」を保存したファイルを、JAVAでインポートして、そのパラメータを利用して推論するのをやってみます。

f:id:arakan_no_boku:20170503103908j:plain

 

3回にわけて書いてきた記事の3回目です。

 

初めて見る方は、学習済モデルの保存(1回目)と、JAVAで利用できるようにマージする(2回目)から続けてみてもらった方が、わかりよいと思います。

 

1回目

 

2回目

 

JAVAでTensorflowを利用できるようにする

自分の開発環境はSTSで、フレームワークとしてSpringBootを使ってます。

 

そこでTensorflow連携テスト用のプロジェクトを、Mavenベースで作ってます。

 

その前提下での情報であることはご容赦ください。

 

もし、Mavenベースでない環境をお使いの場合はお手数ですが、こちらの記事を見てもらって、環境構築をお願いします。


Mavenベースの場合、Tensorflowを使えるようにするのは、pom.xmlに以下を追加するだけです。ただし、自分はwindows版のTensorflowを使っていることに注意してください。

<dependency>
     <groupId>org.tensorflow</groupId>
     <artifactId>tensorflow</artifactId>
     <version>1.1.0-rc0-windows-fix</version>
</dependency>

 

Windows版以外なら、Googleで「Maven tensorflow」とでも検索すれば設定情報はすぐ見つかりますので、それを参照してくださいね。 

 

JAVAでTensorflowの学習済モデルを利用する

さて、やっと本題です。

 

学習済モデルは、前回マージした「frozen_model.pb」を使います。

 

やることは以下の通りです。

  1. frozen_model.pbを読み、Graphオブジェクトを構築する。
  2. 構築したGraphでSessionオブジェクトを構築する。
  3. テスト用データのTensorオブジェクトを構築する。
  4. 推論を実行する。
  5. オブジェクトをクローズしてメモリを解放する。

 

最初にソースを全文載せて、ポイントにわけて確認する方式でやっていきます。

import java.io.File;
import java.nio.file.Files;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

public class ExampleTf {
    public static void main(String args) throws Exception {
        final File modelFile = new File("C:\\XXX\\frozen_model.pb");
        byte
graphDef = Files.readAllBytes(modelFile.toPath());
       Graph graph = new Graph();
       graph.importGraphDef(graphDef);
       Session session = new Session(graph);
       float d = new float{new float{20.0f ,20.0f ,20.0f ,0.0f ,0.0f}};
       float
l= new float{new float{1.0f ,0.0f}};
       Tensor data = Tensor.create(d);
       Tensor label = Tensor.create(l);
       Tensor acc = session.runner().feed("data", data).feed("label",label).fetch("accuracy").run().get(0);
       System.out.println("結果:" + acc.floatValue());
       data.close();
       label.close();
       acc.close();
       session.close();
       graph.close();
    }
}

 

 frozen_model.pbを読み、Graphオブジェクトを構築する

この部分のコードです。

 

サンプルとしてのわかりやすさを優先してフルパスなどを直接書いてたりしてますが、そのへんの手抜きはご容赦ください。

 

なお、このパスの書き方はWindows専用(Windows版Tensorflow for java専用)です。

final File modelFile = new File("C:\\XXX\\frozen_model.pb");
byte graphDef = Files.readAllBytes(modelFile.toPath());
Graph graph = new Graph();
graph.importGraphDef(graphDef);

 

Fileオブジェクトを構築し、内容をbyte配列で読み込んでます。

 

ファイルから読み込んだByte配列は、シリアライズされた状態のGraphですから、空のGraphオブジェクトを生成して、インポートしてやる必要があります。

 

 

まんまですね。

構築したGraphでSessionオブジェクトを構築する

この部分のコードです。

Session session = new Session(graph);

 

Sessionオブジェクトを構築して、runner()で実行しなければならないのは、python版と同じです。

 

テスト用データのTensorオブジェクトを構築する

とりあえず、わかりやすさ優先でテストデータは1件だけ手作りにしてみます。

 

今回の場合は、python側のデータ・タイプが「tf.float32」なので、Java側では floatになります。

 

ただ、Tensorflow for java で引数などに使えるのは、やっぱり基本的にはTensorだけですので、データを作成する手順としては以下になります。

  1. 対象のTensorのRankに合わせてfloatのデータを作る。
  2. 上記データで、Tensorオブジェクトを構築する。

 

対象のTensorのRankに合わせてfloatのデータを作るコードです。

上段はデータ、下段がラベルです。

float d = new float{new float{20.0f ,20.0f ,20.0f ,0.0f ,0.0f}};
float l= new float{new float[]{1.0f ,0.0f}};

 

上記データで、Tensorオブジェクトを構築する部分のコードです。

同様に上段はデータ、下段がラベルです。

Tensor data = Tensor.create(d);
Tensor label = Tensor.create(l);

 

推論を実行する

とりあえず、1件だけですがデータを与えて推論を実行します。

 

その部分と、結果を出力しているコードです。

Tensor acc = session.runner().feed("data", data).feed("label",label).fetch("accuracy").run().get(0);
System.out.println("結果:" + acc.floatValue());

 

Sessionオブジェクトのrunner()を使います。

 

feed("data",data)の第一引数は、python側でGraphを保存する際につけた名前(この例だと、"data")を指定します。

 

feed()は、pythonのコードでのfeed_dictに対応します。

 

fetch("accuracy")でも、同様に保存時にオペレーションにつけた名前を指定します。

 

結果はTensorオブジェクトで返されますので、そこからfloatValue()で値をとりだして、コンソールに出力しています。

 

オブジェクトをクローズしてメモリを解放する

ここが一番重要です。

 

Graph、Session,Tensorなどのオブジェクトを構築したら、必ず明示的にClose()しなければなりません。

 

自動的にメモリが開放されることはありません。

 

メモリリークとかの原因にもなりかねないので、必ず、Close()する癖はつけといたほうが良さそうですね。

 

試してみます。

実行してみました。

結果:1.0

 

いけてそうですね。

 

なお、実行すると以下のようなワーニングメッセージが4行くらい必ず出力されます。

The TensorFlow library wasn't compiled to use SSE instructions, but these are available on your machine and could speed up CPU computations.

エラーではありませんから、気にしなくても良いです。

 

まとめ

Tensorflow for java と、python版の両方で作業をした感覚としては、圧倒的にpython版の方が機能も豊富ですし、使いやすいです。

 

コードもシンプルに書けますしね。

 

だから、ニューラルネットワークを構築したらい、学習するような複雑な処理は、python版でやった方が絶対良いです。

 

でも、Webシステムとかで何かを入力させて、それを学習済モデルで推論して結果を返すような仕組みを作るなら、フロント部分はJAVAで開発する方が、圧倒的に生産性が高いです。

 

だから、python版でバックグラウンドで学習させて、学習済モデルを連携して、Javaで開発したWebシステム側で使うという役割分担ができればいいなと思ってました。

 

今回確認できたことで、それがやれることがわかったのは、うれしいですねえ。

 


Tensorflow入門の入門の前の記

Tensorflow入門の入門8:学習済モデルをJAVAで利用できるように保存(freeze_graph)しなおす。

f:id:arakan_no_boku:20170404211107j:plain