Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learnable initial recurrent state choice #256

Merged
merged 7 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sheeprl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
np.int = np.int64
np.bool = bool

__version__ = "0.5.5.dev0"
__version__ = "0.5.5.dev1"


# Replace `moviepy.decorators.use_clip_fps_by_default` method to work with python 3.8, 3.9, and 3.10
Expand Down
30 changes: 23 additions & 7 deletions sheeprl/algos/dreamer_v3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,17 +370,23 @@ def __init__(
distribution_cfg: Dict[str, Any],
discrete: int = 32,
unimix: float = 0.01,
learnable_initial_recurrent_state: bool = True,
) -> None:
super().__init__()
self.recurrent_model = recurrent_model
self.representation_model = representation_model
self.transition_model = transition_model
self.distribution_cfg = distribution_cfg
self.discrete = discrete
self.unimix = unimix
self.distribution_cfg = distribution_cfg
self.initial_recurrent_state = nn.Parameter(
torch.zeros(recurrent_model.recurrent_state_size, dtype=torch.float32)
)
if learnable_initial_recurrent_state:
self.initial_recurrent_state = nn.Parameter(
torch.zeros(recurrent_model.recurrent_state_size, dtype=torch.float32)
)
else:
self.register_buffer(
"initial_recurrent_state", torch.zeros(recurrent_model.recurrent_state_size, dtype=torch.float32)
)

def get_initial_states(self, batch_shape: Sequence[int] | torch.Size) -> Tuple[Tensor, Tensor]:
initial_recurrent_state = torch.tanh(self.initial_recurrent_state).expand(*batch_shape, -1)
Expand Down Expand Up @@ -521,8 +527,17 @@ def __init__(
distribution_cfg: Dict[str, Any],
discrete: int = 32,
unimix: float = 0.01,
learnable_initial_recurrent_state: bool = True,
) -> None:
super().__init__(recurrent_model, representation_model, transition_model, distribution_cfg, discrete, unimix)
super().__init__(
recurrent_model,
representation_model,
transition_model,
distribution_cfg,
discrete,
unimix,
learnable_initial_recurrent_state,
)

def dynamic(
self, posterior: Tensor, recurrent_state: Tensor, action: Tensor, is_first: Tensor
Expand Down Expand Up @@ -1020,7 +1035,7 @@ def build_agent(
layer_norm_kw=world_model_cfg.recurrent_model.layer_norm.kw,
)
represention_model_input_size = encoder.output_dim
if not cfg.algo.decoupled_rssm:
if not cfg.algo.world_model.decoupled_rssm:
represention_model_input_size += recurrent_state_size
representation_ln_cls = hydra.utils.get_class(world_model_cfg.representation_model.layer_norm.cls)
representation_model = MLP(
Expand Down Expand Up @@ -1055,7 +1070,7 @@ def build_agent(
],
)

if cfg.algo.decoupled_rssm:
if cfg.algo.world_model.decoupled_rssm:
rssm_cls = DecoupledRSSM
else:
rssm_cls = RSSM
Expand All @@ -1066,6 +1081,7 @@ def build_agent(
distribution_cfg=cfg.distribution,
discrete=world_model_cfg.discrete_size,
unimix=cfg.algo.unimix,
learnable_initial_recurrent_state=cfg.algo.world_model.learnable_initial_recurrent_state,
).to(fabric.device)

cnn_decoder = (
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def train(
# Embed observations from the environment
embedded_obs = world_model.encoder(batch_obs)

if cfg.algo.decoupled_rssm:
if cfg.algo.world_model.decoupled_rssm:
posteriors_logits, posteriors = world_model.rssm._representation(embedded_obs)
for i in range(0, sequence_length):
if i == 0:
Expand Down Expand Up @@ -450,7 +450,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
cfg.algo.world_model.stochastic_size,
cfg.algo.world_model.recurrent_model.recurrent_state_size,
discrete_size=cfg.algo.world_model.discrete_size,
decoupled_rssm=cfg.algo.decoupled_rssm,
decoupled_rssm=cfg.algo.world_model.decoupled_rssm,
)

# Optimizers
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v3/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
cfg.algo.world_model.stochastic_size,
cfg.algo.world_model.recurrent_model.recurrent_state_size,
discrete_size=cfg.algo.world_model.discrete_size,
decoupled_rssm=cfg.algo.decoupled_rssm,
decoupled_rssm=cfg.algo.world_model.decoupled_rssm,
)

test(player, fabric, cfg, log_dir, sample_actions=True)
3 changes: 2 additions & 1 deletion sheeprl/configs/algo/dreamer_v3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ dense_act: torch.nn.SiLU
cnn_act: torch.nn.SiLU
unimix: 0.01
hafner_initialization: True
decoupled_rssm: False

# World model
world_model:
Expand All @@ -51,6 +50,8 @@ world_model:
kl_regularizer: 1.0
continue_scale_factor: 1.0
clip_gradients: 1000.0
decoupled_rssm: False
learnable_initial_recurrent_state: True

# Encoder
encoder:
Expand Down
Loading