From 15105814e40bfcf1b83886f722bb573777d03d47 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Fri, 22 Dec 2023 17:11:19 +0100 Subject: [PATCH 1/3] Do not create metrics when timer is disabled + do not cat when unneeded --- howto/register_external_algorithm.md | 2 +- howto/register_new_algorithm.md | 2 +- sheeprl/algos/dreamer_v1/dreamer_v1.py | 4 ++-- sheeprl/algos/dreamer_v2/dreamer_v2.py | 4 ++-- sheeprl/algos/dreamer_v3/dreamer_v3.py | 4 ++-- sheeprl/algos/droq/droq.py | 4 ++-- sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py | 4 ++-- sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py | 4 ++-- sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py | 4 ++-- sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py | 4 ++-- sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py | 4 ++-- sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py | 4 ++-- sheeprl/algos/ppo/ppo.py | 4 ++-- sheeprl/algos/ppo/ppo_decoupled.py | 2 +- sheeprl/algos/ppo_recurrent/ppo_recurrent.py | 4 ++-- sheeprl/algos/sac/sac.py | 4 ++-- sheeprl/algos/sac/sac_decoupled.py | 2 +- sheeprl/algos/sac_ae/sac_ae.py | 4 ++-- sheeprl/models/models.py | 19 +++++++++++++------ sheeprl/utils/timer.py | 7 ++++--- 20 files changed, 49 insertions(+), 41 deletions(-) diff --git a/howto/register_external_algorithm.md b/howto/register_external_algorithm.md index 6b8a3aa4..f7816fcd 100644 --- a/howto/register_external_algorithm.md +++ b/howto/register_external_algorithm.md @@ -268,7 +268,7 @@ def ext_sota_main(fabric: Fabric, cfg: Dict[str, Any]): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): with torch.no_grad(): # Sample an action given the observation received by the environment normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) diff --git a/howto/register_new_algorithm.md b/howto/register_new_algorithm.md index 67d2bee6..84b5fc27 100644 --- a/howto/register_new_algorithm.md +++ b/howto/register_new_algorithm.md @@ -265,7 +265,7 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): with torch.no_grad(): # Sample an action given the observation received by the environment normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index eedfc7c8..88d169a5 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -589,7 +589,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): # Sample an action given the observation received by the environment if ( update <= learning_starts @@ -681,7 +681,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Train the agent if update > learning_starts and updates_before_training <= 0: # Start training - with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): for i in range(cfg.algo.per_rank_gradient_steps): sample = rb.sample_tensors( batch_size=cfg.algo.per_rank_batch_size, diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index c1d9e7b1..040b67fa 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -631,7 +631,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): # Sample an action given the observation received by the environment if ( update <= learning_starts @@ -735,7 +735,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): device=fabric.device, from_numpy=cfg.buffer.from_numpy, ) - with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): for i in range(next(iter(local_data.values())).shape[0]): if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: for cp, tcp in zip(critic.module.parameters(), target_critic.parameters()): diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index f1a96a59..11df56f8 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -566,7 +566,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): # Sample an action given the observation received by the environment if ( update <= learning_starts @@ -678,7 +678,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): device=fabric.device, from_numpy=cfg.buffer.from_numpy, ) - with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): for i in range(next(iter(local_data.values())).shape[0]): if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index e28ecb2a..0c57f80b 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -79,7 +79,7 @@ def train( ) actor_data = {k: actor_data[k][next(iter(actor_sampler))] for k in actor_data.keys()} - with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): # Update the soft-critic for batch_idxes in critic_sampler: critic_batch_data = {k: critic_data[k][batch_idxes] for k in critic_data.keys()} @@ -283,7 +283,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): with torch.no_grad(): # Sample an action given the observation received by the environment actions, _ = agent.actor.module(torch.from_numpy(obs).to(device)) diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py index 17471e72..55145c7d 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py @@ -629,7 +629,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): # Sample an action given the observation received by the environment if ( update <= learning_starts @@ -721,7 +721,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Train the agent if update >= learning_starts and updates_before_training <= 0: # Start training - with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): for i in range(cfg.algo.per_rank_gradient_steps): sample = rb.sample_tensors( batch_size=cfg.algo.per_rank_batch_size, diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index f68c3b70..2fb2e7cb 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -271,7 +271,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): with torch.no_grad(): normalized_obs = {} for k in obs_keys: @@ -349,7 +349,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): if player.actor_type == "exploration": player.actor = actor_task.module player.actor_type = "task" - with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): for i in range(cfg.algo.per_rank_gradient_steps): sample = rb.sample_tensors( batch_size=cfg.algo.per_rank_batch_size, diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index 3bd54795..d371c7ca 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -776,7 +776,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): # Sample an action given the observation received by the environment if ( update <= learning_starts @@ -881,7 +881,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): from_numpy=cfg.buffer.from_numpy, ) # Start training - with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): for i in range(next(iter(local_data.values())).shape[0]): if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index f87986be..931b331f 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -291,7 +291,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): with torch.no_grad(): normalized_obs = {} for k in obs_keys: @@ -383,7 +383,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): from_numpy=cfg.buffer.from_numpy, ) # Start training - with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): for i in range(next(iter(local_data.values())).shape[0]): if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: for cp, tcp in zip(critic_task.module.parameters(), target_critic_task.parameters()): diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index 962fc701..34a01b7c 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -837,7 +837,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): # Sample an action given the observation received by the environment if ( update <= learning_starts @@ -950,7 +950,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): from_numpy=cfg.buffer.from_numpy, ) # Start training - with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): for i in range(next(iter(local_data.values())).shape[0]): if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index 8c9aa0a5..8176263f 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -282,7 +282,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): with torch.no_grad(): preprocessed_obs = {} for k, v in obs.items(): @@ -382,7 +382,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): from_numpy=cfg.buffer.from_numpy, ) # Start training - with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): for i in range(next(iter(local_data.values())).shape[0]): tau = 1 if per_rank_gradient_steps == 0 else cfg.algo.critic.tau if per_rank_gradient_steps % cfg.algo.critic.target_network_update_freq == 0: diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index b28dcfba..15ed07f5 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -269,7 +269,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): with torch.no_grad(): # Sample an action given the observation received by the environment normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) @@ -372,7 +372,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Flatten the first two dimensions: [Buffer_Size, Num_Envs] gathered_data = {k: v.flatten(start_dim=0, end_dim=1).float() for k, v in local_data.items()} - with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): train(fabric, agent, optimizer, gathered_data, aggregator, cfg) train_step += world_size diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index 8cb32364..070a1a65 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -191,7 +191,7 @@ def player( # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): with torch.no_grad(): # Sample an action given the observation received by the environment normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index 5387be36..66e9cbd8 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -281,7 +281,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): with torch.no_grad(): # Sample an action given the observation received by the environment # [Seq_len, Batch_size, D] --> [1, num_envs, D] @@ -438,7 +438,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): mask = (torch.arange(max_len).expand(len(lengths), max_len) < lengths.unsqueeze(1)).T padded_sequences["mask"] = mask.to(device).bool() - with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): train(fabric, agent, optimizer, padded_sequences, aggregator, cfg) train_step += world_size diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index c7360d08..e4b7c7fc 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -234,7 +234,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): if update <= learning_starts: actions = envs.action_space.sample() else: @@ -314,7 +314,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) # Start training - with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): for batch_idxes in sampler: batch = {k: v[batch_idxes] for k, v in gathered_data.items()} train( diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index d87b0b39..11f54cac 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -174,7 +174,7 @@ def player( # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): if update <= learning_starts: actions = envs.action_space.sample() else: diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index 7421af7f..5d4069fd 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -310,7 +310,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment - with timer("Time/env_interaction_time", SumMetric(sync_on_compute=False)): + with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): if update < learning_starts: actions = envs.action_space.sample() else: @@ -392,7 +392,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) # Start training - with timer("Time/train_time", SumMetric(sync_on_compute=cfg.metric.sync_on_compute)): + with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): for batch_idxes in sampler: train( fabric, diff --git a/sheeprl/models/models.py b/sheeprl/models/models.py index fcd349ac..df774c13 100644 --- a/sheeprl/models/models.py +++ b/sheeprl/models/models.py @@ -411,6 +411,8 @@ def __init__( super().__init__() if cnn_encoder is None and mlp_encoder is None: raise ValueError("There must be at least one encoder, both cnn and mlp encoders are None") + self.has_cnn_encoder = False + self.has_mlp_encoder = False if cnn_encoder is not None: if getattr(cnn_encoder, "input_dim", None) is None: raise AttributeError( @@ -422,6 +424,7 @@ def __init__( "`cnn_encoder` must contain the `output_dim` attribute representing " "the dimension of the output tensor" ) + self.has_cnn_encoder = True if mlp_encoder is not None: if getattr(mlp_encoder, "input_dim", None) is None: raise AttributeError( @@ -433,6 +436,8 @@ def __init__( "`mlp_encoder` must contain the `output_dim` attribute representing " "the dimension of the output tensor" ) + self.has_mlp_encoder = True + self.has_both_encoders = self.has_cnn_encoder and self.has_mlp_encoder self.cnn_encoder = cnn_encoder self.mlp_encoder = mlp_encoder self.cnn_input_dim = self.cnn_encoder.input_dim if self.cnn_encoder is not None else None @@ -450,14 +455,16 @@ def mlp_keys(self) -> Sequence[str]: return self.mlp_encoder.keys if self.mlp_encoder is not None else [] def forward(self, obs: Dict[str, Tensor], *args, **kwargs) -> Tensor: - device = obs[list(obs.keys())[0]].device - cnn_out = torch.tensor((), device=device) - mlp_out = torch.tensor((), device=device) - if self.cnn_encoder is not None: + if self.has_cnn_encoder: cnn_out = self.cnn_encoder(obs, *args, **kwargs) - if self.mlp_encoder is not None: + if self.has_mlp_encoder: mlp_out = self.mlp_encoder(obs, *args, **kwargs) - return torch.cat((cnn_out, mlp_out), -1) + if self.has_both_encoders: + return torch.cat((cnn_out, mlp_out), -1) + elif self.has_cnn_encoder: + return cnn_out + else: + return mlp_out class MultiDecoder(nn.Module): diff --git a/sheeprl/utils/timer.py b/sheeprl/utils/timer.py index dcbd3d6a..f501e1e2 100644 --- a/sheeprl/utils/timer.py +++ b/sheeprl/utils/timer.py @@ -1,8 +1,9 @@ # timer.py +from __future__ import annotations import time from contextlib import ContextDecorator -from typing import Dict, Optional, Union +from typing import Dict, Optional, Type, Union import torch from torchmetrics import Metric, SumMetric @@ -19,11 +20,11 @@ class timer(ContextDecorator): timers: Dict[str, Metric] = {} _start_time: Optional[float] = None - def __init__(self, name: str, metric: Optional[Metric] = None) -> None: + def __init__(self, name: str, metric: Optional[Type[Metric]] = None, **kwargs) -> None: """Add timer to dict of timers after initialization""" self.name = name if not timer.disabled and self.name is not None and self.name not in self.timers.keys(): - self.timers.setdefault(self.name, metric if metric is not None else SumMetric()) + self.timers.setdefault(self.name, metric(**kwargs) if metric is not None else SumMetric(**kwargs)) def start(self) -> None: """Start a new timer""" From 19442b2cdd5957d671d1f6f514e55b0334129e69 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Fri, 22 Dec 2023 17:18:16 +0100 Subject: [PATCH 2/3] Add run test optionally --- howto/register_external_algorithm.md | 2 +- howto/register_new_algorithm.md | 2 +- sheeprl/algos/dreamer_v1/dreamer_v1.py | 2 +- sheeprl/algos/dreamer_v2/dreamer_v2.py | 2 +- sheeprl/algos/dreamer_v3/dreamer_v3.py | 2 +- sheeprl/algos/droq/droq.py | 2 +- sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py | 2 +- sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py | 2 +- sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py | 2 +- sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py | 2 +- sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py | 2 +- sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py | 2 +- sheeprl/algos/ppo/ppo.py | 2 +- sheeprl/algos/ppo/ppo_decoupled.py | 2 +- sheeprl/algos/ppo_recurrent/ppo_recurrent.py | 2 +- sheeprl/algos/sac/sac.py | 2 +- sheeprl/algos/sac/sac_decoupled.py | 2 +- sheeprl/algos/sac_ae/sac_ae.py | 6 +++--- sheeprl/algos/sac_ae/utils.py | 2 +- sheeprl/configs/algo/default.yaml | 1 + 20 files changed, 22 insertions(+), 21 deletions(-) diff --git a/howto/register_external_algorithm.md b/howto/register_external_algorithm.md index f7816fcd..a9a871e8 100644 --- a/howto/register_external_algorithm.md +++ b/howto/register_external_algorithm.md @@ -370,7 +370,7 @@ def ext_sota_main(fabric: Fabric, cfg: Dict[str, Any]): fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state) envs.close() - if fabric.is_global_zero: + if fabric.is_global_zero and cfg.algo.run_test: test(agent.module, fabric, cfg, log_dir) # Optional part in case you want to give the possibility to register your models with MLFlow diff --git a/howto/register_new_algorithm.md b/howto/register_new_algorithm.md index 84b5fc27..15f21f48 100644 --- a/howto/register_new_algorithm.md +++ b/howto/register_new_algorithm.md @@ -367,7 +367,7 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state) envs.close() - if fabric.is_global_zero: + if fabric.is_global_zero and cfg.algo.run_test: test(agent.module, fabric, cfg, log_dir) # Optional part in case you want to give the possibility to register your models with MLFlow diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 88d169a5..83819e47 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -775,7 +775,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) envs.close() - if fabric.is_global_zero: + if fabric.is_global_zero and cfg.algo.run_test: test(player, fabric, cfg, log_dir) if not cfg.model_manager.disabled and fabric.is_global_zero: diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 040b67fa..b35c5ba0 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -828,7 +828,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) envs.close() - if fabric.is_global_zero: + if fabric.is_global_zero and cfg.algo.run_test: test(player, fabric, cfg, log_dir) if not cfg.model_manager.disabled and fabric.is_global_zero: diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index 11df56f8..f27a9c53 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -775,7 +775,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) envs.close() - if fabric.is_global_zero: + if fabric.is_global_zero and cfg.algo.run_test: test(player, fabric, cfg, log_dir, sample_actions=True) if not cfg.model_manager.disabled and fabric.is_global_zero: diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index 0c57f80b..57d93ec7 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -385,7 +385,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) envs.close() - if fabric.is_global_zero: + if fabric.is_global_zero and cfg.algo.run_test: test(agent.actor.module, fabric, cfg, log_dir) if not cfg.model_manager.disabled and fabric.is_global_zero: diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py index 55145c7d..f5a965dd 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py @@ -835,7 +835,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() # task test zero-shot - if fabric.is_global_zero: + if fabric.is_global_zero and cfg.algo.run_test: player.actor = actor_task.module player.actor_type = "task" test(player, fabric, cfg, log_dir, "zero-shot") diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index 2fb2e7cb..e5270b71 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -452,7 +452,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): envs.close() # task test few-shot - if fabric.is_global_zero: + if fabric.is_global_zero and cfg.algo.run_test: player.actor = actor_task.module player.actor_type = "task" test(player, fabric, cfg, log_dir, "few-shot") diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index d371c7ca..18d7cf8b 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -1000,7 +1000,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() # task test zero-shot - if fabric.is_global_zero: + if fabric.is_global_zero and cfg.algo.run_test: player.actor = actor_task.module player.actor_type = "task" test(player, fabric, cfg, log_dir, "zero-shot") diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index 931b331f..da71ab9b 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -484,7 +484,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): envs.close() # task test few-shot - if fabric.is_global_zero: + if fabric.is_global_zero and cfg.algo.run_test: player.actor = actor_task.module player.actor_type = "task" test(player, fabric, cfg, log_dir, "few-shot") diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index 34a01b7c..a80874e9 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -1079,7 +1079,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() # task test zero-shot - if fabric.is_global_zero: + if fabric.is_global_zero and cfg.algo.run_test: player.actor = actor_task.module player.actor_type = "task" test(player, fabric, cfg, log_dir, "zero-shot") diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index 8176263f..4d6e94fb 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -487,7 +487,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): envs.close() # task test few-shot - if fabric.is_global_zero: + if fabric.is_global_zero and cfg.algo.run_test: player.actor = actor_task.module player.actor_type = "task" test(player, fabric, cfg, log_dir, "few-shot") diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 15ed07f5..300a8244 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -445,7 +445,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state) envs.close() - if fabric.is_global_zero: + if fabric.is_global_zero and cfg.algo.run_test: test(agent.module, fabric, cfg, log_dir) if not cfg.model_manager.disabled and fabric.is_global_zero: diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index 070a1a65..bf019b52 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -349,7 +349,7 @@ def player( ) envs.close() - if fabric.is_global_zero: + if fabric.is_global_zero and cfg.algo.run_test: test(agent, fabric, cfg, log_dir) if not cfg.model_manager.disabled and fabric.is_global_zero: diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index 66e9cbd8..794c9fc8 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -509,7 +509,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=ckpt_state) envs.close() - if fabric.is_global_zero: + if fabric.is_global_zero and cfg.algo.run_test: test(agent.module, fabric, cfg, log_dir) if not cfg.model_manager.disabled and fabric.is_global_zero: diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index e4b7c7fc..4ea81fe6 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -386,7 +386,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) envs.close() - if fabric.is_global_zero: + if fabric.is_global_zero and cfg.algo.run_test: test(agent.actor.module, fabric, cfg, log_dir) if not cfg.model_manager.disabled and fabric.is_global_zero: diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index 11f54cac..705a53c0 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -310,7 +310,7 @@ def player( ) envs.close() - if fabric.is_global_zero: + if fabric.is_global_zero and cfg.algo.run_test: test(actor, fabric, cfg, log_dir) if not cfg.model_manager.disabled and fabric.is_global_zero: diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index 5d4069fd..64d78c48 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -22,7 +22,7 @@ from sheeprl.algos.sac.loss import critic_loss, entropy_loss, policy_loss from sheeprl.algos.sac_ae.agent import SACAEAgent, build_agent -from sheeprl.algos.sac_ae.utils import preprocess_obs, test_sac_ae +from sheeprl.algos.sac_ae.utils import preprocess_obs, test from sheeprl.data.buffers import ReplayBuffer from sheeprl.models.models import MultiDecoder, MultiEncoder from sheeprl.utils.env import make_env @@ -471,8 +471,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) envs.close() - if fabric.is_global_zero: - test_sac_ae(agent.actor.module, fabric, cfg, log_dir) + if fabric.is_global_zero and cfg.algo.run_test: + test(agent.actor.module, fabric, cfg, log_dir) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.sac_ae.utils import log_models diff --git a/sheeprl/algos/sac_ae/utils.py b/sheeprl/algos/sac_ae/utils.py index aba8a59e..c25cd22c 100644 --- a/sheeprl/algos/sac_ae/utils.py +++ b/sheeprl/algos/sac_ae/utils.py @@ -25,7 +25,7 @@ @torch.no_grad() -def test_sac_ae(actor: "SACAEContinuousActor", fabric: Fabric, cfg: Dict[str, Any], log_dir: str): +def test(actor: "SACAEContinuousActor", fabric: Fabric, cfg: Dict[str, Any], log_dir: str): env = make_env(cfg, cfg.seed, 0, log_dir, "test", vector_env_idx=0)() cnn_keys = actor.encoder.cnn_keys mlp_keys = actor.encoder.mlp_keys diff --git a/sheeprl/configs/algo/default.yaml b/sheeprl/configs/algo/default.yaml index b4b1c476..d83ff673 100644 --- a/sheeprl/configs/algo/default.yaml +++ b/sheeprl/configs/algo/default.yaml @@ -1,6 +1,7 @@ name: ??? total_steps: ??? per_rank_batch_size: ??? +run_test: True # Encoder and decoder keys cnn_keys: From f102f338462df9eefdc763a60e5fc488d89d58a5 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Fri, 22 Dec 2023 17:39:21 +0100 Subject: [PATCH 3/3] Fix wrong import --- sheeprl/algos/sac_ae/evaluate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sheeprl/algos/sac_ae/evaluate.py b/sheeprl/algos/sac_ae/evaluate.py index 9489f4ab..df212cda 100644 --- a/sheeprl/algos/sac_ae/evaluate.py +++ b/sheeprl/algos/sac_ae/evaluate.py @@ -6,7 +6,7 @@ from lightning import Fabric from sheeprl.algos.sac_ae.agent import SACAEAgent, build_agent -from sheeprl.algos.sac_ae.utils import test_sac_ae +from sheeprl.algos.sac_ae.utils import test from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.registry import register_evaluation @@ -41,4 +41,4 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): agent, _, _ = build_agent( fabric, cfg, observation_space, action_space, state["agent"], state["encoder"], state["decoder"] ) - test_sac_ae(agent.actor, fabric, cfg, log_dir) + test(agent.actor, fabric, cfg, log_dir)