Tensorflow.jsを使って、MNIST画像の評価をやってみます。
今回やってみるのは、こんな感じ
Kerasでトレーニングした学習済モデルを、コンバートして、tensorflow.jsで読み込める形式にコンバートしたものを使って、固定のMNIST画像(PNGファイルにしたもの)を読んで、評価して結果を表示する。
そんな感じの簡単なデモを作ります。
うまくいった時の画像イメージはこんな感じ。
超シンプルですが、これだけでも以下の処理ステップが必要です。
- tensorflow.jsの読み込み
- 変換したKerasで学習済のモデルのロード
- 画像を読み込み、ImageDataに変換
- ImageDataをTensorに変換
- Tensorに変換した画像イメージの評価
- 評価結果の表示
しかも・・ハマりどころ満載でした。
JavaScript部分のソースとハマりどころの補足
Tensorflow.jsを使用している部分のソースです。
<!-- Load TensorFlow.js --> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.13.0"></script> <script>// <![CDATA[ async function main(){ const model = await tf.loadModel("<?php echo base_url('model/model.json'); ?>"); //画像オブジェクトを生成 var width = 28; var height = 28; var img = new Image(); img.src = "<?php echo base_url('images/001.png'); ?>"; var canvas = document.createElement("canvas"); canvas.setAttribute("width", width); canvas.setAttribute("height", height); var context = canvas.getContext("2d"); context.drawImage(img, 0, 0, width, height); var imageData = context.getImageData(0, 0, width, height); const example = tf.fromPixels(imageData, 1).reshape([1,28,28]); const prediction = model.predict(example); $("#result").text("この画像の数字は「" + prediction.argMax(-1).dataSync() + "」だよ!"); } main(); // ]]></script>
途中に2か所、PHPの記述がはいってます。
JavaScript以外の部分は、「PHP+CodeIgniter3+Bootstrap4」で動かしているので、変換済モデルや画像ファイルなどの静的データにアクセスするのに、こういう書き方をしないといけなかったというだけです。(興味があれば以下の記事をどうぞ)
なので、実行時には、例えば「http://localhost/hogehoge/resource/model/model.json」みたいなURLに置き換わっていると思って読み替えてください。
さて、ポイントを補足します。
tensorflow.jsの読み込み
以下の部分です。
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.13.0"></script>
これでだけで、「tf.loadData」のようにして、TensorFlow.jsの機能が使えるようになります。
実は、もうひとつのやり方があります。
import * as tf from '@tensorflow/tfjs';
これでもいいです。
こちらの方がモダンな書き方です。
でも、一度躓いたら、JavaScriptに精通している方以外だと、たぶん苦しみます。
scriptタグで「type="module"」にしとかないといけないとか。
相対参照「@tensorflow/tfjs'」がきちんと解決されなくて、エラーがでるとか。
Chromeで実行すると「/が頭についてない」などと怒られるとか。
などなど、なにかとハマりどころ満載です。
自分も、途中で疲れて、今回は簡単なほうを選択しました(笑)
変換したKerasで学習済のモデルのロード
この1行です。
const model = await tf.loadModel("<?php echo base_url('model/model.json'); ?>");
「<?php echo base_url('model/model.json'); ?>」の部分は、CodeIgniter の書き方で、model.jsonへのパスを書いている感じです。
実行時には、例えば「http://localhost/hogehoge/resource/model/model.json」に置き換わっていると思ってください。
注意が必要なのは「await」がついているところです。
awaitは、「async function」の実行を停止して、tf.loadModelの終了を待って、再開する感じで、とてもスマートに「非同期処理」を実現してくれます。
と・・いうことは。
必ず、「async function」の中で使わないといけないということです。
だから、上記でもわざわざ「async function main(){}」で囲って処理を書いて、最後に「main();」で実行してます。
画像を読み込み、ImageDataに変換
ここは、ちょっと長いです。
①
var width = 28;
var height = 28;
var img = new Image();
img.src = "<?php echo base_url('images/001.png'); ?>";②
var canvas = document.createElement("canvas");
canvas.setAttribute("width", width);
canvas.setAttribute("height", height);
var context = canvas.getContext("2d");
context.drawImage(img, 0, 0, width, height);③
var imageData = context.getImageData(0, 0, width, height);
画像ファイル(001.png)は、アクセス可能な静的フォルダにおいて、そこへのパスをCodeIgniterの機能を使って書いてるので、ちょっと見づらいですが、ご容赦を。
やってることは。
の3段階で、ImageDataを生成しているだけです。
ImageDataをTensorに変換
ImageDataを、Tensor(Tensorflowで使う形式)に変換します。
const example = tf.fromPixels(imageData, 1).reshape([1,28,28]);
これ一発です。
サンプルだと、「 tf.fromPixels(imageData, 1)」までしか書いてないことが多いですが、reshapeであわせてやらないと、[null,28,28]になっているぞ・・みたいにエラーになります。
結構ポイントです。
Tensorに変換した画像イメージの評価
Tensorデータ(今回の例だとexsample)に対して評価を実行します。
const prediction = model.predict(example);
評価結果は、predictionにはいります。
これもTensorです。
評価結果の表示
評価結果はHTMLに出力したいので、JQueryのIDセレクタを使って、id="result"のタグに出力する文字列を生成します。
$("#result").text("この画像の数字は「" + prediction.argMax(-1).dataSync() + "」だよ!");
評価結果の「predict」には以下のようなデータがはいってます。
[[0,0,0,0,0,0,0,0,1],]
ここから、結果を得るには「argMax(-1).dataSync() 」を使います。
こうすることで、最も大きい数がはいっているインデックスを返してくれます。
これで処理は終わりです。
なお。
上記のJavaScriptのソースで、var、const、letとかが変数についてます。
一応、以下のような意味があります。
- var:再代入可能な変数/関数スコープ
- let:再代入可能な変数/ブロックスコープ(関数スコープより狭い)
- const:初期化のみで再代入不可/スコープはletと同じ
でも、上記サンプルの中ではあまり深く考えずに無造作に使ってます。
モダンじゃない・・と気になる方は、修正して使ってください(笑)
PHP・CodeIgniter・bootstrapを加えた全体ソースです
最初に、Tensorflow.jsを使っているViewの部分です。
tfdemo.php
<!doctype html> <html lang="jp"> <head> <!-- Required meta tags --> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no"> <!-- Bootstrap CSS --> <link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.1.3/css/bootstrap.min.css" integrity="sha384-MCw98/SFnGE8fJT3GXwEOngsV7Zt27NXFoaoApmYm81iuXoPkFOJwJ8ERdknLPMO" crossorigin="anonymous"> <!-- Optional JavaScript --> <!-- jQuery first, then Popper.js, then Bootstrap JS --> <script src="https://code.jquery.com/jquery-3.3.1.slim.min.js" integrity="sha384-q8i/X+965DzO0rT7abK41JStQIAqVgRVzpbzo5smXKp4YfRvH+8abtTE1Pi6jizo" crossorigin="anonymous"></script> <script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.14.3/umd/popper.min.js" integrity="sha384-ZMP7rVo3mIykV+2+9J3UJ46jBk0WLaUAdn689aCwoqbBJiSnjAK/l8WvCWPIPm49" crossorigin="anonymous"></script> <script src="https://stackpath.bootstrapcdn.com/bootstrap/4.1.3/js/bootstrap.min.js" integrity="sha384-ChfqqxuZUCnJSK3+MXmPNIyE6ZbWh2IMqE241rYiqJxyMiZ6OW/JmZQ5stwEULTy" crossorigin="anonymous"></script> <!-- Load TensorFlow.js --> <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@0.13.0"> </script> <script> async function main(){ const model = await tf.loadModel("<?php echo base_url('model/model.json'); ?>"); //画像オブジェクトを生成 var width = 28; var height = 28; var img = new Image(); img.src = "<?php echo base_url('images/001.png'); ?>"; var canvas = document.createElement("canvas"); canvas.setAttribute("width", width); canvas.setAttribute("height", height); var context = canvas.getContext("2d"); context.drawImage(img, 0, 0, width, height); var imageData = context.getImageData(0, 0, width, height); const example = tf.fromPixels(imageData, 1).reshape([1,28,28]); const prediction = model.predict(example); $("#result").text("この画像の数字は「" + prediction.argMax(-1).dataSync() + "」だよ!"); } main(); </script> <title><?php echo $title; ?></title> </head> <body> <div class="container"> <div class="row bg-primary text-white"> <div class="col-2"> サンプル </div> <div class="col-2"> ヘダー </div> <div class="col-8"> CodeIgniterの練習です </div> </div> <div class="row"> <div class="col text-center my-5"> <h2><?php echo $title; ?></h2> </div> </div> <div class="row"> <div class="col text-center my-5"> <?php echo img('images/001.png')?> </div> </div> <div class="row"> <div class="col text-center my-5"> <h1 id="result"></h1> </div> </div> <div class="row bg-primary text-white"> <div class="col"> フッター </div> </div><!-- container --> </body> </html>
Tfdemo.php
<?php defined('BASEPATH') or exit('No direct script access allowed'); /* * Tensorflow.jsを利用するデモ * * Kerasで学習済のモデルを引き継ぐ * データはMNISTを利用する * */ class Tfdemo extends CI_Controller{ public function __construct() { parent::__construct(); $this->load->helper('url'); $this->load->helper('html'); } public function doit(){ $this->load->helper('form'); $this->load->library('form_validation'); $data['title'] = "Tensorflow.jsのサンプル"; if (! file_exists(APPPATH . 'views/demos/tfdemo_body.php')) { show_404(); } $data['msg'] = "入力を正常に受け取りました"; $this->load->view('demos/tfdemo.php', $data); } }
PHP・CodeIgniter・Bootstrapを使ってます。
ではでは。