diff --git a/rllib/BUILD b/rllib/BUILD index 26b2c4426d813..f040dbab4e73f 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1367,14 +1367,6 @@ py_test( srcs = ["evaluation/tests/test_env_runner_v2.py"] ) -# @OldAPIStack -py_test( - name = "evaluation/tests/test_episode", - tags = ["team:rllib", "evaluation"], - size = "small", - srcs = ["evaluation/tests/test_episode.py"] -) - # @OldAPIStack py_test( name = "evaluation/tests/test_episode_v2", @@ -3181,26 +3173,6 @@ py_test( args = ["--as-test", "--framework=torch", "--stop-reward=7.2"] ) -#@OldAPIStack -py_test( - name = "examples/centralized_critic_2_tf", - main = "examples/centralized_critic_2.py", - tags = ["team:rllib", "exclusive", "examples"], - size = "medium", - srcs = ["examples/centralized_critic_2.py"], - args = ["--as-test", "--framework=tf", "--stop-reward=6.0"] -) - -#@OldAPIStack -py_test( - name = "examples/centralized_critic_2_torch", - main = "examples/centralized_critic_2.py", - tags = ["team:rllib", "exclusive", "examples"], - size = "medium", - srcs = ["examples/centralized_critic_2.py"], - args = ["--as-test", "--framework=torch", "--stop-reward=6.0"] -) - py_test( name = "examples/custom_recurrent_rnn_tokenizer_repeat_after_me_tf2", main = "examples/custom_recurrent_rnn_tokenizer.py", diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index f2462a8450759..863e06eec904a 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -62,7 +62,6 @@ from ray.rllib.env.env_runner import EnvRunner from ray.rllib.env.env_runner_group import EnvRunnerGroup from ray.rllib.env.utils import _gym_env_creator -from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.metrics import ( collect_episodes, summarize_episodes, @@ -1634,10 +1633,7 @@ def restore_workers(self, workers: EnvRunnerGroup) -> None: # worker of an EnvRunnerGroup. state = from_worker.get_state() # Take out (old) connector states from local worker's state. - if ( - self.config.enable_connectors - and not self.config.enable_env_runner_and_connector_v2 - ): + if not self.config.enable_env_runner_and_connector_v2: for pol_states in state["policy_states"].values(): pol_states.pop("connector_configs", None) state_ref = ray.put(state) @@ -2040,7 +2036,7 @@ def compute_single_action( full_fetch: bool = False, explore: Optional[bool] = None, timestep: Optional[int] = None, - episode: Optional[Episode] = None, + episode=None, unsquash_action: Optional[bool] = None, clip_action: Optional[bool] = None, # Kwargs placeholder for future compatibility. @@ -2127,53 +2123,44 @@ def compute_single_action( f"PolicyID '{policy_id}' not found in PolicyMap of the " f"Algorithm's local worker!" ) - local_worker = self.env_runner_group.local_env_runner - - if not self.config.get("enable_connectors"): - # Check the preprocessor and preprocess, if necessary. - pp = local_worker.preprocessors[policy_id] - if pp and type(pp).__name__ != "NoPreprocessor": - observation = pp.transform(observation) - observation = local_worker.filters[policy_id](observation, update=False) - else: - # Just preprocess observations, similar to how it used to be done before. - pp = policy.agent_connectors[ObsPreprocessorConnector] - - # convert the observation to array if possible - if not isinstance(observation, (np.ndarray, dict, tuple)): - try: - observation = np.asarray(observation) - except Exception: + # Just preprocess observations, similar to how it used to be done before. + pp = policy.agent_connectors[ObsPreprocessorConnector] + + # convert the observation to array if possible + if not isinstance(observation, (np.ndarray, dict, tuple)): + try: + observation = np.asarray(observation) + except Exception: + raise ValueError( + f"Observation type {type(observation)} cannot be converted to " + f"np.ndarray." + ) + if pp: + assert len(pp) == 1, "Only one preprocessor should be in the pipeline" + pp = pp[0] + + if not pp.is_identity(): + # Note(Kourosh): This call will leave the policy's connector + # in eval mode. would that be a problem? + pp.in_eval() + if observation is not None: + _input_dict = {Columns.OBS: observation} + elif input_dict is not None: + _input_dict = {Columns.OBS: input_dict[Columns.OBS]} + else: raise ValueError( - f"Observation type {type(observation)} cannot be converted to " - f"np.ndarray." + "Either observation or input_dict must be provided." ) - if pp: - assert len(pp) == 1, "Only one preprocessor should be in the pipeline" - pp = pp[0] - - if not pp.is_identity(): - # Note(Kourosh): This call will leave the policy's connector - # in eval mode. would that be a problem? - pp.in_eval() - if observation is not None: - _input_dict = {Columns.OBS: observation} - elif input_dict is not None: - _input_dict = {Columns.OBS: input_dict[Columns.OBS]} - else: - raise ValueError( - "Either observation or input_dict must be provided." - ) - # TODO (Kourosh): Create a new util method for algorithm that - # computes actions based on raw inputs from env and can keep track - # of its own internal state. - acd = AgentConnectorDataType("0", "0", _input_dict) - # make sure the state is reset since we are only applying the - # preprocessor - pp.reset(env_id="0") - ac_o = pp([acd])[0] - observation = ac_o.data[Columns.OBS] + # TODO (Kourosh): Create a new util method for algorithm that + # computes actions based on raw inputs from env and can keep track + # of its own internal state. + acd = AgentConnectorDataType("0", "0", _input_dict) + # make sure the state is reset since we are only applying the + # preprocessor + pp.reset(env_id="0") + ac_o = pp([acd])[0] + observation = ac_o.data[Columns.OBS] # Input-dict. if input_dict is not None: @@ -2225,7 +2212,7 @@ def compute_actions( full_fetch: bool = False, explore: Optional[bool] = None, timestep: Optional[int] = None, - episodes: Optional[List[Episode]] = None, + episodes=None, unsquash_actions: Optional[bool] = None, clip_actions: Optional[bool] = None, **kwargs, diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index 6b7bd8cea053f..d444c43476834 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -106,7 +106,7 @@ from ray.rllib.core.learner import Learner from ray.rllib.core.learner.learner_group import LearnerGroup from ray.rllib.core.rl_module.rl_module import RLModule - from ray.rllib.evaluation.episode import Episode as OldEpisode + from ray.rllib.utils.typing import EpisodeType logger = logging.getLogger(__name__) @@ -370,7 +370,6 @@ def __init__(self, algo_class: Optional[type] = None): self.observation_filter = "NoFilter" self.update_worker_filter_stats = True self.use_worker_filter_stats = True - self.enable_connectors = True self.sampler_perf_stats_ema_coef = None # `self.learners()` @@ -572,6 +571,7 @@ def __init__(self, algo_class: Optional[type] = None): # TODO: Remove, once all deprecation_warning calls upon using these keys # have been removed. # === Deprecated keys === + self.enable_connectors = DEPRECATED_VALUE self.simple_optimizer = DEPRECATED_VALUE self.monitor = DEPRECATED_VALUE self.evaluation_num_episodes = DEPRECATED_VALUE @@ -1758,7 +1758,6 @@ def env_runners( exploration_config: Optional[dict] = NotProvided, # @OldAPIStack create_env_on_local_worker: Optional[bool] = NotProvided, # @OldAPIStack sample_collector: Optional[Type[SampleCollector]] = NotProvided, # @OldAPIStack - enable_connectors: Optional[bool] = NotProvided, # @OldAPIStack remote_worker_envs: Optional[bool] = NotProvided, # @OldAPIStack remote_env_batch_wait_ms: Optional[float] = NotProvided, # @OldAPIStack preprocessor_pref: Optional[str] = NotProvided, # @OldAPIStack @@ -1776,6 +1775,8 @@ def env_runners( worker_health_probe_timeout_s=DEPRECATED_VALUE, worker_restore_timeout_s=DEPRECATED_VALUE, synchronize_filter=DEPRECATED_VALUE, + # deprecated + enable_connectors=DEPRECATED_VALUE, ) -> "AlgorithmConfig": """Sets the rollout worker configuration. @@ -1822,9 +1823,6 @@ def env_runners( because it doesn't have to sample (done by remote_workers; worker_indices > 0) nor evaluate (done by evaluation workers; see below). - enable_connectors: Use connector based environment runner, so that all - preprocessing of obs and postprocessing of actions are done in agent - and action connectors. env_to_module_connector: A callable taking an Env as input arg and returning an env-to-module ConnectorV2 (might be a pipeline) object. module_to_env_connector: A callable taking an Env and an RLModule as input @@ -1933,29 +1931,29 @@ def env_runners( Returns: This updated AlgorithmConfig object. """ + if enable_connectors != DEPRECATED_VALUE: + deprecation_warning( + old="AlgorithmConfig.env_runners(enable_connectors=...)", + error=False, + ) if num_rollout_workers != DEPRECATED_VALUE: deprecation_warning( old="AlgorithmConfig.env_runners(num_rollout_workers)", new="AlgorithmConfig.env_runners(num_env_runners)", - error=False, + error=True, ) - self.num_env_runners = num_rollout_workers if num_envs_per_worker != DEPRECATED_VALUE: deprecation_warning( old="AlgorithmConfig.env_runners(num_envs_per_worker)", new="AlgorithmConfig.env_runners(num_envs_per_env_runner)", - error=False, + error=True, ) - self.num_envs_per_env_runner = num_envs_per_worker if validate_workers_after_construction != DEPRECATED_VALUE: deprecation_warning( old="AlgorithmConfig.env_runners(validate_workers_after_construction)", new="AlgorithmConfig.env_runners(validate_env_runners_after_" "construction)", - error=False, - ) - self.validate_env_runners_after_construction = ( - validate_workers_after_construction + error=True, ) if env_runner_cls is not NotProvided: @@ -1987,8 +1985,6 @@ def env_runners( self.sample_collector = sample_collector if create_env_on_local_worker is not NotProvided: self.create_env_on_local_worker = create_env_on_local_worker - if enable_connectors is not NotProvided: - self.enable_connectors = enable_connectors if env_to_module_connector is not NotProvided: self._env_to_module_connector = env_to_module_connector if module_to_env_connector is not NotProvided: @@ -2874,7 +2870,7 @@ def multi_agent( ] = NotProvided, policy_map_capacity: Optional[int] = NotProvided, policy_mapping_fn: Optional[ - Callable[[AgentID, "OldEpisode"], PolicyID] + Callable[[AgentID, "EpisodeType"], PolicyID] ] = NotProvided, policies_to_train: Optional[ Union[Collection[PolicyID], Callable[[PolicyID, SampleBatchType], bool]] @@ -4466,14 +4462,6 @@ def _validate_new_api_stack_settings(self): "to False (old API stack), instead." ) - # New API stack (RLModule, Learner APIs) only works with connectors. - if not self.enable_connectors: - raise ValueError( - "The new API stack (RLModule and Learner APIs) only works with " - "connectors! Please enable connectors via " - "`config.env_runners(enable_connectors=True)`." - ) - # LR-schedule checking. Scheduler.validate( fixed_value_or_schedule=self.lr, diff --git a/rllib/algorithms/appo/appo_tf_policy.py b/rllib/algorithms/appo/appo_tf_policy.py index 9129dde30f829..4af36f099df92 100644 --- a/rllib/algorithms/appo/appo_tf_policy.py +++ b/rllib/algorithms/appo/appo_tf_policy.py @@ -17,7 +17,6 @@ VTraceClipGradients, VTraceOptimizer, ) -from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.postprocessing import ( compute_bootstrap_value, compute_gae_for_sample_batch, @@ -362,7 +361,7 @@ def postprocess_trajectory( self, sample_batch: SampleBatch, other_agent_batches: Optional[SampleBatch] = None, - episode: Optional["Episode"] = None, + episode=None, ): # Call super's postprocess_trajectory first. # sample_batch = super().postprocess_trajectory( diff --git a/rllib/algorithms/appo/appo_torch_policy.py b/rllib/algorithms/appo/appo_torch_policy.py index 56ab8f11267e1..34a09d10373f9 100644 --- a/rllib/algorithms/appo/appo_torch_policy.py +++ b/rllib/algorithms/appo/appo_torch_policy.py @@ -17,7 +17,6 @@ make_time_major, VTraceOptimizer, ) -from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.postprocessing import ( compute_bootstrap_value, compute_gae_for_sample_batch, @@ -378,7 +377,7 @@ def postprocess_trajectory( self, sample_batch: SampleBatch, other_agent_batches: Optional[Dict[Any, SampleBatch]] = None, - episode: Optional["Episode"] = None, + episode=None, ): # Call super's postprocess_trajectory first. # sample_batch = super().postprocess_trajectory( diff --git a/rllib/algorithms/callbacks.py b/rllib/algorithms/callbacks.py index 29ac52b871594..2931a269e35d7 100644 --- a/rllib/algorithms/callbacks.py +++ b/rllib/algorithms/callbacks.py @@ -11,7 +11,6 @@ from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.env_context import EnvContext -from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.episode_v2 import EpisodeV2 from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.policy import Policy @@ -227,7 +226,7 @@ def on_episode_created( self, *, # TODO (sven): Deprecate Episode/EpisodeV2 with new API stack. - episode: Union[EpisodeType, Episode, EpisodeV2], + episode: Union[EpisodeType, EpisodeV2], # TODO (sven): Deprecate this arg new API stack (in favor of `env_runner`). worker: Optional["EnvRunner"] = None, env_runner: Optional["EnvRunner"] = None, @@ -284,7 +283,7 @@ def on_episode_created( def on_episode_start( self, *, - episode: Union[EpisodeType, Episode, EpisodeV2], + episode: Union[EpisodeType, EpisodeV2], env_runner: Optional["EnvRunner"] = None, metrics_logger: Optional[MetricsLogger] = None, env: Optional[gym.Env] = None, @@ -326,7 +325,7 @@ def on_episode_start( def on_episode_step( self, *, - episode: Union[EpisodeType, Episode, EpisodeV2], + episode: Union[EpisodeType, EpisodeV2], env_runner: Optional["EnvRunner"] = None, metrics_logger: Optional[MetricsLogger] = None, env: Optional[gym.Env] = None, @@ -369,7 +368,7 @@ def on_episode_step( def on_episode_end( self, *, - episode: Union[EpisodeType, Episode, EpisodeV2], + episode: Union[EpisodeType, EpisodeV2], env_runner: Optional["EnvRunner"] = None, metrics_logger: Optional[MetricsLogger] = None, env: Optional[gym.Env] = None, @@ -473,7 +472,7 @@ def on_postprocess_trajectory( self, *, worker: "EnvRunner", - episode: Episode, + episode, agent_id: AgentID, policy_id: PolicyID, policies: Dict[PolicyID, Policy], @@ -603,7 +602,7 @@ def __init__(self): def on_episode_end( self, *, - episode: Union[EpisodeType, Episode, EpisodeV2], + episode: Union[EpisodeType, EpisodeV2], env_runner: Optional["EnvRunner"] = None, metrics_logger: Optional[MetricsLogger] = None, env: Optional[gym.Env] = None, @@ -743,7 +742,7 @@ def on_postprocess_trajectory( self, *, worker: "EnvRunner", - episode: Episode, + episode, agent_id: AgentID, policy_id: PolicyID, policies: Dict[PolicyID, Policy], diff --git a/rllib/algorithms/impala/impala.py b/rllib/algorithms/impala/impala.py index 39b9a1731e5e6..fc5c9ae4b0f72 100644 --- a/rllib/algorithms/impala/impala.py +++ b/rllib/algorithms/impala/impala.py @@ -1222,7 +1222,6 @@ def aggregate_into_larger_batch(): # AgentCollectors, RolloutWorkers, Policies, TrajectoryView API, etc..): if ( self.config.batch_mode == "truncate_episodes" - and self.config.enable_connectors and self.config.restart_failed_env_runners ): if any( diff --git a/rllib/algorithms/impala/impala_tf_policy.py b/rllib/algorithms/impala/impala_tf_policy.py index 44038e692dd2a..d06d0065b124c 100644 --- a/rllib/algorithms/impala/impala_tf_policy.py +++ b/rllib/algorithms/impala/impala_tf_policy.py @@ -8,7 +8,6 @@ from typing import Dict, List, Optional, Type, Union from ray.rllib.algorithms.impala import vtrace_tf as vtrace -from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.postprocessing import compute_bootstrap_value from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_action_dist import Categorical, TFActionDistribution @@ -416,7 +415,7 @@ def postprocess_trajectory( self, sample_batch: SampleBatch, other_agent_batches: Optional[SampleBatch] = None, - episode: Optional["Episode"] = None, + episode=None, ): # Call super's postprocess_trajectory first. # sample_batch = super().postprocess_trajectory( diff --git a/rllib/algorithms/impala/impala_torch_policy.py b/rllib/algorithms/impala/impala_torch_policy.py index 579c85a392d81..51a264a69cd7f 100644 --- a/rllib/algorithms/impala/impala_torch_policy.py +++ b/rllib/algorithms/impala/impala_torch_policy.py @@ -4,7 +4,6 @@ from typing import Dict, List, Optional, Type, Union import ray -from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.postprocessing import compute_bootstrap_value from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.action_dist import ActionDistribution @@ -398,7 +397,7 @@ def postprocess_trajectory( self, sample_batch: SampleBatch, other_agent_batches: Optional[SampleBatch] = None, - episode: Optional["Episode"] = None, + episode=None, ): # Call super's postprocess_trajectory first. # sample_batch = super().postprocess_trajectory( diff --git a/rllib/algorithms/marwil/marwil_tf_policy.py b/rllib/algorithms/marwil/marwil_tf_policy.py index 5c5194b76ab80..5f75a8424c766 100644 --- a/rllib/algorithms/marwil/marwil_tf_policy.py +++ b/rllib/algorithms/marwil/marwil_tf_policy.py @@ -1,7 +1,6 @@ import logging from typing import Any, Dict, List, Optional, Type, Union -from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.postprocessing import compute_advantages, Postprocessing from ray.rllib.models.action_dist import ActionDistribution from ray.rllib.models.modelv2 import ModelV2 @@ -38,7 +37,7 @@ def postprocess_trajectory( self, sample_batch: SampleBatch, other_agent_batches: Optional[Dict[Any, SampleBatch]] = None, - episode: Optional["Episode"] = None, + episode=None, ): sample_batch = super().postprocess_trajectory( sample_batch, other_agent_batches, episode diff --git a/rllib/algorithms/sac/sac_tf_policy.py b/rllib/algorithms/sac/sac_tf_policy.py index 3a3072986e159..5ec142ec0a0d9 100644 --- a/rllib/algorithms/sac/sac_tf_policy.py +++ b/rllib/algorithms/sac/sac_tf_policy.py @@ -17,7 +17,6 @@ ) from ray.rllib.algorithms.sac.sac_tf_model import SACTFModel from ray.rllib.algorithms.sac.sac_torch_model import SACTorchModel -from ray.rllib.evaluation.episode import Episode from ray.rllib.models import ModelCatalog from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_action_dist import ( @@ -123,7 +122,7 @@ def postprocess_trajectory( policy: Policy, sample_batch: SampleBatch, other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, - episode: Optional[Episode] = None, + episode=None, ) -> SampleBatch: """Postprocesses a trajectory and returns the processed trajectory. diff --git a/rllib/algorithms/tests/test_algorithm_config.py b/rllib/algorithms/tests/test_algorithm_config.py index 11d55a741be32..b88f16636698c 100644 --- a/rllib/algorithms/tests/test_algorithm_config.py +++ b/rllib/algorithms/tests/test_algorithm_config.py @@ -177,7 +177,6 @@ def test_rl_module_api(self): ) .environment("CartPole-v1") .framework("torch") - .env_runners(enable_connectors=True) ) self.assertEqual(config.rl_module_spec.module_class, PPOTorchRLModule) @@ -239,7 +238,6 @@ def test_learner_api(self): enable_env_runner_and_connector_v2=True, ) .environment("CartPole-v1") - .env_runners(enable_connectors=True) ) self.assertEqual(config.learner_class, PPOTorchLearner) diff --git a/rllib/algorithms/tests/test_callbacks_old_api_stack.py b/rllib/algorithms/tests/test_callbacks_old_api_stack.py index da3888f756ca7..0d72cd7abceb6 100644 --- a/rllib/algorithms/tests/test_callbacks_old_api_stack.py +++ b/rllib/algorithms/tests/test_callbacks_old_api_stack.py @@ -4,7 +4,6 @@ import ray from ray.rllib.algorithms.callbacks import DefaultCallbacks, make_multi_callbacks from ray.rllib.algorithms.ppo import PPOConfig -from ray.rllib.evaluation.episode import Episode from ray.rllib.examples.envs.classes.random_env import RandomEnv @@ -55,11 +54,7 @@ def on_episode_created( # Make sure the passed in episode is really brand new. assert episode.env_id == env_index - if isinstance(episode, Episode): - assert episode.length == 0 - assert episode.started is False - else: - assert episode.length == -1 + assert episode.length == -1 assert episode.worker is worker diff --git a/rllib/env/env_runner_group.py b/rllib/env/env_runner_group.py index f7697bad2beed..b6395c1c4a5f7 100644 --- a/rllib/env/env_runner_group.py +++ b/rllib/env/env_runner_group.py @@ -111,8 +111,8 @@ def __init__( """ if num_workers != DEPRECATED_VALUE or local_worker != DEPRECATED_VALUE: deprecation_warning( - old="WorkerSet(num_workers=... OR local_worker=...)", - new="EnvRunnerGroup(num_env_runners=... AND local_env_runner=...)", + old="WorkerSet(num_workers=..., local_worker=...)", + new="EnvRunnerGroup(num_env_runners=..., local_env_runner=...)", error=True, ) @@ -862,7 +862,7 @@ def foreach_worker( synchronous execution). return_obj_refs: whether to return ObjectRef instead of actual results. Note, for fault tolerance reasons, these returned ObjectRefs should - never be resolved with ray.get() outside of this WorkerSet. + never be resolved with ray.get() outside of this EnvRunnerGroup. mark_healthy: Whether to mark all those workers healthy again that are currently marked unhealthy AND that returned results from the remote call (within the given `timeout_seconds`). @@ -936,7 +936,7 @@ def foreach_worker_with_id( timeout_seconds: Time to wait for results. Default is None. return_obj_refs: whether to return ObjectRef instead of actual results. Note, for fault tolerance reasons, these returned ObjectRefs should - never be resolved with ray.get() outside of this WorkerSet. + never be resolved with ray.get() outside of this EnvRunnerGroup. mark_healthy: Whether to mark all those workers healthy again that are currently marked unhealthy AND that returned results from the remote call (within the given `timeout_seconds`). diff --git a/rllib/env/tests/test_multi_agent_env.py b/rllib/env/tests/test_multi_agent_env.py index 9febd9cc05d60..98caf12c57fa1 100644 --- a/rllib/env/tests/test_multi_agent_env.py +++ b/rllib/env/tests/test_multi_agent_env.py @@ -701,7 +701,7 @@ def test_multi_agent_with_sometimes_zero_agents_observing(self): config = ( PPOConfig() .environment("sometimes_zero_agents") - .env_runners(num_env_runners=0, enable_connectors=True) + .env_runners(num_env_runners=0) ) algo = config.build() for i in range(4): diff --git a/rllib/evaluation/__init__.py b/rllib/evaluation/__init__.py index 50309f30f3bf1..08f5bd48be3db 100644 --- a/rllib/evaluation/__init__.py +++ b/rllib/evaluation/__init__.py @@ -1,4 +1,3 @@ -from ray.rllib.evaluation.episode import Episode from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.sample_batch_builder import ( SampleBatchBuilder, @@ -18,5 +17,4 @@ "SyncSampler", "compute_advantages", "collect_metrics", - "Episode", ] diff --git a/rllib/evaluation/collectors/sample_collector.py b/rllib/evaluation/collectors/sample_collector.py index 867b7ec1ae232..3e977bee8b0fd 100644 --- a/rllib/evaluation/collectors/sample_collector.py +++ b/rllib/evaluation/collectors/sample_collector.py @@ -2,7 +2,6 @@ from abc import ABCMeta, abstractmethod from typing import TYPE_CHECKING, Dict, List, Optional, Union -from ray.rllib.evaluation.episode import Episode from ray.rllib.policy.policy_map import PolicyMap from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch from ray.rllib.utils.annotations import OldAPIStack @@ -61,7 +60,7 @@ def __init__(self, def add_init_obs( self, *, - episode: Episode, + episode, agent_id: AgentID, policy_id: PolicyID, init_obs: TensorType, @@ -160,7 +159,7 @@ def add_action_reward_next_obs( raise NotImplementedError @abstractmethod - def episode_step(self, episode: Episode) -> None: + def episode_step(self, episode) -> None: """Increases the episode step counter (across all agents) by one. Args: @@ -240,7 +239,7 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> \ @abstractmethod def postprocess_episode( self, - episode: Episode, + episode, is_done: bool = False, check_dones: bool = False, build: bool = False, diff --git a/rllib/evaluation/collectors/simple_list_collector.py b/rllib/evaluation/collectors/simple_list_collector.py index 68dc3f638657e..45c07baeb6077 100644 --- a/rllib/evaluation/collectors/simple_list_collector.py +++ b/rllib/evaluation/collectors/simple_list_collector.py @@ -8,7 +8,6 @@ from ray.rllib.env.base_env import _DUMMY_AGENT_ID from ray.rllib.evaluation.collectors.sample_collector import SampleCollector from ray.rllib.evaluation.collectors.agent_collector import AgentCollector -from ray.rllib.evaluation.episode import Episode from ray.rllib.policy.policy import Policy from ray.rllib.policy.policy_map import PolicyMap from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch, concat_samples @@ -166,10 +165,10 @@ def __init__( # episode. self.agent_steps: Dict[EpisodeID, int] = collections.defaultdict(int) # Maps episode ID to Episode. - self.episodes: Dict[EpisodeID, Episode] = {} + self.episodes = {} @override(SampleCollector) - def episode_step(self, episode: Episode) -> None: + def episode_step(self, episode) -> None: episode_id = episode.episode_id # In the rase case that an "empty" step is taken at the beginning of # the episode (none of the agents has an observation in the obs-dict @@ -219,7 +218,7 @@ def episode_step(self, episode: Episode) -> None: def add_init_obs( self, *, - episode: Episode, + episode, agent_id: AgentID, env_id: EnvID, policy_id: PolicyID, @@ -419,7 +418,7 @@ def get_inference_input_dict(self, policy_id: PolicyID) -> Dict[str, TensorType] @override(SampleCollector) def postprocess_episode( self, - episode: Episode, + episode, is_done: bool = False, check_dones: bool = False, build: bool = False, @@ -588,9 +587,7 @@ def postprocess_episode( if build: return self._build_multi_agent_batch(episode) - def _build_multi_agent_batch( - self, episode: Episode - ) -> Union[MultiAgentBatch, SampleBatch]: + def _build_multi_agent_batch(self, episode) -> Union[MultiAgentBatch, SampleBatch]: ma_batch = {} for pid, collector in episode.batch_builder.policy_collectors.items(): diff --git a/rllib/evaluation/episode.py b/rllib/evaluation/episode.py deleted file mode 100644 index bb07537d8fac3..0000000000000 --- a/rllib/evaluation/episode.py +++ /dev/null @@ -1,451 +0,0 @@ -import random -from collections import defaultdict -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple - -import numpy as np -import tree # pip install dm_tree - -from ray.rllib.env.base_env import _DUMMY_AGENT_ID -from ray.rllib.policy.policy_map import PolicyMap -from ray.rllib.utils.annotations import OldAPIStack -from ray.rllib.utils.deprecation import deprecation_warning -from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray -from ray.rllib.utils.typing import ( - AgentID, - EnvActionType, - EnvID, - EnvInfoDict, - EnvObsType, - PolicyID, - SampleBatchType, -) -from ray.util import log_once - -if TYPE_CHECKING: - from ray.rllib.evaluation.rollout_worker import RolloutWorker - from ray.rllib.evaluation.sample_batch_builder import MultiAgentSampleBatchBuilder - - -@OldAPIStack -class Episode: - """Tracks the current state of a (possibly multi-agent) episode. - - Attributes: - new_batch_builder: Create a new MultiAgentSampleBatchBuilder. - add_extra_batch: Return a built MultiAgentBatch to the sampler. - batch_builder: Batch builder for the current episode. - total_reward: Summed reward across all agents in this episode. - length: Length of this episode. - episode_id: Unique id identifying this trajectory. - agent_rewards: Summed rewards broken down by agent. - custom_metrics: Dict where the you can add custom metrics. - user_data: Dict that you can use for temporary storage. E.g. - in between two custom callbacks referring to the same episode. - hist_data: Dict mapping str keys to List[float] for storage of - per-timestep float data throughout the episode. - - Use case 1: Model-based rollouts in multi-agent: - A custom compute_actions() function in a policy can inspect the - current episode state and perform a number of rollouts based on the - policies and state of other agents in the environment. - - Use case 2: Returning extra rollouts data. - The model rollouts can be returned back to the sampler by calling: - - .. testcode:: - :skipif: True - - batch = episode.new_batch_builder() - for each transition: - batch.add_values(...) # see sampler for usage - episode.extra_batches.add(batch.build_and_reset()) - """ - - def __init__( - self, - policies: PolicyMap, - policy_mapping_fn: Callable[[AgentID, "Episode", "RolloutWorker"], PolicyID], - batch_builder_factory: Callable[[], "MultiAgentSampleBatchBuilder"], - extra_batch_callback: Callable[[SampleBatchType], None], - env_id: EnvID, - *, - worker: Optional["RolloutWorker"] = None, - ): - """Initializes an Episode instance. - - Args: - policies: The PolicyMap object (mapping PolicyIDs to Policy - objects) to use for determining, which policy is used for - which agent. - policy_mapping_fn: The mapping function mapping AgentIDs to - PolicyIDs. - batch_builder_factory: - extra_batch_callback: - env_id: The environment's ID in which this episode runs. - worker: The RolloutWorker instance, in which this episode runs. - """ - self.new_batch_builder: Callable[ - [], "MultiAgentSampleBatchBuilder" - ] = batch_builder_factory - self.add_extra_batch: Callable[[SampleBatchType], None] = extra_batch_callback - self.batch_builder: "MultiAgentSampleBatchBuilder" = batch_builder_factory() - self.total_reward: float = 0.0 - self.length: int = 0 - self.started = False - self.episode_id: int = random.randrange(int(1e18)) - self.env_id = env_id - self.worker = worker - self.agent_rewards: Dict[Tuple[AgentID, PolicyID], float] = defaultdict(float) - self.custom_metrics: Dict[str, float] = {} - self.user_data: Dict[str, Any] = {} - self.hist_data: Dict[str, List[float]] = {} - self.media: Dict[str, Any] = {} - self.policy_map: PolicyMap = policies - self._policies = self.policy_map # backward compatibility - self.policy_mapping_fn: Callable[ - [AgentID, "Episode", "RolloutWorker"], PolicyID - ] = policy_mapping_fn - self.is_faulty = False - self._next_agent_index: int = 0 - self._agent_to_index: Dict[AgentID, int] = {} - self._agent_to_policy: Dict[AgentID, PolicyID] = {} - self._agent_to_rnn_state: Dict[AgentID, List[Any]] = {} - self._agent_to_last_obs: Dict[AgentID, EnvObsType] = {} - self._agent_to_last_raw_obs: Dict[AgentID, EnvObsType] = {} - self._agent_to_last_terminated: Dict[AgentID, bool] = {} - self._agent_to_last_truncated: Dict[AgentID, bool] = {} - self._agent_to_last_info: Dict[AgentID, EnvInfoDict] = {} - self._agent_to_last_action: Dict[AgentID, EnvActionType] = {} - self._agent_to_last_extra_action_outs: Dict[AgentID, dict] = {} - self._agent_to_prev_action: Dict[AgentID, EnvActionType] = {} - self._agent_reward_history: Dict[AgentID, List[int]] = defaultdict(list) - - def policy_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> PolicyID: - """Returns and stores the policy ID for the specified agent. - - If the agent is new, the policy mapping fn will be called to bind the - agent to a policy for the duration of the entire episode (even if the - policy_mapping_fn is changed in the meantime!). - - Args: - agent_id: The agent ID to lookup the policy ID for. - - Returns: - The policy ID for the specified agent. - """ - - # Perform a new policy_mapping_fn lookup and bind AgentID for the - # duration of this episode to the returned PolicyID. - if agent_id not in self._agent_to_policy: - # Try new API: pass in agent_id and episode as named args. - # New signature should be: (agent_id, episode, worker, **kwargs) - try: - policy_id = self._agent_to_policy[agent_id] = self.policy_mapping_fn( - agent_id, self, worker=self.worker - ) - except TypeError as e: - if ( - "positional argument" in e.args[0] - or "unexpected keyword argument" in e.args[0] - ): - if log_once("policy_mapping_new_signature"): - deprecation_warning( - old="policy_mapping_fn(agent_id)", - new="policy_mapping_fn(agent_id, episode, " - "worker, **kwargs)", - ) - policy_id = self._agent_to_policy[ - agent_id - ] = self.policy_mapping_fn(agent_id) - else: - raise e - # Use already determined PolicyID. - else: - policy_id = self._agent_to_policy[agent_id] - - # PolicyID not found in policy map -> Error. - if policy_id not in self.policy_map: - raise KeyError( - "policy_mapping_fn returned invalid policy id " f"'{policy_id}'!" - ) - return policy_id - - def last_observation_for( - self, agent_id: AgentID = _DUMMY_AGENT_ID - ) -> Optional[EnvObsType]: - """Returns the last observation for the specified AgentID. - - Args: - agent_id: The agent's ID to get the last observation for. - - Returns: - Last observation the specified AgentID has seen. None in case - the agent has never made any observations in the episode. - """ - - return self._agent_to_last_obs.get(agent_id) - - def last_raw_obs_for( - self, agent_id: AgentID = _DUMMY_AGENT_ID - ) -> Optional[EnvObsType]: - """Returns the last un-preprocessed obs for the specified AgentID. - - Args: - agent_id: The agent's ID to get the last un-preprocessed - observation for. - - Returns: - Last un-preprocessed observation the specified AgentID has seen. - None in case the agent has never made any observations in the - episode. - """ - return self._agent_to_last_raw_obs.get(agent_id) - - def last_info_for( - self, agent_id: AgentID = _DUMMY_AGENT_ID - ) -> Optional[EnvInfoDict]: - """Returns the last info for the specified AgentID. - - Args: - agent_id: The agent's ID to get the last info for. - - Returns: - Last info dict the specified AgentID has seen. - None in case the agent has never made any observations in the - episode. - """ - return self._agent_to_last_info.get(agent_id) - - def last_action_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType: - """Returns the last action for the specified AgentID, or zeros. - - The "last" action is the most recent one taken by the agent. - - Args: - agent_id: The agent's ID to get the last action for. - - Returns: - Last action the specified AgentID has executed. - Zeros in case the agent has never performed any actions in the - episode. - """ - policy_id = self.policy_for(agent_id) - policy = self.policy_map[policy_id] - - # Agent has already taken at least one action in the episode. - if agent_id in self._agent_to_last_action: - if policy.config.get("_disable_action_flattening"): - return self._agent_to_last_action[agent_id] - else: - return flatten_to_single_ndarray(self._agent_to_last_action[agent_id]) - # Agent has not acted yet, return all zeros. - else: - if policy.config.get("_disable_action_flattening"): - return tree.map_structure( - lambda s: np.zeros_like(s.sample(), s.dtype) - if hasattr(s, "dtype") - else np.zeros_like(s.sample()), - policy.action_space_struct, - ) - else: - flat = flatten_to_single_ndarray(policy.action_space.sample()) - if hasattr(policy.action_space, "dtype"): - return np.zeros_like(flat, dtype=policy.action_space.dtype) - return np.zeros_like(flat) - - def prev_action_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> EnvActionType: - """Returns the previous action for the specified agent, or zeros. - - The "previous" action is the one taken one timestep before the - most recent action taken by the agent. - - Args: - agent_id: The agent's ID to get the previous action for. - - Returns: - Previous action the specified AgentID has executed. - Zero in case the agent has never performed any actions (or only - one) in the episode. - """ - policy_id = self.policy_for(agent_id) - policy = self.policy_map[policy_id] - - # We are at t > 1 -> There has been a previous action by this agent. - if agent_id in self._agent_to_prev_action: - if policy.config.get("_disable_action_flattening"): - return self._agent_to_prev_action[agent_id] - else: - return flatten_to_single_ndarray(self._agent_to_prev_action[agent_id]) - # We're at t <= 1, so return all zeros. - else: - if policy.config.get("_disable_action_flattening"): - return tree.map_structure( - lambda a: np.zeros_like(a, a.dtype) - if hasattr(a, "dtype") # noqa - else np.zeros_like(a), # noqa - self.last_action_for(agent_id), - ) - else: - return np.zeros_like(self.last_action_for(agent_id)) - - def last_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float: - """Returns the last reward for the specified agent, or zero. - - The "last" reward is the one received most recently by the agent. - - Args: - agent_id: The agent's ID to get the last reward for. - - Returns: - Last reward for the the specified AgentID. - Zero in case the agent has never performed any actions - (and thus received rewards) in the episode. - """ - - history = self._agent_reward_history[agent_id] - # We are at t > 0 -> Return previously received reward. - if len(history) >= 1: - return history[-1] - # We're at t=0, so there is no previous reward, just return zero. - else: - return 0.0 - - def prev_reward_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> float: - """Returns the previous reward for the specified agent, or zero. - - The "previous" reward is the one received one timestep before the - most recently received reward of the agent. - - Args: - agent_id: The agent's ID to get the previous reward for. - - Returns: - Previous reward for the the specified AgentID. - Zero in case the agent has never performed any actions (or only - one) in the episode. - """ - - history = self._agent_reward_history[agent_id] - # We are at t > 1 -> Return reward prior to most recent (last) one. - if len(history) >= 2: - return history[-2] - # We're at t <= 1, so there is no previous reward, just return zero. - else: - return 0.0 - - def rnn_state_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> List[Any]: - """Returns the last RNN state for the specified agent. - - Args: - agent_id: The agent's ID to get the most recent RNN state for. - - Returns: - Most recent RNN state of the the specified AgentID. - """ - - if agent_id not in self._agent_to_rnn_state: - policy_id = self.policy_for(agent_id) - policy = self.policy_map[policy_id] - self._agent_to_rnn_state[agent_id] = policy.get_initial_state() - return self._agent_to_rnn_state[agent_id] - - def last_terminated_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> bool: - """Returns the last `terminated` flag for the specified AgentID. - - Args: - agent_id: The agent's ID to get the last `terminated` flag for. - - Returns: - Last terminated flag for the specified AgentID. - """ - if agent_id not in self._agent_to_last_terminated: - self._agent_to_last_terminated[agent_id] = False - return self._agent_to_last_terminated[agent_id] - - def last_truncated_for(self, agent_id: AgentID = _DUMMY_AGENT_ID) -> bool: - """Returns the last `truncated` flag for the specified AgentID. - - Args: - agent_id: The agent's ID to get the last `truncated` flag for. - - Returns: - Last truncated flag for the specified AgentID. - """ - if agent_id not in self._agent_to_last_truncated: - self._agent_to_last_truncated[agent_id] = False - return self._agent_to_last_truncated[agent_id] - - def last_extra_action_outs_for( - self, - agent_id: AgentID = _DUMMY_AGENT_ID, - ) -> dict: - """Returns the last extra-action outputs for the specified agent. - - This data is returned by a call to - `Policy.compute_actions_from_input_dict` as the 3rd return value - (1st return value = action; 2nd return value = RNN state outs). - - Args: - agent_id: The agent's ID to get the last extra-action outs for. - - Returns: - The last extra-action outs for the specified AgentID. - """ - return self._agent_to_last_extra_action_outs[agent_id] - - def get_agents(self) -> List[AgentID]: - """Returns list of agent IDs that have appeared in this episode. - - Returns: - The list of all agent IDs that have appeared so far in this - episode. - """ - return list(self._agent_to_index.keys()) - - def _add_agent_rewards(self, reward_dict: Dict[AgentID, float]) -> None: - for agent_id, reward in reward_dict.items(): - if reward is not None: - self.agent_rewards[agent_id, self.policy_for(agent_id)] += reward - self.total_reward += reward - self._agent_reward_history[agent_id].append(reward) - - def _set_rnn_state(self, agent_id, rnn_state): - self._agent_to_rnn_state[agent_id] = rnn_state - - def _set_last_observation(self, agent_id, obs): - self._agent_to_last_obs[agent_id] = obs - - def _set_last_raw_obs(self, agent_id, obs): - self._agent_to_last_raw_obs[agent_id] = obs - - def _set_last_terminated(self, agent_id, terminated): - self._agent_to_last_terminated[agent_id] = terminated - - def _set_last_truncated(self, agent_id, truncated): - self._agent_to_last_truncated[agent_id] = truncated - - def _set_last_info(self, agent_id, info): - self._agent_to_last_info[agent_id] = info - - def _set_last_action(self, agent_id, action): - if agent_id in self._agent_to_last_action: - self._agent_to_prev_action[agent_id] = self._agent_to_last_action[agent_id] - self._agent_to_last_action[agent_id] = action - - def _set_last_extra_action_outs(self, agent_id, pi_info): - self._agent_to_last_extra_action_outs[agent_id] = pi_info - - def _agent_index(self, agent_id): - if agent_id not in self._agent_to_index: - self._agent_to_index[agent_id] = self._next_agent_index - self._next_agent_index += 1 - return self._agent_to_index[agent_id] - - @property - def _policy_mapping_fn(self): - deprecation_warning( - old="Episode._policy_mapping_fn", - new="Episode.policy_mapping_fn", - error=True, - ) - return self.policy_mapping_fn diff --git a/rllib/evaluation/observation_function.py b/rllib/evaluation/observation_function.py index fe33b3d60666f..c670ed5192cfe 100644 --- a/rllib/evaluation/observation_function.py +++ b/rllib/evaluation/observation_function.py @@ -2,7 +2,7 @@ from ray.rllib.env import BaseEnv from ray.rllib.policy import Policy -from ray.rllib.evaluation import Episode, RolloutWorker +from ray.rllib.evaluation import RolloutWorker from ray.rllib.utils.annotations import OldAPIStack from ray.rllib.utils.framework import TensorType from ray.rllib.utils.typing import AgentID, PolicyID @@ -28,7 +28,7 @@ def __call__( worker: RolloutWorker, base_env: BaseEnv, policies: Dict[PolicyID, Policy], - episode: Episode, + episode, **kw ) -> Dict[AgentID, TensorType]: """Callback run on each environment step to observe the environment. diff --git a/rllib/evaluation/postprocessing.py b/rllib/evaluation/postprocessing.py index 82ef79e70facc..4b0a6c79bd602 100644 --- a/rllib/evaluation/postprocessing.py +++ b/rllib/evaluation/postprocessing.py @@ -2,7 +2,6 @@ import scipy.signal from typing import Dict, Optional -from ray.rllib.evaluation.episode import Episode from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils.annotations import DeveloperAPI, OldAPIStack @@ -157,7 +156,7 @@ def compute_gae_for_sample_batch( policy: Policy, sample_batch: SampleBatch, other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, - episode: Optional[Episode] = None, + episode=None, ) -> SampleBatch: """Adds GAE (generalized advantage estimations) to a trajectory. diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 07678502923c7..b2a803389eee0 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -20,9 +20,7 @@ Union, ) -import numpy as np -import tree # pip install dm_tree -from gymnasium.spaces import Discrete, MultiDiscrete, Space +from gymnasium.spaces import Space import ray from ray import ObjectRef @@ -70,7 +68,7 @@ from ray.rllib.utils.annotations import OldAPIStack, override from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary from ray.rllib.utils.error import ERR_MSG_NO_GPUS, HOWTO_CHANGE_CONFIG -from ray.rllib.utils.filter import Filter, NoFilter, get_filter +from ray.rllib.utils.filter import Filter, NoFilter from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.utils.from_config import from_config from ray.rllib.utils.policy import create_policy_for_framework @@ -99,7 +97,6 @@ if TYPE_CHECKING: from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.algorithms.callbacks import DefaultCallbacks # noqa - from ray.rllib.evaluation.episode import Episode tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -1089,7 +1086,7 @@ def add_policy( action_space: Optional[Space] = None, config: Optional[PartialAlgorithmConfigDict] = None, policy_state: Optional[PolicyState] = None, - policy_mapping_fn: Optional[Callable[[AgentID, "Episode"], PolicyID]] = None, + policy_mapping_fn=None, policies_to_train: Optional[ Union[Collection[PolicyID], Callable[[PolicyID, SampleBatchType], bool]] ] = None, @@ -1221,7 +1218,7 @@ def remove_policy( def set_policy_mapping_fn( self, - policy_mapping_fn: Optional[Callable[[AgentID, "Episode"], PolicyID]] = None, + policy_mapping_fn: Optional[Callable[[AgentID, Any], PolicyID]] = None, ) -> None: """Sets `self.policy_mapping_fn` to a new callable (if provided). @@ -1425,8 +1422,6 @@ def set_state(self, state: dict) -> None: # key in `state` entirely (will be part of the policies then). self.sync_filters(state["filters"]) - connector_enabled = self.config.enable_connectors - # Support older checkpoint versions (< 1.0), in which the policy_map # was stored under the "state" key, not "policy_states". policy_states = ( @@ -1448,9 +1443,7 @@ def set_state(self, state: dict) -> None: ) else: policy_spec = ( - PolicySpec.deserialize(spec) - if connector_enabled or isinstance(spec, dict) - else spec + PolicySpec.deserialize(spec) if isinstance(spec, dict) else spec ) self.add_policy( policy_id=pid, @@ -1795,11 +1788,6 @@ def _get_complete_policy_specs_dict( if preprocessor is not None: obs_space = preprocessor.observation_space - if not merged_conf.enable_connectors: - # If connectors are not enabled, rollout worker will handle - # the running of these preprocessors. - self.preprocessors[name] = preprocessor - policy_spec.config = merged_conf policy_spec.observation_space = obs_space @@ -1865,37 +1853,22 @@ def _update_filter_dict(self, policy_dict: MultiAgentPolicyConfigDict) -> None: for name, policy_spec in sorted(policy_dict.items()): new_policy = self.policy_map[name] - if policy_spec.config.enable_connectors: - # Note(jungong) : We should only create new connectors for the - # policy iff we are creating a new policy from scratch. i.e, - # we should NOT create new connectors when we already have the - # policy object created before this function call or have the - # restoring states from the caller. - # Also note that we cannot just check the existence of connectors - # to decide whether we should create connectors because we may be - # restoring a policy that has 0 connectors configured. - if ( - new_policy.agent_connectors is None - or new_policy.action_connectors is None - ): - # TODO(jungong) : revisit this. It will be nicer to create - # connectors as the last step of Policy.__init__(). - create_connectors_for_policy(new_policy, policy_spec.config) - maybe_get_filters_for_syncing(self, name) - else: - filter_shape = tree.map_structure( - lambda s: ( - None - if isinstance(s, (Discrete, MultiDiscrete)) # noqa - else np.array(s.shape) - ), - new_policy.observation_space_struct, - ) - - self.filters[name] = get_filter( - policy_spec.config.observation_filter, - filter_shape, - ) + # Note(jungong) : We should only create new connectors for the + # policy iff we are creating a new policy from scratch. i.e, + # we should NOT create new connectors when we already have the + # policy object created before this function call or have the + # restoring states from the caller. + # Also note that we cannot just check the existence of connectors + # to decide whether we should create connectors because we may be + # restoring a policy that has 0 connectors configured. + if ( + new_policy.agent_connectors is None + or new_policy.action_connectors is None + ): + # TODO(jungong) : revisit this. It will be nicer to create + # connectors as the last step of Policy.__init__(). + create_connectors_for_policy(new_policy, policy_spec.config) + maybe_get_filters_for_syncing(self, name) def _call_callbacks_on_create_policy(self): """Calls the on_create_policy callback for each policy in the policy map.""" diff --git a/rllib/evaluation/sample_batch_builder.py b/rllib/evaluation/sample_batch_builder.py index 7b7f284747729..6baaa0611ee3e 100644 --- a/rllib/evaluation/sample_batch_builder.py +++ b/rllib/evaluation/sample_batch_builder.py @@ -1,10 +1,9 @@ import collections import logging import numpy as np -from typing import List, Any, Dict, Optional, TYPE_CHECKING +from typing import List, Any, Dict, TYPE_CHECKING from ray.rllib.env.base_env import _DUMMY_AGENT_ID -from ray.rllib.evaluation.episode import Episode from ray.rllib.policy.policy import Policy from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch from ray.rllib.utils.annotations import OldAPIStack @@ -147,7 +146,7 @@ def add_values(self, agent_id: AgentID, policy_id: AgentID, **values: Any) -> No self.agent_builders[agent_id].add_values(**values) - def postprocess_batch_so_far(self, episode: Optional[Episode] = None) -> None: + def postprocess_batch_so_far(self, episode=None) -> None: """Apply policy postprocessors to any unprocessed rows. This pushes the postprocessed per-agent batches onto the per-policy @@ -240,7 +239,7 @@ def check_missing_dones(self) -> None: "to True. " ) - def build_and_reset(self, episode: Optional[Episode] = None) -> MultiAgentBatch: + def build_and_reset(self, episode=None) -> MultiAgentBatch: """Returns the accumulated sample batches for each policy. Any unprocessed rows will be first postprocessed with a policy diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index 9e34fd237ee03..dadc65451cce3 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -1,62 +1,30 @@ import logging import queue -import time from abc import ABCMeta, abstractmethod from collections import defaultdict, namedtuple from typing import ( TYPE_CHECKING, Any, - Callable, - Dict, - Iterator, List, Optional, - Set, - Tuple, Type, Union, ) -import numpy as np -import tree # pip install dm_tree - -from ray.rllib.env.base_env import ASYNC_RESET_RETURN, BaseEnv, convert_to_base_env +from ray.rllib.env.base_env import BaseEnv, convert_to_base_env from ray.rllib.evaluation.collectors.sample_collector import SampleCollector from ray.rllib.evaluation.collectors.simple_list_collector import SimpleListCollector -from ray.rllib.evaluation.env_runner_v2 import ( - EnvRunnerV2, - _fetch_atari_metrics, - _get_or_raise, - _PerfStats, -) -from ray.rllib.evaluation.episode import Episode +from ray.rllib.evaluation.env_runner_v2 import EnvRunnerV2, _PerfStats from ray.rllib.evaluation.metrics import RolloutMetrics from ray.rllib.offline import InputReader -from ray.rllib.policy.policy import Policy -from ray.rllib.policy.policy_map import PolicyMap -from ray.rllib.policy.sample_batch import SampleBatch, concat_samples +from ray.rllib.policy.sample_batch import concat_samples from ray.rllib.utils.annotations import OldAPIStack, override -from ray.rllib.utils.debug import summarize from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.numpy import convert_to_numpy, make_action_immutable -from ray.rllib.utils.spaces.space_utils import clip_action, unbatch, unsquash_action -from ray.rllib.utils.typing import ( - AgentID, - EnvActionType, - EnvID, - EnvInfoDict, - EnvObsType, - MultiEnvDict, - PolicyID, - SampleBatchType, - TensorStructType, -) +from ray.rllib.utils.typing import SampleBatchType from ray.util.debug import log_once if TYPE_CHECKING: - from gymnasium.envs.classic_control.rendering import SimpleImageViewer - from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.rllib.evaluation.observation_function import ObservationFunction from ray.rllib.evaluation.rollout_worker import RolloutWorker @@ -236,35 +204,19 @@ def __init__( ) self.render = render - if worker.config.enable_connectors: - # Keep a reference to the underlying EnvRunnerV2 instance for - # unit testing purpose. - self._env_runner_obj = EnvRunnerV2( - worker=worker, - base_env=self.base_env, - multiple_episodes_in_batch=multiple_episodes_in_batch, - callbacks=callbacks, - perf_stats=self.perf_stats, - rollout_fragment_length=rollout_fragment_length, - count_steps_by=count_steps_by, - render=self.render, - ) - self._env_runner = self._env_runner_obj.run() - else: - # Create the rollout generator to use for calls to `get_data()`. - self._env_runner = _env_runner( - worker, - self.base_env, - self.extra_batches.put, - normalize_actions, - clip_actions, - multiple_episodes_in_batch, - callbacks, - self.perf_stats, - observation_fn, - self.sample_collector, - self.render, - ) + # Keep a reference to the underlying EnvRunnerV2 instance for + # unit testing purpose. + self._env_runner_obj = EnvRunnerV2( + worker=worker, + base_env=self.base_env, + multiple_episodes_in_batch=multiple_episodes_in_batch, + callbacks=callbacks, + perf_stats=self.perf_stats, + rollout_fragment_length=rollout_fragment_length, + count_steps_by=count_steps_by, + render=self.render, + ) + self._env_runner = self._env_runner_obj.run() self.metrics_queue = queue.Queue() @override(SamplerInput) @@ -299,809 +251,3 @@ def get_extra_batches(self) -> List[SampleBatchType]: except queue.Empty: break return extra - - -@OldAPIStack -def _env_runner( - worker: "RolloutWorker", - base_env: BaseEnv, - extra_batch_callback: Callable[[SampleBatchType], None], - normalize_actions: bool, - clip_actions: bool, - multiple_episodes_in_batch: bool, - callbacks: "DefaultCallbacks", - perf_stats: _PerfStats, - observation_fn: "ObservationFunction", - sample_collector: Optional[SampleCollector] = None, - render: bool = None, -) -> Iterator[SampleBatchType]: - """This implements the common experience collection logic. - - Args: - worker: Reference to the current rollout worker. - base_env: Env implementing BaseEnv. - extra_batch_callback: function to send extra batch data to. - multiple_episodes_in_batch: Whether to pack multiple - episodes into each batch. This guarantees batches will be exactly - `rollout_fragment_length` in size. - normalize_actions: Whether to normalize actions to the action - space's bounds. - clip_actions: Whether to clip actions to the space range. - callbacks: User callbacks to run on episode events. - perf_stats: Record perf stats into this object. - observation_fn: Optional multi-agent - observation func to use for preprocessing observations. - sample_collector: An optional - SampleCollector object to use. - render: Whether to try to render the environment after each - step. - - Yields: - Object containing state, action, reward, terminal condition, - and other fields as dictated by `policy`. - """ - - # May be populated with used for image rendering - simple_image_viewer: Optional["SimpleImageViewer"] = None - - def _new_episode(env_id): - episode = Episode( - worker.policy_map, - worker.policy_mapping_fn, - # SimpleListCollector will find or create a - # simple_list_collector._PolicyCollector as batch_builder - # for this episode later. Here we simply provide a None factory. - lambda: None, # batch_builder_factory - extra_batch_callback, - env_id=env_id, - worker=worker, - ) - return episode - - active_episodes: Dict[EnvID, Episode] = _NewEpisodeDefaultDict(_new_episode) - - # Before the very first poll (this will reset all vector sub-environments): - # Call custom `before_sub_environment_reset` callbacks for all sub-environments. - for env_id, sub_env in base_env.get_sub_environments(as_dict=True).items(): - _create_episode(active_episodes, env_id, callbacks, worker, base_env) - - while True: - perf_stats.incr("iters", 1) - - t0 = time.time() - # Get observations from all ready agents. - # types: MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, ... - ( - unfiltered_obs, - rewards, - terminateds, - truncateds, - infos, - off_policy_actions, - ) = base_env.poll() - env_poll_time = time.time() - t0 - - if log_once("env_returns"): - logger.info("Raw obs from env: {}".format(summarize(unfiltered_obs))) - logger.info("Info return from env: {}".format(summarize(infos))) - - # Process observations and prepare for policy evaluation. - t1 = time.time() - # types: Set[EnvID], Dict[PolicyID, List[_PolicyEvalData]], - # List[Union[RolloutMetrics, SampleBatchType]] - active_envs, to_eval, outputs = _process_observations( - worker=worker, - base_env=base_env, - active_episodes=active_episodes, - unfiltered_obs=unfiltered_obs, - rewards=rewards, - terminateds=terminateds, - truncateds=truncateds, - infos=infos, - multiple_episodes_in_batch=multiple_episodes_in_batch, - callbacks=callbacks, - observation_fn=observation_fn, - sample_collector=sample_collector, - ) - perf_stats.incr("raw_obs_processing_time", time.time() - t1) - for o in outputs: - yield o - - # Do batched policy eval (accross vectorized envs). - t2 = time.time() - # types: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]] - eval_results = _do_policy_eval( - to_eval=to_eval, - policies=worker.policy_map, - sample_collector=sample_collector, - active_episodes=active_episodes, - ) - perf_stats.incr("inference_time", time.time() - t2) - - # Process results and update episode state. - t3 = time.time() - actions_to_send: Dict[ - EnvID, Dict[AgentID, EnvActionType] - ] = _process_policy_eval_results( - to_eval=to_eval, - eval_results=eval_results, - active_episodes=active_episodes, - active_envs=active_envs, - off_policy_actions=off_policy_actions, - policies=worker.policy_map, - normalize_actions=normalize_actions, - clip_actions=clip_actions, - ) - perf_stats.incr("action_processing_time", time.time() - t3) - - # Return computed actions to ready envs. We also send to envs that have - # taken off-policy actions; those envs are free to ignore the action. - t4 = time.time() - base_env.send_actions(actions_to_send) - perf_stats.incr("env_wait_time", env_poll_time + time.time() - t4) - - # Try to render the env, if required. - if render: - t5 = time.time() - # Render can either return an RGB image (uint8 [w x h x 3] numpy - # array) or take care of rendering itself (returning True). - rendered = base_env.try_render() - # Rendering returned an image -> Display it in a SimpleImageViewer. - if isinstance(rendered, np.ndarray) and len(rendered.shape) == 3: - # ImageViewer not defined yet, try to create one. - if simple_image_viewer is None: - try: - from gymnasium.envs.classic_control.rendering import ( - SimpleImageViewer, - ) - - simple_image_viewer = SimpleImageViewer() - except (ImportError, ModuleNotFoundError): - render = False # disable rendering - logger.warning( - "Could not import gymnasium.envs.classic_control." - "rendering! Try `pip install gymnasium[all]`." - ) - if simple_image_viewer: - simple_image_viewer.imshow(rendered) - elif rendered not in [True, False, None]: - raise ValueError( - f"The env's ({base_env}) `try_render()` method returned an" - " unsupported value! Make sure you either return a " - "uint8/w x h x 3 (RGB) image or handle rendering in a " - "window and then return `True`." - ) - perf_stats.incr("env_render_time", time.time() - t5) - - -@OldAPIStack -def _process_observations( - *, - worker: "RolloutWorker", - base_env: BaseEnv, - active_episodes: Dict[EnvID, Episode], - unfiltered_obs: Dict[EnvID, Dict[AgentID, EnvObsType]], - rewards: Dict[EnvID, Dict[AgentID, float]], - terminateds: Dict[EnvID, Dict[AgentID, bool]], - truncateds: Dict[EnvID, Dict[AgentID, bool]], - infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]], - multiple_episodes_in_batch: bool, - callbacks: "DefaultCallbacks", - observation_fn: "ObservationFunction", - sample_collector: SampleCollector, -) -> Tuple[ - Set[EnvID], - Dict[PolicyID, List[_PolicyEvalData]], - List[Union[RolloutMetrics, SampleBatchType]], -]: - """Record new data from the environment and prepare for policy evaluation. - - Args: - worker: Reference to the current rollout worker. - base_env: Env implementing BaseEnv. - active_episodes: Mapping from - episode ID to currently ongoing Episode object. - unfiltered_obs: Doubly keyed dict of env-ids -> agent ids - -> unfiltered observation tensor, returned by a `BaseEnv.poll()` - call. - rewards: Doubly keyed dict of env-ids -> agent ids -> - rewards tensor, returned by a `BaseEnv.poll()` call. - terminateds: Doubly keyed dict of env-ids -> agent ids -> - boolean `terminated` flags, returned by a `BaseEnv.poll()` call. - truncateds: Doubly keyed dict of env-ids -> agent ids -> - boolean `truncated` flags, returned by a `BaseEnv.poll()` call. - infos: Doubly keyed dict of env-ids -> agent ids -> - info dicts, returned by a `BaseEnv.poll()` call. - multiple_episodes_in_batch: Whether to pack multiple - episodes into each batch. This guarantees batches will be exactly - `rollout_fragment_length` in size. - callbacks: User callbacks to run on episode events. - observation_fn: Optional multi-agent - observation func to use for preprocessing observations. - sample_collector: The SampleCollector object - used to store and retrieve environment samples. - - Returns: - Tuple consisting of 1) active_envs: Set of non-terminated env ids. - 2) to_eval: Map of policy_id to list of agent _PolicyEvalData. - 3) outputs: List of metrics and samples to return from the sampler. - """ - - # Output objects. - active_envs: Set[EnvID] = set() - to_eval: Dict[PolicyID, List[_PolicyEvalData]] = defaultdict(list) - outputs: List[Union[RolloutMetrics, SampleBatchType]] = [] - - # For each (vectorized) sub-environment. - # types: EnvID, Dict[AgentID, EnvObsType] - for env_id, all_agents_obs in unfiltered_obs.items(): - episode: Episode = active_episodes[env_id] - - # Check for env_id having returned an error instead of a multi-agent obs dict. - # This is how our BaseEnv can tell the caller to `poll()` that one of its - # sub-environments is faulty and should be restarted (and the ongoing episode - # should not be used for training). - if isinstance(all_agents_obs, Exception): - episode.is_faulty = True - assert terminateds[env_id]["__all__"] is True, ( - f"ERROR: When a sub-environment (env-id {env_id}) returns an error as " - "observation, the terminateds[__all__] flag must also be set to True!" - ) - # This will be filled with dummy observations below. - all_agents_obs = {} - - # Add init obs and infos (from the call to `reset/try_reset`) to episode. - for aid, obs in all_agents_obs.items(): - episode._set_last_raw_obs(aid, obs) - common_infos = infos[env_id].get("__common__", {}) - episode._set_last_info("__common__", common_infos) - for aid, info in infos[env_id].items(): - episode._set_last_info(aid, info) - - # Episode is brand new. - if episode.started is False: - # Call the episode start callback(s). - _call_on_episode_start(episode, env_id, callbacks, worker, base_env) - else: - sample_collector.episode_step(episode) - episode._add_agent_rewards(rewards[env_id]) - - # Check episode termination conditions. - if terminateds[env_id]["__all__"] or truncateds[env_id]["__all__"]: - all_agents_done = True - # Check whether we have to create a fake-last observation - # for some agents (the environment is not required to do so if - # terminateds[__all__]=True or truncateds[__all__]=True). - for ag_id in episode.get_agents(): - if ( - not episode.last_terminated_for(ag_id) - and not episode.last_truncated_for(ag_id) - and ag_id not in all_agents_obs - ): - # Create a fake (all-0s) observation. - obs_sp = worker.policy_map[ - episode.policy_for(ag_id) - ].observation_space - obs_sp = getattr(obs_sp, "original_space", obs_sp) - all_agents_obs[ag_id] = tree.map_structure( - np.zeros_like, obs_sp.sample() - ) - else: - all_agents_done = False - active_envs.add(env_id) - - # Custom observation function is applied before preprocessing. - if observation_fn: - all_agents_obs: Dict[AgentID, EnvObsType] = observation_fn( - agent_obs=all_agents_obs, - worker=worker, - base_env=base_env, - policies=worker.policy_map, - episode=episode, - ) - if not isinstance(all_agents_obs, dict): - raise ValueError("observe() must return a dict of agent observations") - - # For each agent in the environment. - # types: AgentID, EnvObsType - for agent_id, raw_obs in all_agents_obs.items(): - assert agent_id != "__all__" - - last_observation: EnvObsType = episode.last_observation_for(agent_id) - agent_terminated = bool( - terminateds[env_id]["__all__"] or terminateds[env_id].get(agent_id) - ) - agent_truncated = bool( - truncateds[env_id]["__all__"] or truncateds[env_id].get(agent_id, False) - ) - - # A new agent (initial obs) is already done -> Skip entirely. - if last_observation is None and (agent_terminated or agent_truncated): - continue - - policy_id: PolicyID = episode.policy_for(agent_id) - - preprocessor = _get_or_raise(worker.preprocessors, policy_id) - prep_obs: EnvObsType = raw_obs - if preprocessor is not None: - prep_obs = preprocessor.transform(raw_obs) - if log_once("prep_obs"): - logger.info("Preprocessed obs: {}".format(summarize(prep_obs))) - filtered_obs: EnvObsType = _get_or_raise(worker.filters, policy_id)( - prep_obs - ) - if log_once("filtered_obs"): - logger.info("Filtered obs: {}".format(summarize(filtered_obs))) - - episode._set_last_observation(agent_id, filtered_obs) - episode._set_last_terminated(agent_id, agent_terminated) - episode._set_last_truncated(agent_id, agent_truncated) - agent_infos = infos[env_id].get(agent_id, {}) - - # Record transition info if applicable. - if last_observation is None: - sample_collector.add_init_obs( - episode=episode, - agent_id=agent_id, - env_id=env_id, - policy_id=policy_id, - init_obs=filtered_obs, - init_infos=agent_infos, - t=episode.length - 1, - ) - else: - # Add actions, rewards, next-obs to collectors. - values_dict = { - SampleBatch.T: episode.length - 1, - SampleBatch.ENV_ID: env_id, - SampleBatch.AGENT_INDEX: episode._agent_index(agent_id), - # Action (slot 0) taken at timestep t. - SampleBatch.ACTIONS: episode.last_action_for(agent_id), - # Reward received after taking a at timestep t. - SampleBatch.REWARDS: rewards[env_id].get(agent_id, 0.0), - # After taking action=a, did we terminate the episode? - SampleBatch.TERMINATEDS: agent_terminated, - # Was the episode truncated artificially - # (e.g. b/c of some time limit)? - SampleBatch.TRUNCATEDS: agent_truncated, - # Next observation. - SampleBatch.NEXT_OBS: filtered_obs, - } - # Add extra-action-fetches (policy-inference infos) to - # collectors. - pol = worker.policy_map[policy_id] - for key, value in episode.last_extra_action_outs_for(agent_id).items(): - if key in pol.view_requirements: - values_dict[key] = value - # Env infos for this agent. - if SampleBatch.INFOS in pol.view_requirements: - values_dict[SampleBatch.INFOS] = agent_infos - sample_collector.add_action_reward_next_obs( - episode.episode_id, - agent_id, - env_id, - policy_id, - agent_terminated or agent_truncated, - values_dict, - ) - - if not agent_terminated and not agent_truncated: - item = _PolicyEvalData( - env_id, - agent_id, - filtered_obs, - agent_infos, - None - if last_observation is None - else episode.rnn_state_for(agent_id), - None - if last_observation is None - else episode.last_action_for(agent_id), - rewards[env_id].get(agent_id, 0.0), - ) - to_eval[policy_id].append(item) - - # Invoke the `on_episode_step` callback after the step is logged - # to the episode. - # Exception: The very first env.poll() call causes the env to get reset - # (no step taken yet, just a single starting observation logged). - # We need to skip this callback in this case. - if not episode.is_faulty and episode.length > 0: - callbacks.on_episode_step( - worker=worker, - base_env=base_env, - policies=worker.policy_map, - episode=episode, - env_index=env_id, - ) - - # Episode is terminated for all agents (terminateds[__all__] == True or - # truncateds[__all__] == True). - if all_agents_done: - # If, we are not allowed to pack the next episode into the same - # SampleBatch (batch_mode=complete_episodes) -> Build the - # MultiAgentBatch from a single episode and add it to "outputs". - # Otherwise, just postprocess and continue collecting across - # episodes. - # If an episode was marked faulty, perform regular postprocessing - # (to e.g. properly flush and clean up the SampleCollector's buffers), - # but then discard the entire batch and don't return it. - ma_sample_batch = None - if not episode.is_faulty or episode.length > 0: - ma_sample_batch = sample_collector.postprocess_episode( - episode, - is_done=True, - check_dones=True, - build=episode.is_faulty or not multiple_episodes_in_batch, - ) - if not episode.is_faulty: - # Call each (in-memory) policy's Exploration.on_episode_end - # method. - # Note: This may break the exploration (e.g. ParameterNoise) of - # policies in the `policy_map` that have not been recently used - # (and are therefore stashed to disk). However, we certainly do not - # want to loop through all (even stashed) policies here as that - # would counter the purpose of the LRU policy caching. - for p in worker.policy_map.cache.values(): - if getattr(p, "exploration", None) is not None: - p.exploration.on_episode_end( - policy=p, - environment=base_env, - episode=episode, - tf_sess=p.get_session(), - ) - # Call custom on_episode_end callback. - callbacks.on_episode_end( - worker=worker, - base_env=base_env, - policies=worker.policy_map, - episode=episode, - env_index=env_id, - ) - - # Now that all callbacks are done and users had the chance to add custom - # metrics based on the last observation in the episode, finish up metrics - # object and append to `outputs`. - atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics(base_env) - if not episode.is_faulty: - if atari_metrics is not None: - for m in atari_metrics: - outputs.append( - m._replace( - custom_metrics=episode.custom_metrics, - hist_data=episode.hist_data, - ) - ) - else: - outputs.append( - RolloutMetrics( - episode.length, - episode.total_reward, - dict(episode.agent_rewards), - episode.custom_metrics, - {}, - episode.hist_data, - episode.media, - ) - ) - else: - # Add metrics about a faulty episode. - outputs.append(RolloutMetrics(episode_faulty=True)) - - # Only after the RolloutMetrics were appended, append the collected sample - # batch, if any. - if not episode.is_faulty and ma_sample_batch: - outputs.append(ma_sample_batch) - - # Terminated: Try to reset the sub environment. - # Clean up old finished episode. - del active_episodes[env_id] - - # Create a new episode and call `on_episode_created` callback(s). - _create_episode(active_episodes, env_id, callbacks, worker, base_env) - - # The sub environment at index `env_id` might throw an exception - # during the following `try_reset()` attempt. If configured with - # `restart_failed_sub_environments=True`, the BaseEnv will restart - # the affected sub environment (create a new one using its c'tor) and - # must reset the recreated sub env right after that. - # Should the sub environment fail indefinitely during these - # repeated reset attempts, the entire worker will be blocked. - # This would be ok, b/c the alternative would be the worker crashing - # entirely. - while True: - resetted_obs, resetted_infos = base_env.try_reset(env_id) - if resetted_obs is None or not isinstance( - resetted_obs[env_id], Exception - ): - break - else: - # Failed to reset, add metrics about a faulty episode. - outputs.append(RolloutMetrics(episode_faulty=True)) - - # Creates a new episode if this is not async return. - # If reset is async, we will get its result in some future poll. - if resetted_obs is not None and resetted_obs != ASYNC_RESET_RETURN: - new_episode: Episode = active_episodes[env_id] - - resetted_obs = resetted_obs[env_id] - resetted_infos = resetted_infos[env_id] - - # Add init obs and infos (from the call to `reset/try_reset`) to - # episode. - for aid, obs in resetted_obs.items(): - new_episode._set_last_raw_obs(aid, obs) - common_infos = resetted_infos.get("__common__", {}) - new_episode._set_last_info("__common__", common_infos) - for aid, info in resetted_infos.items(): - new_episode._set_last_info(aid, info) - - _call_on_episode_start(new_episode, env_id, callbacks, worker, base_env) - - _assert_episode_not_faulty(new_episode) - if observation_fn: - resetted_obs: Dict[AgentID, EnvObsType] = observation_fn( - agent_obs=resetted_obs, - worker=worker, - base_env=base_env, - policies=worker.policy_map, - episode=new_episode, - ) - # types: AgentID, EnvObsType - for agent_id, raw_obs in resetted_obs.items(): - policy_id: PolicyID = new_episode.policy_for(agent_id) - preproccessor = _get_or_raise(worker.preprocessors, policy_id) - - prep_obs: EnvObsType = raw_obs - if preproccessor is not None: - prep_obs = preproccessor.transform(raw_obs) - filtered_obs: EnvObsType = _get_or_raise(worker.filters, policy_id)( - prep_obs - ) - new_episode._set_last_observation(agent_id, filtered_obs) - - # Add initial obs to buffer. - sample_collector.add_init_obs( - episode=new_episode, - agent_id=agent_id, - env_id=env_id, - policy_id=policy_id, - init_obs=filtered_obs, - init_infos=resetted_infos, - t=new_episode.length - 1, - ) - - item = _PolicyEvalData( - env_id, - agent_id, - filtered_obs, - new_episode.last_info_for(agent_id) or {}, - new_episode.rnn_state_for(agent_id), - None, - 0.0, - ) - to_eval[policy_id].append(item) - - # Try to build something. - if multiple_episodes_in_batch: - sample_batches = ( - sample_collector.try_build_truncated_episode_multi_agent_batch() - ) - if sample_batches: - outputs.extend(sample_batches) - - return active_envs, to_eval, outputs - - -@OldAPIStack -def _do_policy_eval( - *, - to_eval: Dict[PolicyID, List[_PolicyEvalData]], - policies: PolicyMap, - sample_collector: SampleCollector, - active_episodes: Dict[EnvID, Episode], -) -> Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]]: - """Call compute_actions on collected episode/model data to get next action. - - Args: - to_eval: Mapping of policy IDs to lists of _PolicyEvalData objects - (items in these lists will be the batch's items for the model - forward pass). - policies: Mapping from policy ID to Policy obj. - sample_collector: The SampleCollector object to use. - active_episodes: Mapping of EnvID to its currently active episode. - - Returns: - Dict mapping PolicyIDs to compute_actions_from_input_dict() outputs. - """ - - eval_results: Dict[PolicyID, TensorStructType] = {} - - if log_once("compute_actions_input"): - logger.info("Inputs to compute_actions():\n\n{}\n".format(summarize(to_eval))) - - for policy_id, eval_data in to_eval.items(): - # In case the policyID has been removed from this worker, we need to - # re-assign policy_id and re-lookup the Policy object to use. - try: - policy: Policy = _get_or_raise(policies, policy_id) - except ValueError: - # Important: Get the policy_mapping_fn from the active - # Episode as the policy_mapping_fn from the worker may - # have already been changed (mapping fn stay constant - # within one episode). - episode = active_episodes[eval_data[0].env_id] - _assert_episode_not_faulty(episode) - policy_id = episode.policy_mapping_fn( - eval_data[0].agent_id, episode, worker=episode.worker - ) - policy: Policy = _get_or_raise(policies, policy_id) - - input_dict = sample_collector.get_inference_input_dict(policy_id) - eval_results[policy_id] = policy.compute_actions_from_input_dict( - input_dict, - timestep=policy.global_timestep, - episodes=[active_episodes[t.env_id] for t in eval_data], - ) - - if log_once("compute_actions_result"): - logger.info( - "Outputs of compute_actions():\n\n{}\n".format(summarize(eval_results)) - ) - - return eval_results - - -@OldAPIStack -def _process_policy_eval_results( - *, - to_eval: Dict[PolicyID, List[_PolicyEvalData]], - eval_results: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]], - active_episodes: Dict[EnvID, Episode], - active_envs: Set[int], - off_policy_actions: MultiEnvDict, - policies: Dict[PolicyID, Policy], - normalize_actions: bool, - clip_actions: bool, -) -> Dict[EnvID, Dict[AgentID, EnvActionType]]: - """Process the output of policy neural network evaluation. - - Records policy evaluation results into the given episode objects and - returns replies to send back to agents in the env. - - Args: - to_eval: Mapping of policy IDs to lists of _PolicyEvalData objects. - eval_results: Mapping of policy IDs to list of - actions, rnn-out states, extra-action-fetches dicts. - active_episodes: Mapping from episode ID to currently ongoing - Episode object. - active_envs: Set of non-terminated env ids. - off_policy_actions: Doubly keyed dict of env-ids -> agent ids -> - off-policy-action, returned by a `BaseEnv.poll()` call. - policies: Mapping from policy ID to Policy. - normalize_actions: Whether to normalize actions to the action - space's bounds. - clip_actions: Whether to clip actions to the action space's bounds. - - Returns: - Nested dict of env id -> agent id -> actions to be sent to - Env (np.ndarrays). - """ - - actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = defaultdict(dict) - - # types: int - for env_id in active_envs: - actions_to_send[env_id] = {} # at minimum send empty dict - - # types: PolicyID, List[_PolicyEvalData] - for policy_id, eval_data in to_eval.items(): - actions: TensorStructType = eval_results[policy_id][0] - actions = convert_to_numpy(actions) - - rnn_out_cols: StateBatch = eval_results[policy_id][1] - extra_action_out_cols: dict = eval_results[policy_id][2] - - # In case actions is a list (representing the 0th dim of a batch of - # primitive actions), try converting it first. - if isinstance(actions, list): - actions = np.array(actions) - - # Store RNN state ins/outs and extra-action fetches to episode. - for f_i, column in enumerate(rnn_out_cols): - extra_action_out_cols["state_out_{}".format(f_i)] = column - - policy: Policy = _get_or_raise(policies, policy_id) - # Split action-component batches into single action rows. - actions: List[EnvActionType] = unbatch(actions) - # types: int, EnvActionType - for i, action in enumerate(actions): - # Normalize, if necessary. - if normalize_actions: - action_to_send = unsquash_action(action, policy.action_space_struct) - # Clip, if necessary. - elif clip_actions: - action_to_send = clip_action(action, policy.action_space_struct) - else: - action_to_send = action - - env_id: int = eval_data[i].env_id - agent_id: AgentID = eval_data[i].agent_id - episode: Episode = active_episodes[env_id] - _assert_episode_not_faulty(episode) - episode._set_rnn_state( - agent_id, tree.map_structure(lambda x: x[i], rnn_out_cols) - ) - episode._set_last_extra_action_outs( - agent_id, tree.map_structure(lambda x: x[i], extra_action_out_cols) - ) - if env_id in off_policy_actions and agent_id in off_policy_actions[env_id]: - episode._set_last_action(agent_id, off_policy_actions[env_id][agent_id]) - else: - episode._set_last_action(agent_id, action) - - assert agent_id not in actions_to_send[env_id] - # Flag actions as immutable to notify the user when trying to change it - # and to avoid hardly traceable errors. - tree.traverse(make_action_immutable, action_to_send, top_down=False) - actions_to_send[env_id][agent_id] = action_to_send - - return actions_to_send - - -@OldAPIStack -def _create_episode(active_episodes, env_id, callbacks, worker, base_env): - # Make sure we are really creating a new episode here. - assert env_id not in active_episodes - - # Create a new episode under the given `env_id` and call the - # `on_episode_created` callbacks. - new_episode = active_episodes[env_id] - # Call `on_episode_created()` callback. - callbacks.on_episode_created( - worker=worker, - base_env=base_env, - policies=worker.policy_map, - env_index=env_id, - episode=new_episode, - ) - return new_episode - - -@OldAPIStack -def _call_on_episode_start(episode, env_id, callbacks, worker, base_env): - # Call each policy's Exploration.on_episode_start method. - # Note: This may break the exploration (e.g. ParameterNoise) of - # policies in the `policy_map` that have not been recently used - # (and are therefore stashed to disk). However, we certainly do not - # want to loop through all (even stashed) policies here as that - # would counter the purpose of the LRU policy caching. - for p in worker.policy_map.cache.values(): - if getattr(p, "exploration", None) is not None: - p.exploration.on_episode_start( - policy=p, - environment=base_env, - episode=episode, - tf_sess=p.get_session(), - ) - callbacks.on_episode_start( - worker=worker, - base_env=base_env, - policies=worker.policy_map, - episode=episode, - env_index=env_id, - ) - episode.started = True - - -def _to_column_format(rnn_state_rows: List[List[Any]]) -> StateBatch: - num_cols = len(rnn_state_rows[0]) - return [[row[i] for row in rnn_state_rows] for i in range(num_cols)] - - -def _assert_episode_not_faulty(episode): - if episode.is_faulty: - raise AssertionError( - "Episodes marked as `faulty` should not be kept in the " - f"`active_episodes` map! Episode ID={episode.episode_id}." - ) diff --git a/rllib/evaluation/tests/test_env_runner_v2.py b/rllib/evaluation/tests/test_env_runner_v2.py index 0c072c33e79ee..d5d139f385a7f 100644 --- a/rllib/evaluation/tests/test_env_runner_v2.py +++ b/rllib/evaluation/tests/test_env_runner_v2.py @@ -61,8 +61,6 @@ def test_sample_batch_rollout_single_agent_env(self): .env_runners( num_envs_per_env_runner=1, num_env_runners=0, - # Enable EnvRunnerV2. - enable_connectors=True, ) ) @@ -88,8 +86,6 @@ def test_sample_batch_rollout_multi_agent_env(self): .env_runners( num_envs_per_env_runner=1, num_env_runners=0, - # Enable EnvRunnerV2. - enable_connectors=True, ) ) @@ -153,8 +149,6 @@ def compute_actions( .env_runners( num_envs_per_env_runner=1, num_env_runners=0, - # Enable EnvRunnerV2. - enable_connectors=True, rollout_fragment_length=100, ) .multi_agent( @@ -220,8 +214,6 @@ def __init__(self, *args, **kwargs): .env_runners( num_envs_per_env_runner=1, num_env_runners=0, - # Enable EnvRunnerV2. - enable_connectors=True, ) .multi_agent( policies={ @@ -294,8 +286,6 @@ def on_create_policy(self, *, policy_id, policy) -> None: .env_runners( num_envs_per_env_runner=1, num_env_runners=0, - # Enable EnvRunnerV2. - enable_connectors=True, ) ) @@ -317,8 +307,6 @@ def test_start_episode(self): .env_runners( num_envs_per_env_runner=1, num_env_runners=0, - # Enable EnvRunnerV2. - enable_connectors=True, ) .multi_agent( policies={ @@ -373,8 +361,6 @@ def test_env_runner_output(self): .env_runners( num_envs_per_env_runner=1, num_env_runners=0, - # Enable EnvRunnerV2. - enable_connectors=True, ) .multi_agent( policies={ @@ -432,8 +418,6 @@ def on_episode_end( .env_runners( num_envs_per_env_runner=1, num_env_runners=0, - # Enable EnvRunnerV2. - enable_connectors=True, ) .multi_agent( policies={ diff --git a/rllib/evaluation/tests/test_episode.py b/rllib/evaluation/tests/test_episode.py deleted file mode 100644 index d61e94cf3302e..0000000000000 --- a/rllib/evaluation/tests/test_episode.py +++ /dev/null @@ -1,174 +0,0 @@ -import ray -import unittest -from typing import Dict, List, Optional, Union, Tuple -import numpy as np -from ray.rllib.algorithms.algorithm_config import AlgorithmConfig -from ray.rllib.algorithms.callbacks import DefaultCallbacks -from ray.rllib.env.multi_agent_env import MultiAgentEnv -from ray.rllib.evaluation.rollout_worker import RolloutWorker -from ray.rllib.evaluation.episode import Episode -from ray.rllib.examples.envs.classes.mock_env import MockEnv3 -from ray.rllib.policy import Policy -from ray.rllib.utils import override -from ray.rllib.utils.typing import TensorStructType, TensorType - -NUM_STEPS = 25 -NUM_AGENTS = 4 - - -class LastInfoCallback(DefaultCallbacks): - def __init__(self): - super(LastInfoCallback, self).__init__() - self.tc = unittest.TestCase() - self.step = 0 - - def on_episode_start( - self, worker, base_env, policies, episode, env_index, **kwargs - ): - self.step = 0 - self._check_last_values(episode) - - def on_episode_step(self, worker, base_env, episode, env_index=None, **kwargs): - self.step += 1 - self._check_last_values(episode) - - def on_episode_end(self, worker, base_env, policies, episode, **kwargs): - self._check_last_values(episode) - - def _check_last_values(self, episode): - last_obs = { - k: np.where(v)[0].item() for k, v in episode._agent_to_last_obs.items() - } - last_raw_obs = episode._agent_to_last_raw_obs - last_info = episode._agent_to_last_info - last_terminated = episode._agent_to_last_terminated - last_truncated = episode._agent_to_last_truncated - last_action = episode._agent_to_last_action - last_reward = {k: v[-1] for k, v in episode._agent_reward_history.items()} - if self.step == 0: - for last in [ - last_obs, - last_terminated, - last_truncated, - last_action, - last_reward, - ]: - self.tc.assertEqual(last, {}) - self.tc.assertTrue("__common__" in last_info) - self.tc.assertTrue(len(last_raw_obs) > 0) - for agent in last_raw_obs.keys(): - index = int(str(agent).replace("agent", "")) - self.tc.assertEqual(last_raw_obs[agent], 0) - self.tc.assertEqual(last_info[agent]["timestep"], self.step + index) - else: - for agent in last_obs.keys(): - index = int(str(agent).replace("agent", "")) - self.tc.assertEqual(last_obs[agent], self.step + index) - self.tc.assertEqual(last_reward[agent], self.step + index) - self.tc.assertEqual(last_terminated[agent], self.step == NUM_STEPS) - self.tc.assertEqual(last_truncated[agent], self.step == NUM_STEPS) - if self.step == 1: - self.tc.assertEqual(last_action[agent], 0) - else: - self.tc.assertEqual(last_action[agent], self.step + index - 1) - self.tc.assertEqual(last_info[agent]["timestep"], self.step + index) - - -class EchoPolicy(Policy): - @override(Policy) - def compute_actions( - self, - obs_batch: Union[List[TensorStructType], TensorStructType], - state_batches: Optional[List[TensorType]] = None, - prev_action_batch: Union[List[TensorStructType], TensorStructType] = None, - prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None, - info_batch: Optional[Dict[str, list]] = None, - episodes: Optional[List["Episode"]] = None, - explore: Optional[bool] = None, - timestep: Optional[int] = None, - **kwargs, - ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: - return obs_batch.argmax(axis=1), [], {} - - -class EpisodeEnv(MultiAgentEnv): - def __init__(self, episode_length, num): - super().__init__() - self.agents = [MockEnv3(episode_length) for _ in range(num)] - self.terminateds = set() - self.truncateds = set() - self.observation_space = self.agents[0].observation_space - self.action_space = self.agents[0].action_space - - def reset(self, *, seed=None, options=None): - self.terminateds = set() - self.truncateds = set() - obs_and_infos = [a.reset() for a in self.agents] - return ( - {i: oi[0] for i, oi in enumerate(obs_and_infos)}, - {i: dict(oi[1], **{"timestep": i}) for i, oi in enumerate(obs_and_infos)}, - ) - - def step(self, action_dict): - obs, rew, terminated, truncated, info = {}, {}, {}, {}, {} - for i, action in action_dict.items(): - obs[i], rew[i], terminated[i], truncated[i], info[i] = self.agents[i].step( - action - ) - obs[i] = obs[i] + i - rew[i] = rew[i] + i - info[i]["timestep"] = info[i]["timestep"] + i - if terminated[i]: - self.terminateds.add(i) - if truncated[i]: - self.truncateds.add(i) - terminated["__all__"] = len(self.terminateds) == len(self.agents) - truncated["__all__"] = len(self.truncateds) == len(self.agents) - return obs, rew, terminated, truncated, info - - -class TestEpisodeLastValues(unittest.TestCase): - @classmethod - def setUpClass(cls): - ray.init(num_cpus=1) - - @classmethod - def tearDownClass(cls): - ray.shutdown() - - def test_single_agent_env(self): - ev = RolloutWorker( - env_creator=lambda _: MockEnv3(NUM_STEPS), - default_policy_class=EchoPolicy, - # Episode only works with env runner v1. - config=AlgorithmConfig() - .env_runners(enable_connectors=False) - .env_runners(num_env_runners=0) - .callbacks(LastInfoCallback), - ) - ev.sample() - - def test_multi_agent_env(self): - ev = RolloutWorker( - env_creator=lambda _: EpisodeEnv(NUM_STEPS, NUM_AGENTS), - default_policy_class=EchoPolicy, - # Episode only works with env runner v1. - config=AlgorithmConfig() - .env_runners(enable_connectors=False) - .env_runners(num_env_runners=0) - .callbacks(LastInfoCallback) - .multi_agent( - policies={str(agent_id) for agent_id in range(NUM_AGENTS)}, - policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: ( - str(agent_id) - ), - ), - ) - ev.sample() - - -if __name__ == "__main__": - import pytest - import sys - - sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib/evaluation/tests/test_episode_v2.py b/rllib/evaluation/tests/test_episode_v2.py index ee493b6c655f3..2a4f0131df631 100644 --- a/rllib/evaluation/tests/test_episode_v2.py +++ b/rllib/evaluation/tests/test_episode_v2.py @@ -77,10 +77,7 @@ def test_single_agent_env(self): ev = RolloutWorker( env_creator=lambda _: MockEnv3(NUM_STEPS), default_policy_class=EchoPolicy, - config=AlgorithmConfig().env_runners( - enable_connectors=True, - num_env_runners=0, - ), + config=AlgorithmConfig().env_runners(num_env_runners=0), ) ma_batch = ev.sample() self.assertEqual(ma_batch.count, 200) @@ -101,7 +98,7 @@ def test_multi_agent_env(self): str(agent_id) ), ) - .env_runners(enable_connectors=True, num_env_runners=0), + .env_runners(num_env_runners=0), ) sample_batches = ev.sample() self.assertEqual(len(sample_batches.policy_batches), 4) diff --git a/rllib/evaluation/tests/test_rollout_worker.py b/rllib/evaluation/tests/test_rollout_worker.py index 42f4813885f31..d52529d1e6328 100644 --- a/rllib/evaluation/tests/test_rollout_worker.py +++ b/rllib/evaluation/tests/test_rollout_worker.py @@ -493,14 +493,8 @@ def test_reward_clipping(self): ) self.assertEqual(max(sample["rewards"]), 1) result = collect_metrics(ws, []) - # Shows different behavior when connector is on/off. - if config.enable_connectors: - # episode_return_mean shows the correct clipped value. - self.assertEqual(result[EPISODE_RETURN_MEAN], 10) - else: - # episode_return_mean shows the unclipped raw value - # when connector is off, and old env_runner v1 is used. - self.assertEqual(result[EPISODE_RETURN_MEAN], 1000) + # episode_return_mean shows the correct clipped value. + self.assertEqual(result[EPISODE_RETURN_MEAN], 10) ev.stop() # Clipping in certain range (-2.0, 2.0). diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index a59b24a5317fd..0eeea1ea2c8f0 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -1,11 +1,10 @@ -from ray.rllib.env.env_runner_group import EnvRunnerGroup -from ray.rllib.utils.deprecation import deprecation_warning +from ray.rllib.utils.deprecation import Deprecated -deprecation_warning( - old="WorkerSet", + +@Deprecated( new="ray.rllib.env.env_runner_group.EnvRunnerGroup", help="The class has only be renamed w/o any changes in functionality.", - error=False, + error=True, ) - -WorkerSet = EnvRunnerGroup +class WorkerSet: + pass diff --git a/rllib/examples/_old_api_stack/connectors/prepare_checkpoint.py b/rllib/examples/_old_api_stack/connectors/prepare_checkpoint.py index 5242c70909649..01861ec6faa20 100644 --- a/rllib/examples/_old_api_stack/connectors/prepare_checkpoint.py +++ b/rllib/examples/_old_api_stack/connectors/prepare_checkpoint.py @@ -6,12 +6,8 @@ def create_appo_cartpole_checkpoint(output_dir, use_lstm=False): - # enable_connectors defaults to True. Just trying to be explicit here. config = ( - APPOConfig() - .environment("CartPole-v1") - .env_runners(enable_connectors=True) - .training(model={"use_lstm": use_lstm}) + APPOConfig().environment("CartPole-v1").training(model={"use_lstm": use_lstm}) ) # Build algorithm object. algo = config.build() diff --git a/rllib/examples/centralized_critic.py b/rllib/examples/centralized_critic.py index a54caf84100ce..0cbe110810cf4 100644 --- a/rllib/examples/centralized_critic.py +++ b/rllib/examples/centralized_critic.py @@ -106,10 +106,7 @@ def centralized_critic_postprocessing( not pytorch and policy.loss_initialized() ): assert other_agent_batches is not None - if policy.config["enable_connectors"]: - [(_, _, opponent_batch)] = list(other_agent_batches.values()) - else: - [(_, opponent_batch)] = list(other_agent_batches.values()) + [(_, _, opponent_batch)] = list(other_agent_batches.values()) # also record the opponent obs and actions in the trajectory sample_batch[OPPONENT_OBS] = opponent_batch[SampleBatch.CUR_OBS] diff --git a/rllib/examples/centralized_critic_2.py b/rllib/examples/centralized_critic_2.py deleted file mode 100644 index cdc86f218ceef..0000000000000 --- a/rllib/examples/centralized_critic_2.py +++ /dev/null @@ -1,172 +0,0 @@ -# @OldAPIStack - -# *********************************************************************************** -# IMPORTANT NOTE: This script uses the old API stack and will soon be replaced by -# `ray.rllib.examples.multi_agent.pettingzoo_shared_value_function.py`! -# *********************************************************************************** - - -"""An example of implementing a centralized critic with ObservationFunction. - -The advantage of this approach is that it's very simple and you don't have to -change the algorithm at all -- just use callbacks and a custom model. -However, it is a bit less principled in that you have to change the agent -observation spaces to include data that is only used at train time. - -See also: centralized_critic.py for an alternative approach that instead -modifies the policy to add a centralized value function. -""" - -import numpy as np -from gymnasium.spaces import Dict, Discrete -import argparse -import os - -from ray import air, tune -from ray.air.constants import TRAINING_ITERATION -from ray.rllib.algorithms.callbacks import DefaultCallbacks -from ray.rllib.algorithms.ppo import PPOConfig -from ray.rllib.examples._old_api_stack.models.centralized_critic_models import ( - YetAnotherCentralizedCriticModel, - YetAnotherTorchCentralizedCriticModel, -) -from ray.rllib.examples.envs.classes.two_step_game import TwoStepGame -from ray.rllib.models import ModelCatalog -from ray.rllib.policy.sample_batch import SampleBatch -from ray.rllib.utils.metrics import ( - ENV_RUNNER_RESULTS, - EPISODE_RETURN_MEAN, - NUM_ENV_STEPS_SAMPLED_LIFETIME, -) -from ray.rllib.utils.test_utils import check_learning_achieved - -parser = argparse.ArgumentParser() -parser.add_argument( - "--framework", - choices=["tf", "tf2", "torch"], - default="torch", - help="The DL framework specifier.", -) -parser.add_argument( - "--as-test", - action="store_true", - help="Whether this script should be run as a test: --stop-reward must " - "be achieved within --stop-timesteps AND --stop-iters.", -) -parser.add_argument( - "--stop-iters", type=int, default=100, help="Number of iterations to train." -) -parser.add_argument( - "--stop-timesteps", type=int, default=100000, help="Number of timesteps to train." -) -parser.add_argument( - "--stop-reward", type=float, default=7.99, help="Reward at which we stop training." -) - - -class FillInActions(DefaultCallbacks): - """Fills in the opponent actions info in the training batches.""" - - def on_postprocess_trajectory( - self, - worker, - episode, - agent_id, - policy_id, - policies, - postprocessed_batch, - original_batches, - **kwargs, - ): - to_update = postprocessed_batch[SampleBatch.CUR_OBS] - other_id = 1 if agent_id == 0 else 0 - action_encoder = ModelCatalog.get_preprocessor_for_space(Discrete(2)) - - # set the opponent actions into the observation - _, opponent_batch = original_batches[other_id] - opponent_actions = np.array( - [action_encoder.transform(a) for a in opponent_batch[SampleBatch.ACTIONS]] - ) - to_update[:, -2:] = opponent_actions - - -def central_critic_observer(agent_obs, **kw): - """Rewrites the agent obs to include opponent data for training.""" - - new_obs = { - 0: { - "own_obs": agent_obs[0], - "opponent_obs": agent_obs[1], - "opponent_action": 0, # filled in by FillInActions - }, - 1: { - "own_obs": agent_obs[1], - "opponent_obs": agent_obs[0], - "opponent_action": 0, # filled in by FillInActions - }, - } - return new_obs - - -if __name__ == "__main__": - args = parser.parse_args() - - ModelCatalog.register_custom_model( - "cc_model", - YetAnotherTorchCentralizedCriticModel - if args.framework == "torch" - else YetAnotherCentralizedCriticModel, - ) - - action_space = Discrete(2) - observer_space = Dict( - { - "own_obs": Discrete(6), - # These two fields are filled in by the CentralCriticObserver, and are - # not used for inference, only for training. - "opponent_obs": Discrete(6), - "opponent_action": Discrete(2), - } - ) - - config = ( - PPOConfig() - .environment(TwoStepGame) - .framework(args.framework) - .env_runners( - batch_mode="complete_episodes", - num_env_runners=0, - # TODO(avnishn) make a new example compatible w connectors. - enable_connectors=False, - ) - .callbacks(FillInActions) - .training(model={"custom_model": "cc_model"}) - .multi_agent( - policies={ - "pol1": (None, observer_space, action_space, {}), - "pol2": (None, observer_space, action_space, {}), - }, - policy_mapping_fn=lambda agent_id, episode, worker, **kwargs: "pol1" - if agent_id == 0 - else "pol2", - observation_fn=central_critic_observer, - ) - # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. - .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0"))) - ) - - stop = { - TRAINING_ITERATION: args.stop_iters, - NUM_ENV_STEPS_SAMPLED_LIFETIME: args.stop_timesteps, - f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": args.stop_reward, - } - - tuner = tune.Tuner( - "PPO", - param_space=config.to_dict(), - run_config=air.RunConfig(stop=stop, verbose=1), - ) - results = tuner.fit() - - if args.as_test: - check_learning_achieved(results, args.stop_reward) diff --git a/rllib/examples/envs/external_envs/cartpole_server.py b/rllib/examples/envs/external_envs/cartpole_server.py index 7a8fe1a5a3f72..43f25c9a52885 100755 --- a/rllib/examples/envs/external_envs/cartpole_server.py +++ b/rllib/examples/envs/external_envs/cartpole_server.py @@ -33,7 +33,6 @@ from ray import air, tune from ray.air.constants import TRAINING_ITERATION from ray.rllib.env.policy_server_input import PolicyServerInput -from ray.rllib.examples.metrics.custom_metrics_and_callbacks import MyCallbacks from ray.rllib.utils.metrics import ( ENV_RUNNER_RESULTS, EPISODE_RETURN_MEAN, @@ -174,8 +173,6 @@ def _input(ioctx): ) # DL framework to use. .framework(args.framework) - # Create a "chatty" client/server or not. - .callbacks(MyCallbacks if args.callbacks_verbose else None) # Use the `PolicyServerInput` to generate experiences. .offline_data(input_=_input) # Use n worker processes to listen on different ports. diff --git a/rllib/examples/envs/external_envs/unity3d_server.py b/rllib/examples/envs/external_envs/unity3d_server.py index b9799658ba388..4457102877e18 100755 --- a/rllib/examples/envs/external_envs/unity3d_server.py +++ b/rllib/examples/envs/external_envs/unity3d_server.py @@ -135,7 +135,6 @@ def _input(ioctx): .env_runners( num_env_runners=args.num_workers, rollout_fragment_length=20, - enable_connectors=False, ) .environment( env=None, diff --git a/rllib/policy/dynamic_tf_policy_v2.py b/rllib/policy/dynamic_tf_policy_v2.py index e2ad3d6da0ab1..7368696044bdc 100644 --- a/rllib/policy/dynamic_tf_policy_v2.py +++ b/rllib/policy/dynamic_tf_policy_v2.py @@ -3,7 +3,7 @@ import logging import re import tree # pip install dm_tree -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union +from typing import Dict, List, Optional, Tuple, Type, Union from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import ModelV2 @@ -38,9 +38,6 @@ ) from ray.util.debug import log_once -if TYPE_CHECKING: - from ray.rllib.evaluation import Episode - tf1, tf, tfv = try_import_tf() logger = logging.getLogger(__name__) @@ -343,7 +340,7 @@ def postprocess_trajectory( self, sample_batch: SampleBatch, other_agent_batches: Optional[SampleBatch] = None, - episode: Optional["Episode"] = None, + episode=None, ): """Post process trajectory in the format of a SampleBatch. diff --git a/rllib/policy/eager_tf_policy.py b/rllib/policy/eager_tf_policy.py index 4a5d463ae7649..c2e4fa33f1592 100644 --- a/rllib/policy/eager_tf_policy.py +++ b/rllib/policy/eager_tf_policy.py @@ -10,7 +10,6 @@ import tree # pip install dm_tree -from ray.rllib.evaluation.episode import Episode from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.repeated_values import RepeatedValues from ray.rllib.policy.policy import Policy, PolicyState @@ -173,7 +172,7 @@ def compute_actions_from_input_dict( input_dict: Dict[str, TensorType], explore: bool = None, timestep: Optional[int] = None, - episodes: Optional[List[Episode]] = None, + episodes=None, **kwargs, ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: """Traced version of Policy.compute_actions_from_input_dict.""" @@ -462,7 +461,7 @@ def compute_actions_from_input_dict( input_dict: Dict[str, TensorType], explore: bool = None, timestep: Optional[int] = None, - episodes: Optional[List[Episode]] = None, + episodes=None, **kwargs, ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: if not self.config.get("eager_tracing") and not tf1.executing_eagerly(): @@ -511,7 +510,7 @@ def compute_actions( prev_action_batch: Union[List[TensorStructType], TensorStructType] = None, prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None, info_batch: Optional[Dict[str, list]] = None, - episodes: Optional[List["Episode"]] = None, + episodes: Optional[List] = None, explore: Optional[bool] = None, timestep: Optional[int] = None, **kwargs, diff --git a/rllib/policy/eager_tf_policy_v2.py b/rllib/policy/eager_tf_policy_v2.py index a37e86df93ae0..9aedd3112292c 100644 --- a/rllib/policy/eager_tf_policy_v2.py +++ b/rllib/policy/eager_tf_policy_v2.py @@ -12,7 +12,6 @@ import tree # pip install dm_tree from ray.rllib.utils.numpy import convert_to_numpy -from ray.rllib.evaluation.episode import Episode from ray.rllib.models.catalog import ModelCatalog from ray.rllib.models.modelv2 import ModelV2 from ray.rllib.models.tf.tf_action_dist import TFActionDistribution @@ -339,7 +338,7 @@ def postprocess_trajectory( self, sample_batch: SampleBatch, other_agent_batches: Optional[SampleBatch] = None, - episode: Optional["Episode"] = None, + episode=None, ): """Post process trajectory in the format of a SampleBatch. @@ -420,7 +419,7 @@ def compute_actions_from_input_dict( input_dict: Dict[str, TensorType], explore: bool = None, timestep: Optional[int] = None, - episodes: Optional[List[Episode]] = None, + episodes=None, **kwargs, ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: self._is_training = False diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 64cc3db90722d..500292627e6da 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -4,7 +4,6 @@ import platform from abc import ABCMeta, abstractmethod from typing import ( - TYPE_CHECKING, Any, Callable, Collection, @@ -77,9 +76,6 @@ tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() -if TYPE_CHECKING: - from ray.rllib.evaluation import Episode - logger = logging.getLogger(__name__) @@ -455,7 +451,7 @@ def compute_single_action( prev_reward: Optional[TensorStructType] = None, info: dict = None, input_dict: Optional[SampleBatch] = None, - episode: Optional["Episode"] = None, + episode=None, explore: Optional[bool] = None, timestep: Optional[int] = None, # Kwars placeholder for future compatibility. @@ -558,7 +554,7 @@ def compute_actions_from_input_dict( input_dict: Union[SampleBatch, Dict[str, TensorStructType]], explore: Optional[bool] = None, timestep: Optional[int] = None, - episodes: Optional[List["Episode"]] = None, + episodes=None, **kwargs, ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: """Computes actions from collected samples (across multiple-agents). @@ -615,7 +611,7 @@ def compute_actions( prev_action_batch: Union[List[TensorStructType], TensorStructType] = None, prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None, info_batch: Optional[Dict[str, list]] = None, - episodes: Optional[List["Episode"]] = None, + episodes: Optional[List] = None, explore: Optional[bool] = None, timestep: Optional[int] = None, **kwargs, @@ -692,7 +688,7 @@ def postprocess_trajectory( other_agent_batches: Optional[ Dict[AgentID, Tuple["Policy", SampleBatch]] ] = None, - episode: Optional["Episode"] = None, + episode=None, ) -> SampleBatch: """Implements algorithm-specific trajectory postprocessing. @@ -967,14 +963,13 @@ def get_state(self) -> PolicyState: ) state["policy_spec"] = policy_spec.serialize() - if self.config.get("enable_connectors", False): - # Checkpoint connectors state as well if enabled. - connector_configs = {} - if self.agent_connectors: - connector_configs["agent"] = self.agent_connectors.to_state() - if self.action_connectors: - connector_configs["action"] = self.action_connectors.to_state() - state["connector_configs"] = connector_configs + # Checkpoint connectors state as well if enabled. + connector_configs = {} + if self.agent_connectors: + connector_configs["agent"] = self.agent_connectors.to_state() + if self.action_connectors: + connector_configs["action"] = self.action_connectors.to_state() + state["connector_configs"] = connector_configs return state @@ -988,10 +983,6 @@ def restore_connectors(self, state: PolicyState): # To avoid a circular dependency problem cause by SampleBatch. from ray.rllib.connectors.util import restore_connectors_for_policy - # No-op if connector is not enabled. - if not self.config.get("enable_connectors", False): - return - connector_configs = state.get("connector_configs", {}) if "agent" in connector_configs: self.agent_connectors = restore_connectors_for_policy( diff --git a/rllib/policy/policy_template.py b/rllib/policy/policy_template.py index e2fb1a1a3c2f4..f7bbb7142ecab 100644 --- a/rllib/policy/policy_template.py +++ b/rllib/policy/policy_template.py @@ -6,7 +6,6 @@ Optional, Tuple, Type, - TYPE_CHECKING, Union, ) @@ -26,9 +25,6 @@ from ray.rllib.utils.numpy import convert_to_numpy from ray.rllib.utils.typing import ModelGradients, TensorType, AlgorithmConfigDict -if TYPE_CHECKING: - from ray.rllib.evaluation.episode import Episode # noqa - jax, _ = try_import_jax() torch, _ = try_import_torch() @@ -52,7 +48,7 @@ def build_policy_class( Policy, SampleBatch, Optional[Dict[Any, SampleBatch]], - Optional["Episode"], + Optional[Any], ], SampleBatch, ] @@ -140,7 +136,7 @@ def build_policy_class( overrides. If None, uses only(!) the user-provided PartialAlgorithmConfigDict as dict for this Policy. postprocess_fn (Optional[Callable[[Policy, SampleBatch, - Optional[Dict[Any, SampleBatch]], Optional["Episode"]], + Optional[Dict[Any, SampleBatch]], Optional[Any]], SampleBatch]]): Optional callable for post-processing experience batches (called after the super's `postprocess_trajectory` method). stats_fn (Optional[Callable[[Policy, SampleBatch], diff --git a/rllib/policy/tests/test_policy_checkpoint_restore.py b/rllib/policy/tests/test_policy_checkpoint_restore.py index 93449c550fd48..87ff462e7787b 100644 --- a/rllib/policy/tests/test_policy_checkpoint_restore.py +++ b/rllib/policy/tests/test_policy_checkpoint_restore.py @@ -53,11 +53,7 @@ def test_policy_from_checkpoint_twice_torch(self): def test_add_policy_connector_enabled(self): with tempfile.TemporaryDirectory() as tmpdir: - config = ( - APPOConfig() - .environment("CartPole-v1") - .env_runners(enable_connectors=True) - ) + config = APPOConfig().environment("CartPole-v1") algo = config.build() algo.train() result = algo.save(checkpoint_dir=tmpdir) diff --git a/rllib/policy/tf_policy.py b/rllib/policy/tf_policy.py index cceca81dd5d4b..11c524f9c2bf4 100644 --- a/rllib/policy/tf_policy.py +++ b/rllib/policy/tf_policy.py @@ -1,6 +1,6 @@ import logging import math -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import gymnasium as gym import numpy as np @@ -35,9 +35,6 @@ ) from ray.util.debug import log_once -if TYPE_CHECKING: - from ray.rllib.evaluation import Episode - tf1, tf, tfv = try_import_tf() logger = logging.getLogger(__name__) @@ -308,7 +305,7 @@ def compute_actions_from_input_dict( input_dict: Union[SampleBatch, Dict[str, TensorType]], explore: bool = None, timestep: Optional[int] = None, - episodes: Optional[List["Episode"]] = None, + episode=None, **kwargs, ) -> Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: explore = explore if explore is not None else self.config["explore"] @@ -349,7 +346,7 @@ def compute_actions( prev_action_batch: Union[List[TensorType], TensorType] = None, prev_reward_batch: Union[List[TensorType], TensorType] = None, info_batch: Optional[Dict[str, list]] = None, - episodes: Optional[List["Episode"]] = None, + episodes=None, explore: Optional[bool] = None, timestep: Optional[int] = None, **kwargs, diff --git a/rllib/policy/tf_policy_template.py b/rllib/policy/tf_policy_template.py index 83b132de5ceea..fcc123b6a5ef7 100644 --- a/rllib/policy/tf_policy_template.py +++ b/rllib/policy/tf_policy_template.py @@ -1,5 +1,5 @@ import gymnasium as gym -from typing import Callable, Dict, List, Optional, Tuple, Type, Union, TYPE_CHECKING +from typing import Callable, Dict, List, Optional, Tuple, Type, Union from ray.rllib.models.tf.tf_action_dist import TFActionDistribution from ray.rllib.models.modelv2 import ModelV2 @@ -17,15 +17,11 @@ from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.typing import ( - AgentID, ModelGradients, TensorType, AlgorithmConfigDict, ) -if TYPE_CHECKING: - from ray.rllib.evaluation import Episode - tf1, tf, tfv = try_import_tf() @@ -38,17 +34,7 @@ def build_tf_policy( Union[TensorType, List[TensorType]], ], get_default_config: Optional[Callable[[None], AlgorithmConfigDict]] = None, - postprocess_fn: Optional[ - Callable[ - [ - Policy, - SampleBatch, - Optional[Dict[AgentID, SampleBatch]], - Optional["Episode"], - ], - SampleBatch, - ] - ] = None, + postprocess_fn=None, stats_fn: Optional[Callable[[Policy, SampleBatch], Dict[str, TensorType]]] = None, optimizer_fn: Optional[ Callable[[Policy, AlgorithmConfigDict], "tf.keras.optimizers.Optimizer"] diff --git a/rllib/policy/torch_policy.py b/rllib/policy/torch_policy.py index f0a047400db4a..64eeb83740019 100644 --- a/rllib/policy/torch_policy.py +++ b/rllib/policy/torch_policy.py @@ -6,7 +6,6 @@ import threading import time from typing import ( - TYPE_CHECKING, Any, Callable, Dict, @@ -53,9 +52,6 @@ TensorType, ) -if TYPE_CHECKING: - from ray.rllib.evaluation import Episode # noqa - torch, nn = try_import_torch() logger = logging.getLogger(__name__) @@ -332,7 +328,7 @@ def compute_actions( prev_action_batch: Union[List[TensorStructType], TensorStructType] = None, prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None, info_batch: Optional[Dict[str, list]] = None, - episodes: Optional[List["Episode"]] = None, + episodes=None, explore: Optional[bool] = None, timestep: Optional[int] = None, **kwargs, diff --git a/rllib/policy/torch_policy_v2.py b/rllib/policy/torch_policy_v2.py index 66a7e1993d4a7..08216eb6d5da6 100644 --- a/rllib/policy/torch_policy_v2.py +++ b/rllib/policy/torch_policy_v2.py @@ -5,7 +5,7 @@ import os import threading import time -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union import gymnasium as gym import numpy as np @@ -54,9 +54,6 @@ TensorType, ) -if TYPE_CHECKING: - from ray.rllib.evaluation import Episode # noqa - torch, nn = try_import_torch() logger = logging.getLogger(__name__) @@ -388,7 +385,7 @@ def postprocess_trajectory( self, sample_batch: SampleBatch, other_agent_batches: Optional[Dict[Any, SampleBatch]] = None, - episode: Optional["Episode"] = None, + episode=None, ) -> SampleBatch: """Postprocesses a trajectory and returns the processed trajectory. @@ -517,7 +514,7 @@ def compute_actions( prev_action_batch: Union[List[TensorStructType], TensorStructType] = None, prev_reward_batch: Union[List[TensorStructType], TensorStructType] = None, info_batch: Optional[Dict[str, list]] = None, - episodes: Optional[List["Episode"]] = None, + episodes=None, explore: Optional[bool] = None, timestep: Optional[int] = None, **kwargs, diff --git a/rllib/utils/policy.py b/rllib/utils/policy.py index 813e441909180..9cadcb08b0547 100644 --- a/rllib/utils/policy.py +++ b/rllib/utils/policy.py @@ -128,10 +128,6 @@ def parse_policy_specs_from_checkpoint( w = pickle.loads(checkpoint_dict["worker"]) policy_config = w["policy_config"] - assert policy_config.get("enable_connectors", False), ( - "load_policies_from_checkpoint only works for checkpoints generated by stacks " - "with connectors enabled." - ) policy_states = w.get("policy_states", w["state"]) serialized_policy_specs = w["policy_specs"] policy_specs = {