torchtextを使った自然言語処理タスクの前処理

PyTorch: DatasetとDataLoader (画像処理タスク編) - け日記 にてDatasetとDataLoaderの使い方を紹介しました。

今回は自然言語処理のにフォーカスし、torchtextを使った自然言語処理 (NLP) タスクの前処理について整理します。

NLPタスクの前処理

大別すると5段階に分けられます。ディープラーニングではミニバッチ単位でこの前処理を行います。

  1. クリーニング (HTMLタグの除去など)
  2. 単語への分割 (トークナイズ)
  3. 単語の正規化 (ステミングなど)
  4. ストップワードの除去
  5. 単語のベクトル化

こちらの記事が非常によく整理されてます。

torchtext

PyTorchでNLPタスクに取り組むにあたって、当然こういった前処理の実装も必要となります。この前処理を共通化・構造化するお便利ユーティリティパッケージがtorchtextです。

大まかな流れは、画像処理タスクの場合と同じですが、torchtextで提供されたクラスをそれぞれ使います。

  1. データソース (CSVファイルなど) から1個ずつデータをロードして前処理 (torchtext.data.Datasetの継承クラス)
  2. Datasetからデータを取り出してミニバッチにまとめる (torchtext.data.Batchを内包するクラス)

torchtext特有の概念として、Datasetは1つ以上のFieldを持ちます。
Field単位で、入力文 or (何らかの) 正解ラベルなのかを指定したり、前処理を定義できたりします。

実装

スパムメッセージのデータセット (UCI Machine Learning Repository: SMS Spam Collection Data Set) を使いながら、実際の実装を見ていきます。

最初に上記リンクからデータセットをダウンロード・解凍します。
展開されたSMSSpamCollectionは、ヘッダ無し・タブ区切りのテキストとなっており、1列目がham or spam (正解ラベル) 、2列目がメッセージ本文となってます。

$ wget https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip
$ unzip smsspamcollection.zip
$ head SMSSpamCollection
$ head -n3 SMSSpamCollection
ham Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...
ham Ok lar... Joking wif u oni...
spam    Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's

今回はtorchtext 0.5.0を使っています。

import string
import re
import torch
import torchtext

print(torch.__version__)  # 1.3.1
print(torchtext.__version__)  # 0.5.0

nltkのインストールやステミングについてはこちらもご参考ください。

Field

最初にFieldを定義します。

Fieldの役割は、文章やラベルに対してクリーニング・トークナイズ・正規化・ストップワードの除去を行い、数値表現 (= Tensor) に変換することです。具体的な前処理の実装は、このFieldに定義することになります。

# 入力文フィールド
text_field = torchtext.data.Field(
    sequential=True,  # Falseなら、tokenizeしない (デフォルト: True)
    use_vocab=True,  # Trueなら、ボキャブラリを使って数値化する
    init_token="<INIT>",  # 文頭に付与するトークン (デフォルト: None)
    eos_token="<EOS>",  # 文末に付与するトークン (デフォルト: None)
    lower=True,  # 小文字に統一 (デフォルト: False)
    fix_length=256,  # 1文あたりのワード数 (不足する場合はパディング)
    pad_first=False,  # パディングを最初にいれるかどうか (デフォルト: False)
    truncate_first=False,  # fix_lengthを超えた場合に先頭から削るかどうか (デフォルト: False -> 後ろから削られる)
    tokenize=tokenize,  # トークナイズ処理 (デフォルト: str.split)
    stop_words=set(stopwords.words("english")),  # 除去するストップワード
    preprocessing=torchtext.data.Pipeline(preprocessing),  # 後述 (デフォルト: None)
    postprocessing=None,  # ミニバッチ単位で行う処理 (デフォルト: None)
    dtype=torch.long,  # データの型 (デフォルト: torch.long)
    batch_first=True,  # ミニバッチの次元を最初に追加するかどうか (デフォルト: False)
    include_lengths=False,  # パディングした文とあわせて長さを返すかどうか (デフォルト: False)
    is_target=False  # ラベルフィールドかどうか (デフォルト: False) 
)

# ラベルフィールド
label_field = torchtext.data.Field(
    sequential=False,
    use_vocab=False,
    preprocessing=lambda l: 1 if l == "spam" else 0,
    is_target=True
)

def tokenize(s: str) -> list:
    # 記号はスペースで置換して除去
    for p in string.punctuation:
        s = s.replace(p, " ")
    
    # 連続する空白は1つにする
    s = re.sub(r" +", r" ", s).strip()
    
    # スペースで分割
    return s.split()

stemmer = PorterStemmer()
def preprocessing(s: str) -> str:
    # ステミング (語幹抽出)
    return stemmer.stem(s)

引数の詳細はコメントを見ていただくとして、Pipelineとpreprocessingについて補足します。

Pipeline

preprocessingではPipelineオブジェクトを渡す必要があります。

Pipelineクラスはトークン (またはトークンリスト) からトークン (またはトークンリスト) の変換を行います。(加えて、他のPielineオブジェクトを数珠つなぎにして、一連の処理として定義することもできます。)

トークンからトークンへの変換関数 (上の例ではpreprocessing) を渡して初期化してます。

    ...
    preprocessing=torchtext.data.Pipeline(preprocessing),
    ...

preprocessの順序

Fieldクラスの前処理はpreprocessメソッドで行われるのですが、以下の順番で処理されます。

  1. UTF-8へのエンコード (Python 2のみ)
  2. トークナイズ
  3. 小文字化 (lower=Trueのみ)
  4. ストップワードの除去 (stop_wordsに値が設定されている場合のみ)
  5. preprocessingの実行 (preprocessingが設定されている場合のみ)

注目すべきはトークナイズの後にpreprocessingが実行される点です。トークナイズの前、つまり分割された単語ではなく1つの文を入力として何らかの処理を行いたい場合は、tokenizeに含める必要があります。

例えば "100", "23,000", "3.14" といった数値表現を全て "0" に置き換える場合、preprocessの段階では "," や "." で分割されて別々の数字トークンとなるため、1つの数字トークンとして処理できません。
これを回避するためには、tokenizeの時点で数字の間の "," や "." を除去することで1つの数字トークンにするなどの工夫が必要となります。

Dataset

次に、ファイル (SMSSpamCollection) からデータを取り出して、上で定義したFieldを紐付けるDatasetを作成します。

CSV・TSV・JSONの構造化されたテキストファイルをインプットとするTabularDatasetが提供されてます。このクラスを使い、SMSSpamCollectionを入力としたDatasetを定義します。
fieldsにはタプルのリストを渡します。1リストアイテム=1行、1タプル=1カラム (フィールド) にそれぞれ対応します。タプルの最初の文字列 ("Label" と "Text") は、後ほどDatasetオブジェクトからフィールドにアクセスするための属性名となります。

加えて、splitメソッドでtrainとtestに分割しておきます。

# TSVファイルからDatasetにロード
dataset = torchtext.data.TabularDataset(
    path="./SMSSpamCollection",
    format="tsv", # tsv, csv, json
    fields=[("Label", label_field), ("Text", text_field)],
    skip_header=False,
)

# trainとtestで分割
train_dataset, test_dataset = dataset.split(
    split_ratio=0.8,
)

単語ベクトル

torchtextでは単語ベクトルをVectorsの継承クラスで実装されており、いくつか学習済みのベクトルを使うインタフェースが提供されています。今回はFastTextを使います。

  • 最初のロードではダウンロードが走りますが、以降はキャッシュされます
  • ID -> 文字列の対応を表すitos (list) と 文字列 -> IDの対応を表すstoi (dict) のプロパティを持っています
  • ベクトル値はvectors (IDをキーとするlist) からで取得できます
fasttext = torchtext.vocab.FastText(language="en")
# 他の学習済みモデルを使う場合
# vectors = torchtext.vocab.Vectors("./wiki-news-300d-1M.vec")

print(type(fasttext.itos), type(fasttext.stoi))
# <class 'list'> <class 'dict'>
print(len(fasttext.itos))  # 2519370
print(fasttext.vectors[fasttext.stoi["get"]])
# tensor([ 1.6691e-01,  2.0164e-02, -9.7542e-02, -3.3805e-02, -2.6137e-01,
#          ...
#          8.2180e-02, -3.0632e-01,  2.1503e-01,  1.9489e-01, -1.5824e-01])

ボキャブラリの作成

次に学習用のDatasetからボキャブラリを作成します。

作成済みのField (ここではtext_field) のbuild_vocabメソッドを呼び出すと、vocab属性が追加されます。
vocabはVocabオブジェクトで、データセットに現れた単語とそのベクトルに加えて、パディングやUNKも含められます。

  • "is"は存在しないトークン (ストップワードで除去) であるため、stoi["is"]は0、すなわち""となります
print(hasattr(text_field, "vocab"))  # False

text_field.build_vocab(train_dataset, vectors=fasttext, min_freq=3)

print(hasattr(text_field, "vocab"))  # True

print(text_field.vocab.vectors.shape)  # torch.Size([2444, 300])
print(text_field.vocab.vectors)
# tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
#         ...
#         [-0.1913, -0.1712,  0.0094,  ...,  0.1175,  0.2918, -0.0492]])

print(text_field.vocab.itos[:10])
# ['<unk>', '<pad>', '<INIT>', '<EOS>', 'u', 'call', '2', 'get', 'go', 'ur']

print(text_field.vocab.stoi["get"])  # 7
print(text_field.vocab.itos[7])  # get
print(text_field.vocab.vectors[text_field.vocab.stoi["get"]])
# tensor([ 1.6691e-01,  2.0164e-02, -9.7542e-02, -3.3805e-02, -2.6137e-01,
#          ...
#          8.2180e-02, -3.0632e-01,  2.1503e-01,  1.9489e-01, -1.5824e-01])

# 存在しないトークン -> "<unk>"
print(text_field.vocab.stoi["is"])  # 0

Iterator

最後にDatasetからミニバッチ単位でデータをロードする処理を実装します。

本家PyTorchのDataLoaderにあたるユーティリティクラスとして、torchtextではIterator (またはその派生クラス) が提供されています。以下のように使います。

  • ミニバッチごとにFieldで定義した属性名 (Text, Label) でアクセスしてます
  • ミニバッチごとに、Textの次元数は (batch_size, fix_length) 、Labelの次元数は (batch_size) となります

サイズを小さくするために、Textフィールドに単語ベクトル値ではなく単語IDが入っている点に注意してください。ベクトル表現はネットワーク側で取得する必要があります。

batch_size = 8

train_iterator = torchtext.data.Iterator(
    train_dataset, batch_size=batch_size, 
    train=True  # train=Trueならシャッフルソートは有効
)
test_iterator = torchtext.data.Iterator(
    test_dataset, batch_size=batch_size, 
    train=False, sort=False
)

batch = next(iter(train_iterator))
print(batch.Text.size())  # torch.Size([8, 256])
print(batch.Text[0])
# tensor([  2, 604,   8,  37,   8,   0,   0,  34,  18,   8,   6, 708, 161,   3,
#           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
#           ...
#           1,   1,   1,   1])
print(batch.Label)  # tensor([0, 0, 0, 0, 0, 0, 0, 1])

単語埋め込み層

torchでベクトル表現に変換するモジュールは以下の実装となります。
Embeddingのクラスメソッドfrom_pretrainedにボキャブラリの単語ベクトルを渡すことで、Embeddingモジュールオブジェクトが生成されます。これを最初の層にすればOKです。

class WordEmbedding(torch.nn.Module):
    def __init__(self, vectors):
        super().__init__()

        self.embedding = torch.nn.Embedding.from_pretrained(
            embeddings=vectors, freeze=True
        )

    def forward(self, x):
        x = self.embedding(x)
        return x
    
word_embedding = WordEmbedding(text_field.vocab.vectors)
x = word_embedding(batch.Text)
print(x.size())  # torch.Size([8, 256, 300])
print(x[0])
# tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
#         [-0.2761, -0.0085, -0.1173,  ..., -0.0522,  0.1952,  0.0860],
#         [ 0.0769,  0.0298,  0.0174,  ...,  0.0628,  0.0840, -0.1559],
#         ...,
#         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
#         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
#         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]])

まとめ

今回はPyTorchでNLPタスクの前処理に便利なtorchtextを紹介しました。

参考