MNISTデモ

学習目標: MNISTデータセットを使って手書き数字の分類モデルを実装する

MNISTデータセット

MNISTは28×28ピクセルの手書き数字(0-9)の画像データセットです。機械学習の「Hello World」として知られています。

60,000

訓練画像

10,000

テスト画像

10

クラス(0-9)

完全な実装コード

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# デバイス設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# データの準備
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000)

# モデル定義
class MNISTNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

model = MNISTNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# 訓練ループ
def train(model, train_loader, optimizer, criterion, device):
    model.train()
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

# テスト
def test(model, test_loader, device):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
    return correct / len(test_loader.dataset)

# 実行
for epoch in range(5):
    train(model, train_loader, optimizer, criterion, device)
    acc = test(model, test_loader, device)
    print(f"Epoch {epoch+1}: Accuracy = {acc:.4f}")

期待される結果

上記のモデルで5エポック訓練すると、約99%のテスト精度が達成できます。

  • Epoch 1: ~98%
  • Epoch 2: ~98.5%
  • Epoch 3: ~99%
  • Epoch 5: ~99.2%

理解度チェック

Q. MNISTの画像サイズは?