テンソル
学習目標: テンソルの概念と高度な操作方法を理解する
テンソルとは
テンソルは多次元配列を一般化したものです。スカラー、ベクトル、行列はすべてテンソルの特殊なケースです。
0D
スカラー
torch.tensor(3.14)
単一の値
1D
ベクトル
torch.tensor([1,2,3])
1次元の配列
2D
行列
torch.randn(3,4)
2次元の配列
nD
高階テンソル
torch.randn(2,3,4,5)
n次元の配列
テンソルの属性
import torch
x = torch.randn(3, 4, 5)
# 基本属性
print(f"Shape: {x.shape}") # torch.Size([3, 4, 5])
print(f"Size: {x.size()}") # torch.Size([3, 4, 5])
print(f"Dimensions: {x.dim()}") # 3
print(f"Dtype: {x.dtype}") # torch.float32
print(f"Device: {x.device}") # cpu または cuda:0
# 要素数
print(f"Total elements: {x.numel()}") # 60 (3*4*5)
形状操作
# view: メモリ連続なテンソルの形状変更
x = torch.randn(2, 3, 4)
y = x.view(6, 4) # 形状変更
y = x.view(-1, 4) # -1は自動計算(6, 4)
y = x.view(2, -1) # (2, 12)
# reshape: 非連続テンソルにも使える(必要時にコピー)
y = x.reshape(6, 4)
# squeeze / unsqueeze: 次元の追加・削除
x = torch.randn(1, 3, 1, 4)
y = x.squeeze() # サイズ1の次元を削除 → (3, 4)
y = x.squeeze(0) # 特定の次元だけ削除 → (3, 1, 4)
y = x.unsqueeze(0) # 次元を追加 → (1, 1, 3, 1, 4)
# permute: 次元の入れ替え
x = torch.randn(2, 3, 4)
y = x.permute(2, 0, 1) # (4, 2, 3)
# transpose: 2つの次元を入れ替え
y = x.transpose(0, 2) # (4, 3, 2)
結合と分割
# cat: テンソルの連結
a = torch.randn(2, 3)
b = torch.randn(2, 3)
c = torch.cat([a, b], dim=0) # (4, 3) 行方向に結合
c = torch.cat([a, b], dim=1) # (2, 6) 列方向に結合
# stack: 新しい次元を追加して結合
c = torch.stack([a, b], dim=0) # (2, 2, 3)
# split: テンソルの分割
x = torch.randn(6, 4)
chunks = torch.split(x, 2, dim=0) # 3つの(2, 4)テンソル
chunks = torch.split(x, [1, 2, 3], dim=0) # (1,4), (2,4), (3,4)
# chunk: 等分割
chunks = torch.chunk(x, 3, dim=0) # 3つの(2, 4)テンソル
ブロードキャスト
異なるサイズのテンソル間で演算を行う際、PyTorchは自動的に形状を拡張(ブロードキャスト)します。
# ブロードキャストの例
a = torch.randn(3, 4) # (3, 4)
b = torch.randn(4) # (4,)
c = a + b # bが(1, 4)に拡張され、(3, 4)に
# 条件: 各次元が同じか、どちらかが1の場合に可能
x = torch.randn(5, 1, 3) # (5, 1, 3)
y = torch.randn(1, 4, 3) # (1, 4, 3)
z = x + y # (5, 4, 3)
ブロードキャストのルール
- 次元数が異なる場合、少ない方の先頭に1を追加
- 各次元で、サイズが同じか、どちらかが1なら互換
- サイズ1の次元は他方のサイズに拡張される
インデックスとスライス
x = torch.arange(12).reshape(3, 4)
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]])
# 基本的なインデックス
print(x[0]) # 最初の行: tensor([0, 1, 2, 3])
print(x[0, 2]) # 1行3列目: tensor(2)
print(x[:, 1]) # 2列目全体: tensor([1, 5, 9])
# スライス
print(x[1:3, 1:3]) # 部分行列
# tensor([[ 5, 6],
# [ 9, 10]])
# 条件によるインデックス
mask = x > 5
print(x[mask]) # tensor([ 6, 7, 8, 9, 10, 11])
理解度チェック
Q. shape (2, 3, 4) のテンソルをview(-1, 6)で変形すると、結果のshapeは?