torchtextを使った自然言語処理タスクの前処理
PyTorch: DatasetとDataLoader (画像処理タスク編) - け日記 にてDatasetとDataLoaderの使い方を紹介しました。
今回は自然言語処理のにフォーカスし、torchtextを使った自然言語処理 (NLP) タスクの前処理について整理します。
NLPタスクの前処理
大別すると5段階に分けられます。ディープラーニングではミニバッチ単位でこの前処理を行います。
- クリーニング (HTMLタグの除去など)
- 単語への分割 (トークナイズ)
- 単語の正規化 (ステミングなど)
- ストップワードの除去
- 単語のベクトル化
こちらの記事が非常によく整理されてます。
torchtext
PyTorchでNLPタスクに取り組むにあたって、当然こういった前処理の実装も必要となります。この前処理を共通化・構造化するお便利ユーティリティパッケージがtorchtextです。
大まかな流れは、画像処理タスクの場合と同じですが、torchtextで提供されたクラスをそれぞれ使います。
- データソース (CSVファイルなど) から1個ずつデータをロードして前処理 (torchtext.data.Datasetの継承クラス)
- Datasetからデータを取り出してミニバッチにまとめる (torchtext.data.Batchを内包するクラス)
torchtext特有の概念として、Datasetは1つ以上のFieldを持ちます。
Field単位で、入力文 or (何らかの) 正解ラベルなのかを指定したり、前処理を定義できたりします。
実装
スパムメッセージのデータセット (UCI Machine Learning Repository: SMS Spam Collection Data Set) を使いながら、実際の実装を見ていきます。
最初に上記リンクからデータセットをダウンロード・解凍します。
展開されたSMSSpamCollectionは、ヘッダ無し・タブ区切りのテキストとなっており、1列目がham or spam (正解ラベル) 、2列目がメッセージ本文となってます。
$ wget https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip $ unzip smsspamcollection.zip $ head SMSSpamCollection $ head -n3 SMSSpamCollection ham Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat... ham Ok lar... Joking wif u oni... spam Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's
今回はtorchtext 0.5.0を使っています。
import string import re import torch import torchtext print(torch.__version__) # 1.3.1 print(torchtext.__version__) # 0.5.0
nltkのインストールやステミングについてはこちらもご参考ください。
Field
最初にFieldを定義します。
Fieldの役割は、文章やラベルに対してクリーニング・トークナイズ・正規化・ストップワードの除去を行い、数値表現 (= Tensor) に変換することです。具体的な前処理の実装は、このFieldに定義することになります。
# 入力文フィールド text_field = torchtext.data.Field( sequential=True, # Falseなら、tokenizeしない (デフォルト: True) use_vocab=True, # Trueなら、ボキャブラリを使って数値化する init_token="<INIT>", # 文頭に付与するトークン (デフォルト: None) eos_token="<EOS>", # 文末に付与するトークン (デフォルト: None) lower=True, # 小文字に統一 (デフォルト: False) fix_length=256, # 1文あたりのワード数 (不足する場合はパディング) pad_first=False, # パディングを最初にいれるかどうか (デフォルト: False) truncate_first=False, # fix_lengthを超えた場合に先頭から削るかどうか (デフォルト: False -> 後ろから削られる) tokenize=tokenize, # トークナイズ処理 (デフォルト: str.split) stop_words=set(stopwords.words("english")), # 除去するストップワード preprocessing=torchtext.data.Pipeline(preprocessing), # 後述 (デフォルト: None) postprocessing=None, # ミニバッチ単位で行う処理 (デフォルト: None) dtype=torch.long, # データの型 (デフォルト: torch.long) batch_first=True, # ミニバッチの次元を最初に追加するかどうか (デフォルト: False) include_lengths=False, # パディングした文とあわせて長さを返すかどうか (デフォルト: False) is_target=False # ラベルフィールドかどうか (デフォルト: False) ) # ラベルフィールド label_field = torchtext.data.Field( sequential=False, use_vocab=False, preprocessing=lambda l: 1 if l == "spam" else 0, is_target=True ) def tokenize(s: str) -> list: # 記号はスペースで置換して除去 for p in string.punctuation: s = s.replace(p, " ") # 連続する空白は1つにする s = re.sub(r" +", r" ", s).strip() # スペースで分割 return s.split() stemmer = PorterStemmer() def preprocessing(s: str) -> str: # ステミング (語幹抽出) return stemmer.stem(s)
引数の詳細はコメントを見ていただくとして、Pipelineとpreprocessingについて補足します。
Pipeline
preprocessingではPipelineオブジェクトを渡す必要があります。
Pipelineクラスはトークン (またはトークンリスト) からトークン (またはトークンリスト) の変換を行います。(加えて、他のPielineオブジェクトを数珠つなぎにして、一連の処理として定義することもできます。)
トークンからトークンへの変換関数 (上の例ではpreprocessing) を渡して初期化してます。
... preprocessing=torchtext.data.Pipeline(preprocessing), ...
preprocessの順序
Fieldクラスの前処理はpreprocessメソッドで行われるのですが、以下の順番で処理されます。
- UTF-8へのエンコード (Python 2のみ)
- トークナイズ
- 小文字化 (
lower=True
のみ) - ストップワードの除去 (stop_wordsに値が設定されている場合のみ)
- preprocessingの実行 (preprocessingが設定されている場合のみ)
注目すべきはトークナイズの後にpreprocessingが実行される点です。トークナイズの前、つまり分割された単語ではなく1つの文を入力として何らかの処理を行いたい場合は、tokenizeに含める必要があります。
例えば "100", "23,000", "3.14" といった数値表現を全て "0" に置き換える場合、preprocessの段階では "," や "." で分割されて別々の数字トークンとなるため、1つの数字トークンとして処理できません。
これを回避するためには、tokenizeの時点で数字の間の "," や "." を除去することで1つの数字トークンにするなどの工夫が必要となります。
Dataset
次に、ファイル (SMSSpamCollection) からデータを取り出して、上で定義したFieldを紐付けるDatasetを作成します。
CSV・TSV・JSONの構造化されたテキストファイルをインプットとするTabularDatasetが提供されてます。このクラスを使い、SMSSpamCollectionを入力としたDatasetを定義します。
fieldsにはタプルのリストを渡します。1リストアイテム=1行、1タプル=1カラム (フィールド) にそれぞれ対応します。タプルの最初の文字列 ("Label" と "Text") は、後ほどDatasetオブジェクトからフィールドにアクセスするための属性名となります。
加えて、splitメソッドでtrainとtestに分割しておきます。
# TSVファイルからDatasetにロード dataset = torchtext.data.TabularDataset( path="./SMSSpamCollection", format="tsv", # tsv, csv, json fields=[("Label", label_field), ("Text", text_field)], skip_header=False, ) # trainとtestで分割 train_dataset, test_dataset = dataset.split( split_ratio=0.8, )
単語ベクトル
torchtextでは単語ベクトルをVectorsの継承クラスで実装されており、いくつか学習済みのベクトルを使うインタフェースが提供されています。今回はFastTextを使います。
- 最初のロードではダウンロードが走りますが、以降はキャッシュされます
- ID -> 文字列の対応を表すitos (list) と 文字列 -> IDの対応を表すstoi (dict) のプロパティを持っています
- ベクトル値はvectors (IDをキーとするlist) からで取得できます
fasttext = torchtext.vocab.FastText(language="en") # 他の学習済みモデルを使う場合 # vectors = torchtext.vocab.Vectors("./wiki-news-300d-1M.vec") print(type(fasttext.itos), type(fasttext.stoi)) # <class 'list'> <class 'dict'> print(len(fasttext.itos)) # 2519370 print(fasttext.vectors[fasttext.stoi["get"]]) # tensor([ 1.6691e-01, 2.0164e-02, -9.7542e-02, -3.3805e-02, -2.6137e-01, # ... # 8.2180e-02, -3.0632e-01, 2.1503e-01, 1.9489e-01, -1.5824e-01])
ボキャブラリの作成
次に学習用のDatasetからボキャブラリを作成します。
作成済みのField (ここではtext_field) のbuild_vocabメソッドを呼び出すと、vocab属性が追加されます。
vocabはVocabオブジェクトで、データセットに現れた単語とそのベクトルに加えて、パディングやUNKも含められます。
- "is"は存在しないトークン (ストップワードで除去) であるため、
stoi["is"]
は0、すなわち""となります
print(hasattr(text_field, "vocab")) # False text_field.build_vocab(train_dataset, vectors=fasttext, min_freq=3) print(hasattr(text_field, "vocab")) # True print(text_field.vocab.vectors.shape) # torch.Size([2444, 300]) print(text_field.vocab.vectors) # tensor([[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], # ... # [-0.1913, -0.1712, 0.0094, ..., 0.1175, 0.2918, -0.0492]]) print(text_field.vocab.itos[:10]) # ['<unk>', '<pad>', '<INIT>', '<EOS>', 'u', 'call', '2', 'get', 'go', 'ur'] print(text_field.vocab.stoi["get"]) # 7 print(text_field.vocab.itos[7]) # get print(text_field.vocab.vectors[text_field.vocab.stoi["get"]]) # tensor([ 1.6691e-01, 2.0164e-02, -9.7542e-02, -3.3805e-02, -2.6137e-01, # ... # 8.2180e-02, -3.0632e-01, 2.1503e-01, 1.9489e-01, -1.5824e-01]) # 存在しないトークン -> "<unk>" print(text_field.vocab.stoi["is"]) # 0
Iterator
最後にDatasetからミニバッチ単位でデータをロードする処理を実装します。
本家PyTorchのDataLoaderにあたるユーティリティクラスとして、torchtextではIterator (またはその派生クラス) が提供されています。以下のように使います。
- ミニバッチごとにFieldで定義した属性名 (Text, Label) でアクセスしてます
- ミニバッチごとに、Textの次元数は (batch_size, fix_length) 、Labelの次元数は (batch_size) となります
サイズを小さくするために、Textフィールドに単語ベクトル値ではなく単語IDが入っている点に注意してください。ベクトル表現はネットワーク側で取得する必要があります。
batch_size = 8 train_iterator = torchtext.data.Iterator( train_dataset, batch_size=batch_size, train=True # train=Trueならシャッフルソートは有効 ) test_iterator = torchtext.data.Iterator( test_dataset, batch_size=batch_size, train=False, sort=False ) batch = next(iter(train_iterator)) print(batch.Text.size()) # torch.Size([8, 256]) print(batch.Text[0]) # tensor([ 2, 604, 8, 37, 8, 0, 0, 34, 18, 8, 6, 708, 161, 3, # 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, # ... # 1, 1, 1, 1]) print(batch.Label) # tensor([0, 0, 0, 0, 0, 0, 0, 1])
単語埋め込み層
torchでベクトル表現に変換するモジュールは以下の実装となります。
Embeddingのクラスメソッドfrom_pretrainedにボキャブラリの単語ベクトルを渡すことで、Embeddingモジュールオブジェクトが生成されます。これを最初の層にすればOKです。
class WordEmbedding(torch.nn.Module): def __init__(self, vectors): super().__init__() self.embedding = torch.nn.Embedding.from_pretrained( embeddings=vectors, freeze=True ) def forward(self, x): x = self.embedding(x) return x word_embedding = WordEmbedding(text_field.vocab.vectors) x = word_embedding(batch.Text) print(x.size()) # torch.Size([8, 256, 300]) print(x[0]) # tensor([[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], # [-0.2761, -0.0085, -0.1173, ..., -0.0522, 0.1952, 0.0860], # [ 0.0769, 0.0298, 0.0174, ..., 0.0628, 0.0840, -0.1559], # ..., # [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], # [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], # [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]])
まとめ
今回はPyTorchでNLPタスクの前処理に便利なtorchtextを紹介しました。