説明可能AI (XAI)
学習目標: モデルの予測理由を可視化する代表的な手法(SHAP / LIME / Captum / Grad-CAM)を理解し、PyTorchで実装できるようになる
なぜモデルの説明性が必要か
- 規制対応: GDPR の「自動意思決定の説明を受ける権利」、医療・金融の業界規制
- 信頼性: 業務担当者がモデルの予測を採用するか判断するため
- デバッグ: スプリアス特徴(学習データのバイアス)に依存していないか確認
- 科学的発見: モデルが見つけたパターンから新しい知見を得る
XAI手法の分類
| 分類軸 | カテゴリ | 例 |
|---|---|---|
| スコープ | Global(モデル全体) | Permutation Importance, SHAP global |
| Local(1サンプル) | LIME, SHAP local, Grad-CAM | |
| モデル依存 | Model-agnostic(モデル非依存) | SHAP, LIME |
| Model-specific(モデル依存) | Grad-CAM (CNN), Attention可視化 (Transformer) |
SHAP (SHapley Additive exPlanations)
ゲーム理論の Shapley値 を機械学習の特徴量寄与に応用。各特徴が予測値にどれだけ貢献したかを公平に分解する。
基本概念
予測値 = ベース値 (E[f(X)]) + Σ shap_value(feature_i) 各 shap_value は「その特徴を加えると平均的に予測がどれだけ動いたか」
PyTorchモデル向け実装
import shap
import torch
# 学習済みモデル(評価モード)
model.eval()
# 背景データ(参照分布)
background = X_train[:100] # 100サンプル
# DeepExplainer はディープラーニング向け
explainer = shap.DeepExplainer(model, background)
# テストサンプルでSHAP値を計算
shap_values = explainer.shap_values(X_test[:10])
# 可視化(特徴重要度の summary plot)
shap.summary_plot(shap_values, X_test[:10], feature_names=feature_names)
# 1サンプルの内訳
shap.force_plot(explainer.expected_value, shap_values[0], X_test[0])
画像モデル向け(GradientExplainer)
# 画像分類モデルの場合
explainer = shap.GradientExplainer(model, background_images)
shap_values = explainer.shap_values(test_images)
# 各クラスについてSHAPマップを画像オーバーレイ表示
shap.image_plot(shap_values, test_images.numpy())
注意: SHAPはサンプル数 × 特徴数のオーダーで計算コストが大きい。
テキスト/画像分類でも近似版(
PartitionExplainer)を使うのが現実的。
LIME (Local Interpretable Model-agnostic Explanations)
注目サンプル周辺で複雑なモデルをシンプルな線形モデルで局所近似することで、ローカルな説明を得る。
アルゴリズム
- 説明したいサンプル
xの周りに摂動サンプルを大量生成 - 各摂動について元モデルで予測
- 摂動と
xの距離を重みにして、線形モデル(LASSO等)を学習 - 線形モデルの係数が「特徴ごとの寄与」
テーブル/画像/テキストへの適用
from lime.lime_tabular import LimeTabularExplainer
from lime.lime_image import LimeImageExplainer
from lime.lime_text import LimeTextExplainer
# === Tabular ===
explainer = LimeTabularExplainer(
X_train.values,
feature_names=feature_names,
class_names=['neg', 'pos'],
mode='classification',
)
exp = explainer.explain_instance(X_test.iloc[0].values,
model.predict_proba,
num_features=10)
exp.show_in_notebook()
# === Image ===
def predict_fn(images):
# numpy (N, H, W, 3) → torch → softmax確率
x = torch.tensor(images).permute(0, 3, 1, 2).float() / 255
with torch.no_grad():
return model(x).softmax(1).numpy()
explainer = LimeImageExplainer()
exp = explainer.explain_instance(image_np, predict_fn,
top_labels=3, num_samples=1000)
img, mask = exp.get_image_and_mask(exp.top_labels[0], positive_only=True)
Captum (PyTorch公式XAIライブラリ)
PyTorch生まれの統合XAIライブラリ。Integrated Gradients / Saliency / GradCAM / DeepLift など多数の手法を統一APIで使える。
Integrated Gradients (画像分類)
from captum.attr import IntegratedGradients, NoiseTunnel
from captum.attr import visualization as viz
model.eval()
ig = IntegratedGradients(model)
# 帰属マップを計算
attributions, delta = ig.attribute(
input_tensor, # (1, 3, H, W)
target=predicted_class,
return_convergence_delta=True,
n_steps=50, # 補間ステップ数
)
# 可視化(元画像にヒートマップオーバーレイ)
viz.visualize_image_attr(
attributions.squeeze().cpu().permute(1, 2, 0).numpy(),
original_image=input_tensor.squeeze().cpu().permute(1, 2, 0).numpy(),
method='blended_heat_map',
sign='positive',
)
Grad-CAM (CNN専用)
from captum.attr import LayerGradCam, LayerAttribution
# 最後のConv層を指定
target_layer = model.layer4[-1]
gradcam = LayerGradCam(model, target_layer)
attrs = gradcam.attribute(input_tensor, target=predicted_class)
# 入力サイズにアップサンプル
upsampled = LayerAttribution.interpolate(attrs, input_tensor.shape[-2:])
heatmap = upsampled.squeeze().cpu().numpy()
Captum 主要メソッド早見表
| クラス | 用途 | 計算量 |
|---|---|---|
Saliency | 勾配の絶対値 | 1回 forward+backward |
IntegratedGradients | 積分勾配(理論的に堅牢) | n_steps回 |
DeepLift | 参照入力からの差分 | 1〜2回 |
GradientShap | SHAPの勾配版 | n_samples回 |
LayerGradCam | CNN向け、特定層のヒートマップ | 1回 |
LayerActivation | 中間層の活性 | 1回 forward |
パイプライン: 学習 → Captumで複数手法を試す → SHAPで定量化 → LIMEで意外なサンプルを掘り下げる、という流れが実務で多いです。