VAE(変分オートエンコーダ)

学習目標: VAEの確率的な潜在空間と再パラメータ化トリックを理解する

VAEとオートエンコーダの違い

通常のオートエンコーダ

潜在表現zは決定的な値

z = encoder(x)

潜在空間に「穴」ができやすく、補間が不自然になることがある

VAE

潜在表現zは確率分布からサンプリング

μ, σ = encoder(x)
z ~ N(μ, σ²)

潜在空間が連続的で滑らかになり、生成に適する

再パラメータ化トリック

サンプリング操作は微分できないため、勾配が流れない問題があります。

解決策: サンプリングを以下のように書き換えます:

z = μ + σ × ε, where ε ~ N(0, 1)

εは学習に関係しない標準正規分布からのサンプル

PyTorch実装

class VAE(nn.Module):
    def __init__(self, input_dim=784, latent_dim=20):
        super().__init__()

        # エンコーダ
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(256, latent_dim)     # 平均
        self.fc_logvar = nn.Linear(256, latent_dim)  # 対数分散

        # デコーダ
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, input_dim),
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# 損失関数
def vae_loss(recon_x, x, mu, logvar):
    # 再構成誤差
    recon_loss = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
    # KLダイバージェンス
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_loss

損失関数の意味

Loss = 再構成誤差 + β × KLダイバージェンス

再構成誤差

入力と出力の差。小さいほど再構成が正確。

KLダイバージェンス

潜在分布が標準正規分布にどれだけ近いか。潜在空間を整理する正則化項。

理解度チェック

Q. 再パラメータ化トリックの目的は?