強化学習 (Reinforcement Learning)
学習目標: エージェント・環境・報酬のフレームワークを理解し、Q学習・DQN・Policy Gradient・Actor-Critic を PyTorch で実装できるようになる
強化学習の基本要素
action a_t
┌─────────────────────┐
│ ▼
[Agent] [Environment]
▲ │
└─────────────────────┘
state s_{t+1}, reward r_t
| 要素 | 説明 |
|---|---|
| 状態 s | 環境の観測(例: ゲーム画面、ロボットの関節角度) |
| 行動 a | エージェントが選ぶ操作(離散 or 連続) |
| 報酬 r | 環境から得られるシグナル(最大化したい量) |
| 方策 π(a|s) | 状態 → 行動の確率分布 |
| 価値関数 V(s) / Q(s,a) | その状態(と行動)から得られる将来報酬の期待値 |
| γ | 割引率(将来報酬の現在価値、0.9〜0.99) |
ベルマン方程式
Q(s, a) = E[r + γ · max_{a'} Q(s', a') | s, a] ← Q学習
V(s) = E_π[r + γ · V(s')] ← 価値反復
gym環境の基本ループ
import gymnasium as gym
env = gym.make("CartPole-v1")
state, info = env.reset(seed=42)
total_reward = 0
for step in range(500):
action = agent.select_action(state) # ε-greedy など
next_state, reward, terminated, truncated, info = env.step(action)
agent.store(state, action, reward, next_state, terminated)
agent.update()
total_reward += reward
state = next_state
if terminated or truncated:
break
print(f"return: {total_reward}")
env.close()
Q学習 → DQN
表形式Q学習(離散・小空間)
import numpy as np
Q = np.zeros((n_states, n_actions))
alpha, gamma, eps = 0.1, 0.99, 0.1
for ep in range(num_episodes):
s, _ = env.reset()
while True:
a = env.action_space.sample() if np.random.rand() < eps else Q[s].argmax()
s2, r, term, trunc, _ = env.step(a)
Q[s, a] += alpha * (r + gamma * Q[s2].max() - Q[s, a])
s = s2
if term or trunc: break
DQN (Deep Q-Network)
Q関数をニューラルネットで近似。経験再生バッファとターゲットネットワークで安定化。
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import deque
import random
class QNet(nn.Module):
def __init__(self, state_dim, n_actions):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, 128), nn.ReLU(),
nn.Linear(128, 128), nn.ReLU(),
nn.Linear(128, n_actions),
)
def forward(self, x):
return self.net(x)
class ReplayBuffer:
def __init__(self, capacity=10000):
self.buf = deque(maxlen=capacity)
def push(self, *transition):
self.buf.append(transition)
def sample(self, batch_size):
batch = random.sample(self.buf, batch_size)
s, a, r, s2, d = zip(*batch)
return (torch.stack(s), torch.tensor(a), torch.tensor(r, dtype=torch.float32),
torch.stack(s2), torch.tensor(d, dtype=torch.float32))
def train_dqn(env, episodes=500, batch_size=64, gamma=0.99, target_sync=50):
qnet = QNet(env.observation_space.shape[0], env.action_space.n)
target = QNet(env.observation_space.shape[0], env.action_space.n)
target.load_state_dict(qnet.state_dict())
opt = torch.optim.Adam(qnet.parameters(), lr=1e-3)
buf = ReplayBuffer()
eps = 1.0
for ep in range(episodes):
s, _ = env.reset()
s = torch.tensor(s, dtype=torch.float32)
while True:
if random.random() < eps:
a = env.action_space.sample()
else:
with torch.no_grad():
a = qnet(s).argmax().item()
s2, r, term, trunc, _ = env.step(a)
s2_t = torch.tensor(s2, dtype=torch.float32)
buf.push(s, a, r, s2_t, term)
s = s2_t
if len(buf.buf) >= batch_size:
S, A, R, S2, D = buf.sample(batch_size)
with torch.no_grad():
target_q = R + gamma * target(S2).max(1).values * (1 - D)
pred_q = qnet(S).gather(1, A.unsqueeze(1)).squeeze()
loss = F.smooth_l1_loss(pred_q, target_q)
opt.zero_grad(); loss.backward(); opt.step()
if term or trunc: break
eps = max(0.05, eps * 0.995)
if ep % target_sync == 0:
target.load_state_dict(qnet.state_dict())
改良版
- Double DQN: 過大評価バイアスを抑制。
argmaxを qnet で、Q値を target で取る - Dueling DQN: 価値 V(s) と アドバンテージ A(s,a) を分離して学習
- Prioritized Replay: TD誤差の大きい経験を優先的にサンプリング
Policy Gradient (方策勾配法)
方策 π_θ(a|s) を直接最適化する。価値を経由しないため、連続行動空間にも自然に拡張できる。
REINFORCE
class PolicyNet(nn.Module):
def __init__(self, state_dim, n_actions):
super().__init__()
self.net = nn.Sequential(
nn.Linear(state_dim, 128), nn.ReLU(),
nn.Linear(128, n_actions),
)
def forward(self, x):
return F.softmax(self.net(x), dim=-1)
def train_reinforce(env, episodes=1000, gamma=0.99, lr=1e-3):
pi = PolicyNet(env.observation_space.shape[0], env.action_space.n)
opt = torch.optim.Adam(pi.parameters(), lr=lr)
for ep in range(episodes):
s, _ = env.reset()
log_probs, rewards = [], []
while True:
s_t = torch.tensor(s, dtype=torch.float32)
probs = pi(s_t)
dist = torch.distributions.Categorical(probs)
a = dist.sample()
log_probs.append(dist.log_prob(a))
s, r, term, trunc, _ = env.step(a.item())
rewards.append(r)
if term or trunc: break
# 割引リターン
returns, R = [], 0
for r in reversed(rewards):
R = r + gamma * R
returns.insert(0, R)
returns = torch.tensor(returns)
returns = (returns - returns.mean()) / (returns.std() + 1e-8)
# 方策勾配
loss = -(torch.stack(log_probs) * returns).sum()
opt.zero_grad(); loss.backward(); opt.step()
分散の大きさ: REINFORCEはエピソードリターンを使うため分散が大きい。次節のActor-CriticでValue baselineを引いて分散を減らすのが定石。
Actor-Critic / PPO
Actor (方策) と Critic (価値関数) を同時に学習。アドバンテージ A(s,a) = Q(s,a) − V(s) を使って分散を減らす。
A2C (Advantage Actor-Critic)
class ActorCritic(nn.Module):
def __init__(self, state_dim, n_actions):
super().__init__()
self.shared = nn.Sequential(nn.Linear(state_dim, 128), nn.ReLU())
self.actor = nn.Linear(128, n_actions)
self.critic = nn.Linear(128, 1)
def forward(self, x):
h = self.shared(x)
return F.softmax(self.actor(h), dim=-1), self.critic(h)
def train_a2c_step(model, opt, transitions, gamma=0.99):
states, actions, rewards, dones, next_states = zip(*transitions)
states = torch.stack(states)
actions = torch.tensor(actions)
rewards = torch.tensor(rewards, dtype=torch.float32)
next_states = torch.stack(next_states)
dones = torch.tensor(dones, dtype=torch.float32)
probs, values = model(states)
_, next_values = model(next_states)
targets = rewards + gamma * next_values.squeeze() * (1 - dones)
advantages = targets.detach() - values.squeeze()
dist = torch.distributions.Categorical(probs)
log_prob = dist.log_prob(actions)
entropy = dist.entropy().mean()
actor_loss = -(log_prob * advantages.detach()).mean()
critic_loss = F.mse_loss(values.squeeze(), targets.detach())
loss = actor_loss + 0.5 * critic_loss - 0.01 * entropy
opt.zero_grad(); loss.backward(); opt.step()
PPO (Proximal Policy Optimization)
方策更新を「行き過ぎないように」クリッピングする現代の定番手法。OpenAI/DeepMindでも標準。
def ppo_loss(old_log_probs, new_log_probs, advantages, eps_clip=0.2):
ratio = (new_log_probs - old_log_probs).exp()
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1 - eps_clip, 1 + eps_clip) * advantages
return -torch.min(surr1, surr2).mean()
| 手法 | 適用領域 | 特徴 |
|---|---|---|
| DQN | 離散行動 | サンプル効率○、安定するまで時間 |
| REINFORCE | 離散/連続 | シンプル、分散大 |
| A2C/A3C | 離散/連続 | 並列化容易 |
| PPO | 離散/連続 | 現代の定番。安定・調整少 |
| SAC | 連続のみ | サンプル効率○、ロボット系で強い |