ニューラルネットワーク構築

学習目標: torch.nnを使ってニューラルネットワークを構築する方法を理解する

nn.Moduleの基本

PyTorchでは、ニューラルネットワークはnn.Moduleを継承したクラスとして定義します。

import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        # レイヤーの定義
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # 順伝播の定義
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# モデルのインスタンス化
model = SimpleNet(input_size=784, hidden_size=128, output_size=10)
print(model)

主要なレイヤー

レイヤー 説明 使用例
nn.Linear 全結合層 nn.Linear(in_features, out_features)
nn.Conv2d 2D畳み込み層 nn.Conv2d(in_ch, out_ch, kernel_size)
nn.MaxPool2d 最大プーリング nn.MaxPool2d(kernel_size)
nn.BatchNorm2d バッチ正規化 nn.BatchNorm2d(num_features)
nn.Dropout ドロップアウト nn.Dropout(p=0.5)
nn.LSTM LSTM層 nn.LSTM(input_size, hidden_size)

活性化関数

ReLU
nn.ReLU()

f(x) = max(0, x)

最も一般的。隠れ層に使用。
Sigmoid
nn.Sigmoid()

f(x) = 1/(1+e⁻ˣ)

二値分類の出力層に。
Softmax
nn.Softmax(dim=1)

確率分布に変換

多クラス分類の出力層に。

nn.Sequential

シンプルなネットワークはnn.Sequentialで簡潔に定義できます。

# Sequentialを使った定義
model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(128, 10)
)

# 入力を渡して出力を得る
x = torch.randn(32, 784)  # バッチサイズ32、入力次元784
output = model(x)          # (32, 10)

パラメータの確認

# モデルのパラメータを確認
for name, param in model.named_parameters():
    print(f"{name}: {param.shape}")

# パラメータ数をカウント
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total: {total_params:,}, Trainable: {trainable_params:,}")

実践例:MNISTの分類器

class MNISTClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.network = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 10)
        )

    def forward(self, x):
        x = self.flatten(x)      # (N, 1, 28, 28) → (N, 784)
        logits = self.network(x)  # (N, 784) → (N, 10)
        return logits

model = MNISTClassifier()
print(f"パラメータ数: {sum(p.numel() for p in model.parameters()):,}")

理解度チェック

Q. nn.Moduleを継承したクラスで必ず実装する必要があるメソッドは?