ポケモンのアイコンを識別するやつ(3)
CPPXのXです。
下記の学習部分を説明しようかと思います。
このソースコードについて説明していきたいと思います。
推論部分に使っているモデルの一部を説明をしようかと思います。
現在、3個の識別器を使っているのですが、それぞれ切り分けられるような構造になっているので、モデルの数と同じ3回くらいに渡って全部説明しよかなという感じです。
では、目次です。
環境
- python 3.6.5
- tensorflow 1.10.0
説明
テストデータについて
スマホで撮ったパーティ画像を使って、105件分のアイコン切り抜き画像を作成しました。(pkcp_data.zipに入ってるやつです)
今回言っている精度はその105件に対する識別率です。
モデル
全体像はこんな感じです。
rgbを使った識別器と、hsvのhs部分だけ使った識別器を作り、その結果を混ぜて最終予測をする識別器という構成になっています。
mix部分が識別精度への貢献が一番でかいので、今回はそこの話をしようかなと思います。
それぞれの精度がどんな感じかというと、rgb(96%前後), hs(90%前後), mix(99%前後)
hsの識別器がなかなか精度上がらないです。
では、mix部分の説明に入ろうと思います。
まず中身を見てみます。
割とシンプルな感じにしています。
rpool
rpoolについては、maxpoolした後に、depthwise convした結果を足し合わせています(separableではないです)。
チャネル増幅は4倍とかです。
足し合わせる時には、元が同じ特徴マップを足し合わせています。
気持ち的には、チャネル増幅しつつ残渣求める雰囲気になるのかなという感じです。
計算グラフはこんな感じです。
2回depthwiseがあるのは、maxpoolしてから畳んだ方がいいのか、stride2で畳んだ方がいいのか分からなかったので、両方やって連結しているためです。
ただのmaxpool + pointwiseにしても大丈夫だとは思います。
その辺り検証重ねる必要あると思うのですが、1回の学習がかなり時間かかるのでそのうちという感じです。
pointwise
pointwiseは特殊なことしてなくて1 x 1の通常畳み込みです。
block
block部分は、xceptionの中間部分(separable conv x 3) + seで構成されています。
このブロックが精度への寄与率かなり高いです。
さらに、通常の畳み込み処理を使った他のブロックよりかなり高速です。
計算グラフはこんな感じです。
connect部分はxception * se + input(ブロックへの)という感じです。
今回だとse入れても入れなくても特に何も変わらなかったので気分的に入れてます。
今回は2ループ回しています。
pred
pred部分はただの全結合1層でそのままsoftmaxへGOです。
mixで使ってるパーツは以上です。
mixへの入力は、他識別器からの特徴マップです。
最初の入力部分(addしてるとこ)では、他識別器の入力に近い特徴マップを入力しています。
2個目の入力部分(concatしてるとこ)では、他識別器の出力に近い特徴マップを入力しています。
計算グラフでも見てみます。
xceptionブロックに入る前と、全結合の直前からmixに繋いでいます。
rdepとかいうのはrpoolのmaxpoolしないやつです。
rdepの話はrgb識別器の説明をする時にします。
気持ち的には、3人寄れば文殊の知恵じゃね?的なところからきています。
rgbとhsの識別器にもpredがついているのですが、こちらでも予測を行います。
転移学習に近い感じなのでしょうか。rgbとhsはこのpredから誤差を注入して予め学習しておきます。
これらの識別器は独立しているので、それぞれモデルの中身を変えても問題ないです。
増やしても問題ないです。
ただ、増やした場合、単純に計算量が倍増するのでどんどん重いネットワークになっていきます。
独立していることを活かして、それぞれ別の計算資源に投げることもできます。
自分の手元だと並列に動かせる計算資源がないのでその辺りはやっていません。
学習時もmix以外は完全に独立させて学習できるので、資源豊富なら色々な識別器を繋ぐと面白そうです。
適当なゴミカスネットワークを繋いでもmixが取捨選択してくれるので、軽量弱識別器を大量に繋ぐとか面白そうかなと思ったのですが、それを極端化して実現してるのがseparable convなのかなとも思います。
mixについては以上です。
おわり
いい感じのネットワークができたのかなと思います。
mixに繋がる識別器が、それぞれ弱くても何かに特化してれば最終的な精度が上がるのが良い感じです。
なんか人間味溢れてますね。
次の記事では、rgbの識別器について説明を行おうかと思います。
今後やりたいこと
別データセット(CIFAR100とかimagenetとか)を学習した識別器を繋いだら精度上がったりしないかなあ。
denseネットの構造を取り入れたい。
中間層とか可視化しながら、うまく識別できないデータについて考察しなきゃいけないのかなあ・・・・。
モデルのダイエットしたい。
他の記事
ポケモンのアイコンを識別するやつ(1)
ポケモンのアイコンを識別するやつ(2)
ポケモンのアイコンを識別するやつ(3) now