TensorFlowで画像を分類する実験をしたときに書いた一連のプログラムをGitHubに公開しました.
この記事を書いてからいろいろと手直しをして,まだまだ改善点は多いですが,少しずつ整ってきました.上記の実験でやっていることは,TensorFlowのチュートリアルにある「CIFAR-10」という画像分類問題です.
自分のオリジナルのテーマで画像分類問題をつくって取り組もうとすると,モデルの定義や学習のコードを書くことの他にも,以下のようなことをやらなければなりません.
- 訓練用の画像データを集める
- 評価用の画像データを集める
- 集めた画像データにラベル付けをする
- 画像とラベルのデータセットをTensorFlowに入力する
- 実際に画像を与えて学習成果を試す
これらのうち,最初の3項目は先日公開した『tfPhotoPalette』を使うことで少しだけラクができるようになりました.
今回は,このツールが生成するJSONファイルを与えて,実際に画像を与えて分類してみるまでの流れを担うPythonプログラムを紹介します.README.mdも併せてご覧ください.
オリジナルテーマを作成する
READMEの冒頭手順に従って,新たに挑戦しようとしているテーマ用のフォルダをworkspace/
フォルダ内に作成します.workspaceフォルダ内にseed/
という空のテーマがあり,それをコピーしているだけです.以下のようなディレクトリ構成になっています.
workspace/shokujin/
画像を集める
- tfPhotoPaletteで画像を切り抜いてラベル付けした後エクスポートしたJSONファイルを適当なファイル名で
raw_jsons/
フォルダに保存します. - ツールで生成したJSONファイルであれば,いくつのファイルに分かれていてもOKです.
- ひとつのファイルに複数ラベルぶんのデータが含まれていてもOKです.
- 複数人でツールを使ってそれぞれがエクスポートした場合はそれらをそのまま配置すればOKです.
ラベルを設定する
cifar10.labels
というファイルに,今回の学習で分類対象にするラベル名を設定します.食神さんの定食を分類する例では,以下のように書きます.
4 0, t5, teishoku-5 1, t1, teishoku-1 2, t4, teishoku-4 3, t3, teishoku-3
先頭の行に分類数(いくつのクラスに分けるか)を書き,続いて各行にラベル情報を書いていきます.カンマ区切りで書かれている値は順に,「プログラム中で扱うラベル番号
,ツールで与えたラベル名
,答えとして表示する名称
」です.
ここで設定したラベル名以外は,仮にエクスポートしたJSONに含まれていた場合でも学習の際には無視されます.このようにすることで,十分な数が揃っていない対象ラベルを除外指定することができます.
学習用と評価用に分ける
gen_labeled_jsons.py
を実行すると,raw_jsons/
フォルダ内のすべてのJSONファイルを読み取り,ラベル毎に「ラベルXに該当する画像はY件あります.何件ぶんを評価用に切り分けますか?」のように聞いてきます.適当な数値を入力していってください.
Tfrecordsファイルを生成する
.tfrecords
は画像とラベルの組をTensorFlowに与えるためのファイル形式です.gen_tfrecords.py
を実行すると,先ほど用意した学習用/評価用のデータセットをこのフォーマットにして新たに保存します.
学習 〜 評価する
ここでやっていることはチュートリアルとほとんど同じです.
train.py
: 学習用のtfrecordsのデータセットを用いて学習します.eval.py
: 評価用のデータセットを使って,正答率を求めます.
実際に写真を与えて遊ぶ
play.py
の--jpg
オプションに任意の画像ファイルを与えると,その写真がどのグループに分類されるかを確かめることができます.例えば,この間食べた3番定食の写真を与えて
$ python play.py --theme=shokujin --jpg ~/Desktop/teishoku-sample.jpg
のように実行すると,
のように結果が得られます.なんとか正解しているようです.
学習成果をすぐに試してみたいけれど,丁度よい写真が手元にない!という場合には,gen_toys.py
を使います.--ans
オプションに正解ラベル番号を与えると,評価用のデータセット中からランダムで5個のJPEGファイルを生成してくれます.
今回一連のツールを作ったことにより,思い立ったらすぐに画像を集めることができて,データが集まってきたらすぐに実行できるような環境が整ったので満足しています.肝心の学習モデルはチュートリアルのtensorflow.models.image.cifar10
を使ったままなので,今度はこちらを設計できるようになれるよう頑張ります*1.
GitHub
*1:食神の定食写真は引き続き大募集中です!