引き続きPyTorchのお勉強です。
画像処理タスクの文脈でDatasetとDataLoaderの使い方を整理していきます。
DatasetとDataLoader
PyTorchに限らず、ディープラーニングタスクのデータの入力については、一般的に以下の要件が挙げられます
- データをミニバッチにまとめる
- 任意の前処理を実行する
- (CPU依存のため) 適切に並列処理させたい
PyTorchで上のニーズに応えるのがDatasetとDataLoaderです。Datasetが取り出したデータをDataLoaderがミニバッチにまとめるという関係になってます。いずれもtorch.utils.dataに属します
- Datasetは、データソース (ファイルやNumPy行列など) から1個ずつデータ (入力とground truth) を取り出す
- DataLoaderは、Datasetからバッチサイズずつデータを取り出す
- このミニバッチがモデルへの入力となる
データセットの準備
PyTorchチュートリアルの中で提供されている画像データセットをダウンロードして、./hymenoptera_data/{train or val}/{ants or bees}/ファイル名.jpg
に展開します。アリかハチを見分けるタスクになっており、ファイルパスで区別されています。
import numpy as np from PIL import Image from pathlib import Path import urllib.request import zipfile import random import torch import torch.utils.data as data from torchvision import transforms print(torch.__version__) # 1.3.1 urllib.request.urlretrieve( "https://download.pytorch.org/tutorial/hymenoptera_data.zip", "hymenoptera_data.zip" ) with zipfile.ZipFile("./hymenoptera_data.zip") as zip: zip.extractall(".") print(list(pathlib.Path("hymenoptera_data/val").glob("**/*.jpg"))[:4]) # [PosixPath('hymenoptera_data/val/bees/2478216347_535c8fe6d7.jpg'), # PosixPath('hymenoptera_data/val/bees/1486120850_490388f84b.jpg'), # PosixPath('hymenoptera_data/val/bees/2060668999_e11edb10d0.jpg'), # PosixPath('hymenoptera_data/val/bees/2104135106_a65eede1de.jpg')]
Dataset
Datasetはデータソースから1個ずつデータを取り出すクラスです。DataLoaderにはこのDataset継承クラスのオブジェクトを渡します。
(ストリーミングデータ向けに、イテレータ形式で取り出すためのIterableDatasetもあります。ここではDatasetクラスを中心に扱いたいと思います。)
このクラスを継承して実装します。学習と評価で、ロードするデータやオーグメンテーション有無などの要件が異なりますので、別のクラスに分けるか、あるいは、初期化時の引数などで切り替えられるようにしておくのが一般的です。
- 実装が必要なメソッドは__getitem__と__len__の2つ
- __getitem__は入力とGTをタプルにして返します
- 入力は0次元目をチャネルに変換します
class MyDataset(data.Dataset): def __init__(self, dir_path, input_size): super().__init__() self.dir_path = dir_path self.input_size = input_size # hymenoptera_data/{train or val}/{ants or bees}/ファイル名.jpg # ex. hymenoptera_data/val/bees/2478216347_535c8fe6d7.jpg self.image_paths = [str(p) for p in Path(self.dir_path).glob("**/*.jpg")] self.len = len(self.image_paths) def __len__(self): return self.len def __getitem__(self, index): p = self.image_paths[index] # 入力 image = Image.open(p) image = image.resize(self.input_size) image = np.array(image) image = np.transpose(image, (2, 0, 1)) image = torch.from_numpy(image) # ラベル (0: ants, 1: bees) label = p.split("/")[2] label = 1 if label == "bees" else 0 return image, label train_dataset = MyDataset("./hymenoptera_data/train/", (224, 224)) image, label = train_dataset[0] print(image.size(), label) # torch.Size([3, 224, 224]) 1
torchvision.transforms
データを取り出すとともに、入力データに対する前処理が必要な場合もDataset内で実現することがしばしばあります。
上の実装でも、画像に対してリサイズ・インデックス入れ替え・Tensor化を行ってます。
画像の場合、torchvisionパッケージに含まれるtransformsを使うとより簡単に実装できます。
- 学習と評価で必要な前処理が異なりますので、__init__にphaseを追加
- Composeでまとめることで、順に一連の変換処理を実行する関数へとまとめられます
class MyDataset(data.Dataset): def __init__(self, dir_path, input_size, phase): super().__init__() # ... (省略) ... if phase == "train": transform_ops.extend([ transforms.RandomVerticalFlip(), # ランダムに左右反転 transforms.RandomRotation(20), # ±20度でランダムに回転 transforms.RandomResizedCrop(input_size, scale=(0.8, 1.0)) # ランダムにクロップしてリサイズ ]) elif phase == "val": transform_ops.append( transforms.Resize(input_size) # リサイズ ) transform_ops.extend([ transforms.ToTensor(), # Tensor化 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 標準化 ]) # 1つの変換処理に集約 self.transformer = transforms.Compose(transform_ops) def __len__(self): return self.len def __getitem__(self, index): p = self.image_paths[index] # 入力 image = Image.open(p) image = self.transformer(image) # __init__で定義した関数に置き換え # ラベル (0: ants, 1: bees) label = p.split("/")[2] label = 1 if label == "bees" else 0 return image, label
DataLoader
DataLoaderはバッチサイズごとにDatasetからデータを取り出します。継承などは行わず、作成したDatasetオブジェクトとバッチサイズを渡してオブジェクト化します。
- __iter__が定義されており、forループなどでバッチ単位で取り出すことができます
batch_size=16ですので、入力が (16, 3, 224, 224) 、ラベルも (16) のサイズになります。
- shuffle=Trueで順番がランダムになりますので、学習時はTrue、評価時はFalseとするのが一般的です
- num_workersで実行プロセス数を指定でき、2以上の値を設定すると並列化できます (デフォルトは0でメインプロセスでの実行になります)
- drop_last=Trueとすると、末尾の方で残りデータ数がbatch_sizeに満たない場合にスキップされます
- バッチ標準化などはデータ数が1だと計算できずエラーになるため (これも学習時のみのTrue)
train_dataloader = data.DataLoader( train_dataset, batch_size=16, shuffle=True, num_workers=2, drop_last=True ) images, labels = next(iter(train_dataloader)) print(images.size()) # torch.Size([16, 3, 224, 224]) print(labels.size()) # torch.Size([16])
collate_fn
入力とラベルが1対1で紐付かないケースがあります。例えば物体検出では、1画像中に含まれるオブジェクトの数は、0個の場合もあれば2個以上の場合もあります。
DataLoaderは、デフォルトではtorch.stackで1つのTensorオブジェクトへ連結しようとします。ところが以下のlabelsのように、0次元目の次元数がデータによって異なる場合は、連結できずにエラーとなってしまいます。
class DummyDataset(data.Dataset): # 1ラベルの物体検出のデータセット # (ex. 複数人が写っている写真の中から人の位置を特定する、など) def __init__(self, size=100): super().__init__() self.len = size def __len__(self): return self.len def __getitem__(self, index): labels = [] object_num = random.randint(0, 4) # 0〜4個でランダム dummy_img = torch.randint(0, 255, (3, 224, 224)) # ダミー画像 dummy_bbox = [1, 2, 3, 4] # ダミーのバウンディングボックス for _ in range(object_num): # 物体数分のバウンディングボックス # [], ..., [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]] labels.append(dummy_bbox) labels = torch.tensor(labels) return dummy_img, labels dummy_dataset = DummyDataset() image, labels = dummy_dataset[10] print(image.size(), labels.size()) dummy_dataloader = data.DataLoader( dummy_dataset, batch_size=4, shuffle=True, num_workers=2, drop_last=True ) images, labels = next(iter(dummy_dataloader)) # 内部で実行しているtorch.stackでエラーになる
こうした特殊な変換が必要なケースでは、DataLoaderのcollate_fnへ自前の変換関数を渡すことで対応します。
collate_fnは第1引数 (以下ではbatch) にDatasetの返り値のリストが格納されています。DummyDatasetの場合は画像 (Tensor) とラベル (Tensorリスト) のタプルリストです。
以下のようにlabelsは1つのTensorに変換せず、Tensorリストのまま返すようにすることで、ラベルが可変長の場合でもバッチにまとめられるようにしています。
def collate_fn(batch): # bactchはDatasetの返り値 (タプル) のリスト images, labels = [], [] for image, label in batch: images.append(image) labels.append(label) images = torch.stack(images, dim=0) # labelsはTensorリストのまま return images, labels dummy_dataloader = data.DataLoader( dummy_dataset, batch_size=4, shuffle=True, num_workers=2, drop_last=True, collate_fn=collate_fn ) images, labels = next(iter(dummy_dataloader)) print(images.size()) # torch.Size([4, 3, 224, 224]) print(labels) # [tensor([[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]), # tensor([[1, 2, 3, 4]]), # tensor([]), # tensor([[1, 2, 3, 4], [1, 2, 3, 4]])]
まとめ
画像処理タスクのケースを中心にPyTorchのDatasetとDataLoaderについてまとめました。