Few-Shot Learning

学習目標: 少数のサンプルから学習する手法(Metric-based / Optimization-based / Model-based)を理解し、PyTorch で Prototypical Network / MAML を実装できるようになる

Few-Shot Learning とは

クラスごとに K サンプル(K=1〜5) だけで分類できるよう学習する設定。
人間が「数枚見ただけで新しい動物を見分けられる」ような汎化を目指す。

N-way K-shot タスク

1タスク = N クラス × K サンプル のサポート集合 + Q クエリ
例: 5-way 1-shot = 5クラスから1枚ずつ示し、新しい画像が
                    どのクラスかを当てる

主要アプローチの分類

カテゴリ代表手法アイデア
Metric-basedSiamese / Matching / Prototypical / Relation Network埋め込み空間で距離を学習
Optimization-basedMAML / Reptile / ANIL「数ステップで適応できる初期パラメータ」を学習
Model-basedMANN / SNAIL外部メモリで素早く適応
Pre-train + Fine-tuneCLIP / DINO + Linear Probe大規模事前学習 → 下流タスクへ少量で適応

Prototypical Network (ProtoNet)

各クラスの埋め込みベクトルの平均(プロトタイプ)を計算し、クエリと最も近いプロトタイプを予測する。
シンプルだが強力で、Few-Shotの定番ベースライン。

数式

プロトタイプ:  c_k = (1/|S_k|) Σ_{(x_i, y_i)∈S_k} f_θ(x_i)
予測:          p(y=k | x) ∝ exp(-d(f_θ(x), c_k))
損失:          NLL on cross-entropy across classes

実装

import torch
import torch.nn as nn
import torch.nn.functional as F


class EmbeddingNet(nn.Module):
    """4層CNN(Omniglot/miniImageNet定番)"""
    def __init__(self, hid=64, out=64):
        super().__init__()
        def block(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(),
                nn.MaxPool2d(2),
            )
        self.net = nn.Sequential(
            block(3, hid), block(hid, hid),
            block(hid, hid), block(hid, out),
            nn.Flatten(),
        )
    def forward(self, x):
        return self.net(x)


def prototypical_loss(model, support_x, support_y, query_x, query_y, n_way):
    """1タスクのprotoloss"""
    z_s = model(support_x)     # (N*K, D)
    z_q = model(query_x)       # (N*Q, D)

    # クラスごとに平均 → プロトタイプ (N, D)
    prototypes = torch.stack([z_s[support_y == k].mean(0) for k in range(n_way)])

    # 各クエリと各プロトタイプの距離(負の二乗距離 = logit)
    dists = torch.cdist(z_q, prototypes)
    logits = -dists

    loss = F.cross_entropy(logits, query_y)
    acc  = (logits.argmax(1) == query_y).float().mean()
    return loss, acc


def train_episode(model, opt, sampler, device='cuda'):
    """1エピソード(=1タスク)で更新"""
    support_x, support_y, query_x, query_y, n_way = sampler.sample()
    support_x, query_x = support_x.to(device), query_x.to(device)
    support_y, query_y = support_y.to(device), query_y.to(device)

    loss, acc = prototypical_loss(model, support_x, support_y, query_x, query_y, n_way)
    opt.zero_grad(); loss.backward(); opt.step()
    return loss.item(), acc.item()
Tip: 同じ埋め込みネットを使って n_way を可変にできるのが ProtoNet の強み(5-way学習 → 10-way推論も可能)。

MAML (Model-Agnostic Meta-Learning)

「少数ステップの勾配降下で適応できる良い初期パラメータ θ」を学習する。
内側ループでタスク特化、外側ループで初期値を更新するメタ学習。

アルゴリズム

for メタ更新:
    タスクをバッチサンプル
    for 各タスク T_i:
        θ_i' = θ - α ∇_θ L_{T_i}(f_θ, support)     ← 内側ループ
    θ ← θ - β ∇_θ Σ_i L_{T_i}(f_{θ_i'}, query)     ← 外側ループ

実装スケッチ(1st-order MAML, FOMAML)

def maml_outer_step(model, opt, tasks, inner_lr=0.01, inner_steps=5):
    meta_loss = 0
    for support_x, support_y, query_x, query_y in tasks:
        # === 内側ループ ===
        fast_params = {n: p.clone() for n, p in model.named_parameters()}
        for _ in range(inner_steps):
            logits = functional_forward(model, fast_params, support_x)
            loss = F.cross_entropy(logits, support_y)
            grads = torch.autograd.grad(loss, fast_params.values(),
                                        create_graph=True)
            fast_params = {n: p - inner_lr * g
                           for (n, p), g in zip(fast_params.items(), grads)}

        # === 外側: 適応後のパラメータでクエリ損失 ===
        q_logits = functional_forward(model, fast_params, query_x)
        meta_loss = meta_loss + F.cross_entropy(q_logits, query_y)

    meta_loss = meta_loss / len(tasks)
    opt.zero_grad(); meta_loss.backward(); opt.step()

実用的には learn2learnhigher ライブラリを使うと内側ループが綺麗に書けます。

MAML系の派生

手法変更点
FOMAML2階勾配を捨てて1階に近似 → 高速
Reptile外側勾配を「適応後の重み − 初期重み」で近似
ANIL内側ループでは最終層だけ更新 → 軽量化
iMAMLimplicit gradient で内側計算を分離

エピソード(タスク)サンプリング

Few-Shotは「学習時にも N-way K-shotのタスクをランダムに作る」のが特徴。データセット → タスク列に変換するサンプラーを書くのが要点。

import random
import torch
from torch.utils.data import Dataset

class EpisodeSampler:
    def __init__(self, dataset_by_class, n_way=5, k_shot=1, q_query=15):
        """
        dataset_by_class: {class_id: [tensor_image_1, tensor_image_2, ...]}
        """
        self.data = dataset_by_class
        self.classes = list(dataset_by_class.keys())
        self.n_way, self.k_shot, self.q_query = n_way, k_shot, q_query

    def sample(self):
        chosen = random.sample(self.classes, self.n_way)
        support_x, support_y, query_x, query_y = [], [], [], []
        for label, cls in enumerate(chosen):
            samples = random.sample(self.data[cls], self.k_shot + self.q_query)
            for s in samples[:self.k_shot]:
                support_x.append(s); support_y.append(label)
            for s in samples[self.k_shot:]:
                query_x.append(s);   query_y.append(label)
        return (torch.stack(support_x), torch.tensor(support_y),
                torch.stack(query_x),   torch.tensor(query_y),
                self.n_way)

代表データセット

  • Omniglot — 50言語の手書き文字、1623クラス × 20サンプル(Few-Shotの定番)
  • miniImageNet — ImageNet から100クラス抽出、84×84
  • tieredImageNet — 階層構造ありの大規模版
  • Meta-Dataset — 10種類のドメイン横断ベンチマーク