PyTorch基礎
学習目標: PyTorchの基本的な使い方と特徴を理解する
PyTorchとは
PyTorchはMeta(旧Facebook)が開発したオープンソースの深層学習フレームワークです。 動的計算グラフ、Pythonライクな書き方、研究との親和性の高さから、研究者・エンジニアに広く使われています。
PyTorchの特徴
- 動的計算グラフ: 実行時にグラフを構築(柔軟性が高い)
- Pythonic: NumPyに似た直感的なAPI
- デバッグしやすい: Python標準のデバッグツールが使える
- GPUサポート: CUDAによる高速計算
主なコンポーネント
torch- テンソル操作torch.nn- ニューラルネットワークtorch.autograd- 自動微分torch.optim- 最適化アルゴリズムtorchvision- 画像データセット・モデル
最初のPyTorchコード
import torch
# PyTorchのバージョン確認
print(f"PyTorch version: {torch.__version__}")
# GPUが使えるか確認
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
Google ColabなどGPU環境では
torch.cuda.is_available()がTrueになります。
テンソルの基本
PyTorchの基本データ構造はテンソルです。NumPyの配列に似ていますが、GPU上で計算でき、自動微分をサポートします。
import torch
# テンソルの作成
x = torch.tensor([1, 2, 3]) # リストから
y = torch.zeros(3, 4) # ゼロで初期化
z = torch.ones(2, 3) # 1で初期化
r = torch.rand(2, 2) # 0-1の乱数
n = torch.randn(3, 3) # 標準正規分布
# NumPyからの変換
import numpy as np
np_array = np.array([1, 2, 3])
tensor_from_np = torch.from_numpy(np_array)
# NumPyへの変換
back_to_np = tensor_from_np.numpy()
テンソル演算
# 基本演算
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])
print(a + b) # 要素ごとの加算
print(a * b) # 要素ごとの乗算
print(torch.dot(a, b)) # 内積
# 行列演算
A = torch.randn(3, 4)
B = torch.randn(4, 2)
C = torch.mm(A, B) # 行列積 (3x4) @ (4x2) = (3x2)
# または
C = A @ B
# 形状変更
x = torch.randn(2, 3, 4)
print(x.shape) # torch.Size([2, 3, 4])
print(x.view(6, 4).shape) # torch.Size([6, 4])
print(x.reshape(2, -1).shape) # torch.Size([2, 12])
GPUの使用
# デバイスの設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# テンソルをGPUに移動
x = torch.randn(3, 3)
x_gpu = x.to(device)
# または
x_gpu = x.cuda() # GPUに移動
x_cpu = x_gpu.cpu() # CPUに戻す
# モデルもGPUに移動可能
# model = model.to(device)
GPU上のテンソルとCPU上のテンソルは直接演算できません。同じデバイスに揃える必要があります。
PyTorch vs NumPy 対応表
| 操作 | NumPy | PyTorch |
|---|---|---|
| 配列作成 | np.array([1,2,3]) |
torch.tensor([1,2,3]) |
| ゼロ配列 | np.zeros((3,4)) |
torch.zeros(3,4) |
| 乱数 | np.random.randn(3,4) |
torch.randn(3,4) |
| 形状 | x.shape |
x.shape / x.size() |
| 形状変更 | x.reshape(2,6) |
x.view(2,6) / x.reshape(2,6) |
| 行列積 | np.dot(A, B) / A @ B |
torch.mm(A, B) / A @ B |
理解度チェック
Q. PyTorchの特徴として正しいのは?