PyTorch基礎

学習目標: PyTorchの基本的な使い方と特徴を理解する

PyTorchとは

PyTorchはMeta(旧Facebook)が開発したオープンソースの深層学習フレームワークです。 動的計算グラフ、Pythonライクな書き方、研究との親和性の高さから、研究者・エンジニアに広く使われています。

PyTorchの特徴
  • 動的計算グラフ: 実行時にグラフを構築(柔軟性が高い)
  • Pythonic: NumPyに似た直感的なAPI
  • デバッグしやすい: Python標準のデバッグツールが使える
  • GPUサポート: CUDAによる高速計算
主なコンポーネント
  • torch - テンソル操作
  • torch.nn - ニューラルネットワーク
  • torch.autograd - 自動微分
  • torch.optim - 最適化アルゴリズム
  • torchvision - 画像データセット・モデル

最初のPyTorchコード

import torch

# PyTorchのバージョン確認
print(f"PyTorch version: {torch.__version__}")

# GPUが使えるか確認
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
Google ColabなどGPU環境ではtorch.cuda.is_available()がTrueになります。

テンソルの基本

PyTorchの基本データ構造はテンソルです。NumPyの配列に似ていますが、GPU上で計算でき、自動微分をサポートします。

import torch

# テンソルの作成
x = torch.tensor([1, 2, 3])           # リストから
y = torch.zeros(3, 4)                  # ゼロで初期化
z = torch.ones(2, 3)                   # 1で初期化
r = torch.rand(2, 2)                   # 0-1の乱数
n = torch.randn(3, 3)                  # 標準正規分布

# NumPyからの変換
import numpy as np
np_array = np.array([1, 2, 3])
tensor_from_np = torch.from_numpy(np_array)

# NumPyへの変換
back_to_np = tensor_from_np.numpy()

テンソル演算

# 基本演算
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])

print(a + b)           # 要素ごとの加算
print(a * b)           # 要素ごとの乗算
print(torch.dot(a, b)) # 内積

# 行列演算
A = torch.randn(3, 4)
B = torch.randn(4, 2)
C = torch.mm(A, B)     # 行列積 (3x4) @ (4x2) = (3x2)
# または
C = A @ B

# 形状変更
x = torch.randn(2, 3, 4)
print(x.shape)                    # torch.Size([2, 3, 4])
print(x.view(6, 4).shape)        # torch.Size([6, 4])
print(x.reshape(2, -1).shape)    # torch.Size([2, 12])

GPUの使用

# デバイスの設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# テンソルをGPUに移動
x = torch.randn(3, 3)
x_gpu = x.to(device)
# または
x_gpu = x.cuda()  # GPUに移動
x_cpu = x_gpu.cpu()  # CPUに戻す

# モデルもGPUに移動可能
# model = model.to(device)
GPU上のテンソルとCPU上のテンソルは直接演算できません。同じデバイスに揃える必要があります。

PyTorch vs NumPy 対応表

操作 NumPy PyTorch
配列作成 np.array([1,2,3]) torch.tensor([1,2,3])
ゼロ配列 np.zeros((3,4)) torch.zeros(3,4)
乱数 np.random.randn(3,4) torch.randn(3,4)
形状 x.shape x.shape / x.size()
形状変更 x.reshape(2,6) x.view(2,6) / x.reshape(2,6)
行列積 np.dot(A, B) / A @ B torch.mm(A, B) / A @ B

理解度チェック

Q. PyTorchの特徴として正しいのは?