PyTorchのモデルのパラメータ数をカウントする方法です。2パターンあります。
1. Moduleのparametersメソッドを合計する
Module.parametersメソッドで各層のパラメータがtensorで取得できますので、numelで要素数を合計していくことでパラメータ数を計算できます。
- requires_gradがTrueのパラメータが学習可能です
import torch import torch.nn as nn class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = self._conv(3, 16) self.conv2 = self._conv(16, 32) self.conv3 = self._conv(32, 64) self.pool = nn.AvgPool2d(4) self.dropout = nn.Dropout(0.4) self.fc = nn.Linear(64, 10) def forward(self, x): x = self.conv1(x) # (3, 32, 32) -> (16, 16, 16) x = self.conv2(x) # (16, 16, 16) -> (32, 8, 8) x = self.conv3(x) # (32, 8, 8) -> (64, 4, 4) x = self.pool(x) # (64, 1, 1) -> (64) -> (10) x = x.view(-1, 64) x = self.dropout(x) x = self.fc(x) return x def _conv(self, input_filters, output_filters): return nn.Sequential( nn.Conv2d(input_filters, output_filters, 3, stride=1, padding=1), nn.BatchNorm2d(output_filters), nn.ReLU(), nn.Conv2d(output_filters, output_filters, 3, stride=1, padding=1), nn.BatchNorm2d(output_filters), nn.ReLU(), nn.Conv2d(output_filters, output_filters, 3, stride=2, padding=1), nn.BatchNorm2d(output_filters), nn.ReLU(), nn.Dropout2d(0.1), ) net = Net() # パラメータカウント params = 0 for p in net.parameters(): if p.requires_grad: params += p.numel() print(params) # 121898
2. pytorch-summaryを使う
Kerasライクなモデルのサマライズを行うパッケージ pytorch-summary を使うことでもパラメータ数を取得できます。
pip install pytorch-summary
で事前にインストールします。
あとは以下のようにsummaryメソッドにモデルと入力サイズを入れると、モデルのパラメータを計算してくれます。
- Total paramsが上で求めた値と一致することを確認ください
from torchsummary import summary device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") net.to(device) summary(net, (3, 32, 32)) # ---------------------------------------------------------------- # Layer (type) Output Shape Param # # ================================================================ # Conv2d-1 [-1, 16, 32, 32] 448 # BatchNorm2d-2 [-1, 16, 32, 32] 32 # ReLU-3 [-1, 16, 32, 32] 0 # Conv2d-4 [-1, 16, 32, 32] 2,320 # BatchNorm2d-5 [-1, 16, 32, 32] 32 # ReLU-6 [-1, 16, 32, 32] 0 # Conv2d-7 [-1, 16, 16, 16] 2,320 # BatchNorm2d-8 [-1, 16, 16, 16] 32 # ReLU-9 [-1, 16, 16, 16] 0 # Dropout2d-10 [-1, 16, 16, 16] 0 # Conv2d-11 [-1, 32, 16, 16] 4,640 # BatchNorm2d-12 [-1, 32, 16, 16] 64 # ReLU-13 [-1, 32, 16, 16] 0 # Conv2d-14 [-1, 32, 16, 16] 9,248 # BatchNorm2d-15 [-1, 32, 16, 16] 64 # ReLU-16 [-1, 32, 16, 16] 0 # Conv2d-17 [-1, 32, 8, 8] 9,248 # BatchNorm2d-18 [-1, 32, 8, 8] 64 # ReLU-19 [-1, 32, 8, 8] 0 # Dropout2d-20 [-1, 32, 8, 8] 0 # Conv2d-21 [-1, 64, 8, 8] 18,496 # BatchNorm2d-22 [-1, 64, 8, 8] 128 # ReLU-23 [-1, 64, 8, 8] 0 # Conv2d-24 [-1, 64, 8, 8] 36,928 # BatchNorm2d-25 [-1, 64, 8, 8] 128 # ReLU-26 [-1, 64, 8, 8] 0 # Conv2d-27 [-1, 64, 4, 4] 36,928 # BatchNorm2d-28 [-1, 64, 4, 4] 128 # ReLU-29 [-1, 64, 4, 4] 0 # Dropout2d-30 [-1, 64, 4, 4] 0 # AvgPool2d-31 [-1, 64, 1, 1] 0 # Dropout-32 [-1, 64] 0 # Linear-33 [-1, 10] 650 # ================================================================ # Total params: 121,898 # Trainable params: 121,898 # Non-trainable params: 0 # ---------------------------------------------------------------- # Input size (MB): 0.01 # Forward/backward pass size (MB): 1.53 # Params size (MB): 0.47 # Estimated Total Size (MB): 2.01 # ----------------------------------------------------------------