KerasでDCGANを作ってKMNISTのくずし字を生成する

KMNISTのくずし字をDCGANで生成する、というモデルをKerasで作ります。

DCGAN

DCGAN (Deep Convolutional GAN) はGAN (Generative Adversarial Network) の生成モデルの一種で、画像を生成するものです (提案論文) 。

GANは2つのモデルを学習によって獲得します。生成モデルは判別モデルを騙すように、判別モデルは生成モデルに騙されないように、それぞれ学習を行うことで、より高度な生成モデルを作ることが目的です。

  • 生成モデル (generator): データセットに近い (ありそうな) データを生成するモデル
  • 判別モデル (discriminator): データセットのデータ (本物) なのか、生成されたデータ (贋物) なのかを判別するモデル

DCGANでは、100次元の一様分布乱数を入力として画像を生成します (提案論文Figure 1抜粋) 。

f:id:ohke:20190518161024p:plain

データの準備

KMNISTくずし字データセット (doi:10.20676/00000341) を最初に準備します。KerasでLeNet-5を実装してKuzushiji-MNISTを分類する - け日記も参考にしてみてください。

codh.rois.ac.jp

今回は学習用の画像のみを使います。npz形式でダウンロードしておきます。

$ wget http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-train-imgs.npz

numpyでロードします。28ピクセル x 28ピクセル x 1チャネル (=グレースケール) x 60,000枚となります。文字は10種です。

import numpy as np

# データのロード
X_images = np.load('kmnist-train-imgs.npz')['arr_0'][:, :, :, np.newaxis]  # (60000, 28, 28, 1)

# -1〜+1に正規化
X_images = (X_images - [127.5]) / 127.5

モデルの作成

上の通り、DCGANは生成モデルと判別モデルからなりますので、それぞれKerasで実装していきます。

f:id:ohke:20190518172846p:plain

各モデルの実装は↓の書籍を参考にしています。

直感 Deep Learning ―Python×Kerasでアイデアを形にするレシピ

直感 Deep Learning ―Python×Kerasでアイデアを形にするレシピ

  • 作者: Antonio Gulli,Sujit Pal,大串正矢,久保隆宏,中山光樹
  • 出版社/メーカー: オライリージャパン
  • 発売日: 2018/08/17
  • メディア: 単行本(ソフトカバー)
  • この商品を含むブログを見る

import keras
from keras.models import Sequential
from keras.layers import (
    Dense, BatchNormalization, Activation, Reshape, UpSampling2D, 
    Conv2D, MaxPooling2D, Flatten
)
from keras.optimizers import Adam

生成モデルの実装

生成モデルでは100次元の一様分布の乱数 (-1〜+1) を入力とし、最終的に28 x 28 x 1で出力しています。

  • UpSampling2Dでは、画像の縦横を拡大します
    • size=(2,2)の場合、各セル (1x1) が2x2へコピーされます
  • 提案論文では64 x 64 x 3となっていますが、KMNISTに揃えています
# generatorの定義
def generator_model(name='generator'):
    model = Sequential(name=name)
    
    model.add(Dense(1024, input_shape=(100,), activation='tanh'))
    model.add(Dense(128*7*7))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(Reshape((7, 7, 128), input_shape=(7*7*128,)))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(64, (5, 5), padding='same', activation='tanh', data_format='channels_last'))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(1, (5, 5), padding='same', activation='tanh', data_format='channels_last'))
    
    return model

判別モデルの実装

次に判別モデルですが、畳み込み層 x 2と全結合層 x 1の単純な構造です。

# discriminatorの定義
def discrimanator_model(name='discriminator'):
    model = Sequential(name=name)
    
    model.add(Conv2D(64, (5, 5), padding='same', input_shape=(28, 28, 1), activation='tanh', data_format='channels_last'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(128, (5, 5), activation='tanh', data_format='channels_last'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(1024, activation='tanh'))
    model.add(Dense(1, activation='sigmoid'))
    
    return model

DCGANモデルの実装

生成モデルと判別モデルを組み合わせます。それぞれ独立して学習するために、別々にcompileしていることを確認してください。

  • 生成モデル学習中に判別モデルのパラメータを変えてはいけないので、判別モデルを学習しないようにしてます (discriminator.trainable = False)
    • この場合でも、discriminatorを直接学習する場合は、パラメータは更新されます
# dcganの定義
def dcgan_model(generator, discriminator):
    model = Sequential()
    
    model.add(generator)
    model.add(discriminator)
    
    return model

# 各コンポーネントの生成
def create_models():
    discriminator = discrimanator_model()
    print(discriminator.summary())

    discriminator.compile(
        loss=keras.losses.binary_crossentropy, 
        optimizer=Adam()
    )

    discriminator.trainable = False
    
    generator = generator_model()
    print(generator.summary())

    dcgan = dcgan_model(generator, discriminator)
    print(dcgan.summary())

    dcgan.compile(
        loss=keras.losses.binary_crossentropy, 
        optimizer=Adam()
    )
    
    return (dcgan, generator, discriminator)

dcgan, generator, discriminator = create_models()

モデルの学習

学習は以下のステップで行います。

  1. generatorで、画像を生成し、データセットの画像と混ぜる (ここでは30枚ずつ1:1です)
  2. discriminatorで、生成画像 (y=0) or データセット画像 (y=1) を判別するように、ミニバッチで学習する
  3. generatorで、一様乱数から画像を生成し、discriminatorでy=1と判別されるように、同じくミニバッチで学習する (全て生成画像を入力とするため、目的変数はすべて1)
epochs = 30
batch_size = 30
batches = int(X_images.shape[0] / batch_size)
z_dimensions = 100

discriminator_losses = []
generator_losses = []

for epoch in range(epochs):
    for i in range(batches):
        # データセット画像 (本物)
        genuine_images = X_images[i*batch_size : i*batch_size + batch_size]  # (30, 28, 28, 1)

        # 生成画像 (贋物)
        noise = np.random.uniform(-1, 1, (batch_size, z_dimensions)) # (30, 100)
        fake_images = generator.predict(noise)  # (30, 28, 28, 1)

        # discriminatorの学習
        X = np.concatenate([genuine_images, fake_images])  # データセット画像と生成画像を混ぜる
        y = [1] * batch_size + [0] * batch_size  # データセットの画像を1とするように目的変数をセット
        discriminator_loss = discriminator.train_on_batch(X, y)
        
        # generatorの学習
        noise = np.random.uniform(-1, 1, (batch_size, z_dimensions)) # (30, 100)
        generator_loss =  dcgan.train_on_batch(noise, [1]*batch_size)  # データセット画像 (y=1) と誤判定されたいの (全て1)
        
    print('epoch ', epoch, ': discriminator_loss=', discriminator_loss, 'generator_loss=', generator_loss)
    generator_losses.append(generator_loss)
    discriminator_losses.append(discriminator_loss)
    
    # エポック完了後、生成した画像の一部を書き出す
    save_images(fake_images, str(epoch))

生成された画像

学習の結果、生成された画像を見ていきましょう。なおデータセットには以下のような10種の文字が含まれています (こちらより抜粋) 。

f:id:ohke:20190518183657p:plain

epoch 0終了後

最初のエポック終了後ですが、まだ文字らしくなっていません。

f:id:ohke:20190518182704p:plain f:id:ohke:20190518182719p:plain f:id:ohke:20190518182729p:plain f:id:ohke:20190518182738p:plain f:id:ohke:20190518182748p:plain

epoch 4終了後

文字らしくなってきており、"は (ハ)" や "つ" などは、形になってきてます。

f:id:ohke:20190518183525p:plain f:id:ohke:20190518183536p:plain f:id:ohke:20190518183545p:plain f:id:ohke:20190518183552p:plain f:id:ohke:20190518183601p:plain

epoch 29終了後

生成された文字のほとんどが "は (ハ)" や "つ" に偏っていました。これらの文字は字形が単純で生成しやすく、結果そればかりを生成するモデルになったと思われます。文字種ごとに分ける、などデータセットで分離する必要がありそうです。

f:id:ohke:20190518190331p:plain f:id:ohke:20190518190341p:plain f:id:ohke:20190518190351p:plain f:id:ohke:20190518190357p:plain f:id:ohke:20190518190405p:plain