GAN基礎

学習目標: GANの敵対的学習の仕組みを理解する

GANとは

GAN(Generative Adversarial Network)は、2つのネットワーク(生成器と識別器)が 互いに競い合いながら学習する生成モデルです。

生成器 G

偽物を作る

敵対的

識別器 D

本物と偽物を見分ける

直感的な理解:偽札作りのたとえ

生成器 G(偽札作り)

目標: 識別器を騙せる偽札を作る

  • 最初は下手な偽札を作る
  • 識別器に見破られる
  • 徐々に精巧な偽札を作れるようになる
識別器 D(警察)

目標: 本物と偽物を見分ける

  • 本物の札と偽札を比較
  • 偽札の特徴を学習
  • より巧妙な偽札も見破れるようになる

PyTorch実装

import torch
import torch.nn as nn

# 生成器
class Generator(nn.Module):
    def __init__(self, latent_dim=100, output_dim=784):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, output_dim),
            nn.Tanh()
        )

    def forward(self, z):
        return self.model(z)

# 識別器
class Discriminator(nn.Module):
    def __init__(self, input_dim=784):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# 学習
G = Generator()
D = Discriminator()
criterion = nn.BCELoss()
optimizer_G = torch.optim.Adam(G.parameters(), lr=0.0002)
optimizer_D = torch.optim.Adam(D.parameters(), lr=0.0002)

for epoch in range(epochs):
    for real_data in dataloader:
        batch_size = real_data.size(0)
        real_labels = torch.ones(batch_size, 1)
        fake_labels = torch.zeros(batch_size, 1)

        # 識別器の学習
        optimizer_D.zero_grad()
        z = torch.randn(batch_size, 100)
        fake_data = G(z)
        d_loss = criterion(D(real_data), real_labels) + criterion(D(fake_data.detach()), fake_labels)
        d_loss.backward()
        optimizer_D.step()

        # 生成器の学習
        optimizer_G.zero_grad()
        g_loss = criterion(D(fake_data), real_labels)  # 識別器を騙したい
        g_loss.backward()
        optimizer_G.step()

GANの課題

モード崩壊

生成器が限られたパターンしか生成しなくなる問題

学習の不安定性

GとDのバランスが崩れると学習が失敗する

評価の難しさ

生成品質の定量的な評価が難しい

理解度チェック

Q. 生成器の学習時、損失関数は何を最大化しようとしますか?