diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index a037f7bb052b3..ae916a8186f35 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -4,6 +4,7 @@ import os import sys from typing import ( + TYPE_CHECKING, Any, Callable, Container, @@ -12,7 +13,6 @@ Optional, Tuple, Type, - TYPE_CHECKING, Union, ) @@ -21,17 +21,13 @@ import ray from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.rllib.core.learner.learner import LearnerHyperparameters -from ray.rllib.core.learner.learner_group_config import ( - LearnerGroupConfig, - ModuleSpec, -) +from ray.rllib.core.learner.learner_group_config import LearnerGroupConfig, ModuleSpec from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec from ray.rllib.core.rl_module.rl_module import ModuleID, SingleAgentRLModuleSpec from ray.rllib.env.env_context import EnvContext from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.wrappers.atari_wrappers import is_atari from ray.rllib.evaluation.collectors.sample_collector import SampleCollector -from ray.rllib.utils.torch_utils import TORCH_COMPILE_REQUIRED_VERSION from ray.rllib.evaluation.collectors.simple_list_collector import SimpleListCollector from ray.rllib.evaluation.episode import Episode from ray.rllib.models import MODEL_DEFAULTS @@ -39,16 +35,16 @@ from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils import deep_update, merge_dicts from ray.rllib.utils.annotations import ( - OverrideToImplementCustomLogic_CallToSuperRecommended, ExperimentalAPI, + OverrideToImplementCustomLogic_CallToSuperRecommended, ) from ray.rllib.utils.deprecation import ( - Deprecated, DEPRECATED_VALUE, + Deprecated, deprecation_warning, ) from ray.rllib.utils.framework import try_import_tf, try_import_torch -from ray.rllib.utils.from_config import from_config, NotProvided +from ray.rllib.utils.from_config import NotProvided, from_config from ray.rllib.utils.gym import ( convert_old_gym_space_to_gymnasium_space, try_import_gymnasium_and_gym, @@ -56,10 +52,11 @@ from ray.rllib.utils.policy import validate_policy_id from ray.rllib.utils.schedules.scheduler import Scheduler from ray.rllib.utils.serialization import ( - deserialize_type, NOT_SERIALIZABLE, + deserialize_type, serialize_type, ) +from ray.rllib.utils.torch_utils import TORCH_COMPILE_REQUIRED_VERSION from ray.rllib.utils.typing import ( AgentID, AlgorithmConfigDict, @@ -307,6 +304,7 @@ def __init__(self, algo_class=None): # If not specified, we will try to auto-detect this. self.is_atari = None self.auto_wrap_old_gym_envs = True + self.action_mask_key = "action_mask" # `self.rollouts()` self.env_runner_cls = None @@ -1325,6 +1323,7 @@ def environment( disable_env_checking: Optional[bool] = NotProvided, is_atari: Optional[bool] = NotProvided, auto_wrap_old_gym_envs: Optional[bool] = NotProvided, + action_mask_key: Optional[str] = NotProvided, ) -> "AlgorithmConfig": """Sets the config's RL-environment settings. @@ -1376,6 +1375,9 @@ def environment( (gym.wrappers.EnvCompatibility). If False, RLlib will produce a descriptive error on which steps to perform to upgrade to gymnasium (or to switch this flag to True). + action_mask_key: If observation is a dictionary, expect the value by + the key `action_mask_key` to contain a valid actions mask (`numpy.int8` + array of zeros and ones). Defaults to "action_mask". Returns: This updated AlgorithmConfig object. @@ -1408,6 +1410,8 @@ def environment( self.is_atari = is_atari if auto_wrap_old_gym_envs is not NotProvided: self.auto_wrap_old_gym_envs = auto_wrap_old_gym_envs + if action_mask_key is not NotProvided: + self.action_mask_key = action_mask_key return self diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index af6c9b37d4d4b..d6a4ca253d3cf 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -1,13 +1,10 @@ -from collections import defaultdict import copy -from gymnasium.spaces import Discrete, MultiDiscrete, Space import importlib.util import logging -import numpy as np import os import platform import threading -import tree # pip install dm_tree +from collections import defaultdict from types import FunctionType from typing import ( TYPE_CHECKING, @@ -23,6 +20,10 @@ Union, ) +import numpy as np +import tree # pip install dm_tree +from gymnasium.spaces import Discrete, MultiDiscrete, Space + import ray from ray import ObjectRef from ray import cloudpickle as pickle @@ -45,8 +46,8 @@ D4RLReader, DatasetReader, DatasetWriter, - IOContext, InputReader, + IOContext, JsonReader, JsonWriter, MixedInput, @@ -56,13 +57,11 @@ ) from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.policy.policy_map import PolicyMap -from ray.rllib.policy.sample_batch import convert_ma_batch_to_sample_batch -from ray.rllib.utils.filter import NoFilter -from ray.rllib.utils.from_config import from_config from ray.rllib.policy.sample_batch import ( DEFAULT_POLICY_ID, MultiAgentBatch, concat_samples, + convert_ma_batch_to_sample_batch, ) from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 @@ -70,20 +69,19 @@ from ray.rllib.utils.annotations import DeveloperAPI, override from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary from ray.rllib.utils.deprecation import ( - Deprecated, DEPRECATED_VALUE, + Deprecated, deprecation_warning, ) from ray.rllib.utils.error import ERR_MSG_NO_GPUS, HOWTO_CHANGE_CONFIG -from ray.rllib.utils.filter import Filter, get_filter +from ray.rllib.utils.filter import Filter, NoFilter, get_filter from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.from_config import from_config from ray.rllib.utils.policy import create_policy_for_framework, validate_policy_id from ray.rllib.utils.sgd import do_minibatch_sgd from ray.rllib.utils.tf_run_builder import _TFRunBuilder -from ray.rllib.utils.tf_utils import ( - get_gpu_devices as get_tf_gpu_devices, - get_tf_eager_cls_if_necessary, -) +from ray.rllib.utils.tf_utils import get_gpu_devices as get_tf_gpu_devices +from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary from ray.rllib.utils.typing import ( AgentID, EnvCreator, @@ -97,11 +95,10 @@ SampleBatchType, T, ) +from ray.tune.registry import registry_contains_input, registry_get_input from ray.util.annotations import PublicAPI from ray.util.debug import disable_log_once_globally, enable_periodic_logging, log_once from ray.util.iter import ParallelIteratorWorker -from ray.tune.registry import registry_contains_input, registry_get_input - if TYPE_CHECKING: from ray.rllib.algorithms.algorithm_config import AlgorithmConfig @@ -408,7 +405,7 @@ def gen_rollouts(): if self.env is not None: # Validate environment (general validation function). if not self.config.disable_env_checking: - check_env(self.env) + check_env(self.env, self.config) # Custom validation function given, typically a function attribute of the # algorithm trainer. if validate_env is not None: @@ -499,7 +496,6 @@ def wrap(env): and ray._private.worker._mode() != ray._private.worker.LOCAL_MODE and not config._fake_gpus ): - devices = [] if self.config.framework_str in ["tf2", "tf"]: devices = get_tf_gpu_devices() @@ -1837,7 +1833,6 @@ def _build_policy_map( # Loop through given policy-dict and add each entry to our map. for name, policy_spec in sorted(policy_dict.items()): - # Create the actual policy object. if policy is None: new_policy = create_policy_for_framework( @@ -1996,7 +1991,7 @@ def _get_output_creator_from_config(self): def _get_make_sub_env_fn( self, env_creator, env_context, validate_env, env_wrapper, seed ): - disable_env_checking = self.config.disable_env_checking + config = self.config def _make_sub_env_local(vector_index): # Used to created additional environments during environment @@ -2008,9 +2003,9 @@ def _make_sub_env_local(vector_index): # Create the sub-env. env = env_creator(env_ctx) # Validate first. - if not disable_env_checking: + if not config.disable_env_checking: try: - check_env(env) + check_env(env, config) except Exception as e: logger.warning( "We've added a module for checking environments that " diff --git a/rllib/examples/env/cartpole_sparse_rewards.py b/rllib/examples/env/cartpole_sparse_rewards.py index c07465e07acaf..d68f7614c1033 100644 --- a/rllib/examples/env/cartpole_sparse_rewards.py +++ b/rllib/examples/env/cartpole_sparse_rewards.py @@ -2,7 +2,7 @@ import gymnasium as gym import numpy as np -from gymnasium.spaces import Discrete, Dict, Box +from gymnasium.spaces import Box, Dict, Discrete class CartPoleSparseRewards(gym.Env): @@ -14,7 +14,9 @@ def __init__(self, config=None): self.observation_space = Dict( { "obs": self.env.observation_space, - "action_mask": Box(low=0, high=1, shape=(self.action_space.n,)), + "action_mask": Box( + low=0, high=1, shape=(self.action_space.n,), dtype=np.int8 + ), } ) self.running_reward = 0 @@ -24,7 +26,7 @@ def reset(self, *, seed=None, options=None): obs, infos = self.env.reset() return { "obs": obs, - "action_mask": np.array([1, 1], dtype=np.float32), + "action_mask": np.array([1, 1], dtype=np.int8), }, infos def step(self, action): @@ -32,7 +34,7 @@ def step(self, action): self.running_reward += rew score = self.running_reward if terminated else 0 return ( - {"obs": obs, "action_mask": np.array([1, 1], dtype=np.float32)}, + {"obs": obs, "action_mask": np.array([1, 1], dtype=np.int8)}, score, terminated, truncated, @@ -43,7 +45,7 @@ def set_state(self, state): self.running_reward = state[1] self.env = deepcopy(state[0]) obs = np.array(list(self.env.unwrapped.state)) - return {"obs": obs, "action_mask": np.array([1, 1], dtype=np.float32)} + return {"obs": obs, "action_mask": np.array([1, 1], dtype=np.int8)} def get_state(self): return deepcopy(self.env), self.running_reward diff --git a/rllib/examples/env/parametric_actions_cartpole.py b/rllib/examples/env/parametric_actions_cartpole.py index 52985924bf5bc..94fb78f417b8f 100644 --- a/rllib/examples/env/parametric_actions_cartpole.py +++ b/rllib/examples/env/parametric_actions_cartpole.py @@ -1,7 +1,8 @@ +import random + import gymnasium as gym -from gymnasium.spaces import Box, Dict, Discrete import numpy as np -import random +from gymnasium.spaces import Box, Dict, Discrete class ParametricActionsCartPole(gym.Env): @@ -35,18 +36,17 @@ def __init__(self, max_avail_actions): self.wrapped = gym.make("CartPole-v1") self.observation_space = Dict( { - "action_mask": Box(0, 1, shape=(max_avail_actions,), dtype=np.float32), + "action_mask": Box(0, 1, shape=(max_avail_actions,), dtype=np.int8), "avail_actions": Box(-10, 10, shape=(max_avail_actions, 2)), "cart": self.wrapped.observation_space, } ) - self._skip_env_checking = True def update_avail_actions(self): self.action_assignments = np.array( [[0.0, 0.0]] * self.action_space.n, dtype=np.float32 ) - self.action_mask = np.array([0.0] * self.action_space.n, dtype=np.float32) + self.action_mask = np.array([0.0] * self.action_space.n, dtype=np.int8) self.left_idx, self.right_idx = random.sample(range(self.action_space.n), 2) self.action_assignments[self.left_idx] = self.left_action_embed self.action_assignments[self.right_idx] = self.right_action_embed @@ -78,7 +78,7 @@ def step(self, action): ) orig_obs, rew, done, truncated, info = self.wrapped.step(actual_action) self.update_avail_actions() - self.action_mask = self.action_mask.astype(np.float32) + self.action_mask = self.action_mask.astype(np.int8) obs = { "action_mask": self.action_mask, "avail_actions": self.action_assignments, @@ -104,7 +104,7 @@ def __init__(self, max_avail_actions): # Randomly set which two actions are valid and available. self.left_idx, self.right_idx = random.sample(range(max_avail_actions), 2) self.valid_avail_actions_mask = np.array( - [0.0] * max_avail_actions, dtype=np.float32 + [0.0] * max_avail_actions, dtype=np.int8 ) self.valid_avail_actions_mask[self.left_idx] = 1 self.valid_avail_actions_mask[self.right_idx] = 1 @@ -116,7 +116,6 @@ def __init__(self, max_avail_actions): "cart": self.wrapped.observation_space, } ) - self._skip_env_checking = True def reset(self, *, seed=None, options=None): obs, infos = self.wrapped.reset() diff --git a/rllib/examples/models/parametric_actions_model.py b/rllib/examples/models/parametric_actions_model.py index b61ad6fea4c8b..20711553b82b8 100644 --- a/rllib/examples/models/parametric_actions_model.py +++ b/rllib/examples/models/parametric_actions_model.py @@ -5,7 +5,7 @@ from ray.rllib.models.tf.fcnet import FullyConnectedNetwork from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC from ray.rllib.utils.framework import try_import_tf, try_import_torch -from ray.rllib.utils.torch_utils import FLOAT_MIN, FLOAT_MAX +from ray.rllib.utils.torch_utils import FLOAT_MAX, FLOAT_MIN tf1, tf, tfv = try_import_tf() torch, nn = try_import_torch() @@ -144,7 +144,7 @@ def __init__( obs_cart = tf.keras.layers.Input(shape=true_obs_shape, name="obs_cart") valid_avail_actions_mask = tf.keras.layers.Input( - shape=(num_outputs), name="valid_avail_actions_mask" + shape=(num_outputs,), name="valid_avail_actions_mask" ) self.pred_action_embed_model = FullyConnectedNetwork( diff --git a/rllib/examples/parametric_actions_cartpole_embeddings_learnt_by_model.py b/rllib/examples/parametric_actions_cartpole_embeddings_learnt_by_model.py index e0f1a7c858c09..a2f22791813e1 100644 --- a/rllib/examples/parametric_actions_cartpole_embeddings_learnt_by_model.py +++ b/rllib/examples/parametric_actions_cartpole_embeddings_learnt_by_model.py @@ -75,6 +75,7 @@ "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")), "num_workers": 0, "framework": args.framework, + "action_mask_key": "valid_avail_actions_mask", }, **cfg ) diff --git a/rllib/utils/pre_checks/env.py b/rllib/utils/pre_checks/env.py index 7d00d74be0f46..4bdeb6a15d35d 100644 --- a/rllib/utils/pre_checks/env.py +++ b/rllib/utils/pre_checks/env.py @@ -1,10 +1,11 @@ """Common pre-checks for all RLlib experiments.""" -from copy import copy import logging -import numpy as np import traceback +from copy import copy +from typing import TYPE_CHECKING, Optional, Set, Union + +import numpy as np import tree # pip install dm_tree -from typing import TYPE_CHECKING, Set, Union from ray.actor import ActorHandle from ray.rllib.utils.annotations import DeveloperAPI @@ -18,6 +19,7 @@ from ray.util import log_once if TYPE_CHECKING: + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.env import BaseEnv, MultiAgentEnv, VectorEnv logger = logging.getLogger(__name__) @@ -26,23 +28,25 @@ @DeveloperAPI -def check_env(env: EnvType) -> None: +def check_env(env: EnvType, config: Optional["AlgorithmConfig"] = None) -> None: """Run pre-checks on env that uncover common errors in environments. Args: env: Environment to be checked. + config: Additional checks config. Raises: ValueError: If env is not an instance of SUPPORTED_ENVIRONMENT_TYPES. ValueError: See check_gym_env docstring for details. """ + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.env import ( BaseEnv, + ExternalEnv, + ExternalMultiAgentEnv, MultiAgentEnv, RemoteBaseEnv, VectorEnv, - ExternalMultiAgentEnv, - ExternalEnv, ) if hasattr(env, "_skip_env_checking") and env._skip_env_checking: @@ -78,7 +82,7 @@ def check_env(env: EnvType) -> None: elif isinstance(env, VectorEnv): check_vector_env(env) elif isinstance(env, gym.Env) or old_gym and isinstance(env, old_gym.Env): - check_gym_environments(env) + check_gym_environments(env, AlgorithmConfig() if config is None else config) elif isinstance(env, BaseEnv): check_base_env(env) else: @@ -102,11 +106,14 @@ def check_env(env: EnvType) -> None: @DeveloperAPI -def check_gym_environments(env: Union[gym.Env, "old_gym.Env"]) -> None: +def check_gym_environments( + env: Union[gym.Env, "old_gym.Env"], config: "AlgorithmConfig" +) -> None: """Checking for common errors in a gymnasium/gym environments. Args: env: Environment to be checked. + config: Additional checks config. Warning: If env has no attribute spec with a sub attribute, @@ -212,6 +219,12 @@ def check_gym_environments(env: Union[gym.Env, "old_gym.Env"]) -> None: space_type, ) ) + # sample a valid action in case of parametric actions + if isinstance(reset_obs, dict): + if config.action_mask_key in reset_obs: + sampled_action = env.action_space.sample( + mask=reset_obs[config.action_mask_key] + ) # Check if env.step can run, and generates observations rewards, done # signals and infos that are within their respective spaces and are of diff --git a/rllib/utils/tests/test_check_env.py b/rllib/utils/tests/test_check_env.py index 2f91c46ca289e..65ce3e215da90 100644 --- a/rllib/utils/tests/test_check_env.py +++ b/rllib/utils/tests/test_check_env.py @@ -1,19 +1,21 @@ -import gymnasium as gym -from gymnasium.spaces import Box, Dict, Discrete import logging +import unittest +from unittest.mock import MagicMock, Mock + +import gymnasium as gym import numpy as np import pytest -import unittest -from unittest.mock import Mock, MagicMock +from gymnasium.spaces import Box, Dict, Discrete from ray.rllib.env.base_env import convert_to_base_env -from ray.rllib.env.multi_agent_env import make_multi_agent, MultiAgentEnvWrapper +from ray.rllib.env.multi_agent_env import MultiAgentEnvWrapper, make_multi_agent +from ray.rllib.examples.env.parametric_actions_cartpole import ParametricActionsCartPole from ray.rllib.examples.env.random_env import RandomEnv from ray.rllib.utils.pre_checks.env import ( + check_base_env, check_env, check_gym_environments, check_multiagent_environments, - check_base_env, ) @@ -25,10 +27,10 @@ def inject_fixtures(self, caplog): def test_has_observation_and_action_space(self): env = Mock(spec=[]) with pytest.raises(AttributeError, match="Env must have observation_space."): - check_gym_environments(env) + check_gym_environments(env, Mock()) env = Mock(spec=["observation_space"]) with pytest.raises(AttributeError, match="Env must have action_space."): - check_gym_environments(env) + check_gym_environments(env, Mock()) def test_obs_and_action_spaces_are_gym_spaces(self): env = RandomEnv() @@ -104,6 +106,10 @@ def test_step(self): with pytest.raises(ValueError, match=error): check_env(env) + def test_parametric_actions(self): + env = ParametricActionsCartPole(10) + check_env(env) + class TestCheckMultiAgentEnv(unittest.TestCase): @pytest.fixture(autouse=True)