自己教師あり学習 (Self-Supervised Learning)

学習目標: ラベルなしデータから表現を学ぶ手法(対照学習 / マスク復元 / BYOL系)を理解し、SimCLR と MAE を実装できる

SSL のモチベーション

  • ラベル不要: Webから大量のデータをそのまま使える
  • 汎用表現: 下流タスク(分類・検出・セグ)に転移しやすい
  • 事前学習標準化: BERT・SimCLR・MAE が現代の事前学習の主流

SSL タスクの構造

1. データから自動的に「擬似ラベル」を作る (pretext task)
2. 擬似ラベルを当てるように encoder を学習
3. 下流タスク向けに encoder を固定 or fine-tune

擬似タスクの例

カテゴリ代表手法
対照学習同じ画像の2つの拡張を近づけるSimCLR, MoCo
マスク復元一部を隠して復元MAE, BERT, BEiT
予測タスクパッチ位置、回転角を予測Jigsaw, RotNet
クラスタリング埋め込みをクラスタ化DeepCluster, SwAV
蒸留型Teacher-Student 自己蒸留BYOL, DINO, MSN

SimCLR — シンプル対照学習

アイデア

画像 x ──┬─ aug_1 ─► encoder ─► projector ─► z_i
        └─ aug_2 ─► encoder ─► projector ─► z_j

z_i と z_j を近づけ、他バッチサンプルとは遠ざける
(NT-Xent 損失 = 温度付きsoftmax cross-entropy)

PyTorch 実装

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms

class SimCLRTransform:
    """同じ画像から2つの異なるaugmentedビューを作る"""
    def __init__(self, size=224):
        self.t = transforms.Compose([
            transforms.RandomResizedCrop(size, scale=(0.2, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=23),
            transforms.ToTensor(),
        ])
    def __call__(self, x):
        return self.t(x), self.t(x)


class Projector(nn.Module):
    def __init__(self, in_dim, hidden=2048, out=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.BatchNorm1d(hidden), nn.ReLU(),
            nn.Linear(hidden, out),
        )
    def forward(self, x):
        return self.net(x)


def nt_xent_loss(z1, z2, temperature=0.5):
    """NT-Xent (info-NCE) loss"""
    N = z1.size(0)
    z = F.normalize(torch.cat([z1, z2], dim=0), dim=1)   # (2N, D)
    sim = z @ z.T / temperature                          # (2N, 2N)
    mask = torch.eye(2*N, dtype=torch.bool, device=z.device)
    sim.masked_fill_(mask, -float('inf'))                # 自己は除外

    # 正解ペアのインデックス: i ↔ i+N
    targets = torch.cat([torch.arange(N, 2*N), torch.arange(0, N)]).to(z.device)
    return F.cross_entropy(sim, targets)


def train_step(encoder, projector, opt, batch_views):
    x1, x2 = batch_views
    h1, h2 = encoder(x1), encoder(x2)
    z1, z2 = projector(h1), projector(h2)
    loss = nt_xent_loss(z1, z2)
    opt.zero_grad(); loss.backward(); opt.step()
    return loss.item()

学習のポイント

  • 強い augmentation が要(ColorJitter/RandomGrayscale/GaussianBlur がほぼ必須)
  • 大きなバッチサイズ(256〜8192)。負例が多いほど精度↑
  • Projection Head をはさむと表現が改善
  • 下流タスクでは projector を捨てて encoder のみ使う

MAE — Masked Autoencoder (He et al. 2022)

画像の 75%のパッチをマスクし、見える25%から復元するように ViT を学習。シンプルかつ強力。

アーキテクチャ

画像 ─► パッチ分割 ─► 75%をマスク
                     │
                     ▼ (見える25%のみ)
                Heavy Encoder (ViT-Large)
                     │
                     ▼ + 学習可能なマスクトークン
                Light Decoder (浅いTransformer)
                     │
                     ▼
                ピクセル空間で復元
                MSE Loss (マスクパッチのみ)

実装の要点

import torch.nn as nn


class MAE(nn.Module):
    def __init__(self, encoder, decoder, patch_size=16, mask_ratio=0.75):
        super().__init__()
        self.encoder = encoder      # ViT
        self.decoder = decoder      # 浅いTransformer
        self.patch_size = patch_size
        self.mask_ratio = mask_ratio
        self.mask_token = nn.Parameter(torch.zeros(1, 1, encoder.embed_dim))

    def random_mask(self, x):
        """N枚×P個のパッチからmask_ratio割をマスク"""
        N, L, D = x.shape
        n_keep = int(L * (1 - self.mask_ratio))
        noise = torch.rand(N, L, device=x.device)
        ids_shuffle = noise.argsort(dim=1)
        ids_restore = ids_shuffle.argsort(dim=1)
        ids_keep    = ids_shuffle[:, :n_keep]

        x_masked = torch.gather(x, 1, ids_keep.unsqueeze(-1).expand(-1, -1, D))
        return x_masked, ids_restore, ids_keep

    def forward(self, imgs):
        # パッチ化+位置埋め込み
        x = self.encoder.patch_embed(imgs)
        x_visible, ids_restore, _ = self.random_mask(x)

        # エンコーダは見えるパッチのみに作用
        latent = self.encoder.blocks(x_visible)

        # マスク位置にmask_tokenを挿入し、元順に並び替え
        N, L_total, D = imgs.size(0), x.size(1), latent.size(-1)
        mask_tokens = self.mask_token.expand(N, L_total - latent.size(1), -1)
        full = torch.cat([latent, mask_tokens], dim=1)
        full = torch.gather(full, 1, ids_restore.unsqueeze(-1).expand(-1, -1, D))

        # デコーダで復元
        pred = self.decoder(full)
        return pred

    def loss(self, pred, target, mask):
        """マスクされたパッチでのみMSE"""
        loss = ((pred - target) ** 2).mean(dim=-1)
        return (loss * mask).sum() / mask.sum()

なぜMAEが効くのか

  • 75%という高いマスク率で「context → 全体」の推論を強制
  • エンコーダが見えるパッチのみを扱うので計算効率↑
  • ピクセル空間で復元するため意味理解が育つ

BYOL / DINO — 負例なし自己蒸留

SimCLRのような負例を使わず、Online ↔ Target ネットワークの「自己蒸留」で学習。BYOL (DeepMind), DINO (FAIR) が代表。

BYOL のアーキテクチャ

view_1 ─► online encoder ─► online projector ─► online predictor ──► p
view_2 ─► target encoder ─► target projector ────────────────────► z (stop_grad)

loss = || normalize(p) - normalize(z) ||²

target_θ ← ema(target_θ, online_θ, τ=0.996)

DINO の特徴

  • ViT に適用、クラスタリングを経由しないシンプルな蒸留
  • Teacher 出力は 温度の異なるsoftmax + センタリング で安定化
  • 学習済みTeacher は 自己注意マップが綺麗にオブジェクト位置を捉える(教師なしセグ可能)

EMA 更新の実装

@torch.no_grad()
def update_target(target_net, online_net, tau=0.996):
    for tp, op in zip(target_net.parameters(), online_net.parameters()):
        tp.data = tau * tp.data + (1 - tau) * op.data


def byol_step(online, target, predictor, opt, x1, x2):
    o1 = predictor(online(x1))
    o2 = predictor(online(x2))
    with torch.no_grad():
        t1 = target(x1)
        t2 = target(x2)
    loss = F.mse_loss(F.normalize(o1, dim=-1), F.normalize(t2, dim=-1)) + \
           F.mse_loss(F.normalize(o2, dim=-1), F.normalize(t1, dim=-1))
    opt.zero_grad(); loss.backward(); opt.step()
    update_target(target, online)
手法負例特徴
SimCLR必要(大バッチ)シンプル、計算重
MoCo v3キュー使用小バッチでも動く
BYOL不要EMA Target + Predictor
DINO不要ViT に強い、自己注意マップ◎
MAE不要マスク復元、ViT専用、大規模化容易