ResNetでCIFAR-10を分類する

KerasでResNetを作ってCIFAR-10を分類し、通常のCNNモデルと比較します。

ResNet

ResNetはCNNのモデルの1つです。
Microsoft ResearchのKaiming Heらが2015年に提案1し、その年のILSVRCではResNetで学習したモデルが優勝しました。

VGGやGoogLeNetにて、畳み込み層を重ねることでより良い感じの特徴抽出ができることが明らかになっていましたが、多層になればなるほど前段の層で勾配消失問題が顕著になります。

ResNetは勾配消失問題を解消するために、複数の畳み込み層をスキップするshortcut connectionを導入しました。このresidual blockの図を以下に示します。

  • 今までの畳み込み層は、入力を各レイヤに通すdeep path (青) のみ
  • ResNetでは、deep pathの出力と、入力をそのまま伝搬するshortcut connection (赤) を加算する

学習過程で畳み込み層F(x)は、畳み込み層の出力H(x)をそのまま学習するのではなく、入力xとの残差H(x)-xで重みを学習するようになります (ResNetのResidualはここから来てます) 。

deep pathを通過することで勾配が小さくなっていってましたが、shortcut connectionで畳み込み層をスキップできるので、前段の層にも勾配を伝搬可能となった、ということです。

  • 論文中では、deep path内でサイズを変えて合流させるBottleneckアーキテクチャも触れられていましたが、今回はそれを行わないPlainアーキテクチャを組みます

また、以下の特徴も性能へ好影響を与えていることが指摘されてます。

  • 表現力が高すぎて役に立たない畳み込み層を残差0でスキップできる
  • 複数のパスを通しているためアンサンブル学習に近い

CIFAR-10

32x32のカラー画像 (3チャネル) を10クラスに分類するタスクで、Kerasのデータセット付属のデータセットです。

  • 学習データ50,000サンプル、テストサンプル10,000サンプルからなります
  • 各クラスのサンプル数は均等

Kerasで実装

以下の3パターンを実装します。

  • thinモデル: 畳み込み層x3 (Shortcut connection無し, 出力層8x8x32)
  • plainモデル: 畳み込み層x13 (Shortcut connection無し, 出力層4x4x64)
  • residualモデル: 畳み込み層x13 (Shortcut connection有り, 出力層4x4x64)
import matplotlib.pyplot as plt
import keras
from keras import backend as K
from keras.models import Sequential, Model
from keras.layers import Input, Conv2D, MaxPool2D, BatchNormalization, ReLU, Flatten, Dense, Add, Dropout
from keras.datasets import cifar10

# データのロード
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)
# (50000, 32, 32, 3) (50000, 1) (10000, 32, 32, 3) (10000, 1)

# 0〜1へ正規化
X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0

# one-hotエンコーディング
Y_train = keras.utils.to_categorical(y_train)
Y_test = keras.utils.to_categorical(y_test)

input_shape = X_train.shape[1:]
train_samples = X_train.shape[0]
test_samples = X_test.shape[0]

# ハイパパラメータ
epochs = 200
batch_size = 50

畳み込み層以外の実装

最初に畳み込み層以外の実装を見ていきます。共通して使うモデルため、生成関数generate_modelを定義してます。

  • 最初の畳み込み層で (32, 32, 3) -> (16, 16, 32) にします
  • 畳み込み層はblock_sets*blocks*block_layers層で構成します
    • block_fでブロックの実装を関数で渡します (後述)
    • 1ブロックセット終了時に、2x2でプーリングし、次のブロックセットのフィルタ数を倍に拡張
  • 全結合層で 畳み込み層の出力 -> 100 -> 10 で絞っていきます
  • 畳み込み後にBatchNormalization, ReLU, Dropoutを行う
# モデルの生成
def generate_model(input_shape, block_f, blocks, block_sets, block_layers=2, first_filters=32, kernel_size=(3,3)):
  inputs = Input(shape=input_shape)
  
  # 入力層
  x = Conv2D(filters=first_filters, kernel_size=kernel_size, padding='same')(inputs)
  x = BatchNormalization()(x)
  x = ReLU()(x)
  x = MaxPool2D((2, 2))(x)
  x = Dropout(0.2)(x)
  
  # 畳み込み層
  for s in range(block_sets):
    filters =  first_filters * (2**s)
    
    for b in range(blocks):
      x = block_f(x, kernel_size, filters, block_layers)
      
    x = MaxPool2D((2, 2))(x)
    x = Dropout(0.2)(x)
  
  # 出力層
  x = Flatten()(x)
  x = Dropout(0.4)(x)
  x = Dense(100)(x)
  x = ReLU()(x)
  outputs = Dense(10, activation='softmax')(x)
  
  model = Model(input=inputs, output=outputs)
  
  return model

shortcut connection無しのブロック

layersの数だけ畳み込み層を追加する実装です。畳み込み後にBatchNormalizationとReLUを付加してます。

# shortcut connection無しのブロック
def plain_block(x, kernel_size, filters, layers):
  for l in range(layers):
    x = Conv2D(filters, kernel_size, padding='same')(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    
  return x

shortcut connection有りのブロック (residual block)

shortcut connection有りのブロックも同様にlayers層の畳み込み層を構成しますが、入力信号をshortcut_xに保持しておき、最後のReLU関数の前にAddでdeep pathの出力信xと加算してます。

  • 加算ではフィルタ数が一致する必要があるため、ズレている場合は1x1の畳み込み層を噛ませてフィルタ数を揃えてます
# shortcut path有りのブロック (residual block)
def residual_block(x, kernel_size, filters, layers=2):
  shortcut_x = x
  
  for l in range(layers):
    x = Conv2D(filters, kernel_size, padding='same')(x)
    x = BatchNormalization()(x)
    
    if l == layers-1:
      if K.int_shape(x) != K.int_shape(shortcut_x):
        shortcut_x = Conv2D(filters, (1, 1), padding='same')(shortcut_x)  # 1x1フィルタ
      
      x = Add()([x, shortcut_x])
      
    x = ReLU()(x)
    
  return x

検証

3つのモデルを生成・学習し、train lossを比較します。

  • 引数block_fをthin/plainとresidualで切り替えていることに着目してください
# ハイパパラメータ
epochs = 200
batch_size = 50

# thinモデル
thin_model  = generate_model(input_shape, plain_block, blocks=1, block_sets=1)
thin_model.compile(loss=keras.losses.categorical_crossentropy, optimizer=keras.optimizers.Adam(), metrics=['accuracy'])
thin_history = thin_model.fit(X_train, Y_train, epochs=epochs, batch_size=batch_size, validation_data=(X_test, Y_test))

# plainモデル
plain_model  = generate_model(input_shape, plain_block, blocks=3, block_sets=2)
plain_model.compile(loss=keras.losses.categorical_crossentropy, optimizer=keras.optimizers.Adam(), metrics=['accuracy'])
plain_history = plain_model.fit(X_train, Y_train, epochs=epochs, batch_size=batch_size, validation_data=(X_test, Y_test))

# residualモデル
residual_model  = generate_model(input_shape, residual_block, blocks=3, block_sets=2)
residual_model.compile(loss=keras.losses.categorical_crossentropy, optimizer=keras.optimizers.Adam(), metrics=['accuracy'])
residual_history = residual_model.fit(X_train, Y_train, epochs=epochs, batch_size=batch_size, validation_data=(X_test, Y_test))

CIFAR-10で実験

それではtrain lossを示します。今回の実験では、畳み込み層が浅く、精度・収束速度の両方においてResNetの持ち味は発揮できていないようです。

  • thinとplain/residualを比較すると、一段小さくなっており、畳み込み層を厚くすることで特徴抽出が改善できていることを確認できます
  • plainとresidualを比較すると、わずかですが一貫してplainの方が小さい値となってます

f:id:ohke:20190621163549p:plain

畳み込み層をさらに増やして実験 (13層 -> 25層)

上の結果を受けて、今度はplainとresidualの畳み込み層を13層から25層まで増やして学習させてみます。

  • 1個のブロックセットが持つ畳み込み層を6層から12層に増やしています
    • 1+12*2で、合計25層となります
  • それ以外は変更を加えていません
# plainモデル
plain_model  = generate_model(input_shape, plain_block, blocks=6, block_sets=2, first_filters=32)

# residualモデル
residual_model  = generate_model(input_shape, residual_block, blocks=6, block_sets=2, first_filters=32)

同じくtrain lossを比較すると、常にplainモデルよりもresidualモデルが下回っており、また収束も速くなってます。
一方でplainモデルは、12層バージョンよりもlossが大きい傾向にあり、層が増えると学習が難しくなることも伺えます。

f:id:ohke:20190621201147p:plain

検証データでのaccuracyで比較しても、plainモデルより3〜4%程度改善されています

  • 13層のplainモデルと比較しても、約2%改善しています

f:id:ohke:20190621201655p:plain

まとめ

ResNetをKerasを使って実装し、CIFAR-10タスクでプレーンなCNNと比較しました。畳み込み層を深くすると、精度・収束速度の両方でResNetの方が良くなったことを確認しました。


  1. Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun, Deep Residual Learning for Image Recognition (2015) (arXiv)