条件付き生成
学習目標: 条件付き生成モデルの仕組みと実装方法を理解する
条件付き生成とは
条件付き生成(Conditional Generation)は、特定の条件(クラスラベル、テキスト、画像など)に基づいて 目的のデータを生成する手法です。
無条件生成
ランダムな画像を生成
クラス条件付き
「猫」「犬」などのラベルで制御
テキスト条件付き
「赤い車」などの説明文で制御
Conditional GAN (cGAN)
cGANでは、生成器と識別器の両方に条件情報を入力します。
G(z, c) → 画像, D(x, c) → 本物/偽物
z: ノイズ, c: 条件(クラスラベルなど), x: 画像
class ConditionalGenerator(nn.Module):
def __init__(self, latent_dim=100, n_classes=10, img_size=28):
super().__init__()
self.img_size = img_size
# クラスラベルの埋め込み
self.label_embedding = nn.Embedding(n_classes, n_classes)
self.model = nn.Sequential(
nn.Linear(latent_dim + n_classes, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, img_size * img_size),
nn.Tanh()
)
def forward(self, z, labels):
# ラベルを埋め込みベクトルに変換
c = self.label_embedding(labels)
# ノイズとラベルを結合
x = torch.cat([z, c], dim=1)
img = self.model(x)
return img.view(-1, 1, self.img_size, self.img_size)
class ConditionalDiscriminator(nn.Module):
def __init__(self, n_classes=10, img_size=28):
super().__init__()
self.img_size = img_size
self.label_embedding = nn.Embedding(n_classes, n_classes)
self.model = nn.Sequential(
nn.Linear(img_size * img_size + n_classes, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img, labels):
c = self.label_embedding(labels)
x = torch.cat([img.view(-1, self.img_size * self.img_size), c], dim=1)
return self.model(x)
Conditional VAE (cVAE)
class ConditionalVAE(nn.Module):
def __init__(self, input_dim=784, latent_dim=20, n_classes=10):
super().__init__()
# エンコーダ(入力 + ラベル)
self.encoder = nn.Sequential(
nn.Linear(input_dim + n_classes, 256),
nn.ReLU()
)
self.fc_mu = nn.Linear(256, latent_dim)
self.fc_logvar = nn.Linear(256, latent_dim)
# デコーダ(潜在変数 + ラベル)
self.decoder = nn.Sequential(
nn.Linear(latent_dim + n_classes, 256),
nn.ReLU(),
nn.Linear(256, input_dim),
nn.Sigmoid()
)
self.n_classes = n_classes
def encode(self, x, c):
# one-hotエンコーディング
c_onehot = F.one_hot(c, self.n_classes).float()
x_c = torch.cat([x, c_onehot], dim=1)
h = self.encoder(x_c)
return self.fc_mu(h), self.fc_logvar(h)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z, c):
c_onehot = F.one_hot(c, self.n_classes).float()
z_c = torch.cat([z, c_onehot], dim=1)
return self.decoder(z_c)
def forward(self, x, c):
mu, logvar = self.encode(x, c)
z = self.reparameterize(mu, logvar)
return self.decode(z, c), mu, logvar
# 特定のクラスの画像を生成
def generate_digit(model, digit, n_samples=10):
z = torch.randn(n_samples, latent_dim)
labels = torch.full((n_samples,), digit, dtype=torch.long)
with torch.no_grad():
generated = model.decode(z, labels)
return generated
テキスト条件付き生成
DALL-E、Stable Diffusionなどの最新モデルでは、テキストを条件として画像を生成します。
アーキテクチャの概要:
- テキストエンコーダ(CLIP、T5など)でテキストを埋め込みベクトルに変換
- Cross-Attentionで画像生成にテキスト情報を注入
- 拡散モデルやTransformerで画像を生成
# 概念的なコード(簡略化)
class TextConditionedGenerator(nn.Module):
def __init__(self, text_encoder, image_decoder):
super().__init__()
self.text_encoder = text_encoder # 例: CLIP
self.image_decoder = image_decoder
def forward(self, noise, text):
# テキストを埋め込みに変換
text_embedding = self.text_encoder(text)
# Cross-Attentionでテキスト情報を注入しながら画像生成
image = self.image_decoder(noise, text_embedding)
return image
理解度チェック
Q. 条件付きGANで、識別器にもラベル情報を与える理由は?