NN設計のベストプラクティス

学習目標: 入力設計・層構成・初期化・正規化・スキップ接続など、性能を左右する設計選択を体系的に押さえる

入力データの設計

特徴量の前処理パターン

特徴量タイプ処理
連続値(正規分布的)標準化 (x - μ) / σ
連続値(偏った分布)log変換 → 標準化
カテゴリ(少数)One-hot encoding
カテゴリ(多数)Embedding層
順序カテゴリ整数化 + Embedding or 線形エンコード
テキストTokenize → Embedding
画像Resize → Normalize (mean/std)

Embedding 設計の経験則

embedding_dim ≈ min(50, (cardinality + 1) ** 0.25 * 6)

例:
  cardinality=10    → dim=12
  cardinality=100   → dim=20
  cardinality=1000  → dim=33
  cardinality=10000 → dim=50
class TabularModel(nn.Module):
    """カテゴリ + 連続値を扱う汎用モデル"""
    def __init__(self, cat_dims, n_cont, hidden=128, n_out=1):
        super().__init__()
        # 各カテゴリ列の embedding
        self.embeds = nn.ModuleList([
            nn.Embedding(c, min(50, (c+1)**0.25 * 6 + 0.5 // 1))
            for c in cat_dims
        ])
        emb_total = sum(e.embedding_dim for e in self.embeds)

        self.bn_cont = nn.BatchNorm1d(n_cont)
        self.mlp = nn.Sequential(
            nn.Linear(emb_total + n_cont, hidden), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(hidden, hidden), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(hidden, n_out),
        )

    def forward(self, x_cat, x_cont):
        x_emb = torch.cat([e(x_cat[:, i]) for i, e in enumerate(self.embeds)], dim=1)
        x_cont = self.bn_cont(x_cont)
        return self.mlp(torch.cat([x_emb, x_cont], dim=1))

タスクに合うアーキテクチャを選ぶ

データ第一選択次の選択肢
テーブル勾配ブースティング (XGBoost/LightGBM)MLP, TabNet, FT-Transformer
画像(中規模)ResNet50 + 転移学習EfficientNet, ConvNeXt, ViT
画像(小規模 <1万枚)事前学習モデル + Linear probe強い augmentation + 蒸留
テキストBERT/RoBERTa fine-tuningLLM zero-shot / few-shot
時系列(短期)1D CNN, GRUTransformer Encoder
時系列(長期)Informer / TFTState Space Models
音声Wav2Vec 2.0 / HuBERT 事前学習Conformer
3D点群PointNet++Point Transformer

層の深さと幅

  • 幅 (width) を増やすほど学習容量↑、計算コストは2次オーダーで増
  • 深さ (depth) を増やすほど抽象化↑、勾配消失/爆発のリスク↑(スキップ接続で緩和)
  • 経験則: 「深く・スキップ付き」が大体強い (ResNet, Transformer)
  • 計算予算が決まっている場合は深さ>幅が効くことが多い

初期化と正規化

重み初期化(活性に応じて選ぶ)

活性推奨初期化
ReLU / LeakyReLUKaiming He(fan_in)
Tanh / SigmoidXavier (Glorot)
GELU / SiLUKaiming or 標準正規 (Transformer系では fan_in 系)
ResNet残差最終BatchNormを 0 初期化(残差が恒等から学習)
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)

model.apply(init_weights)

正規化レイヤの使い分け

正規化使う場面
BatchNorm画像分類CNN、batch_size≥16
LayerNormNLP/Transformer、batch_size依存しない
GroupNorm小batch学習、検出/セグ
InstanceNormスタイル変換、生成系
RMSNormLLM最新トレンド(LayerNormの軽量版)

スキップ接続とアテンション

なぜスキップ接続か

  • 勾配の流れ: 浅い層まで勾配が伝わりやすい
  • 恒等写像からの学習: 何もしない=悪化しない、を保証
  • 表現力: 残差ブロックが特徴のリファインを担当
def residual_block(x, layer):
    """残差接続(基本パターン)"""
    return x + layer(x)


# Transformer ブロックの中
def transformer_block(x, attn, ffn, ln1, ln2):
    x = x + attn(ln1(x))     # Pre-LN
    x = x + ffn(ln2(x))
    return x

Self-Attention の最小実装

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_h = d_model // n_heads
        self.n_h = n_heads
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.out = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        B, N, D = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.n_h, self.d_h)
        qkv = qkv.permute(2, 0, 3, 1, 4)              # (3, B, h, N, d_h)
        q, k, v = qkv[0], qkv[1], qkv[2]

        scores = (q @ k.transpose(-2, -1)) / (self.d_h ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn = scores.softmax(dim=-1)

        out = (attn @ v).transpose(1, 2).reshape(B, N, D)
        return self.out(out)

Feed-Forward (FFN) サイズ

Transformer のFFN は 4×d_model が定番。最近は SwiGLU (Gated Linear Unit) も人気。

class SwiGLU(nn.Module):
    def __init__(self, d_model, d_ff=None):
        super().__init__()
        d_ff = d_ff or int(d_model * 8 / 3)
        self.w_gate = nn.Linear(d_model, d_ff, bias=False)
        self.w_up   = nn.Linear(d_model, d_ff, bias=False)
        self.w_down = nn.Linear(d_ff, d_model, bias=False)

    def forward(self, x):
        return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))

設計の落とし穴とTips

よくある落とし穴

症状原因対応
Loss が動かない初期化失敗 / 学習率小すぎ / 勾配消失lr finder, Kaiming初期化, スキップ接続追加
Loss が NaN学習率大 / 入力にNaN / 数値オーバー勾配clip, 入力チェック, FP32維持
過学習モデル過大 / データ少 / Dropoutなしaugmentation, weight_decay, early stop
過小学習モデル小 / lr小 / 特徴量不適幅深さ↑, lr↑, 特徴量見直し
BN とDropoutの併用で性能低下分布シフトBNの後にDropoutは避ける
BatchSize変更で精度大幅低下lr scaling 忘れlr ∝ batch_size で調整

設計プロセスの推奨フロー

  1. ベースライン: 最もシンプルなモデルから(MLP, ResNet18等)
  2. 過学習させてみる: 学習データに過学習できなければ実装ミス疑い
  3. regularization: 過学習が確認できたらDropout / weight_decay / augmentation
  4. 容量を上げる: 深さ・幅を増やす(必要なら)
  5. アーキ変更: 残差・アテンション・分岐構造
  6. 転移学習: 大規模事前学習モデルを使う
  7. ハイパラ調整: 最後に lr/wd/schedule

サニティチェック

def sanity_check(model, x, y):
    """学習開始前に確認すべきこと"""
    # 1. forward が通る
    out = model(x)
    assert out.shape[0] == x.shape[0]

    # 2. loss は finite
    loss = F.cross_entropy(out, y)
    assert torch.isfinite(loss), "Initial loss is NaN/Inf"

    # 3. 学習可能パラメータがある
    n_trainable = sum(p.requires_grad for p in model.parameters())
    assert n_trainable > 0

    # 4. 勾配が流れる
    loss.backward()
    for name, p in model.named_parameters():
        if p.requires_grad and p.grad is None:
            print(f"WARNING: no gradient for {name}")
Karpathy の格言: "Neural networks are leaky abstractions". 必ず小さく動かしてから大きくする。 最初の60秒で過学習できなければ、それは 実装バグです。