モデル比較

学習目標: 複数モデルを公平に比較するためのメトリクス・計測方法・統計的検定を理解する

主要メトリクス

分類タスク

from sklearn.metrics import (
    accuracy_score, f1_score, roc_auc_score,
    confusion_matrix, classification_report,
)

acc  = accuracy_score(y_true, y_pred)
f1   = f1_score(y_true, y_pred, average='macro')
# 確率スコアで AUC
auc  = roc_auc_score(y_true, y_proba, multi_class='ovr')

print(classification_report(y_true, y_pred, digits=4))
print(confusion_matrix(y_true, y_pred))

回帰タスク

from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import numpy as np

mse  = mean_squared_error(y_true, y_pred)
rmse = np.sqrt(mse)
mae  = mean_absolute_error(y_true, y_pred)
r2   = r2_score(y_true, y_pred)
mape = np.mean(np.abs((y_true - y_pred) / np.maximum(np.abs(y_true), 1e-8)))

生成タスク

指標用途
FID画像生成: 本物と生成の特徴分布距離
IS画像生成: 自信度 × 多様性
LPIPS画像: 知覚距離
BLEU / ROUGE機械翻訳・要約
Perplexity言語モデル

不均衡データの注意

正解クラスが偏っていると Accuracy は誤解を招きます。F1 (macro) / AUROC / Balanced Accuracy を優先しましょう。

速度・サイズの計測

パラメータ数

def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

t, tr = count_params(model)
print(f"Total: {t/1e6:.2f}M, Trainable: {tr/1e6:.2f}M")

FLOPs(理論演算量)

# pip install fvcore
from fvcore.nn import FlopCountAnalysis

flops = FlopCountAnalysis(model, dummy_input)
print(f"FLOPs: {flops.total() / 1e9:.2f} G")

推論レイテンシ(GPU)

import torch, time

def benchmark(model, x, n_warm=20, n_iter=100):
    model.eval()
    with torch.no_grad():
        # ウォームアップ
        for _ in range(n_warm):
            _ = model(x)
        torch.cuda.synchronize()

        start = time.time()
        for _ in range(n_iter):
            _ = model(x)
        torch.cuda.synchronize()
        elapsed = (time.time() - start) / n_iter
    return elapsed * 1000          # ms

x = torch.randn(1, 3, 224, 224, device='cuda')
print(f"Latency: {benchmark(model.cuda(), x):.2f} ms")
print(f"Throughput: {1000 / benchmark(model.cuda(), x):.1f} samples/sec")

メモリ使用量

torch.cuda.reset_peak_memory_stats()
out = model(x)
peak = torch.cuda.max_memory_allocated() / 1e6
print(f"Peak GPU memory: {peak:.1f} MB")

統計的有意性

2モデルのスコア差が「偶然か実力か」を区別する。シードを変えて複数回学習し、検定する。

対応のあるt検定(Paired t-test)

import numpy as np
from scipy import stats

# モデルA, B を10シードで評価
scores_a = [0.842, 0.839, 0.845, 0.841, 0.843,
            0.840, 0.844, 0.838, 0.842, 0.846]
scores_b = [0.851, 0.848, 0.854, 0.850, 0.852,
            0.849, 0.853, 0.847, 0.851, 0.855]

t, p = stats.ttest_rel(scores_b, scores_a)
print(f"diff = {np.mean(scores_b)-np.mean(scores_a):.4f}")
print(f"t = {t:.3f}, p = {p:.4f}")
# p < 0.05 なら統計的に有意な差

ブートストラップ信頼区間

def bootstrap_ci(scores, n_boot=10000, ci=0.95):
    means = []
    n = len(scores)
    for _ in range(n_boot):
        sample = np.random.choice(scores, size=n, replace=True)
        means.append(sample.mean())
    lower = np.percentile(means, (1 - ci) / 2 * 100)
    upper = np.percentile(means, (1 + ci) / 2 * 100)
    return np.mean(scores), (lower, upper)

mean, (lo, hi) = bootstrap_ci(scores_a)
print(f"acc = {mean:.4f}, 95% CI = [{lo:.4f}, {hi:.4f}]")
シード数の目安: 最低3〜5シード、可能なら10シード。少ないシードで「Aがベストです」と言うのは禁物。

比較表テンプレート

論文・実験レポートで使える比較表のテンプレ。精度 ± 標準偏差 + 計算コストを必ず併記。

モデル Params (M) FLOPs (G) Latency (ms) Accuracy vs baseline (p値)
Mean Std
ResNet-18 (baseline) 11.71.84.2 91.5±0.3
ResNet-50 25.64.18.7 92.8±0.20.003 ★
EfficientNet-B0 5.30.45.5 92.1±0.30.018 ★
ViT-Small 22.04.69.1 93.2±0.40.001 ★★
ConvNeXt-Tiny 28.64.59.5 93.5±0.20.001 ★★

★ p < 0.05、★★ p < 0.01。Latencyは 224×224, batch=1, A100 GPU。

表に含めるべき情報

  • パラメータ数(M)と FLOPs(G)
  • 精度の 平均 ± 標準偏差 (複数シード)
  • 推論レイテンシ(環境を明記)
  • ベースラインとの統計的有意差
  • 学習コスト(任意、GPU時間)

Pareto frontier の可視化

計算コスト × 精度 の散布図でモデルを並べると、「より小さく・より速く・より高精度」のフロンティアが見える。

import matplotlib.pyplot as plt

models = [
    ('ResNet-18', 1.8, 91.5),
    ('ResNet-50', 4.1, 92.8),
    ('EfficientNet-B0', 0.4, 92.1),
    ('ViT-Small', 4.6, 93.2),
    ('ConvNeXt-Tiny', 4.5, 93.5),
]
for name, flops, acc in models:
    plt.scatter(flops, acc, s=80)
    plt.text(flops + 0.05, acc, name, fontsize=9)
plt.xlabel('FLOPs (G)'); plt.ylabel('Accuracy (%)')
plt.title('精度 vs 計算量')
plt.grid(True, alpha=0.3); plt.show()