実装パターン集

学習目標: VAE / GAN / Denoising AE / 完全な学習パイプラインの PyTorch 実装パターンを押さえる

🔄 シンプルVAE

モデル定義

import torch
import torch.nn as nn
import torch.nn.functional as F

class VAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super().__init__()
        # Encoder
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu     = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        # Decoder
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

損失関数(ELBO)

def vae_loss(recon_x, x, mu, logvar):
    # 再構成損失 (ベルヌーイ尤度)
    bce = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')
    # KLダイバージェンス: -0.5 * Σ(1 + log σ² - μ² - σ²)
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return bce + kld

学習ループ

device = "cuda" if torch.cuda.is_available() else "cpu"
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(1, num_epochs + 1):
    model.train()
    total = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon, mu, logvar = model(data)
        loss = vae_loss(recon, data, mu, logvar)
        loss.backward()
        optimizer.step()
        total += loss.item()
    print(f"Epoch {epoch}: avg loss = {total / len(train_loader.dataset):.4f}")

生成(潜在空間からサンプリング)

model.eval()
with torch.no_grad():
    z = torch.randn(64, 20).to(device)
    samples = model.decode(z).view(-1, 1, 28, 28).cpu()
    # samples を画像として保存

⚔️ 基本GAN

Generator と Discriminator

class Generator(nn.Module):
    def __init__(self, latent_dim=100, img_dim=784):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 256), nn.LeakyReLU(0.2),
            nn.Linear(256, 512),         nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),        nn.LeakyReLU(0.2),
            nn.Linear(1024, img_dim),    nn.Tanh(),
        )
    def forward(self, z):
        return self.net(z)


class Discriminator(nn.Module):
    def __init__(self, img_dim=784):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(img_dim, 1024), nn.LeakyReLU(0.2), nn.Dropout(0.3),
            nn.Linear(1024, 512),     nn.LeakyReLU(0.2), nn.Dropout(0.3),
            nn.Linear(512, 256),      nn.LeakyReLU(0.2), nn.Dropout(0.3),
            nn.Linear(256, 1),        nn.Sigmoid(),
        )
    def forward(self, x):
        return self.net(x)

学習ループ(ミニマックスゲーム)

G = Generator().to(device)
D = Discriminator().to(device)
opt_g = torch.optim.Adam(G.parameters(), lr=2e-4, betas=(0.5, 0.999))
opt_d = torch.optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))
bce = nn.BCELoss()

for epoch in range(num_epochs):
    for real, _ in train_loader:
        N = real.size(0)
        real = real.view(N, -1).to(device)
        ones, zeros = torch.ones(N, 1).to(device), torch.zeros(N, 1).to(device)

        # === Discriminator update ===
        opt_d.zero_grad()
        d_real = D(real)
        z = torch.randn(N, 100).to(device)
        fake = G(z).detach()           # Generator は更新しない
        d_fake = D(fake)
        d_loss = bce(d_real, ones) + bce(d_fake, zeros)
        d_loss.backward()
        opt_d.step()

        # === Generator update ===
        opt_g.zero_grad()
        z = torch.randn(N, 100).to(device)
        fake = G(z)
        g_loss = bce(D(fake), ones)    # 「本物」と騙したい
        g_loss.backward()
        opt_g.step()

学習の安定化テクニック

テクニック効果
Label Smoothing本物ラベルを 1.0 → 0.9 にして D の過信を抑える
WGAN-GPBCE の代わりに Wasserstein 距離 + 勾配ペナルティ
Spectral NormD の各層に SN を入れて Lipschitz 制約
TTURG と D で異なる学習率(D側を大きく)

🔧 デノイジングオートエンコーダ

入力にノイズを加え、ノイズ前の画像を復元するように学習。表現学習・前処理に有用。

モデル

class DenoisingAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 2, stride=2), nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 2, stride=2), nn.Sigmoid(),
        )

    def forward(self, x):
        return self.decoder(self.encoder(x))

ノイズの種類

def add_gaussian_noise(x, std=0.1):
    return torch.clamp(x + std * torch.randn_like(x), 0., 1.)

def add_salt_pepper_noise(x, prob=0.05):
    mask = torch.rand_like(x)
    x = x.clone()
    x[mask < prob / 2] = 0.0      # pepper
    x[mask > 1 - prob / 2] = 1.0  # salt
    return x

def add_blockout_noise(x, n_blocks=3, size=5):
    x = x.clone()
    for _ in range(n_blocks):
        i = torch.randint(0, x.size(-2) - size, (1,)).item()
        j = torch.randint(0, x.size(-1) - size, (1,)).item()
        x[..., i:i+size, j:j+size] = 0.0
    return x

学習ループ

model = DenoisingAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

for epoch in range(num_epochs):
    for clean, _ in train_loader:
        clean = clean.to(device)
        noisy = add_gaussian_noise(clean, std=0.2)

        optimizer.zero_grad()
        reconstructed = model(noisy)
        loss = criterion(reconstructed, clean)   # ターゲットはノイズ前
        loss.backward()
        optimizer.step()

応用例

  • 古い写真の修復 — 傷・ノイズを取り除く
  • 低光量画像の改善 — センサーノイズ除去
  • 医療画像の処理 — MRI/CT のアーチファクト除去
  • 事前学習 — ノイズ除去を補助タスクにして、エンコーダを特徴抽出器として再利用

🚀 完全な学習パイプライン

VAE/GAN/Diffusion 共通で使える、本番運用を意識した骨組みです。

1. 設定 (YAML)

# config.yaml
model:
  name: vae
  latent_dim: 64
  hidden_dim: 400

train:
  batch_size: 128
  epochs: 100
  lr: 1.0e-3
  scheduler: cosine
  warmup_epochs: 5

data:
  dataset: mnist
  augmentation: [random_horizontal_flip, normalize]

logging:
  tensorboard: true
  log_every: 100
  checkpoint_every: 5

2. データ準備(DataLoader + 拡張)

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

train_tfm = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

train_ds = datasets.MNIST('data', train=True, transform=train_tfm, download=True)
val_ds   = datasets.MNIST('data', train=False, transform=transforms.ToTensor())

train_loader = DataLoader(train_ds, batch_size=128, shuffle=True,
                          num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=256, shuffle=False)

3. 学習ループ(チェックポイント・スケジューラ・ロギング)

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/vae_exp1')
scaler = torch.cuda.amp.GradScaler()           # FP16 学習
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

global_step = 0
best_val = float('inf')

for epoch in range(epochs):
    model.train()
    for x, _ in train_loader:
        x = x.to(device, non_blocking=True)
        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            recon, mu, logvar = model(x)
            loss = vae_loss(recon, x, mu, logvar)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()

        if global_step % 100 == 0:
            writer.add_scalar('train/loss', loss.item(), global_step)
        global_step += 1

    scheduler.step()

    # 検証
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for x, _ in val_loader:
            x = x.to(device)
            recon, mu, logvar = model(x)
            val_loss += vae_loss(recon, x, mu, logvar).item()
    val_loss /= len(val_loader.dataset)
    writer.add_scalar('val/loss', val_loss, epoch)

    # ベストモデル保存
    if val_loss < best_val:
        best_val = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
        }, 'best.pth')

writer.close()

4. 評価とエクスポート

# ベストモデルを復元
ckpt = torch.load('best.pth')
model.load_state_dict(ckpt['model_state_dict'])
model.eval()

# 評価(再構成品質)
with torch.no_grad():
    psnr_total, n = 0.0, 0
    for x, _ in val_loader:
        x = x.to(device)
        recon, _, _ = model(x)
        mse = F.mse_loss(recon, x, reduction='none').mean(dim=(1,2,3))
        psnr = -10 * torch.log10(mse).mean().item()
        psnr_total += psnr * x.size(0); n += x.size(0)
    print(f"Validation PSNR: {psnr_total/n:.2f} dB")

# ONNXエクスポート
dummy = torch.randn(1, 1, 28, 28).to(device)
torch.onnx.export(model, dummy, 'vae.onnx', opset_version=17)

5. ベストプラクティス チェックリスト

  • ✅ ハイパーパラメータは YAML / Hydra で管理(コード変更なしで実験切替)
  • SEED を固定して再現性確保
  • ✅ TensorBoard / W&B でメトリクスを記録
  • ✅ Gradient clipping で発散を防ぐ
  • ✅ Mixed Precision (AMP) で速度2倍・VRAM削減
  • ✅ ベストモデルとラストモデルの両方を保存
  • ✅ 検証ロスが N エポック改善しなければ early stopping
  • ✅ 学習中の生成サンプルを可視化(劣化に早く気付ける)