知識蒸留 (Knowledge Distillation)
学習目標: 大モデル(Teacher)の知識を小モデル(Student)に転送する手法を理解し、Logit蒸留・Feature蒸留・Self蒸留を実装できる
蒸留はなぜ効くか
- ダークナレッジ: 正解クラス以外の確率分布が「クラス間の類似性」を教える
- 正則化効果: Studentが過学習しにくくなる(softなターゲット)
- モデル圧縮: 推論コストを 10倍以上削減しつつ精度を維持しやすい
- ロバスト性向上: Teacher のアンサンブル → 小Student で実運用
主要パターン
| カテゴリ | 転送先 | 代表手法 |
|---|---|---|
| Logit蒸留 (Response-based) | 出力分布 | Hinton et al. 2015 |
| Feature蒸留 (Feature-based) | 中間特徴 | FitNets, AT, OFD |
| Relation蒸留 (Relation-based) | サンプル間関係 | RKD, CRD |
| Self蒸留 | 同一モデルの異なる層/エポック | BYOT, Snapshot Distillation |
| Online蒸留 | 複数Studentが相互教師 | DML, Co-Distillation |
Logit蒸留 — Hinton et al. (2015)
Teacher と Student の softmax出力 を一致させる。温度 T で分布をなだらかにするのがポイント。
損失関数
soft_loss = T² · KL(softmax(teacher_logits/T) || softmax(student_logits/T)) hard_loss = CE(student_logits, true_label) L = α · soft_loss + (1 - α) · hard_loss
PyTorch 実装
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistillationLoss(nn.Module):
def __init__(self, T=4.0, alpha=0.7):
super().__init__()
self.T, self.alpha = T, alpha
def forward(self, student_logits, teacher_logits, labels):
# soft loss (KL divergence between softened distributions)
s = F.log_softmax(student_logits / self.T, dim=1)
t = F.softmax(teacher_logits / self.T, dim=1)
soft = F.kl_div(s, t, reduction='batchmean') * (self.T ** 2)
# hard loss
hard = F.cross_entropy(student_logits, labels)
return self.alpha * soft + (1 - self.alpha) * hard
def train_step(teacher, student, opt, criterion, x, labels):
teacher.eval()
with torch.no_grad():
teacher_logits = teacher(x)
student.train()
student_logits = student(x)
loss = criterion(student_logits, teacher_logits, labels)
opt.zero_grad(); loss.backward(); opt.step()
return loss.item()
温度 T のチューニング
| T | 効果 |
|---|---|
| T = 1 | Hard ターゲット相当(softmaxほぼone-hot) |
| T = 2〜5 | 軽くsoften、classifier系で定番 |
| T = 10+ | 強くsoften、暗黙知を強く転送 |
Feature蒸留
Teacher の中間層の活性をStudentに合わせる。Logit蒸留より転送できる情報量が多いが、層のサイズ合わせが必要。
FitNets (基本パターン)
class FeatureDistillationLoss(nn.Module):
def __init__(self, student_dim, teacher_dim):
super().__init__()
# 次元が違う場合は1x1convでマッピング
self.adapter = nn.Conv2d(student_dim, teacher_dim, 1) \
if student_dim != teacher_dim else nn.Identity()
def forward(self, student_feat, teacher_feat):
return F.mse_loss(self.adapter(student_feat), teacher_feat)
def hook_features(model, layer_names):
"""指定レイヤの出力を取得するフック"""
features = {}
def make_hook(name):
def hook(module, inp, out):
features[name] = out
return hook
handles = []
for name, module in model.named_modules():
if name in layer_names:
handles.append(module.register_forward_hook(make_hook(name)))
return features, handles
Attention Transfer (AT)
Feature mapを sum_c |F_c|² でアテンションマップに変換してから蒸留。次元の違いを気にしなくて良い。
def attention_map(f):
"""(N, C, H, W) → (N, H*W) のアテンションマップ"""
return F.normalize(f.pow(2).mean(1).view(f.size(0), -1), dim=1)
def at_loss(student_feat, teacher_feat):
return (attention_map(student_feat) - attention_map(teacher_feat)).pow(2).mean()
Self蒸留 / Online蒸留
Self蒸留 (BYOT: Be Your Own Teacher)
1つのネットワーク内で深い層 → 浅い層に蒸留。追加モデル不要で精度向上。
class BYOTModel(nn.Module):
"""ResNet 各 stage の出口に補助分類器を付ける"""
def __init__(self, backbone, n_classes):
super().__init__()
self.stages = backbone.stages # 4 stages 想定
self.aux_heads = nn.ModuleList([
nn.Linear(d, n_classes) for d in backbone.stage_dims
])
self.main_head = nn.Linear(backbone.stage_dims[-1], n_classes)
def forward(self, x):
outs = []
for stage in self.stages:
x = stage(x)
outs.append(x)
aux_logits = [head(o.mean([-1, -2])) for head, o in zip(self.aux_heads, outs)]
main_logits = self.main_head(outs[-1].mean([-1, -2]))
return main_logits, aux_logits
def byot_loss(main_logits, aux_logits_list, labels, T=4.0):
main_ce = F.cross_entropy(main_logits, labels)
aux_ce = sum(F.cross_entropy(a, labels) for a in aux_logits_list)
# 各auxを mainのsoft labelに合わせる
with torch.no_grad():
soft_target = F.softmax(main_logits / T, dim=1)
distill = sum(F.kl_div(F.log_softmax(a / T, dim=1), soft_target,
reduction='batchmean') * (T**2)
for a in aux_logits_list)
return main_ce + 0.5 * aux_ce + 0.5 * distill
DML (Deep Mutual Learning) — Online蒸留
2つのStudentを同時に学習し、互いを教師に。Teacher不要、精度↑。
def mutual_loss(logits_a, logits_b, labels):
ce_a = F.cross_entropy(logits_a, labels)
ce_b = F.cross_entropy(logits_b, labels)
# 相互の予測を softに合わせる
kl_a = F.kl_div(F.log_softmax(logits_a, dim=1),
F.softmax(logits_b.detach(), dim=1),
reduction='batchmean')
kl_b = F.kl_div(F.log_softmax(logits_b, dim=1),
F.softmax(logits_a.detach(), dim=1),
reduction='batchmean')
return ce_a + ce_b + kl_a + kl_b
| 手法 | Teacher | 追加メモリ | 典型用途 |
|---|---|---|---|
| Hinton KD | 事前学習 | +1モデル | 軽量化 |
| FitNets / AT | 事前学習 | +1モデル | 細かい知識転送 |
| BYOT | 自分 | auxヘッドのみ | 追加コスト最小で精度↑ |
| DML | 相互 | 2モデル並列 | 事前学習なし |