-
Notifications
You must be signed in to change notification settings - Fork 6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RLlib] DreamerV3: Main algo code and required changes to some RLlib APIs (RolloutWorker) #35386
[RLlib] DreamerV3: Main algo code and required changes to some RLlib APIs (RolloutWorker) #35386
Conversation
rllib/algorithms/algorithm_config.py
Outdated
@@ -296,6 +295,9 @@ def __init__(self, algo_class=None): | |||
self.auto_wrap_old_gym_envs = True | |||
|
|||
# `self.rollouts()` | |||
# TODO (sven): Clean up the configuration of fully customizable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can now publicly configure the class used for rollouts. This used to be configurable before via config.debugging(worker_cls=..)
, but was not working correctly.
rllib/algorithms/algorithm_config.py
Outdated
@@ -838,7 +843,7 @@ def validate(self) -> None: | |||
self.model["_disable_action_flattening"] = True | |||
if self.model.get("custom_preprocessor"): | |||
deprecation_warning( | |||
old="model_config['custom_preprocessor']", | |||
old="AlgorithmConfig.training(model={'custom_preprocessor': ...})", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
enhanced
rllib/algorithms/algorithm_config.py
Outdated
@@ -2716,12 +2731,22 @@ def get_multi_agent_setup( | |||
# Normal env (gym.Env or MultiAgentEnv): These should have the | |||
# `observation_space` and `action_space` properties. | |||
elif env is not None: | |||
if hasattr(env, "observation_space") and isinstance( | |||
if hasattr(env, "single_observation_space") and isinstance( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Support new gym.vector.Env
envs, which have a single_action|observation_space
property.
@@ -60,7 +60,7 @@ def _forward_exploration(self, batch: NestedDict) -> Mapping[str, Any]: | |||
|
|||
return output | |||
|
|||
@override(TfRLModule) | |||
@override(RLModule) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
rllib/core/learner/learner.py
Outdated
@@ -352,7 +352,7 @@ def _configure_optimizers_per_module_helper( | |||
pairs.append(pair) | |||
elif isinstance(pair_or_pairs, dict): | |||
# pair_or_pairs is a NamedParamOptimizerPairs | |||
for name, pair in pairs.items(): | |||
for name, pair in pair_or_pairs.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was a bug, but not visible for Learners that only use the default (single) optimizer path.
rllib/core/learner/learner.py
Outdated
@@ -435,8 +435,25 @@ def compute_gradients(self, loss: Mapping[str, Any]) -> ParamDictType: | |||
The gradients in teh same format as self._params. | |||
""" | |||
|
|||
@OverrideToImplementCustomLogic |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorted this into a better position. It should always be together:
compute_grads
postprocess_grads
apply_grads
^ in that order
rllib/core/learner/learner.py
Outdated
@abc.abstractmethod | ||
def apply_gradients(self, gradients: ParamDictType) -> None: | ||
def apply_gradients(self, gradients_dict: ParamDictType) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consistent args naming.
rllib/core/learner/learner.py
Outdated
forward passes within this method, and to use the "forward_train" outputs to | ||
compute the required tensors for loss calculation. | ||
"fwd_out". The returned dictionary must contain a key called | ||
`self.TOTAL_LOSS_KEY`, which will be used to compute gradients. It is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use constant name for this key.
rllib/core/learner/learner.py
Outdated
@@ -811,7 +807,7 @@ def update( | |||
reduce_fn: Callable[[List[Mapping[str, Any]]], ResultDict] = ( | |||
_reduce_mean_results | |||
), | |||
) -> Mapping[str, Any]: | |||
) -> Union[Mapping[str, Any], List[Mapping[str, Any]]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If reduce_fn
not given, might return a list of dicts.
rllib/core/learner/tf/tf_learner.py
Outdated
@@ -124,7 +122,7 @@ def postprocess_gradients( | |||
return gradients_dict | |||
|
|||
@override(Learner) | |||
def apply_gradients(self, gradients: ParamDictType) -> None: | |||
def apply_gradients(self, gradients_dict: ParamDictType) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same: more consistent args naming.
@@ -490,11 +489,13 @@ def helper(_batch): | |||
# constraint on forward_train and compute_loss APIs. This seems to be | |||
# in-efficient. Make it efficient. | |||
_batch = NestedDict(_batch) | |||
with tf.GradientTape() as tape: | |||
with tf.GradientTape(persistent=True) as tape: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Necessary for multiple optimizers that operate on the same RLModule.
@@ -93,24 +94,8 @@ def compute_gradients( | |||
|
|||
return grads | |||
|
|||
@OverrideToImplementCustomLogic_CallToSuperRecommended |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved down to a better location in the file.
self._params[pid].grad = grad | ||
|
||
# for each optimizer call its step function with the gradients | ||
for optim in self._optimizer_parameters: | ||
optim.step() | ||
|
||
@OverrideToImplementCustomLogic_CallToSuperRecommended |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.. to here :)
rllib/core/models/catalog.py
Outdated
@@ -25,93 +25,6 @@ | |||
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space | |||
|
|||
|
|||
def _multi_action_dist_partial_helper( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved down for making the main Catalog class in this file more prominent. We should generally always move private functions to the end of files to avoid confusion and make the main class(es) in a file more visible.
…mer_v3_02_learner
Signed-off-by: sven1977 <svenmika1977@gmail.com>
Signed-off-by: sven1977 <svenmika1977@gmail.com>
ffc1827
to
1b28c73
Compare
…APIs (RolloutWorker). (ray-project#35386) Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
…e RLlib APIs (RolloutWorker). (ray-project#35386)" (ray-project#36564) This reverts commit 8290bd1. Signed-off-by: e428265 <arvind.chandramouli@lmco.com>
DreamerV3:
Main algo code (dreamerv3.py, README) compilation and model size (architecture) tests.
Added DreamerV3Catalog.
Added DreamerV3 Algorithm class and config.
Some changes to RLlib:
AlgorithmConfig.rollouts(env_runner_class=...)
setting.TfLearner.update()
persistent=True.Managed to keep the Learner API as-is by simply overriding the
DreamerV3TfLearner.compute_gradients()
method. W/o overriding this, DreamerV3 on tf will not learn as computing gradients for the TOTAL_LOSS_KEY over all model params messes up world model gradients.Why are these changes needed?
Related issue number
Checks
git commit -s
) in this PR.scripts/format.sh
to lint the changes in this PR.method in Tune, I've added it in
doc/source/tune/api/
under thecorresponding
.rst
file.