データパイプライン

学習目標: Dataset / DataLoader / Transform / Augmentation の組み合わせを理解し、効率的な学習データ供給パイプラインを構築できる

Dataset の3パターン

1. 既存データセット

from torchvision import datasets, transforms

train_ds = datasets.CIFAR10(
    root='./data', train=True, download=True,
    transform=transforms.ToTensor(),
)
print(f"size: {len(train_ds)}, sample: {train_ds[0][0].shape}")

2. Map-style Dataset(自作)

from torch.utils.data import Dataset
from PIL import Image
import os

class ImageFolderDataset(Dataset):
    """フォルダ名がクラスラベル"""
    def __init__(self, root, transform=None):
        self.transform = transform
        self.samples = []
        self.classes = sorted(os.listdir(root))
        self.class_to_idx = {c: i for i, c in enumerate(self.classes)}
        for cls in self.classes:
            for fname in os.listdir(os.path.join(root, cls)):
                self.samples.append((os.path.join(root, cls, fname),
                                     self.class_to_idx[cls]))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        img = Image.open(path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, label

3. Iterable Dataset(ストリーミング)

サイズが事前に分からない、または巨大なデータに使う。
例: TFRecordのようなshardから逐次読み込み、Webからストリーミング。

from torch.utils.data import IterableDataset

class StreamingDataset(IterableDataset):
    def __init__(self, shard_paths):
        self.shard_paths = shard_paths

    def __iter__(self):
        worker = torch.utils.data.get_worker_info()
        shards = self.shard_paths
        if worker is not None:
            # 各workerに均等に分配
            shards = shards[worker.id::worker.num_workers]

        for path in shards:
            for sample in read_shard(path):
                yield sample

DataLoader の重要オプション

from torch.utils.data import DataLoader

train_loader = DataLoader(
    train_ds,
    batch_size=128,
    shuffle=True,            # 学習時はTrue(順序バイアス防止)
    num_workers=4,           # 並列ロード(CPUコア数まで)
    pin_memory=True,         # GPU転送高速化
    persistent_workers=True, # epoch間でworkerを再起動しない
    prefetch_factor=2,       # workerが先読みするバッチ数
    drop_last=True,          # BatchNorm/DDPで端数バッチを避ける
    collate_fn=None,         # カスタムバッチ化が必要なら指定
)

カスタム collate_fn

可変長サンプルのpadding、辞書バッチ化などに使う。

def pad_collate(batch):
    """可変長シーケンスを最長に合わせてpadding"""
    sequences, labels = zip(*batch)
    lengths = torch.tensor([s.size(0) for s in sequences])
    padded = nn.utils.rnn.pad_sequence(sequences, batch_first=True)
    return padded, lengths, torch.tensor(labels)


loader = DataLoader(seq_dataset, batch_size=32, collate_fn=pad_collate)

WeightedRandomSampler — 不均衡データ対応

from torch.utils.data import WeightedRandomSampler

# クラス頻度の逆数を重みに
class_counts = torch.bincount(torch.tensor([y for _, y in dataset]))
class_weights = 1.0 / class_counts.float()
sample_weights = [class_weights[y] for _, y in dataset]

sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(dataset),
    replacement=True,
)
loader = DataLoader(dataset, batch_size=64, sampler=sampler)

データ拡張 (Augmentation)

torchvision.transforms (v2)

from torchvision.transforms import v2

train_tf = v2.Compose([
    v2.RandomResizedCrop(224, scale=(0.8, 1.0)),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
    v2.RandAugment(num_ops=2, magnitude=9),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406],
                 std=[0.229, 0.224, 0.225]),
    v2.RandomErasing(p=0.25),
])

val_tf = v2.Compose([
    v2.Resize(256),
    v2.CenterCrop(224),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.485, 0.456, 0.406],
                 std=[0.229, 0.224, 0.225]),
])

MixUp / CutMix(高度な拡張)

def mixup(x, y, alpha=0.2, n_classes=10):
    """2サンプルを混合"""
    lam = torch.distributions.Beta(alpha, alpha).sample().item()
    perm = torch.randperm(x.size(0))

    x_mixed = lam * x + (1 - lam) * x[perm]
    y_a = F.one_hot(y, n_classes).float()
    y_b = F.one_hot(y[perm], n_classes).float()
    y_mixed = lam * y_a + (1 - lam) * y_b

    return x_mixed, y_mixed


def cutmix(x, y, alpha=1.0, n_classes=10):
    """画像の矩形領域を別画像と入れ替え"""
    lam = torch.distributions.Beta(alpha, alpha).sample().item()
    perm = torch.randperm(x.size(0))
    _, _, H, W = x.shape
    cut_h, cut_w = int(H * (1 - lam)**0.5), int(W * (1 - lam)**0.5)
    cy, cx = torch.randint(0, H, (1,)).item(), torch.randint(0, W, (1,)).item()
    y1, y2 = max(0, cy - cut_h // 2), min(H, cy + cut_h // 2)
    x1, x2 = max(0, cx - cut_w // 2), min(W, cx + cut_w // 2)

    x_new = x.clone()
    x_new[:, :, y1:y2, x1:x2] = x[perm, :, y1:y2, x1:x2]
    lam = 1 - ((y2 - y1) * (x2 - x1) / (H * W))

    y_a = F.one_hot(y, n_classes).float()
    y_b = F.one_hot(y[perm], n_classes).float()
    return x_new, lam * y_a + (1 - lam) * y_b

テキスト・時系列の拡張

  • テキスト: BackTranslation, 同義語置換, Cutout, EDA
  • 時系列: Jittering, Scaling, Time Warping, Window Slicing
  • 音声: SpecAugment(時間/周波数マスキング)

パイプラインのボトルネック対策

よくあるボトルネック

症状原因対策
GPU使用率が低いI/O待ち / 前処理が重いnum_workers増、prefetchと pin_memory
メモリリークWorkerの再生成persistent_workers=True
Epoch開始でフリーズworkerフォーク重いpersistent_workers + 軽量Dataset
Augmentation がCPU boundPIL / numpy遅いGPU augmentation (kornia) or v2 Compose (TVテンソル)
ディスクI/O遅い大量小ファイルHDF5 / LMDB / WebDatasetでまとめる

速度プロファイリング

import time

# 1. Dataset単体の速度
start = time.time()
for i in range(100):
    _ = dataset[i]
print(f"Dataset: {(time.time() - start)/100*1000:.1f} ms/sample")

# 2. DataLoader全体
start = time.time()
n = 0
for batch in loader:
    n += batch[0].size(0)
    if n >= 1000: break
print(f"DataLoader: {n/(time.time()-start):.1f} samples/sec")

GPU側 augmentation (kornia)

import kornia.augmentation as K

gpu_aug = K.AugmentationSequential(
    K.RandomHorizontalFlip(),
    K.RandomCrop((224, 224)),
    K.ColorJitter(0.4, 0.4, 0.4),
    K.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]),
                std=torch.tensor([0.229, 0.224, 0.225])),
).cuda()

# 学習ループで適用
for x, y in loader:
    x = x.cuda(non_blocking=True)
    x = gpu_aug(x)               # GPU上で拡張
    output = model(x)