モデル管理

学習目標: 学習済みモデルの保存・読込・比較・配布の方法を理解する

モデルの保存

state_dict だけを保存する(推奨)

パラメータだけ保存するのが安全・互換性が高い方法。コード側でモデルを再構築してから読み込む。

import torch

# 保存
torch.save(model.state_dict(), 'vae_model.pth')

# 読み込み(モデル定義は別途必要)
model = VAE(latent_dim=64)              # 同じ構成で再構築
model.load_state_dict(torch.load('vae_model.pth'))
model.eval()

学習途中の状態(チェックポイント)を保存

checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'loss': loss.item(),
    'config': {'latent_dim': 64, 'lr': 1e-4},
}
torch.save(checkpoint, f'checkpoint_epoch{epoch:03d}.pth')

再開時はこう書く:

checkpoint = torch.load('checkpoint_epoch050.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
start_epoch = checkpoint['epoch'] + 1

for epoch in range(start_epoch, num_epochs):
    train_one_epoch(...)

モデル全体を保存(非推奨だが楽)

クラス定義ごと pickle される。リファクタすると壊れやすいので長期保存には不向き。

torch.save(model, 'whole_model.pth')   # ⚠ クラス定義に依存
model = torch.load('whole_model.pth')

本番向け: ONNX エクスポート

PyTorch 非依存で推論できる形式。TensorRT / ONNX Runtime で高速推論可。

dummy = torch.randn(1, 64)              # 入力サンプル
torch.onnx.export(
    model, dummy, 'vae.onnx',
    input_names=['z'], output_names=['image'],
    dynamic_axes={'z': {0: 'batch'}, 'image': {0: 'batch'}},
    opset_version=17,
)

モデルの読み込み

推論モードで読み込む

model = VAE(latent_dim=64)
model.load_state_dict(torch.load('vae_model.pth', map_location='cpu'))
model.eval()                # ← BatchNorm/Dropout を推論モードに

# サンプル生成
with torch.no_grad():
    z = torch.randn(8, 64)
    samples = model.decode(z)

map_location='cpu' は GPU で保存したモデルを CPU で読み込むときに必要。 指定しないと RuntimeError: CUDA is not available になる。

strict=False で不一致を許容

転移学習などで一部の層だけ流用する場合:

missing, unexpected = model.load_state_dict(
    torch.load('partial.pth'), strict=False
)
print(f"missing keys:    {missing}")
print(f"unexpected keys: {unexpected}")

Hugging Face Hub からダウンロード

from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16,
).to("cuda")

image = pipe("a fluffy cat on a windowsill").images[0]
image.save("cat.png")

モデル間の比較

比較すべき主要メトリクス

メトリクス意味典型的な使い方
FID (Fréchet Inception Distance) 生成画像と本物画像の特徴分布の距離 低いほど良い。GANの定番指標
IS (Inception Score) 分類器の自信度 × 多様性 高いほど良い。FIDに置き換わりつつある
LPIPS 知覚的距離(学習済み特徴の差) 2枚の画像の類似度
Reconstruction Loss VAE等での再構成誤差 MSE/BCE。低いほど良い
Inference Time 1サンプル生成にかかる時間 本番運用ではこれが効く
Model Size パラメータ数 (M) / ディスクサイズ (MB) エッジ展開時に重要

パラメータ数のカウント

def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

total, trainable = count_params(model)
print(f"Total: {total/1e6:.2f}M, Trainable: {trainable/1e6:.2f}M")

推論速度の計測

import time
model.eval()
x = torch.randn(1, 3, 64, 64).to(device)

# ウォームアップ
for _ in range(10):
    with torch.no_grad():
        _ = model(x)
torch.cuda.synchronize()

# 計測
start = time.time()
N = 100
with torch.no_grad():
    for _ in range(N):
        _ = model(x)
torch.cuda.synchronize()
elapsed = (time.time() - start) / N
print(f"平均推論時間: {elapsed*1000:.2f} ms/sample")

配布・デプロイ

Hugging Face Hub にアップロード

from huggingface_hub import HfApi, login

login(token="hf_xxx...")
api = HfApi()

# モデルファイルをアップロード
api.upload_file(
    path_or_fileobj="vae_model.pth",
    path_in_repo="vae_model.pth",
    repo_id="your-name/my-vae",
    repo_type="model",
)

軽量化テクニック

手法削減効果備考
FP16 (半精度)サイズ・速度 約2倍model.half() または torch.autocast
INT8量子化サイズ 約4倍torch.quantization / TensorRT
蒸留パラメータ 5〜10倍減大モデル → 小モデルへの転移学習
枝刈り (Pruning)パラメータ 2〜3倍減重要度の低い重みを 0 にする
LoRA差分のみ数MB大モデルの微調整に最適

本番推論サーバの基本パターン (FastAPI)

from fastapi import FastAPI
from pydantic import BaseModel
import torch

app = FastAPI()
model = VAE(latent_dim=64)
model.load_state_dict(torch.load('vae_model.pth', map_location='cpu'))
model.eval()

class GenerateRequest(BaseModel):
    n_samples: int = 1
    seed: int | None = None

@app.post("/generate")
def generate(req: GenerateRequest):
    if req.seed is not None:
        torch.manual_seed(req.seed)
    with torch.no_grad():
        z = torch.randn(req.n_samples, 64)
        samples = model.decode(z).cpu().numpy().tolist()
    return {"samples": samples}

実運用ではバッチング・GPUプール・キャッシュ・タイムアウトを足す。 重い推論は celery 等の非同期キュー経由が安全。