From bbed38732fb1145d60cdfa093e0bcebc569ccf1e Mon Sep 17 00:00:00 2001 From: Federico Belotti Date: Wed, 10 Jul 2024 10:27:51 +0000 Subject: [PATCH 1/3] Add action, reward and obs wrappers --- sheeprl/configs/env/default.yaml | 5 ++++ sheeprl/envs/wrappers.py | 50 ++++++++++++++++++++++++++++++++ sheeprl/utils/env.py | 30 ++++++++++++++++++- 3 files changed, 84 insertions(+), 1 deletion(-) diff --git a/sheeprl/configs/env/default.yaml b/sheeprl/configs/env/default.yaml index 459d0cab..6e728a4e 100644 --- a/sheeprl/configs/env/default.yaml +++ b/sheeprl/configs/env/default.yaml @@ -5,8 +5,13 @@ sync_env: False screen_size: 64 action_repeat: 1 grayscale: False +clip_actions: False clip_rewards: False +clip_obs: False +clip_obs_range: null capture_video: True +normalize_obs: False +normalize_rewards: False frame_stack_dilation: 1 actions_as_observation: num_stack: -1 diff --git a/sheeprl/envs/wrappers.py b/sheeprl/envs/wrappers.py index cc285b11..74999260 100644 --- a/sheeprl/envs/wrappers.py +++ b/sheeprl/envs/wrappers.py @@ -8,6 +8,7 @@ import gymnasium as gym import numpy as np from gymnasium.core import Env, RenderFrame +from gymnasium.wrappers.normalize import NormalizeObservation, RunningMeanStd, update_mean_var_count_from_moments class MaskVelocityWrapper(gym.ObservationWrapper): @@ -340,3 +341,52 @@ def _get_actions_stack(self) -> np.ndarray: actions_stack = list(self._actions)[self._dilation - 1 :: self._dilation] actions = np.concatenate(actions_stack, axis=-1) return actions.astype(np.float32) + + +class NormalizeObservationWrapper(NormalizeObservation): + """This wrapper will normalize observations s.t. each coordinate is centered with unit variance. + + Note: + The normalization depends on past trajectories and observations will not be normalized correctly if the wrapper was + newly instantiated or the policy was changed recently. + """ + + def __init__(self, env: gym.Env, epsilon: float = 1e-8): + """This wrapper will normalize observations s.t. each coordinate is centered with unit variance. + + Args: + env (Env): The environment to apply the wrapper + epsilon: A stability parameter that is used when scaling the observations. + """ + super().__init__(env, epsilon=epsilon) + self._is_dict_space = False + if isinstance(env.observation_space, gym.spaces.Dict): + self._is_dict_space = True + self.obs_rms = { + k: RunningMeanStd(shape=self.observation_space[k].shape) for k in self.observation_space.keys() + } + + def step(self, action): + """Steps through the environment and normalizes the observation.""" + if not self._is_dict_space: + return super().step(action) + obs, rews, terminateds, truncateds, infos = self.env.step(action) + obs = self.normalize(obs) + return obs, rews, terminateds, truncateds, infos + + def reset(self, **kwargs): + """Resets the environment and normalizes the observation.""" + if not self._is_dict_space: + return super().reset(**kwargs) + obs, info = self.env.reset(**kwargs) + return self.normalize(obs), info + + def normalize(self, obs): + """Normalises the observation using the running mean and variance of the observations.""" + if not self._is_dict_space: + return super().normalize(obs) + new_obs = {} + for k in self.observation_space.keys(): + self.obs_rms[k].update(obs[k][np.newaxis]) + new_obs[k] = ((obs[k][np.newaxis] - self.obs_rms[k].mean) / np.sqrt(self.obs_rms[k].var + self.epsilon))[0] + return new_obs diff --git a/sheeprl/utils/env.py b/sheeprl/utils/env.py index 750d85ee..245810c6 100644 --- a/sheeprl/utils/env.py +++ b/sheeprl/utils/env.py @@ -1,6 +1,6 @@ import os import warnings -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Sequence import cv2 import gymnasium as gym @@ -13,6 +13,7 @@ FrameStack, GrayscaleRenderWrapper, MaskVelocityWrapper, + NormalizeObservationWrapper, RewardAsObservationWrapper, ) from sheeprl.utils.imports import _IS_DIAMBRA_ARENA_AVAILABLE, _IS_DIAMBRA_AVAILABLE, _IS_DMC_AVAILABLE @@ -211,6 +212,33 @@ def transform_obs(obs: Dict[str, Any]): if cfg.env.actions_as_observation.num_stack > 0 and "diambra" not in cfg.env.wrapper._target_: env = ActionsAsObservationWrapper(env, **cfg.env.actions_as_observation) + if cfg.env.normalize_obs: + env = NormalizeObservationWrapper(env) + + if cfg.env.clip_obs: + if ( + isinstance(cfg.env.clip_obs_range, Sequence) + and not isinstance(cfg.env.clip_obs_range, str) + and len(cfg.env.clip_obs_range) != 2 + ): + raise ValueError( + f"clip_obs_range must be a sequence of length 2, got: {cfg.env.clip_obs_range} of type " + f"{type(cfg.env.clip_obs_range)}" + ) + env = gym.wrappers.TransformObservation( + env, + lambda obs: { + k: np.clip(obs[k], cfg.env.clip_obs_range[0], cfg.env.clip_obs_range[1]) + for k in cfg.algo.mlp_keys.encoder + }, + ) + + if cfg.env.clip_actions: + env = gym.wrappers.ClipAction(env) + + if cfg.env.normalize_rewards: + env = gym.wrappers.NormalizeReward(env, gamma=cfg.algo.gamma) + if cfg.env.reward_as_observation: env = RewardAsObservationWrapper(env) From 37197aec947a42bcd9c1b998db0f4572d1201f61 Mon Sep 17 00:00:00 2001 From: belerico Date: Wed, 10 Jul 2024 16:17:07 +0200 Subject: [PATCH 2/3] Clip pixel obs also --- sheeprl/envs/wrappers.py | 5 +++-- sheeprl/utils/env.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/sheeprl/envs/wrappers.py b/sheeprl/envs/wrappers.py index 74999260..34c48cde 100644 --- a/sheeprl/envs/wrappers.py +++ b/sheeprl/envs/wrappers.py @@ -8,7 +8,7 @@ import gymnasium as gym import numpy as np from gymnasium.core import Env, RenderFrame -from gymnasium.wrappers.normalize import NormalizeObservation, RunningMeanStd, update_mean_var_count_from_moments +from gymnasium.wrappers.normalize import NormalizeObservation, RunningMeanStd class MaskVelocityWrapper(gym.ObservationWrapper): @@ -347,7 +347,8 @@ class NormalizeObservationWrapper(NormalizeObservation): """This wrapper will normalize observations s.t. each coordinate is centered with unit variance. Note: - The normalization depends on past trajectories and observations will not be normalized correctly if the wrapper was + The normalization depends on past trajectories and observations + will not be normalized correctly if the wrapper was newly instantiated or the policy was changed recently. """ diff --git a/sheeprl/utils/env.py b/sheeprl/utils/env.py index 245810c6..be485747 100644 --- a/sheeprl/utils/env.py +++ b/sheeprl/utils/env.py @@ -228,8 +228,8 @@ def transform_obs(obs: Dict[str, Any]): env = gym.wrappers.TransformObservation( env, lambda obs: { - k: np.clip(obs[k], cfg.env.clip_obs_range[0], cfg.env.clip_obs_range[1]) - for k in cfg.algo.mlp_keys.encoder + k: np.clip(obs[k], cfg.env.clip_obs_range[0], cfg.env.clip_obs_range[1]) if k in obs else obs[k] + for k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder }, ) From 0f5b02da9126c068b97872a4ebe5cbdc8b51eb40 Mon Sep 17 00:00:00 2001 From: belerico Date: Wed, 10 Jul 2024 22:08:00 +0200 Subject: [PATCH 3/3] Test env as the first training env (due to norm stats to be shared) --- sheeprl/algos/a2c/a2c.py | 2 +- sheeprl/algos/a2c/utils.py | 6 ++++-- sheeprl/algos/dreamer_v1/dreamer_v1.py | 2 +- sheeprl/algos/dreamer_v2/dreamer_v2.py | 2 +- sheeprl/algos/dreamer_v2/utils.py | 6 +++++- sheeprl/algos/dreamer_v3/dreamer_v3.py | 2 +- sheeprl/algos/dreamer_v3/utils.py | 6 +++++- sheeprl/algos/droq/droq.py | 2 +- sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py | 2 +- sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py | 2 +- sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py | 2 +- sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py | 2 +- sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py | 2 +- sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py | 2 +- sheeprl/algos/ppo/ppo.py | 2 +- sheeprl/algos/ppo/ppo_decoupled.py | 2 +- sheeprl/algos/ppo/utils.py | 5 +++-- sheeprl/algos/ppo_recurrent/ppo_recurrent.py | 2 +- sheeprl/algos/ppo_recurrent/utils.py | 5 +++-- sheeprl/algos/sac/utils.py | 5 +++-- sheeprl/algos/sac_ae/sac_ae.py | 2 +- sheeprl/algos/sac_ae/utils.py | 5 +++-- 22 files changed, 41 insertions(+), 27 deletions(-) diff --git a/sheeprl/algos/a2c/a2c.py b/sheeprl/algos/a2c/a2c.py index 54bbe7b1..01f6f4db 100644 --- a/sheeprl/algos/a2c/a2c.py +++ b/sheeprl/algos/a2c/a2c.py @@ -373,7 +373,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero and cfg.algo.run_test: - test(player, fabric, cfg, log_dir) + test(player, fabric, cfg, log_dir, envs.envs[0]) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.ppo.utils import log_models diff --git a/sheeprl/algos/a2c/utils.py b/sheeprl/algos/a2c/utils.py index 88fb0099..3ccffc91 100644 --- a/sheeprl/algos/a2c/utils.py +++ b/sheeprl/algos/a2c/utils.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Sequence +import gymnasium as gym import numpy as np import torch from lightning import Fabric @@ -21,8 +22,9 @@ def prepare_obs( @torch.no_grad() -def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): - env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)() +def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str, env: gym.Env | None = None): + if env is None: + env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)() agent.eval() done = False cumulative_rew = 0 diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 30eadde7..ef7704bc 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -740,7 +740,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero and cfg.algo.run_test: - test(player, fabric, cfg, log_dir) + test(player, fabric, cfg, log_dir, env=envs.envs[0]) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.dreamer_v1.utils import log_models diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 49f16751..cc3f7155 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -782,7 +782,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero and cfg.algo.run_test: - test(player, fabric, cfg, log_dir) + test(player, fabric, cfg, log_dir, env=envs.envs[0]) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.dreamer_v1.utils import log_models diff --git a/sheeprl/algos/dreamer_v2/utils.py b/sheeprl/algos/dreamer_v2/utils.py index 3a846858..9969e06d 100644 --- a/sheeprl/algos/dreamer_v2/utils.py +++ b/sheeprl/algos/dreamer_v2/utils.py @@ -124,6 +124,7 @@ def test( log_dir: str, test_name: str = "", greedy: bool = True, + env: gym.Env | gym.Wrapper | None = None, ): """Test the model on the environment with the frozen model. @@ -136,8 +137,11 @@ def test( Default to "". greedy (bool): whether or not to sample actions. Default to True. + env (gym.Env | gym.Wrapper): the environment to test on. + Default to None. """ - env: gym.Env = make_env(cfg, cfg.seed, 0, log_dir, "test" + (f"_{test_name}" if test_name != "" else ""))() + if env is None: + env: gym.Env = make_env(cfg, cfg.seed, 0, log_dir, "test" + (f"_{test_name}" if test_name != "" else ""))() done = False cumulative_rew = 0 obs = env.reset(seed=cfg.seed)[0] diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index babcebf8..074006fa 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -764,7 +764,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero and cfg.algo.run_test: - test(player, fabric, cfg, log_dir, greedy=False) + test(player, fabric, cfg, log_dir, greedy=False, env=envs.envs[0]) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.dreamer_v1.utils import log_models diff --git a/sheeprl/algos/dreamer_v3/utils.py b/sheeprl/algos/dreamer_v3/utils.py index 2fdac419..49fb6cbd 100644 --- a/sheeprl/algos/dreamer_v3/utils.py +++ b/sheeprl/algos/dreamer_v3/utils.py @@ -99,6 +99,7 @@ def test( log_dir: str, test_name: str = "", greedy: bool = True, + env: gym.Env | None = None, ): """Test the model on the environment with the frozen model. @@ -111,8 +112,11 @@ def test( Default to "". greedy (bool): whether or not to sample the actions. Default to True. + env (gym.Env | gym.Wrapper): the environment to test on. + Default to None. """ - env: gym.Env = make_env(cfg, cfg.seed, 0, log_dir, "test" + (f"_{test_name}" if test_name != "" else ""))() + if env is None: + env: gym.Env = make_env(cfg, cfg.seed, 0, log_dir, "test" + (f"_{test_name}" if test_name != "" else ""))() done = False cumulative_rew = 0 obs = env.reset(seed=cfg.seed)[0] diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index b5cf8c35..691eaf1a 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -426,7 +426,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero and cfg.algo.run_test: - test(player, fabric, cfg, log_dir) + test(player, fabric, cfg, log_dir, envs.envs[0]) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.sac.utils import log_models diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py index 6c184b45..c20404fa 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py @@ -784,7 +784,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): player.actor_type = "task" fabric_player = get_single_device_fabric(fabric) player.actor = fabric_player.setup_module(unwrap_fabric(actor_task)) - test(player, fabric, cfg, log_dir, "zero-shot") + test(player, fabric, cfg, log_dir, "zero-shot", env=envs.envs[0]) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.dreamer_v1.utils import log_models diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index b071c977..2a3bf0a4 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -431,7 +431,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): if fabric.is_global_zero and cfg.algo.run_test: player.actor_type = "task" player.actor = fabric_player.setup_module(unwrap_fabric(actor_task)) - test(player, fabric, cfg, log_dir, "few-shot") + test(player, fabric, cfg, log_dir, "few-shot", env=envs.envs[0]) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.dreamer_v1.utils import log_models diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index ecf285aa..f3b4e4de 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -939,7 +939,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): player.actor_type = "task" fabric_player = get_single_device_fabric(fabric) player.actor = fabric_player.setup_module(unwrap_fabric(actor_task)) - test(player, fabric, cfg, log_dir, "zero-shot") + test(player, fabric, cfg, log_dir, "zero-shot", env=envs.envs[0]) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.dreamer_v1.utils import log_models diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index fcc59fe4..10503a89 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -459,7 +459,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): if fabric.is_global_zero and cfg.algo.run_test: player.actor_type = "task" player.actor = fabric_player.setup_module(unwrap_fabric(actor_task)) - test(player, fabric, cfg, log_dir, "few-shot") + test(player, fabric, cfg, log_dir, "few-shot", env=envs.envs[0]) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.dreamer_v1.utils import log_models diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index 2a5bb00d..2de5d040 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -1032,7 +1032,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): player.actor_type = "task" fabric_player = get_single_device_fabric(fabric) player.actor = fabric_player.setup_module(unwrap_fabric(actor_task)) - test(player, fabric, cfg, log_dir, "zero-shot", greedy=False) + test(player, fabric, cfg, log_dir, "zero-shot", greedy=False, env=envs.envs[0]) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.dreamer_v1.utils import log_models diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index d825ac24..39b18704 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -461,7 +461,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): if fabric.is_global_zero and cfg.algo.run_test: player.actor_type = "task" player.actor = fabric_player.setup_module(unwrap_fabric(actor_task)) - test(player, fabric, cfg, log_dir, "few-shot", greedy=False) + test(player, fabric, cfg, log_dir, "few-shot", greedy=False, env=envs.envs[0]) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.dreamer_v1.utils import log_models diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 95057f2d..aec1dde1 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -442,7 +442,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero and cfg.algo.run_test: - test(player, fabric, cfg, log_dir) + test(player, fabric, cfg, log_dir, envs.envs[0]) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.ppo.utils import log_models diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index 79f97337..a6b64218 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -355,7 +355,7 @@ def player( envs.close() if fabric.is_global_zero and cfg.algo.run_test: - test(agent, fabric, cfg, log_dir) + test(agent, fabric, cfg, log_dir, envs.envs[0]) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.ppo.utils import log_models diff --git a/sheeprl/algos/ppo/utils.py b/sheeprl/algos/ppo/utils.py index 4b5e4634..8f7f5bfa 100644 --- a/sheeprl/algos/ppo/utils.py +++ b/sheeprl/algos/ppo/utils.py @@ -36,8 +36,9 @@ def prepare_obs( @torch.no_grad() -def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): - env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)() +def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str, env: gym.Env | None = None): + if env is None: + env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)() agent.eval() done = False cumulative_rew = 0 diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index 744a0e82..2d1e751e 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -515,7 +515,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero and cfg.algo.run_test: - test(player, fabric, cfg, log_dir) + test(player, fabric, cfg, log_dir, envs.envs[0]) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.ppo.utils import log_models diff --git a/sheeprl/algos/ppo_recurrent/utils.py b/sheeprl/algos/ppo_recurrent/utils.py index 47111ade..27910d25 100644 --- a/sheeprl/algos/ppo_recurrent/utils.py +++ b/sheeprl/algos/ppo_recurrent/utils.py @@ -39,8 +39,9 @@ def prepare_obs( @torch.no_grad() -def test(agent: "RecurrentPPOPlayer", fabric: Fabric, cfg: Dict[str, Any], log_dir: str): - env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)() +def test(agent: "RecurrentPPOPlayer", fabric: Fabric, cfg: Dict[str, Any], log_dir: str, env: gym.Env | None = None): + if env is None: + env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)() agent.eval() done = False cumulative_rew = 0 diff --git a/sheeprl/algos/sac/utils.py b/sheeprl/algos/sac/utils.py index 9432db3f..6b98b42f 100644 --- a/sheeprl/algos/sac/utils.py +++ b/sheeprl/algos/sac/utils.py @@ -37,8 +37,9 @@ def prepare_obs( @torch.no_grad() -def test(actor: SACPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): - env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)() +def test(actor: SACPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str, env: gym.Env | None = None): + if env is None: + env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)() actor.eval() done = False cumulative_rew = 0 diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index 9ce76677..aabbc06f 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -492,7 +492,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero and cfg.algo.run_test: - test(player, fabric, cfg, log_dir) + test(player, fabric, cfg, log_dir, envs.envs[0]) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.sac_ae.utils import log_models diff --git a/sheeprl/algos/sac_ae/utils.py b/sheeprl/algos/sac_ae/utils.py index 680f5ee1..cdb7412e 100644 --- a/sheeprl/algos/sac_ae/utils.py +++ b/sheeprl/algos/sac_ae/utils.py @@ -40,8 +40,9 @@ def prepare_obs( @torch.no_grad() -def test(actor: "SACAEPlayer", fabric: Fabric, cfg: Dict[str, Any], log_dir: str): - env = make_env(cfg, cfg.seed, 0, log_dir, "test", vector_env_idx=0)() +def test(actor: "SACAEPlayer", fabric: Fabric, cfg: Dict[str, Any], log_dir: str, env: gym.Env | None = None): + if env is None: + env = make_env(cfg, cfg.seed, 0, log_dir, "test", vector_env_idx=0)() actor.eval() done = False cumulative_rew = 0