CIFARデモ
学習目標: CIFAR-10データセットを使ってカラー画像の分類モデルを実装する
CIFAR-10データセット
CIFAR-10は32×32ピクセルのカラー画像で、10種類の物体クラスを含みます。
10クラス:
airplane
automobile
bird
cat
deer
dog
frog
horse
ship
truck
実装コード
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# データ拡張付きの前処理
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
])
# データセット
train_dataset = datasets.CIFAR10('./data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10('./data', train=False, transform=test_transform)
# CNNモデル
class CIFAR10Net(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Dropout(0.25),
nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Dropout(0.25),
nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Dropout(0.25),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(256 * 4 * 4, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, 10)
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
期待される性能
| モデル | パラメータ数 | 精度 |
|---|---|---|
| シンプルなCNN | ~1M | ~85% |
| 上記モデル(データ拡張あり) | ~2.5M | ~90% |
| ResNet-18(転移学習) | ~11M | ~95% |
理解度チェック
Q. CIFAR-10の画像サイズは?