畳み込みニューラルネットワークが持つ座標変換の問題に着目してCoordConvを提案したAn intriguing failing of convolutional neural networks and the CoordConv solution (NeurIPS'18, arXiv) について紹介します。
@incollection{NIPS2018_8169, title = {An intriguing failing of convolutional neural networks and the CoordConv solution}, author = {Liu, Rosanne and Lehman, Joel and Molino, Piero and Petroski Such, Felipe and Frank, Eric and Sergeev, Alex and Yosinski, Jason}, booktitle = {Advances in Neural Information Processing Systems 31}, editor = {S. Bengio and H. Wallach and H. Larochelle and K. Grauman and N. Cesa-Bianchi and R. Garnett}, pages = {9605--9616}, year = {2018}, publisher = {Curran Associates, Inc.}, url = {http://papers.nips.cc/paper/8169-an-intriguing-failing-of-convolutional-neural-networks-and-the-coordconv-solution.pdf} }
畳み込みニューラルネットワークの座標変換
CNNでは空間表現を、密なデカルト表現から疎なピクセルベースの表現へ (またはその逆) が難しい。畳み込み層の単純なスタックも、画像分類のようなタスクには効果があるが、物体検出や生成などのタスクには不向き。
CoordConv
下図右がここで提案しているCoordConvです。チャネル方向に座標情報 (以下ではi coordinateとj coordinateの2チャネル) を連結しているのが特徴です。
- i coordinateは、
[[0 0 ... 0] [1 1 ... 1] ... [h-1 h-1 ... h-1]]
(行内が同一の値) となるランク1行列 - j coordinateは、
[[0 1 ... w-1] [0 1 ... w-1] ... [0 1 ... w-1]]
(列内が同一の値) となるランク1行列 - いずれも [-1, 1] に収まるように線形スケール
通常の畳み込み層の性質を維持しつつ、タスクの特性に応じて追加されていた座標情報をシンプルに表現している。
- パラメータ数は
->
でd (coordinatesチャネル数で2) 分のみの増加
- coordinatesチャネルに対する重みが0になれば、通常の畳み込み同様の並進不変性は保たれる
座標変換が表現されていることをトイタスクで確認
座標変換がうまくできることを簡単なタスクで確認。
- タスク1: (x, y)と対応する座標が1.0となるように予測する (下図上段)
- タスク2: (x, y)を中心とする正方形をレンダリングする (下図中段)
- タスク3: 赤と青の2つの図形をランダムに生成した画像で学習・生成する (下図下段)
タスク1の結果をプロットしたのが下図です。座標に対応する点 (ピクセル) を予測するClassificationとその逆変換のRegressionの2つを比較していますが、いずれにおいても、通常の畳み込み (中段) はほぼ予測できてないが、CoordConvでは精度100%をマークするという結果になってます。
CoordConvを実践的なタスクに適用
畳み込みをCoordConvに置き換えてうまく座標変換できるようにすることで、一般のタスクも改善できるかどうかを調べる。
画像分類
並進不変性を必要とするImageNet画像分類タスクの場合、Accuracyの改善につながらなかったが、一方で悪化もなかった。
- 追加したcoordinatesチャネルによって並進不変性が損なわれないということが確認できる
物体検出
物体検出では、ピクセル空間からデカルト空間のBBoxに変換する座標変換を行っており、CoordConvがマッチしそうです。
Faster R-CNNのRPNをCoordConvで置き換えたネットワークで、簡単な物体検出タスクを行ったところ、IoUが大きく改善した。
- Faster R-CNNについてはこちらも参照: 論文メモ: Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks - け日記
生成モデル
生成モデルで起こる位置のモード崩壊は、潜在空間からピクセル空間への変換の難しさにも起因していると考え、CoordConvが一助になると予想した。
タスク3について、2種類のGANモデル (通常の畳み込みバージョンとCoordConvバージョン) で学習・生成した。通常の畳み込みでは図形の中心が実際の分布よりも中央に寄りやすい傾向が見られたが、CoordConvのモデルではそれが改善していた (下図 (b)) 。1000サンプルの平均値を見ても、CoordConvの方が実際の平均に近い結果が得られた (下図 (d)) 。
- ただし、図形間の距離の分布はどちらかというとCoordConvの方が悪くなっている (下図 (c))
まとめ
畳み込み層に位置情報をシンプルに埋め込んだCoordConvについて紹介しました。学習・推論にかかる追加コストも小さく、また、幅広いタスクに適用・応用できる点で、アドバンテージのある提案だと思います。