知識蒸留 (Knowledge Distillation)

学習目標: 大モデル(Teacher)の知識を小モデル(Student)に転送する手法を理解し、Logit蒸留・Feature蒸留・Self蒸留を実装できる

蒸留はなぜ効くか

  • ダークナレッジ: 正解クラス以外の確率分布が「クラス間の類似性」を教える
  • 正則化効果: Studentが過学習しにくくなる(softなターゲット)
  • モデル圧縮: 推論コストを 10倍以上削減しつつ精度を維持しやすい
  • ロバスト性向上: Teacher のアンサンブル → 小Student で実運用

主要パターン

カテゴリ転送先代表手法
Logit蒸留 (Response-based)出力分布Hinton et al. 2015
Feature蒸留 (Feature-based)中間特徴FitNets, AT, OFD
Relation蒸留 (Relation-based)サンプル間関係RKD, CRD
Self蒸留同一モデルの異なる層/エポックBYOT, Snapshot Distillation
Online蒸留複数Studentが相互教師DML, Co-Distillation

Logit蒸留 — Hinton et al. (2015)

Teacher と Student の softmax出力 を一致させる。温度 T で分布をなだらかにするのがポイント。

損失関数

soft_loss = T² · KL(softmax(teacher_logits/T) || softmax(student_logits/T))
hard_loss = CE(student_logits, true_label)

L = α · soft_loss + (1 - α) · hard_loss

PyTorch 実装

import torch
import torch.nn as nn
import torch.nn.functional as F


class DistillationLoss(nn.Module):
    def __init__(self, T=4.0, alpha=0.7):
        super().__init__()
        self.T, self.alpha = T, alpha

    def forward(self, student_logits, teacher_logits, labels):
        # soft loss (KL divergence between softened distributions)
        s = F.log_softmax(student_logits / self.T, dim=1)
        t = F.softmax(teacher_logits / self.T, dim=1)
        soft = F.kl_div(s, t, reduction='batchmean') * (self.T ** 2)

        # hard loss
        hard = F.cross_entropy(student_logits, labels)

        return self.alpha * soft + (1 - self.alpha) * hard


def train_step(teacher, student, opt, criterion, x, labels):
    teacher.eval()
    with torch.no_grad():
        teacher_logits = teacher(x)

    student.train()
    student_logits = student(x)
    loss = criterion(student_logits, teacher_logits, labels)

    opt.zero_grad(); loss.backward(); opt.step()
    return loss.item()

温度 T のチューニング

T効果
T = 1Hard ターゲット相当(softmaxほぼone-hot)
T = 2〜5軽くsoften、classifier系で定番
T = 10+強くsoften、暗黙知を強く転送

Feature蒸留

Teacher の中間層の活性をStudentに合わせる。Logit蒸留より転送できる情報量が多いが、層のサイズ合わせが必要。

FitNets (基本パターン)

class FeatureDistillationLoss(nn.Module):
    def __init__(self, student_dim, teacher_dim):
        super().__init__()
        # 次元が違う場合は1x1convでマッピング
        self.adapter = nn.Conv2d(student_dim, teacher_dim, 1) \
                       if student_dim != teacher_dim else nn.Identity()

    def forward(self, student_feat, teacher_feat):
        return F.mse_loss(self.adapter(student_feat), teacher_feat)


def hook_features(model, layer_names):
    """指定レイヤの出力を取得するフック"""
    features = {}
    def make_hook(name):
        def hook(module, inp, out):
            features[name] = out
        return hook
    handles = []
    for name, module in model.named_modules():
        if name in layer_names:
            handles.append(module.register_forward_hook(make_hook(name)))
    return features, handles

Attention Transfer (AT)

Feature mapを sum_c |F_c|² でアテンションマップに変換してから蒸留。次元の違いを気にしなくて良い。

def attention_map(f):
    """(N, C, H, W) → (N, H*W) のアテンションマップ"""
    return F.normalize(f.pow(2).mean(1).view(f.size(0), -1), dim=1)


def at_loss(student_feat, teacher_feat):
    return (attention_map(student_feat) - attention_map(teacher_feat)).pow(2).mean()

Self蒸留 / Online蒸留

Self蒸留 (BYOT: Be Your Own Teacher)

1つのネットワーク内で深い層 → 浅い層に蒸留。追加モデル不要で精度向上。

class BYOTModel(nn.Module):
    """ResNet 各 stage の出口に補助分類器を付ける"""
    def __init__(self, backbone, n_classes):
        super().__init__()
        self.stages = backbone.stages          # 4 stages 想定
        self.aux_heads = nn.ModuleList([
            nn.Linear(d, n_classes) for d in backbone.stage_dims
        ])
        self.main_head = nn.Linear(backbone.stage_dims[-1], n_classes)

    def forward(self, x):
        outs = []
        for stage in self.stages:
            x = stage(x)
            outs.append(x)
        aux_logits = [head(o.mean([-1, -2])) for head, o in zip(self.aux_heads, outs)]
        main_logits = self.main_head(outs[-1].mean([-1, -2]))
        return main_logits, aux_logits


def byot_loss(main_logits, aux_logits_list, labels, T=4.0):
    main_ce = F.cross_entropy(main_logits, labels)
    aux_ce  = sum(F.cross_entropy(a, labels) for a in aux_logits_list)

    # 各auxを mainのsoft labelに合わせる
    with torch.no_grad():
        soft_target = F.softmax(main_logits / T, dim=1)
    distill = sum(F.kl_div(F.log_softmax(a / T, dim=1), soft_target,
                           reduction='batchmean') * (T**2)
                  for a in aux_logits_list)

    return main_ce + 0.5 * aux_ce + 0.5 * distill

DML (Deep Mutual Learning) — Online蒸留

2つのStudentを同時に学習し、互いを教師に。Teacher不要、精度↑。

def mutual_loss(logits_a, logits_b, labels):
    ce_a = F.cross_entropy(logits_a, labels)
    ce_b = F.cross_entropy(logits_b, labels)

    # 相互の予測を softに合わせる
    kl_a = F.kl_div(F.log_softmax(logits_a, dim=1),
                    F.softmax(logits_b.detach(), dim=1),
                    reduction='batchmean')
    kl_b = F.kl_div(F.log_softmax(logits_b, dim=1),
                    F.softmax(logits_a.detach(), dim=1),
                    reduction='batchmean')
    return ce_a + ce_b + kl_a + kl_b
手法Teacher追加メモリ典型用途
Hinton KD事前学習+1モデル軽量化
FitNets / AT事前学習+1モデル細かい知識転送
BYOT自分auxヘッドのみ追加コスト最小で精度↑
DML相互2モデル並列事前学習なし