-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcartPole.py
172 lines (138 loc) · 5.18 KB
/
cartPole.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
# 如果是 gymnasium 則 import gymnasium as gym
# 若 gym 在 Python 3.11 可能有版本相容問題,可以用 pip install gym==0.26.2
class PolicyNetwork(nn.Module):
def __init__(self, state_dim=4, hidden_dim=128, action_dim=2):
super(PolicyNetwork, self).__init__()
self.fc1 = nn.Linear(state_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, action_dim)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
# 輸出 logits (對應 action_dim=2)
logits = self.fc3(x)
return logits
def get_action(self, state):
"""
給定單一狀態 (shape: (4,)),
輸出一個 action 與 log_prob。
"""
# 轉成 batch=1 的張量
state = torch.FloatTensor(state).unsqueeze(0) # shape (1, 4)
logits = self.forward(state) # shape (1, 2)
# 用 Categorical 分佈做采樣
dist = torch.distributions.Categorical(logits=logits)
action = dist.sample() # 得到一個整數 0 or 1
log_prob = dist.log_prob(action)
return action.item(), log_prob
def discount_rewards(rewards, gamma=0.99):
"""
給定一串 step 的 reward,例如 [r0, r1, r2, ...],
回傳每個 time step t 對應的折扣後回報 G_t。
"""
discounted = np.zeros_like(rewards, dtype=np.float32)
running_add = 0
for t in reversed(range(len(rewards))):
running_add = rewards[t] + gamma * running_add
discounted[t] = running_add
return discounted
def normalize_rewards(rewards):
"""
Normalize the rewards to have mean 0 and standard deviation 1.
"""
rewards = np.array(rewards)
rewards = (rewards - np.mean(rewards)) / (np.std(rewards) + 1e-8)
return rewards
def run_episode(env, policy_net, gamma=0.99):
"""
跑一個 episode,收集 (log_prob, reward)。
回傳:
- log_probs: list of log_prob (tensor)
- rewards: list of float
- total_reward: episode 最後得到的累積 reward (評估用)
"""
log_probs = []
rewards = []
total_reward = 0
state = env.reset()[0] # 若是 gymnasium,env.reset() 回傳 (obs, info)
done = False
while not done:
action, log_prob = policy_net.get_action(state)
# 執行動作
next_state, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
# 紀錄
log_probs.append(log_prob)
rewards.append(reward)
total_reward += reward
state = next_state
# 計算折扣後回報
discounted_r = discount_rewards(rewards, gamma) # shape (episode_length,)
normalized_r = normalize_rewards(discounted_r) # Normalize rewards
return log_probs, normalized_r, total_reward
def train_cartpole(
max_episodes=1000,
gamma=0.99,
lr=1e-3,
hidden_dim=128
):
env = gym.make("CartPole-v1")
# 環境狀態維度=4,動作維度=2
policy_net = PolicyNetwork(state_dim=4, hidden_dim=hidden_dim, action_dim=2)
optimizer = optim.Adam(policy_net.parameters(), lr=lr)
rewards = []
for episode in range(max_episodes):
log_probs, discounted_r, total_reward = run_episode(env, policy_net, gamma)
# 計算 baseline (平均回報)
baseline = np.mean(discounted_r)
# 準備計算 policy gradient loss
# Σ_t [ -log_pi(a_t|s_t) * (G_t - baseline) ]
loss = 0
for log_prob, Gt in zip(log_probs, discounted_r):
loss += -log_prob * (Gt - baseline)
# 反向傳播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 紀錄每個 episode 的總 reward
rewards.append(total_reward)
# 顯示訓練進度
print(f"Episode {episode}, Reward = {total_reward}")
# 如果總分連續多次都達到滿分(500),可以提早結束
if total_reward >= 500:
print("Solved CartPole!")
break
env.close()
# 繪製 reward 圖
plt.plot(rewards)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('Total Reward per Episode')
plt.show()
return policy_net
def play_cartpole(env, policy_net, render=True):
state = env.reset()[0]
done = False
total_reward = 0
while not done:
if render:
env.render()
action, _ = policy_net.get_action(state)
state, reward, terminated, truncated, info = env.step(action)
done = terminated or truncated
total_reward += reward
return total_reward
if __name__ == "__main__":
policy_net = train_cartpole(max_episodes=1000, gamma=0.97, lr=3*1e-4, hidden_dim=256)
# 測試
env = gym.make("CartPole-v1", render_mode="human") # gym 0.26+ 需要指定 render_mode="human"
score = play_cartpole(env, policy_net, render=True)
print("Test Score:", score)
env.close()