ニューラルネットワーク構築
学習目標: torch.nnを使ってニューラルネットワークを構築する方法を理解する
nn.Moduleの基本
PyTorchでは、ニューラルネットワークはnn.Moduleを継承したクラスとして定義します。
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
# レイヤーの定義
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
# 順伝播の定義
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# モデルのインスタンス化
model = SimpleNet(input_size=784, hidden_size=128, output_size=10)
print(model)
主要なレイヤー
| レイヤー | 説明 | 使用例 |
|---|---|---|
nn.Linear |
全結合層 | nn.Linear(in_features, out_features) |
nn.Conv2d |
2D畳み込み層 | nn.Conv2d(in_ch, out_ch, kernel_size) |
nn.MaxPool2d |
最大プーリング | nn.MaxPool2d(kernel_size) |
nn.BatchNorm2d |
バッチ正規化 | nn.BatchNorm2d(num_features) |
nn.Dropout |
ドロップアウト | nn.Dropout(p=0.5) |
nn.LSTM |
LSTM層 | nn.LSTM(input_size, hidden_size) |
活性化関数
ReLU
nn.ReLU()
f(x) = max(0, x)
最も一般的。隠れ層に使用。Sigmoid
nn.Sigmoid()
f(x) = 1/(1+e⁻ˣ)
二値分類の出力層に。Softmax
nn.Softmax(dim=1)
確率分布に変換
多クラス分類の出力層に。nn.Sequential
シンプルなネットワークはnn.Sequentialで簡潔に定義できます。
# Sequentialを使った定義
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 128),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(128, 10)
)
# 入力を渡して出力を得る
x = torch.randn(32, 784) # バッチサイズ32、入力次元784
output = model(x) # (32, 10)
パラメータの確認
# モデルのパラメータを確認
for name, param in model.named_parameters():
print(f"{name}: {param.shape}")
# パラメータ数をカウント
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total: {total_params:,}, Trainable: {trainable_params:,}")
実践例:MNISTの分類器
class MNISTClassifier(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.network = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 10)
)
def forward(self, x):
x = self.flatten(x) # (N, 1, 28, 28) → (N, 784)
logits = self.network(x) # (N, 784) → (N, 10)
return logits
model = MNISTClassifier()
print(f"パラメータ数: {sum(p.numel() for p in model.parameters()):,}")
理解度チェック
Q. nn.Moduleを継承したクラスで必ず実装する必要があるメソッドは?