データパイプライン
学習目標: Dataset / DataLoader / Transform / Augmentation の組み合わせを理解し、効率的な学習データ供給パイプラインを構築できる
Dataset の3パターン
1. 既存データセット
from torchvision import datasets, transforms
train_ds = datasets.CIFAR10(
root='./data', train=True, download=True,
transform=transforms.ToTensor(),
)
print(f"size: {len(train_ds)}, sample: {train_ds[0][0].shape}")
2. Map-style Dataset(自作)
from torch.utils.data import Dataset
from PIL import Image
import os
class ImageFolderDataset(Dataset):
"""フォルダ名がクラスラベル"""
def __init__(self, root, transform=None):
self.transform = transform
self.samples = []
self.classes = sorted(os.listdir(root))
self.class_to_idx = {c: i for i, c in enumerate(self.classes)}
for cls in self.classes:
for fname in os.listdir(os.path.join(root, cls)):
self.samples.append((os.path.join(root, cls, fname),
self.class_to_idx[cls]))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
path, label = self.samples[idx]
img = Image.open(path).convert('RGB')
if self.transform:
img = self.transform(img)
return img, label
3. Iterable Dataset(ストリーミング)
サイズが事前に分からない、または巨大なデータに使う。
例: TFRecordのようなshardから逐次読み込み、Webからストリーミング。
from torch.utils.data import IterableDataset
class StreamingDataset(IterableDataset):
def __init__(self, shard_paths):
self.shard_paths = shard_paths
def __iter__(self):
worker = torch.utils.data.get_worker_info()
shards = self.shard_paths
if worker is not None:
# 各workerに均等に分配
shards = shards[worker.id::worker.num_workers]
for path in shards:
for sample in read_shard(path):
yield sample
DataLoader の重要オプション
from torch.utils.data import DataLoader
train_loader = DataLoader(
train_ds,
batch_size=128,
shuffle=True, # 学習時はTrue(順序バイアス防止)
num_workers=4, # 並列ロード(CPUコア数まで)
pin_memory=True, # GPU転送高速化
persistent_workers=True, # epoch間でworkerを再起動しない
prefetch_factor=2, # workerが先読みするバッチ数
drop_last=True, # BatchNorm/DDPで端数バッチを避ける
collate_fn=None, # カスタムバッチ化が必要なら指定
)
カスタム collate_fn
可変長サンプルのpadding、辞書バッチ化などに使う。
def pad_collate(batch):
"""可変長シーケンスを最長に合わせてpadding"""
sequences, labels = zip(*batch)
lengths = torch.tensor([s.size(0) for s in sequences])
padded = nn.utils.rnn.pad_sequence(sequences, batch_first=True)
return padded, lengths, torch.tensor(labels)
loader = DataLoader(seq_dataset, batch_size=32, collate_fn=pad_collate)
WeightedRandomSampler — 不均衡データ対応
from torch.utils.data import WeightedRandomSampler
# クラス頻度の逆数を重みに
class_counts = torch.bincount(torch.tensor([y for _, y in dataset]))
class_weights = 1.0 / class_counts.float()
sample_weights = [class_weights[y] for _, y in dataset]
sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(dataset),
replacement=True,
)
loader = DataLoader(dataset, batch_size=64, sampler=sampler)
データ拡張 (Augmentation)
torchvision.transforms (v2)
from torchvision.transforms import v2
train_tf = v2.Compose([
v2.RandomResizedCrop(224, scale=(0.8, 1.0)),
v2.RandomHorizontalFlip(p=0.5),
v2.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
v2.RandAugment(num_ops=2, magnitude=9),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
v2.RandomErasing(p=0.25),
])
val_tf = v2.Compose([
v2.Resize(256),
v2.CenterCrop(224),
v2.ToImage(),
v2.ToDtype(torch.float32, scale=True),
v2.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
MixUp / CutMix(高度な拡張)
def mixup(x, y, alpha=0.2, n_classes=10):
"""2サンプルを混合"""
lam = torch.distributions.Beta(alpha, alpha).sample().item()
perm = torch.randperm(x.size(0))
x_mixed = lam * x + (1 - lam) * x[perm]
y_a = F.one_hot(y, n_classes).float()
y_b = F.one_hot(y[perm], n_classes).float()
y_mixed = lam * y_a + (1 - lam) * y_b
return x_mixed, y_mixed
def cutmix(x, y, alpha=1.0, n_classes=10):
"""画像の矩形領域を別画像と入れ替え"""
lam = torch.distributions.Beta(alpha, alpha).sample().item()
perm = torch.randperm(x.size(0))
_, _, H, W = x.shape
cut_h, cut_w = int(H * (1 - lam)**0.5), int(W * (1 - lam)**0.5)
cy, cx = torch.randint(0, H, (1,)).item(), torch.randint(0, W, (1,)).item()
y1, y2 = max(0, cy - cut_h // 2), min(H, cy + cut_h // 2)
x1, x2 = max(0, cx - cut_w // 2), min(W, cx + cut_w // 2)
x_new = x.clone()
x_new[:, :, y1:y2, x1:x2] = x[perm, :, y1:y2, x1:x2]
lam = 1 - ((y2 - y1) * (x2 - x1) / (H * W))
y_a = F.one_hot(y, n_classes).float()
y_b = F.one_hot(y[perm], n_classes).float()
return x_new, lam * y_a + (1 - lam) * y_b
テキスト・時系列の拡張
- テキスト: BackTranslation, 同義語置換, Cutout, EDA
- 時系列: Jittering, Scaling, Time Warping, Window Slicing
- 音声: SpecAugment(時間/周波数マスキング)
パイプラインのボトルネック対策
よくあるボトルネック
| 症状 | 原因 | 対策 |
|---|---|---|
| GPU使用率が低い | I/O待ち / 前処理が重い | num_workers増、prefetchと pin_memory |
| メモリリーク | Workerの再生成 | persistent_workers=True |
| Epoch開始でフリーズ | workerフォーク重い | persistent_workers + 軽量Dataset |
| Augmentation がCPU bound | PIL / numpy遅い | GPU augmentation (kornia) or v2 Compose (TVテンソル) |
| ディスクI/O遅い | 大量小ファイル | HDF5 / LMDB / WebDatasetでまとめる |
速度プロファイリング
import time
# 1. Dataset単体の速度
start = time.time()
for i in range(100):
_ = dataset[i]
print(f"Dataset: {(time.time() - start)/100*1000:.1f} ms/sample")
# 2. DataLoader全体
start = time.time()
n = 0
for batch in loader:
n += batch[0].size(0)
if n >= 1000: break
print(f"DataLoader: {n/(time.time()-start):.1f} samples/sec")
GPU側 augmentation (kornia)
import kornia.augmentation as K
gpu_aug = K.AugmentationSequential(
K.RandomHorizontalFlip(),
K.RandomCrop((224, 224)),
K.ColorJitter(0.4, 0.4, 0.4),
K.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225])),
).cuda()
# 学習ループで適用
for x, y in loader:
x = x.cuda(non_blocking=True)
x = gpu_aug(x) # GPU上で拡張
output = model(x)