GAN基礎
学習目標: GANの敵対的学習の仕組みを理解する
GANとは
GAN(Generative Adversarial Network)は、2つのネットワーク(生成器と識別器)が 互いに競い合いながら学習する生成モデルです。
生成器 G
偽物を作る
敵対的
識別器 D
本物と偽物を見分ける
直感的な理解:偽札作りのたとえ
生成器 G(偽札作り)
目標: 識別器を騙せる偽札を作る
- 最初は下手な偽札を作る
- 識別器に見破られる
- 徐々に精巧な偽札を作れるようになる
識別器 D(警察)
目標: 本物と偽物を見分ける
- 本物の札と偽札を比較
- 偽札の特徴を学習
- より巧妙な偽札も見破れるようになる
PyTorch実装
import torch
import torch.nn as nn
# 生成器
class Generator(nn.Module):
def __init__(self, latent_dim=100, output_dim=784):
super().__init__()
self.model = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, output_dim),
nn.Tanh()
)
def forward(self, z):
return self.model(z)
# 識別器
class Discriminator(nn.Module):
def __init__(self, input_dim=784):
super().__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
# 学習
G = Generator()
D = Discriminator()
criterion = nn.BCELoss()
optimizer_G = torch.optim.Adam(G.parameters(), lr=0.0002)
optimizer_D = torch.optim.Adam(D.parameters(), lr=0.0002)
for epoch in range(epochs):
for real_data in dataloader:
batch_size = real_data.size(0)
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
# 識別器の学習
optimizer_D.zero_grad()
z = torch.randn(batch_size, 100)
fake_data = G(z)
d_loss = criterion(D(real_data), real_labels) + criterion(D(fake_data.detach()), fake_labels)
d_loss.backward()
optimizer_D.step()
# 生成器の学習
optimizer_G.zero_grad()
g_loss = criterion(D(fake_data), real_labels) # 識別器を騙したい
g_loss.backward()
optimizer_G.step()
GANの課題
モード崩壊
生成器が限られたパターンしか生成しなくなる問題
学習の不安定性
GとDのバランスが崩れると学習が失敗する
評価の難しさ
生成品質の定量的な評価が難しい
理解度チェック
Q. 生成器の学習時、損失関数は何を最大化しようとしますか?