モデル管理
学習目標: 学習済みモデルの保存・読込・比較・配布の方法を理解する
モデルの保存
state_dict だけを保存する(推奨)
パラメータだけ保存するのが安全・互換性が高い方法。コード側でモデルを再構築してから読み込む。
import torch
# 保存
torch.save(model.state_dict(), 'vae_model.pth')
# 読み込み(モデル定義は別途必要)
model = VAE(latent_dim=64) # 同じ構成で再構築
model.load_state_dict(torch.load('vae_model.pth'))
model.eval()
学習途中の状態(チェックポイント)を保存
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'loss': loss.item(),
'config': {'latent_dim': 64, 'lr': 1e-4},
}
torch.save(checkpoint, f'checkpoint_epoch{epoch:03d}.pth')
再開時はこう書く:
checkpoint = torch.load('checkpoint_epoch050.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
start_epoch = checkpoint['epoch'] + 1
for epoch in range(start_epoch, num_epochs):
train_one_epoch(...)
モデル全体を保存(非推奨だが楽)
クラス定義ごと pickle される。リファクタすると壊れやすいので長期保存には不向き。
torch.save(model, 'whole_model.pth') # ⚠ クラス定義に依存
model = torch.load('whole_model.pth')
本番向け: ONNX エクスポート
PyTorch 非依存で推論できる形式。TensorRT / ONNX Runtime で高速推論可。
dummy = torch.randn(1, 64) # 入力サンプル
torch.onnx.export(
model, dummy, 'vae.onnx',
input_names=['z'], output_names=['image'],
dynamic_axes={'z': {0: 'batch'}, 'image': {0: 'batch'}},
opset_version=17,
)
モデルの読み込み
推論モードで読み込む
model = VAE(latent_dim=64)
model.load_state_dict(torch.load('vae_model.pth', map_location='cpu'))
model.eval() # ← BatchNorm/Dropout を推論モードに
# サンプル生成
with torch.no_grad():
z = torch.randn(8, 64)
samples = model.decode(z)
map_location='cpu' は GPU で保存したモデルを CPU で読み込むときに必要。
指定しないと RuntimeError: CUDA is not available になる。
strict=False で不一致を許容
転移学習などで一部の層だけ流用する場合:
missing, unexpected = model.load_state_dict(
torch.load('partial.pth'), strict=False
)
print(f"missing keys: {missing}")
print(f"unexpected keys: {unexpected}")
Hugging Face Hub からダウンロード
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
).to("cuda")
image = pipe("a fluffy cat on a windowsill").images[0]
image.save("cat.png")
モデル間の比較
比較すべき主要メトリクス
| メトリクス | 意味 | 典型的な使い方 |
|---|---|---|
| FID (Fréchet Inception Distance) | 生成画像と本物画像の特徴分布の距離 | 低いほど良い。GANの定番指標 |
| IS (Inception Score) | 分類器の自信度 × 多様性 | 高いほど良い。FIDに置き換わりつつある |
| LPIPS | 知覚的距離(学習済み特徴の差) | 2枚の画像の類似度 |
| Reconstruction Loss | VAE等での再構成誤差 | MSE/BCE。低いほど良い |
| Inference Time | 1サンプル生成にかかる時間 | 本番運用ではこれが効く |
| Model Size | パラメータ数 (M) / ディスクサイズ (MB) | エッジ展開時に重要 |
パラメータ数のカウント
def count_params(model):
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
return total, trainable
total, trainable = count_params(model)
print(f"Total: {total/1e6:.2f}M, Trainable: {trainable/1e6:.2f}M")
推論速度の計測
import time
model.eval()
x = torch.randn(1, 3, 64, 64).to(device)
# ウォームアップ
for _ in range(10):
with torch.no_grad():
_ = model(x)
torch.cuda.synchronize()
# 計測
start = time.time()
N = 100
with torch.no_grad():
for _ in range(N):
_ = model(x)
torch.cuda.synchronize()
elapsed = (time.time() - start) / N
print(f"平均推論時間: {elapsed*1000:.2f} ms/sample")
配布・デプロイ
Hugging Face Hub にアップロード
from huggingface_hub import HfApi, login
login(token="hf_xxx...")
api = HfApi()
# モデルファイルをアップロード
api.upload_file(
path_or_fileobj="vae_model.pth",
path_in_repo="vae_model.pth",
repo_id="your-name/my-vae",
repo_type="model",
)
軽量化テクニック
| 手法 | 削減効果 | 備考 |
|---|---|---|
| FP16 (半精度) | サイズ・速度 約2倍 | model.half() または torch.autocast |
| INT8量子化 | サイズ 約4倍 | torch.quantization / TensorRT |
| 蒸留 | パラメータ 5〜10倍減 | 大モデル → 小モデルへの転移学習 |
| 枝刈り (Pruning) | パラメータ 2〜3倍減 | 重要度の低い重みを 0 にする |
| LoRA | 差分のみ数MB | 大モデルの微調整に最適 |
本番推論サーバの基本パターン (FastAPI)
from fastapi import FastAPI
from pydantic import BaseModel
import torch
app = FastAPI()
model = VAE(latent_dim=64)
model.load_state_dict(torch.load('vae_model.pth', map_location='cpu'))
model.eval()
class GenerateRequest(BaseModel):
n_samples: int = 1
seed: int | None = None
@app.post("/generate")
def generate(req: GenerateRequest):
if req.seed is not None:
torch.manual_seed(req.seed)
with torch.no_grad():
z = torch.randn(req.n_samples, 64)
samples = model.decode(z).cpu().numpy().tolist()
return {"samples": samples}
実運用ではバッチング・GPUプール・キャッシュ・タイムアウトを足す。
重い推論は celery 等の非同期キュー経由が安全。