テンソル

学習目標: テンソルの概念と高度な操作方法を理解する

テンソルとは

テンソルは多次元配列を一般化したものです。スカラー、ベクトル、行列はすべてテンソルの特殊なケースです。

0D
スカラー
torch.tensor(3.14)

単一の値

1D
ベクトル
torch.tensor([1,2,3])

1次元の配列

2D
行列
torch.randn(3,4)

2次元の配列

nD
高階テンソル
torch.randn(2,3,4,5)

n次元の配列

テンソルの属性

import torch

x = torch.randn(3, 4, 5)

# 基本属性
print(f"Shape: {x.shape}")       # torch.Size([3, 4, 5])
print(f"Size: {x.size()}")       # torch.Size([3, 4, 5])
print(f"Dimensions: {x.dim()}")  # 3
print(f"Dtype: {x.dtype}")       # torch.float32
print(f"Device: {x.device}")     # cpu または cuda:0

# 要素数
print(f"Total elements: {x.numel()}")  # 60 (3*4*5)

形状操作

# view: メモリ連続なテンソルの形状変更
x = torch.randn(2, 3, 4)
y = x.view(6, 4)           # 形状変更
y = x.view(-1, 4)          # -1は自動計算(6, 4)
y = x.view(2, -1)          # (2, 12)

# reshape: 非連続テンソルにも使える(必要時にコピー)
y = x.reshape(6, 4)

# squeeze / unsqueeze: 次元の追加・削除
x = torch.randn(1, 3, 1, 4)
y = x.squeeze()           # サイズ1の次元を削除 → (3, 4)
y = x.squeeze(0)          # 特定の次元だけ削除 → (3, 1, 4)
y = x.unsqueeze(0)        # 次元を追加 → (1, 1, 3, 1, 4)

# permute: 次元の入れ替え
x = torch.randn(2, 3, 4)
y = x.permute(2, 0, 1)    # (4, 2, 3)

# transpose: 2つの次元を入れ替え
y = x.transpose(0, 2)     # (4, 3, 2)

結合と分割

# cat: テンソルの連結
a = torch.randn(2, 3)
b = torch.randn(2, 3)
c = torch.cat([a, b], dim=0)  # (4, 3) 行方向に結合
c = torch.cat([a, b], dim=1)  # (2, 6) 列方向に結合

# stack: 新しい次元を追加して結合
c = torch.stack([a, b], dim=0)  # (2, 2, 3)

# split: テンソルの分割
x = torch.randn(6, 4)
chunks = torch.split(x, 2, dim=0)  # 3つの(2, 4)テンソル
chunks = torch.split(x, [1, 2, 3], dim=0)  # (1,4), (2,4), (3,4)

# chunk: 等分割
chunks = torch.chunk(x, 3, dim=0)  # 3つの(2, 4)テンソル

ブロードキャスト

異なるサイズのテンソル間で演算を行う際、PyTorchは自動的に形状を拡張(ブロードキャスト)します。

# ブロードキャストの例
a = torch.randn(3, 4)      # (3, 4)
b = torch.randn(4)         # (4,)
c = a + b                  # bが(1, 4)に拡張され、(3, 4)に

# 条件: 各次元が同じか、どちらかが1の場合に可能
x = torch.randn(5, 1, 3)   # (5, 1, 3)
y = torch.randn(1, 4, 3)   # (1, 4, 3)
z = x + y                  # (5, 4, 3)
ブロードキャストのルール
  1. 次元数が異なる場合、少ない方の先頭に1を追加
  2. 各次元で、サイズが同じか、どちらかが1なら互換
  3. サイズ1の次元は他方のサイズに拡張される

インデックスとスライス

x = torch.arange(12).reshape(3, 4)
# tensor([[ 0,  1,  2,  3],
#         [ 4,  5,  6,  7],
#         [ 8,  9, 10, 11]])

# 基本的なインデックス
print(x[0])         # 最初の行: tensor([0, 1, 2, 3])
print(x[0, 2])      # 1行3列目: tensor(2)
print(x[:, 1])      # 2列目全体: tensor([1, 5, 9])

# スライス
print(x[1:3, 1:3])  # 部分行列
# tensor([[ 5,  6],
#         [ 9, 10]])

# 条件によるインデックス
mask = x > 5
print(x[mask])      # tensor([ 6,  7,  8,  9, 10, 11])

理解度チェック

Q. shape (2, 3, 4) のテンソルをview(-1, 6)で変形すると、結果のshapeは?