引き続きPyTorchのお勉強してます。
今回はPyTorchで計算資源を有効活用した推論を行うための仕組みの1つTorchScriptについてまとめます。
TorchScriptとは
TorchScriptはPyTorchの中間表現 (intermediate representation) コードとその変換機構です。
主な利点は3つで、いずれも推論のシチュエーションで役立つものです。
- TorchScriptコードは独自のインタプリタで実行・解釈され、Pythonインタプリタのグローバルインタプリタロック (GIL) とも無縁なので、マルチスレッドで並列計算できる
- コードとパラメータをまるごと保存でき、Python以外の実行環境でもロードできる
- 中間コードへ落とし込むコンパイラにて最適化しやすくなる
TorchScriptの変換方法は2つあります。
- torch.jit.traceを使う
- モデルのforwardメソッドを実行時の処理を記録 (トレース) することで変換します
- torch.jit.traceは入力に依存してネットワークが変わる場合に不向きです (forward内でifやforがある場合など)
- torch.jit.scriptを使う
- モデルの定義時にデコレータを付与することで変換する
PyTorch 1.2.0以降ではtorch.jit.scriptでモジュールのプロパティやメソッドを再帰的に変換できるようになりました。その結果、デコレータ無しでもModuleオブジェクトにtorch.jit.scriptを1回通すだけでTorchScriptへ変換できます (参考) 。
本投稿ではPyTorch 1.3.1を使いますので、より楽ちんな2番目の方法を紹介したいと思います。
TorchScriptへの変換
例として、ResNetなどで用いられる残差ブロック (residual block) をピュアなPyTorchで実装し、それをTorchScriptへ変換します。
import torch import torch.nn as nn print(torch.__version__) # 1.3.1
residual blockの実装
最初に残差ブロックをPyTorchで実装します。
- 入力を2つに分け、1つは (畳み込み層 + batch norm + ReLU) * n に通し、もう1つはそのままにして、出力ではその2つを足してReLUにかけたブロックです
- ResNetについては以前の投稿も参考にしてみてください (ResNetでCIFAR-10を分類する - け日記)
通常のPyTorchのモジュールと同様に実装します。nn.Moduleを継承してクラスを作り、プロパティにモジュールとパラメータを持たせ、forwardを実装しているだけです。 モジュールの作り方については先週の投稿に詳しく記載しました。
- 畳み込み層はModuleListにまとめてます
- 入力チャネル数
in_channels
と出力チャネル数out_channels
が異なる場合、残差residual
を1x1の畳み込みにかけることで出力チャネル数に合わせてます - 最後のReLUは畳み込み層を通過した
x
と残差residual
を足した後にかけてます
この時点では特別なことはしてません。(ハマりどころは後述します。)
class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, conv_layers, kernel_size=3): super().__init__() self.conv_layers = conv_layers padding = (kernel_size - 1) // 2 # 最初の畳み込みで出力チャネル数にあわせる convs = [nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)] convs.extend([ nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding) for _ in range(1, conv_layers) ]) # ModuleListで畳み込み層を1つにまとめる self.convs = nn.ModuleList(convs) self.conv_1x1 = None if in_channels != out_channels: # 入出力のチャネル数が異なる場合、残差のチャネル数を1x1畳み込みで出力チャネル数に合わせる self.conv_1x1 = nn.Conv2d(in_channels, out_channels, 1) self.batch_norm = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) def forward(self, x): residual = x i = 0 for conv in self.convs: x = conv(x) x = self.batch_norm(x) # 畳み込み最後のReLUはスキップ if i < (self.conv_layers - 1): x = self.relu(x) i += 1 if self.conv_1x1 is not None: # 残差のチャネル数を出力チャネル数に揃える residual = self.conv_1x1(residual) x = x + residual output = self.relu(x) return output
torch.jit.scriptによる変換
上のResidualBlockモジュールをTorchScriptに変換します。
最初にResidualBlockモジュールをインスタンス化します。
block = ResidualBlock(3, 8, 3) print(type(block)) # <class '__main__.ResidualBlock'> print(block) # ResidualBlock( # (convs): ModuleList( # (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # ) # (conv_1x1): Conv2d(3, 8, kernel_size=(1, 1), stride=(1, 1)) # (batch_norm): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) # (relu): ReLU(inplace=True) # )
そしてtorch.jit.scriptにオブジェクトを渡して変換します。torch.jit.ScriptModule型のオブジェクトが得られていれば変換成功です。
scripted_block = torch.jit.script(block) print(type(scripted_block)) # <class 'torch.jit.ScriptModule'>
ScriptModuleは、nn.Moduleを継承してますのでPythonでは今までのモジュール同様扱えますが、同時にC++のモジュールクラス (torch::jit::script::Module) のラッパにもなってます。
またScriptModuleでは、計算グラフを持つgraphプロパティと、それをもうちょっとPythonライクに見やすくしたcodeプロパティを持ってます。
以下でcodeを表示してますが、単一のforward関数が定義され、その中に全てのコードが展開されているようです。最後のif True:
のような強引な置き換えも見られます。
print(scripted_block.code) # import __torch__.___torch_mangle_340 # import __torch__.torch.jit.___torch_mangle_356 # import __torch__.torch.nn.modules.conv.___torch_mangle_341 # import __torch__.torch.nn.modules.conv.___torch_mangle_346 # import __torch__.torch.nn.modules.conv.___torch_mangle_351 # import __torch__.torch.nn.modules.batchnorm.___torch_mangle_357 # import __torch__.torch.nn.modules.activation.___torch_mangle_359 # import __torch__.torch.nn.modules.conv.___torch_mangle_361 # def forward(self, # x: Tensor) -> Tensor: # _0 = self.convs # ... (中略) ... # if _390: # residual = _391 # else: # residual = torch.conv2d(x, _388, _387.bias, [1, 1], [0, 0], [1, 1], 1) # x12 = torch.add(x10, residual, alpha=1) # if True: # output = torch.relu_(x12) # else: # output = torch.relu(x12) # return output
同じモデルになっているか簡単に確認します。同じ入力を与えると、PyTorch版もTorchScript版も同じ値を出力できていることがわかります。
input = torch.ones(1, 3, 2, 2) output1 = block(input) print(output1.size()) # torch.Size([1, 8, 2, 2]) output2 = scripted_block(input) print(output2.size()) # torch.Size([1, 8, 2, 2]) print(torch.sum(output1 != output2)) # tensor(0)
はまりどころ
上では1発で変換できたように見えますが、記述方法によってはうまく変換できないケースがあります。自分が直面した問題を2つ挙げます。
ModuleListとModuleDictはインデックスでアクセスできない
できるかどうかちょっと心配だったのですが、ModuleListとModuleDictも変換できます。
ただしアクセス方法が制限されており、イテレーション形式のみサポートされてます。例えばx = self.convs[i](x)
のようにインデックスでアクセスすると、以下のようにエラーとなります。
- 関連イシュー
module' object is not subscriptable: at <ipython-input-43-f60e7388ed80>:30:16 def forward(self, x): residual = x i = 0 for conv in self.convs: # x = conv(x) x = self.convs[i](x) ~~~~~~~~~~~~ <--- HERE x = self.batch_norm(x) # 畳み込み最後のReLUはスキップ if i < (self.conv_layers - 1): x = self.relu(x) i += 1
Noneオブジェクトの暗黙的なbool変換が行われない
PythonではNoneオブジェクトを条件式に使うとFalse扱いとなります。
# if self.conv_1x1 is not None: if self.conv_1x1: # 残差のチャネル数を出力チャネル数に揃える residual = self.conv_1x1(residual)
ところが上のように記述してしまうと変換時にエラーとなってしまいます。条件文などではbool型となるように明示的に変換したほうが良さそうです。
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-45-b535658587ed> in <module> 68 print(output) 69 ---> 70 scripted_block = torch.jit.script(block) 71 print(type(scripted_block)) # <class 'torch.jit.ScriptModule'> 72 ... (省略) ... /opt/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py in _create_methods_from_stubs(self, stubs) 1421 rcbs = [m.resolution_callback for m in stubs] 1422 defaults = [get_default_args(m.original_method) for m in stubs] -> 1423 self._c._create_methods(self, defs, rcbs, defaults) 1424 1425 # For each user-defined class that subclasses ScriptModule this meta-class, RuntimeError: Could not cast value of type __torch__.torch.nn.modules.conv.___torch_mangle_430.Conv2d to bool:
プロパティのモジュールも再帰的に探索・変換される
プロパティにモジュール (自作クラスを含む) をコンポジットしている場合でも、再帰的に変換してくれます。
先程定義したResidualBlockをたくさん持つNetオブジェクトも、torch.jit.scriptで変換できます。デコレータなどは不要です。
class Net(nn.Module): def __init__(self): super().__init__() self.conv1 = ResidualBlock(3, 8, 1) self.conv2 = ResidualBlock(8, 16, 3) self.conv3 = ResidualBlock(16, 32, 3) self.conv4 = ResidualBlock(32, 64, 3) self.conv5 = ResidualBlock(64, 128, 3) self.linear = nn.Linear(7*7*128, 10) self.pool = nn.MaxPool2d(2) def forward(self, x): x = self.conv1(x) x = self.pool(x) x = self.conv2(x) x = self.pool(x) x = self.conv3(x) x = self.pool(x) x = self.conv4(x) x = self.pool(x) x = self.conv5(x) x = self.pool(x) x = x.flatten(1) x = self.linear(x) output = torch.softmax(x, dim=1) return output scripted_net = torch.jit.script(net) print(type(scripted_net)) # <class 'torch.jit.ScriptModule'>
TorchScriptモデルのセーブとロード
TorchScriptを保存するには、saveメソッドを呼び出すだけです。
torch.jit.saveではメソッドやサブモジュール、パラメータなどを全てシリアライズして保存されます。
scripted_block.save("scripted_block.pt")
上で保存したモデルはtorch.jit.loadでロードできます。
- シリアライズフォーマットはC++ APIとも互換性があり、torch::jit::loadでロードできるようになります
loaded_block = torch.jit.load("scripted_block.pt") print(type(loaded_block)) # <class 'torch.jit.ScriptModule'>
forwardすると当然同じ結果が得られます。
output3 = loaded_block(input) print(torch.sum(output2 != output3)) # tensor(0) print(torch.sum(output1 != output3)) # tensor(0)
まとめ
今回はTorchScriptへの変換について調査・整理しました。パフォーマンス面での評価は今後行いたいと思います。
まだまだ開発途上にあり、インタフェースレベルでの大きな変更も今後あるかと思います。また、周辺ライブラリの対応も後手に回っており、変換が難しいケースもあります。可能な限り最新バージョンに追随していく必要があります。
- 例えば、torchvisionが提供するResNetモデルが現行最新版 (0.4.2) でもTorchScriptに変換できなかったりします (参考イシュー)
参考
TorchScriptの情報源については、今のところ公式のドキュメントとチュートリアルが頼りになります。