モデル最適化

学習目標: 量子化・プルーニング・蒸留・コンパイルを使ってモデルを軽量・高速化する手法を理解する

最適化のトレードオフ

手法サイズ削減速度向上精度低下難易度
FP16×2×2ほぼ無し
INT8 量子化 (Post-training)×4×2〜41〜2%
INT8 量子化 (QAT)×4×2〜4ほぼ無し
構造化プルーニング×2〜5×2〜52〜5%
非構造化プルーニング×10+変わらず1〜3%
知識蒸留×5〜10×5〜101〜5%
torch.compile無し×1.5〜3無し

最適化の順序(推奨)

  1. torch.compile を試す(コード変更なし、効果大)
  2. FP16/BF16 でAutoCast(autocast +GradScaler)
  3. 蒸留 でアーキ縮小(最大の効果)
  4. 量子化 (INT8) で推論側を軽量化
  5. プルーニング でさらに削減
  6. ONNX/TensorRT でデプロイ最適化

量子化 (Quantization)

FP32重みをINT8等の低ビット表現に変換し、サイズ削減と高速化を同時に得る。

Dynamic Quantization (最も簡単)

重みのみ量子化、活性は実行時に動的量子化。CPU推論で効く。

import torch
import torch.nn as nn
import torch.quantization

# 学習済みモデル
model.eval()

# 線形層・LSTM層を INT8 量子化
quantized = torch.quantization.quantize_dynamic(
    model,
    {nn.Linear, nn.LSTM},
    dtype=torch.qint8,
)

# サイズ確認
def model_size_mb(m):
    torch.save(m.state_dict(), '/tmp/_tmp.p')
    import os; size = os.path.getsize('/tmp/_tmp.p') / 1e6
    os.remove('/tmp/_tmp.p'); return size

print(f"FP32: {model_size_mb(model):.2f} MB")
print(f"INT8: {model_size_mb(quantized):.2f} MB")

Static Quantization (PTQ: Post-Training)

キャリブレーションデータで活性のスケールも事前計算。さらに高速。

model.eval()
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')   # x86

# 量子化準備: 観測器を挿入
prepared = torch.quantization.prepare(model, inplace=False)

# キャリブレーション(代表データを流す)
with torch.no_grad():
    for x, _ in calibration_loader:
        prepared(x)

# 量子化を確定
quantized = torch.quantization.convert(prepared, inplace=False)

Quantization-Aware Training (QAT)

学習時から量子化を意識。精度劣化が最小。

model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
model.train()
qat_model = torch.quantization.prepare_qat(model, inplace=False)

# 通常通りファインチューニング
for epoch in range(num_epochs):
    for x, y in train_loader:
        out = qat_model(x)
        loss = criterion(out, y)
        loss.backward(); optimizer.step(); optimizer.zero_grad()

qat_model.eval()
quantized = torch.quantization.convert(qat_model)

プルーニング (Pruning)

重要度の低い重みを 0 にしてスパース化。ストレージ削減・推論加速。

非構造化プルーニング (Magnitude Pruning)

import torch.nn.utils.prune as prune

# 1層だけ: 重みの絶対値が小さい上位20%を 0 にする
prune.l1_unstructured(model.fc1, name='weight', amount=0.2)

# モデル全体に一括適用
parameters_to_prune = [
    (m, 'weight') for m in model.modules()
    if isinstance(m, (nn.Linear, nn.Conv2d))
]
prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.3,                    # 全体の30%を削除
)

# 再学習でリカバリ(重要)
fine_tune(model, train_loader, epochs=5)

# pruning を恒久化(マスクを削除)
for m, name in parameters_to_prune:
    prune.remove(m, name)

構造化プルーニング (Channel Pruning)

チャネル単位で削除 → 実際にFLOPsが減って速くなる。

# Conv層の出力チャネルのうち、L2ノルムが小さい25%を削除
prune.ln_structured(
    model.conv1, name='weight', amount=0.25, n=2, dim=0,
)

Iterative Magnitude Pruning (IMP)

段階的にプルーニングしてリカバリを繰り返す。少しずつ削るほど精度が保たれる。

target_sparsity, current = 0.9, 0.0
step = 0.1

while current < target_sparsity:
    current = min(current + step, target_sparsity)
    prune.global_unstructured(parameters_to_prune,
                              pruning_method=prune.L1Unstructured,
                              amount=current)
    fine_tune(model, train_loader, epochs=2)
    print(f"sparsity={current}, val_acc={evaluate(model, val_loader):.3f}")

スパース密度の計測

def sparsity(model):
    total, zero = 0, 0
    for p in model.parameters():
        total += p.numel()
        zero += (p == 0).sum().item()
    return zero / total

print(f"Sparsity: {sparsity(model):.1%}")

コンパイルとエクスポート

torch.compile (PyTorch 2.x)

コード変更なしで 1.5〜3 倍高速化。グラフモード最適化。

import torch

# モデルをコンパイル(学習・推論両方で使える)
compiled = torch.compile(model, mode='reduce-overhead')

# mode の選択肢:
#   'default'           : バランス
#   'reduce-overhead'   : 小モデルで効く
#   'max-autotune'      : 最高速、コンパイル時間長

TorchScript (グラフ凍結)

# Tracing: 入力例から推論パスをトレース
example_input = torch.randn(1, 3, 224, 224)
traced = torch.jit.trace(model, example_input)
traced.save('model.pt')

# 別環境で読み込み(Python不要、C++でも動く)
loaded = torch.jit.load('model.pt')

ONNX エクスポート

torch.onnx.export(
    model, example_input, 'model.onnx',
    input_names=['input'], output_names=['output'],
    dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}},
    opset_version=17,
)

# 検証(onnx-runtime)
import onnxruntime as ort
sess = ort.InferenceSession('model.onnx', providers=['CUDAExecutionProvider'])
output = sess.run(None, {'input': example_input.numpy()})

TensorRT (NVIDIA GPU 推論)

# ONNX → TensorRT (CLI)
trtexec --onnx=model.onnx --saveEngine=model.trt \
        --fp16 --workspace=4096

# 計測: 同じ入力で平均レイテンシ
trtexec --loadEngine=model.trt --shapes=input:1x3x224x224 \
        --iterations=100

エコシステム概観

ターゲット推奨パス
NVIDIA GPUPyTorch → TensorRT (FP16/INT8)
CPU 推論ONNX Runtime / OpenVINO + INT8
モバイル (Android/iOS)PyTorch Mobile / Core ML
ブラウザONNX Runtime Web / WebGPU
エッジ (Raspberry Pi 等)TFLite / ONNX Runtime (ARM)