PyTorchは誤差逆伝播とパラメータ更新をどうやって行っているのか?

引き続きお仕事でPyTorchを使った開発を行っているのですが、これまでKerasで高度にラッピングされた学習フレームワークしか経験が無かったので、お作法的なところで躓くこと・疑問に思うことがよくありました。

  • loss.backward()で計算グラフを伝って誤差逆伝播されるのはなんとなくわかる
    • だけど、その計算方法や計算結果は誰が持ってて、入力側へどうやって渡してるのだろうか...
  • optimizer.zero_grad()optimizer.step()は何をしているの?

今回はPyTorchの誤差逆伝播やパラメータ更新について調べて整理しました。

この投稿ではPyTorch 1.1.0を使ってます。

import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

print(torch.__version__)  # 1.1.0

1次関数のモデリング例

1次関数  y = wx + b + \epsilon のパラメータ (w=4.5, b=7.0) を獲得する簡単な問題を例にします。

np.random.seed(42)

w = 4.5
b = 7.0

x_array = np.array([-1.5, -1.0, -0.1, 0.9, 1.8, 2.2, 3.1])
y_array = w * x_array + np.random.normal(size=x_array.shape[0])
plt.scatter(x_array, y_array)  # 下図

x = torch.tensor(x_array).float()
print(x)  # tensor([-1.5000, -1.0000, -0.1000,  0.9000,  1.8000,  2.2000,  3.1000])
y = torch.tensor(y_array).float()
print(y)  # tensor([ 0.7467,  2.3617,  7.1977, 12.5730, 14.8658, 16.6659, 22.5292])

以降ではモデル (1次関数) と誤差関数 (平均2乗誤差) を定義・使用しています。

誤差逆伝播

まずPyTorchの誤差逆伝播の実装を掘り下げて見ていきます。誤差逆伝播は3ステップで行います。

  1. 勾配計算が必要なパラメータ (requires_grad = True) を使ってモデルで計算を行い、
  2. 得られた出力とground truthから誤差を計算して、
  3. 誤差関数の出力 (Tensorオブジェクト) のbackwardメソッドを呼び出す
# w=1.0, b=0.0で初期化
param_w = torch.tensor([1.0], requires_grad=True)
param_b = torch.tensor([0.0], requires_grad=True)
print(param_w.is_leaf, param_b.is_leaf)  # True True

# 1次関数を仮定して計算 (Step 1)
y_p = param_w * x + param_b
print(y_p, y_p.is_leaf)
# tensor([-1.5000, -1.0000, -0.1000,  0.9000,  1.8000,  2.2000,  3.1000],
#        grad_fn=<AddBackward0>) False

# 平均2乗誤差を計算 (Step 2)
loss = torch.mean((y_p - y)**2)
print(loss)
# tensor(137.6195, grad_fn=<MeanBackward0>)

print(y_p.grad_fn)
# <AddBackward0 object at 0x7f33db45f588>
print(loss.grad_fn.next_functions[0][0].next_functions[0][0].next_functions[0][0])
# <AddBackward0 object at 0x7f33db45fac8>

print("before:", param_w.grad, param_b.grad)
# before: None None

# 誤差を伝播 (Step 3)
loss.backward()

print("after:", param_w.grad, param_b.grad)
# after: tensor([-33.8909]) tensor([-20.4400])

Step 1と2で計算グラフの構築が行われます。

grad_fnプロパティに偏微分 (勾配) 計算に使う関数オブジェクトがセットされてます。Tensorの演算関数毎に定義されており、以下のように__call__で実行できます。加算の場合は、2つのオペランドにそのまま勾配が伝搬されますので、値1.5が2つ並んだTensorオブジェクトが得られます。

y_p.grad_fn(torch.tensor([-1.5]))  # (tensor([1.5000]), tensor([1.5000]))

またnext_functionsプロパティには、誤差を伝播する先の関数オブジェクトへの参照が入ってます。next_functionsプロパティによって計算グラフのルート (loss) からリーフ (パラメータ、ここではparam_wとparam_b) へ伝うことができます。lossのnext_functionを3段伝って行くと、y_pのgrad_fn (AddBackward0) に至ることがわかるかと思います。
ただし、メモリアドレスが一致していません (0x7f33db45f588 と 0x7f33db45fac8) 。PyTorchの計算グラフは計算を実行するごとに組み替えられ、y_p = param_w * x + param_b を実行したときと loss = torch.mean((y_p - y)**2) を実行した時で実体が異なるためです。

Step 3で計算グラフのルート (loss) からリーフ (パラメータ、ここではparam_wとparam_b) へ勾配を伝播させ、順々に計算・累積させていきます。勾配は各Tensorオブジェクトのgradプロパティで持っており、backwardをコールすることでセットされます

backwardメソッドでは勾配が累積されることに注意です。ミニバッチなどを考慮してこのようになっているのですが、以下のように3回続けてbackwardを呼び出すと勾配も3倍になります。
このためパラメータの更新が終わった後に勾配をゼロクリアする必要があります。optimパッケージのzero_gradメソッドでやっていることはこれと同じです。

for i in range(3):
    y_p = param_w * x + param_b
    loss = torch.mean((y_p - y)**2)
    loss.backward()
    print(param_w.grad, param_b.grad)

# tensor([-33.8909]) tensor([-20.4400])
# tensor([-67.7818]) tensor([-40.8801])
# tensor([-101.6727]) tensor([-61.3201])

for i in range(3):
    if param_w.grad: param_w.grad.zero_()
    if param_b.grad: param_b.grad.zero_()
    y_p = param_w * x + param_b
    loss = torch.mean((y_p - y)**2)
    loss.backward()
    print(param_w.grad, param_b.grad)

# tensor([-33.8909]) tensor([-20.4400])
# tensor([-33.8909]) tensor([-20.4400])
# tensor([-33.8909]) tensor([-20.4400])

また評価時などで損失関数の計算などは行うが勾配計算に含めない (= 学習に使わない) といったこともできます。with torch.no_grad():のコンテキストで括るだけでOKです。

  • このコンテキスト内で生成されたTensorオブジェクトはgradを計算しません
with torch.no_grad():
    print(param_w.requires_grad, param_b.requires_grad)
    
    y_p = param_w * x + param_b
    print(y_p, y_p.requires_grad, y_p.grad)
    # tensor([-1.5000, -1.0000, -0.1000,  0.9000,  1.8000,  2.2000,  3.1000]) False None
    
    loss = torch.mean((y_p - y)**2)
    print(loss, loss.requires_grad, loss.grad)
    # tensor(137.6195) False None

パラメータ更新

ここまで仕組みを理解できれば、後はgradプロパティに適当な学習率をかけてパラメータを更新すれば、簡単な勾配降下法を実装できることが予想できます。optimパッケージのstepメソッドはまさにこれをやってます。

ただし、param_wとparam_bはフォワードにて既に計算グラフ内に組み込まれていますので、以下のようにナイーブに計算するとおかしくなります。

learning_rate = 0.01

param_w = param_w - param_w.grad * learning_rate
param_b = param_b - param_b.grad * learning_rate

print(param_w, param_b)
# tensor([1.3389], grad_fn=<SubBackward0>) tensor([0.2044], grad_fn=<SubBackward0>)

Tensorのdetachメソッドで計算グラフから切り離すことができます。

  • さらに勾配を再計算する場合は、忘れずrequires_gradプロパティをTrueにセットしてください
learning_rate = 0.01

param_w = (param_w - param_w.grad * learning_rate).detach().requires_grad_()
param_b = (param_b - param_b.grad * learning_rate).detach().requires_grad_()

print(param_w, param_b)
# tensor([1.2309], requires_grad=True) tensor([0.0644], requires_grad=True)

ここまでの実装をまとめると、Tensorだけで学習できるようになります。

param_w = torch.tensor([1.0], requires_grad=True)
param_b = torch.tensor([0.0], requires_grad=True)

epochs = 300
learning_rate = 0.01

for epoch in range(1, epochs + 1):
    y_p = param_w * x + param_b
    loss = torch.mean((y_p - y)**2)
    
    loss.backward()
    
    param_w = (param_w - param_w.grad * learning_rate).detach().requires_grad_()
    param_b = (param_b - param_b.grad * learning_rate).detach().requires_grad_()
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}: loss={loss}, param_w={float(param_w)}, param_b={float(param_b)}")

# Epoch 10: loss=52.31643295288086, param_w=3.4754106998443604, param_b=1.6848385334014893
# Epoch 20: loss=23.30514144897461, param_w=4.616385459899902, param_b=2.8110599517822266
# ...
# Epoch 300: loss=0.5062416791915894, param_w=4.625252723693848, param_b=7.377382755279541

optim.SGDで置き換えても、まったく同じ結果になることが確認できます。つまりTensorだけでSGDを再実装できたことになります。やったね。

# ... (省略) ...

optimizer = optim.SGD([param_w, param_b], lr=learning_rate)

for epoch in range(1, epochs + 1):
    y_p = param_w * x + param_b
    loss = torch.mean((y_p - y)**2)
    
    optimizer.zero_grad()
    loss.backward()
    
    optimizer.step()
        
    if epoch % 10 == 0:
        print(f"Epoch {epoch}: loss={loss}, param_w={float(param_w)}, param_b={float(param_b)}")
        
# Epoch 10: loss=52.31643295288086, param_w=3.4754106998443604, param_b=1.6848385334014893
# Epoch 20: loss=23.30513572692871, param_w=4.616385459899902, param_b=2.8110601902008057
# ...
# Epoch 300: loss=0.5062416791915894, param_w=4.625252723693848, param_b=7.377382755279541

まとめ

PyTorchの誤差逆伝播とパラメータ更新について詳しく見ていきました。

  • 誤差はTensor.gradプロパティに累積して持つ
  • Tensor.next_functionsプロパティを伝ってパラメータまで伝搬させる
  • gradプロパティから勾配を計算してパラメータを更新する

参考

Deep Learning with PyTorch | PyTorch

  • 4章を主に参考にしました