条件付き生成

学習目標: 条件付き生成モデルの仕組みと実装方法を理解する

条件付き生成とは

条件付き生成(Conditional Generation)は、特定の条件(クラスラベル、テキスト、画像など)に基づいて 目的のデータを生成する手法です。

無条件生成

ランダムな画像を生成

クラス条件付き

「猫」「犬」などのラベルで制御

テキスト条件付き

「赤い車」などの説明文で制御

Conditional GAN (cGAN)

cGANでは、生成器と識別器の両方に条件情報を入力します。

G(z, c) → 画像, D(x, c) → 本物/偽物

z: ノイズ, c: 条件(クラスラベルなど), x: 画像

class ConditionalGenerator(nn.Module):
    def __init__(self, latent_dim=100, n_classes=10, img_size=28):
        super().__init__()
        self.img_size = img_size

        # クラスラベルの埋め込み
        self.label_embedding = nn.Embedding(n_classes, n_classes)

        self.model = nn.Sequential(
            nn.Linear(latent_dim + n_classes, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, img_size * img_size),
            nn.Tanh()
        )

    def forward(self, z, labels):
        # ラベルを埋め込みベクトルに変換
        c = self.label_embedding(labels)
        # ノイズとラベルを結合
        x = torch.cat([z, c], dim=1)
        img = self.model(x)
        return img.view(-1, 1, self.img_size, self.img_size)


class ConditionalDiscriminator(nn.Module):
    def __init__(self, n_classes=10, img_size=28):
        super().__init__()
        self.img_size = img_size

        self.label_embedding = nn.Embedding(n_classes, n_classes)

        self.model = nn.Sequential(
            nn.Linear(img_size * img_size + n_classes, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img, labels):
        c = self.label_embedding(labels)
        x = torch.cat([img.view(-1, self.img_size * self.img_size), c], dim=1)
        return self.model(x)

Conditional VAE (cVAE)

class ConditionalVAE(nn.Module):
    def __init__(self, input_dim=784, latent_dim=20, n_classes=10):
        super().__init__()

        # エンコーダ(入力 + ラベル)
        self.encoder = nn.Sequential(
            nn.Linear(input_dim + n_classes, 256),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)

        # デコーダ(潜在変数 + ラベル)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim + n_classes, 256),
            nn.ReLU(),
            nn.Linear(256, input_dim),
            nn.Sigmoid()
        )

        self.n_classes = n_classes

    def encode(self, x, c):
        # one-hotエンコーディング
        c_onehot = F.one_hot(c, self.n_classes).float()
        x_c = torch.cat([x, c_onehot], dim=1)
        h = self.encoder(x_c)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c):
        c_onehot = F.one_hot(c, self.n_classes).float()
        z_c = torch.cat([z, c_onehot], dim=1)
        return self.decoder(z_c)

    def forward(self, x, c):
        mu, logvar = self.encode(x, c)
        z = self.reparameterize(mu, logvar)
        return self.decode(z, c), mu, logvar


# 特定のクラスの画像を生成
def generate_digit(model, digit, n_samples=10):
    z = torch.randn(n_samples, latent_dim)
    labels = torch.full((n_samples,), digit, dtype=torch.long)
    with torch.no_grad():
        generated = model.decode(z, labels)
    return generated

テキスト条件付き生成

DALL-E、Stable Diffusionなどの最新モデルでは、テキストを条件として画像を生成します。

アーキテクチャの概要:

  1. テキストエンコーダ(CLIP、T5など)でテキストを埋め込みベクトルに変換
  2. Cross-Attentionで画像生成にテキスト情報を注入
  3. 拡散モデルやTransformerで画像を生成
# 概念的なコード(簡略化)
class TextConditionedGenerator(nn.Module):
    def __init__(self, text_encoder, image_decoder):
        super().__init__()
        self.text_encoder = text_encoder  # 例: CLIP
        self.image_decoder = image_decoder

    def forward(self, noise, text):
        # テキストを埋め込みに変換
        text_embedding = self.text_encoder(text)

        # Cross-Attentionでテキスト情報を注入しながら画像生成
        image = self.image_decoder(noise, text_embedding)
        return image

理解度チェック

Q. 条件付きGANで、識別器にもラベル情報を与える理由は?