Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Fix env_check for parametric actions (with action mask) #34790

Merged
merged 9 commits into from
Jun 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions rllib/algorithms/algorithm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import sys
from typing import (
TYPE_CHECKING,
Any,
Callable,
Container,
Expand All @@ -12,7 +13,6 @@
Optional,
Tuple,
Type,
TYPE_CHECKING,
Union,
)

Expand All @@ -21,45 +21,42 @@
import ray
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.core.learner.learner import LearnerHyperparameters
from ray.rllib.core.learner.learner_group_config import (
LearnerGroupConfig,
ModuleSpec,
)
from ray.rllib.core.learner.learner_group_config import LearnerGroupConfig, ModuleSpec
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.core.rl_module.rl_module import ModuleID, SingleAgentRLModuleSpec
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.env.wrappers.atari_wrappers import is_atari
from ray.rllib.evaluation.collectors.sample_collector import SampleCollector
from ray.rllib.utils.torch_utils import TORCH_COMPILE_REQUIRED_VERSION
from ray.rllib.evaluation.collectors.simple_list_collector import SimpleListCollector
from ray.rllib.evaluation.episode import Episode
from ray.rllib.models import MODEL_DEFAULTS
from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID
from ray.rllib.utils import deep_update, merge_dicts
from ray.rllib.utils.annotations import (
OverrideToImplementCustomLogic_CallToSuperRecommended,
ExperimentalAPI,
OverrideToImplementCustomLogic_CallToSuperRecommended,
)
from ray.rllib.utils.deprecation import (
Deprecated,
DEPRECATED_VALUE,
Deprecated,
deprecation_warning,
)
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.from_config import from_config, NotProvided
from ray.rllib.utils.from_config import NotProvided, from_config
from ray.rllib.utils.gym import (
convert_old_gym_space_to_gymnasium_space,
try_import_gymnasium_and_gym,
)
from ray.rllib.utils.policy import validate_policy_id
from ray.rllib.utils.schedules.scheduler import Scheduler
from ray.rllib.utils.serialization import (
deserialize_type,
NOT_SERIALIZABLE,
deserialize_type,
serialize_type,
)
from ray.rllib.utils.torch_utils import TORCH_COMPILE_REQUIRED_VERSION
from ray.rllib.utils.typing import (
AgentID,
AlgorithmConfigDict,
Expand Down Expand Up @@ -307,6 +304,7 @@ def __init__(self, algo_class=None):
# If not specified, we will try to auto-detect this.
self.is_atari = None
self.auto_wrap_old_gym_envs = True
self.action_mask_key = "action_mask"

# `self.rollouts()`
self.env_runner_cls = None
Expand Down Expand Up @@ -1325,6 +1323,7 @@ def environment(
disable_env_checking: Optional[bool] = NotProvided,
is_atari: Optional[bool] = NotProvided,
auto_wrap_old_gym_envs: Optional[bool] = NotProvided,
action_mask_key: Optional[str] = NotProvided,
) -> "AlgorithmConfig":
"""Sets the config's RL-environment settings.

Expand Down Expand Up @@ -1376,6 +1375,9 @@ def environment(
(gym.wrappers.EnvCompatibility). If False, RLlib will produce a
descriptive error on which steps to perform to upgrade to gymnasium
(or to switch this flag to True).
action_mask_key: If observation is a dictionary, expect the value by
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks for adding this so quickly! I think it's ready to merge now. Just waiting for tests to finish ..

the key `action_mask_key` to contain a valid actions mask (`numpy.int8`
array of zeros and ones). Defaults to "action_mask".

Returns:
This updated AlgorithmConfig object.
Expand Down Expand Up @@ -1408,6 +1410,8 @@ def environment(
self.is_atari = is_atari
if auto_wrap_old_gym_envs is not NotProvided:
self.auto_wrap_old_gym_envs = auto_wrap_old_gym_envs
if action_mask_key is not NotProvided:
self.action_mask_key = action_mask_key

return self

Expand Down
39 changes: 17 additions & 22 deletions rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from collections import defaultdict
import copy
from gymnasium.spaces import Discrete, MultiDiscrete, Space
import importlib.util
import logging
import numpy as np
import os
import platform
import threading
import tree # pip install dm_tree
from collections import defaultdict
from types import FunctionType
from typing import (
TYPE_CHECKING,
Expand All @@ -23,6 +20,10 @@
Union,
)

import numpy as np
import tree # pip install dm_tree
from gymnasium.spaces import Discrete, MultiDiscrete, Space

import ray
from ray import ObjectRef
from ray import cloudpickle as pickle
Expand All @@ -45,8 +46,8 @@
D4RLReader,
DatasetReader,
DatasetWriter,
IOContext,
InputReader,
IOContext,
JsonReader,
JsonWriter,
MixedInput,
Expand All @@ -56,34 +57,31 @@
)
from ray.rllib.policy.policy import Policy, PolicySpec
from ray.rllib.policy.policy_map import PolicyMap
from ray.rllib.policy.sample_batch import convert_ma_batch_to_sample_batch
from ray.rllib.utils.filter import NoFilter
from ray.rllib.utils.from_config import from_config
from ray.rllib.policy.sample_batch import (
DEFAULT_POLICY_ID,
MultiAgentBatch,
concat_samples,
convert_ma_batch_to_sample_batch,
)
from ray.rllib.policy.torch_policy import TorchPolicy
from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2
from ray.rllib.utils import check_env, force_list
from ray.rllib.utils.annotations import DeveloperAPI, override
from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary
from ray.rllib.utils.deprecation import (
Deprecated,
DEPRECATED_VALUE,
Deprecated,
deprecation_warning,
)
from ray.rllib.utils.error import ERR_MSG_NO_GPUS, HOWTO_CHANGE_CONFIG
from ray.rllib.utils.filter import Filter, get_filter
from ray.rllib.utils.filter import Filter, NoFilter, get_filter
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.from_config import from_config
from ray.rllib.utils.policy import create_policy_for_framework, validate_policy_id
from ray.rllib.utils.sgd import do_minibatch_sgd
from ray.rllib.utils.tf_run_builder import _TFRunBuilder
from ray.rllib.utils.tf_utils import (
get_gpu_devices as get_tf_gpu_devices,
get_tf_eager_cls_if_necessary,
)
from ray.rllib.utils.tf_utils import get_gpu_devices as get_tf_gpu_devices
from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary
from ray.rllib.utils.typing import (
AgentID,
EnvCreator,
Expand All @@ -97,11 +95,10 @@
SampleBatchType,
T,
)
from ray.tune.registry import registry_contains_input, registry_get_input
from ray.util.annotations import PublicAPI
from ray.util.debug import disable_log_once_globally, enable_periodic_logging, log_once
from ray.util.iter import ParallelIteratorWorker
from ray.tune.registry import registry_contains_input, registry_get_input


if TYPE_CHECKING:
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
Expand Down Expand Up @@ -408,7 +405,7 @@ def gen_rollouts():
if self.env is not None:
# Validate environment (general validation function).
if not self.config.disable_env_checking:
check_env(self.env)
check_env(self.env, self.config)
# Custom validation function given, typically a function attribute of the
# algorithm trainer.
if validate_env is not None:
Expand Down Expand Up @@ -499,7 +496,6 @@ def wrap(env):
and ray._private.worker._mode() != ray._private.worker.LOCAL_MODE
and not config._fake_gpus
):

devices = []
if self.config.framework_str in ["tf2", "tf"]:
devices = get_tf_gpu_devices()
Expand Down Expand Up @@ -1837,7 +1833,6 @@ def _build_policy_map(

# Loop through given policy-dict and add each entry to our map.
for name, policy_spec in sorted(policy_dict.items()):

# Create the actual policy object.
if policy is None:
new_policy = create_policy_for_framework(
Expand Down Expand Up @@ -1996,7 +1991,7 @@ def _get_output_creator_from_config(self):
def _get_make_sub_env_fn(
self, env_creator, env_context, validate_env, env_wrapper, seed
):
disable_env_checking = self.config.disable_env_checking
config = self.config

def _make_sub_env_local(vector_index):
# Used to created additional environments during environment
Expand All @@ -2008,9 +2003,9 @@ def _make_sub_env_local(vector_index):
# Create the sub-env.
env = env_creator(env_ctx)
# Validate first.
if not disable_env_checking:
if not config.disable_env_checking:
try:
check_env(env)
check_env(env, config)
except Exception as e:
logger.warning(
"We've added a module for checking environments that "
Expand Down
12 changes: 7 additions & 5 deletions rllib/examples/env/cartpole_sparse_rewards.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import gymnasium as gym
import numpy as np
from gymnasium.spaces import Discrete, Dict, Box
from gymnasium.spaces import Box, Dict, Discrete


class CartPoleSparseRewards(gym.Env):
Expand All @@ -14,7 +14,9 @@ def __init__(self, config=None):
self.observation_space = Dict(
{
"obs": self.env.observation_space,
"action_mask": Box(low=0, high=1, shape=(self.action_space.n,)),
"action_mask": Box(
low=0, high=1, shape=(self.action_space.n,), dtype=np.int8
),
}
)
self.running_reward = 0
Expand All @@ -24,15 +26,15 @@ def reset(self, *, seed=None, options=None):
obs, infos = self.env.reset()
return {
"obs": obs,
"action_mask": np.array([1, 1], dtype=np.float32),
"action_mask": np.array([1, 1], dtype=np.int8),
}, infos

def step(self, action):
obs, rew, terminated, truncated, info = self.env.step(action)
self.running_reward += rew
score = self.running_reward if terminated else 0
return (
{"obs": obs, "action_mask": np.array([1, 1], dtype=np.float32)},
{"obs": obs, "action_mask": np.array([1, 1], dtype=np.int8)},
score,
terminated,
truncated,
Expand All @@ -43,7 +45,7 @@ def set_state(self, state):
self.running_reward = state[1]
self.env = deepcopy(state[0])
obs = np.array(list(self.env.unwrapped.state))
return {"obs": obs, "action_mask": np.array([1, 1], dtype=np.float32)}
return {"obs": obs, "action_mask": np.array([1, 1], dtype=np.int8)}

def get_state(self):
return deepcopy(self.env), self.running_reward
15 changes: 7 additions & 8 deletions rllib/examples/env/parametric_actions_cartpole.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import random

import gymnasium as gym
from gymnasium.spaces import Box, Dict, Discrete
import numpy as np
import random
from gymnasium.spaces import Box, Dict, Discrete


class ParametricActionsCartPole(gym.Env):
Expand Down Expand Up @@ -35,18 +36,17 @@ def __init__(self, max_avail_actions):
self.wrapped = gym.make("CartPole-v1")
self.observation_space = Dict(
{
"action_mask": Box(0, 1, shape=(max_avail_actions,), dtype=np.float32),
"action_mask": Box(0, 1, shape=(max_avail_actions,), dtype=np.int8),
"avail_actions": Box(-10, 10, shape=(max_avail_actions, 2)),
"cart": self.wrapped.observation_space,
}
)
self._skip_env_checking = True

def update_avail_actions(self):
self.action_assignments = np.array(
[[0.0, 0.0]] * self.action_space.n, dtype=np.float32
)
self.action_mask = np.array([0.0] * self.action_space.n, dtype=np.float32)
self.action_mask = np.array([0.0] * self.action_space.n, dtype=np.int8)
self.left_idx, self.right_idx = random.sample(range(self.action_space.n), 2)
self.action_assignments[self.left_idx] = self.left_action_embed
self.action_assignments[self.right_idx] = self.right_action_embed
Expand Down Expand Up @@ -78,7 +78,7 @@ def step(self, action):
)
orig_obs, rew, done, truncated, info = self.wrapped.step(actual_action)
self.update_avail_actions()
self.action_mask = self.action_mask.astype(np.float32)
self.action_mask = self.action_mask.astype(np.int8)
obs = {
"action_mask": self.action_mask,
"avail_actions": self.action_assignments,
Expand All @@ -104,7 +104,7 @@ def __init__(self, max_avail_actions):
# Randomly set which two actions are valid and available.
self.left_idx, self.right_idx = random.sample(range(max_avail_actions), 2)
self.valid_avail_actions_mask = np.array(
[0.0] * max_avail_actions, dtype=np.float32
[0.0] * max_avail_actions, dtype=np.int8
)
self.valid_avail_actions_mask[self.left_idx] = 1
self.valid_avail_actions_mask[self.right_idx] = 1
Expand All @@ -116,7 +116,6 @@ def __init__(self, max_avail_actions):
"cart": self.wrapped.observation_space,
}
)
self._skip_env_checking = True

def reset(self, *, seed=None, options=None):
obs, infos = self.wrapped.reset()
Expand Down
4 changes: 2 additions & 2 deletions rllib/examples/models/parametric_actions_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.torch_utils import FLOAT_MIN, FLOAT_MAX
from ray.rllib.utils.torch_utils import FLOAT_MAX, FLOAT_MIN

tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()
Expand Down Expand Up @@ -144,7 +144,7 @@ def __init__(

obs_cart = tf.keras.layers.Input(shape=true_obs_shape, name="obs_cart")
valid_avail_actions_mask = tf.keras.layers.Input(
shape=(num_outputs), name="valid_avail_actions_mask"
shape=(num_outputs,), name="valid_avail_actions_mask"
)

self.pred_action_embed_model = FullyConnectedNetwork(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
"num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
"num_workers": 0,
"framework": args.framework,
"action_mask_key": "valid_avail_actions_mask",
},
**cfg
)
Expand Down
Loading