From 5644c598000564dc1b501b5d7b663960a161a551 Mon Sep 17 00:00:00 2001 From: belerico Date: Thu, 4 Apr 2024 22:24:53 +0200 Subject: [PATCH 1/4] Learnable initial recurrent state choice --- sheeprl/algos/dreamer_v3/agent.py | 15 +++++++++++---- sheeprl/configs/algo/dreamer_v3.yaml | 1 + 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index adcbf195..45194c8d 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -361,17 +361,23 @@ def __init__( distribution_cfg: Dict[str, Any], discrete: int = 32, unimix: float = 0.01, + learnable_initial_recurrent_state: bool = True, ) -> None: super().__init__() self.recurrent_model = recurrent_model self.representation_model = representation_model self.transition_model = transition_model + self.distribution_cfg = distribution_cfg self.discrete = discrete self.unimix = unimix - self.distribution_cfg = distribution_cfg - self.initial_recurrent_state = nn.Parameter( - torch.zeros(recurrent_model.recurrent_state_size, dtype=torch.float32) - ) + if learnable_initial_recurrent_state: + self.initial_recurrent_state = nn.Parameter( + torch.zeros(recurrent_model.recurrent_state_size, dtype=torch.float32) + ) + else: + self.register_buffer( + "initial_recurrent_state", torch.zeros(recurrent_model.recurrent_state_size, dtype=torch.float32) + ) def get_initial_states(self, batch_shape: Sequence[int] | torch.Size) -> Tuple[Tensor, Tensor]: initial_recurrent_state = torch.tanh(self.initial_recurrent_state).expand(*batch_shape, -1) @@ -1057,6 +1063,7 @@ def build_agent( distribution_cfg=cfg.distribution, discrete=world_model_cfg.discrete_size, unimix=cfg.algo.unimix, + learnable_initial_recurrent_state=cfg.algo.learnable_initial_recurrent_state, ).to(fabric.device) cnn_decoder = ( diff --git a/sheeprl/configs/algo/dreamer_v3.yaml b/sheeprl/configs/algo/dreamer_v3.yaml index 9b6e85fd..fe92e87b 100644 --- a/sheeprl/configs/algo/dreamer_v3.yaml +++ b/sheeprl/configs/algo/dreamer_v3.yaml @@ -40,6 +40,7 @@ cnn_act: torch.nn.SiLU unimix: 0.01 hafner_initialization: True decoupled_rssm: False +learnable_initial_recurrent_state: True # World model world_model: From c0188fab53edf05b856fb163bf0ade48b2c91aeb Mon Sep 17 00:00:00 2001 From: belerico Date: Fri, 5 Apr 2024 08:30:38 +0200 Subject: [PATCH 2/4] Fix DecoupledRSSM to accept the learnable_initial_recurrent_state flag --- sheeprl/algos/dreamer_v3/agent.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index 45194c8d..116b618f 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -518,8 +518,17 @@ def __init__( distribution_cfg: Dict[str, Any], discrete: int = 32, unimix: float = 0.01, + learnable_initial_recurrent_state: bool = True, ) -> None: - super().__init__(recurrent_model, representation_model, transition_model, distribution_cfg, discrete, unimix) + super().__init__( + recurrent_model, + representation_model, + transition_model, + distribution_cfg, + discrete, + unimix, + learnable_initial_recurrent_state, + ) def dynamic( self, posterior: Tensor, recurrent_state: Tensor, action: Tensor, is_first: Tensor From 4e1703089556ad5239aad99ed4ecb0d5dc7b4552 Mon Sep 17 00:00:00 2001 From: belerico Date: Fri, 5 Apr 2024 10:12:58 +0200 Subject: [PATCH 3/4] Move hyperparams to rightful key inside world_model --- sheeprl/algos/dreamer_v3/agent.py | 6 +++--- sheeprl/algos/dreamer_v3/dreamer_v3.py | 4 ++-- sheeprl/algos/dreamer_v3/evaluate.py | 2 +- sheeprl/configs/algo/dreamer_v3.yaml | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index 116b618f..fec8ac67 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -1026,7 +1026,7 @@ def build_agent( layer_norm_kw=world_model_cfg.recurrent_model.layer_norm.kw, ) represention_model_input_size = encoder.output_dim - if not cfg.algo.decoupled_rssm: + if not cfg.algo.world_model.decoupled_rssm: represention_model_input_size += recurrent_state_size representation_ln_cls = hydra.utils.get_class(world_model_cfg.representation_model.layer_norm.cls) representation_model = MLP( @@ -1061,7 +1061,7 @@ def build_agent( ], ) - if cfg.algo.decoupled_rssm: + if cfg.algo.world_model.decoupled_rssm: rssm_cls = DecoupledRSSM else: rssm_cls = RSSM @@ -1072,7 +1072,7 @@ def build_agent( distribution_cfg=cfg.distribution, discrete=world_model_cfg.discrete_size, unimix=cfg.algo.unimix, - learnable_initial_recurrent_state=cfg.algo.learnable_initial_recurrent_state, + learnable_initial_recurrent_state=cfg.algo.world_model.learnable_initial_recurrent_state, ).to(fabric.device) cnn_decoder = ( diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 7a7d2c17..cb5a2a56 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -112,7 +112,7 @@ def train( # Embed observations from the environment embedded_obs = world_model.encoder(batch_obs) - if cfg.algo.decoupled_rssm: + if cfg.algo.world_model.decoupled_rssm: posteriors_logits, posteriors = world_model.rssm._representation(embedded_obs) for i in range(0, sequence_length): if i == 0: @@ -450,7 +450,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, discrete_size=cfg.algo.world_model.discrete_size, - decoupled_rssm=cfg.algo.decoupled_rssm, + decoupled_rssm=cfg.algo.world_model.decoupled_rssm, ) # Optimizers diff --git a/sheeprl/algos/dreamer_v3/evaluate.py b/sheeprl/algos/dreamer_v3/evaluate.py index 12f885c6..cc0e67f9 100644 --- a/sheeprl/algos/dreamer_v3/evaluate.py +++ b/sheeprl/algos/dreamer_v3/evaluate.py @@ -63,7 +63,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): cfg.algo.world_model.stochastic_size, cfg.algo.world_model.recurrent_model.recurrent_state_size, discrete_size=cfg.algo.world_model.discrete_size, - decoupled_rssm=cfg.algo.decoupled_rssm, + decoupled_rssm=cfg.algo.world_model.decoupled_rssm, ) test(player, fabric, cfg, log_dir, sample_actions=True) diff --git a/sheeprl/configs/algo/dreamer_v3.yaml b/sheeprl/configs/algo/dreamer_v3.yaml index fe92e87b..3ccdb999 100644 --- a/sheeprl/configs/algo/dreamer_v3.yaml +++ b/sheeprl/configs/algo/dreamer_v3.yaml @@ -39,8 +39,6 @@ dense_act: torch.nn.SiLU cnn_act: torch.nn.SiLU unimix: 0.01 hafner_initialization: True -decoupled_rssm: False -learnable_initial_recurrent_state: True # World model world_model: @@ -52,6 +50,8 @@ world_model: kl_regularizer: 1.0 continue_scale_factor: 1.0 clip_gradients: 1000.0 + decoupled_rssm: False + learnable_initial_recurrent_state: True # Encoder encoder: From 24654e5000191d0e7a100fb69a367de95333d211 Mon Sep 17 00:00:00 2001 From: belerico Date: Fri, 5 Apr 2024 10:39:39 +0200 Subject: [PATCH 4/4] Update version --- sheeprl/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sheeprl/__init__.py b/sheeprl/__init__.py index 5ec12dc8..495a97f4 100644 --- a/sheeprl/__init__.py +++ b/sheeprl/__init__.py @@ -52,7 +52,7 @@ np.int = np.int64 np.bool = bool -__version__ = "0.5.5.dev0" +__version__ = "0.5.5.dev1" # Replace `moviepy.decorators.use_clip_fps_by_default` method to work with python 3.8, 3.9, and 3.10