CNN(畳み込みニューラルネットワーク)

学習目標: CNNの構造と画像認識への適用方法を理解する

CNNとは

CNN(Convolutional Neural Network)は、画像認識に特化したニューラルネットワークです。 畳み込み層で局所的な特徴を抽出し、階層的に複雑なパターンを学習します。

CNNの主要コンポーネント
畳み込み層

フィルタ(カーネル)を使って画像から特徴を抽出

プーリング層

特徴マップのサイズを縮小し、位置不変性を獲得

全結合層

抽出した特徴を使って最終的な分類を実行

畳み込み層

import torch.nn as nn

# 2D畳み込み層
conv = nn.Conv2d(
    in_channels=3,      # 入力チャンネル数(RGB=3)
    out_channels=32,    # 出力チャンネル数(フィルタ数)
    kernel_size=3,      # カーネルサイズ(3x3)
    stride=1,           # ストライド
    padding=1           # パディング
)

# 入力: (N, C, H, W) = (バッチ, チャンネル, 高さ, 幅)
x = torch.randn(16, 3, 32, 32)  # バッチ16、RGB、32x32画像
out = conv(x)  # (16, 32, 32, 32)
出力サイズの計算: out_size = (in_size + 2×padding - kernel_size) / stride + 1

プーリング層

# 最大プーリング
maxpool = nn.MaxPool2d(kernel_size=2, stride=2)  # 2x2プーリング
# (N, C, 32, 32) → (N, C, 16, 16)

# 平均プーリング
avgpool = nn.AvgPool2d(kernel_size=2, stride=2)

# グローバル平均プーリング(画像全体を1つの値に)
gap = nn.AdaptiveAvgPool2d((1, 1))  # (N, C, H, W) → (N, C, 1, 1)

CNN実装例

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        # 特徴抽出部分
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 32x32 → 16x16

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 16x16 → 8x8

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 8x8 → 4x4
        )

        # 分類部分
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

model = SimpleCNN()
x = torch.randn(1, 3, 32, 32)  # CIFAR-10サイズ
print(model(x).shape)  # torch.Size([1, 10])

有名なCNNアーキテクチャ

名前 特徴 パラメータ数
LeNet-5 1998 最初の実用的CNN 60K
AlexNet 2012 ReLU、Dropout導入 60M
VGG-16 2014 3x3カーネルの積み重ね 138M
ResNet 2015 スキップ接続(残差学習) 25M

理解度チェック

Q. プーリング層の主な役割は?