け日記

最近はPythonでいろいろやってます

KerasでLeNet-5を実装してKuzushiji-MNISTを分類する

仕事でそろそろコンピュータビジョン系の力が必要になるかも、となってきましたので、チクタク勉強を始めてます。

今回はKerasを使ってKuzushiji-MNISTの文字を分類するネットワークをLeNet-5で実装する、ということに取り組みます。

Kuzushiji-MNIST

日本の古典籍のくずし字の画像とラベルからなるデータセット (doi:10.20676/00000341) で、人文学オープンデータ共同利用センターによって作成されました。

codh.rois.ac.jp

3つのデータセットがありますが、今回は一番手頃なKuzushiji-MNISTを使います。

  • Kuzushiji-MNIST
    • ひらがな10クラスの分類タスク (28x28のグレースケール画像とラベルを70,000セットを含む)
  • Kuzushiji-49
    • ひらがな49クラスの分類タスク (28x28のグレースケール画像とラベルを270,912セットを含む)
  • Kuzushiji-Kanji
    • 漢字3832クラスの分類タスク (64x64のグレースケール画像とラベルを140,426セットを含む)

GitHub上で公開されてます。

github.com

@online{clanuwat2018deep,
  author       = {Tarin Clanuwat and Mikel Bober-Irizar and Asanobu Kitamoto and Alex Lamb and Kazuaki Yamamoto and David Ha},
  title        = {Deep Learning for Classical Japanese Literature},
  date         = {2018-12-03},
  year         = {2018},
  eprintclass  = {cs.CV},
  eprinttype   = {arXiv},
  eprint       = {cs.CV/1812.01718},
}

KerasとTensorflowをインストールしておきます。

pip install keras tensorflow

Kerasで実装・分類

データセットの取得・概観

こちらのURLからデータセットをダウンロードしておきます。今回はNumPyフォーマットを使います。

wget http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-train-imgs.npz
wget http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-train-labels.npz
wget http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-test-imgs.npz
wget http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-test-labels.npz

ダウンロードしたファイルをarrayにロードします。あわせて、ラベルデータも読み込みます。

  • 0ならば'お'、1ならば'き'、... というようにひらがな10クラスとなってます
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# データのロード
X_train_images = np.load('kmnist-train-imgs.npz')['arr_0']
Y_train_labels = np.load('kmnist-train-labels.npz')['arr_0']
X_test_images = np.load('kmnist-test-imgs.npz')['arr_0']
Y_test_labels = np.load('kmnist-test-labels.npz')['arr_0']

# ラベルデータのロード
label_map = pd.read_csv('http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist_classmap.csv')['char']
print(label_map)
# 0    お
# 1    き
# 2    す
# 3    つ
# 4    な
# 5    は
# 6    ま
# 7    や
# 8    れ
# 9    を
# Name: char, dtype: object
  • 学習データは60,000セット、テストデータは10,000セット
    • 学習データ・テストデータともに偏りはなく、10等分されてます
  • 画像は28x28で0〜255のグレースケール
# 画像の数・サイズを確認
print(X_train_images.shape, Y_train_labels.shape, X_test_images.shape, Y_test_labels.shape)
# ((60000, 28, 28, 1), (60000, 10), (10000, 28, 28, 1), (10000, 10))

# 値域を確認
np.min(X_train_images), np.max(X_train_images), np.min(X_test_images), np.max(X_test_images)
# (0, 255, 0, 255)

# ラベルの偏りを確認
print(np.unique(Y_train_labels, return_counts=True))
# (array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=uint8), array([6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000, 6000]))
print(np.unique(Y_test_labels, return_counts=True))
# (array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=uint8), array([1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000]))

# 学習データ36枚を表示
plt.figure(figsize=(10,10))
for i in range(36):
    plt.subplot(6, 6, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(X_train_images[i])
    plt.xlabel(Y_train_labels[i])
plt.show()

学習データを見てみますと、現代日本人ではなかなか理解できない文字となっていることがわかります。特に'を'は今とだいぶ違いますね。
公開されてから日も浅く、データの不備もいくつか指摘されています (https://github.com/rois-codh/kmnist/issues) ので、今後アノテーションが変わる可能性もあります。

KerasでLeNetを実装

それではKerasでモデルを実装していきます。

データの前処理

Kerasに入力できるように前処理を行います。

  • 各ピクセル (256階調の整数) を0.0〜1.0の小数へ正規化
  • チャネルの次元を追加 (RGBカラーなら3チャネルとなります)
  • keras.utils.to_categoricalでラベルをone-hotエンコーディング
import keras
from keras.models import Sequential
from keras.layers import Dense, Flatten
from keras.layers import Conv2D, AveragePooling2D

# 正規化
X_train = X_train_images.astype('float32')
X_test = X_test_images.astype('float32')
X_train /= 255
X_test /= 255

print(X_train.min(), X_train.max(), X_test.min(), X_test.max())
# (0.0, 1.0, 0.0, 1.0)

# チャネルの次元を加える
X_train = X_train.reshape(X_train.shape + (1,))
X_test = X_test.reshape(X_test.shape + (1,))

print(X_train.shape, X_test.shape)
# ((60000, 28, 28, 1), (10000, 28, 28, 1))

# one-hotエンコーディング
num_labels = label_map.size

Y_train = keras.utils.to_categorical(Y_train_labels, num_labels)
Y_test = keras.utils.to_categorical(Y_test_labels, num_labels)

print(Y_train.shape, Y_test.shape)
# ((60000, 10), (10000, 10))

モデルの作成

次にモデルを作ります。今回は元祖MNISTのLeNet-5を実装してみます (提案論文PDF、下図はFig 2.から抜粋) 。

  • X -> Conv -> Pooling -> Conv -> Pooling -> FC -> FC -> FC -> softmax -> y という簡単な構造です

input_shape = (X_train.shape[1], X_train.shape[2], 1)

model = Sequential()

model.add(Conv2D(6, kernel_size=(5, 5), strides=(1, 1), padding='same', activation='tanh', input_shape=input_shape))
model.add(AveragePooling2D((2, 2), strides=(2, 2)))
model.add(Conv2D(16, kernel_size=(5, 5), strides=(1, 1), padding='valid', activation='tanh'))
model.add(AveragePooling2D((2, 2), strides=(2, 2)))
model.add(Flatten())
model.add(Dense(120, activation='tanh'))
model.add(Dense(84, activation='tanh'))
model.add(Dense(num_labels, activation='softmax'))

model.compile(
    loss=keras.losses.categorical_crossentropy,
    optimizer=keras.optimizers.Adadelta(),
    metrics=['accuracy']
)

print(model.summary())

作成したモデルはsummaryメソッドで概観できます。

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_8 (Conv2D)            (None, 28, 28, 6)         156       
_________________________________________________________________
average_pooling2d_7 (Average (None, 14, 14, 6)         0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 10, 10, 16)        2416      
_________________________________________________________________
average_pooling2d_8 (Average (None, 5, 5, 16)          0         
_________________________________________________________________
flatten_4 (Flatten)          (None, 400)               0         
_________________________________________________________________
dense_9 (Dense)              (None, 120)               48120     
_________________________________________________________________
dense_10 (Dense)             (None, 84)                10164     
_________________________________________________________________
dense_11 (Dense)             (None, 10)                850       
=================================================================
Total params: 61,706
Trainable params: 61,706
Non-trainable params: 0
_________________________________________________________________

モデルの学習と検証

作成したモデルはfitメソッドで学習します。

  • エポック数とバッチサイズはハイパパラメータとして渡します

最終的には、学習データでは精度99%となってますが、テストデータで93%と、過学習の傾向が見られてます。

epochs = 30
batch_size = 1000

history = model.fit(x=X_train,y=Y_train, epochs=epochs, batch_size=batch_size, validation_data=(X_test, Y_test), verbose=1)
# Train on 60000 samples, validate on 10000 samples
# Epoch 1/30
# 60000/60000 [==============================] - 25s 423us/step - loss: 1.0245 - acc: 0.6948 - val_loss: 1.0961 - val_acc: 0.6553
# Epoch 2/30
# 60000/60000 [==============================] - 26s 431us/step - loss: 0.5329 - acc: 0.8411 - val_loss: 0.7834 - val_acc: 0.7537
# ...
# Epoch 30/30
# 60000/60000 [==============================] - 26s 431us/step - loss: 0.0280 - acc: 0.9933 - val_loss: 0.2718 - val_acc: 0.9298

# 損失関数の値の推移
epoch_array = np.array(range(30))
plt.plot(epoch_array, history.history['loss'], history.history['val_loss'])
plt.legend(['loss', 'val_loss'])
plt.show()

fitの返り値 (History) から学習過程の精度と損失関数の値を取得できます。

f:id:ohke:20190316140733p:plain

モデルを使った予測

最後に予測を行います。

予測はpredictメソッドで行います。10000の内、724の分類に失敗してました。

  • 間違ってラベル付けされた文字としては、1位が'す'、2位が'き'でした
Y_predict = model.predict(X_test)

# 誤ったテストデータを取得
labels_predict = np.argmax(Y_predict, axis=1)
labels_test = np.argmax(Y_test, axis=1)

miss_indexes = np.where(labels_predict != labels_test)[0]
print(len(miss_indexes)) # 724

miss_predict = labels_predict[miss_indexes]
miss_test = labels_test[miss_indexes]

print(np.unique(miss_test, return_counts=True))
# (array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
#  array([ 62, 102, 134,  45,  83,  80,  35,  69,  50,  64]))

# 誤った36枚を表示
plt.figure(figsize=(10,10))
for i in range(36):
    plt.subplot(6, 6, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(X_test_images[miss_indexes[i]])
    plt.xlabel(f'{miss_test[i]} -> {miss_predict[i]}')

plt.show()

誤った画像を見てみると、うーん、確かに解釈が難しそうですね...