発展アーキテクチャ

学習目標: StyleGAN・ViT-VAE・Score-based Diffusion など、生成AIの発展形アーキテクチャを理解する

StyleGAN — スタイル制御可能な GAN

革新的な特徴

  • マッピングネットワーク: ノイズ z を中間潜在空間 w にマップ。w はもつれが少なく属性操作しやすい
  • AdaIN (Adaptive Instance Normalization): 各解像度に w から得たスタイルを注入
  • 段階的解像度上昇: 4×4 → 8×8 → ... → 1024×1024 へと徐々に細部を生成
  • Stochastic Variation: 各層にノイズを注入し、髪の毛・皮膚の質感などの細かな差を表現

アーキテクチャ概要

       z (latent) ────► Mapping Network (8層MLP) ────► w
                                                       │
                                                       ▼ (各層に分配)
       const (4×4) ──► [Conv + AdaIN + Noise]  ──► 4×4
                       [Conv + AdaIN + Noise]  ──► 8×8 (upscale)
                       [Conv + AdaIN + Noise]  ──► 16×16
                              ...
                       [Conv + AdaIN + Noise]  ──► 1024×1024
                                                       │
                                                       ▼
                                                  生成画像

PyTorch 実装スケッチ (Mapping Network + Synthesis Block)

class MappingNetwork(nn.Module):
    """z ∈ Z → w ∈ W への 8層 MLP"""
    def __init__(self, z_dim=512, w_dim=512, n_layers=8):
        super().__init__()
        layers = []
        for _ in range(n_layers):
            layers += [nn.Linear(z_dim, w_dim),
                       nn.LeakyReLU(0.2, inplace=True)]
            z_dim = w_dim
        self.net = nn.Sequential(*layers)

    def forward(self, z):
        z = F.normalize(z, dim=1)   # PixelNorm 相当
        return self.net(z)


class AdaIN(nn.Module):
    """Adaptive Instance Normalization: w からスタイルを生成して注入"""
    def __init__(self, w_dim, channels):
        super().__init__()
        self.norm  = nn.InstanceNorm2d(channels)
        self.style = nn.Linear(w_dim, channels * 2)  # スケール+バイアス

    def forward(self, x, w):
        gamma, beta = self.style(w).chunk(2, dim=1)
        gamma = gamma.unsqueeze(2).unsqueeze(3)
        beta  = beta.unsqueeze(2).unsqueeze(3)
        return gamma * self.norm(x) + beta


class StyleGANBlock(nn.Module):
    """1解像度ぶんの生成ブロック"""
    def __init__(self, in_ch, out_ch, w_dim):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.adain = AdaIN(w_dim, out_ch)
        self.noise_scale = nn.Parameter(torch.zeros(1, out_ch, 1, 1))
        self.act = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x, w):
        x = self.conv(x)
        noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device)
        x = x + self.noise_scale * noise
        x = self.adain(x, w)
        return self.act(x)

主な使い方テクニック

テクニック説明
Truncation Trick w' = w_mean + ψ * (w - w_mean) で多様性↔品質をトレードオフ (ψ=0.5〜1.0)
Style Mixing 低解像度層と高解像度層で異なる w を使い、属性を混ぜる(顔の輪郭A + 髪型B)
Latent Walk w 空間で2点間を補間し、なめらかな変化を生成

Vision Transformer VAE

ViT-VAE の特徴

  • パッチベース処理: 画像を 16×16 のパッチに分割してトークン化
  • Self-Attention: パッチ間の関係を捉える(CNNの局所性に縛られない)
  • グローバルな依存性: 画像全体の構造を1層で把握できる
  • 解釈性: アテンションマップで「どこに注目したか」が可視化可能

アーキテクチャ

画像 (H, W, 3)
   │
   ▼ パッチ分割 (P×P)
パッチ列 (N, P*P*3)   N = (H/P) * (W/P)
   │
   ▼ Linear projection
パッチ埋め込み (N, D) + 位置埋め込み
   │
   ▼ Transformer Encoder × L
コンテキスト化された埋め込み (N, D)
   │
   ▼ Aggregator (CLS or mean)
   │
   ▼ μ, log σ² の予測
潜在 z = μ + σ * ε
   │
   ▼ Transformer Decoder
再構成パッチ列
   │
   ▼ パッチ再構成
復元画像 (H, W, 3)

PyTorch 実装スケッチ

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=128):
        super().__init__()
        self.n_patches = (img_size // patch_size) ** 2
        # Conv2d で簡潔にパッチ化+埋め込み
        self.proj = nn.Conv2d(in_channels, embed_dim,
                              kernel_size=patch_size, stride=patch_size)
        self.pos = nn.Parameter(torch.zeros(1, self.n_patches, embed_dim))

    def forward(self, x):
        x = self.proj(x).flatten(2).transpose(1, 2)  # (B, N, D)
        return x + self.pos


class ViTVAE(nn.Module):
    def __init__(self, img_size=32, patch_size=4, embed_dim=128,
                 n_heads=4, n_layers=6, latent_dim=64):
        super().__init__()
        self.embed = PatchEmbedding(img_size, patch_size, 3, embed_dim)

        enc_layer = nn.TransformerEncoderLayer(embed_dim, n_heads,
                                               batch_first=True)
        self.encoder = nn.TransformerEncoder(enc_layer, n_layers)
        self.to_latent = nn.Linear(embed_dim, latent_dim * 2)

        self.from_latent = nn.Linear(latent_dim, embed_dim * self.embed.n_patches)
        dec_layer = nn.TransformerEncoderLayer(embed_dim, n_heads,
                                               batch_first=True)
        self.decoder = nn.TransformerEncoder(dec_layer, n_layers)
        self.to_pixels = nn.Linear(embed_dim, 3 * patch_size * patch_size)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        return mu + std * torch.randn_like(std)

    def forward(self, x):
        h = self.embed(x)                         # (B, N, D)
        h = self.encoder(h).mean(dim=1)           # (B, D)
        mu, logvar = self.to_latent(h).chunk(2, dim=1)
        z = self.reparameterize(mu, logvar)
        # decode...
        return mu, logvar, z

アテンション可視化

ViT-VAE はアテンション重みを取り出すことで「どのパッチがどのパッチに注目したか」をヒートマップで可視化できます。 CNNにはない強力な解釈性です。

# 最終層のアテンション重みを取り出す
def get_attention_map(model, x):
    # forward_pre_hook で attention を保存
    saved = {}
    def hook(module, input, output):
        # output[1] が attention weights (need_weights=True 時)
        saved['attn'] = output[1]

    handle = model.encoder.layers[-1].self_attn.register_forward_hook(hook)
    _ = model(x)
    handle.remove()
    return saved['attn']  # (B, n_heads, N, N)

Score-based Diffusion Models

特徴

  • スコア関数の推定: ∇_x log p(x) を学習する
  • 連続時間表現 (SDE): 拡散プロセスを確率微分方程式として定式化
  • 柔軟なサンプリング: 数値ソルバーで step 数や軌道を自由に調整
  • DDPMとの統一: Song et al. (2021) で連続時間の SDE 視点に統合

📐 理論的背景

順過程 (Forward SDE):  dx = f(x, t) dt + g(t) dW
        ノイズを少しずつ加える

逆過程 (Reverse SDE):  dx = [f(x, t) - g(t)² ∇_x log p_t(x)] dt + g(t) dŴ
                            ↑ スコア関数(モデルが学習する対象)

ニューラルネットワーク s_θ(x, t) を ∇_x log p_t(x) に近づける:

L = E_t,x [λ(t) || s_θ(x_t, t) - ∇_x log p_t(x_t) ||²]
        ↑ Denoising Score Matching で学習可能

ノイズスケジュールの比較

SDEf(x, t)g(t)特徴
VE (Variance Exploding) 0√(d[σ²(t)]/dt) 分散発散型。NCSN系
VP (Variance Preserving) -β(t)/2 · x√β(t) 分散保存型。DDPM系
sub-VP -β(t)/2 · x√(β(t)·(1-e^{-2∫β})) 分散制限版。安定

実装のポイント

# Denoising Score Matching ロス
def dsm_loss(score_net, x, sde, t_min=1e-5, t_max=1.0):
    """Song et al. 2021 のシンプルなDSM"""
    # ランダムなタイムステップ
    t = torch.rand(x.size(0), device=x.device) * (t_max - t_min) + t_min

    # SDE から x_t, perturbation kernel の mean/std を取得
    mean, std = sde.marginal_prob(x, t)
    eps = torch.randn_like(x)
    x_t = mean + std[:, None, None, None] * eps

    # スコア推定: -ε / σ が目標
    score = score_net(x_t, t)
    target = -eps / std[:, None, None, None]

    loss = ((score - target) ** 2 * std[:, None, None, None] ** 2).mean()
    return loss


# Predictor-Corrector サンプリング
@torch.no_grad()
def pc_sample(score_net, sde, shape, n_steps=1000, snr=0.16):
    x = sde.prior_sampling(shape).to(device)
    for i in range(n_steps):
        t = torch.full((shape[0],), 1.0 - i / n_steps, device=device)
        # Predictor (Euler-Maruyama)
        drift, diffusion = sde.reverse_sde(score_net, x, t)
        dt = -1.0 / n_steps
        x = x + drift * dt + diffusion[:, None, None, None] * \
            torch.randn_like(x) * (-dt) ** 0.5
        # Corrector (Langevin step)
        grad = score_net(x, t)
        grad_norm = grad.flatten(1).norm(dim=1).mean()
        noise_norm = math.sqrt(x.numel() / x.size(0))
        step_size = (snr * noise_norm / grad_norm) ** 2 * 2
        x = x + step_size * grad + (2 * step_size) ** 0.5 * torch.randn_like(x)
    return x
参考論文: Song, Sohl-Dickstein et al. Score-Based Generative Modeling through Stochastic Differential Equations (ICLR 2021)