PyTorch: DatasetとDataLoader (画像処理タスク編)

引き続きPyTorchのお勉強です。

画像処理タスクの文脈でDatasetとDataLoaderの使い方を整理していきます。

DatasetとDataLoader

PyTorchに限らず、ディープラーニングタスクのデータの入力については、一般的に以下の要件が挙げられます

  • データをミニバッチにまとめる
  • 任意の前処理を実行する
  • (CPU依存のため) 適切に並列処理させたい

PyTorchで上のニーズに応えるのがDatasetとDataLoaderです。Datasetが取り出したデータをDataLoaderがミニバッチにまとめるという関係になってます。いずれもtorch.utils.dataに属します

  • Datasetは、データソース (ファイルやNumPy行列など) から1個ずつデータ (入力とground truth) を取り出す
  • DataLoaderは、Datasetからバッチサイズずつデータを取り出す
    • このミニバッチがモデルへの入力となる

データセットの準備

PyTorchチュートリアルの中で提供されている画像データセットをダウンロードして、./hymenoptera_data/{train or val}/{ants or bees}/ファイル名.jpgに展開します。アリかハチを見分けるタスクになっており、ファイルパスで区別されています。

import numpy as np
from PIL import Image
from pathlib import Path
import urllib.request
import zipfile
import random

import torch
import torch.utils.data as data
from torchvision import transforms

print(torch.__version__)  # 1.3.1

urllib.request.urlretrieve(
    "https://download.pytorch.org/tutorial/hymenoptera_data.zip", 
    "hymenoptera_data.zip"
)

with zipfile.ZipFile("./hymenoptera_data.zip") as zip:
    zip.extractall(".")
    
print(list(pathlib.Path("hymenoptera_data/val").glob("**/*.jpg"))[:4])
# [PosixPath('hymenoptera_data/val/bees/2478216347_535c8fe6d7.jpg'),
#  PosixPath('hymenoptera_data/val/bees/1486120850_490388f84b.jpg'),
#  PosixPath('hymenoptera_data/val/bees/2060668999_e11edb10d0.jpg'),
#  PosixPath('hymenoptera_data/val/bees/2104135106_a65eede1de.jpg')]

Dataset

Datasetはデータソースから1個ずつデータを取り出すクラスです。DataLoaderにはこのDataset継承クラスのオブジェクトを渡します。
(ストリーミングデータ向けに、イテレータ形式で取り出すためのIterableDatasetもあります。ここではDatasetクラスを中心に扱いたいと思います。)

このクラスを継承して実装します。学習と評価で、ロードするデータやオーグメンテーション有無などの要件が異なりますので、別のクラスに分けるか、あるいは、初期化時の引数などで切り替えられるようにしておくのが一般的です。

  • 実装が必要なメソッドは__getitem__と__len__の2つ
  • __getitem__は入力とGTをタプルにして返します
    • 入力は0次元目をチャネルに変換します
class MyDataset(data.Dataset):
    def __init__(self, dir_path, input_size):
        super().__init__()
        
        self.dir_path = dir_path
        self.input_size = input_size
        
        # hymenoptera_data/{train or val}/{ants or bees}/ファイル名.jpg
        # ex. hymenoptera_data/val/bees/2478216347_535c8fe6d7.jpg
        self.image_paths = [str(p) for p in Path(self.dir_path).glob("**/*.jpg")]
        self.len = len(self.image_paths)
        
    def __len__(self):
        return self.len
    
    def __getitem__(self, index):
        p = self.image_paths[index]
        
        # 入力
        image = Image.open(p)
        image = image.resize(self.input_size)
        image = np.array(image)
        image = np.transpose(image, (2, 0, 1))
        image = torch.from_numpy(image)
        
        # ラベル (0: ants, 1: bees)
        label = p.split("/")[2]
        label = 1 if label == "bees" else 0
        
        return image, label

train_dataset = MyDataset("./hymenoptera_data/train/", (224, 224))

image, label = train_dataset[0]
print(image.size(), label)  # torch.Size([3, 224, 224]) 1

torchvision.transforms

データを取り出すとともに、入力データに対する前処理が必要な場合もDataset内で実現することがしばしばあります。
上の実装でも、画像に対してリサイズ・インデックス入れ替え・Tensor化を行ってます。

画像の場合、torchvisionパッケージに含まれるtransformsを使うとより簡単に実装できます。

  • 学習と評価で必要な前処理が異なりますので、__init__にphaseを追加
  • Composeでまとめることで、順に一連の変換処理を実行する関数へとまとめられます
class MyDataset(data.Dataset):
    def __init__(self, dir_path, input_size, phase):
        super().__init__()
        
        # ... (省略) ...
        
        if phase == "train":
            transform_ops.extend([
                transforms.RandomVerticalFlip(),  # ランダムに左右反転
                transforms.RandomRotation(20),  # ±20度でランダムに回転
                transforms.RandomResizedCrop(input_size, scale=(0.8, 1.0))  # ランダムにクロップしてリサイズ
            ])
        elif phase == "val":
            transform_ops.append(
                transforms.Resize(input_size)  # リサイズ
            )
            
        transform_ops.extend([
            transforms.ToTensor(),  # Tensor化
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 標準化
        ])
        
        # 1つの変換処理に集約
        self.transformer = transforms.Compose(transform_ops)
        
    def __len__(self):
        return self.len
    
    def __getitem__(self, index):
        p = self.image_paths[index]
        
        # 入力
        image = Image.open(p)
        image = self.transformer(image)  # __init__で定義した関数に置き換え
        
        # ラベル (0: ants, 1: bees)
        label = p.split("/")[2]
        label = 1 if label == "bees" else 0
        
        return image, label

DataLoader

DataLoaderはバッチサイズごとにDatasetからデータを取り出します。継承などは行わず、作成したDatasetオブジェクトとバッチサイズを渡してオブジェクト化します。

  • __iter__が定義されており、forループなどでバッチ単位で取り出すことができます

batch_size=16ですので、入力が (16, 3, 224, 224) 、ラベルも (16) のサイズになります。

  • shuffle=Trueで順番がランダムになりますので、学習時はTrue、評価時はFalseとするのが一般的です
  • num_workersで実行プロセス数を指定でき、2以上の値を設定すると並列化できます (デフォルトは0でメインプロセスでの実行になります)
  • drop_last=Trueとすると、末尾の方で残りデータ数がbatch_sizeに満たない場合にスキップされます
    • バッチ標準化などはデータ数が1だと計算できずエラーになるため (これも学習時のみのTrue)
train_dataloader = data.DataLoader(
    train_dataset, batch_size=16, shuffle=True,
    num_workers=2, drop_last=True
)

images, labels = next(iter(train_dataloader))
print(images.size())  # torch.Size([16, 3, 224, 224])
print(labels.size())  # torch.Size([16])

collate_fn

入力とラベルが1対1で紐付かないケースがあります。例えば物体検出では、1画像中に含まれるオブジェクトの数は、0個の場合もあれば2個以上の場合もあります。

DataLoaderは、デフォルトではtorch.stackで1つのTensorオブジェクトへ連結しようとします。ところが以下のlabelsのように、0次元目の次元数がデータによって異なる場合は、連結できずにエラーとなってしまいます。

class DummyDataset(data.Dataset):
    # 1ラベルの物体検出のデータセット
    # (ex. 複数人が写っている写真の中から人の位置を特定する、など)
    def __init__(self, size=100):
        super().__init__()
        
        self.len = size
        
    def __len__(self):
        return self.len
        
    def __getitem__(self, index):
        labels = []
        object_num = random.randint(0, 4)  # 0〜4個でランダム

        dummy_img = torch.randint(0, 255, (3, 224, 224))  # ダミー画像
        dummy_bbox = [1, 2, 3, 4]  # ダミーのバウンディングボックス
        
        for _ in range(object_num):
            # 物体数分のバウンディングボックス
            # [], ..., [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]
            labels.append(dummy_bbox)
            
        labels = torch.tensor(labels)
            
        return dummy_img, labels

dummy_dataset = DummyDataset()
image, labels = dummy_dataset[10]
print(image.size(), labels.size())

dummy_dataloader = data.DataLoader(
    dummy_dataset, batch_size=4, shuffle=True,
    num_workers=2, drop_last=True
)

images, labels = next(iter(dummy_dataloader))
# 内部で実行しているtorch.stackでエラーになる

こうした特殊な変換が必要なケースでは、DataLoaderのcollate_fnへ自前の変換関数を渡すことで対応します。

collate_fnは第1引数 (以下ではbatch) にDatasetの返り値のリストが格納されています。DummyDatasetの場合は画像 (Tensor) とラベル (Tensorリスト) のタプルリストです。
以下のようにlabelsは1つのTensorに変換せず、Tensorリストのまま返すようにすることで、ラベルが可変長の場合でもバッチにまとめられるようにしています。

def collate_fn(batch):
    # bactchはDatasetの返り値 (タプル) のリスト
    images, labels = [], []
    
    for image, label in batch:
        images.append(image)
        labels.append(label)
        
    images = torch.stack(images, dim=0)

    # labelsはTensorリストのまま

    return images, labels

dummy_dataloader = data.DataLoader(
    dummy_dataset, batch_size=4, shuffle=True,
    num_workers=2, drop_last=True, collate_fn=collate_fn
)

images, labels = next(iter(dummy_dataloader))
print(images.size())  # torch.Size([4, 3, 224, 224])
print(labels)
# [tensor([[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]),
#  tensor([[1, 2, 3, 4]]),
#  tensor([]),
#  tensor([[1, 2, 3, 4], [1, 2, 3, 4]])]

まとめ

画像処理タスクのケースを中心にPyTorchのDatasetとDataLoaderについてまとめました。

参考