-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib; Offline RL] CQL: Support multi-GPU/CPU setup and different learning rates for actor, critic, and alpha. #47402
Changes from 5 commits
8923dc2
75d91ce
331f0ef
fe79e3d
ecbd588
2be80fd
cb9ecbe
54b5d2d
e6b3769
62d6dfb
edbe263
100aabe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,19 +2,20 @@ | |
from typing import Optional, Type, Union | ||
|
||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided | ||
from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy | ||
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy | ||
from ray.rllib.algorithms.sac.sac import ( | ||
SAC, | ||
SACConfig, | ||
) | ||
from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import ( | ||
AddObservationsFromEpisodesToBatch, | ||
) | ||
from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa | ||
AddNextObservationsFromEpisodesToTrainBatch, | ||
) | ||
from ray.rllib.core.learner.learner import Learner | ||
from ray.rllib.algorithms.cql.cql_tf_policy import CQLTFPolicy | ||
from ray.rllib.algorithms.cql.cql_torch_policy import CQLTorchPolicy | ||
from ray.rllib.algorithms.sac.sac import ( | ||
SAC, | ||
SACConfig, | ||
) | ||
from ray.rllib.core.rl_module.rl_module import RLModuleSpec | ||
from ray.rllib.execution.rollout_ops import ( | ||
synchronous_parallel_sample, | ||
) | ||
|
@@ -48,7 +49,7 @@ | |
SAMPLE_TIMER, | ||
TIMERS, | ||
) | ||
from ray.rllib.utils.typing import ResultDict | ||
from ray.rllib.utils.typing import ResultDict, RLModuleSpecType | ||
|
||
tf1, tf, tfv = try_import_tf() | ||
tfp = try_import_tfp() | ||
|
@@ -84,6 +85,12 @@ def __init__(self, algo_class=None): | |
self.lagrangian_thresh = 5.0 | ||
self.min_q_weight = 5.0 | ||
self.lr = 3e-4 | ||
# Note, the new stack defines learning rates for each component. | ||
# The base learning rate `lr` has to be set to `None`, if using | ||
# the new stack. | ||
self.actor_lr = 2e-4, | ||
self.critic_lr = 8e-4 | ||
self.alpha_lr = 9e-4 | ||
|
||
# Changes to Algorithm's/SACConfig's default: | ||
|
||
|
@@ -234,6 +241,28 @@ def validate(self) -> None: | |
"Set this hyperparameter in the `AlgorithmConfig.offline_data`." | ||
) | ||
|
||
@override(SACConfig) | ||
def get_default_rl_module_spec(self) -> RLModuleSpecType: | ||
from ray.rllib.algorithms.sac.sac_catalog import SACCatalog | ||
|
||
if self.framework_str == "torch": | ||
from ray.rllib.algorithms.cql.torch.cql_torch_rl_module import ( | ||
CQLTorchRLModule, | ||
) | ||
|
||
return RLModuleSpec(module_class=CQLTorchRLModule, catalog_class=SACCatalog) | ||
else: | ||
raise ValueError( | ||
f"The framework {self.framework_str} is not supported. " "Use `torch`." | ||
) | ||
|
||
@property | ||
def _model_config_auto_includes(self): | ||
return super()._model_config_auto_includes | { | ||
"num_actions": self.num_actions, | ||
"_deterministic_loss": self._deterministic_loss, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Let's remove this deterministic loss thing. It's a relic from a long time ago (2020) when I was trying to debug SAC on torch vs our old SAC on tf. It serves no real purpose and just bloats the code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great!! That saves us many lines of code! |
||
} | ||
|
||
|
||
class CQL(SAC): | ||
"""CQL (derived from SAC).""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, so for the new stack, users have to set this to None, manually? I guess this is ok (explicit is always good).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, exactly. We discussed this in the other PR concerning
SAC
.