PyTorchのモデルをTorchScriptへ変換する

引き続きPyTorchのお勉強してます。

今回はPyTorchで計算資源を有効活用した推論を行うための仕組みの1つTorchScriptについてまとめます。

TorchScriptとは

TorchScriptはPyTorchの中間表現 (intermediate representation) コードとその変換機構です。

主な利点は3つで、いずれも推論のシチュエーションで役立つものです。

  • TorchScriptコードは独自のインタプリタで実行・解釈され、Pythonインタプリタのグローバルインタプリタロック (GIL) とも無縁なので、マルチスレッドで並列計算できる
  • コードとパラメータをまるごと保存でき、Python以外の実行環境でもロードできる
  • 中間コードへ落とし込むコンパイラにて最適化しやすくなる

TorchScriptの変換方法は2つあります。

  • torch.jit.traceを使う
    • モデルのforwardメソッドを実行時の処理を記録 (トレース) することで変換します
    • torch.jit.traceは入力に依存してネットワークが変わる場合に不向きです (forward内でifやforがある場合など)
  • torch.jit.scriptを使う
    • モデルの定義時にデコレータを付与することで変換する

PyTorch 1.2.0以降ではtorch.jit.scriptでモジュールのプロパティやメソッドを再帰的に変換できるようになりました。その結果、デコレータ無しでもModuleオブジェクトにtorch.jit.scriptを1回通すだけでTorchScriptへ変換できます (参考) 。
本投稿ではPyTorch 1.3.1を使いますので、より楽ちんな2番目の方法を紹介したいと思います。

TorchScriptへの変換

例として、ResNetなどで用いられる残差ブロック (residual block) をピュアなPyTorchで実装し、それをTorchScriptへ変換します。

import torch
import torch.nn as nn

print(torch.__version__)  # 1.3.1

residual blockの実装

最初に残差ブロックをPyTorchで実装します。

  • 入力を2つに分け、1つは (畳み込み層 + batch norm + ReLU) * n に通し、もう1つはそのままにして、出力ではその2つを足してReLUにかけたブロックです
  • ResNetについては以前の投稿も参考にしてみてください (ResNetでCIFAR-10を分類する - け日記)

通常のPyTorchのモジュールと同様に実装します。nn.Moduleを継承してクラスを作り、プロパティにモジュールとパラメータを持たせ、forwardを実装しているだけです。 モジュールの作り方については先週の投稿に詳しく記載しました。

  • 畳み込み層はModuleListにまとめてます
  • 入力チャネル数in_channelsと出力チャネル数out_channelsが異なる場合、残差residualを1x1の畳み込みにかけることで出力チャネル数に合わせてます
  • 最後のReLUは畳み込み層を通過したxと残差residualを足した後にかけてます

この時点では特別なことはしてません。(ハマりどころは後述します。)

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, conv_layers, kernel_size=3):
        super().__init__()
        
        self.conv_layers = conv_layers
        
        padding = (kernel_size - 1) // 2
        # 最初の畳み込みで出力チャネル数にあわせる
        convs = [nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)]
        convs.extend([
            nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding)
            for _ in range(1, conv_layers)
        ])
        # ModuleListで畳み込み層を1つにまとめる
        self.convs = nn.ModuleList(convs)
        
        self.conv_1x1 = None
        if in_channels != out_channels:
            # 入出力のチャネル数が異なる場合、残差のチャネル数を1x1畳み込みで出力チャネル数に合わせる
            self.conv_1x1 = nn.Conv2d(in_channels, out_channels, 1)
        
        self.batch_norm = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        residual = x
        
        i = 0
        for conv in self.convs:
            x = conv(x)
            x = self.batch_norm(x)
            # 畳み込み最後のReLUはスキップ
            if i < (self.conv_layers - 1):
                x = self.relu(x)
            i += 1
        
        if self.conv_1x1 is not None:
            # 残差のチャネル数を出力チャネル数に揃える
            residual = self.conv_1x1(residual)
                
        x = x + residual
        
        output = self.relu(x)
        
        return output

torch.jit.scriptによる変換

上のResidualBlockモジュールをTorchScriptに変換します。

最初にResidualBlockモジュールをインスタンス化します。

block = ResidualBlock(3, 8, 3)
print(type(block))  # <class '__main__.ResidualBlock'>
print(block)
# ResidualBlock(
#   (convs): ModuleList(
#     (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#     (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#     (2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   )
#   (conv_1x1): Conv2d(3, 8, kernel_size=(1, 1), stride=(1, 1))
#   (batch_norm): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#   (relu): ReLU(inplace=True)
# )

そしてtorch.jit.scriptにオブジェクトを渡して変換します。torch.jit.ScriptModule型のオブジェクトが得られていれば変換成功です。

scripted_block = torch.jit.script(block)
print(type(scripted_block))  # <class 'torch.jit.ScriptModule'>

ScriptModuleは、nn.Moduleを継承してますのでPythonでは今までのモジュール同様扱えますが、同時にC++のモジュールクラス (torch::jit::script::Module) のラッパにもなってます。

またScriptModuleでは、計算グラフを持つgraphプロパティと、それをもうちょっとPythonライクに見やすくしたcodeプロパティを持ってます。

以下でcodeを表示してますが、単一のforward関数が定義され、その中に全てのコードが展開されているようです。最後のif True:のような強引な置き換えも見られます。

print(scripted_block.code)
# import __torch__.___torch_mangle_340
# import __torch__.torch.jit.___torch_mangle_356
# import __torch__.torch.nn.modules.conv.___torch_mangle_341
# import __torch__.torch.nn.modules.conv.___torch_mangle_346
# import __torch__.torch.nn.modules.conv.___torch_mangle_351
# import __torch__.torch.nn.modules.batchnorm.___torch_mangle_357
# import __torch__.torch.nn.modules.activation.___torch_mangle_359
# import __torch__.torch.nn.modules.conv.___torch_mangle_361
# def forward(self,
#     x: Tensor) -> Tensor:
#   _0 = self.convs
# ... (中略) ...
#   if _390:
#     residual = _391
#   else:
#     residual = torch.conv2d(x, _388, _387.bias, [1, 1], [0, 0], [1, 1], 1)
#   x12 = torch.add(x10, residual, alpha=1)
#   if True:
#     output = torch.relu_(x12)
#   else:
#     output = torch.relu(x12)
#   return output

同じモデルになっているか簡単に確認します。同じ入力を与えると、PyTorch版もTorchScript版も同じ値を出力できていることがわかります。

input = torch.ones(1, 3, 2, 2)

output1 = block(input)
print(output1.size())  # torch.Size([1, 8, 2, 2])

output2 = scripted_block(input)
print(output2.size())  # torch.Size([1, 8, 2, 2])

print(torch.sum(output1 != output2))  # tensor(0)

はまりどころ

上では1発で変換できたように見えますが、記述方法によってはうまく変換できないケースがあります。自分が直面した問題を2つ挙げます。

ModuleListとModuleDictはインデックスでアクセスできない

できるかどうかちょっと心配だったのですが、ModuleListとModuleDictも変換できます。

ただしアクセス方法が制限されており、イテレーション形式のみサポートされてます。例えばx = self.convs[i](x)のようにインデックスでアクセスすると、以下のようにエラーとなります。

module' object is not subscriptable:
at <ipython-input-43-f60e7388ed80>:30:16
    def forward(self, x):
        residual = x
        
        i = 0
        for conv in self.convs:
            # x = conv(x)
            x = self.convs[i](x)
                ~~~~~~~~~~~~ <--- HERE
            x = self.batch_norm(x)
            # 畳み込み最後のReLUはスキップ
            if i < (self.conv_layers - 1):
                x = self.relu(x)
            i += 1

Noneオブジェクトの暗黙的なbool変換が行われない

PythonではNoneオブジェクトを条件式に使うとFalse扱いとなります。

        # if self.conv_1x1 is not None:
        if self.conv_1x1:
            # 残差のチャネル数を出力チャネル数に揃える
            residual = self.conv_1x1(residual)

ところが上のように記述してしまうと変換時にエラーとなってしまいます。条件文などではbool型となるように明示的に変換したほうが良さそうです。

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-45-b535658587ed> in <module>
     68 print(output)
     69 
---> 70 scripted_block = torch.jit.script(block)
     71 print(type(scripted_block))  # <class 'torch.jit.ScriptModule'>
     72 

... (省略) ...

/opt/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py in _create_methods_from_stubs(self, stubs)
   1421     rcbs = [m.resolution_callback for m in stubs]
   1422     defaults = [get_default_args(m.original_method) for m in stubs]
-> 1423     self._c._create_methods(self, defs, rcbs, defaults)
   1424 
   1425 # For each user-defined class that subclasses ScriptModule this meta-class,

RuntimeError: 
Could not cast value of type __torch__.torch.nn.modules.conv.___torch_mangle_430.Conv2d to bool:

プロパティのモジュールも再帰的に探索・変換される

プロパティにモジュール (自作クラスを含む) をコンポジットしている場合でも、再帰的に変換してくれます。

先程定義したResidualBlockをたくさん持つNetオブジェクトも、torch.jit.scriptで変換できます。デコレータなどは不要です。

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = ResidualBlock(3, 8, 1)
        self.conv2 = ResidualBlock(8, 16, 3)
        self.conv3 = ResidualBlock(16, 32, 3)
        self.conv4 = ResidualBlock(32, 64, 3)
        self.conv5 = ResidualBlock(64, 128, 3)
        self.linear = nn.Linear(7*7*128, 10)
        
        self.pool = nn.MaxPool2d(2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = self.conv3(x)
        x = self.pool(x)
        x = self.conv4(x)
        x = self.pool(x)
        x = self.conv5(x)
        x = self.pool(x)
        
        x = x.flatten(1)
        
        x = self.linear(x)
        
        output = torch.softmax(x, dim=1)
        
        return output

scripted_net = torch.jit.script(net)
print(type(scripted_net))  # <class 'torch.jit.ScriptModule'>

TorchScriptモデルのセーブとロード

TorchScriptを保存するには、saveメソッドを呼び出すだけです。

torch.jit.saveではメソッドやサブモジュール、パラメータなどを全てシリアライズして保存されます。

scripted_block.save("scripted_block.pt")

上で保存したモデルはtorch.jit.loadでロードできます。

  • シリアライズフォーマットはC++ APIとも互換性があり、torch::jit::loadでロードできるようになります
loaded_block = torch.jit.load("scripted_block.pt")
print(type(loaded_block))  # <class 'torch.jit.ScriptModule'>

forwardすると当然同じ結果が得られます。

output3 = loaded_block(input)

print(torch.sum(output2 != output3))  # tensor(0)
print(torch.sum(output1 != output3))  # tensor(0)

まとめ

今回はTorchScriptへの変換について調査・整理しました。パフォーマンス面での評価は今後行いたいと思います。

まだまだ開発途上にあり、インタフェースレベルでの大きな変更も今後あるかと思います。また、周辺ライブラリの対応も後手に回っており、変換が難しいケースもあります。可能な限り最新バージョンに追随していく必要があります。

  • 例えば、torchvisionが提供するResNetモデルが現行最新版 (0.4.2) でもTorchScriptに変換できなかったりします (参考イシュー)

参考

TorchScriptの情報源については、今のところ公式のドキュメントとチュートリアルが頼りになります。