-
Notifications
You must be signed in to change notification settings - Fork 71
/
Copy pathcql.py
435 lines (376 loc) · 16.3 KB
/
cql.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
from collections import OrderedDict
import numpy as np
import torch
import torch.optim as optim
from torch import nn as nn
import rlkit.torch.pytorch_util as ptu
from rlkit.core.eval_util import create_stats_ordered_dict
from rlkit.torch.torch_rl_algorithm import TorchTrainer
from torch import autograd
class CQLTrainer(TorchTrainer):
def __init__(
self,
env,
policy,
qf1,
qf2,
target_qf1,
target_qf2,
discount=0.99,
reward_scale=1.0,
policy_lr=1e-3,
qf_lr=1e-3,
optimizer_class=optim.Adam,
soft_target_tau=1e-2,
plotter=None,
render_eval_paths=False,
use_automatic_entropy_tuning=True,
target_entropy=None,
policy_eval_start=0,
num_qs=2,
# CQL
min_q_version=3,
temp=1.0,
min_q_weight=1.0,
## sort of backup
max_q_backup=False,
deterministic_backup=True,
num_random=10,
with_lagrange=False,
lagrange_thresh=0.0,
):
super().__init__()
self.env = env
self.policy = policy
self.qf1 = qf1
self.qf2 = qf2
self.target_qf1 = target_qf1
self.target_qf2 = target_qf2
self.soft_target_tau = soft_target_tau
self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
if self.use_automatic_entropy_tuning:
if target_entropy:
self.target_entropy = target_entropy
else:
self.target_entropy = -np.prod(self.env.action_space.shape).item()
self.log_alpha = ptu.zeros(1, requires_grad=True)
self.alpha_optimizer = optimizer_class(
[self.log_alpha],
lr=policy_lr,
)
self.with_lagrange = with_lagrange
if self.with_lagrange:
self.target_action_gap = lagrange_thresh
self.log_alpha_prime = ptu.zeros(1, requires_grad=True)
self.alpha_prime_optimizer = optimizer_class(
[self.log_alpha_prime],
lr=qf_lr,
)
self.plotter = plotter
self.render_eval_paths = render_eval_paths
self.qf_criterion = nn.MSELoss()
self.vf_criterion = nn.MSELoss()
self.policy_optimizer = optimizer_class(
self.policy.parameters(),
lr=policy_lr,
)
self.qf1_optimizer = optimizer_class(
self.qf1.parameters(),
lr=qf_lr,
)
self.qf2_optimizer = optimizer_class(
self.qf2.parameters(),
lr=qf_lr,
)
self.discount = discount
self.reward_scale = reward_scale
self.eval_statistics = OrderedDict()
self._n_train_steps_total = 0
self._need_to_update_eval_statistics = True
self.policy_eval_start = policy_eval_start
self._current_epoch = 0
self._policy_update_ctr = 0
self._num_q_update_steps = 0
self._num_policy_update_steps = 0
self._num_policy_steps = 1
self.num_qs = num_qs
## min Q
self.temp = temp
self.min_q_version = min_q_version
self.min_q_weight = min_q_weight
self.softmax = torch.nn.Softmax(dim=1)
self.softplus = torch.nn.Softplus(beta=self.temp, threshold=20)
self.max_q_backup = max_q_backup
self.deterministic_backup = deterministic_backup
self.num_random = num_random
# For implementation on the
self.discrete = False
def _get_tensor_values(self, obs, actions, network=None):
action_shape = actions.shape[0]
obs_shape = obs.shape[0]
num_repeat = int (action_shape / obs_shape)
obs_temp = obs.unsqueeze(1).repeat(1, num_repeat, 1).view(obs.shape[0] * num_repeat, obs.shape[1])
preds = network(obs_temp, actions)
preds = preds.view(obs.shape[0], num_repeat, 1)
return preds
def _get_policy_actions(self, obs, num_actions, network=None):
obs_temp = obs.unsqueeze(1).repeat(1, num_actions, 1).view(obs.shape[0] * num_actions, obs.shape[1])
new_obs_actions, _, _, new_obs_log_pi, *_ = network(
obs_temp, reparameterize=True, return_log_prob=True,
)
if not self.discrete:
return new_obs_actions, new_obs_log_pi.view(obs.shape[0], num_actions, 1)
else:
return new_obs_actions
def train_from_torch(self, batch):
self._current_epoch += 1
rewards = batch['rewards']
terminals = batch['terminals']
obs = batch['observations']
actions = batch['actions']
next_obs = batch['next_observations']
"""
Policy and Alpha Loss
"""
new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
obs, reparameterize=True, return_log_prob=True,
)
if self.use_automatic_entropy_tuning:
alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
self.alpha_optimizer.zero_grad()
alpha_loss.backward()
self.alpha_optimizer.step()
alpha = self.log_alpha.exp()
else:
alpha_loss = 0
alpha = 1
if self.num_qs == 1:
q_new_actions = self.qf1(obs, new_obs_actions)
else:
q_new_actions = torch.min(
self.qf1(obs, new_obs_actions),
self.qf2(obs, new_obs_actions),
)
policy_loss = (alpha*log_pi - q_new_actions).mean()
if self._current_epoch < self.policy_eval_start:
"""
For the initial few epochs, try doing behaivoral cloning, if needed
conventionally, there's not much difference in performance with having 20k
gradient steps here, or not having it
"""
policy_log_prob = self.policy.log_prob(obs, actions)
policy_loss = (alpha * log_pi - policy_log_prob).mean()
"""
QF Loss
"""
q1_pred = self.qf1(obs, actions)
if self.num_qs > 1:
q2_pred = self.qf2(obs, actions)
new_next_actions, _, _, new_log_pi, *_ = self.policy(
next_obs, reparameterize=True, return_log_prob=True,
)
new_curr_actions, _, _, new_curr_log_pi, *_ = self.policy(
obs, reparameterize=True, return_log_prob=True,
)
if not self.max_q_backup:
if self.num_qs == 1:
target_q_values = self.target_qf1(next_obs, new_next_actions)
else:
target_q_values = torch.min(
self.target_qf1(next_obs, new_next_actions),
self.target_qf2(next_obs, new_next_actions),
)
if not self.deterministic_backup:
target_q_values = target_q_values - alpha * new_log_pi
if self.max_q_backup:
"""when using max q backup"""
next_actions_temp, _ = self._get_policy_actions(next_obs, num_actions=10, network=self.policy)
target_qf1_values = self._get_tensor_values(next_obs, next_actions_temp, network=self.target_qf1).max(1)[0].view(-1, 1)
target_qf2_values = self._get_tensor_values(next_obs, next_actions_temp, network=self.target_qf2).max(1)[0].view(-1, 1)
target_q_values = torch.min(target_qf1_values, target_qf2_values)
q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values
q_target = q_target.detach()
qf1_loss = self.qf_criterion(q1_pred, q_target)
if self.num_qs > 1:
qf2_loss = self.qf_criterion(q2_pred, q_target)
## add CQL
random_actions_tensor = torch.FloatTensor(q2_pred.shape[0] * self.num_random, actions.shape[-1]).uniform_(-1, 1) # .cuda()
curr_actions_tensor, curr_log_pis = self._get_policy_actions(obs, num_actions=self.num_random, network=self.policy)
new_curr_actions_tensor, new_log_pis = self._get_policy_actions(next_obs, num_actions=self.num_random, network=self.policy)
q1_rand = self._get_tensor_values(obs, random_actions_tensor, network=self.qf1)
q2_rand = self._get_tensor_values(obs, random_actions_tensor, network=self.qf2)
q1_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf1)
q2_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf2)
q1_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf1)
q2_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf2)
cat_q1 = torch.cat(
[q1_rand, q1_pred.unsqueeze(1), q1_next_actions, q1_curr_actions], 1
)
cat_q2 = torch.cat(
[q2_rand, q2_pred.unsqueeze(1), q2_next_actions, q2_curr_actions], 1
)
std_q1 = torch.std(cat_q1, dim=1)
std_q2 = torch.std(cat_q2, dim=1)
if self.min_q_version == 3:
# importance sammpled version
random_density = np.log(0.5 ** curr_actions_tensor.shape[-1])
cat_q1 = torch.cat(
[q1_rand - random_density, q1_next_actions - new_log_pis.detach(), q1_curr_actions - curr_log_pis.detach()], 1
)
cat_q2 = torch.cat(
[q2_rand - random_density, q2_next_actions - new_log_pis.detach(), q2_curr_actions - curr_log_pis.detach()], 1
)
min_qf1_loss = torch.logsumexp(cat_q1 / self.temp, dim=1,).mean() * self.min_q_weight * self.temp
min_qf2_loss = torch.logsumexp(cat_q2 / self.temp, dim=1,).mean() * self.min_q_weight * self.temp
"""Subtract the log likelihood of data"""
min_qf1_loss = min_qf1_loss - q1_pred.mean() * self.min_q_weight
min_qf2_loss = min_qf2_loss - q2_pred.mean() * self.min_q_weight
if self.with_lagrange:
alpha_prime = torch.clamp(self.log_alpha_prime.exp(), min=0.0, max=1000000.0)
min_qf1_loss = alpha_prime * (min_qf1_loss - self.target_action_gap)
min_qf2_loss = alpha_prime * (min_qf2_loss - self.target_action_gap)
self.alpha_prime_optimizer.zero_grad()
alpha_prime_loss = (-min_qf1_loss - min_qf2_loss)*0.5
alpha_prime_loss.backward(retain_graph=True)
self.alpha_prime_optimizer.step()
qf1_loss = qf1_loss + min_qf1_loss
qf2_loss = qf2_loss + min_qf2_loss
"""
Update networks
"""
# Update the Q-functions iff
self._num_q_update_steps += 1
self.qf1_optimizer.zero_grad()
qf1_loss.backward(retain_graph=True)
self.qf1_optimizer.step()
if self.num_qs > 1:
self.qf2_optimizer.zero_grad()
qf2_loss.backward(retain_graph=True)
self.qf2_optimizer.step()
self._num_policy_update_steps += 1
self.policy_optimizer.zero_grad()
policy_loss.backward(retain_graph=False)
self.policy_optimizer.step()
"""
Soft Updates
"""
ptu.soft_update_from_to(
self.qf1, self.target_qf1, self.soft_target_tau
)
if self.num_qs > 1:
ptu.soft_update_from_to(
self.qf2, self.target_qf2, self.soft_target_tau
)
"""
Save some statistics for eval
"""
if self._need_to_update_eval_statistics:
self._need_to_update_eval_statistics = False
"""
Eval should set this to None.
This way, these statistics are only computed for one batch.
"""
policy_loss = (log_pi - q_new_actions).mean()
self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
self.eval_statistics['min QF1 Loss'] = np.mean(ptu.get_numpy(min_qf1_loss))
if self.num_qs > 1:
self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
self.eval_statistics['min QF2 Loss'] = np.mean(ptu.get_numpy(min_qf2_loss))
if not self.discrete:
self.eval_statistics['Std QF1 values'] = np.mean(ptu.get_numpy(std_q1))
self.eval_statistics['Std QF2 values'] = np.mean(ptu.get_numpy(std_q2))
self.eval_statistics.update(create_stats_ordered_dict(
'QF1 in-distribution values',
ptu.get_numpy(q1_curr_actions),
))
self.eval_statistics.update(create_stats_ordered_dict(
'QF2 in-distribution values',
ptu.get_numpy(q2_curr_actions),
))
self.eval_statistics.update(create_stats_ordered_dict(
'QF1 random values',
ptu.get_numpy(q1_rand),
))
self.eval_statistics.update(create_stats_ordered_dict(
'QF2 random values',
ptu.get_numpy(q2_rand),
))
self.eval_statistics.update(create_stats_ordered_dict(
'QF1 next_actions values',
ptu.get_numpy(q1_next_actions),
))
self.eval_statistics.update(create_stats_ordered_dict(
'QF2 next_actions values',
ptu.get_numpy(q2_next_actions),
))
self.eval_statistics.update(create_stats_ordered_dict(
'actions',
ptu.get_numpy(actions)
))
self.eval_statistics.update(create_stats_ordered_dict(
'rewards',
ptu.get_numpy(rewards)
))
self.eval_statistics['Num Q Updates'] = self._num_q_update_steps
self.eval_statistics['Num Policy Updates'] = self._num_policy_update_steps
self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
policy_loss
))
self.eval_statistics.update(create_stats_ordered_dict(
'Q1 Predictions',
ptu.get_numpy(q1_pred),
))
if self.num_qs > 1:
self.eval_statistics.update(create_stats_ordered_dict(
'Q2 Predictions',
ptu.get_numpy(q2_pred),
))
self.eval_statistics.update(create_stats_ordered_dict(
'Q Targets',
ptu.get_numpy(q_target),
))
self.eval_statistics.update(create_stats_ordered_dict(
'Log Pis',
ptu.get_numpy(log_pi),
))
if not self.discrete:
self.eval_statistics.update(create_stats_ordered_dict(
'Policy mu',
ptu.get_numpy(policy_mean),
))
self.eval_statistics.update(create_stats_ordered_dict(
'Policy log std',
ptu.get_numpy(policy_log_std),
))
if self.use_automatic_entropy_tuning:
self.eval_statistics['Alpha'] = alpha.item()
self.eval_statistics['Alpha Loss'] = alpha_loss.item()
if self.with_lagrange:
self.eval_statistics['Alpha_prime'] = alpha_prime.item()
self.eval_statistics['min_q1_loss'] = ptu.get_numpy(min_qf1_loss).mean()
self.eval_statistics['min_q2_loss'] = ptu.get_numpy(min_qf2_loss).mean()
self.eval_statistics['threshold action gap'] = self.target_action_gap
self.eval_statistics['alpha prime loss'] = alpha_prime_loss.item()
self._n_train_steps_total += 1
def get_diagnostics(self):
return self.eval_statistics
def end_epoch(self, epoch):
self._need_to_update_eval_statistics = True
@property
def networks(self):
base_list = [
self.policy,
self.qf1,
self.qf2,
self.target_qf1,
self.target_qf2,
]
return base_list
def get_snapshot(self):
return dict(
policy=self.policy,
qf1=self.qf1,
qf2=self.qf2,
target_qf1=self.target_qf1,
target_qf2=self.target_qf2,
)