Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add action, reward and obs wrappers #311

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sheeprl/algos/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

envs.close()
if fabric.is_global_zero and cfg.algo.run_test:
test(player, fabric, cfg, log_dir)
test(player, fabric, cfg, log_dir, envs.envs[0])

if not cfg.model_manager.disabled and fabric.is_global_zero:
from sheeprl.algos.ppo.utils import log_models
Expand Down
6 changes: 4 additions & 2 deletions sheeprl/algos/a2c/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Any, Dict, Sequence

import gymnasium as gym
import numpy as np
import torch
from lightning import Fabric
Expand All @@ -21,8 +22,9 @@ def prepare_obs(


@torch.no_grad()
def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str, env: gym.Env | None = None):
if env is None:
env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
agent.eval()
done = False
cumulative_rew = 0
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

envs.close()
if fabric.is_global_zero and cfg.algo.run_test:
test(player, fabric, cfg, log_dir)
test(player, fabric, cfg, log_dir, env=envs.envs[0])

if not cfg.model_manager.disabled and fabric.is_global_zero:
from sheeprl.algos.dreamer_v1.utils import log_models
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

envs.close()
if fabric.is_global_zero and cfg.algo.run_test:
test(player, fabric, cfg, log_dir)
test(player, fabric, cfg, log_dir, env=envs.envs[0])

if not cfg.model_manager.disabled and fabric.is_global_zero:
from sheeprl.algos.dreamer_v1.utils import log_models
Expand Down
6 changes: 5 additions & 1 deletion sheeprl/algos/dreamer_v2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def test(
log_dir: str,
test_name: str = "",
greedy: bool = True,
env: gym.Env | gym.Wrapper | None = None,
):
"""Test the model on the environment with the frozen model.

Expand All @@ -136,8 +137,11 @@ def test(
Default to "".
greedy (bool): whether or not to sample actions.
Default to True.
env (gym.Env | gym.Wrapper): the environment to test on.
Default to None.
"""
env: gym.Env = make_env(cfg, cfg.seed, 0, log_dir, "test" + (f"_{test_name}" if test_name != "" else ""))()
if env is None:
env: gym.Env = make_env(cfg, cfg.seed, 0, log_dir, "test" + (f"_{test_name}" if test_name != "" else ""))()
done = False
cumulative_rew = 0
obs = env.reset(seed=cfg.seed)[0]
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,7 +764,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

envs.close()
if fabric.is_global_zero and cfg.algo.run_test:
test(player, fabric, cfg, log_dir, greedy=False)
test(player, fabric, cfg, log_dir, greedy=False, env=envs.envs[0])

if not cfg.model_manager.disabled and fabric.is_global_zero:
from sheeprl.algos.dreamer_v1.utils import log_models
Expand Down
6 changes: 5 additions & 1 deletion sheeprl/algos/dreamer_v3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def test(
log_dir: str,
test_name: str = "",
greedy: bool = True,
env: gym.Env | None = None,
):
"""Test the model on the environment with the frozen model.

Expand All @@ -111,8 +112,11 @@ def test(
Default to "".
greedy (bool): whether or not to sample the actions.
Default to True.
env (gym.Env | gym.Wrapper): the environment to test on.
Default to None.
"""
env: gym.Env = make_env(cfg, cfg.seed, 0, log_dir, "test" + (f"_{test_name}" if test_name != "" else ""))()
if env is None:
env: gym.Env = make_env(cfg, cfg.seed, 0, log_dir, "test" + (f"_{test_name}" if test_name != "" else ""))()
done = False
cumulative_rew = 0
obs = env.reset(seed=cfg.seed)[0]
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/droq/droq.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

envs.close()
if fabric.is_global_zero and cfg.algo.run_test:
test(player, fabric, cfg, log_dir)
test(player, fabric, cfg, log_dir, envs.envs[0])

if not cfg.model_manager.disabled and fabric.is_global_zero:
from sheeprl.algos.sac.utils import log_models
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
player.actor_type = "task"
fabric_player = get_single_device_fabric(fabric)
player.actor = fabric_player.setup_module(unwrap_fabric(actor_task))
test(player, fabric, cfg, log_dir, "zero-shot")
test(player, fabric, cfg, log_dir, "zero-shot", env=envs.envs[0])

if not cfg.model_manager.disabled and fabric.is_global_zero:
from sheeprl.algos.dreamer_v1.utils import log_models
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
if fabric.is_global_zero and cfg.algo.run_test:
player.actor_type = "task"
player.actor = fabric_player.setup_module(unwrap_fabric(actor_task))
test(player, fabric, cfg, log_dir, "few-shot")
test(player, fabric, cfg, log_dir, "few-shot", env=envs.envs[0])

if not cfg.model_manager.disabled and fabric.is_global_zero:
from sheeprl.algos.dreamer_v1.utils import log_models
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
player.actor_type = "task"
fabric_player = get_single_device_fabric(fabric)
player.actor = fabric_player.setup_module(unwrap_fabric(actor_task))
test(player, fabric, cfg, log_dir, "zero-shot")
test(player, fabric, cfg, log_dir, "zero-shot", env=envs.envs[0])

if not cfg.model_manager.disabled and fabric.is_global_zero:
from sheeprl.algos.dreamer_v1.utils import log_models
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
if fabric.is_global_zero and cfg.algo.run_test:
player.actor_type = "task"
player.actor = fabric_player.setup_module(unwrap_fabric(actor_task))
test(player, fabric, cfg, log_dir, "few-shot")
test(player, fabric, cfg, log_dir, "few-shot", env=envs.envs[0])

if not cfg.model_manager.disabled and fabric.is_global_zero:
from sheeprl.algos.dreamer_v1.utils import log_models
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,7 +1032,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
player.actor_type = "task"
fabric_player = get_single_device_fabric(fabric)
player.actor = fabric_player.setup_module(unwrap_fabric(actor_task))
test(player, fabric, cfg, log_dir, "zero-shot", greedy=False)
test(player, fabric, cfg, log_dir, "zero-shot", greedy=False, env=envs.envs[0])

if not cfg.model_manager.disabled and fabric.is_global_zero:
from sheeprl.algos.dreamer_v1.utils import log_models
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
if fabric.is_global_zero and cfg.algo.run_test:
player.actor_type = "task"
player.actor = fabric_player.setup_module(unwrap_fabric(actor_task))
test(player, fabric, cfg, log_dir, "few-shot", greedy=False)
test(player, fabric, cfg, log_dir, "few-shot", greedy=False, env=envs.envs[0])

if not cfg.model_manager.disabled and fabric.is_global_zero:
from sheeprl.algos.dreamer_v1.utils import log_models
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

envs.close()
if fabric.is_global_zero and cfg.algo.run_test:
test(player, fabric, cfg, log_dir)
test(player, fabric, cfg, log_dir, envs.envs[0])

if not cfg.model_manager.disabled and fabric.is_global_zero:
from sheeprl.algos.ppo.utils import log_models
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/ppo/ppo_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def player(

envs.close()
if fabric.is_global_zero and cfg.algo.run_test:
test(agent, fabric, cfg, log_dir)
test(agent, fabric, cfg, log_dir, envs.envs[0])

if not cfg.model_manager.disabled and fabric.is_global_zero:
from sheeprl.algos.ppo.utils import log_models
Expand Down
5 changes: 3 additions & 2 deletions sheeprl/algos/ppo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ def prepare_obs(


@torch.no_grad()
def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str, env: gym.Env | None = None):
if env is None:
env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
agent.eval()
done = False
cumulative_rew = 0
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/ppo_recurrent/ppo_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

envs.close()
if fabric.is_global_zero and cfg.algo.run_test:
test(player, fabric, cfg, log_dir)
test(player, fabric, cfg, log_dir, envs.envs[0])

if not cfg.model_manager.disabled and fabric.is_global_zero:
from sheeprl.algos.ppo.utils import log_models
Expand Down
5 changes: 3 additions & 2 deletions sheeprl/algos/ppo_recurrent/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def prepare_obs(


@torch.no_grad()
def test(agent: "RecurrentPPOPlayer", fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
def test(agent: "RecurrentPPOPlayer", fabric: Fabric, cfg: Dict[str, Any], log_dir: str, env: gym.Env | None = None):
if env is None:
env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
agent.eval()
done = False
cumulative_rew = 0
Expand Down
5 changes: 3 additions & 2 deletions sheeprl/algos/sac/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ def prepare_obs(


@torch.no_grad()
def test(actor: SACPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
def test(actor: SACPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str, env: gym.Env | None = None):
if env is None:
env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
actor.eval()
done = False
cumulative_rew = 0
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/sac_ae/sac_ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

envs.close()
if fabric.is_global_zero and cfg.algo.run_test:
test(player, fabric, cfg, log_dir)
test(player, fabric, cfg, log_dir, envs.envs[0])

if not cfg.model_manager.disabled and fabric.is_global_zero:
from sheeprl.algos.sac_ae.utils import log_models
Expand Down
5 changes: 3 additions & 2 deletions sheeprl/algos/sac_ae/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ def prepare_obs(


@torch.no_grad()
def test(actor: "SACAEPlayer", fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
env = make_env(cfg, cfg.seed, 0, log_dir, "test", vector_env_idx=0)()
def test(actor: "SACAEPlayer", fabric: Fabric, cfg: Dict[str, Any], log_dir: str, env: gym.Env | None = None):
if env is None:
env = make_env(cfg, cfg.seed, 0, log_dir, "test", vector_env_idx=0)()
actor.eval()
done = False
cumulative_rew = 0
Expand Down
5 changes: 5 additions & 0 deletions sheeprl/configs/env/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,13 @@ sync_env: False
screen_size: 64
action_repeat: 1
grayscale: False
clip_actions: False
clip_rewards: False
clip_obs: False
clip_obs_range: null
capture_video: True
normalize_obs: False
normalize_rewards: False
frame_stack_dilation: 1
actions_as_observation:
num_stack: -1
Expand Down
51 changes: 51 additions & 0 deletions sheeprl/envs/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import gymnasium as gym
import numpy as np
from gymnasium.core import Env, RenderFrame
from gymnasium.wrappers.normalize import NormalizeObservation, RunningMeanStd


class MaskVelocityWrapper(gym.ObservationWrapper):
Expand Down Expand Up @@ -340,3 +341,53 @@ def _get_actions_stack(self) -> np.ndarray:
actions_stack = list(self._actions)[self._dilation - 1 :: self._dilation]
actions = np.concatenate(actions_stack, axis=-1)
return actions.astype(np.float32)


class NormalizeObservationWrapper(NormalizeObservation):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.

Note:
The normalization depends on past trajectories and observations
will not be normalized correctly if the wrapper was
newly instantiated or the policy was changed recently.
"""

def __init__(self, env: gym.Env, epsilon: float = 1e-8):
"""This wrapper will normalize observations s.t. each coordinate is centered with unit variance.

Args:
env (Env): The environment to apply the wrapper
epsilon: A stability parameter that is used when scaling the observations.
"""
super().__init__(env, epsilon=epsilon)
self._is_dict_space = False
if isinstance(env.observation_space, gym.spaces.Dict):
self._is_dict_space = True
self.obs_rms = {
k: RunningMeanStd(shape=self.observation_space[k].shape) for k in self.observation_space.keys()
}

def step(self, action):
"""Steps through the environment and normalizes the observation."""
if not self._is_dict_space:
return super().step(action)
obs, rews, terminateds, truncateds, infos = self.env.step(action)
obs = self.normalize(obs)
return obs, rews, terminateds, truncateds, infos

def reset(self, **kwargs):
"""Resets the environment and normalizes the observation."""
if not self._is_dict_space:
return super().reset(**kwargs)
obs, info = self.env.reset(**kwargs)
return self.normalize(obs), info

def normalize(self, obs):
"""Normalises the observation using the running mean and variance of the observations."""
if not self._is_dict_space:
return super().normalize(obs)
new_obs = {}
for k in self.observation_space.keys():
self.obs_rms[k].update(obs[k][np.newaxis])
new_obs[k] = ((obs[k][np.newaxis] - self.obs_rms[k].mean) / np.sqrt(self.obs_rms[k].var + self.epsilon))[0]
return new_obs
30 changes: 29 additions & 1 deletion sheeprl/utils/env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import warnings
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, Optional, Sequence

import cv2
import gymnasium as gym
Expand All @@ -13,6 +13,7 @@
FrameStack,
GrayscaleRenderWrapper,
MaskVelocityWrapper,
NormalizeObservationWrapper,
RewardAsObservationWrapper,
)
from sheeprl.utils.imports import _IS_DIAMBRA_ARENA_AVAILABLE, _IS_DIAMBRA_AVAILABLE, _IS_DMC_AVAILABLE
Expand Down Expand Up @@ -211,6 +212,33 @@ def transform_obs(obs: Dict[str, Any]):
if cfg.env.actions_as_observation.num_stack > 0 and "diambra" not in cfg.env.wrapper._target_:
env = ActionsAsObservationWrapper(env, **cfg.env.actions_as_observation)

if cfg.env.normalize_obs:
env = NormalizeObservationWrapper(env)

if cfg.env.clip_obs:
if (
isinstance(cfg.env.clip_obs_range, Sequence)
and not isinstance(cfg.env.clip_obs_range, str)
and len(cfg.env.clip_obs_range) != 2
):
raise ValueError(
f"clip_obs_range must be a sequence of length 2, got: {cfg.env.clip_obs_range} of type "
f"{type(cfg.env.clip_obs_range)}"
)
env = gym.wrappers.TransformObservation(
env,
lambda obs: {
k: np.clip(obs[k], cfg.env.clip_obs_range[0], cfg.env.clip_obs_range[1]) if k in obs else obs[k]
for k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder
},
)

if cfg.env.clip_actions:
env = gym.wrappers.ClipAction(env)

if cfg.env.normalize_rewards:
env = gym.wrappers.NormalizeReward(env, gamma=cfg.algo.gamma)

if cfg.env.reward_as_observation:
env = RewardAsObservationWrapper(env)

Expand Down
Loading