PyTorchでパラメータ数をカウントする

PyTorchのモデルのパラメータ数をカウントする方法です。2パターンあります。

1. Moduleのparametersメソッドを合計する

Module.parametersメソッドで各層のパラメータがtensorで取得できますので、numelで要素数を合計していくことでパラメータ数を計算できます。

  • requires_gradがTrueのパラメータが学習可能です
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.conv1 = self._conv(3, 16)
        self.conv2 = self._conv(16, 32)
        self.conv3 = self._conv(32, 64)
        self.pool = nn.AvgPool2d(4)
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(64, 10)
        
    def forward(self, x):
        x = self.conv1(x)  # (3, 32, 32) -> (16, 16, 16)
        x = self.conv2(x)  # (16, 16, 16) -> (32, 8, 8)
        x = self.conv3(x)  # (32, 8, 8) -> (64, 4, 4)
        x = self.pool(x)
        
        # (64, 1, 1) -> (64) -> (10)
        x = x.view(-1, 64)
        x = self.dropout(x)
        x = self.fc(x)
        
        return x
    
    def _conv(self, input_filters, output_filters):
        return nn.Sequential(
            nn.Conv2d(input_filters, output_filters, 3, stride=1, padding=1),
            nn.BatchNorm2d(output_filters),
            nn.ReLU(),
            nn.Conv2d(output_filters, output_filters, 3, stride=1, padding=1),
            nn.BatchNorm2d(output_filters),
            nn.ReLU(),
            nn.Conv2d(output_filters, output_filters, 3, stride=2, padding=1),
            nn.BatchNorm2d(output_filters),
            nn.ReLU(),
            nn.Dropout2d(0.1),
        )

net = Net()

# パラメータカウント
params = 0
for p in net.parameters():
    if p.requires_grad:
        params += p.numel()
        
print(params)  # 121898

2. pytorch-summaryを使う

Kerasライクなモデルのサマライズを行うパッケージ pytorch-summary を使うことでもパラメータ数を取得できます。
pip install pytorch-summaryで事前にインストールします。

github.com

あとは以下のようにsummaryメソッドにモデルと入力サイズを入れると、モデルのパラメータを計算してくれます。

  • Total paramsが上で求めた値と一致することを確認ください
from torchsummary import summary

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net.to(device)

summary(net, (3, 32, 32))
# ----------------------------------------------------------------
#         Layer (type)               Output Shape         Param #
# ================================================================
#             Conv2d-1           [-1, 16, 32, 32]             448
#        BatchNorm2d-2           [-1, 16, 32, 32]              32
#               ReLU-3           [-1, 16, 32, 32]               0
#             Conv2d-4           [-1, 16, 32, 32]           2,320
#        BatchNorm2d-5           [-1, 16, 32, 32]              32
#               ReLU-6           [-1, 16, 32, 32]               0
#             Conv2d-7           [-1, 16, 16, 16]           2,320
#        BatchNorm2d-8           [-1, 16, 16, 16]              32
#               ReLU-9           [-1, 16, 16, 16]               0
#         Dropout2d-10           [-1, 16, 16, 16]               0
#            Conv2d-11           [-1, 32, 16, 16]           4,640
#       BatchNorm2d-12           [-1, 32, 16, 16]              64
#              ReLU-13           [-1, 32, 16, 16]               0
#            Conv2d-14           [-1, 32, 16, 16]           9,248
#       BatchNorm2d-15           [-1, 32, 16, 16]              64
#              ReLU-16           [-1, 32, 16, 16]               0
#            Conv2d-17             [-1, 32, 8, 8]           9,248
#       BatchNorm2d-18             [-1, 32, 8, 8]              64
#              ReLU-19             [-1, 32, 8, 8]               0
#         Dropout2d-20             [-1, 32, 8, 8]               0
#            Conv2d-21             [-1, 64, 8, 8]          18,496
#       BatchNorm2d-22             [-1, 64, 8, 8]             128
#              ReLU-23             [-1, 64, 8, 8]               0
#            Conv2d-24             [-1, 64, 8, 8]          36,928
#       BatchNorm2d-25             [-1, 64, 8, 8]             128
#              ReLU-26             [-1, 64, 8, 8]               0
#            Conv2d-27             [-1, 64, 4, 4]          36,928
#       BatchNorm2d-28             [-1, 64, 4, 4]             128
#              ReLU-29             [-1, 64, 4, 4]               0
#         Dropout2d-30             [-1, 64, 4, 4]               0
#         AvgPool2d-31             [-1, 64, 1, 1]               0
#           Dropout-32                   [-1, 64]               0
#            Linear-33                   [-1, 10]             650
# ================================================================
# Total params: 121,898
# Trainable params: 121,898
# Non-trainable params: 0
# ----------------------------------------------------------------
# Input size (MB): 0.01
# Forward/backward pass size (MB): 1.53
# Params size (MB): 0.47
# Estimated Total Size (MB): 2.01
# ----------------------------------------------------------------