モデル最適化
学習目標: 量子化・プルーニング・蒸留・コンパイルを使ってモデルを軽量・高速化する手法を理解する
最適化のトレードオフ
| 手法 | サイズ削減 | 速度向上 | 精度低下 | 難易度 |
|---|---|---|---|---|
| FP16 | ×2 | ×2 | ほぼ無し | 易 |
| INT8 量子化 (Post-training) | ×4 | ×2〜4 | 1〜2% | 易 |
| INT8 量子化 (QAT) | ×4 | ×2〜4 | ほぼ無し | 中 |
| 構造化プルーニング | ×2〜5 | ×2〜5 | 2〜5% | 中 |
| 非構造化プルーニング | ×10+ | 変わらず | 1〜3% | 易 |
| 知識蒸留 | ×5〜10 | ×5〜10 | 1〜5% | 中 |
| torch.compile | 無し | ×1.5〜3 | 無し | 易 |
最適化の順序(推奨)
- torch.compile を試す(コード変更なし、効果大)
- FP16/BF16 でAutoCast(autocast +GradScaler)
- 蒸留 でアーキ縮小(最大の効果)
- 量子化 (INT8) で推論側を軽量化
- プルーニング でさらに削減
- ONNX/TensorRT でデプロイ最適化
量子化 (Quantization)
FP32重みをINT8等の低ビット表現に変換し、サイズ削減と高速化を同時に得る。
Dynamic Quantization (最も簡単)
重みのみ量子化、活性は実行時に動的量子化。CPU推論で効く。
import torch
import torch.nn as nn
import torch.quantization
# 学習済みモデル
model.eval()
# 線形層・LSTM層を INT8 量子化
quantized = torch.quantization.quantize_dynamic(
model,
{nn.Linear, nn.LSTM},
dtype=torch.qint8,
)
# サイズ確認
def model_size_mb(m):
torch.save(m.state_dict(), '/tmp/_tmp.p')
import os; size = os.path.getsize('/tmp/_tmp.p') / 1e6
os.remove('/tmp/_tmp.p'); return size
print(f"FP32: {model_size_mb(model):.2f} MB")
print(f"INT8: {model_size_mb(quantized):.2f} MB")
Static Quantization (PTQ: Post-Training)
キャリブレーションデータで活性のスケールも事前計算。さらに高速。
model.eval()
model.qconfig = torch.quantization.get_default_qconfig('fbgemm') # x86
# 量子化準備: 観測器を挿入
prepared = torch.quantization.prepare(model, inplace=False)
# キャリブレーション(代表データを流す)
with torch.no_grad():
for x, _ in calibration_loader:
prepared(x)
# 量子化を確定
quantized = torch.quantization.convert(prepared, inplace=False)
Quantization-Aware Training (QAT)
学習時から量子化を意識。精度劣化が最小。
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
model.train()
qat_model = torch.quantization.prepare_qat(model, inplace=False)
# 通常通りファインチューニング
for epoch in range(num_epochs):
for x, y in train_loader:
out = qat_model(x)
loss = criterion(out, y)
loss.backward(); optimizer.step(); optimizer.zero_grad()
qat_model.eval()
quantized = torch.quantization.convert(qat_model)
プルーニング (Pruning)
重要度の低い重みを 0 にしてスパース化。ストレージ削減・推論加速。
非構造化プルーニング (Magnitude Pruning)
import torch.nn.utils.prune as prune
# 1層だけ: 重みの絶対値が小さい上位20%を 0 にする
prune.l1_unstructured(model.fc1, name='weight', amount=0.2)
# モデル全体に一括適用
parameters_to_prune = [
(m, 'weight') for m in model.modules()
if isinstance(m, (nn.Linear, nn.Conv2d))
]
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.3, # 全体の30%を削除
)
# 再学習でリカバリ(重要)
fine_tune(model, train_loader, epochs=5)
# pruning を恒久化(マスクを削除)
for m, name in parameters_to_prune:
prune.remove(m, name)
構造化プルーニング (Channel Pruning)
チャネル単位で削除 → 実際にFLOPsが減って速くなる。
# Conv層の出力チャネルのうち、L2ノルムが小さい25%を削除
prune.ln_structured(
model.conv1, name='weight', amount=0.25, n=2, dim=0,
)
Iterative Magnitude Pruning (IMP)
段階的にプルーニングしてリカバリを繰り返す。少しずつ削るほど精度が保たれる。
target_sparsity, current = 0.9, 0.0
step = 0.1
while current < target_sparsity:
current = min(current + step, target_sparsity)
prune.global_unstructured(parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=current)
fine_tune(model, train_loader, epochs=2)
print(f"sparsity={current}, val_acc={evaluate(model, val_loader):.3f}")
スパース密度の計測
def sparsity(model):
total, zero = 0, 0
for p in model.parameters():
total += p.numel()
zero += (p == 0).sum().item()
return zero / total
print(f"Sparsity: {sparsity(model):.1%}")
コンパイルとエクスポート
torch.compile (PyTorch 2.x)
コード変更なしで 1.5〜3 倍高速化。グラフモード最適化。
import torch
# モデルをコンパイル(学習・推論両方で使える)
compiled = torch.compile(model, mode='reduce-overhead')
# mode の選択肢:
# 'default' : バランス
# 'reduce-overhead' : 小モデルで効く
# 'max-autotune' : 最高速、コンパイル時間長
TorchScript (グラフ凍結)
# Tracing: 入力例から推論パスをトレース
example_input = torch.randn(1, 3, 224, 224)
traced = torch.jit.trace(model, example_input)
traced.save('model.pt')
# 別環境で読み込み(Python不要、C++でも動く)
loaded = torch.jit.load('model.pt')
ONNX エクスポート
torch.onnx.export(
model, example_input, 'model.onnx',
input_names=['input'], output_names=['output'],
dynamic_axes={'input': {0: 'batch'}, 'output': {0: 'batch'}},
opset_version=17,
)
# 検証(onnx-runtime)
import onnxruntime as ort
sess = ort.InferenceSession('model.onnx', providers=['CUDAExecutionProvider'])
output = sess.run(None, {'input': example_input.numpy()})
TensorRT (NVIDIA GPU 推論)
# ONNX → TensorRT (CLI)
trtexec --onnx=model.onnx --saveEngine=model.trt \
--fp16 --workspace=4096
# 計測: 同じ入力で平均レイテンシ
trtexec --loadEngine=model.trt --shapes=input:1x3x224x224 \
--iterations=100
エコシステム概観
| ターゲット | 推奨パス |
|---|---|
| NVIDIA GPU | PyTorch → TensorRT (FP16/INT8) |
| CPU 推論 | ONNX Runtime / OpenVINO + INT8 |
| モバイル (Android/iOS) | PyTorch Mobile / Core ML |
| ブラウザ | ONNX Runtime Web / WebGPU |
| エッジ (Raspberry Pi 等) | TFLite / ONNX Runtime (ARM) |