Keras: スパムメッセージをLSTMで分類する

KerasでRNN (LSTM) を実装し、スパムメッセージを分類してみます。

以前、同じデータセットに対してscikit-learnを使ってナイーブベイズで分類を行いましたが、いわばそのディープラーニング版となります。

データの準備

まずは以下からデータセットをダウンロード・展開します。

UCI Machine Learning Repository: SMS Spam Collection Data Set

$ wget https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip
$ unzip smsspamcollection.zip

"SMSSpamCollection"というファイルで展開されますので、DataFrameにロードします。独立変数はスパムメッセージテキスト ("text") 、従属変数がスパムかどうかのカテゴリ変数 ("category"、スパムなら1) となります。

import pandas as pd

pd.set_option("display.max_colwidth", 100)

dataset_df = pd.read_csv('./SMSSpamCollection', sep='\t', header=None)
dataset_df.rename({0: 'label', 1: 'text'}, axis=1, inplace=True)
dataset_df['category'] = dataset_df.apply(lambda r: 1 if r['label'] == 'spam' else 0, axis=1)

dataset_df.head()

f:id:ohke:20190427113459p:plain

次にデータセット全体を、学習データ4457件とテストデータ1115件 (概ね1:4) に分離します。

from sklearn.model_selection import train_test_split

X_train, X_test, Y_train, Y_test = train_test_split(
    dataset_df[['text']], dataset_df[['category']], 
    test_size=0.2, random_state=0
)

print(X_train.shape, X_test.shape, Y_train.shape, Y_test.shape)
# (4457, 1) (1115, 1) (4457, 1) (1115, 1)

テキストの前処理

スパムテキストに前処理を施していきます。2段階の処理となります。

  • Step 1. トークナイズ keras.preprocessing.text.Tokenizer
    • 1つのテキストをトークン (単語) 列に分離して辞書を作り、各単語のインデックスに変換したベクトルを作ります
    • 単語とインデックスの対応付けはword_indexプロパティに持ってます
    • 小文字への変換や記号の除去などはデフォルトで行われます
  • Step 2. パディング keras.preprocessing.sequence.pad_sequences()
    • 1つのテキストが含むトークン数は可変ですので、LSTMのセル数と揃えるためにパディングします (長さmaxlenを100としてます)
    • パディングは前から詰められ、値は0となります
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences

max_len = 100  # 1メッセージの最大単語数 (不足分はパディング)

tokenizer = Tokenizer()
tokenizer.fit_on_texts(X_train['text'])
x_train = tokenizer.texts_to_sequences(X_train['text'])
x_test = tokenizer.texts_to_sequences(X_test['text'])

for text, vector in zip(X_train['text'].head(3), x_train[0:3]):
    print(text)
    print(vector)
# No I'm good for the movie, is it ok if I leave in an hourish?
# [38, 32, 56, 12, 5, 636, 9, 14, 47, 36, 1, 208, 8, 128, 3810]
# If you were/are free i can give. Otherwise nalla adi entey nattil kittum
# [36, 3, 204, 21, 51, 1, 29, 138, 949, 2527, 3811, 3812, 3813, 3814]
# Have you emigrated or something? Ok maybe 5.30 was a bit hopeful...
# [17, 3, 3815, 26, 185, 47, 404, 209, 740, 62, 4, 299, 3816]

x_train = pad_sequences(x_train, maxlen=max_len)
x_test = pad_sequences(x_test, maxlen=max_len)

print(x_train[0])
# [   0    0    0    0    0    0    0    0    0    0    0    0    0    0
#     0    0    0    0    0    0    0    0    0    0    0    0    0    0
#     0    0    0    0    0    0    0    0    0    0    0    0    0    0
#     0    0    0    0    0    0    0    0    0    0    0    0    0    0
#     0    0    0    0    0    0    0    0    0    0    0    0    0    0
#     0    0    0    0    0    0    0    0    0    0    0    0    0    0
#     0   38   32   56   12    5  636    9   14   47   36    1  208    8
#   128 3810]

ちなみに、学習データに現れない単語については除去されます。以下ではテストデータで初めて現れた"latelyxxx"という単語が除去されていることを確認してます。

# "latelyxxx"は学習データに現れない単語
print(X_test['text'].iloc[9])
# Hey hun-onbus goin 2 meet him. He wants 2go out 4a meal but I donyt feel like it cuz have 2 get last bus home!But hes sweet latelyxxx
print(x_test[9])
# [   0    0    0    0    0    0    0    0    0    0    0    0    0    0
#     ...
#     0    0    0    0  133 1824  535   19  156  135   67  600 3350   52
#  4782 3437   24    1  227   57   14 1267   17   19   33  183  450   81
#    24  392]

# 392は"sweet" -> 末尾の"latelyxxx"は消されている
print(tokenizer.word_index['sweet'])  # 392
print('latelyxxx' in tokenizer.word_index.keys())  # False

モデルの実装

モデルの実装は3段階です。図を載せます (青が次元数です) 。

  • Step 1. 単語埋め込み層 keras.layers.Embedding
    • one-hotエンコーディングの疎行列 (学習データの語彙数+1) から密行列 (ここでは32次元) に変換します
    • (内部でどういったアルゴリズムで計算されているのかは、ちょっと調べたのですがわかりませんでした... 今後の宿題です)
  • Step 2. LSTM層 keras.layers.LSTM
    • LSTM層では16次元の記憶セルを設定
    • return_sequences=Falseとすることで、最後のセルの出力のみをLSTM層の出力としている
  • Step 3. 全結合層 keras.layers.Dense
    • 最後のLSTMセルの出力とバイアス (16+1=17次元) を全結合
    • 2値分類タスクのため、出力数を1、活性化関数をシグモイドとしている
from keras.models import Sequential
from keras.layers import LSTM, Dense, Embedding

vocabulary_size = len(tokenizer.word_index) + 1  # 学習データの語彙数+1

model = Sequential()

model.add(Embedding(input_dim=vocabulary_size, output_dim=32))
model.add(LSTM(16, return_sequences=False))
model.add(Dense(1, activation='sigmoid'))

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

model.summary()
# _________________________________________________________________
# Layer (type)                 Output Shape              Param #   
# =================================================================
# embedding_15 (Embedding)     (None, None, 32)          257920    
# _________________________________________________________________
# lstm_18 (LSTM)               (None, 16)                3136      
# _________________________________________________________________
# dense_17 (Dense)             (None, 1)                 17        
# =================================================================
# Total params: 261,073
# Trainable params: 261,073
# Non-trainable params: 0
# _________________________________________________________________

学習とテスト

得られたモデルを学習させてみたところ、検証データで98.92%の精度となりました。ナイーブベイズでは97.1%の精度でしたので、それより1.8%も改善されています。

学習データでは100%の精度を実現しており、若干過学習している傾向が見られます。

# テストデータの設定
y_train = Y_train['category'].values
y_test = Y_test['category'].values

# 学習
history = model.fit(
    x_train, y_train, batch_size=32, epochs=10,
    validation_data=(x_test, y_test)
)
# Train on 4457 samples, validate on 1115 samples
# Epoch 1/10
# 4457/4457 [==============================] - 18s 4ms/step - loss: 0.3035 - acc: 0.9051 - val_loss: 0.1221 - val_acc: 0.9722
# Epoch 2/10
# 4457/4457 [==============================] - 16s 4ms/step - loss: 0.0747 - acc: 0.9856 - val_loss: 0.0685 - val_acc: 0.9848
# Epoch 3/10
# 4457/4457 [==============================] - 16s 4ms/step - loss: 0.0383 - acc: 0.9955 - val_loss: 0.0551 - val_acc: 0.9857
# Epoch 4/10
# 4457/4457 [==============================] - 16s 4ms/step - loss: 0.0210 - acc: 0.9978 - val_loss: 0.0516 - val_acc: 0.9874
# Epoch 5/10
# 4457/4457 [==============================] - 16s 4ms/step - loss: 0.0121 - acc: 0.9989 - val_loss: 0.0462 - val_acc: 0.9839
# Epoch 6/10
# 4457/4457 [==============================] - 16s 4ms/step - loss: 0.0083 - acc: 0.9993 - val_loss: 0.0515 - val_acc: 0.9883
# Epoch 7/10
# 4457/4457 [==============================] - 16s 4ms/step - loss: 0.0056 - acc: 0.9996 - val_loss: 0.0472 - val_acc: 0.9892
# Epoch 8/10
# 4457/4457 [==============================] - 16s 3ms/step - loss: 0.0037 - acc: 0.9996 - val_loss: 0.0480 - val_acc: 0.9883
# Epoch 9/10
# 4457/4457 [==============================] - 16s 4ms/step - loss: 0.0026 - acc: 0.9998 - val_loss: 0.0504 - val_acc: 0.9892
# Epoch 10/10
# 4457/4457 [==============================] - 16s 4ms/step - loss: 0.0019 - acc: 1.0000 - val_loss: 0.0514 - val_acc: 0.9892

最後に検証データを使った判定結果を示します。混合行列からも、非スパムに分類されやすい傾向が見られます。

from sklearn.metrics import confusion_matrix

y_pred = model.predict_classes(x_test)

print(confusion_matrix(y_test, y_pred))
# [[953   2]
#  [ 10 150]]

# 非スパムと誤判定したメッセージ
print(X_test[y_test > y_pred.reshape(-1)]['text'])
# 684     Hi I'm sue. I am 20 years old and work as a lapdancer. I love sex. Text me live - I'm i my bedro...
# 731     Email AlertFrom: Jeri StewartSize: 2KBSubject: Low-cost prescripiton drvgsTo listen to email cal...
# 660                                             88800 and 89034 are premium phone services call 08718711108
# 751     Do you realize that in about 40 years, we'll have thousands of old ladies running around with ta...
# 4213                               Missed call alert. These numbers called but left no message. 07008009200
# 3864    Oh my god! I've found your number again! I'm so glad, text me back xafter this msgs cst std ntwk...
# 5449                                 Latest News! Police station toilet stolen, cops have nothing to go on!
# 415                                                       100 dating service cal;l 09064012103 box334sk38ch
# 1430                                For sale - arsenal dartboard. Good condition but no doubles or trebles!

# スパムと誤判定したメッセージ
print(X_test[y_test < y_pred.reshape(-1)]['text'])
# 2368    V nice! Off 2 sheffield tom 2 air my opinions on categories 2 b used 2 measure ethnicity in next...
# 2340    Cheers for the message Zogtorius. Ive been staring at my phone for an age deciding whether to t...
# Name: text, dtype: object

まとめ

KerasでLSTMを使ってスパムメッセージの分類を行いました。