VAE(変分オートエンコーダ)
学習目標: VAEの確率的な潜在空間と再パラメータ化トリックを理解する
VAEとオートエンコーダの違い
通常のオートエンコーダ
潜在表現zは決定的な値
z = encoder(x)
潜在空間に「穴」ができやすく、補間が不自然になることがある
VAE
潜在表現zは確率分布からサンプリング
μ, σ = encoder(x)
z ~ N(μ, σ²)
潜在空間が連続的で滑らかになり、生成に適する
再パラメータ化トリック
サンプリング操作は微分できないため、勾配が流れない問題があります。
解決策: サンプリングを以下のように書き換えます:
z = μ + σ × ε, where ε ~ N(0, 1)
εは学習に関係しない標準正規分布からのサンプル
PyTorch実装
class VAE(nn.Module):
def __init__(self, input_dim=784, latent_dim=20):
super().__init__()
# エンコーダ
self.encoder = nn.Sequential(
nn.Linear(input_dim, 256),
nn.ReLU()
)
self.fc_mu = nn.Linear(256, latent_dim) # 平均
self.fc_logvar = nn.Linear(256, latent_dim) # 対数分散
# デコーダ
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.ReLU(),
nn.Linear(256, input_dim),
nn.Sigmoid()
)
def encode(self, x):
h = self.encoder(x)
return self.fc_mu(h), self.fc_logvar(h)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
return self.decoder(z)
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
# 損失関数
def vae_loss(recon_x, x, mu, logvar):
# 再構成誤差
recon_loss = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
# KLダイバージェンス
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + kl_loss
損失関数の意味
Loss = 再構成誤差 + β × KLダイバージェンス
再構成誤差
入力と出力の差。小さいほど再構成が正確。
KLダイバージェンス
潜在分布が標準正規分布にどれだけ近いか。潜在空間を整理する正則化項。
理解度チェック
Q. 再パラメータ化トリックの目的は?