モデル比較
学習目標: 複数モデルを公平に比較するためのメトリクス・計測方法・統計的検定を理解する
主要メトリクス
分類タスク
from sklearn.metrics import (
accuracy_score, f1_score, roc_auc_score,
confusion_matrix, classification_report,
)
acc = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred, average='macro')
# 確率スコアで AUC
auc = roc_auc_score(y_true, y_proba, multi_class='ovr')
print(classification_report(y_true, y_pred, digits=4))
print(confusion_matrix(y_true, y_pred))
回帰タスク
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import numpy as np
mse = mean_squared_error(y_true, y_pred)
rmse = np.sqrt(mse)
mae = mean_absolute_error(y_true, y_pred)
r2 = r2_score(y_true, y_pred)
mape = np.mean(np.abs((y_true - y_pred) / np.maximum(np.abs(y_true), 1e-8)))
生成タスク
| 指標 | 用途 |
|---|---|
| FID | 画像生成: 本物と生成の特徴分布距離 |
| IS | 画像生成: 自信度 × 多様性 |
| LPIPS | 画像: 知覚距離 |
| BLEU / ROUGE | 機械翻訳・要約 |
| Perplexity | 言語モデル |
不均衡データの注意
正解クラスが偏っていると Accuracy は誤解を招きます。F1 (macro) / AUROC / Balanced Accuracy を優先しましょう。
速度・サイズの計測
パラメータ数
def count_params(model):
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
return total, trainable
t, tr = count_params(model)
print(f"Total: {t/1e6:.2f}M, Trainable: {tr/1e6:.2f}M")
FLOPs(理論演算量)
# pip install fvcore
from fvcore.nn import FlopCountAnalysis
flops = FlopCountAnalysis(model, dummy_input)
print(f"FLOPs: {flops.total() / 1e9:.2f} G")
推論レイテンシ(GPU)
import torch, time
def benchmark(model, x, n_warm=20, n_iter=100):
model.eval()
with torch.no_grad():
# ウォームアップ
for _ in range(n_warm):
_ = model(x)
torch.cuda.synchronize()
start = time.time()
for _ in range(n_iter):
_ = model(x)
torch.cuda.synchronize()
elapsed = (time.time() - start) / n_iter
return elapsed * 1000 # ms
x = torch.randn(1, 3, 224, 224, device='cuda')
print(f"Latency: {benchmark(model.cuda(), x):.2f} ms")
print(f"Throughput: {1000 / benchmark(model.cuda(), x):.1f} samples/sec")
メモリ使用量
torch.cuda.reset_peak_memory_stats()
out = model(x)
peak = torch.cuda.max_memory_allocated() / 1e6
print(f"Peak GPU memory: {peak:.1f} MB")
統計的有意性
2モデルのスコア差が「偶然か実力か」を区別する。シードを変えて複数回学習し、検定する。
対応のあるt検定(Paired t-test)
import numpy as np
from scipy import stats
# モデルA, B を10シードで評価
scores_a = [0.842, 0.839, 0.845, 0.841, 0.843,
0.840, 0.844, 0.838, 0.842, 0.846]
scores_b = [0.851, 0.848, 0.854, 0.850, 0.852,
0.849, 0.853, 0.847, 0.851, 0.855]
t, p = stats.ttest_rel(scores_b, scores_a)
print(f"diff = {np.mean(scores_b)-np.mean(scores_a):.4f}")
print(f"t = {t:.3f}, p = {p:.4f}")
# p < 0.05 なら統計的に有意な差
ブートストラップ信頼区間
def bootstrap_ci(scores, n_boot=10000, ci=0.95):
means = []
n = len(scores)
for _ in range(n_boot):
sample = np.random.choice(scores, size=n, replace=True)
means.append(sample.mean())
lower = np.percentile(means, (1 - ci) / 2 * 100)
upper = np.percentile(means, (1 + ci) / 2 * 100)
return np.mean(scores), (lower, upper)
mean, (lo, hi) = bootstrap_ci(scores_a)
print(f"acc = {mean:.4f}, 95% CI = [{lo:.4f}, {hi:.4f}]")
シード数の目安: 最低3〜5シード、可能なら10シード。少ないシードで「Aがベストです」と言うのは禁物。
比較表テンプレート
論文・実験レポートで使える比較表のテンプレ。精度 ± 標準偏差 + 計算コストを必ず併記。
| モデル | Params (M) | FLOPs (G) | Latency (ms) | Accuracy | vs baseline (p値) | |
|---|---|---|---|---|---|---|
| Mean | Std | |||||
| ResNet-18 (baseline) | 11.7 | 1.8 | 4.2 | 91.5 | ±0.3 | — |
| ResNet-50 | 25.6 | 4.1 | 8.7 | 92.8 | ±0.2 | 0.003 ★ |
| EfficientNet-B0 | 5.3 | 0.4 | 5.5 | 92.1 | ±0.3 | 0.018 ★ |
| ViT-Small | 22.0 | 4.6 | 9.1 | 93.2 | ±0.4 | 0.001 ★★ |
| ConvNeXt-Tiny | 28.6 | 4.5 | 9.5 | 93.5 | ±0.2 | 0.001 ★★ |
★ p < 0.05、★★ p < 0.01。Latencyは 224×224, batch=1, A100 GPU。
表に含めるべき情報
- パラメータ数(M)と FLOPs(G)
- 精度の 平均 ± 標準偏差 (複数シード)
- 推論レイテンシ(環境を明記)
- ベースラインとの統計的有意差
- 学習コスト(任意、GPU時間)
Pareto frontier の可視化
計算コスト × 精度 の散布図でモデルを並べると、「より小さく・より速く・より高精度」のフロンティアが見える。
import matplotlib.pyplot as plt
models = [
('ResNet-18', 1.8, 91.5),
('ResNet-50', 4.1, 92.8),
('EfficientNet-B0', 0.4, 92.1),
('ViT-Small', 4.6, 93.2),
('ConvNeXt-Tiny', 4.5, 93.5),
]
for name, flops, acc in models:
plt.scatter(flops, acc, s=80)
plt.text(flops + 0.05, acc, name, fontsize=9)
plt.xlabel('FLOPs (G)'); plt.ylabel('Accuracy (%)')
plt.title('精度 vs 計算量')
plt.grid(True, alpha=0.3); plt.show()