自己教師あり学習 (Self-Supervised Learning)
学習目標: ラベルなしデータから表現を学ぶ手法(対照学習 / マスク復元 / BYOL系)を理解し、SimCLR と MAE を実装できる
SSL のモチベーション
- ラベル不要: Webから大量のデータをそのまま使える
- 汎用表現: 下流タスク(分類・検出・セグ)に転移しやすい
- 事前学習標準化: BERT・SimCLR・MAE が現代の事前学習の主流
SSL タスクの構造
1. データから自動的に「擬似ラベル」を作る (pretext task) 2. 擬似ラベルを当てるように encoder を学習 3. 下流タスク向けに encoder を固定 or fine-tune
擬似タスクの例
| カテゴリ | 例 | 代表手法 |
|---|---|---|
| 対照学習 | 同じ画像の2つの拡張を近づける | SimCLR, MoCo |
| マスク復元 | 一部を隠して復元 | MAE, BERT, BEiT |
| 予測タスク | パッチ位置、回転角を予測 | Jigsaw, RotNet |
| クラスタリング | 埋め込みをクラスタ化 | DeepCluster, SwAV |
| 蒸留型 | Teacher-Student 自己蒸留 | BYOL, DINO, MSN |
SimCLR — シンプル対照学習
アイデア
画像 x ──┬─ aug_1 ─► encoder ─► projector ─► z_i
└─ aug_2 ─► encoder ─► projector ─► z_j
z_i と z_j を近づけ、他バッチサンプルとは遠ざける
(NT-Xent 損失 = 温度付きsoftmax cross-entropy)
PyTorch 実装
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
class SimCLRTransform:
"""同じ画像から2つの異なるaugmentedビューを作る"""
def __init__(self, size=224):
self.t = transforms.Compose([
transforms.RandomResizedCrop(size, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=23),
transforms.ToTensor(),
])
def __call__(self, x):
return self.t(x), self.t(x)
class Projector(nn.Module):
def __init__(self, in_dim, hidden=2048, out=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden), nn.BatchNorm1d(hidden), nn.ReLU(),
nn.Linear(hidden, out),
)
def forward(self, x):
return self.net(x)
def nt_xent_loss(z1, z2, temperature=0.5):
"""NT-Xent (info-NCE) loss"""
N = z1.size(0)
z = F.normalize(torch.cat([z1, z2], dim=0), dim=1) # (2N, D)
sim = z @ z.T / temperature # (2N, 2N)
mask = torch.eye(2*N, dtype=torch.bool, device=z.device)
sim.masked_fill_(mask, -float('inf')) # 自己は除外
# 正解ペアのインデックス: i ↔ i+N
targets = torch.cat([torch.arange(N, 2*N), torch.arange(0, N)]).to(z.device)
return F.cross_entropy(sim, targets)
def train_step(encoder, projector, opt, batch_views):
x1, x2 = batch_views
h1, h2 = encoder(x1), encoder(x2)
z1, z2 = projector(h1), projector(h2)
loss = nt_xent_loss(z1, z2)
opt.zero_grad(); loss.backward(); opt.step()
return loss.item()
学習のポイント
- 強い augmentation が要(ColorJitter/RandomGrayscale/GaussianBlur がほぼ必須)
- 大きなバッチサイズ(256〜8192)。負例が多いほど精度↑
- Projection Head をはさむと表現が改善
- 下流タスクでは projector を捨てて encoder のみ使う
MAE — Masked Autoencoder (He et al. 2022)
画像の 75%のパッチをマスクし、見える25%から復元するように ViT を学習。シンプルかつ強力。
アーキテクチャ
画像 ─► パッチ分割 ─► 75%をマスク
│
▼ (見える25%のみ)
Heavy Encoder (ViT-Large)
│
▼ + 学習可能なマスクトークン
Light Decoder (浅いTransformer)
│
▼
ピクセル空間で復元
MSE Loss (マスクパッチのみ)
実装の要点
import torch.nn as nn
class MAE(nn.Module):
def __init__(self, encoder, decoder, patch_size=16, mask_ratio=0.75):
super().__init__()
self.encoder = encoder # ViT
self.decoder = decoder # 浅いTransformer
self.patch_size = patch_size
self.mask_ratio = mask_ratio
self.mask_token = nn.Parameter(torch.zeros(1, 1, encoder.embed_dim))
def random_mask(self, x):
"""N枚×P個のパッチからmask_ratio割をマスク"""
N, L, D = x.shape
n_keep = int(L * (1 - self.mask_ratio))
noise = torch.rand(N, L, device=x.device)
ids_shuffle = noise.argsort(dim=1)
ids_restore = ids_shuffle.argsort(dim=1)
ids_keep = ids_shuffle[:, :n_keep]
x_masked = torch.gather(x, 1, ids_keep.unsqueeze(-1).expand(-1, -1, D))
return x_masked, ids_restore, ids_keep
def forward(self, imgs):
# パッチ化+位置埋め込み
x = self.encoder.patch_embed(imgs)
x_visible, ids_restore, _ = self.random_mask(x)
# エンコーダは見えるパッチのみに作用
latent = self.encoder.blocks(x_visible)
# マスク位置にmask_tokenを挿入し、元順に並び替え
N, L_total, D = imgs.size(0), x.size(1), latent.size(-1)
mask_tokens = self.mask_token.expand(N, L_total - latent.size(1), -1)
full = torch.cat([latent, mask_tokens], dim=1)
full = torch.gather(full, 1, ids_restore.unsqueeze(-1).expand(-1, -1, D))
# デコーダで復元
pred = self.decoder(full)
return pred
def loss(self, pred, target, mask):
"""マスクされたパッチでのみMSE"""
loss = ((pred - target) ** 2).mean(dim=-1)
return (loss * mask).sum() / mask.sum()
なぜMAEが効くのか
- 75%という高いマスク率で「context → 全体」の推論を強制
- エンコーダが見えるパッチのみを扱うので計算効率↑
- ピクセル空間で復元するため意味理解が育つ
BYOL / DINO — 負例なし自己蒸留
SimCLRのような負例を使わず、Online ↔ Target ネットワークの「自己蒸留」で学習。BYOL (DeepMind), DINO (FAIR) が代表。
BYOL のアーキテクチャ
view_1 ─► online encoder ─► online projector ─► online predictor ──► p view_2 ─► target encoder ─► target projector ────────────────────► z (stop_grad) loss = || normalize(p) - normalize(z) ||² target_θ ← ema(target_θ, online_θ, τ=0.996)
DINO の特徴
- ViT に適用、クラスタリングを経由しないシンプルな蒸留
- Teacher 出力は 温度の異なるsoftmax + センタリング で安定化
- 学習済みTeacher は 自己注意マップが綺麗にオブジェクト位置を捉える(教師なしセグ可能)
EMA 更新の実装
@torch.no_grad()
def update_target(target_net, online_net, tau=0.996):
for tp, op in zip(target_net.parameters(), online_net.parameters()):
tp.data = tau * tp.data + (1 - tau) * op.data
def byol_step(online, target, predictor, opt, x1, x2):
o1 = predictor(online(x1))
o2 = predictor(online(x2))
with torch.no_grad():
t1 = target(x1)
t2 = target(x2)
loss = F.mse_loss(F.normalize(o1, dim=-1), F.normalize(t2, dim=-1)) + \
F.mse_loss(F.normalize(o2, dim=-1), F.normalize(t1, dim=-1))
opt.zero_grad(); loss.backward(); opt.step()
update_target(target, online)
| 手法 | 負例 | 特徴 |
|---|---|---|
| SimCLR | 必要(大バッチ) | シンプル、計算重 |
| MoCo v3 | キュー使用 | 小バッチでも動く |
| BYOL | 不要 | EMA Target + Predictor |
| DINO | 不要 | ViT に強い、自己注意マップ◎ |
| MAE | 不要 | マスク復元、ViT専用、大規模化容易 |