Few-Shot Learning
学習目標: 少数のサンプルから学習する手法(Metric-based / Optimization-based / Model-based)を理解し、PyTorch で Prototypical Network / MAML を実装できるようになる
Few-Shot Learning とは
クラスごとに K サンプル(K=1〜5) だけで分類できるよう学習する設定。
人間が「数枚見ただけで新しい動物を見分けられる」ような汎化を目指す。
N-way K-shot タスク
1タスク = N クラス × K サンプル のサポート集合 + Q クエリ
例: 5-way 1-shot = 5クラスから1枚ずつ示し、新しい画像が
どのクラスかを当てる
主要アプローチの分類
| カテゴリ | 代表手法 | アイデア |
|---|---|---|
| Metric-based | Siamese / Matching / Prototypical / Relation Network | 埋め込み空間で距離を学習 |
| Optimization-based | MAML / Reptile / ANIL | 「数ステップで適応できる初期パラメータ」を学習 |
| Model-based | MANN / SNAIL | 外部メモリで素早く適応 |
| Pre-train + Fine-tune | CLIP / DINO + Linear Probe | 大規模事前学習 → 下流タスクへ少量で適応 |
Prototypical Network (ProtoNet)
各クラスの埋め込みベクトルの平均(プロトタイプ)を計算し、クエリと最も近いプロトタイプを予測する。
シンプルだが強力で、Few-Shotの定番ベースライン。
数式
プロトタイプ: c_k = (1/|S_k|) Σ_{(x_i, y_i)∈S_k} f_θ(x_i)
予測: p(y=k | x) ∝ exp(-d(f_θ(x), c_k))
損失: NLL on cross-entropy across classes
実装
import torch
import torch.nn as nn
import torch.nn.functional as F
class EmbeddingNet(nn.Module):
"""4層CNN(Omniglot/miniImageNet定番)"""
def __init__(self, hid=64, out=64):
super().__init__()
def block(in_ch, out_ch):
return nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.net = nn.Sequential(
block(3, hid), block(hid, hid),
block(hid, hid), block(hid, out),
nn.Flatten(),
)
def forward(self, x):
return self.net(x)
def prototypical_loss(model, support_x, support_y, query_x, query_y, n_way):
"""1タスクのprotoloss"""
z_s = model(support_x) # (N*K, D)
z_q = model(query_x) # (N*Q, D)
# クラスごとに平均 → プロトタイプ (N, D)
prototypes = torch.stack([z_s[support_y == k].mean(0) for k in range(n_way)])
# 各クエリと各プロトタイプの距離(負の二乗距離 = logit)
dists = torch.cdist(z_q, prototypes)
logits = -dists
loss = F.cross_entropy(logits, query_y)
acc = (logits.argmax(1) == query_y).float().mean()
return loss, acc
def train_episode(model, opt, sampler, device='cuda'):
"""1エピソード(=1タスク)で更新"""
support_x, support_y, query_x, query_y, n_way = sampler.sample()
support_x, query_x = support_x.to(device), query_x.to(device)
support_y, query_y = support_y.to(device), query_y.to(device)
loss, acc = prototypical_loss(model, support_x, support_y, query_x, query_y, n_way)
opt.zero_grad(); loss.backward(); opt.step()
return loss.item(), acc.item()
Tip: 同じ埋め込みネットを使って
n_way を可変にできるのが ProtoNet の強み(5-way学習 → 10-way推論も可能)。
MAML (Model-Agnostic Meta-Learning)
「少数ステップの勾配降下で適応できる良い初期パラメータ θ」を学習する。
内側ループでタスク特化、外側ループで初期値を更新するメタ学習。
アルゴリズム
for メタ更新:
タスクをバッチサンプル
for 各タスク T_i:
θ_i' = θ - α ∇_θ L_{T_i}(f_θ, support) ← 内側ループ
θ ← θ - β ∇_θ Σ_i L_{T_i}(f_{θ_i'}, query) ← 外側ループ
実装スケッチ(1st-order MAML, FOMAML)
def maml_outer_step(model, opt, tasks, inner_lr=0.01, inner_steps=5):
meta_loss = 0
for support_x, support_y, query_x, query_y in tasks:
# === 内側ループ ===
fast_params = {n: p.clone() for n, p in model.named_parameters()}
for _ in range(inner_steps):
logits = functional_forward(model, fast_params, support_x)
loss = F.cross_entropy(logits, support_y)
grads = torch.autograd.grad(loss, fast_params.values(),
create_graph=True)
fast_params = {n: p - inner_lr * g
for (n, p), g in zip(fast_params.items(), grads)}
# === 外側: 適応後のパラメータでクエリ損失 ===
q_logits = functional_forward(model, fast_params, query_x)
meta_loss = meta_loss + F.cross_entropy(q_logits, query_y)
meta_loss = meta_loss / len(tasks)
opt.zero_grad(); meta_loss.backward(); opt.step()
実用的には learn2learn や higher ライブラリを使うと内側ループが綺麗に書けます。
MAML系の派生
| 手法 | 変更点 |
|---|---|
| FOMAML | 2階勾配を捨てて1階に近似 → 高速 |
| Reptile | 外側勾配を「適応後の重み − 初期重み」で近似 |
| ANIL | 内側ループでは最終層だけ更新 → 軽量化 |
| iMAML | implicit gradient で内側計算を分離 |
エピソード(タスク)サンプリング
Few-Shotは「学習時にも N-way K-shotのタスクをランダムに作る」のが特徴。データセット → タスク列に変換するサンプラーを書くのが要点。
import random
import torch
from torch.utils.data import Dataset
class EpisodeSampler:
def __init__(self, dataset_by_class, n_way=5, k_shot=1, q_query=15):
"""
dataset_by_class: {class_id: [tensor_image_1, tensor_image_2, ...]}
"""
self.data = dataset_by_class
self.classes = list(dataset_by_class.keys())
self.n_way, self.k_shot, self.q_query = n_way, k_shot, q_query
def sample(self):
chosen = random.sample(self.classes, self.n_way)
support_x, support_y, query_x, query_y = [], [], [], []
for label, cls in enumerate(chosen):
samples = random.sample(self.data[cls], self.k_shot + self.q_query)
for s in samples[:self.k_shot]:
support_x.append(s); support_y.append(label)
for s in samples[self.k_shot:]:
query_x.append(s); query_y.append(label)
return (torch.stack(support_x), torch.tensor(support_y),
torch.stack(query_x), torch.tensor(query_y),
self.n_way)
代表データセット
- Omniglot — 50言語の手書き文字、1623クラス × 20サンプル(Few-Shotの定番)
- miniImageNet — ImageNet から100クラス抽出、84×84
- tieredImageNet — 階層構造ありの大規模版
- Meta-Dataset — 10種類のドメイン横断ベンチマーク