-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathalgo.py
195 lines (158 loc) · 5.83 KB
/
algo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
from collections import defaultdict
from functools import partial
from typing import Any, Callable, Tuple
import flax
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from flax.training.train_state import TrainState
from jax.random import PRNGKey
"""
Inspired by code from Flax: https://github.com/google/flax/blob/main/examples/ppo/ppo_lib.py
"""
def loss_actor_and_critic(
params_model: flax.core.frozen_dict.FrozenDict,
apply_fn: Callable[..., Any],
state: jnp.ndarray,
target: jnp.ndarray,
value_old: jnp.ndarray,
log_pi_old: jnp.ndarray,
gae: jnp.ndarray,
action: jnp.ndarray,
clip_eps: float,
critic_coeff: float,
entropy_coeff: float
) -> jnp.ndarray:
state = state.astype(jnp.float32) / 255.
value_pred, pi = apply_fn(params_model, state)
value_pred = value_pred[:, 0]
log_prob = pi.log_prob(action[:,0])
value_pred_clipped = value_old + (value_pred - value_old).clip(
-clip_eps, clip_eps)
value_losses = jnp.square(value_pred - target)
value_losses_clipped = jnp.square(value_pred_clipped - target)
value_loss = 0.5 * jnp.maximum(value_losses,
value_losses_clipped).mean()
ratio = jnp.exp(log_prob - log_pi_old)
gae = (gae - gae.mean()) / (gae.std() + 1e-8)
loss_actor1 = ratio * gae
loss_actor2 = jnp.clip(ratio, 1.0 - clip_eps,
1.0 + clip_eps) * gae
loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
loss_actor = loss_actor.mean()
entropy = pi.entropy().mean()
total_loss = loss_actor + critic_coeff * value_loss - entropy_coeff * entropy
return total_loss, (value_loss, loss_actor, entropy, value_pred.mean(),
target.mean(), gae.mean())
@partial(jax.jit, static_argnums=0)
def policy(apply_fn: Callable[..., Any],
params: flax.core.frozen_dict.FrozenDict,
state: np.ndarray):
value, pi = apply_fn(params, state)
return value, pi
def select_action(
train_state: TrainState,
state: np.ndarray,
rng: PRNGKey,
sample: bool = False
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, PRNGKey]:
value, pi = policy(train_state.apply_fn, train_state.params, state)
if sample:
rng, key = jax.random.split(rng)
action = pi.sample(seed=key)
else:
action = pi.mode()
log_prob = pi.log_prob(action)
return action, log_prob, value[:, 0], rng
def get_transition(
train_state: TrainState,
env,
state,
batch,
rng: PRNGKey,
):
action, log_pi, value, new_key = select_action(
train_state, state.astype(jnp.float32) / 255., rng, sample=True)
next_state, reward, done, _ = env.step(action)
batch.append(state, action, reward, done,
np.array(log_pi), np.array(value))
return train_state, next_state, batch, new_key
@partial(jax.jit)
def flatten_dims(x):
return x.swapaxes(0, 1).reshape(x.shape[0] * x.shape[1], *x.shape[2:])
def update(
train_state: TrainState,
batch: Tuple,
num_envs: int,
n_steps: int,
n_minibatch: int,
epoch_ppo: int,
clip_eps: float,
entropy_coeff: float,
critic_coeff: float,
rng: PRNGKey
):
state, action, log_pi_old, value, target, gae = batch
size_batch = num_envs * n_steps
size_minibatch = size_batch // n_minibatch
idxes = np.arange(num_envs * n_steps)
avg_metrics_dict = defaultdict(int)
for _ in range(epoch_ppo):
idxes = np.random.permutation(idxes)
idxes_list = [idxes[start:start + size_minibatch]
for start in jnp.arange(0, size_batch, size_minibatch)]
train_state, total_loss = update_jit(
train_state,
idxes_list,
flatten_dims(state),
flatten_dims(action),
flatten_dims(log_pi_old),
flatten_dims(value),
np.array(flatten_dims(target)),
np.array(flatten_dims(gae)),
clip_eps,
entropy_coeff,
critic_coeff)
total_loss, (value_loss, loss_actor, entropy, value_pred,
target_val, gae_val) = total_loss
avg_metrics_dict['total_loss'] += np.asarray(total_loss)
avg_metrics_dict['value_loss'] += np.asarray(value_loss)
avg_metrics_dict['actor_loss'] += np.asarray(loss_actor)
avg_metrics_dict['entropy'] += np.asarray(entropy)
avg_metrics_dict['value_pred'] += np.asarray(value_pred)
avg_metrics_dict['target'] += np.asarray(target_val)
avg_metrics_dict['gae'] += np.asarray(gae_val)
for k, v in avg_metrics_dict.items():
avg_metrics_dict[k] = v / (epoch_ppo)
return avg_metrics_dict, train_state, rng
@partial(jax.jit, static_argnames=("clip_eps", "entropy_coeff", "critic_coeff"))
def update_jit(
train_state: TrainState,
idxes: np.ndarray,
state,
action,
log_pi_old,
value,
target,
gae,
clip_eps: float,
entropy_coeff: float,
critic_coeff: float
):
for idx in idxes:
grad_fn = jax.value_and_grad(loss_actor_and_critic,
has_aux=True)
total_loss, grads = grad_fn(train_state.params,
train_state.apply_fn,
state=state[idx],
target=target[idx],
value_old=value[idx],
log_pi_old=log_pi_old[idx],
gae=gae[idx],
action=action[idx].reshape(-1, 1),
clip_eps=clip_eps,
critic_coeff=critic_coeff,
entropy_coeff=entropy_coeff)
train_state = train_state.apply_gradients(grads=grads)
return train_state, total_loss