PyTorchのモデルをTorchScriptへ変換する

引き続きPyTorchのお勉強してます。

今回はPyTorchで計算資源を有効活用した推論を行うための仕組みの1つTorchScriptについてまとめます。

TorchScriptとは

TorchScriptはPyTorchの中間表現 (intermediate representation) コードとその変換機構です。

主な利点は3つで、いずれも推論のシチュエーションで役立つものです。

  • TorchScriptコードは独自のインタプリタで実行・解釈され、Pythonインタプリタのグローバルインタプリタロック (GIL) とも無縁なので、マルチスレッドで並列計算できる
  • コードとパラメータをまるごと保存でき、Python以外の実行環境でもロードできる
  • 中間コードへ落とし込むコンパイラにて最適化しやすくなる

TorchScriptの変換方法は2つあります。

  • torch.jit.traceを使う
    • モデルのforwardメソッドを実行時の処理を記録 (トレース) することで変換します
    • torch.jit.traceは入力に依存してネットワークが変わる場合に不向きです (forward内でifやforがある場合など)
  • torch.jit.scriptを使う
    • モデルの定義時にデコレータを付与することで変換する

PyTorch 1.2.0以降ではtorch.jit.scriptでモジュールのプロパティやメソッドを再帰的に変換できるようになりました。その結果、デコレータ無しでもModuleオブジェクトにtorch.jit.scriptを1回通すだけでTorchScriptへ変換できます (参考) 。
本投稿ではPyTorch 1.3.1を使いますので、より楽ちんな2番目の方法を紹介したいと思います。

TorchScriptへの変換

例として、ResNetなどで用いられる残差ブロック (residual block) をピュアなPyTorchで実装し、それをTorchScriptへ変換します。

import torch
import torch.nn as nn

print(torch.__version__)  # 1.3.1

residual blockの実装

最初に残差ブロックをPyTorchで実装します。

  • 入力を2つに分け、1つは (畳み込み層 + batch norm + ReLU) * n に通し、もう1つはそのままにして、出力ではその2つを足してReLUにかけたブロックです
  • ResNetについては以前の投稿も参考にしてみてください (ResNetでCIFAR-10を分類する - け日記)

通常のPyTorchのモジュールと同様に実装します。nn.Moduleを継承してクラスを作り、プロパティにモジュールとパラメータを持たせ、forwardを実装しているだけです。 モジュールの作り方については先週の投稿に詳しく記載しました。

  • 畳み込み層はModuleListにまとめてます
  • 入力チャネル数in_channelsと出力チャネル数out_channelsが異なる場合、残差residualを1x1の畳み込みにかけることで出力チャネル数に合わせてます
  • 最後のReLUは畳み込み層を通過したxと残差residualを足した後にかけてます

この時点では特別なことはしてません。(ハマりどころは後述します。)

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, conv_layers, kernel_size=3):
        super().__init__()
        
        self.conv_layers = conv_layers
        
        padding = (kernel_size - 1) // 2
        # 最初の畳み込みで出力チャネル数にあわせる
        convs = [nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding)]
        convs.extend([
            nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding)
            for _ in range(1, conv_layers)
        ])
        # ModuleListで畳み込み層を1つにまとめる
        self.convs = nn.ModuleList(convs)
        
        self.conv_1x1 = None
        if in_channels != out_channels:
            # 入出力のチャネル数が異なる場合、残差のチャネル数を1x1畳み込みで出力チャネル数に合わせる
            self.conv_1x1 = nn.Conv2d(in_channels, out_channels, 1)
        
        self.batch_norm = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        residual = x
        
        i = 0
        for conv in self.convs:
            x = conv(x)
            x = self.batch_norm(x)
            # 畳み込み最後のReLUはスキップ
            if i < (self.conv_layers - 1):
                x = self.relu(x)
            i += 1
        
        if self.conv_1x1 is not None:
            # 残差のチャネル数を出力チャネル数に揃える
            residual = self.conv_1x1(residual)
                
        x = x + residual
        
        output = self.relu(x)
        
        return output

torch.jit.scriptによる変換

上のResidualBlockモジュールをTorchScriptに変換します。

最初にResidualBlockモジュールをインスタンス化します。

block = ResidualBlock(3, 8, 3)
print(type(block))  # <class '__main__.ResidualBlock'>
print(block)
# ResidualBlock(
#   (convs): ModuleList(
#     (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#     (1): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#     (2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#   )
#   (conv_1x1): Conv2d(3, 8, kernel_size=(1, 1), stride=(1, 1))
#   (batch_norm): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#   (relu): ReLU(inplace=True)
# )

そしてtorch.jit.scriptにオブジェクトを渡して変換します。torch.jit.ScriptModule型のオブジェクトが得られていれば変換成功です。

scripted_block = torch.jit.script(block)
print(type(scripted_block))  # <class 'torch.jit.ScriptModule'>

ScriptModuleは、nn.Moduleを継承してますのでPythonでは今までのモジュール同様扱えますが、同時にC++のモジュールクラス (torch::jit::script::Module) のラッパにもなってます。

またScriptModuleでは、計算グラフを持つgraphプロパティと、それをもうちょっとPythonライクに見やすくしたcodeプロパティを持ってます。

以下でcodeを表示してますが、単一のforward関数が定義され、その中に全てのコードが展開されているようです。最後のif True:のような強引な置き換えも見られます。

print(scripted_block.code)
# import __torch__.___torch_mangle_340
# import __torch__.torch.jit.___torch_mangle_356
# import __torch__.torch.nn.modules.conv.___torch_mangle_341
# import __torch__.torch.nn.modules.conv.___torch_mangle_346
# import __torch__.torch.nn.modules.conv.___torch_mangle_351
# import __torch__.torch.nn.modules.batchnorm.___torch_mangle_357
# import __torch__.torch.nn.modules.activation.___torch_mangle_359
# import __torch__.torch.nn.modules.conv.___torch_mangle_361
# def forward(self,
#     x: Tensor) -> Tensor:
#   _0 = self.convs
# ... (中略) ...
#   if _390:
#     residual = _391
#   else:
#     residual = torch.conv2d(x, _388, _387.bias, [1, 1], [0, 0], [1, 1], 1)
#   x12 = torch.add(x10, residual, alpha=1)
#   if True:
#     output = torch.relu_(x12)
#   else:
#     output = torch.relu(x12)
#   return output

同じモデルになっているか簡単に確認します。同じ入力を与えると、PyTorch版もTorchScript版も同じ値を出力できていることがわかります。

input = torch.ones(1, 3, 2, 2)

output1 = block(input)
print(output1.size())  # torch.Size([1, 8, 2, 2])

output2 = scripted_block(input)
print(output2.size())  # torch.Size([1, 8, 2, 2])

print(torch.sum(output1 != output2))  # tensor(0)

はまりどころ

上では1発で変換できたように見えますが、記述方法によってはうまく変換できないケースがあります。自分が直面した問題を2つ挙げます。

ModuleListとModuleDictはインデックスでアクセスできない

できるかどうかちょっと心配だったのですが、ModuleListとModuleDictも変換できます。

ただしアクセス方法が制限されており、イテレーション形式のみサポートされてます。例えばx = self.convs[i](x)のようにインデックスでアクセスすると、以下のようにエラーとなります。

module' object is not subscriptable:
at <ipython-input-43-f60e7388ed80>:30:16
    def forward(self, x):
        residual = x
        
        i = 0
        for conv in self.convs:
            # x = conv(x)
            x = self.convs[i](x)
                ~~~~~~~~~~~~ <--- HERE
            x = self.batch_norm(x)
            # 畳み込み最後のReLUはスキップ
            if i < (self.conv_layers - 1):
                x = self.relu(x)
            i += 1

Noneオブジェクトの暗黙的なbool変換が行われない

PythonではNoneオブジェクトを条件式に使うとFalse扱いとなります。

        # if self.conv_1x1 is not None:
        if self.conv_1x1:
            # 残差のチャネル数を出力チャネル数に揃える
            residual = self.conv_1x1(residual)

ところが上のように記述してしまうと変換時にエラーとなってしまいます。条件文などではbool型となるように明示的に変換したほうが良さそうです。

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-45-b535658587ed> in <module>
     68 print(output)
     69 
---> 70 scripted_block = torch.jit.script(block)
     71 print(type(scripted_block))  # <class 'torch.jit.ScriptModule'>
     72 

... (省略) ...

/opt/anaconda3/lib/python3.7/site-packages/torch/jit/__init__.py in _create_methods_from_stubs(self, stubs)
   1421     rcbs = [m.resolution_callback for m in stubs]
   1422     defaults = [get_default_args(m.original_method) for m in stubs]
-> 1423     self._c._create_methods(self, defs, rcbs, defaults)
   1424 
   1425 # For each user-defined class that subclasses ScriptModule this meta-class,

RuntimeError: 
Could not cast value of type __torch__.torch.nn.modules.conv.___torch_mangle_430.Conv2d to bool:

プロパティのモジュールも再帰的に探索・変換される

プロパティにモジュール (自作クラスを含む) をコンポジットしている場合でも、再帰的に変換してくれます。

先程定義したResidualBlockをたくさん持つNetオブジェクトも、torch.jit.scriptで変換できます。デコレータなどは不要です。

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = ResidualBlock(3, 8, 1)
        self.conv2 = ResidualBlock(8, 16, 3)
        self.conv3 = ResidualBlock(16, 32, 3)
        self.conv4 = ResidualBlock(32, 64, 3)
        self.conv5 = ResidualBlock(64, 128, 3)
        self.linear = nn.Linear(7*7*128, 10)
        
        self.pool = nn.MaxPool2d(2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = self.conv3(x)
        x = self.pool(x)
        x = self.conv4(x)
        x = self.pool(x)
        x = self.conv5(x)
        x = self.pool(x)
        
        x = x.flatten(1)
        
        x = self.linear(x)
        
        output = torch.softmax(x, dim=1)
        
        return output

scripted_net = torch.jit.script(net)
print(type(scripted_net))  # <class 'torch.jit.ScriptModule'>

TorchScriptモデルのセーブとロード

TorchScriptを保存するには、saveメソッドを呼び出すだけです。

torch.jit.saveではメソッドやサブモジュール、パラメータなどを全てシリアライズして保存されます。

scripted_block.save("scripted_block.pt")

上で保存したモデルはtorch.jit.loadでロードできます。

  • シリアライズフォーマットはC++ APIとも互換性があり、torch::jit::loadでロードできるようになります
loaded_block = torch.jit.load("scripted_block.pt")
print(type(loaded_block))  # <class 'torch.jit.ScriptModule'>

forwardすると当然同じ結果が得られます。

output3 = loaded_block(input)

print(torch.sum(output2 != output3))  # tensor(0)
print(torch.sum(output1 != output3))  # tensor(0)

まとめ

今回はTorchScriptへの変換について調査・整理しました。パフォーマンス面での評価は今後行いたいと思います。

まだまだ開発途上にあり、インタフェースレベルでの大きな変更も今後あるかと思います。また、周辺ライブラリの対応も後手に回っており、変換が難しいケースもあります。可能な限り最新バージョンに追随していく必要があります。

  • 例えば、torchvisionが提供するResNetモデルが現行最新版 (0.4.2) でもTorchScriptに変換できなかったりします (参考イシュー)

参考

TorchScriptの情報源については、今のところ公式のドキュメントとチュートリアルが頼りになります。

PyTorchでネットワークを実装する

引き続きPyTorchのお勉強中です。

前々回はテンソル前回は誤差逆伝播について見ていきましたが、今回はtorch.nnのモジュールを中心にネットワークの作り方について整理していきます。

前回より少し難しくして、2次関数  y = w2 \times x^2 + w1 \times x + b + \epsilon (w=4.5, b=7.0, εは誤差) を例にしていきます。

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

print(torch.__version__)  # 1.1.0

np.random.seed(42)

w2 = 0.6
w1 = 4.5
b = 7.0

x_array = np.array([-1.5, -1.0, -0.1, 0.9, 1.8, 2.2, 3.1])
y_array = w2 * x_array**2 + w * x_array + b + np.random.normal(size=x_array.shape[0])
plt.scatter(x_array, y_array)

x = torch.tensor(x_array).unsqueeze(1).float()
print(x)  # tensor([[-1.5000], [-1.0000], [-0.1000], [0.9000], [1.8000], [2.2000], [3.1000]])
y = torch.tensor(y_array).unsqueeze(1).float()
print(y)  # tensor([[2.0967], [2.9617], [7.2037], [13.0590], [16.8098], [19.5699], [28.2952]])

モジュールとtorch.nn

PyTorchにおいて、畳み込みフィルタ・活性化関数・損失関数などのネットワークの"層"を実装しているのが、モジュールです。このモジュールを組み合わせることによって、ネットワークを実装します。

PyTorchではよく使われる層はモジュールとして既に実装されていることが多く、それらはtorch.nnに集約されています。

モジュールの基本的な使い方

最初に  y = w \times x + b の1次元関数でモデリング・学習するコードを実装し、モジュールの基本的な使い方を確認します。

以下がそのコードです。2つのモジュールを使っており、SGDで学習してます。学習のループの詳細は前回の投稿も参考にしてみてください。

  • nn.Linearは1次関数で、傾きと切片がパラメータとなります
    • __init__引数は入力サイズ・出力サイズ
  • nn.MSELossは平均2乗誤差を計算するモジュールで、損失関数として利用しています
model = nn.Linear(1, 1)
loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)

for epoch in range(1, 2001):
    y_p = model(x)
    loss = loss_fn(y_p, y)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if epoch == 1 or epoch % 200 == 0:
        print(f"Epoch {epoch}: loss={loss}, params={list(model.parameters())}")

# Epoch 1: loss=247.5394287109375, params=[Parameter containing:tensor([[-0.3880]], requires_grad=True), Parameter containing:tensor([0.6815], requires_grad=True)]
# ...
# Epoch 2000: loss=2.2267117500305176, params=[Parameter containing:tensor([[5.5755]], requires_grad=True), Parameter containing:tensor([8.3513], requires_grad=True)]

概ね前回の投稿と同じですが、torch.nnのモジュールの特徴を2点補足します。

parametersメソッド

学習可能なパラメータはparametersメソッドで取得できます。

optim.SGDなどのOptimizerを継承した最適化クラスをインスタンス化する際に、第1引数 (params) にparametersメソッドから得られた値を渡すことで勾配計算 (学習) の対象に含めることができます。

print(list(model.parameters()))
# [Parameter containing: tensor([[5.5755]], requires_grad=True),
#  Parameter containing: tensor([8.3513], requires_grad=True)]

モジュールのパラメータはParameterオブジェクトとしてプロパティに持っています。実体はtensorなのですが、parametersメソッドはこのクラスのプロパティのみを探索しています。

  • Linearでは2つのプロパティ (weightとbias) がParameterオブジェクトです (参考)

0次元目はバッチのインデックス

nn.Moduleは0次元目がバッチ内のインデックスが入っていることを前提としてます。そのため以下のケースはエラーとなります。モデルは入力サイズ1を期待しているのに、実際の入力サイズは7のためです。
上のコードではunsqueezeメソッドで1次元目 (サイズ1) を追加することで回避してます。

x = torch.tensor(x_array).float()  # unsqueeze無し
print(x.shape)  # torch.Size([7])

y_pred = model(x)
# ---------------------------------------------------------------------------
# RuntimeError                              Traceback (most recent call last)
# ...
# RuntimeError: size mismatch, m1: [1 x 7], m2: [1 x 1] at /opt/conda/conda-bld/pytorch_1556653114079/work/aten/src/TH/generic/THTensorMath.cpp:961

ネットワークの自作

上でLinearを使ってモデリングしましたが、学習の結果lossは下がりきりませんでした。もうちょっと複雑なモデルを導入する必要がありそうです。

そこで自前のネットワークを定義していくのですが、上の例で見たシンプルなインタフェース (__call__で順伝播、backwardで逆伝播) は維持したいところです。

LinearやMSELossなどのモジュール (TensorFlowではレイヤと呼ばれる) はnn.Moduleクラスを継承してます。

nn.Moduleクラスを継承したクラスを作成することで、シンプルなインタフェースで学習を実現できます。

以下コードでは、隠れ層 (1層) の簡単なネットワークをnn.Moduleを継承したMyModelで実装しています。使い方は上で見たLinearと全く同じことに着目してください。

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.linear_1 = nn.Linear(1, 5)
        self.linear_2 = nn.Linear(5, 1)
        
    def forward(self, x):
        x = self.linear_1(x)
        x = torch.relu(x)
        x = self.linear_2(x)
        return x

model = MyModel()
print(model)
# MyModel(
#   (linear_1): Linear(in_features=1, out_features=5, bias=True)
#   (linear_2): Linear(in_features=5, out_features=1, bias=True)
# )

loss_fn = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)

for epoch in range(1, 2001):
    y_p = model(x)
    loss = loss_fn(y_p, y)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if epoch == 1 or epoch % 200 == 0:
        print(f"Epoch {epoch}: loss={loss}")
        
# Epoch 1: loss=238.2289276123047
# Epoch 200: loss=1.2844644784927368
# ...
# Epoch 2000: loss=1.1387771368026733
        
print(list(model.parameters()))
# [Parameter containing: tensor([[-0.9788], [0.6803], [1.7500], [1.0663], [0.5548]], requires_grad=True), 
#  Parameter containing: tensor([-0.5819, 0.9828, 0.7439, 1.3531, 1.1561], requires_grad=True), 
#  Parameter containing: tensor([[0.3001, 0.8646, 1.9179, 1.5838, 1.1189]], requires_grad=True), 
#  Parameter containing: tensor([1.4064], requires_grad=True)]

MyModelで実装しているのは__init__メソッドとforwardメソッドのみです。それぞれポイントを説明していきます。

__init__

__init__メソッドではネットワークを構成するパラメータとモジュールを初期化して、プロパティにセットしていきます。

モジュールのプロパティに対しても再帰的にパラメータを探索します。以下ではMyLinearsのインスタンスをプロパティに持つMyModelを定義してますが、print(model)でMyLinearの中のモジュールのパラメータまで認識されていることがわかります。

class MyLinears(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.linear1 = nn.Linear(1, 5)
        self.linear2 = nn.Linear(5, 1)
        
    def forward(self, x):
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.linear2(x)
        return x

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.my_linears = MyLinears()
        self.linear = nn.Linear(1, 1)
        
    def forward(self, x):
        x = self.my_linears(x)
        x = torch.relu(x)
        x = self.linear(x)
        return x

model = MyModel()
print(model)
# MyModel(
#   (my_linears): MyLinears(
#     (linear1): Linear(in_features=1, out_features=5, bias=True)
#     (linear2): Linear(in_features=5, out_features=1, bias=True)
#   )
#   (linear): Linear(in_features=1, out_features=1, bias=True)
# )

ただしModuleを継承したプロパティのみがパラメータの探索対象となることに注意です。以下のようにlistなどでセットしてもパラメータ扱いされません。

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.linears = [nn.Linear(1, 5), nn.Linear(5, 1)]
        
    def forward(self, x):
        x = self.linears[0](x)
        x = torch.relu(x)
        x = self.linears[1](x)
        return x

model = MyModel()
print(model) # MyModel()
print(list(model.parameters()))  # []

ModuleListModuleDictでプロパティにセットすると、パラメータを探索するようになります。

  • それぞれ__getitem__や__len__などは実装されており、list・dictと同じアクセスができます
class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.linears = nn.ModuleList([
            nn.Linear(1, 5), nn.Linear(5, 1)
        ])
        
    def forward(self, x):
        x = self.linears[0](x)
        x = torch.relu(x)
        x = self.linears[1](x)
        return x

model = MyModel()
print(model)
# MyModel(
#   (linears): ModuleList(
#     (0): Linear(in_features=1, out_features=5, bias=True)
#     (1): Linear(in_features=5, out_features=1, bias=True)
#   )
# )

forward

forwardメソッドで各モジュールを順につなげます。
上の実装ではxを介して 入力 -> Linear -> relu -> Linear -> 出力 という順番で受け渡ししてます。

ただしこのforwardメソッドを直接実行して順伝播させるのはNGで、__call__を使わないといけません。__call__内ではforward以外にhookも一緒に実行しているためです。

hookには順伝播時に実行したい任意の処理を定義できます。以下ではregister_forward_pre_hookとregister_forward_hookでprint関数をセットし、__call__で順伝播させることで各hookを実行させてます。

  • register_forward_pre_hookメソッドで渡された関数は順伝播前、register_forward_hookメソッドで渡された関数は順伝播後に実行されます
  • ちなみにregister_backward_hookもあります
model.register_forward_pre_hook(lambda module, input: print(f"forward pre hook (input: {len(input)})"))
model.register_forward_hook(lambda module, input, output: print(f"forward hook (output: {output.shape})"))

y_p = model(x)
# forward pre hook (input: 1)
# forward hook (output: torch.Size([7, 1, 1]))

またreluはプロパティとして定義していない点にも注目です。学習可能なパラメータを持たない層 (relu, tanhやbatch_normなど) は状態の無いただの関数なので、プロパティで持つ必要がありません。

torch以下にはそういった関数形式のインタフェースで層が提供されています。torch.reluもその1つです。モジュールクラスではパラメータをプロパティで持ちますが、関数の場合は引数で重みなども渡します。
nn.ReLUクラスとしても提供されていますので、何度も同じ関数を挟むのであればプロパティとして持って再利用するほうが良いかもしれません。このあたりは好みの問題です。

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.linear_1 = nn.Linear(1, 5)
        self.linear_2 = nn.Linear(5, 1)
        # ...
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.linear_1(x)
        x = self.relu(x)
        x = self.linear_2(x)
        x = self.relu(x)
        # ...
        return x

Sequential

上のように「入力xを最初のモジュールの入力として、その出力を2番目のモジュールの入力として、...」という単純な数珠つなぎのforwardであれば、nn.Sequentialが便利です。
Sequentialクラスは上の数珠つなぎのforwardが既に定義されたモジュールなので、各層のモジュールオブジェクトを順番通りに渡すだけでモジュールを作成できます。

model = nn.Sequential(
    nn.Linear(1, 5),
    nn.ReLU(),
    nn.Linear(5, 1)
)

y_p = model(x)

まとめ

今回はPyTorchでネットワークを作りました。

  • PyTorchの層はnn.Moduleのサブクラスで実現されており、自分でネットワークを作る場合もnn.Moduleの継承クラスで実装する
  • Propertyオブジェクトとしてプロパティに持たせることで、学習可能なパラメータとして認識される
  • Moduleサブクラスオブジェクトをプロパティに持つことで、そのプロパティに対しても再帰的にパラメータが探索される

参考

Deep Learning with PyTorch | PyTorch

  • 5章を主に参考にしました

PyTorchは誤差逆伝播とパラメータ更新をどうやって行っているのか?

引き続きお仕事でPyTorchを使った開発を行っているのですが、これまでKerasで高度にラッピングされた学習フレームワークしか経験が無かったので、お作法的なところで躓くこと・疑問に思うことがよくありました。

  • loss.backward()で計算グラフを伝って誤差逆伝播されるのはなんとなくわかる
    • だけど、その計算方法や計算結果は誰が持ってて、入力側へどうやって渡してるのだろうか...
  • optimizer.zero_grad()optimizer.step()は何をしているの?

今回はPyTorchの誤差逆伝播やパラメータ更新について調べて整理しました。

この投稿ではPyTorch 1.1.0を使ってます。

import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

print(torch.__version__)  # 1.1.0

1次関数のモデリング例

1次関数  y = wx + b + \epsilon のパラメータ (w=4.5, b=7.0) を獲得する簡単な問題を例にします。

np.random.seed(42)

w = 4.5
b = 7.0

x_array = np.array([-1.5, -1.0, -0.1, 0.9, 1.8, 2.2, 3.1])
y_array = w * x_array + np.random.normal(size=x_array.shape[0])
plt.scatter(x_array, y_array)  # 下図

x = torch.tensor(x_array).float()
print(x)  # tensor([-1.5000, -1.0000, -0.1000,  0.9000,  1.8000,  2.2000,  3.1000])
y = torch.tensor(y_array).float()
print(y)  # tensor([ 0.7467,  2.3617,  7.1977, 12.5730, 14.8658, 16.6659, 22.5292])

以降ではモデル (1次関数) と誤差関数 (平均2乗誤差) を定義・使用しています。

誤差逆伝播

まずPyTorchの誤差逆伝播の実装を掘り下げて見ていきます。誤差逆伝播は3ステップで行います。

  1. 勾配計算が必要なパラメータ (requires_grad = True) を使ってモデルで計算を行い、
  2. 得られた出力とground truthから誤差を計算して、
  3. 誤差関数の出力 (Tensorオブジェクト) のbackwardメソッドを呼び出す
# w=1.0, b=0.0で初期化
param_w = torch.tensor([1.0], requires_grad=True)
param_b = torch.tensor([0.0], requires_grad=True)
print(param_w.is_leaf, param_b.is_leaf)  # True True

# 1次関数を仮定して計算 (Step 1)
y_p = param_w * x + param_b
print(y_p, y_p.is_leaf)
# tensor([-1.5000, -1.0000, -0.1000,  0.9000,  1.8000,  2.2000,  3.1000],
#        grad_fn=<AddBackward0>) False

# 平均2乗誤差を計算 (Step 2)
loss = torch.mean((y_p - y)**2)
print(loss)
# tensor(137.6195, grad_fn=<MeanBackward0>)

print(y_p.grad_fn)
# <AddBackward0 object at 0x7f33db45f588>
print(loss.grad_fn.next_functions[0][0].next_functions[0][0].next_functions[0][0])
# <AddBackward0 object at 0x7f33db45fac8>

print("before:", param_w.grad, param_b.grad)
# before: None None

# 誤差を伝播 (Step 3)
loss.backward()

print("after:", param_w.grad, param_b.grad)
# after: tensor([-33.8909]) tensor([-20.4400])

Step 1と2で計算グラフの構築が行われます。

grad_fnプロパティに偏微分 (勾配) 計算に使う関数オブジェクトがセットされてます。Tensorの演算関数毎に定義されており、以下のように__call__で実行できます。加算の場合は、2つのオペランドにそのまま勾配が伝搬されますので、値1.5が2つ並んだTensorオブジェクトが得られます。

y_p.grad_fn(torch.tensor([-1.5]))  # (tensor([1.5000]), tensor([1.5000]))

またnext_functionsプロパティには、誤差を伝播する先の関数オブジェクトへの参照が入ってます。next_functionsプロパティによって計算グラフのルート (loss) からリーフ (パラメータ、ここではparam_wとparam_b) へ伝うことができます。lossのnext_functionを3段伝って行くと、y_pのgrad_fn (AddBackward0) に至ることがわかるかと思います。
ただし、メモリアドレスが一致していません (0x7f33db45f588 と 0x7f33db45fac8) 。PyTorchの計算グラフは計算を実行するごとに組み替えられ、y_p = param_w * x + param_b を実行したときと loss = torch.mean((y_p - y)**2) を実行した時で実体が異なるためです。

Step 3で計算グラフのルート (loss) からリーフ (パラメータ、ここではparam_wとparam_b) へ勾配を伝播させ、順々に計算・累積させていきます。勾配は各Tensorオブジェクトのgradプロパティで持っており、backwardをコールすることでセットされます

backwardメソッドでは勾配が累積されることに注意です。ミニバッチなどを考慮してこのようになっているのですが、以下のように3回続けてbackwardを呼び出すと勾配も3倍になります。
このためパラメータの更新が終わった後に勾配をゼロクリアする必要があります。optimパッケージのzero_gradメソッドでやっていることはこれと同じです。

for i in range(3):
    y_p = param_w * x + param_b
    loss = torch.mean((y_p - y)**2)
    loss.backward()
    print(param_w.grad, param_b.grad)

# tensor([-33.8909]) tensor([-20.4400])
# tensor([-67.7818]) tensor([-40.8801])
# tensor([-101.6727]) tensor([-61.3201])

for i in range(3):
    if param_w.grad: param_w.grad.zero_()
    if param_b.grad: param_b.grad.zero_()
    y_p = param_w * x + param_b
    loss = torch.mean((y_p - y)**2)
    loss.backward()
    print(param_w.grad, param_b.grad)

# tensor([-33.8909]) tensor([-20.4400])
# tensor([-33.8909]) tensor([-20.4400])
# tensor([-33.8909]) tensor([-20.4400])

また評価時などで損失関数の計算などは行うが勾配計算に含めない (= 学習に使わない) といったこともできます。with torch.no_grad():のコンテキストで括るだけでOKです。

  • このコンテキスト内で生成されたTensorオブジェクトはgradを計算しません
with torch.no_grad():
    print(param_w.requires_grad, param_b.requires_grad)
    
    y_p = param_w * x + param_b
    print(y_p, y_p.requires_grad, y_p.grad)
    # tensor([-1.5000, -1.0000, -0.1000,  0.9000,  1.8000,  2.2000,  3.1000]) False None
    
    loss = torch.mean((y_p - y)**2)
    print(loss, loss.requires_grad, loss.grad)
    # tensor(137.6195) False None

パラメータ更新

ここまで仕組みを理解できれば、後はgradプロパティに適当な学習率をかけてパラメータを更新すれば、簡単な勾配降下法を実装できることが予想できます。optimパッケージのstepメソッドはまさにこれをやってます。

ただし、param_wとparam_bはフォワードにて既に計算グラフ内に組み込まれていますので、以下のようにナイーブに計算するとおかしくなります。

learning_rate = 0.01

param_w = param_w - param_w.grad * learning_rate
param_b = param_b - param_b.grad * learning_rate

print(param_w, param_b)
# tensor([1.3389], grad_fn=<SubBackward0>) tensor([0.2044], grad_fn=<SubBackward0>)

Tensorのdetachメソッドで計算グラフから切り離すことができます。

  • さらに勾配を再計算する場合は、忘れずrequires_gradプロパティをTrueにセットしてください
learning_rate = 0.01

param_w = (param_w - param_w.grad * learning_rate).detach().requires_grad_()
param_b = (param_b - param_b.grad * learning_rate).detach().requires_grad_()

print(param_w, param_b)
# tensor([1.2309], requires_grad=True) tensor([0.0644], requires_grad=True)

ここまでの実装をまとめると、Tensorだけで学習できるようになります。

param_w = torch.tensor([1.0], requires_grad=True)
param_b = torch.tensor([0.0], requires_grad=True)

epochs = 300
learning_rate = 0.01

for epoch in range(1, epochs + 1):
    y_p = param_w * x + param_b
    loss = torch.mean((y_p - y)**2)
    
    loss.backward()
    
    param_w = (param_w - param_w.grad * learning_rate).detach().requires_grad_()
    param_b = (param_b - param_b.grad * learning_rate).detach().requires_grad_()
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}: loss={loss}, param_w={float(param_w)}, param_b={float(param_b)}")

# Epoch 10: loss=52.31643295288086, param_w=3.4754106998443604, param_b=1.6848385334014893
# Epoch 20: loss=23.30514144897461, param_w=4.616385459899902, param_b=2.8110599517822266
# ...
# Epoch 300: loss=0.5062416791915894, param_w=4.625252723693848, param_b=7.377382755279541

optim.SGDで置き換えても、まったく同じ結果になることが確認できます。つまりTensorだけでSGDを再実装できたことになります。やったね。

# ... (省略) ...

optimizer = optim.SGD([param_w, param_b], lr=learning_rate)

for epoch in range(1, epochs + 1):
    y_p = param_w * x + param_b
    loss = torch.mean((y_p - y)**2)
    
    optimizer.zero_grad()
    loss.backward()
    
    optimizer.step()
        
    if epoch % 10 == 0:
        print(f"Epoch {epoch}: loss={loss}, param_w={float(param_w)}, param_b={float(param_b)}")
        
# Epoch 10: loss=52.31643295288086, param_w=3.4754106998443604, param_b=1.6848385334014893
# Epoch 20: loss=23.30513572692871, param_w=4.616385459899902, param_b=2.8110601902008057
# ...
# Epoch 300: loss=0.5062416791915894, param_w=4.625252723693848, param_b=7.377382755279541

まとめ

PyTorchの誤差逆伝播とパラメータ更新について詳しく見ていきました。

  • 誤差はTensor.gradプロパティに累積して持つ
  • Tensor.next_functionsプロパティを伝ってパラメータまで伝搬させる
  • gradプロパティから勾配を計算してパラメータを更新する

参考

Deep Learning with PyTorch | PyTorch

  • 4章を主に参考にしました