From 182744bbd151c166b8028355eae12a5da63fb3cc Mon Sep 17 00:00:00 2001 From: Sven Mika Date: Wed, 26 Oct 2022 11:31:56 +0200 Subject: [PATCH] [RLlib] AlgorithmConfig: Next steps (volume 01); Algos, RolloutWorker, PolicyMap, WorkerSet use AlgorithmConfig objects under the hood. (#29395) --- ...-saving-and-loading-algos-and-policies.rst | 2 +- rllib/algorithms/a2c/a2c.py | 7 +- rllib/algorithms/a2c/tests/test_a2c.py | 25 +- rllib/algorithms/a3c/a3c.py | 7 +- rllib/algorithms/a3c/tests/test_a3c.py | 16 +- rllib/algorithms/algorithm.py | 460 +++------ rllib/algorithms/algorithm_config.py | 781 +++++++++++++- rllib/algorithms/alpha_star/alpha_star.py | 48 +- .../alpha_star/distributed_learners.py | 8 + rllib/algorithms/alpha_star/league_builder.py | 52 +- .../alpha_star/tests/test_alpha_star.py | 9 +- rllib/algorithms/alpha_zero/alpha_zero.py | 8 +- rllib/algorithms/apex_ddpg/apex_ddpg.py | 10 +- .../apex_ddpg/tests/test_apex_ddpg.py | 18 +- rllib/algorithms/apex_dqn/apex_dqn.py | 72 +- .../apex_dqn/tests/test_apex_dqn.py | 27 +- rllib/algorithms/appo/appo.py | 9 +- rllib/algorithms/ars/ars.py | 17 +- rllib/algorithms/bandit/bandit.py | 8 +- rllib/algorithms/bandit/tests/test_bandits.py | 6 +- rllib/algorithms/bc/bc.py | 4 +- rllib/algorithms/bc/tests/test_bc.py | 8 +- rllib/algorithms/cql/cql.py | 8 +- rllib/algorithms/cql/tests/test_cql.py | 15 +- rllib/algorithms/ddpg/ddpg.py | 4 +- rllib/algorithms/ddpg/tests/test_ddpg.py | 5 +- rllib/algorithms/ddppo/ddppo.py | 11 +- rllib/algorithms/ddppo/tests/test_ddppo.py | 16 +- rllib/algorithms/dqn/dqn.py | 4 +- rllib/algorithms/dqn/tests/test_dqn.py | 21 +- rllib/algorithms/dreamer/dreamer.py | 9 +- .../algorithms/dreamer/tests/test_dreamer.py | 6 +- rllib/algorithms/dt/dt.py | 43 +- rllib/algorithms/es/es.py | 18 +- rllib/algorithms/impala/impala.py | 13 +- rllib/algorithms/impala/tests/test_impala.py | 20 +- rllib/algorithms/maddpg/maddpg_tf_policy.py | 14 +- rllib/algorithms/maml/maml.py | 4 +- rllib/algorithms/maml/tests/test_maml.py | 8 +- rllib/algorithms/marwil/tests/test_marwil.py | 18 +- rllib/algorithms/mbmpo/mbmpo.py | 6 +- rllib/algorithms/mbmpo/tests/test_mbmpo.py | 8 +- rllib/algorithms/mock.py | 2 +- rllib/algorithms/pg/pg.py | 11 +- rllib/algorithms/pg/pg_tf_policy.py | 11 +- rllib/algorithms/pg/pg_torch_policy.py | 13 +- rllib/algorithms/pg/tests/test_pg.py | 12 +- rllib/algorithms/ppo/ppo.py | 13 +- rllib/algorithms/ppo/tests/test_ppo.py | 21 +- rllib/algorithms/qmix/qmix.py | 8 +- rllib/algorithms/qmix/tests/test_qmix.py | 8 +- rllib/algorithms/r2d2/r2d2.py | 5 +- rllib/algorithms/sac/sac.py | 16 +- rllib/algorithms/sac/tests/test_rnnsac.py | 3 +- rllib/algorithms/sac/tests/test_sac.py | 50 +- rllib/algorithms/simple_q/simple_q.py | 7 +- .../simple_q/tests/test_simple_q.py | 12 +- rllib/algorithms/slateq/slateq.py | 8 +- rllib/algorithms/tests/test_algorithm.py | 4 +- .../algorithms/tests/test_worker_failures.py | 20 +- rllib/env/policy_client.py | 15 +- rllib/env/tests/test_external_env.py | 72 +- .../tests/test_external_multi_agent_env.py | 38 +- rllib/env/tests/test_multi_agent_env.py | 139 ++- rllib/env/wrappers/model_vector_env.py | 4 +- .../evaluation/collectors/agent_collector.py | 4 +- rllib/evaluation/rollout_worker.py | 974 +++++++++--------- rllib/evaluation/sampler.py | 10 +- rllib/evaluation/tests/test_episode.py | 31 +- rllib/evaluation/tests/test_episode_v2.py | 29 +- rllib/evaluation/tests/test_rollout_worker.py | 407 +++++--- .../tests/test_trajectory_view_api.py | 68 +- rllib/evaluation/worker_set.py | 281 ++--- .../documentation/replay_buffer_demo.py | 49 +- .../saving_and_loading_algos_and_policies.py | 4 +- rllib/examples/hierarchical_training.py | 7 +- rllib/execution/multi_gpu_learner_thread.py | 10 +- rllib/execution/rollout_ops.py | 5 +- rllib/execution/train_ops.py | 10 +- rllib/offline/estimators/tests/utils.py | 4 +- .../offline/tests/test_feature_importance.py | 6 +- rllib/policy/dynamic_tf_policy_v2.py | 6 - rllib/policy/eager_tf_policy_v2.py | 12 +- rllib/policy/policy_map.py | 60 +- .../checkpoints/create_checkpoints.py | 8 +- rllib/tests/test_execution.py | 13 +- rllib/tests/test_perf.py | 8 +- rllib/utils/debug/memory.py | 2 +- rllib/utils/policy.py | 5 + rllib/utils/pre_checks/multi_agent.py | 137 --- rllib/utils/serialization.py | 24 +- rllib/utils/test_utils.py | 3 +- rllib/utils/tests/test_errors.py | 36 +- rllib/utils/tf_utils.py | 5 +- 94 files changed, 2566 insertions(+), 1997 deletions(-) delete mode 100644 rllib/utils/pre_checks/multi_agent.py diff --git a/doc/source/rllib/rllib-saving-and-loading-algos-and-policies.rst b/doc/source/rllib/rllib-saving-and-loading-algos-and-policies.rst index 014c308598ad7..ad4ae76e4c325 100644 --- a/doc/source/rllib/rllib-saving-and-loading-algos-and-policies.rst +++ b/doc/source/rllib/rllib-saving-and-loading-algos-and-policies.rst @@ -82,7 +82,7 @@ handle any checkpoints created with Ray 2.0 or any version up to ``V``. .. code-block:: shell - $ mode rllib_checkpoint.json + $ more rllib_checkpoint.json {"type": "Algorithm", "checkpoint_version": "1.0"} Now, let's check out the `policies/` sub-directory: diff --git a/rllib/algorithms/a2c/a2c.py b/rllib/algorithms/a2c/a2c.py index efbb2862e4d6d..27d9eccff1ead 100644 --- a/rllib/algorithms/a2c/a2c.py +++ b/rllib/algorithms/a2c/a2c.py @@ -35,11 +35,12 @@ class A2CConfig(A3CConfig): >>> from ray import tune >>> config = A2CConfig().training(lr=0.01, grad_clip=30.0)\ ... .resources(num_gpus=0)\ - ... .rollouts(num_rollout_workers=2) + ... .rollouts(num_rollout_workers=2)\ + ... .environment("CartPole-v1") >>> print(config.to_dict()) >>> # Build a Algorithm object from the config and run 1 training iteration. - >>> trainer = config.build(env="CartPole-v1") - >>> trainer.train() + >>> algo = config.build() + >>> algo.train() Example: >>> import ray.air as air diff --git a/rllib/algorithms/a2c/tests/test_a2c.py b/rllib/algorithms/a2c/tests/test_a2c.py index e1854fb51db18..8609c3ce586b7 100644 --- a/rllib/algorithms/a2c/tests/test_a2c.py +++ b/rllib/algorithms/a2c/tests/test_a2c.py @@ -27,13 +27,14 @@ def test_a2c_compilation(self): # Test against all frameworks. for _ in framework_iterator(config, with_eager_tracing=True): for env in ["CartPole-v0", "Pendulum-v1", "PongDeterministic-v0"]: - trainer = config.build(env=env) + config.environment(env) + algo = config.build() for i in range(num_iterations): - results = trainer.train() + results = algo.train() check_train_results(results) print(results) - check_compute_single_action(trainer) - trainer.stop() + check_compute_single_action(algo) + algo.stop() def test_a2c_exec_impl(self): config = ( @@ -43,12 +44,12 @@ def test_a2c_exec_impl(self): ) for _ in framework_iterator(config): - trainer = config.build() - results = trainer.train() + algo = config.build() + results = algo.train() check_train_results(results) print(results) - check_compute_single_action(trainer) - trainer.stop() + check_compute_single_action(algo) + algo.stop() def test_a2c_exec_impl_microbatch(self): config = ( @@ -59,12 +60,12 @@ def test_a2c_exec_impl_microbatch(self): ) for _ in framework_iterator(config): - trainer = config.build() - results = trainer.train() + algo = config.build() + results = algo.train() check_train_results(results) print(results) - check_compute_single_action(trainer) - trainer.stop() + check_compute_single_action(algo) + algo.stop() if __name__ == "__main__": diff --git a/rllib/algorithms/a3c/a3c.py b/rllib/algorithms/a3c/a3c.py index a5117629dab6c..9d0206b0db9b4 100644 --- a/rllib/algorithms/a3c/a3c.py +++ b/rllib/algorithms/a3c/a3c.py @@ -35,11 +35,12 @@ class A3CConfig(AlgorithmConfig): >>> from ray import tune >>> config = A3CConfig().training(lr=0.01, grad_clip=30.0)\ ... .resources(num_gpus=0)\ - ... .rollouts(num_rollout_workers=4) + ... .rollouts(num_rollout_workers=4)\ + ... .environment("CartPole-v1") >>> print(config.to_dict()) >>> # Build a Algorithm object from the config and run 1 training iteration. - >>> trainer = config.build(env="CartPole-v1") - >>> trainer.train() + >>> algo = config.build() + >>> algo.train() Example: >>> config = A3CConfig() diff --git a/rllib/algorithms/a3c/tests/test_a3c.py b/rllib/algorithms/a3c/tests/test_a3c.py index a470503bfb10a..49fdff3327af2 100644 --- a/rllib/algorithms/a3c/tests/test_a3c.py +++ b/rllib/algorithms/a3c/tests/test_a3c.py @@ -31,15 +31,15 @@ def test_a3c_compilation(self): for env in ["CartPole-v1", "Pendulum-v1", "PongDeterministic-v0"]: print("env={}".format(env)) config.model["use_lstm"] = env == "CartPole-v1" - trainer = config.build(env=env) + algo = config.build(env=env) for i in range(num_iterations): - results = trainer.train() + results = algo.train() check_train_results(results) print(results) check_compute_single_action( - trainer, include_state=config.model["use_lstm"] + algo, include_state=config.model["use_lstm"] ) - trainer.stop() + algo.stop() def test_a3c_entropy_coeff_schedule(self): """Test A3C entropy coeff schedule support.""" @@ -78,17 +78,17 @@ def _step_n_times(trainer, n: int): # Test against all frameworks. for _ in framework_iterator(config): - trainer = config.build(env="CartPole-v1") + algo = config.build(env="CartPole-v1") - coeff = _step_n_times(trainer, 1) # 20 timesteps + coeff = _step_n_times(algo, 1) # 20 timesteps # Should be close to the starting coeff of 0.01 self.assertGreaterEqual(coeff, 0.005) - coeff = _step_n_times(trainer, 10) # 200 timesteps + coeff = _step_n_times(algo, 10) # 200 timesteps # Should have annealed to the final coeff of 0.0001. self.assertLessEqual(coeff, 0.00011) - trainer.stop() + algo.stop() if __name__ == "__main__": diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index 9900c03202990..bb3a760e9070c 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -7,7 +7,6 @@ import importlib import json import logging -import math import numpy as np import os from packaging import version @@ -36,7 +35,6 @@ import ray.cloudpickle as pickle from ray.exceptions import GetTimeoutError, RayActorError, RayError from ray.rllib.algorithms.algorithm_config import AlgorithmConfig -from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.rllib.algorithms.registry import ALGORITHMS as ALL_ALGORITHMS from ray.rllib.env.env_context import EnvContext from ray.rllib.env.utils import _gym_env_creator @@ -62,9 +60,9 @@ DirectMethod, DoublyRobust, ) -from ray.rllib.policy.policy import Policy +from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID, SampleBatch, concat_samples -from ray.rllib.utils import deep_update, FilterManager, merge_dicts +from ray.rllib.utils import deep_update, FilterManager from ray.rllib.utils.annotations import ( DeveloperAPI, ExperimentalAPI, @@ -95,7 +93,6 @@ ) from ray.rllib.utils.metrics.learner_info import LEARNER_INFO from ray.rllib.utils.policy import validate_policy_id -from ray.rllib.utils.pre_checks.multi_agent import check_multi_agent from ray.rllib.utils.replay_buffers import MultiAgentReplayBuffer from ray.rllib.utils.spaces import space_utils from ray.rllib.utils.typing import ( @@ -204,7 +201,7 @@ class Algorithm(Trainable): ] # List of keys that are always fully overridden if present in any dict or sub-dict - _override_all_key_list = ["off_policy_estimation_methods"] + _override_all_key_list = ["off_policy_estimation_methods", "policies"] _progress_metrics = [ "episode_reward_mean", @@ -311,40 +308,49 @@ def from_state(state: Dict) -> "Algorithm": @PublicAPI def __init__( self, - config: Optional[Union[PartialAlgorithmConfigDict, AlgorithmConfig]] = None, - env: Optional[Union[str, EnvType]] = None, + config: Union[AlgorithmConfig, PartialAlgorithmConfigDict], + env=None, # deprecated arg logger_creator: Optional[Callable[[], Logger]] = None, **kwargs, ): """Initializes an Algorithm instance. Args: - config: Algorithm-specific configuration dict. - env: Name of the environment to use (e.g. a gym-registered str), - a full class path (e.g. - "ray.rllib.examples.env.random_env.RandomEnv"), or an Env - class directly. Note that this arg can also be specified via - the "env" key in `config`. + config: Algorithm-specific configuration object. logger_creator: Callable that creates a ray.tune.Logger object. If unspecified, a default logger is created. **kwargs: Arguments passed to the Trainable base class. - """ - # User provided (partial) config (this may be w/o the default - # Algorithm's Config object). Will get merged with AlgorithmConfig() - # in self.setup(). - config = config or {} - # Resolve AlgorithmConfig into a plain dict. + # Resolve possible dict into an AlgorithmConfig object. # TODO: In the future, only support AlgorithmConfig objects here. - if isinstance(config, AlgorithmConfig): - config = config.to_dict() + if isinstance(config, dict): + default_config = self.get_default_config() + # `self.get_default_config()` also returned a dict -> + # Last resort: Create core AlgorithmConfig from merged dicts. + if isinstance(default_config, dict): + config = AlgorithmConfig.from_dict( + config_dict=self.merge_trainer_configs(default_config, config, True) + ) + else: + config = default_config.update_from_dict(config) + + if env is not None: + deprecation_warning( + old=f"algo = Algorithm(env='{env}', ...)", + new=f"algo = AlgorithmConfig().environment('{env}').build()", + error=False, + ) + config.environment(env) + + # Freeze our AlgorithmConfig object (no more changes possible). + config.freeze() # Convert `env` provided in config into a concrete env creator callable, which # takes an EnvContext (config dict) as arg and returning an RLlib supported Env # type (e.g. a gym.Env). self._env_id, self.env_creator = self._get_env_id_and_creator( - env or config.get("env"), config + config.env, config ) env_descr = ( self._env_id.__name__ if isinstance(self._env_id, type) else self._env_id @@ -367,7 +373,7 @@ class directly. Note that this arg can also be specified via # Allow users to more precisely configure the created logger # via "logger_config.type". - if config.get("logger_config") and "type" in config["logger_config"]: + if config.logger_config and "type" in config.logger_config: def default_logger_creator(config): """Creates a custom logger with the default prefix.""" @@ -394,6 +400,9 @@ def default_logger_creator(config): self._episodes_to_be_collected = [] self._remote_workers_for_metrics = [] + # The fully qualified AlgorithmConfig used for evaluation + # (or None if evaluation not setup). + self.evaluation_config: Optional[AlgorithmConfig] = None # Evaluation WorkerSet and metrics last returned by `self.evaluate()`. self.evaluation_workers: Optional[WorkerSet] = None # If evaluation duration is "auto", use a AsyncRequestsManager to be more @@ -411,7 +420,11 @@ def default_logger_creator(config): } } - super().__init__(config=config, logger_creator=logger_creator, **kwargs) + super().__init__( + config=config, + logger_creator=logger_creator, + **kwargs, + ) # Check, whether `training_iteration` is still a tune.Trainable property # and has not been overridden by the user in the attempt to implement the @@ -427,19 +440,24 @@ def default_logger_creator(config): @OverrideToImplementCustomLogic @classmethod - def get_default_config(cls) -> AlgorithmConfigDict: - return AlgorithmConfig().to_dict() + def get_default_config(cls) -> Union[AlgorithmConfig, AlgorithmConfigDict]: + return AlgorithmConfig() @OverrideToImplementCustomLogic_CallToSuperRecommended @override(Trainable) - def setup(self, config: PartialAlgorithmConfigDict): - - # Setup our config: Merge the user-supplied config (which could - # be a partial config dict with the class' default). - self.config = self.merge_trainer_configs( - self.get_default_config(), config, self._allow_unknown_configs - ) - self.config["env"] = self._env_id + def setup(self, config: Union[AlgorithmConfig, PartialAlgorithmConfigDict]): + + # Setup our config: Merge the user-supplied config dict (which could + # be a partial config dict) with the class' default. + if not isinstance(config, AlgorithmConfig): + assert isinstance(config, PartialAlgorithmConfigDict) + config_obj = self.get_default_config() + if not isinstance(config_obj, AlgorithmConfig): + assert isinstance(config, PartialAlgorithmConfigDict) + config_obj = AlgorithmConfig().from_dict(config_obj) + config_obj.update_from_dict(config) + config_obj.env = self._env_id + self.config = config_obj # Validate the framework settings in config. self.validate_framework(self.config) @@ -484,10 +502,9 @@ def setup(self, config: PartialAlgorithmConfigDict): ope_dict = {str(ope): {"type": ope} for ope in input_evaluation} deprecation_warning( old="config.input_evaluation={}".format(input_evaluation), - new='config["evaluation_config"]' - '["off_policy_estimation_methods"]={}'.format( - ope_dict, - ), + new="config.evaluation(evaluation_config={" + f"'off_policy_estimation_methods'={ope_dict}" + "})", error=True, help="Running OPE during training is not recommended.", ) @@ -524,8 +541,8 @@ def setup(self, config: PartialAlgorithmConfigDict): self.workers = WorkerSet( env_creator=self.env_creator, validate_env=self.validate_env, - policy_class=self.get_default_policy_class(self.config), - trainer_config=self.config, + default_policy_class=self.get_default_policy_class(self.config), + config=self.config, num_workers=self.config["num_workers"], local_worker=True, logdir=self.logdir, @@ -574,67 +591,15 @@ def setup(self, config: PartialAlgorithmConfigDict): "policies" ] = self.workers.local_worker().policy_dict + # Validate evaluation config. + self.evaluation_config = self.config.get_evaluation_config_object() + self.validate_config(self.evaluation_config) + # Evaluation WorkerSet setup. # User would like to setup a separate evaluation worker set. - - # Update with evaluation settings: - user_eval_config = copy.deepcopy(self.config["evaluation_config"]) - - # Merge user-provided eval config with the base config. This makes sure - # the eval config is always complete, no matter whether we have eval - # workers or perform evaluation on the (non-eval) local worker. - eval_config = merge_dicts(self.config, user_eval_config) - self.config["evaluation_config"] = eval_config - - if self.config.get("evaluation_num_workers", 0) > 0 or self.config.get( - "evaluation_interval" - ): - logger.debug(f"Using evaluation_config: {user_eval_config}.") - - # Validate evaluation config. - self.validate_config(eval_config) - - # Set the `in_evaluation` flag. - eval_config["in_evaluation"] = True - - # Evaluation duration unit: episodes. - # Switch on `complete_episode` rollouts. Also, make sure - # rollout fragments are short so we never have more than one - # episode in one rollout. - if eval_config["evaluation_duration_unit"] == "episodes": - eval_config.update( - { - "batch_mode": "complete_episodes", - "rollout_fragment_length": 1, - } - ) - # Evaluation duration unit: timesteps. - # - Set `batch_mode=truncate_episodes` so we don't perform rollouts - # strictly along episode borders. - # Set `rollout_fragment_length` such that desired steps are divided - # equally amongst workers or - in "auto" duration mode - set it - # to a reasonably small number (10), such that a single `sample()` - # call doesn't take too much time and we can stop evaluation as soon - # as possible after the train step is completed. - else: - eval_config.update( - { - "batch_mode": "truncate_episodes", - "rollout_fragment_length": 10 - if self.config["evaluation_duration"] == "auto" - else int( - math.ceil( - self.config["evaluation_duration"] - / (self.config["evaluation_num_workers"] or 1) - ) - ), - } - ) - - self.config["evaluation_config"] = eval_config - + if self.config.evaluation_num_workers > 0 or self.config.evaluation_interval: _, env_creator = self._get_env_id_and_creator( - eval_config.get("env"), eval_config + self.evaluation_config.env, self.evaluation_config ) # Create a separate evaluation worker set for evaluation. @@ -644,8 +609,8 @@ def setup(self, config: PartialAlgorithmConfigDict): self.evaluation_workers: WorkerSet = WorkerSet( env_creator=env_creator, validate_env=None, - policy_class=self.get_default_policy_class(self.config), - trainer_config=eval_config, + default_policy_class=self.get_default_policy_class(self.config), + config=self.evaluation_config, num_workers=self.config["evaluation_num_workers"], # Don't even create a local worker if num_workers > 0. local_worker=False, @@ -681,7 +646,6 @@ def setup(self, config: PartialAlgorithmConfigDict): mod, obj = method_type.rsplit(".", 1) mod = importlib.import_module(mod) method_type = getattr(mod, obj) - if isinstance(method_type, type) and issubclass( method_type, OfflineEvaluator ): @@ -710,7 +674,10 @@ def _init(self, config: AlgorithmConfigDict, env_creator: EnvCreator) -> None: raise NotImplementedError @OverrideToImplementCustomLogic - def get_default_policy_class(self, config: AlgorithmConfigDict) -> Type[Policy]: + def get_default_policy_class( + self, + config: Union[AlgorithmConfig, AlgorithmConfigDict], + ) -> Type[Policy]: """Returns a default Policy class to use, given a config. This class will be used inside RolloutWorkers' PolicyMaps in case @@ -892,7 +859,7 @@ def evaluate( # In "auto" mode (only for parallel eval + training): Run as long # as training lasts. unit = self.config["evaluation_duration_unit"] - eval_cfg = self.config["evaluation_config"] + eval_cfg = self.evaluation_config rollout = eval_cfg["rollout_fragment_length"] num_envs = eval_cfg["num_envs_per_worker"] auto = self.config["evaluation_duration"] == "auto" @@ -1092,7 +1059,7 @@ def _evaluate_async( # In "auto" mode (only for parallel eval + training): Run as long # as training lasts. unit = self.config["evaluation_duration_unit"] - eval_cfg = self.config["evaluation_config"] + eval_cfg = self.evaluation_config rollout = eval_cfg["rollout_fragment_length"] num_envs = eval_cfg["num_envs_per_worker"] auto = self.config["evaluation_duration"] == "auto" @@ -1673,7 +1640,7 @@ def add_policy( *, observation_space: Optional[gym.spaces.Space] = None, action_space: Optional[gym.spaces.Space] = None, - config: Optional[PartialAlgorithmConfigDict] = None, + config: Optional[Union[AlgorithmConfig, PartialAlgorithmConfigDict]] = None, policy_state: Optional[PolicyState] = None, policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None, policies_to_train: Optional[ @@ -1703,7 +1670,7 @@ def add_policy( If None, try to infer this space from the environment. action_space: The action space of the policy to add. If None, try to infer this space from the environment. - config: The config overrides for the policy to add. + config: The config object or overrides for the policy to add. policy_state: Optional state dict to apply to the new policy instance, right after its construction. policy_mapping_fn: An optional (updated) policy mapping function @@ -2006,7 +1973,7 @@ def cleanup(self) -> None: @classmethod @override(Trainable) def default_resource_request( - cls, config: PartialAlgorithmConfigDict + cls, config: Union[AlgorithmConfig, PartialAlgorithmConfigDict] ) -> Union[Resources, PlacementGroupFactory]: # Default logic for RLlib Algorithms: @@ -2016,8 +1983,14 @@ def default_resource_request( # workers to determine their CPU/GPU resource needs. # Convenience config handles. - cf = dict(cls.get_default_config(), **config) - eval_cf = cf["evaluation_config"] + default_config = cls.get_default_config() + # TODO: Have to make this work for now for AlgorithmConfigs (returned by + # get_default_config(). Use only AlgorithmConfigs once all Algorithms + # return an AlgorothmConfig from their get_default_config() method. + if not isinstance(default_config, dict): + default_config = default_config.to_dict() + cf = dict(default_config, **config) + eval_cf = cf["evaluation_config"] or {} local_worker = { "CPU": cf["num_cpus_for_driver"], @@ -2069,14 +2042,14 @@ def _before_evaluate(self): @staticmethod def _get_env_id_and_creator( - env_specifier: Union[str, EnvType, None], config: PartialAlgorithmConfigDict + env_specifier: Union[str, EnvType, None], config: AlgorithmConfig ) -> Tuple[Optional[str], EnvCreator]: """Returns env_id and creator callable given original env id from config. Args: env_specifier: An env class, an already tune registered env ID, a known gym env name, or None (if no env is used). - config: The Algorithm's (maybe partial) config dict. + config: The AlgorithmConfig object. Returns: Tuple consisting of a) env ID string and b) env creator callable. @@ -2109,7 +2082,7 @@ def env_creator_from_classpath(env_context): elif isinstance(env_specifier, type): env_id = env_specifier # .__name__ - if config.get("remote_worker_envs"): + if config["remote_worker_envs"]: # Check gym version (0.22 or higher?). # If > 0.21, can't perform auto-wrapping of the given class as this # would lead to a pickle error. @@ -2183,12 +2156,15 @@ def _sync_weights_to_workers( @classmethod @override(Trainable) - def resource_help(cls, config: AlgorithmConfigDict) -> str: + def resource_help(cls, config: Union[AlgorithmConfig, AlgorithmConfigDict]) -> str: return ( - "\n\nYou can adjust the resource requests of RLlib agents by " - "setting `num_workers`, `num_gpus`, and other configs. See " - "the DEFAULT_CONFIG defined by each agent for more info.\n\n" - "The config of this agent is: {}".format(config) + "\n\nYou can adjust the resource requests of RLlib Algorithms by calling " + "`AlgorithmConfig.resources(" + "num_gpus=.., num_cpus_per_worker=.., num_gpus_per_worker=.., ..)` or " + "`AgorithmConfig.rollouts(num_rollout_workers=..)`. See " + "the `ray.rllib.algorithms.algorithm_config.AlgorithmConfig` classes " + "(each Algorithm has its own subclass of this class) for more info.\n\n" + f"The config of this Algorithm is: {config}" ) @classmethod @@ -2235,12 +2211,13 @@ def merge_trainer_configs( ) @staticmethod - def validate_framework(config: PartialAlgorithmConfigDict) -> None: - """Validates the config dictionary wrt the framework settings. + def validate_framework( + config: Union[AlgorithmConfig, PartialAlgorithmConfigDict] + ) -> None: + """Validates the config object (or dictionary) wrt. the framework settings. Args: - config: The config dictionary to be validated. - + config: The config object (or dictionary) to be validated. """ _tf1, _tf, _tfv = None, None, None _torch = None @@ -2313,47 +2290,39 @@ def resolve_tf_settings(): @OverrideToImplementCustomLogic_CallToSuperRecommended @DeveloperAPI - def validate_config(self, config: AlgorithmConfigDict) -> None: - """Validates a given config dict for this Algorithm. + def validate_config( + self, + config: Union[AlgorithmConfig, AlgorithmConfigDict], + ) -> None: + """Validates a given config object (or dictionary) for this Algorithm. Users should override this method to implement custom validation behavior. It is recommended to call `super().validate_config()` in this override. Args: - config: The given config dict to check. + config: The given config object (or dictionary) to check. Raises: ValueError: If there is something wrong with the config. """ - model_config = config.get("model") - if model_config is None: - config["model"] = model_config = {} - - # Use DefaultCallbacks class, if callbacks is None. - if config["callbacks"] is None: - config["callbacks"] = DefaultCallbacks - # Check, whether given `callbacks` is a callable. - if not callable(config["callbacks"]): - raise ValueError( - "`callbacks` must be a callable method that " - "returns a subclass of DefaultCallbacks, got " - f"{config['callbacks']}!" - ) + from ray.rllib.models.catalog import MODEL_DEFAULTS + + model_config = config.get("model", MODEL_DEFAULTS) # Multi-GPU settings. simple_optim_setting = config.get("simple_optimizer", DEPRECATED_VALUE) - if simple_optim_setting != DEPRECATED_VALUE: - deprecation_warning(old="simple_optimizer", error=False) - # Validate "multiagent" sub-dict and convert policy 4-tuples to - # PolicySpec objects. - policies, is_multi_agent = check_multi_agent(config) + framework = config.get("framework", "tf") - framework = config.get("framework") + if simple_optim_setting is True: + pass # Multi-GPU setting: Must use MultiGPUTrainOneStep. - if config.get("num_gpus", 0) > 1: - if framework in ["tfe", "tf2"]: + elif config.get("num_gpus", 0) > 1: + # TODO: AlphaStar uses >1 GPUs differently (1 per policy actor), so this is + # ok for tf2 here. + # Remove this hacky check, once we have fully moved to the RLTrainer API. + if framework in ["tfe", "tf2"] and type(self).__name__ != "AlphaStar": raise ValueError( "`num_gpus` > 1 not supported yet for " "framework={}!".format(framework) @@ -2373,18 +2342,32 @@ def validate_config(self, config: AlgorithmConfigDict) -> None: config["simple_optimizer"] = True # Multi-agent case: Try using MultiGPU optimizer (only # if all policies used are DynamicTFPolicies or TorchPolicies). - elif is_multi_agent: + elif ( + (isinstance(config, AlgorithmConfig) and config.is_multi_agent()) + or isinstance(config, dict) + and AlgorithmConfig.from_dict(config).is_multi_agent() + ): from ray.rllib.policy.dynamic_tf_policy import DynamicTFPolicy from ray.rllib.policy.torch_policy import TorchPolicy default_policy_cls = self.get_default_policy_class(config) + policies = config["multiagent"]["policies"] + policy_specs = ( + [ + PolicySpec(*spec) if isinstance(spec, (tuple, list)) else spec + for spec in policies.values() + ] + if isinstance(policies, dict) + else [PolicySpec() for _ in policies] + ) + if any( - (p.policy_class or default_policy_cls) is None + (spec.policy_class or default_policy_cls) is None or not issubclass( - p.policy_class or default_policy_cls, + spec.policy_class or default_policy_cls, (DynamicTFPolicy, TorchPolicy), ) - for p in config["multiagent"]["policies"].values() + for spec in policy_specs ): config["simple_optimizer"] = True else: @@ -2397,7 +2380,7 @@ def validate_config(self, config: AlgorithmConfigDict) -> None: if framework in ["tfe", "tf2"]: raise ValueError( "`simple_optimizer=False` not supported for " - "framework={}!".format(framework) + "config.framework({})!".format(framework) ) # Check model config. @@ -2419,155 +2402,12 @@ def validate_config(self, config: AlgorithmConfigDict) -> None: "model.lstm_use_prev_action and model.lstm_use_prev_reward", error=True, ) - model_config["lstm_use_prev_action"] = prev_a_r - model_config["lstm_use_prev_reward"] = prev_a_r - - # Check batching/sample collection settings. - if config["batch_mode"] not in ["truncate_episodes", "complete_episodes"]: - raise ValueError( - "`batch_mode` must be one of [truncate_episodes|" - "complete_episodes]! Got {}".format(config["batch_mode"]) - ) # Store multi-agent batch count mode. self._by_agent_steps = ( self.config["multiagent"].get("count_steps_by") == "agent_steps" ) - # Metrics settings. - if ( - config.get("metrics_smoothing_episodes", DEPRECATED_VALUE) - != DEPRECATED_VALUE - ): - deprecation_warning( - old="metrics_smoothing_episodes", - new="metrics_num_episodes_for_smoothing", - error=True, - ) - config["metrics_num_episodes_for_smoothing"] = config[ - "metrics_smoothing_episodes" - ] - if config.get("min_iter_time_s", DEPRECATED_VALUE) != DEPRECATED_VALUE: - deprecation_warning( - old="min_iter_time_s", - new="min_time_s_per_iteration", - error=True, - ) - config["min_time_s_per_iteration"] = config["min_iter_time_s"] or 0 - - if config.get("min_time_s_per_reporting", DEPRECATED_VALUE) != DEPRECATED_VALUE: - deprecation_warning( - old="min_time_s_per_reporting", - new="min_time_s_per_iteration", - error=True, - ) - config["min_time_s_per_iteration"] = config["min_time_s_per_reporting"] or 0 - - if ( - config.get("min_sample_timesteps_per_reporting", DEPRECATED_VALUE) - != DEPRECATED_VALUE - ): - deprecation_warning( - old="min_sample_timesteps_per_reporting", - new="min_sample_timesteps_per_iteration", - error=True, - ) - config["min_sample_timesteps_per_iteration"] = ( - config["min_sample_timesteps_per_reporting"] or 0 - ) - - if ( - config.get("min_train_timesteps_per_reporting", DEPRECATED_VALUE) - != DEPRECATED_VALUE - ): - deprecation_warning( - old="min_train_timesteps_per_reporting", - new="min_train_timesteps_per_iteration", - error=True, - ) - config["min_train_timesteps_per_iteration"] = ( - config["min_train_timesteps_per_reporting"] or 0 - ) - - if config.get("collect_metrics_timeout", DEPRECATED_VALUE) != DEPRECATED_VALUE: - # TODO: Warn once all algos use the `training_iteration` method. - # deprecation_warning( - # old="collect_metrics_timeout", - # new="metrics_episode_collection_timeout_s", - # error=False, - # ) - config["metrics_episode_collection_timeout_s"] = config[ - "collect_metrics_timeout" - ] - - if config.get("timesteps_per_iteration", DEPRECATED_VALUE) != DEPRECATED_VALUE: - deprecation_warning( - old="timesteps_per_iteration", - new="`min_sample_timesteps_per_iteration` OR " - "`min_train_timesteps_per_iteration`", - error=True, - ) - config["min_sample_timesteps_per_iteration"] = ( - config["timesteps_per_iteration"] or 0 - ) - config["timesteps_per_iteration"] = DEPRECATED_VALUE - - # Evaluation settings. - - # Deprecated setting: `evaluation_num_episodes`. - if config.get("evaluation_num_episodes", DEPRECATED_VALUE) != DEPRECATED_VALUE: - deprecation_warning( - old="evaluation_num_episodes", - new="`evaluation_duration` and `evaluation_duration_unit=episodes`", - error=True, - ) - config["evaluation_duration"] = config["evaluation_num_episodes"] - config["evaluation_duration_unit"] = "episodes" - config["evaluation_num_episodes"] = DEPRECATED_VALUE - - # If `evaluation_num_workers` > 0, warn if `evaluation_interval` is - # None (also set `evaluation_interval` to 1). - if config["evaluation_num_workers"] > 0 and not config["evaluation_interval"]: - logger.warning( - f"You have specified {config['evaluation_num_workers']} " - "evaluation workers, but your `evaluation_interval` is None! " - "Therefore, evaluation will not occur automatically with each" - " call to `Algorithm.train()`. Instead, you will have to call " - "`Algorithm.evaluate()` manually in order to trigger an " - "evaluation run." - ) - # If `evaluation_num_workers=0` and - # `evaluation_parallel_to_training=True`, warn that you need - # at least one remote eval worker for parallel training and - # evaluation, and set `evaluation_parallel_to_training` to False. - elif config["evaluation_num_workers"] == 0 and config.get( - "evaluation_parallel_to_training", False - ): - logger.warning( - "`evaluation_parallel_to_training` can only be done if " - "`evaluation_num_workers` > 0! Setting " - "`evaluation_parallel_to_training` to False." - ) - config["evaluation_parallel_to_training"] = False - - # If `evaluation_duration=auto`, error if - # `evaluation_parallel_to_training=False`. - if config["evaluation_duration"] == "auto": - if not config["evaluation_parallel_to_training"]: - raise ValueError( - "`evaluation_duration=auto` not supported for " - "`evaluation_parallel_to_training=False`!" - ) - # Make sure, it's an int otherwise. - elif ( - not isinstance(config["evaluation_duration"], int) - or config["evaluation_duration"] <= 0 - ): - raise ValueError( - "`evaluation_duration` ({}) must be an int and " - ">0!".format(config["evaluation_duration"]) - ) - @staticmethod @ExperimentalAPI def validate_env(env: EnvType, env_context: EnvContext) -> None: @@ -2854,11 +2694,19 @@ def _checkpoint_info_to_algorithm_state( if pid in policy_ids } # Remove policies from multiagent dict that are not in `policy_ids`. - policies_dict = state["config"]["multiagent"]["policies"] - policies_dict = { - pid: spec for pid, spec in policies_dict.items() if pid in policy_ids - } - state["config"]["multiagent"]["policies"] = policies_dict + new_config = AlgorithmConfig.from_dict(state["config"]) + new_policies = new_config.policies + if isinstance(new_policies, (set, list, tuple)): + new_policies = {pid for pid in new_policies if pid in policy_ids} + else: + new_policies = { + pid: spec for pid, spec in new_policies.items() if pid in policy_ids + } + new_config.multi_agent( + policies=new_policies, + policies_to_train=policies_to_train, + ) + state["config"] = new_config.to_dict() # Prepare local `worker` state to add policies' states into it, # read from separate policy checkpoint files. @@ -3002,7 +2850,7 @@ def _run_one_evaluation( self._automatic_evaluation_duration_fn, unit, self.config["evaluation_num_workers"], - self.config["evaluation_config"], + self.evaluation_config, train_future, ) ) @@ -3015,10 +2863,8 @@ def _run_one_evaluation( num_recreated = self.try_recover_from_step_attempt( error=e, worker_set=self.evaluation_workers, - ignore=self.config["evaluation_config"].get("ignore_worker_failures"), - recreate=self.config["evaluation_config"].get( - "recreate_failed_workers" - ), + ignore=self.evaluation_config.get("ignore_worker_failures"), + recreate=self.evaluation_config.get("recreate_failed_workers"), ) # `self._evaluate_async` handles its own worker failures and already adds # this metric, but `self.evaluate` doesn't. @@ -3214,8 +3060,8 @@ def _make_workers( return WorkerSet( env_creator=env_creator, validate_env=validate_env, - policy_class=policy_class, - trainer_config=config, + default_policy_class=policy_class, + config=config, num_workers=num_workers, local_worker=local_worker, logdir=self.logdir, diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index 421f5eb232cc8..f82bd772fbf29 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -1,27 +1,41 @@ import copy -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, Union - import gym +from gym.spaces import Space +import logging +import math +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union +import ray +from ray.util import log_once from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.rllib.env.env_context import EnvContext +from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.evaluation.collectors.sample_collector import SampleCollector from ray.rllib.evaluation.collectors.simple_list_collector import SimpleListCollector from ray.rllib.models import MODEL_DEFAULTS +from ray.rllib.policy.policy import Policy, PolicySpec +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils import deep_update, merge_dicts from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning +from ray.rllib.utils.from_config import from_config +from ray.rllib.utils.policy import validate_policy_id from ray.rllib.utils.typing import ( AlgorithmConfigDict, EnvConfigDict, EnvType, + MultiAgentPolicyConfigDict, PartialAlgorithmConfigDict, + PolicyID, ResultDict, + SampleBatchType, ) from ray.tune.logger import Logger if TYPE_CHECKING: from ray.rllib.algorithms.algorithm import Algorithm +logger = logging.getLogger(__name__) + class AlgorithmConfig: """A RLlib AlgorithmConfig builds an RLlib Algorithm from a given configuration. @@ -37,7 +51,7 @@ class AlgorithmConfig: ... .rollouts(num_rollout_workers=4) ... .callbacks(MemoryTrackingCallbacks) >>> # A config object can be used to construct the respective Trainer. - >>> rllib_trainer = config.build() + >>> rllib_algo = config.build() Example: >>> from ray.rllib.algorithms.algorithm_config import AlgorithmConfig @@ -154,10 +168,13 @@ def __init__(self, algo_class=None): } # `self.multi_agent()` - self.policies = {} + self._is_multi_agent = False + self.policies = {DEFAULT_POLICY_ID: PolicySpec()} self.policy_map_capacity = 100 self.policy_map_cache = None - self.policy_mapping_fn = None + self.policy_mapping_fn = ( + lambda aid, episode, worker, **kwargs: DEFAULT_POLICY_ID + ) self.policies_to_train = None self.observation_fn = None self.count_steps_by = "env_steps" @@ -172,6 +189,7 @@ def __init__(self, algo_class=None): self.output_config = {} self.output_compress_columns = ["obs", "new_obs"] self.output_max_file_size = 64 * 1024 * 1024 + self.offline_sampling = False # `self.evaluation()` self.evaluation_interval = None @@ -179,7 +197,7 @@ def __init__(self, algo_class=None): self.evaluation_duration_unit = "episodes" self.evaluation_sample_timeout_s = 180.0 self.evaluation_parallel_to_training = False - self.evaluation_config = {} + self.evaluation_config = None self.off_policy_estimation_methods = {} self.ope_split_batch_by_episode = True self.evaluation_num_workers = 0 @@ -216,6 +234,9 @@ def __init__(self, algo_class=None): self._disable_action_flattening = False self._disable_execution_plan_api = True + # Has this config object been frozen (cannot alter its attributes anymore). + self._is_frozen = False + # TODO: Remove, once all deprecation_warning calls upon using these keys # have been removed. # === Deprecated keys === @@ -251,6 +272,7 @@ def to_dict(self) -> AlgorithmConfigDict: """ config = copy.deepcopy(vars(self)) config.pop("algo_class") + config.pop("_is_frozen") # Worst naming convention ever: NEVER EVER use reserved key-words... if "lambda_" in config: @@ -264,15 +286,7 @@ def to_dict(self) -> AlgorithmConfigDict: # Setup legacy multi-agent sub-dict: config["multiagent"] = {} - for k in [ - "policies", - "policy_map_capacity", - "policy_map_cache", - "policy_mapping_fn", - "policies_to_train", - "observation_fn", - "count_steps_by", - ]: + for k in self.multiagent.keys(): config["multiagent"][k] = config.pop(k) # Switch out deprecated vs new config keys. @@ -282,14 +296,154 @@ def to_dict(self) -> AlgorithmConfigDict: config["framework"] = config.pop("framework_str", None) config["num_cpus_for_driver"] = config.pop("num_cpus_for_local_worker", 1) + for dep_k in [ + "monitor", + "evaluation_num_episodes", + "metrics_smoothing_episodes", + "timesteps_per_iteration", + "min_iter_time_s", + "collect_metrics_timeout", + "buffer_size", + "prioritized_replay", + "learning_starts", + "replay_batch_size", + "replay_mode", + "prioritized_replay_alpha", + "prioritized_replay_beta", + "prioritized_replay_eps", + "min_time_s_per_reporting", + "min_train_timesteps_per_reporting", + "min_sample_timesteps_per_reporting", + "input_evaluation", + ]: + if config.get(dep_k) == DEPRECATED_VALUE: + config.pop(dep_k, None) + return config + @classmethod + def from_dict(cls, config_dict: dict) -> "AlgorithmConfig": + """Creates an AlgorithmConfig from a legacy python config dict. + + Examples: + >>> from ray.rllib.algorithms.ppo.ppo import DEFAULT_CONFIG, PPOConfig + >>> ppo_config = PPOConfig.from_dict(DEFAULT_CONFIG) + >>> ppo = ppo_config.build(env="Pendulum-v1") + + Args: + config_dict: The legacy formatted python config dict for some algorithm. + + Returns: + A new AlgorithmConfig object that matches the given python config dict. + """ + # Create a default config object of this class. + config_obj = cls() + # Remove `_is_frozen` flag from config dict in case the AlgorithmConfig that + # the dict was derived from was already frozen (we don't want to copy the + # frozenness). + config_dict.pop("_is_frozen", None) + config_obj.update_from_dict(config_dict) + return config_obj + + def update_from_dict( + self, + config_dict: PartialAlgorithmConfigDict, + ) -> "AlgorithmConfig": + """Modifies this AlgorithmConfig via the provided python config dict. + + Warns if `config_dict` contains deprecated keys. + Silently sets even properties of `self` that do NOT exist. This way, this method + may be used to configure custom Policies which do not have their own specific + AlgorithmConfig classes, e.g. + `ray.rllib.examples.policy.random_policy::RandomPolicy`. + + Args: + config_dict: The old-style python config dict (PartialAlgorithmConfigDict) + to use for overriding some properties defined in there. + + Returns: + This updated AlgorithmConfig object. + """ + # Modify our properties one by one. + for key, value in config_dict.items(): + key = self._translate_special_keys(key, warn_deprecated=False) + + # Set our multi-agent settings. + if key == "multiagent": + kwargs = { + k: value[k] + for k in [ + "policies", + "policy_map_capacity", + "policy_map_cache", + "policy_mapping_fn", + "policies_to_train", + "observation_fn", + "count_steps_by", + ] + if k in value + } + self.multi_agent(**kwargs) + # Some keys must use `.update()` from given config dict (to not lose + # any sub-keys). + elif key == "callbacks_class": + self.callbacks(callbacks_class=value) + elif key == "env_config": + self.environment(env_config=value) + elif key == "model": + self.training(model=value) + # If config key matches a property, just set it, otherwise, warn and set. + else: + if not hasattr(self, key) and log_once( + "unknown_property_in_algo_config" + ): + logger.warning( + f"Cannot create {type(self).__name__} from given " + f"`config_dict`! Property {key} not supported." + ) + setattr(self, key, value) + + return self + + def copy(self, copy_frozen: Optional[bool] = None) -> "AlgorithmConfig": + """Creates a deep copy of this config and (un)freezes if necessary. + + Args: + copy_frozen: Whether the created deep copy will be frozen or not. If None, + keep the same frozen status that `self` currently has. + + Returns: + A deep copy of `self` that is (un)frozen. + """ + cp = copy.deepcopy(self) + if copy_frozen is True: + cp.freeze() + elif copy_frozen is False: + cp._is_frozen = False + if isinstance(cp.evaluation_config, AlgorithmConfig): + cp.evaluation_config._is_frozen = False + return cp + + def freeze(self) -> None: + """Freezes this config object, such that no attributes can be set anymore. + + Algorithms should use this method to make sure that their config objects + remain read-only after this. + """ + if self._is_frozen: + return + self._is_frozen = True + # Also freeze underlying eval config, if applicable. + if isinstance(self.evaluation_config, AlgorithmConfig): + self.evaluation_config.freeze() + def build( self, env: Optional[Union[str, EnvType]] = None, logger_creator: Optional[Callable[[], Logger]] = None, + use_copy: bool = True, ) -> "Algorithm": - """Builds an Algorithm from the AlgorithmConfig. + """Builds an Algorithm from this AlgorithmConfig (or a copy thereof). Args: env: Name of the environment to use (e.g. a gym-registered str), @@ -299,6 +453,10 @@ class directly. Note that this arg can also be specified via the "env" key in `config`. logger_creator: Callable that creates a ray.tune.Logger object. If unspecified, a default logger is created. + use_copy: Whether to deepcopy `self` and pass the copy to the Algorithm + (instead of `self`) as config. This is useful in case you would like to + recycle the same AlgorithmConfig over and over, e.g. in a test case, in + which we loop over different DL-frameworks. Returns: A ray.rllib.algorithms.algorithm.Algorithm object. @@ -311,8 +469,7 @@ class directly. Note that this arg can also be specified via self.logger_creator = logger_creator return self.algo_class( - config=self.to_dict(), - env=self.env, + config=self if not use_copy else copy.deepcopy(self), logger_creator=self.logger_creator, ) @@ -513,7 +670,11 @@ def environment( if env is not None: self.env = env if env_config is not None: - self.env_config = env_config + deep_update( + self.env_config, + env_config, + True, + ) if observation_space is not None: self.observation_space = observation_space if action_space is not None: @@ -603,15 +764,23 @@ def rollouts( batch_mode: How to build per-Sampler (RolloutWorker) batches, which are then usually concat'd to form the train batch. Note that "steps" below can mean different things (either env- or agent-steps) and depends on the - `count_steps_by` (multiagent) setting below. - "truncate_episodes": Each produced batch (when calling - RolloutWorker.sample()) will contain exactly `rollout_fragment_length` - steps. This mode guarantees evenly sized batches, but increases + `count_steps_by` setting, adjustable via + `AlgorithmConfig.multi_agent(count_steps_by=..)`: + 1) "truncate_episodes": Each call to sample() will return a + batch of at most `rollout_fragment_length * num_envs_per_worker` in + size. The batch will be exactly `rollout_fragment_length * num_envs` + in size if postprocessing does not change batch sizes. Episodes + may be truncated in order to meet this size requirement. + This mode guarantees evenly sized batches, but increases variance as the future return must now be estimated at truncation boundaries. - "complete_episodes": Each unroll happens exactly over one episode, from - beginning to end. Data collection will not stop unless the episode - terminates or a configured horizon (hard or soft) is hit. + 2) "complete_episodes": Each call to sample() will return a + batch of at least `rollout_fragment_length * num_envs_per_worker` in + size. Episodes will not be truncated, but multiple episodes + may be packed within one batch to meet the (minimum) batch size. + Note that when `num_envs_per_worker > 1`, episode steps will be buffered + until the episode completes, and hence batches may contain + significant amounts of off-policy data. remote_worker_envs: If using num_envs_per_worker > 1, whether to create those new envs in remote processes instead of in the same worker. This adds overheads, but can make sense if your envs can take much @@ -696,8 +865,16 @@ def rollouts( self.enable_connectors = enable_connectors if rollout_fragment_length is not None: self.rollout_fragment_length = rollout_fragment_length + + # Check batching/sample collection settings. if batch_mode is not None: + if batch_mode not in ["truncate_episodes", "complete_episodes"]: + raise ValueError( + "`config.batch_mode` must be one of [truncate_episodes|" + "complete_episodes]! Got {}".format(batch_mode) + ) self.batch_mode = batch_mode + if remote_worker_envs is not None: self.remote_worker_envs = remote_worker_envs if remote_env_batch_wait_ms is not None: @@ -754,6 +931,7 @@ def training( train_batch_size: Training batch size, if applicable. model: Arguments passed into the policy model. See models/catalog.py for a full list of the available model options. + TODO: Provide ModelConfig objects instead of dicts. optimizer: Arguments to pass to the policy optimizer. Returns: @@ -766,7 +944,7 @@ def training( if train_batch_size is not None: self.train_batch_size = train_batch_size if model is not None: - self.model = model + self.model.update(model) if optimizer is not None: self.optimizer = merge_dicts(self.optimizer, optimizer) @@ -784,6 +962,15 @@ def callbacks(self, callbacks_class) -> "AlgorithmConfig": Returns: This updated AlgorithmConfig object. """ + if callbacks_class is None: + callbacks_class = DefaultCallbacks + # Check, whether given `callbacks` is a callable. + if not callable(callbacks_class): + raise ValueError( + "`config.callbacks_class` must be a callable method that " + "returns a subclass of DefaultCallbacks, got " + f"{callbacks_class}!" + ) self.callbacks_class = callbacks_class return self @@ -926,11 +1113,16 @@ def evaluation( if evaluation_parallel_to_training is not None: self.evaluation_parallel_to_training = evaluation_parallel_to_training if evaluation_config is not None: - # Convert another AlgorithmConfig into dict. - if isinstance(evaluation_config, AlgorithmConfig): - self.evaluation_config = evaluation_config.to_dict() - else: - self.evaluation_config = evaluation_config + from ray.rllib.algorithms.algorithm import Algorithm + + self.evaluation_config = deep_update( + self.evaluation_config or {}, + evaluation_config, + True, + Algorithm._allow_unknown_subkeys, + Algorithm._override_all_subkeys_if_type_changes, + Algorithm._override_all_key_list, + ) if off_policy_estimation_methods is not None: self.off_policy_estimation_methods = off_policy_estimation_methods if evaluation_num_workers is not None: @@ -944,6 +1136,46 @@ def evaluation( if ope_split_batch_by_episode is not None: self.ope_split_batch_by_episode = ope_split_batch_by_episode + # If `evaluation_num_workers` > 0, warn if `evaluation_interval` is + # None (also set `evaluation_interval` to 1). + if self.evaluation_num_workers > 0 and not self.evaluation_interval: + logger.warning( + f"You have specified {self.evaluation_num_workers} " + "evaluation workers, but your `evaluation_interval` is None! " + "Therefore, evaluation will not occur automatically with each" + " call to `Algorithm.train()`. Instead, you will have to call " + "`Algorithm.evaluate()` manually in order to trigger an " + "evaluation run." + ) + # If `evaluation_num_workers=0` and + # `evaluation_parallel_to_training=True`, warn that you need + # at least one remote eval worker for parallel training and + # evaluation, and set `evaluation_parallel_to_training` to False. + elif self.evaluation_num_workers == 0 and self.evaluation_parallel_to_training: + raise ValueError( + "`evaluation_parallel_to_training` can only be done if " + "`evaluation_num_workers` > 0! Try setting " + "`config.evaluation_parallel_to_training` to False." + ) + + # If `evaluation_duration=auto`, error if + # `evaluation_parallel_to_training=False`. + if self.evaluation_duration == "auto": + if not self.evaluation_parallel_to_training: + raise ValueError( + "`evaluation_duration=auto` not supported for " + "`evaluation_parallel_to_training=False`!" + ) + # Make sure, it's an int otherwise. + elif ( + not isinstance(self.evaluation_duration, int) + or self.evaluation_duration <= 0 + ): + raise ValueError( + f"`evaluation_duration` ({self.evaluation_duration}) must be an " + f"int and >0!" + ) + return self def offline_data( @@ -959,6 +1191,7 @@ def offline_data( output_config=None, output_compress_columns=None, output_max_file_size=None, + offline_sampling=None, ) -> "AlgorithmConfig": """Sets the config's offline data settings. @@ -1016,6 +1249,11 @@ def offline_data( output data. output_max_file_size: Max output file size before rolling over to a new file. + offline_sampling: Whether sampling for the Algorithm happens via + reading from offline data. If True, RolloutWorkers will NOT limit the + number of collected batches within the same `sample()` call based on + the number of sub-environments within the worker (no sub-environments + present). Returns: This updated AlgorithmConfig object. @@ -1047,6 +1285,8 @@ def offline_data( self.output_compress_columns = output_compress_columns if output_max_file_size is not None: self.output_max_file_size = output_max_file_size + if offline_sampling is not None: + self.offline_sampling = offline_sampling return self @@ -1064,16 +1304,22 @@ def multi_agent( ) -> "AlgorithmConfig": """Sets the config's multi-agent settings. + Validates the new multi-agent settings and translates everything into + a unified multi-agent setup format. For example a `policies` list or set + of IDs is properly converted into a dict mapping these IDs to PolicySpecs. + Args: - policies: Map of type MultiAgentPolicyConfigDict from policy ids to tuples - of (policy_cls, obs_space, act_space, config). This defines the - observation and action spaces of the policies and any extra config. + policies: Map of type MultiAgentPolicyConfigDict from policy ids to either + 4-tuples of (policy_cls, obs_space, act_space, config) or PolicySpecs. + These tuples or PolicySpecs define the class of the policy, the + observation- and action spaces of the policies, and any extra config. policy_map_capacity: Keep this many policies in the "policy_map" (before writing least-recently used ones to disk/S3). policy_map_cache: Where to store overflowing (least-recently used) policies? Could be a directory (str) or an S3 location. None for using the default output dir. - policy_mapping_fn: Function mapping agent ids to policy ids. + policy_mapping_fn: Function mapping agent ids to policy ids. The signature + is: (agent_id, episode, worker, **kwargs) -> PolicyID. policies_to_train: Determines those policies that should be updated. Options are: - None, for all policies. @@ -1097,17 +1343,54 @@ def multi_agent( This updated AlgorithmConfig object. """ if policies is not None: + # Make sure our Policy IDs are ok (this should work whether `policies` + # is a dict or just any Sequence). + for pid in policies: + validate_policy_id(pid, error=False) + # Policy IDs must be strings. + if not isinstance(pid, str): + raise KeyError( + f"Policy IDs must always be of type `str`, got {type(pid)}" + ) + # Validate each policy spec in a given dict. + if isinstance(policies, dict): + for pid, spec in policies.items(): + # If not a PolicySpec object, values must be lists/tuples of len 4. + if not isinstance(spec, PolicySpec): + if not isinstance(spec, (list, tuple)) or len(spec) != 4: + raise ValueError( + "Policy specs must be tuples/lists of " + "(cls or None, obs_space, action_space, config), " + f"got {spec} for PolicyID={pid}" + ) + # TODO: Switch from dict to AlgorithmConfigOverride, once available. + # Config not a dict. + elif ( + not isinstance(spec.config, (AlgorithmConfig, dict)) + and spec.config is not None + ): + raise ValueError( + f"Multi-agent policy config for {pid} must be a dict or " + f"AlgorithmConfig object, but got {type(spec.config)}!" + ) self.policies = policies + if policy_map_capacity is not None: self.policy_map_capacity = policy_map_capacity + if policy_map_cache is not None: self.policy_map_cache = policy_map_cache + if policy_mapping_fn is not None: + # Attempt to create a `policy_mapping_fn` from config dict. Helpful + # is users would like to specify custom callable classes in yaml files. + if isinstance(policy_mapping_fn, dict): + policy_mapping_fn = from_config(policy_mapping_fn) self.policy_mapping_fn = policy_mapping_fn - if policies_to_train is not None: - self.policies_to_train = policies_to_train + if observation_fn is not None: self.observation_fn = observation_fn + if replay_mode != DEPRECATED_VALUE: deprecation_warning( old="AlgorithmConfig.multi_agent(replay_mode=..)", @@ -1115,11 +1398,57 @@ def multi_agent( "replay_buffer_config={'replay_mode': ..})", error=True, ) + if count_steps_by is not None: + if count_steps_by not in ["env_steps", "agent_steps"]: + raise ValueError( + "config.multi_agent(count_steps_by=..) must be one of " + f"[env_steps|agent_steps], not {count_steps_by}!" + ) self.count_steps_by = count_steps_by + if policies_to_train is not None: + assert isinstance(policies_to_train, (list, set, tuple)) or callable( + policies_to_train + ), ( + "ERROR: `policies_to_train`must be a [list|set|tuple] or a " + "callable taking PolicyID and SampleBatch and returning " + "True|False (trainable or not?)." + ) + # Check `policies_to_train` for invalid entries. + if isinstance(policies_to_train, (list, set, tuple)): + if len(policies_to_train) == 0: + logger.warning( + "`config.multi_agent(policies_to_train=..)` is empty! " + "Make sure - if you would like to learn at least one policy - " + "to add its ID to that list." + ) + for pid in policies_to_train: + if pid not in self.policies: + raise ValueError( + "`config.multi_agent(policies_to_train=..)` contains " + f"policy ID ({pid}) that was not defined in " + f"`config.multi_agent(policies=..)`!" + ) + self.policies_to_train = policies_to_train + + # Is this a multi-agent setup? True, iff DEFAULT_POLICY_ID is only + # PolicyID found in policies dict. + self._is_multi_agent = ( + len(self.policies) > 1 or DEFAULT_POLICY_ID not in self.policies + ) + return self + def is_multi_agent(self) -> bool: + """Returns whether this config specifies a multi-agent setup. + + Returns: + True, if a) >1 policies defined OR b) 1 policy defined, but its ID is NOT + DEFAULT_POLICY_ID. + """ + return self._is_multi_agent + def reporting( self, *, @@ -1307,3 +1636,379 @@ def experimental( self._disable_execution_plan_api = _disable_execution_plan_api return self + + def get_evaluation_config_object( + self, + ) -> Optional["AlgorithmConfig"]: + """Creates a full AlgorithmConfig object from `self.evaluation_config`. + + Returns: + A fully valid AlgorithmConfig object that can be used for the evaluation + WorkerSet. If `self` is already an evaluation config object, return None. + """ + if self.in_evaluation: + assert self.evaluation_config is None + return None + + # Convert AlgorithmConfig into dict (for later updating from dict). + evaluation_config = self.evaluation_config + if isinstance(evaluation_config, AlgorithmConfig): + evaluation_config = evaluation_config.to_dict() + + # Create unfrozen copy of self to be used as the to-be-returned eval + # AlgorithmConfig. + eval_config_obj = self.copy(copy_frozen=False) + # Switch on the `in_evaluation` flag and remove `evaluation_config` + # (set to None). + eval_config_obj.in_evaluation = True + eval_config_obj.evaluation_config = None + # Update with evaluation settings: + eval_config_obj.update_from_dict(evaluation_config or {}) + + # Evaluation duration unit: episodes. + # Switch on `complete_episode` rollouts. Also, make sure + # rollout fragments are short so we never have more than one + # episode in one rollout. + if self.evaluation_duration_unit == "episodes": + eval_config_obj.batch_mode = "complete_episodes" + eval_config_obj.rollout_fragment_length = 1 + # Evaluation duration unit: timesteps. + # - Set `batch_mode=truncate_episodes` so we don't perform rollouts + # strictly along episode borders. + # Set `rollout_fragment_length` such that desired steps are divided + # equally amongst workers or - in "auto" duration mode - set it + # to a reasonably small number (10), such that a single `sample()` + # call doesn't take too much time and we can stop evaluation as soon + # as possible after the train step is completed. + else: + eval_config_obj.batch_mode = "truncate_episodes" + eval_config_obj.rollout_fragment_length = ( + 10 + if self.evaluation_duration == "auto" + else int( + math.ceil( + self.evaluation_duration / (self.evaluation_num_workers or 1) + ) + ) + ) + + return eval_config_obj + + def get_multi_agent_setup( + self, + *, + policies: Optional[MultiAgentPolicyConfigDict] = None, + env: Optional[EnvType] = None, + spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None, + default_policy_class: Optional[Type[Policy]] = None, + ) -> Tuple[MultiAgentPolicyConfigDict, Callable[[PolicyID, SampleBatchType], bool]]: + """Infers the observation- and action spaces in a multi-agent policy dict. + + Args: + policies: The multi-agent `policies` dict mapping policy IDs + to PolicySpec objects. Note that the `policy_class`, + `observation_space`, and `action_space` properties in these PolicySpecs + may be None and must therefore be inferred here. + env: An optional env instance, from which to infer the different spaces for + the different policies. + spaces: Optional dict mapping policy IDs to tuples of 1) observation space + and 2) action space that should be used for the respective policy. + These spaces were usually provided by an already instantiated remote + worker. + default_policy_class: The Policy class to use should a PolicySpec have its + policy_class property set to None. + + Returns: + A tuple consisting of 1) a MultiAgentPolicyConfigDict and 2) a + `is_policy_to_train(PolicyID, SampleBatchType) -> bool` callable. + """ + policies = copy.deepcopy(policies or self.policies) + + # Policies given as set/list/tuple (of PolicyIDs) -> Setup each policy + # automatically via empty PolicySpec (will make RLlib infer observation- and + # action spaces as well as the Policy's class). + if isinstance(policies, (set, list, tuple)): + policies = {pid: PolicySpec() for pid in policies} + + # Try extracting spaces from env or from given spaces dict. + env_obs_space = None + env_act_space = None + + # Env is a ray.remote: Get spaces via its (automatically added) + # `_get_spaces()` method. + if isinstance(env, ray.actor.ActorHandle): + env_obs_space, env_act_space = ray.get(env._get_spaces.remote()) + # Normal env (gym.Env or MultiAgentEnv): These should have the + # `observation_space` and `action_space` properties. + elif env is not None: + if hasattr(env, "observation_space") and isinstance( + env.observation_space, gym.Space + ): + env_obs_space = env.observation_space + + if hasattr(env, "action_space") and isinstance(env.action_space, gym.Space): + env_act_space = env.action_space + # Last resort: Try getting the env's spaces from the spaces + # dict's special __env__ key. + if spaces is not None: + if env_obs_space is None: + env_obs_space = spaces.get("__env__", [None])[0] + if env_act_space is None: + env_act_space = spaces.get("__env__", [None, None])[1] + + # Check each defined policy ID and unify its spec. + for pid, policy_spec in policies.copy().items(): + # Convert to PolicySpec if plain list/tuple. + if not isinstance(policy_spec, PolicySpec): + policies[pid] = policy_spec = PolicySpec(*policy_spec) + + # Infer policy classes for policies dict, if not provided (None). + if policy_spec.policy_class is None and default_policy_class is not None: + policies[pid].policy_class = default_policy_class + + # Infer observation space. + if policy_spec.observation_space is None: + if spaces is not None and pid in spaces: + obs_space = spaces[pid][0] + elif env_obs_space is not None: + # Multi-agent case AND different agents have different spaces: + # Need to reverse map spaces (for the different agents) to certain + # policy IDs. + if ( + isinstance(env, MultiAgentEnv) + and hasattr(env, "_spaces_in_preferred_format") + and env._spaces_in_preferred_format + ): + obs_space = None + mapping_fn = self.policy_mapping_fn + if mapping_fn: + for aid in env.get_agent_ids(): + # Match: Assign spaces for this agentID to the PolicyID. + if mapping_fn(aid, None, None) == pid: + # Make sure, different agents that map to the same + # policy don't have different spaces. + if ( + obs_space is not None + and env_obs_space[aid] != obs_space + ): + raise ValueError( + "Two agents in your environment map to the " + "same policyID (as per your `policy_mapping" + "_fn`), however, these agents also have " + "different observation spaces!" + ) + obs_space = env_obs_space[aid] + # Otherwise, just use env's obs space as-is. + else: + obs_space = env_obs_space + # Space given directly in config. + elif self.observation_space: + obs_space = self.observation_space + else: + raise ValueError( + "`observation_space` not provided in PolicySpec for " + f"{pid} and env does not have an observation space OR " + "no spaces received from other workers' env(s) OR no " + "`observation_space` specified in config!" + ) + + policies[pid].observation_space = obs_space + + # Infer action space. + if policy_spec.action_space is None: + if spaces is not None and pid in spaces: + act_space = spaces[pid][1] + elif env_act_space is not None: + # Multi-agent case AND different agents have different spaces: + # Need to reverse map spaces (for the different agents) to certain + # policy IDs. + if ( + isinstance(env, MultiAgentEnv) + and hasattr(env, "_spaces_in_preferred_format") + and env._spaces_in_preferred_format + ): + act_space = None + mapping_fn = self.policy_mapping_fn + if mapping_fn: + for aid in env.get_agent_ids(): + # Match: Assign spaces for this AgentID to the PolicyID. + if mapping_fn(aid, None, None) == pid: + # Make sure, different agents that map to the same + # policy don't have different spaces. + if ( + act_space is not None + and env_act_space[aid] != act_space + ): + raise ValueError( + "Two agents in your environment map to the " + "same policyID (as per your `policy_mapping" + "_fn`), however, these agents also have " + "different action spaces!" + ) + act_space = env_act_space[aid] + # Otherwise, just use env's action space as-is. + else: + act_space = env_act_space + elif self.action_space: + act_space = self.action_space + else: + raise ValueError( + "`action_space` not provided in PolicySpec for " + f"{pid} and env does not have an action space OR " + "no spaces received from other workers' env(s) OR no " + "`action_space` specified in config!" + ) + policies[pid].action_space = act_space + + # Config is None -> Set to {}. + if policies[pid].config is None: + policies[pid].config = {} + + # If container given, construct a simple default callable returning True + # if the PolicyID is found in the list/set of IDs. + is_policy_to_train = self.policies_to_train + if self.policies_to_train is not None and not callable(self.policies_to_train): + pols = set(self.policies_to_train) + + def is_policy_to_train(pid, batch=None): + return pid in pols + + return policies, is_policy_to_train + + def __setattr__(self, key, value): + if hasattr(self, "_is_frozen") and self._is_frozen: + # TODO: Remove `simple_optimizer` entirely. + # Remove need to set `worker_index` in RolloutWorker's c'tor. + if key not in ["simple_optimizer", "worker_index", "_is_frozen"]: + raise AttributeError( + f"Cannot set attribute ({key}) of an already frozen " + "AlgorithmConfig!" + ) + super().__setattr__(key, value) + + def __getitem__(self, item): + # TODO: Uncomment this once all algorithms use AlgorithmConfigs under the + # hood (as well as Ray Tune). + # if log_once("algo_config_getitem"): + # logger.warning( + # "AlgorithmConfig objects should NOT be used as dict! " + # f"Try accessing `{item}` directly as a property." + # ) + item = self._translate_special_keys(item) + return getattr(self, item) + + def __setitem__(self, key, value): + # TODO: Remove comments once all methods/functions only support + # AlgorithmConfigs and there is no more ambiguity anywhere in the code + # on whether an AlgorithmConfig is used or an old python config dict. + # raise AttributeError( + # "AlgorithmConfig objects should not have their values set like dicts" + # f"(`config['{key}'] = {value}`), " + # f"but via setting their properties directly (config.{prop} = {value})." + # ) + super().__setattr__(key, value) + + def __contains__(self, item) -> bool: + prop = self._translate_special_keys(item, warn_deprecated=False) + return hasattr(self, prop) + + def get(self, key, default=None): + prop = self._translate_special_keys(key, warn_deprecated=False) + return getattr(self, prop, default) + + def pop(self, key, default=None): + return self.get(key, default) + + def keys(self): + return self.to_dict().keys() + + def values(self): + return self.to_dict().values() + + def items(self): + return self.to_dict().items() + + @staticmethod + def _translate_special_keys(key: str, warn_deprecated: bool = True) -> str: + # Handle special key (str) -> `AlgorithmConfig.[some_property]` cases. + if key == "callbacks": + key = "callbacks_class" + elif key == "create_env_on_driver": + key = "create_env_on_local_worker" + elif key == "custom_eval_function": + key = "custom_evaluation_function" + elif key == "framework": + key = "framework_str" + elif key == "input": + key = "input_" + elif key == "lambda": + key = "lambda_" + elif key == "num_cpus_for_driver": + key = "num_cpus_for_local_worker" + + # Deprecated keys. + if warn_deprecated: + if key == "collect_metrics_timeout": + deprecation_warning( + old="collect_metrics_timeout", + new="metrics_episode_collection_timeout_s", + error=True, + ) + elif key == "metrics_smoothing_episodes": + deprecation_warning( + old="config.metrics_smoothing_episodes", + new="config.metrics_num_episodes_for_smoothing", + error=True, + ) + elif key == "min_iter_time_s": + deprecation_warning( + old="config.min_iter_time_s", + new="config.min_time_s_per_iteration", + error=True, + ) + elif key == "min_time_s_per_reporting": + deprecation_warning( + old="config.min_time_s_per_reporting", + new="config.min_time_s_per_iteration", + error=True, + ) + elif key == "min_sample_timesteps_per_reporting": + deprecation_warning( + old="config.min_sample_timesteps_per_reporting", + new="config.min_sample_timesteps_per_iteration", + error=True, + ) + elif key == "min_train_timesteps_per_reporting": + deprecation_warning( + old="config.min_train_timesteps_per_reporting", + new="config.min_train_timesteps_per_iteration", + error=True, + ) + elif key == "timesteps_per_iteration": + deprecation_warning( + old="config.timesteps_per_iteration", + new="`config.min_sample_timesteps_per_iteration` OR " + "`config.min_train_timesteps_per_iteration`", + error=True, + ) + elif key == "evaluation_num_episodes": + deprecation_warning( + old="config.evaluation_num_episodes", + new="`config.evaluation_duration` and " + "`config.evaluation_duration_unit=episodes`", + error=True, + ) + + return key + + @property + def multiagent(self): + return { + "policies": self.policies, + "policy_mapping_fn": self.policy_mapping_fn, + "policies_to_train": self.policies_to_train, + "policy_map_capacity": self.policy_map_capacity, + "policy_map_cache": self.policy_map_cache, + "count_steps_by": self.count_steps_by, + "observation_fn": self.observation_fn, + } diff --git a/rllib/algorithms/alpha_star/alpha_star.py b/rllib/algorithms/alpha_star/alpha_star.py index 3f4c10d6142b2..5c0aa61b3aa6c 100644 --- a/rllib/algorithms/alpha_star/alpha_star.py +++ b/rllib/algorithms/alpha_star/alpha_star.py @@ -11,6 +11,7 @@ import ray.rllib.algorithms.appo.appo as appo from ray.actor import ActorHandle from ray.rllib.algorithms.algorithm import Algorithm +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.algorithms.alpha_star.distributed_learners import DistributedLearners from ray.rllib.algorithms.alpha_star.league_builder import AlphaStarLeagueBuilder from ray.rllib.evaluation.rollout_worker import RolloutWorker @@ -18,6 +19,7 @@ from ray.rllib.execution.parallel_requests import AsyncRequestsManager from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.policy.sample_batch import MultiAgentBatch +from ray.rllib.utils import deep_update from ray.rllib.utils.annotations import override from ray.rllib.utils.deprecation import Deprecated from ray.rllib.utils.from_config import from_config @@ -54,8 +56,8 @@ class AlphaStarConfig(appo.APPOConfig): ... .rollouts(num_rollout_workers=64) >>> print(config.to_dict()) >>> # Build a Algorithm object from the config and run 1 training iteration. - >>> trainer = config.build(env="CartPole-v1") - >>> trainer.train() + >>> algo = config.build(env="CartPole-v1") + >>> algo.train() Example: >>> from ray.rllib.algorithms.alpha_star import AlphaStarConfig @@ -140,6 +142,8 @@ def __init__(self, algo_class=None): # values. self.vtrace_drop_last_ts = False self.min_time_s_per_iteration = 2 + self.policies = None + self.simple_optimizer = True # __sphinx_doc_end__ # fmt: on @@ -215,7 +219,18 @@ def training( if timeout_s_learner_manager is not None: self.timeout_s_learner_manager = timeout_s_learner_manager if league_builder_config is not None: - self.league_builder_config = league_builder_config + # Override entire `league_builder_config` if `type` key changes. + # Update, if `type` key remains the same or is not specified. + new_league_builder_config = deep_update( + {"league_builder_config": self.league_builder_config}, + {"league_builder_config": league_builder_config}, + False, + ["league_builder_config"], + ["league_builder_config"], + ) + self.league_builder_config = new_league_builder_config[ + "league_builder_config" + ] if max_num_policies_to_train is not None: self.max_num_policies_to_train = max_num_policies_to_train if max_requests_in_flight_per_sampler_worker is not None: @@ -248,7 +263,7 @@ def default_resource_request(cls, config): # Construct a dummy LeagueBuilder, such that it gets the opportunity to # adjust the multiagent config, according to its setup, and we can then # properly infer the resources to allocate. - from_config(cf["league_builder_config"], trainer=None, trainer_config=cf) + from_config(cf["league_builder_config"], algo=None, algo_config=cf) max_num_policies_to_train = cf["max_num_policies_to_train"] or len( cf["multiagent"].get("policies_to_train") or cf["multiagent"]["policies"] @@ -316,25 +331,23 @@ def get_default_config(cls) -> AlgorithmConfigDict: def validate_config(self, config: AlgorithmConfigDict): # Create the LeagueBuilder object, allowing it to build the multiagent # config as well. - self.league_builder = from_config( - config["league_builder_config"], trainer=self, trainer_config=config - ) + if not config.get("in_evaluation"): + self.league_builder = from_config( + config["league_builder_config"], algo=self, algo_config=config + ) super().validate_config(config) @override(appo.APPO) - def setup(self, config: PartialAlgorithmConfigDict): + def setup(self, config: AlgorithmConfig): # Call super's setup to validate config, create RolloutWorkers # (train and eval), etc.. - num_gpus_saved = config["num_gpus"] - config["num_gpus"] = min(config["num_gpus"], 1) super().setup(config) - self.config["num_gpus"] = num_gpus_saved + local_worker = self.workers.local_worker() # - Create n policy learner actors (@ray.remote-converted Policies) on # one or more GPU nodes. # - On each such node, also locate one replay buffer shard. - ma_cfg = self.config["multiagent"] # By default, set max_num_policies_to_train to the number of policy IDs # provided in the multiagent config. if self.config["max_num_policies_to_train"] is None: @@ -367,8 +380,15 @@ def setup(self, config: PartialAlgorithmConfigDict): replay_actor_class=ReplayActor, replay_actor_args=replay_actor_args, ) - for pid, policy_spec in ma_cfg["policies"].items(): - if pid in self.workers.local_worker().get_policies_to_train(): + policies, _ = self.config.get_multi_agent_setup( + spaces=local_worker.spaces, + default_policy_class=local_worker.default_policy_class, + ) + for pid, policy_spec in policies.items(): + if ( + local_worker.is_policy_to_train is None + or local_worker.is_policy_to_train(pid) + ): distributed_learners.add_policy(pid, policy_spec) # Store distributed_learners on all RolloutWorkers diff --git a/rllib/algorithms/alpha_star/distributed_learners.py b/rllib/algorithms/alpha_star/distributed_learners.py index b8c8d3e87baff..a085e597bdcef 100644 --- a/rllib/algorithms/alpha_star/distributed_learners.py +++ b/rllib/algorithms/alpha_star/distributed_learners.py @@ -4,6 +4,7 @@ import ray from ray.actor import ActorHandle from ray.rllib.algorithms.algorithm import Algorithm +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.policy.policy import PolicySpec from ray.rllib.utils.actors import create_colocated_actors from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary @@ -147,6 +148,10 @@ def __init__( replay_actor_class, replay_actor_args, ): + # For now, remain in config dict-land (b/c we are dealing with Policy classes + # here which do NOT use AlgorithmConfig yet). + if isinstance(config, AlgorithmConfig): + config = config.to_dict() self.config = config self.has_replay_buffer = False self.max_num_policies = max_num_policies @@ -216,6 +221,9 @@ def _add_replay_buffer_and_policy( policy_spec.policy_class, config ) + if isinstance(config, AlgorithmConfig): + config = config.to_dict() + colocated = create_colocated_actors( actor_specs=[ (self.replay_actor_class, self.replay_actor_args, {}, 1), diff --git a/rllib/algorithms/alpha_star/league_builder.py b/rllib/algorithms/alpha_star/league_builder.py index 4b3a773aa4a4e..f72287252da6a 100644 --- a/rllib/algorithms/alpha_star/league_builder.py +++ b/rllib/algorithms/alpha_star/league_builder.py @@ -6,30 +6,31 @@ from typing import Any, DefaultDict, Dict from ray.rllib.algorithms.algorithm import Algorithm +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.examples.policy.random_policy import RandomPolicy from ray.rllib.policy.policy import PolicySpec from ray.rllib.utils.annotations import ExperimentalAPI, override from ray.rllib.utils.numpy import softmax -from ray.rllib.utils.typing import PolicyID, AlgorithmConfigDict, ResultDict +from ray.rllib.utils.typing import PolicyID, ResultDict logger = logging.getLogger(__name__) @ExperimentalAPI class LeagueBuilder(metaclass=ABCMeta): - def __init__(self, trainer: Algorithm, trainer_config: AlgorithmConfigDict): + def __init__(self, algo: Algorithm, algo_config: AlgorithmConfig): """Initializes a LeagueBuilder instance. Args: - trainer: The Algorithm object by which this league builder is used. + algo: The Algorithm object by which this league builder is used. Algorithm calls `build_league()` after each training step. - trainer_config: The (not yet validated) config dict to be + algo_config: The (not yet validated) config to be used on the Algorithm. Child classes of `LeagueBuilder` should preprocess this to add e.g. multiagent settings to this config. """ - self.trainer = trainer - self.config = trainer_config + self.algo = algo + self.config = algo_config def build_league(self, result: ResultDict) -> None: """Method containing league-building logic. Called after train step. @@ -67,8 +68,8 @@ def build_league(self, result: ResultDict) -> None: class AlphaStarLeagueBuilder(LeagueBuilder): def __init__( self, - trainer: Algorithm, - trainer_config: AlgorithmConfigDict, + algo: Algorithm, + algo_config: AlgorithmConfig, num_random_policies: int = 2, num_learning_league_exploiters: int = 4, num_learning_main_exploiters: int = 4, @@ -86,10 +87,10 @@ def __init__( M: Main self-play (main vs main). Args: - trainer: The Algorithm object by which this league builder is used. + algo: The Algorithm object by which this league builder is used. Algorithm calls `build_league()` after each training step to reconfigure the league structure (e.g. to add/remove policies). - trainer_config: The (not yet validated) config dict to be + algo_config: The (not yet validated) config to be used on the Algorithm. Child classes of `LeagueBuilder` should preprocess this to add e.g. multiagent settings to this config. @@ -113,7 +114,7 @@ def __init__( prob_main_exploiter_playing_against_learning_main: Probability of a main-exploiter vs (training!) main match. """ - super().__init__(trainer, trainer_config) + super().__init__(algo, algo_config) self.win_rate_threshold_for_new_snapshot = win_rate_threshold_for_new_snapshot self.keep_new_snapshot_training_prob = keep_new_snapshot_training_prob @@ -131,13 +132,13 @@ def __init__( ) # Build trainer's multiagent config. - ma_config = self.config["multiagent"] + self.config._is_frozen = False # Make sure the multiagent config dict has no policies defined: - assert not ma_config.get("policies"), ( - "ERROR: `config.multiagent.policies` should not be pre-defined! " + assert self.config.policies is None, ( + "ERROR: `config.policies` should be None (not pre-defined by user)! " "AlphaStarLeagueBuilder will construct this itself." ) - ma_config["policies"] = policies = {} + policies = {} self.main_policies = 1 self.league_exploiters = ( @@ -149,7 +150,7 @@ def __init__( policies["main_0"] = PolicySpec() # Train all non-random policies that exist at beginning. - ma_config["policies_to_train"] = ["main_0"] + policies_to_train = ["main_0"] # Add random policies. i = -1 @@ -160,23 +161,26 @@ def __init__( for j in range(num_learning_league_exploiters): pid = f"league_exploiter_{j + i + 1}" policies[pid] = PolicySpec() - ma_config["policies_to_train"].append(pid) + policies_to_train.append(pid) # Add initial (learning) main-exploiters. for j in range(num_learning_league_exploiters): pid = f"main_exploiter_{j + i + 1}" policies[pid] = PolicySpec() - ma_config["policies_to_train"].append(pid) + policies_to_train.append(pid) # Build initial policy mapping function: main_0 vs main_exploiter_0. - ma_config["policy_mapping_fn"] = ( + self.config.policy_mapping_fn = ( lambda aid, ep, worker, **kw: "main_0" if ep.episode_id % 2 == aid else "main_exploiter_0" ) + self.config.policies = policies + self.config.policies_to_train = policies_to_train + self.config.freeze() @override(LeagueBuilder) def build_league(self, result: ResultDict) -> None: - local_worker = self.trainer.workers.local_worker() + local_worker = self.algo.workers.local_worker() # If no evaluation results -> Use hist data gathered for training. if "evaluation" in result: @@ -191,7 +195,7 @@ def build_league(self, result: ResultDict) -> None: set(local_worker.policy_map.keys()) - trainable_policies ) - logger.info(f"League building after iter {self.trainer.iteration}:") + logger.info(f"League building after iter {self.algo.iteration}:") # Calculate current win-rates. for policy_id, rew in hist_stats.items(): @@ -363,10 +367,10 @@ def policy_mapping_fn(agent_id, episode, worker, **kwargs): return "main_0" # Add and set the weights of the new polic(y/ies). - state = self.trainer.get_policy(policy_id).get_state() - self.trainer.add_policy( + state = self.algo.get_policy(policy_id).get_state() + self.algo.add_policy( policy_id=new_pol_id, - policy_cls=type(self.trainer.get_policy(policy_id)), + policy_cls=type(self.algo.get_policy(policy_id)), policy_state=state, policy_mapping_fn=policy_mapping_fn, policies_to_train=trainable_policies, diff --git a/rllib/algorithms/alpha_star/tests/test_alpha_star.py b/rllib/algorithms/alpha_star/tests/test_alpha_star.py index 44623bdef7285..c8a8b1c6cd9f4 100644 --- a/rllib/algorithms/alpha_star/tests/test_alpha_star.py +++ b/rllib/algorithms/alpha_star/tests/test_alpha_star.py @@ -52,13 +52,14 @@ def test_alpha_star_compilation(self): num_iterations = 2 for _ in framework_iterator(config, with_eager_tracing=True): - trainer = config.build() + config.policies = None + algo = config.build() for i in range(num_iterations): - results = trainer.train() + results = algo.train() print(results) check_train_results(results) - check_compute_single_action(trainer) - trainer.stop() + check_compute_single_action(algo) + algo.stop() if __name__ == "__main__": diff --git a/rllib/algorithms/alpha_zero/alpha_zero.py b/rllib/algorithms/alpha_zero/alpha_zero.py index 90788d50fb396..82f298eda75bc 100644 --- a/rllib/algorithms/alpha_zero/alpha_zero.py +++ b/rllib/algorithms/alpha_zero/alpha_zero.py @@ -73,8 +73,8 @@ class AlphaZeroConfig(AlgorithmConfig): ... .rollouts(num_rollout_workers=4) >>> print(config.to_dict()) >>> # Build a Algorithm object from the config and run 1 training iteration. - >>> trainer = config.build(env="CartPole-v1") - >>> trainer.train() + >>> algo = config.build(env="CartPole-v1") + >>> algo.train() Example: >>> from ray.rllib.algorithms.alpha_zero import AlphaZeroConfig @@ -148,12 +148,12 @@ def __init__(self, algo_class=None): self.train_batch_size = 4000 self.batch_mode = "complete_episodes" # Extra configuration that disables exploration. - self.evaluation_config = { + self.evaluation(evaluation_config={ "mcts_config": { "argmax_tree_policy": True, "add_dirichlet_noise": False, }, - } + }) # __sphinx_doc_end__ # fmt: on diff --git a/rllib/algorithms/apex_ddpg/apex_ddpg.py b/rllib/algorithms/apex_ddpg/apex_ddpg.py index ed5bee9c84c67..7db4688f484db 100644 --- a/rllib/algorithms/apex_ddpg/apex_ddpg.py +++ b/rllib/algorithms/apex_ddpg/apex_ddpg.py @@ -23,8 +23,8 @@ class ApexDDPGConfig(DDPGConfig): >>> config = ApexDDPGConfig().training(lr=0.01).resources(num_gpus=1) >>> print(config.to_dict()) >>> # Build a Trainer object from the config and run one training iteration. - >>> trainer = config.build(env="Pendulum-v1") - >>> trainer.train() + >>> algo = config.build(env="Pendulum-v1") + >>> algo.train() Example: >>> from ray.rllib.algorithms.apex_ddpg.apex_ddpg import ApexDDPGConfig @@ -111,7 +111,6 @@ def __init__(self, algo_class=None): def training( self, *, - optimizer: Optional[dict] = None, max_requests_in_flight_per_sampler_worker: Optional[int] = None, max_requests_in_flight_per_replay_worker: Optional[int] = None, timeout_s_sampler_manager: Optional[float] = None, @@ -121,9 +120,6 @@ def training( """Sets the training related configuration. Args: - optimizer: Apex-DDPG optimizer settings (dict). Set the number of reply - buffer shards in here via the `num_replay_buffer_shards` key - (default=4). max_requests_in_flight_per_sampler_worker: Max number of inflight requests to each sampling worker. See the AsyncRequestsManager class for more details. Tuning these values is important when running experimens with @@ -158,8 +154,6 @@ def training( """ super().training(**kwargs) - if optimizer is not None: - self.optimizer = optimizer if max_requests_in_flight_per_sampler_worker is not None: self.max_requests_in_flight_per_sampler_worker = ( max_requests_in_flight_per_sampler_worker diff --git a/rllib/algorithms/apex_ddpg/tests/test_apex_ddpg.py b/rllib/algorithms/apex_ddpg/tests/test_apex_ddpg.py index 5b75c6309c6f2..ba6eef1a8fd63 100644 --- a/rllib/algorithms/apex_ddpg/tests/test_apex_ddpg.py +++ b/rllib/algorithms/apex_ddpg/tests/test_apex_ddpg.py @@ -22,24 +22,22 @@ def test_apex_ddpg_compilation_and_per_worker_epsilon_values(self): """Test whether APEX-DDPG can be built on all frameworks.""" config = ( apex_ddpg.ApexDDPGConfig() + .environment(env="Pendulum-v1") .rollouts(num_rollout_workers=2) .reporting(min_sample_timesteps_per_iteration=100) .training( num_steps_sampled_before_learning_starts=0, optimizer={"num_replay_buffer_shards": 1}, ) - .environment(env="Pendulum-v1") ) num_iterations = 1 for _ in framework_iterator(config, with_eager_tracing=True): - trainer = config.build() + algo = config.build() # Test per-worker scale distribution. - infos = trainer.workers.foreach_policy( - lambda p, _: p.get_exploration_state() - ) + infos = algo.workers.foreach_policy(lambda p, _: p.get_exploration_state()) scale = [i["cur_scale"] for i in infos] expected = [ 0.4 ** (1 + (i + 1) / float(config.num_workers - 1) * 7) @@ -48,20 +46,18 @@ def test_apex_ddpg_compilation_and_per_worker_epsilon_values(self): check(scale, [0.0] + expected) for _ in range(num_iterations): - results = trainer.train() + results = algo.train() check_train_results(results) print(results) - check_compute_single_action(trainer) + check_compute_single_action(algo) # Test again per-worker scale distribution # (should not have changed). - infos = trainer.workers.foreach_policy( - lambda p, _: p.get_exploration_state() - ) + infos = algo.workers.foreach_policy(lambda p, _: p.get_exploration_state()) scale = [i["cur_scale"] for i in infos] check(scale, [0.0] + expected) - trainer.stop() + algo.stop() if __name__ == "__main__": diff --git a/rllib/algorithms/apex_dqn/apex_dqn.py b/rllib/algorithms/apex_dqn/apex_dqn.py index 32f526adf6775..e140cc4eec0fa 100644 --- a/rllib/algorithms/apex_dqn/apex_dqn.py +++ b/rllib/algorithms/apex_dqn/apex_dqn.py @@ -15,18 +15,16 @@ import platform import random from collections import defaultdict -from typing import Callable, Dict, List, Optional, Type +from typing import Dict, List, Optional import ray from ray._private.dict import merge_dicts from ray.actor import ActorHandle -from ray.rllib import Policy from ray.rllib.algorithms import Algorithm from ray.rllib.algorithms.dqn.dqn import DQN, DQNConfig from ray.rllib.algorithms.dqn.learner_thread import LearnerThread from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.execution.parallel_requests import AsyncRequestsManager -from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils.actors import create_colocated_actors from ray.rllib.utils.annotations import override from ray.rllib.utils.deprecation import DEPRECATED_VALUE, Deprecated @@ -200,21 +198,6 @@ def __init__(self, algo_class=None): def training( self, *, - num_atoms: Optional[int] = None, - v_min: Optional[float] = None, - v_max: Optional[float] = None, - noisy: Optional[bool] = None, - sigma0: Optional[float] = None, - dueling: Optional[bool] = None, - hiddens: Optional[int] = None, - double_q: Optional[bool] = None, - n_step: Optional[int] = None, - before_learn_on_batch: Callable[ - [Type[MultiAgentBatch], List[Type[Policy]], Type[int]], - Type[MultiAgentBatch], - ] = None, - training_intensity: Optional[float] = None, - replay_buffer_config: Optional[dict] = None, max_requests_in_flight_per_sampler_worker: Optional[int] = None, max_requests_in_flight_per_replay_worker: Optional[int] = None, timeout_s_sampler_manager: Optional[float] = None, @@ -228,25 +211,40 @@ def training( When this is greater than 1, distributional Q-learning is used. v_min: Minimum value estimation v_max: Maximum value estimation - noisy: Whether to use noisy network to aid exploration. This adds - parametric noise to the model weights. + noisy: Whether to use noisy network to aid exploration. This adds parametric + noise to the model weights. sigma0: Control the initial parameter noise for noisy nets. - dueling: Whether to use dueling DQN policy. + dueling: Whether to use dueling DQN. hiddens: Dense-layer setup for each the advantage branch and the value branch - double_q: Whether to use double DQN for the policy. + double_q: Whether to use double DQN. n_step: N-step for Q-learning. before_learn_on_batch: Callback to run before learning on a multi-agent batch of experiences. - training_intensity: The ratio of timesteps to train on for every - timestep that is sampled. This must be greater than 0. + training_intensity: The intensity with which to update the model (vs + collecting samples from the env). + If None, uses "natural" values of: + `train_batch_size` / (`rollout_fragment_length` x `num_workers` x + `num_envs_per_worker`). + If not None, will make sure that the ratio between timesteps inserted + into and sampled from the buffer matches the given values. + Example: + training_intensity=1000.0 + train_batch_size=250 + rollout_fragment_length=1 + num_workers=1 (or 0) + num_envs_per_worker=1 + -> natural value = 250 / 1 = 250.0 + -> will make sure that replay+train op will be executed 4x asoften as + rollout+insert op (4 * 250 = 1000). + See: rllib/algorithms/dqn/dqn.py::calculate_rr_weights for further + details. replay_buffer_config: Replay buffer config. Examples: { "_enable_replay_buffer_api": True, "type": "MultiAgentReplayBuffer", "capacity": 50000, - "replay_batch_size": 32, "replay_sequence_length": 1, } - OR - @@ -306,30 +304,6 @@ def training( # Pass kwargs onto super's `training()` method. super().training(**kwargs) - if num_atoms is not None: - self.num_atoms = num_atoms - if v_min is not None: - self.v_min = v_min - if v_max is not None: - self.v_max = v_max - if noisy is not None: - self.noisy = noisy - if sigma0 is not None: - self.sigma0 = sigma0 - if dueling is not None: - self.dueling = dueling - if hiddens is not None: - self.hiddens = hiddens - if double_q is not None: - self.double_q = double_q - if n_step is not None: - self.n_step = n_step - if before_learn_on_batch is not None: - self.before_learn_on_batch = before_learn_on_batch - if training_intensity is not None: - self.training_intensity = training_intensity - if replay_buffer_config is not None: - self.replay_buffer_config = replay_buffer_config if max_requests_in_flight_per_sampler_worker is not None: self.max_requests_in_flight_per_sampler_worker = ( max_requests_in_flight_per_sampler_worker diff --git a/rllib/algorithms/apex_dqn/tests/test_apex_dqn.py b/rllib/algorithms/apex_dqn/tests/test_apex_dqn.py index b4a32025a8d02..75c5048ee38bd 100644 --- a/rllib/algorithms/apex_dqn/tests/test_apex_dqn.py +++ b/rllib/algorithms/apex_dqn/tests/test_apex_dqn.py @@ -23,6 +23,7 @@ def tearDown(self): def test_apex_zero_workers(self): config = ( apex_dqn.ApexDQNConfig() + .environment("CartPole-v0") .rollouts(num_rollout_workers=0) .resources(num_gpus=0) .training( @@ -38,16 +39,17 @@ def test_apex_zero_workers(self): ) for _ in framework_iterator(config): - trainer = config.build(env="CartPole-v0") - results = trainer.train() + algo = config.build() + results = algo.train() check_train_results(results) print(results) - trainer.stop() + algo.stop() def test_apex_dqn_compilation_and_per_worker_epsilon_values(self): """Test whether APEXDQN can be built on all frameworks.""" config = ( apex_dqn.ApexDQNConfig() + .environment("CartPole-v0") .rollouts(num_rollout_workers=3) .resources(num_gpus=0) .training( @@ -63,34 +65,31 @@ def test_apex_dqn_compilation_and_per_worker_epsilon_values(self): ) for _ in framework_iterator(config, with_eager_tracing=True): - trainer = config.build(env="CartPole-v0") + algo = config.build() # Test per-worker epsilon distribution. - infos = trainer.workers.foreach_policy( - lambda p, _: p.get_exploration_state() - ) + infos = algo.workers.foreach_policy(lambda p, _: p.get_exploration_state()) expected = [0.4, 0.016190862, 0.00065536] check([i["cur_epsilon"] for i in infos], [0.0] + expected) - check_compute_single_action(trainer) + check_compute_single_action(algo) for i in range(2): - results = trainer.train() + results = algo.train() check_train_results(results) print(results) # Test again per-worker epsilon distribution # (should not have changed). - infos = trainer.workers.foreach_policy( - lambda p, _: p.get_exploration_state() - ) + infos = algo.workers.foreach_policy(lambda p, _: p.get_exploration_state()) check([i["cur_epsilon"] for i in infos], [0.0] + expected) - trainer.stop() + algo.stop() def test_apex_lr_schedule(self): config = ( apex_dqn.ApexDQNConfig() + .environment("CartPole-v0") .rollouts( num_rollout_workers=1, rollout_fragment_length=5, @@ -143,7 +142,7 @@ def _step_n_times(algo, n: int): ] for _ in framework_iterator(config, frameworks=("torch", "tf")): - algo = config.build(env="CartPole-v0") + algo = config.build() lr = _step_n_times(algo, 3) # 50 timesteps # Close to 0.2 diff --git a/rllib/algorithms/appo/appo.py b/rllib/algorithms/appo/appo.py index cff5ef6713c28..60bcb0692939c 100644 --- a/rllib/algorithms/appo/appo.py +++ b/rllib/algorithms/appo/appo.py @@ -41,11 +41,12 @@ class APPOConfig(ImpalaConfig): >>> from ray.rllib.algorithms.appo import APPOConfig >>> config = APPOConfig().training(lr=0.01, grad_clip=30.0)\ ... .resources(num_gpus=1)\ - ... .rollouts(num_rollout_workers=16) + ... .rollouts(num_rollout_workers=16)\ + ... .environment("CartPole-v1") >>> print(config.to_dict()) >>> # Build a Algorithm object from the config and run 1 training iteration. - >>> trainer = config.build(env="CartPole-v1") - >>> trainer.train() + >>> algo = config.build() + >>> algo.train() Example: >>> from ray.rllib.algorithms.appo import APPOConfig @@ -196,7 +197,7 @@ def __call__(self, fetches): class APPO(Impala): def __init__(self, config, *args, **kwargs): - """Initializes a DDPPO instance.""" + """Initializes an APPO instance.""" super().__init__(config, *args, **kwargs) # After init: Initialize target net. diff --git a/rllib/algorithms/ars/ars.py b/rllib/algorithms/ars/ars.py index 39e1cb0dff2b0..01d9fc2021df6 100644 --- a/rllib/algorithms/ars/ars.py +++ b/rllib/algorithms/ars/ars.py @@ -50,11 +50,12 @@ class ARSConfig(AlgorithmConfig): >>> from ray.rllib.algorithms.ars import ARSConfig >>> config = ARSConfig().training(sgd_stepsize=0.02, report_length=20)\ ... .resources(num_gpus=0)\ - ... .rollouts(num_rollout_workers=4) + ... .rollouts(num_rollout_workers=4)\ + ... .environment("CartPole-v1") >>> print(config.to_dict()) >>> # Build a Algorithm object from the config and run 1 training iteration. - >>> trainer = config.build(env="CartPole-v1") - >>> trainer.train() + >>> algo = config.build() + >>> algo.train() Example: >>> from ray.rllib.algorithms.ars import ARSConfig @@ -97,13 +98,17 @@ def __init__(self): # Override some of AlgorithmConfig's default values with ARS-specific values. self.num_workers = 2 self.observation_filter = "MeanStdFilter" + # ARS will use Algorithm's evaluation WorkerSet (if evaluation_interval > 0). # Therefore, we must be careful not to use more than 1 env per eval worker # (would break ARSPolicy's compute_single_action method) and to not do # obs-filtering. - self.evaluation_config["num_envs_per_worker"] = 1 - self.evaluation_config["observation_filter"] = "NoFilter" - + self.evaluation( + evaluation_config={ + "num_envs_per_worker": 1, + "observation_filter": "NoFilter", + } + ) # __sphinx_doc_end__ # fmt: on diff --git a/rllib/algorithms/bandit/bandit.py b/rllib/algorithms/bandit/bandit.py index 983e9bb537133..24019f42201b3 100644 --- a/rllib/algorithms/bandit/bandit.py +++ b/rllib/algorithms/bandit/bandit.py @@ -47,8 +47,8 @@ class BanditLinTSConfig(BanditConfig): >>> config = BanditLinTSConfig().rollouts(num_rollout_workers=4) >>> print(config.to_dict()) >>> # Build a Algorithm object from the config and run 1 training iteration. - >>> trainer = config.build(env=WheelBanditEnv) - >>> trainer.train() + >>> algo = config.build(env=WheelBanditEnv) + >>> algo.train() """ def __init__(self): @@ -70,8 +70,8 @@ class BanditLinUCBConfig(BanditConfig): >>> config = BanditLinUCBConfig().rollouts(num_rollout_workers=4) >>> print(config.to_dict()) >>> # Build a Algorithm object from the config and run 1 training iteration. - >>> trainer = config.build(env=WheelBanditEnv) - >>> trainer.train() + >>> algo = config.build(env=WheelBanditEnv) + >>> algo.train() """ def __init__(self): diff --git a/rllib/algorithms/bandit/tests/test_bandits.py b/rllib/algorithms/bandit/tests/test_bandits.py index f4065d4631085..4ae2647b7d11b 100644 --- a/rllib/algorithms/bandit/tests/test_bandits.py +++ b/rllib/algorithms/bandit/tests/test_bandits.py @@ -54,15 +54,15 @@ def test_bandit_lin_ucb_compilation(self): ): for train_batch_size in [1, 10]: config.training(train_batch_size=train_batch_size) - trainer = config.build() + algo = config.build() results = None for i in range(num_iterations): - results = trainer.train() + results = algo.train() check_train_results(results) print(results) # Force good learning behavior (this is a very simple env). self.assertTrue(results["episode_reward_mean"] == 10.0) - trainer.stop() + algo.stop() if __name__ == "__main__": diff --git a/rllib/algorithms/bc/bc.py b/rllib/algorithms/bc/bc.py index b21ab61b0a27b..16e89bcccaa94 100644 --- a/rllib/algorithms/bc/bc.py +++ b/rllib/algorithms/bc/bc.py @@ -14,8 +14,8 @@ class BCConfig(MARWILConfig): ... .offline_data(input_="./rllib/tests/data/cartpole/large.json") >>> print(config.to_dict()) >>> # Build a Trainer object from the config and run 1 training iteration. - >>> trainer = config.build() - >>> trainer.train() + >>> algo = config.build() + >>> algo.train() Example: >>> from ray.rllib.algorithms.bc import BCConfig diff --git a/rllib/algorithms/bc/tests/test_bc.py b/rllib/algorithms/bc/tests/test_bc.py index 2e27b04354084..cdcfbd16b173a 100644 --- a/rllib/algorithms/bc/tests/test_bc.py +++ b/rllib/algorithms/bc/tests/test_bc.py @@ -50,10 +50,10 @@ def test_bc_compilation_and_learning_from_offline_file(self): # Test for all frameworks. for _ in framework_iterator(config, frameworks=("tf", "torch")): - trainer = config.build(env="CartPole-v0") + algo = config.build(env="CartPole-v0") learnt = False for i in range(num_iterations): - results = trainer.train() + results = algo.train() check_train_results(results) print(results) @@ -72,9 +72,9 @@ def test_bc_compilation_and_learning_from_offline_file(self): "data!".format(min_reward) ) - check_compute_single_action(trainer, include_prev_action_reward=True) + check_compute_single_action(algo, include_prev_action_reward=True) - trainer.stop() + algo.stop() if __name__ == "__main__": diff --git a/rllib/algorithms/cql/cql.py b/rllib/algorithms/cql/cql.py index 05b39cc5ee287..e0beffee4c091 100644 --- a/rllib/algorithms/cql/cql.py +++ b/rllib/algorithms/cql/cql.py @@ -49,8 +49,8 @@ class CQLConfig(SACConfig): ... .rollouts(num_rollout_workers=4) >>> print(config.to_dict()) >>> # Build a Trainer object from the config and run 1 training iteration. - >>> trainer = config.build(env="CartPole-v1") - >>> trainer.train() + >>> algo = config.build(env="CartPole-v1") + >>> algo.train() """ def __init__(self, algo_class=None): @@ -138,10 +138,6 @@ def validate_config(self, config: AlgorithmConfigDict) -> None: new="min_train_timesteps_per_iteration", error=True, ) - config["min_train_timesteps_per_iteration"] = config[ - "timesteps_per_iteration" - ] - config["timesteps_per_iteration"] = DEPRECATED_VALUE # Call super's validation method. super().validate_config(config) diff --git a/rllib/algorithms/cql/tests/test_cql.py b/rllib/algorithms/cql/tests/test_cql.py index 692eadc185789..7e747c8e29359 100644 --- a/rllib/algorithms/cql/tests/test_cql.py +++ b/rllib/algorithms/cql/tests/test_cql.py @@ -73,20 +73,19 @@ def test_cql_compilation(self): # Test for tf/torch frameworks. for fw in framework_iterator(config, with_eager_tracing=True): - trainer = config.build() + algo = config.build() for i in range(num_iterations): - results = trainer.train() + results = algo.train() check_train_results(results) print(results) eval_results = results["evaluation"] print( - f"iter={trainer.iteration} " - f"R={eval_results['episode_reward_mean']}" + f"iter={algo.iteration} " f"R={eval_results['episode_reward_mean']}" ) - check_compute_single_action(trainer) + check_compute_single_action(algo) # Get policy and model. - pol = trainer.get_policy() + pol = algo.get_policy() cql_model = pol.model if fw == "tf": pol.get_session().__enter__() @@ -95,7 +94,7 @@ def test_cql_compilation(self): # using the data from CQL's global replay buffer. # Get a sample (MultiAgentBatch). - batch = trainer.workers.local_worker().input_reader.next() + batch = algo.workers.local_worker().input_reader.next() multi_agent_batch = batch.as_multi_agent() # All experiences have been buffered for `default_policy` batch = multi_agent_batch.policy_batches["default_policy"] @@ -140,7 +139,7 @@ def test_cql_compilation(self): if fw == "tf": pol.get_session().__exit__(None, None, None) - trainer.stop() + algo.stop() if __name__ == "__main__": diff --git a/rllib/algorithms/ddpg/ddpg.py b/rllib/algorithms/ddpg/ddpg.py index dce3cef67bd73..688c2b783bd67 100644 --- a/rllib/algorithms/ddpg/ddpg.py +++ b/rllib/algorithms/ddpg/ddpg.py @@ -20,8 +20,8 @@ class DDPGConfig(SimpleQConfig): >>> config = DDPGConfig().training(lr=0.01).resources(num_gpus=1) >>> print(config.to_dict()) >>> # Build a Trainer object from the config and run one training iteration. - >>> trainer = config.build(env="Pendulum-v1") - >>> trainer.train() + >>> algo = config.build(env="Pendulum-v1") + >>> algo.train() Example: >>> from ray.rllib.algorithms.ddpg.ddpg import DDPGConfig diff --git a/rllib/algorithms/ddpg/tests/test_ddpg.py b/rllib/algorithms/ddpg/tests/test_ddpg.py index 320358c7b66d8..c059c0909e003 100644 --- a/rllib/algorithms/ddpg/tests/test_ddpg.py +++ b/rllib/algorithms/ddpg/tests/test_ddpg.py @@ -66,6 +66,7 @@ def test_ddpg_exploration_and_with_random_prerun(self): core_config = ( ddpg.DDPGConfig() + .environment("Pendulum-v1") .rollouts(num_rollout_workers=0) .training(num_steps_sampled_before_learning_starts=0) ) @@ -76,7 +77,7 @@ def test_ddpg_exploration_and_with_random_prerun(self): for _ in framework_iterator(core_config): config = copy.deepcopy(core_config) # Default OUNoise setup. - algo = config.build(env="Pendulum-v1") + algo = config.build() # Setting explore=False should always return the same action. a_ = algo.compute_single_action(obs, explore=False) check(algo.get_policy().global_timestep, 1) @@ -104,7 +105,7 @@ def test_ddpg_exploration_and_with_random_prerun(self): } ) - algo = ddpg.DDPG(config=config, env="Pendulum-v1") + algo = config.build() # ts=0 (get a deterministic action as per explore=False). deterministic_action = algo.compute_single_action(obs, explore=False) check(algo.get_policy().global_timestep, 1) diff --git a/rllib/algorithms/ddppo/ddppo.py b/rllib/algorithms/ddppo/ddppo.py index 126b321d1d939..0be7595b68f33 100644 --- a/rllib/algorithms/ddppo/ddppo.py +++ b/rllib/algorithms/ddppo/ddppo.py @@ -58,8 +58,8 @@ class DDPPOConfig(PPOConfig): ... .rollouts(num_rollout_workers=10) >>> print(config.to_dict()) >>> # Build a Algorithm object from the config and run 1 training iteration. - >>> trainer = config.build(env="CartPole-v1") - >>> trainer.train() + >>> algo = config.build(env="CartPole-v1") + >>> algo.train() Example: >>> from ray.rllib.algorithms.ddppo import DDPPOConfig @@ -181,7 +181,7 @@ class directly. Note that this arg can also be specified via config=config, env=env, logger_creator=logger_creator, **kwargs ) - if "train_batch_size" in config.keys() and config["train_batch_size"] != -1: + if "train_batch_size" in config and config["train_batch_size"] != -1: # Users should not define `train_batch_size` directly (always -1). raise ValueError( "Set rollout_fragment_length instead of train_batch_size for DDPPO." @@ -237,7 +237,10 @@ def validate_config(self, config): "num_gpus_per_worker=1." ) # `batch_mode` must be "truncate_episodes". - if config["batch_mode"] != "truncate_episodes": + if ( + not config.get("in_evaluation") + and config["batch_mode"] != "truncate_episodes" + ): raise ValueError( "Distributed data parallel requires truncate_episodes batch mode." ) diff --git a/rllib/algorithms/ddppo/tests/test_ddppo.py b/rllib/algorithms/ddppo/tests/test_ddppo.py index c0c5c52223b18..0b4e4e0c0ef6d 100644 --- a/rllib/algorithms/ddppo/tests/test_ddppo.py +++ b/rllib/algorithms/ddppo/tests/test_ddppo.py @@ -29,18 +29,18 @@ def test_ddppo_compilation(self): num_iterations = 2 for _ in framework_iterator(config, frameworks="torch"): - trainer = config.build(env="CartPole-v0") + algo = config.build(env="CartPole-v0") for i in range(num_iterations): - results = trainer.train() + results = algo.train() check_train_results(results) print(results) # Make sure, weights on all workers are the same. - weights = trainer.workers.foreach_worker(lambda w: w.get_weights()) + weights = algo.workers.foreach_worker(lambda w: w.get_weights()) for w in weights[1:]: check(w, weights[1]) - check_compute_single_action(trainer) - trainer.stop() + check_compute_single_action(algo) + algo.stop() def test_ddppo_schedule(self): """Test whether lr_schedule will anneal lr to 0""" @@ -51,15 +51,15 @@ def test_ddppo_schedule(self): num_iterations = 10 for _ in framework_iterator(config, "torch"): - trainer = config.build(env="CartPole-v0") + algo = config.build(env="CartPole-v0") lr = -100.0 for _ in range(num_iterations): - result = trainer.train() + result = algo.train() if result["info"][LEARNER_INFO]: lr = result["info"][LEARNER_INFO][DEFAULT_POLICY_ID][ LEARNER_STATS_KEY ]["cur_lr"] - trainer.stop() + algo.stop() assert lr == 0.0, "lr should anneal to 0.0" def test_validate_config(self): diff --git a/rllib/algorithms/dqn/dqn.py b/rllib/algorithms/dqn/dqn.py index 2c19d49726cc3..7bfc2ecd0909a 100644 --- a/rllib/algorithms/dqn/dqn.py +++ b/rllib/algorithms/dqn/dqn.py @@ -178,12 +178,12 @@ def training( Type[MultiAgentBatch], ] = None, training_intensity: Optional[float] = None, - replay_buffer_config: Optional[dict] = None, td_error_loss_fn: Optional[str] = None, categorical_distribution_temperature: Optional[float] = None, **kwargs, ) -> "DQNConfig": """Sets the training related configuration. + Args: num_atoms: Number of atoms for representing the distribution of return. When this is greater than 1, distributional Q-learning is used. @@ -286,8 +286,6 @@ def training( self.before_learn_on_batch = before_learn_on_batch if training_intensity is not None: self.training_intensity = training_intensity - if replay_buffer_config is not None: - self.replay_buffer_config = replay_buffer_config if td_error_loss_fn is not None: self.td_error_loss_fn = td_error_loss_fn assert self.td_error_loss_fn in [ diff --git a/rllib/algorithms/dqn/tests/test_dqn.py b/rllib/algorithms/dqn/tests/test_dqn.py index 2d8b1218aea7d..d3642dbe29fb6 100644 --- a/rllib/algorithms/dqn/tests/test_dqn.py +++ b/rllib/algorithms/dqn/tests/test_dqn.py @@ -26,6 +26,7 @@ def test_dqn_compilation(self): num_iterations = 1 config = ( dqn.dqn.DQNConfig() + .environment("CartPole-v0") .rollouts(num_rollout_workers=2) .training(num_steps_sampled_before_learning_starts=0) ) @@ -33,8 +34,7 @@ def test_dqn_compilation(self): for _ in framework_iterator(config, with_eager_tracing=True): # Double-dueling DQN. print("Double-dueling") - plain_config = deepcopy(config) - trainer = dqn.DQN(config=plain_config, env="CartPole-v0") + trainer = config.build() for i in range(num_iterations): results = trainer.train() check_train_results(results) @@ -48,7 +48,7 @@ def test_dqn_compilation(self): rainbow_config = deepcopy(config).training( num_atoms=10, noisy=True, double_q=True, dueling=True, n_step=5 ) - trainer = dqn.DQN(config=rainbow_config, env="CartPole-v0") + trainer = rainbow_config.build() for i in range(num_iterations): results = trainer.train() check_train_results(results) @@ -65,6 +65,7 @@ def test_dqn_compilation_integer_rewards(self): num_iterations = 1 config = ( dqn.dqn.DQNConfig() + .environment("Taxi-v3") .rollouts(num_rollout_workers=2) .training(num_steps_sampled_before_learning_starts=0) ) @@ -72,8 +73,7 @@ def test_dqn_compilation_integer_rewards(self): for _ in framework_iterator(config, with_eager_tracing=True): # Double-dueling DQN. print("Double-dueling") - plain_config = deepcopy(config) - trainer = dqn.DQN(config=plain_config, env="Taxi-v3") + trainer = config.build() for i in range(num_iterations): results = trainer.train() check_train_results(results) @@ -87,7 +87,7 @@ def test_dqn_compilation_integer_rewards(self): rainbow_config = deepcopy(config).training( num_atoms=10, noisy=True, double_q=True, dueling=True, n_step=5 ) - trainer = dqn.DQN(config=rainbow_config, env="Taxi-v3") + trainer = rainbow_config.build() for i in range(num_iterations): results = trainer.train() check_train_results(results) @@ -101,6 +101,7 @@ def test_dqn_exploration_and_soft_q_config(self): """Tests, whether a DQN Agent outputs exploration/softmaxed actions.""" config = ( dqn.dqn.DQNConfig() + .environment("FrozenLake-v1") .rollouts(num_rollout_workers=0) .environment(env_config={"is_slippery": False, "map_name": "4x4"}) ).training(num_steps_sampled_before_learning_starts=0) @@ -110,7 +111,7 @@ def test_dqn_exploration_and_soft_q_config(self): # Test against all frameworks. for _ in framework_iterator(config): # Default EpsilonGreedy setup. - trainer = dqn.DQN(config=config, env="FrozenLake-v1") + trainer = config.build() # Setting explore=False should always return the same action. a_ = trainer.compute_single_action(obs, explore=False) for _ in range(50): @@ -128,7 +129,7 @@ def test_dqn_exploration_and_soft_q_config(self): config.exploration( exploration_config={"type": "SoftQ", "temperature": 0.000001} ) - trainer = dqn.DQN(config=config, env="FrozenLake-v1") + trainer = config.build() # Due to the low temp, always expect the same action. actions = [trainer.compute_single_action(obs)] for _ in range(50): @@ -138,7 +139,7 @@ def test_dqn_exploration_and_soft_q_config(self): # Higher softmax temperature. config.exploration_config["temperature"] = 1.0 - trainer = dqn.DQN(config=config, env="FrozenLake-v1") + trainer = config.build() # Even with the higher temperature, if we set explore=False, we # should expect the same actions always. @@ -157,7 +158,7 @@ def test_dqn_exploration_and_soft_q_config(self): # With Random exploration. config.exploration(exploration_config={"type": "Random"}, explore=True) - trainer = dqn.DQN(config=config, env="FrozenLake-v1") + trainer = config.build() actions = [] for _ in range(300): actions.append(trainer.compute_single_action(obs)) diff --git a/rllib/algorithms/dreamer/dreamer.py b/rllib/algorithms/dreamer/dreamer.py index a2f93fc83b5c1..223088e0d76a1 100644 --- a/rllib/algorithms/dreamer/dreamer.py +++ b/rllib/algorithms/dreamer/dreamer.py @@ -41,8 +41,8 @@ class DreamerConfig(AlgorithmConfig): ... .rollouts(num_rollout_workers=4) >>> print(config.to_dict()) >>> # Build a Algorithm object from the config and run 1 training iteration. - >>> trainer = config.build(env="CartPole-v1") - >>> trainer.train() + >>> algo = config.build(env="CartPole-v1") + >>> algo.train() Example: >>> from ray import air @@ -113,11 +113,10 @@ def __init__(self): self.num_steps_sampled_before_learning_starts = 0 # .environment() - self.env_config = { + self.env_config.update({ # Repeats action send by policy for frame_skip times in env "frame_skip": 2, - } - + }) # __sphinx_doc_end__ # fmt: on diff --git a/rllib/algorithms/dreamer/tests/test_dreamer.py b/rllib/algorithms/dreamer/tests/test_dreamer.py index 779fef241d8bf..0d0ee01da79db 100644 --- a/rllib/algorithms/dreamer/tests/test_dreamer.py +++ b/rllib/algorithms/dreamer/tests/test_dreamer.py @@ -35,12 +35,12 @@ def test_dreamer_compilation(self): # Test against all frameworks. for _ in framework_iterator(config, frameworks="torch"): - trainer = config.build() + algo = config.build() for i in range(num_iterations): - results = trainer.train() + results = algo.train() print(results) # check_compute_single_action(trainer, include_state=True) - trainer.stop() + algo.stop() if __name__ == "__main__": diff --git a/rllib/algorithms/dt/dt.py b/rllib/algorithms/dt/dt.py index bd7f471848d58..55da34b8abb96 100644 --- a/rllib/algorithms/dt/dt.py +++ b/rllib/algorithms/dt/dt.py @@ -9,6 +9,7 @@ from ray.rllib.execution.train_ops import multi_gpu_train_one_step, train_one_step from ray.rllib.policy import Policy from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.utils import deep_update from ray.rllib.utils.annotations import override, PublicAPI from ray.rllib.utils.metrics import ( NUM_AGENT_STEPS_SAMPLED, @@ -113,9 +114,38 @@ def training( Args: replay_buffer_config: Replay buffer config. + Examples: { - "capacity": How many trajectories/episodes does the buffer hold. + "_enable_replay_buffer_api": True, + "type": "MultiAgentReplayBuffer", + "capacity": 50000, + "replay_sequence_length": 1, } + - OR - + { + "_enable_replay_buffer_api": True, + "type": "MultiAgentPrioritizedReplayBuffer", + "capacity": 50000, + "prioritized_replay_alpha": 0.6, + "prioritized_replay_beta": 0.4, + "prioritized_replay_eps": 1e-6, + "replay_sequence_length": 1, + } + - Where - + prioritized_replay_alpha: Alpha parameter controls the degree of + prioritization in the buffer. In other words, when a buffer sample has + a higher temporal-difference error, with how much more probability + should it drawn to use to update the parametrized Q-network. 0.0 + corresponds to uniform probability. Setting much above 1.0 may quickly + result as the sampling distribution could become heavily “pointy” with + low entropy. + prioritized_replay_beta: Beta parameter controls the degree of + importance sampling which suppresses the influence of gradient updates + from samples that have higher probability of being sampled via alpha + parameter and the temporal-difference error. + prioritized_replay_eps: Epsilon parameter sets the baseline probability + for sampling so that when the temporal-difference error of a sample is + zero, there is still a chance of drawing the sample. embed_dim: Dimension of the embeddings in the GPT model. num_layers: Number of attention layers in the GPT model. num_heads: Number of attention heads in the GPT model. Must divide @@ -142,7 +172,16 @@ def training( """ super().training(**kwargs) if replay_buffer_config is not None: - self.replay_buffer_config = replay_buffer_config + # Override entire `replay_buffer_config` if `type` key changes. + # Update, if `type` key remains the same or is not specified. + new_replay_buffer_config = deep_update( + {"replay_buffer_config": self.replay_buffer_config}, + {"replay_buffer_config": replay_buffer_config}, + False, + ["replay_buffer_config"], + ["replay_buffer_config"], + ) + self.replay_buffer_config = new_replay_buffer_config["replay_buffer_config"] if embed_dim is not None: self.embed_dim = embed_dim if num_layers is not None: diff --git a/rllib/algorithms/es/es.py b/rllib/algorithms/es/es.py index 3a6c9c13a9ee9..eab1267955a1e 100644 --- a/rllib/algorithms/es/es.py +++ b/rllib/algorithms/es/es.py @@ -51,8 +51,8 @@ class ESConfig(AlgorithmConfig): ... .rollouts(num_rollout_workers=4) >>> print(config.to_dict()) >>> # Build a Algorithm object from the config and run 1 training iteration. - >>> trainer = config.build(env="CartPole-v1") - >>> trainer.train() + >>> algo = config.build(env="CartPole-v1") + >>> algo.train() Example: >>> from ray.rllib.algorithms.es import ESConfig @@ -96,13 +96,17 @@ def __init__(self): self.train_batch_size = 10000 self.num_workers = 10 self.observation_filter = "MeanStdFilter" - # ARS will use Algorithm's evaluation WorkerSet (if evaluation_interval > 0). + + # ES will use Algorithm's evaluation WorkerSet (if evaluation_interval > 0). # Therefore, we must be careful not to use more than 1 env per eval worker - # (would break ARSPolicy's compute_single_action method) and to not do + # (would break ESPolicy's compute_single_action method) and to not do # obs-filtering. - self.evaluation_config["num_envs_per_worker"] = 1 - self.evaluation_config["observation_filter"] = "NoFilter" - + self.evaluation( + evaluation_config={ + "num_envs_per_worker": 1, + "observation_filter": "NoFilter", + } + ) # __sphinx_doc_end__ # fmt: on diff --git a/rllib/algorithms/impala/impala.py b/rllib/algorithms/impala/impala.py index 97f15997cd639..4d06325b92c34 100644 --- a/rllib/algorithms/impala/impala.py +++ b/rllib/algorithms/impala/impala.py @@ -65,8 +65,8 @@ class ImpalaConfig(AlgorithmConfig): ... .rollouts(num_rollout_workers=64) >>> print(config.to_dict()) >>> # Build a Algorithm object from the config and run 1 training iteration. - >>> trainer = config.build(env="CartPole-v1") - >>> trainer.train() + >>> algo = config.build(env="CartPole-v1") + >>> algo.train() Example: >>> from ray.rllib.algorithms.impala import ImpalaConfig @@ -104,8 +104,7 @@ def __init__(self, algo_class=None): self.minibatch_buffer_size = 1 self.num_sgd_iter = 1 self.replay_proportion = 0.0 - self.replay_ratio = ((1 / self.replay_proportion) - if self.replay_proportion > 0 else 0.0) + self.replay_ratio = 0.0 self.replay_buffer_num_slots = 0 self.learner_queue_size = 16 self.learner_queue_timeout = 300 @@ -205,8 +204,7 @@ def training( minibatching. This conf only has an effect if `num_sgd_iter > 1`. num_sgd_iter: Number of passes to make over each train batch. replay_proportion: Set >0 to enable experience replay. Saved samples will - be replayed with a p:1 proportion to new data samples. Used in the - execution plan API. + be replayed with a p:1 proportion to new data samples. replay_buffer_num_slots: Number of sample batches to store for replay. The number of transitions saved total will be (replay_buffer_num_slots * rollout_fragment_length). @@ -288,6 +286,9 @@ def training( self.num_sgd_iter = num_sgd_iter if replay_proportion is not None: self.replay_proportion = replay_proportion + self.replay_ratio = ( + (1 / self.replay_proportion) if self.replay_proportion > 0 else 0.0 + ) if replay_buffer_num_slots is not None: self.replay_buffer_num_slots = replay_buffer_num_slots if learner_queue_size is not None: diff --git a/rllib/algorithms/impala/tests/test_impala.py b/rllib/algorithms/impala/tests/test_impala.py index bb63791d7ec09..8b1723eac6c5e 100644 --- a/rllib/algorithms/impala/tests/test_impala.py +++ b/rllib/algorithms/impala/tests/test_impala.py @@ -50,18 +50,18 @@ def test_impala_compilation(self): ) # Test with and w/o aggregation workers (this has nothing # to do with LSTMs, though). - trainer = config.build(env=env) + algo = config.build(env=env) for i in range(num_iterations): - results = trainer.train() + results = algo.train() check_train_results(results) print(results) check_compute_single_action( - trainer, + algo, include_state=lstm, include_prev_action_reward=lstm, ) - trainer.stop() + algo.stop() def test_impala_lr_schedule(self): # Test whether we correctly ignore the "lr" setting. @@ -87,8 +87,8 @@ def get_lr(result): ] for fw in framework_iterator(config): - trainer = config.build() - policy = trainer.get_policy() + algo = config.build() + policy = algo.get_policy() try: if fw == "tf": @@ -96,11 +96,11 @@ def get_lr(result): else: check(policy.cur_lr, 0.05) for _ in range(1): - r1 = trainer.train() + r1 = algo.train() for _ in range(2): - r2 = trainer.train() + r2 = algo.train() for _ in range(2): - r3 = trainer.train() + r3 = algo.train() # Due to the asynch'ness of IMPALA, learner-stats metrics # could be delayed by one iteration. Do 3 train() calls here # and measure guaranteed decrease in lr between 1st and 3rd. @@ -111,7 +111,7 @@ def get_lr(result): assert lr3 <= lr2, (lr2, lr3) assert lr3 < lr1, (lr1, lr3) finally: - trainer.stop() + algo.stop() if __name__ == "__main__": diff --git a/rllib/algorithms/maddpg/maddpg_tf_policy.py b/rllib/algorithms/maddpg/maddpg_tf_policy.py index 28b4bcce3177b..f2f63774c7c36 100644 --- a/rllib/algorithms/maddpg/maddpg_tf_policy.py +++ b/rllib/algorithms/maddpg/maddpg_tf_policy.py @@ -68,13 +68,19 @@ def _make_continuous_space(space): "Space {} is not supported.".format(space) ) + from ray.rllib.algorithms.maddpg.maddpg import MADDPGConfig + + policies, _ = ( + MADDPGConfig.from_dict(config) + .environment(observation_space=obs_space, action_space=act_space) + .get_multi_agent_setup() + ) obs_space_n = [ - _make_continuous_space(spec.observation_space or obs_space) - for _, spec in config["multiagent"]["policies"].items() + _make_continuous_space(spec.observation_space) + for _, spec in policies.items() ] act_space_n = [ - _make_continuous_space(spec.action_space or act_space) - for _, spec in config["multiagent"]["policies"].items() + _make_continuous_space(spec.action_space) for _, spec in policies.items() ] # _____ Placeholders diff --git a/rllib/algorithms/maml/maml.py b/rllib/algorithms/maml/maml.py index bdf96bd89ef67..0768172556fb2 100644 --- a/rllib/algorithms/maml/maml.py +++ b/rllib/algorithms/maml/maml.py @@ -34,8 +34,8 @@ class MAMLConfig(AlgorithmConfig): >>> config = MAMLConfig().training(use_gae=False).resources(num_gpus=1) >>> print(config.to_dict()) >>> # Build a Algorithm object from the config and run 1 training iteration. - >>> trainer = config.build(env="CartPole-v1") - >>> trainer.train() + >>> algo = config.build(env="CartPole-v1") + >>> algo.train() Example: >>> from ray.rllib.algorithms.maml import MAMLConfig diff --git a/rllib/algorithms/maml/tests/test_maml.py b/rllib/algorithms/maml/tests/test_maml.py index 611cda9b24dc8..55abf71ce955b 100644 --- a/rllib/algorithms/maml/tests/test_maml.py +++ b/rllib/algorithms/maml/tests/test_maml.py @@ -34,13 +34,13 @@ def test_maml_compilation(self): continue print("env={}".format(env)) env_ = "ray.rllib.examples.env.{}".format(env) - trainer = config.build(env=env_) + algo = config.build(env=env_) for i in range(num_iterations): - results = trainer.train() + results = algo.train() check_train_results(results) print(results) - check_compute_single_action(trainer, include_prev_action_reward=True) - trainer.stop() + check_compute_single_action(algo, include_prev_action_reward=True) + algo.stop() if __name__ == "__main__": diff --git a/rllib/algorithms/marwil/tests/test_marwil.py b/rllib/algorithms/marwil/tests/test_marwil.py index 248b929824b18..88001a7bbc3f0 100644 --- a/rllib/algorithms/marwil/tests/test_marwil.py +++ b/rllib/algorithms/marwil/tests/test_marwil.py @@ -64,10 +64,10 @@ def test_marwil_compilation_and_learning_from_offline_file(self): # Test for all frameworks. for _ in framework_iterator(config, frameworks=("tf", "torch")): - trainer = config.build() + algo = config.build() learnt = False for i in range(num_iterations): - results = trainer.train() + results = algo.train() check_train_results(results) print(results) @@ -88,9 +88,9 @@ def test_marwil_compilation_and_learning_from_offline_file(self): "offline data!".format(min_reward) ) - check_compute_single_action(trainer, include_prev_action_reward=True) + check_compute_single_action(algo, include_prev_action_reward=True) - trainer.stop() + algo.stop() def test_marwil_cont_actions_from_offline_file(self): """Test whether MARWILTrainer runs with cont. actions. @@ -128,10 +128,10 @@ def test_marwil_cont_actions_from_offline_file(self): # Test for all frameworks. for _ in framework_iterator(config, frameworks=("tf", "torch")): - trainer = config.build(env="Pendulum-v1") + algo = config.build(env="Pendulum-v1") for i in range(num_iterations): - print(trainer.train()) - trainer.stop() + print(algo.train()) + algo.stop() def test_marwil_loss_function(self): """ @@ -155,8 +155,8 @@ def test_marwil_loss_function(self): reader = JsonReader(inputs=[data_file]) batch = reader.next() - trainer = config.build(env="CartPole-v0") - policy = trainer.get_policy() + algo = config.build(env="CartPole-v0") + policy = algo.get_policy() model = policy.model # Calculate our own expected values (to then compare against the diff --git a/rllib/algorithms/mbmpo/mbmpo.py b/rllib/algorithms/mbmpo/mbmpo.py index 0a1c86d8d2cd7..550ccefa94419 100644 --- a/rllib/algorithms/mbmpo/mbmpo.py +++ b/rllib/algorithms/mbmpo/mbmpo.py @@ -45,8 +45,8 @@ class MBMPOConfig(AlgorithmConfig): ... .rollouts(num_rollout_workers=64) >>> print(config.to_dict()) >>> # Build a Algorithm object from the config and run 1 training iteration. - >>> trainer = config.build(env="CartPole-v1") - >>> trainer.train() + >>> algo = config.build(env="CartPole-v1") + >>> algo.train() Example: >>> from ray.rllib.algorithms.mbmpo import MBMPOConfig @@ -231,7 +231,7 @@ def training( if horizon is not None: self.horizon = horizon if dynamics_model is not None: - self.dynamics_model = dynamics_model + self.dynamics_model.update(dynamics_model) if custom_vector_env is not None: self.custom_vector_env = custom_vector_env if num_maml_steps is not None: diff --git a/rllib/algorithms/mbmpo/tests/test_mbmpo.py b/rllib/algorithms/mbmpo/tests/test_mbmpo.py index 03abbaacdd38c..8fa75763b183b 100644 --- a/rllib/algorithms/mbmpo/tests/test_mbmpo.py +++ b/rllib/algorithms/mbmpo/tests/test_mbmpo.py @@ -30,15 +30,15 @@ def test_mbmpo_compilation(self): # Test for torch framework (tf not implemented yet). for _ in framework_iterator(config, frameworks="torch"): - trainer = config.build() + algo = config.build() for i in range(num_iterations): - results = trainer.train() + results = algo.train() check_train_results(results) print(results) - check_compute_single_action(trainer, include_prev_action_reward=False) - trainer.stop() + check_compute_single_action(algo, include_prev_action_reward=False) + algo.stop() if __name__ == "__main__": diff --git a/rllib/algorithms/mock.py b/rllib/algorithms/mock.py index 48a218146a075..a6f1848ece969 100644 --- a/rllib/algorithms/mock.py +++ b/rllib/algorithms/mock.py @@ -34,7 +34,7 @@ def setup(self, config): # Setup our config: Merge the user-supplied config (which could # be a partial config dict with the class' default). self.config = self.merge_trainer_configs( - self.get_default_config(), config, self._allow_unknown_configs + self.get_default_config(), config, True ) self.config["env"] = self._env_id diff --git a/rllib/algorithms/pg/pg.py b/rllib/algorithms/pg/pg.py index 402b29df026fa..a04442c6eee5d 100644 --- a/rllib/algorithms/pg/pg.py +++ b/rllib/algorithms/pg/pg.py @@ -5,7 +5,6 @@ from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import override from ray.rllib.utils.deprecation import Deprecated -from ray.rllib.utils.typing import AlgorithmConfigDict class PGConfig(AlgorithmConfig): @@ -16,8 +15,8 @@ class PGConfig(AlgorithmConfig): >>> config = PGConfig().training(lr=0.01).resources(num_gpus=1) >>> print(config.to_dict()) >>> # Build a Algorithm object from the config and run 1 training iteration. - >>> trainer = config.build(env="CartPole-v1") - >>> trainer.train() + >>> algo = config.build(env="CartPole-v1") + >>> algo.train() Example: >>> from ray.rllib.algorithms.pg import PGConfig @@ -103,11 +102,11 @@ class PG(Algorithm): @classmethod @override(Algorithm) - def get_default_config(cls) -> AlgorithmConfigDict: - return PGConfig().to_dict() + def get_default_config(cls) -> AlgorithmConfig: + return PGConfig() @override(Algorithm) - def get_default_policy_class(self, config) -> Type[Policy]: + def get_default_policy_class(self, config: AlgorithmConfig) -> Type[Policy]: if config["framework"] == "torch": from ray.rllib.algorithms.pg.pg_torch_policy import PGTorchPolicy diff --git a/rllib/algorithms/pg/pg_tf_policy.py b/rllib/algorithms/pg/pg_tf_policy.py index 5682caace2e7d..06881a8a72cb6 100644 --- a/rllib/algorithms/pg/pg_tf_policy.py +++ b/rllib/algorithms/pg/pg_tf_policy.py @@ -5,11 +5,10 @@ import logging from typing import Dict, List, Type, Union, Optional, Tuple -import ray - from ray.rllib.evaluation.episode import Episode from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2 from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 +from ray.rllib.algorithms.pg.pg import PGConfig from ray.rllib.algorithms.pg.utils import post_process_advantages from ray.rllib.utils.typing import AgentID from ray.rllib.utils.annotations import override @@ -49,14 +48,16 @@ def __init__( self, observation_space, action_space, - config, + config: PGConfig, existing_model=None, existing_inputs=None, ): # First thing first, enable eager execution if necessary. base.enable_eager_execution_if_necessary() - config = dict(ray.rllib.algorithms.pg.PGConfig().to_dict(), **config) + # Enforce AlgorithmConfig for PG Policies. + if isinstance(config, dict): + config = PGConfig.from_dict(config) # Initialize base class. base.__init__( @@ -68,7 +69,7 @@ def __init__( existing_model=existing_model, ) - LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) + LearningRateSchedule.__init__(self, config.lr, config.lr_schedule) # Note: this is a bit ugly, but loss and optimizer initialization must # happen after all the MixIns are initialized. diff --git a/rllib/algorithms/pg/pg_torch_policy.py b/rllib/algorithms/pg/pg_torch_policy.py index f8ed99a6e7690..a04f4834ca4ed 100644 --- a/rllib/algorithms/pg/pg_torch_policy.py +++ b/rllib/algorithms/pg/pg_torch_policy.py @@ -4,13 +4,12 @@ import logging from typing import Dict, List, Type, Union, Optional, Tuple -import ray - from ray.rllib.evaluation.episode import Episode from ray.rllib.utils.typing import AgentID from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 from ray.rllib.utils.annotations import override from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.algorithms.pg.pg import PGConfig from ray.rllib.algorithms.pg.utils import post_process_advantages from ray.rllib.evaluation.postprocessing import Postprocessing from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper @@ -29,19 +28,21 @@ class PGTorchPolicy(LearningRateSchedule, TorchPolicyV2): """PyTorch policy class used with PGTrainer.""" - def __init__(self, observation_space, action_space, config): + def __init__(self, observation_space, action_space, config: PGConfig): - config = dict(ray.rllib.algorithms.pg.PGConfig().to_dict(), **config) + # Enforce AlgorithmConfig for PG Policies. + if isinstance(config, dict): + config = PGConfig.from_dict(config) TorchPolicyV2.__init__( self, observation_space, action_space, config, - max_seq_len=config["model"]["max_seq_len"], + max_seq_len=config.model["max_seq_len"], ) - LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) + LearningRateSchedule.__init__(self, config.lr, config.lr_schedule) # TODO: Don't require users to call this manually. self._initialize_loss_from_dummy_batch() diff --git a/rllib/algorithms/pg/tests/test_pg.py b/rllib/algorithms/pg/tests/test_pg.py index ed16ef701ad84..f9bc557e65f6e 100644 --- a/rllib/algorithms/pg/tests/test_pg.py +++ b/rllib/algorithms/pg/tests/test_pg.py @@ -85,13 +85,15 @@ def test_pg_compilation(self): "FrozenLake-v1", ]: print(f"env={env}") - trainer = config.build(env=env) + config.environment(env) + + algo = config.build() for i in range(num_iterations): - results = trainer.train() + results = algo.train() check_train_results(results) print(results) - check_compute_single_action(trainer, include_prev_action_reward=True) + check_compute_single_action(algo, include_prev_action_reward=True) def test_pg_loss_functions(self): """Tests the PG loss function math.""" @@ -123,8 +125,8 @@ def test_pg_loss_functions(self): for fw, sess in framework_iterator(config, session=True): dist_cls = Categorical if fw != "torch" else TorchCategorical - trainer = config.build(env="CartPole-v0") - policy = trainer.get_policy() + algo = config.build(env="CartPole-v0") + policy = algo.get_policy() vars = policy.model.trainable_variables() if sess: vars = policy.get_session().run(vars) diff --git a/rllib/algorithms/ppo/ppo.py b/rllib/algorithms/ppo/ppo.py index 281301ced1629..97f1fb4ff8ff3 100644 --- a/rllib/algorithms/ppo/ppo.py +++ b/rllib/algorithms/ppo/ppo.py @@ -54,8 +54,8 @@ class PPOConfig(AlgorithmConfig): ... .rollouts(num_rollout_workers=4) >>> print(config.to_dict()) >>> # Build a Algorithm object from the config and run 1 training iteration. - >>> trainer = config.build(env="CartPole-v1") - >>> trainer.train() + >>> algo = config.build(env="CartPole-v1") + >>> algo.train() Example: >>> from ray.rllib.algorithms.ppo import PPOConfig @@ -341,7 +341,8 @@ def validate_config(self, config: AlgorithmConfigDict) -> None: * config["rollout_fragment_length"] ) if ( - config["train_batch_size"] > 0 + not config.get("in_evaluation") + and config["train_batch_size"] > 0 and config["train_batch_size"] % calculated_min_rollout_size != 0 ): new_rollout_fragment_length = math.ceil( @@ -366,7 +367,11 @@ def validate_config(self, config: AlgorithmConfigDict) -> None: # `postprocessing_fn`), iff generalized advantage estimation is used # (value function estimate at end of truncated episode to estimate # remaining value). - if config["batch_mode"] == "truncate_episodes" and not config["use_gae"]: + if ( + not config.get("in_evaluation") + and config["batch_mode"] == "truncate_episodes" + and not config["use_gae"] + ): raise ValueError( "Episode truncation is not supported without a value " "function (to estimate the return at the end of the truncated" diff --git a/rllib/algorithms/ppo/tests/test_ppo.py b/rllib/algorithms/ppo/tests/test_ppo.py index 2d0c3c54675c6..00013d060d50a 100644 --- a/rllib/algorithms/ppo/tests/test_ppo.py +++ b/rllib/algorithms/ppo/tests/test_ppo.py @@ -133,9 +133,9 @@ def test_ppo_compilation_and_schedule_mixins(self): ) ) - trainer = config.build(env=env) - policy = trainer.get_policy() - entropy_coeff = trainer.get_policy().entropy_coeff + algo = config.build(env=env) + policy = algo.get_policy() + entropy_coeff = algo.get_policy().entropy_coeff lr = policy.cur_lr if fw == "tf": entropy_coeff, lr = policy.get_session().run( @@ -145,20 +145,21 @@ def test_ppo_compilation_and_schedule_mixins(self): check(lr, config.lr) for i in range(num_iterations): - results = trainer.train() + results = algo.train() check_train_results(results) print(results) check_compute_single_action( - trainer, include_prev_action_reward=True, include_state=lstm + algo, include_prev_action_reward=True, include_state=lstm ) - trainer.stop() + algo.stop() def test_ppo_exploration_setup(self): """Tests, whether PPO runs with different exploration setups.""" config = ( ppo.PPOConfig() .environment( + "FrozenLake-v1", env_config={"is_slippery": False, "map_name": "4x4"}, ) .rollouts( @@ -171,7 +172,7 @@ def test_ppo_exploration_setup(self): # Test against all frameworks. for fw in framework_iterator(config): # Default Agent should be setup with StochasticSampling. - trainer = ppo.PPO(config=config, env="FrozenLake-v1") + trainer = config.build() # explore=False, always expect the same (deterministic) action. a_ = trainer.compute_single_action( obs, explore=False, prev_action=np.array(2), prev_reward=np.array(1.0) @@ -207,6 +208,7 @@ def test_ppo_free_log_std(self): """Tests the free log std option works.""" config = ( ppo.PPOConfig() + .environment("CartPole-v0") .rollouts( num_rollout_workers=0, ) @@ -222,7 +224,7 @@ def test_ppo_free_log_std(self): ) for fw, sess in framework_iterator(config, session=True): - trainer = ppo.PPO(config=config, env="CartPole-v0") + trainer = config.build() policy = trainer.get_policy() # Check the free log std var is created. @@ -265,6 +267,7 @@ def test_ppo_loss_function(self): """Tests the PPO loss function math.""" config = ( ppo.PPOConfig() + .environment("CartPole-v0") .rollouts( num_rollout_workers=0, ) @@ -279,7 +282,7 @@ def test_ppo_loss_function(self): ) for fw, sess in framework_iterator(config, session=True): - trainer = ppo.PPO(config=config, env="CartPole-v0") + trainer = config.build() policy = trainer.get_policy() # Check no free log std var by default. diff --git a/rllib/algorithms/qmix/qmix.py b/rllib/algorithms/qmix/qmix.py index 5a5ac7037fdb4..c07bc9bf9ebfa 100644 --- a/rllib/algorithms/qmix/qmix.py +++ b/rllib/algorithms/qmix/qmix.py @@ -136,11 +136,9 @@ def __init__(self): # The evaluation stats will be reported under the "evaluation" metric key. # Note that evaluation is currently not parallelized, and that for Ape-X # metrics are already only reported for the lowest epsilon workers. - self.evaluation_interval = None - self.evaluation_duration = 10 - self.evaluation_config = { - "explore": False, - } + self.evaluation( + evaluation_config={"explore": False} + ) # __sphinx_doc_end__ # fmt: on diff --git a/rllib/algorithms/qmix/tests/test_qmix.py b/rllib/algorithms/qmix/tests/test_qmix.py index 65f0360ead59f..28c9053a0e429 100644 --- a/rllib/algorithms/qmix/tests/test_qmix.py +++ b/rllib/algorithms/qmix/tests/test_qmix.py @@ -105,13 +105,13 @@ def test_avail_actions_qmix(self): .rollouts(num_envs_per_worker=5) ) # Test with vectorization on. - trainer = config.build() + algo = config.build() for _ in range(4): - trainer.train() # OK if it doesn't trip the action assertion error + algo.train() # OK if it doesn't trip the action assertion error - assert trainer.train()["episode_reward_mean"] == 30.0 - trainer.stop() + assert algo.train()["episode_reward_mean"] == 30.0 + algo.stop() ray.shutdown() diff --git a/rllib/algorithms/r2d2/r2d2.py b/rllib/algorithms/r2d2/r2d2.py index 7d5cdf36379bb..38385a115d025 100644 --- a/rllib/algorithms/r2d2/r2d2.py +++ b/rllib/algorithms/r2d2/r2d2.py @@ -208,7 +208,10 @@ def validate_config(self, config: AlgorithmConfigDict) -> None: # Call super's validation method. super().validate_config(config) - if config["replay_buffer_config"].get("replay_sequence_length", -1) != -1: + if ( + not config.get("in_evaluation") + and config["replay_buffer_config"].get("replay_sequence_length", -1) != -1 + ): raise ValueError( "`replay_sequence_length` is calculated automatically to be " "model->max_seq_len + burn_in!" diff --git a/rllib/algorithms/sac/sac.py b/rllib/algorithms/sac/sac.py index 798add5ce6e1d..01a32b99f93f7 100644 --- a/rllib/algorithms/sac/sac.py +++ b/rllib/algorithms/sac/sac.py @@ -5,6 +5,7 @@ from ray.rllib.algorithms.dqn.dqn import DQN from ray.rllib.algorithms.sac.sac_tf_policy import SACTFPolicy from ray.rllib.policy.policy import Policy +from ray.rllib.utils import deep_update from ray.rllib.utils.annotations import override from ray.rllib.utils.deprecation import ( DEPRECATED_VALUE, @@ -240,9 +241,9 @@ def training( if twin_q is not None: self.twin_q = twin_q if q_model_config is not None: - self.q_model_config = q_model_config + self.q_model_config.update(q_model_config) if policy_model_config is not None: - self.policy_model_config = policy_model_config + self.policy_model_config.update(policy_model_config) if tau is not None: self.tau = tau if initial_alpha is not None: @@ -254,7 +255,16 @@ def training( if store_buffer_in_checkpoints is not None: self.store_buffer_in_checkpoints = store_buffer_in_checkpoints if replay_buffer_config is not None: - self.replay_buffer_config = replay_buffer_config + # Override entire `replay_buffer_config` if `type` key changes. + # Update, if `type` key remains the same or is not specified. + new_replay_buffer_config = deep_update( + {"replay_buffer_config": self.replay_buffer_config}, + {"replay_buffer_config": replay_buffer_config}, + False, + ["replay_buffer_config"], + ["replay_buffer_config"], + ) + self.replay_buffer_config = new_replay_buffer_config["replay_buffer_config"] if training_intensity is not None: self.training_intensity = training_intensity if clip_actions is not None: diff --git a/rllib/algorithms/sac/tests/test_rnnsac.py b/rllib/algorithms/sac/tests/test_rnnsac.py index 6cf62ebaa0215..67b1f035dbb61 100644 --- a/rllib/algorithms/sac/tests/test_rnnsac.py +++ b/rllib/algorithms/sac/tests/test_rnnsac.py @@ -22,6 +22,7 @@ def test_rnnsac_compilation(self): """Test whether RNNSAC can be built on all frameworks.""" config = ( sac.RNNSACConfig() + .environment("CartPole-v0") .rollouts(num_rollout_workers=0) .training( # Wrap with an LSTM and use a very simple base-model. @@ -53,7 +54,7 @@ def test_rnnsac_compilation(self): # Test building an RNNSAC agent in all frameworks. for _ in framework_iterator(config, frameworks="torch"): - algo = config.build(env="CartPole-v0") + algo = config.build() for i in range(num_iterations): results = algo.train() print(results) diff --git a/rllib/algorithms/sac/tests/test_sac.py b/rllib/algorithms/sac/tests/test_sac.py index 4c2a2952417c6..122f22cbb2781 100644 --- a/rllib/algorithms/sac/tests/test_sac.py +++ b/rllib/algorithms/sac/tests/test_sac.py @@ -132,6 +132,7 @@ def test_sac_compilation(self): "CartPole-v0", ]: print("Env={}".format(env)) + config.environment(env) # Test making the Q-model a custom one for CartPole, otherwise, # use the default model. config.q_model_config["custom_model"] = ( @@ -139,36 +140,40 @@ def test_sac_compilation(self): if env == "CartPole-v0" else None ) - trainer = config.build(env=env) + algo = config.build() for i in range(num_iterations): - results = trainer.train() + results = algo.train() check_train_results(results) print(results) - check_compute_single_action(trainer) + check_compute_single_action(algo) # Test, whether the replay buffer is saved along with # a checkpoint (no point in doing it for all frameworks since # this is framework agnostic). if fw == "tf" and env == "CartPole-v0": - checkpoint = trainer.save() - new_trainer = sac.SAC(config, env=env) - new_trainer.restore(checkpoint) + checkpoint = algo.save() + new_algo = config.build() + new_algo.restore(checkpoint) # Get some data from the buffer and compare. - data = trainer.local_replay_buffer.replay_buffers[ + data = algo.local_replay_buffer.replay_buffers[ "default_policy" ]._storage[: 42 + 42] - new_data = new_trainer.local_replay_buffer.replay_buffers[ + new_data = new_algo.local_replay_buffer.replay_buffers[ "default_policy" ]._storage[: 42 + 42] check(data, new_data) - new_trainer.stop() + new_algo.stop() - trainer.stop() + algo.stop() def test_sac_loss_function(self): """Tests SAC loss function results across all frameworks.""" config = ( sac.SACConfig() + .environment( + SimpleEnv, + env_config={"simplex_actions": True}, + ) .training( twin_q=False, gamma=0.99, @@ -181,9 +186,6 @@ def test_sac_loss_function(self): .reporting( min_time_s_per_iteration=0, ) - .environment( - env_config={"simplex_actions": True}, - ) .debugging(seed=42) ) @@ -230,7 +232,6 @@ def test_sac_loss_function(self): "default_policy/log_alpha_1": "log_alpha", } - env = SimpleEnv batch_size = 64 obs_size = (batch_size, 1) actions = np.random.random(size=(batch_size, 2)) @@ -250,8 +251,8 @@ def test_sac_loss_function(self): config, frameworks=("tf", "torch"), session=True ): # Generate Algorithm and get its default Policy object. - trainer = config.build(env=env) - policy = trainer.get_policy() + algo = config.build() + policy = algo.get_policy() p_sess = None if sess: p_sess = policy.get_session() @@ -446,9 +447,9 @@ def test_sac_loss_function(self): tf_inputs.append(in_) # Set a fake-batch to use # (instead of sampling from replay buffer). - buf = trainer.local_replay_buffer + buf = algo.local_replay_buffer patch_buffer_with_fake_sampling_method(buf, in_) - trainer.train() + algo.train() updated_weights = policy.get_weights() # Net must have changed. if tf_updated_weights: @@ -465,9 +466,9 @@ def test_sac_loss_function(self): in_ = tf_inputs[update_iteration] # Set a fake-batch to use # (instead of sampling from replay buffer). - buf = trainer.local_replay_buffer + buf = algo.local_replay_buffer patch_buffer_with_fake_sampling_method(buf, in_) - trainer.train() + algo.train() # Compare updated model. for tf_key in sorted(tf_weights.keys()): if re.search("_[23]|alpha", tf_key): @@ -500,7 +501,7 @@ def test_sac_loss_function(self): ) else: check(tf_var, torch_var, atol=0.003) - trainer.stop() + algo.stop() def test_sac_dict_obs_order(self): dict_space = Dict( @@ -534,6 +535,7 @@ def step(self, action): tune.register_env("nested", lambda _: NestedDictEnv()) config = ( sac.SACConfig() + .environment("nested") .training( replay_buffer_config={ "capacity": 10, @@ -550,12 +552,12 @@ def step(self, action): num_iterations = 1 for _ in framework_iterator(config, with_eager_tracing=True): - trainer = config.build(env="nested") + algo = config.build() for _ in range(num_iterations): - results = trainer.train() + results = algo.train() check_train_results(results) print(results) - check_compute_single_action(trainer) + check_compute_single_action(algo) def _get_batch_helper(self, obs_size, actions, batch_size): return SampleBatch( diff --git a/rllib/algorithms/simple_q/simple_q.py b/rllib/algorithms/simple_q/simple_q.py index 5459885f73cfb..5f278439b9173 100644 --- a/rllib/algorithms/simple_q/simple_q.py +++ b/rllib/algorithms/simple_q/simple_q.py @@ -142,7 +142,7 @@ def __init__(self, algo_class=None): } # `evaluation()` - self.evaluation_config = {"explore": False} + self.evaluation(evaluation_config={"explore": False}) # `reporting()` self.min_time_s_per_iteration = None @@ -176,8 +176,6 @@ def training( """Sets the training related configuration. Args: - timesteps_per_iteration: Minimum env steps to optimize for per train call. - This value does not affect learning, only the length of iterations. target_network_update_freq: Update the target network every `target_network_update_freq` sample steps. replay_buffer_config: Replay buffer config. @@ -300,7 +298,8 @@ def validate_config(self, config: AlgorithmConfigDict) -> None: " used at the same time!" ) - validate_buffer_config(config) + if not config.get("in_evaluation"): + validate_buffer_config(config) # Multi-agent mode and multi-GPU optimizer. if config["multiagent"]["policies"] and not config["simple_optimizer"]: diff --git a/rllib/algorithms/simple_q/tests/test_simple_q.py b/rllib/algorithms/simple_q/tests/test_simple_q.py index 482bb29c07d9e..fe3880c57e1e0 100644 --- a/rllib/algorithms/simple_q/tests/test_simple_q.py +++ b/rllib/algorithms/simple_q/tests/test_simple_q.py @@ -45,16 +45,16 @@ def test_simple_q_compilation(self): num_iterations = 2 for _ in framework_iterator(config, with_eager_tracing=True): - trainer = config.build(env="CartPole-v0") - rw = trainer.workers.local_worker() + algo = config.build(env="CartPole-v0") + rw = algo.workers.local_worker() for i in range(num_iterations): sb = rw.sample() assert sb.count == config.rollout_fragment_length - results = trainer.train() + results = algo.train() check_train_results(results) print(results) - check_compute_single_action(trainer) + check_compute_single_action(algo) def test_simple_q_loss_function(self): """Tests the Simple-Q loss function results on all frameworks.""" @@ -66,11 +66,11 @@ def test_simple_q_loss_function(self): "fcnet_activation": "linear", }, num_steps_sampled_before_learning_starts=0, - ) + ).environment("CartPole-v0") for fw in framework_iterator(config): # Generate Algorithm and get its default Policy object. - trainer = simple_q.SimpleQ(config=config, env="CartPole-v0") + trainer = config.build() policy = trainer.get_policy() # Batch of size=2. input_ = SampleBatch( diff --git a/rllib/algorithms/slateq/slateq.py b/rllib/algorithms/slateq/slateq.py index 0384b9f9faf78..608ab18824156 100644 --- a/rllib/algorithms/slateq/slateq.py +++ b/rllib/algorithms/slateq/slateq.py @@ -35,8 +35,8 @@ class SlateQConfig(AlgorithmConfig): >>> config = SlateQConfig().training(lr=0.01).resources(num_gpus=1) >>> print(config.to_dict()) >>> # Build a Algorithm object from the config and run 1 training iteration. - >>> trainer = config.build(env="CartPole-v1") - >>> trainer.train() + >>> algo = config.build(env="CartPole-v1") + >>> algo.train() Example: >>> from ray.rllib.algorithms.slateq import SlateQConfig @@ -110,8 +110,6 @@ def __init__(self): "epsilon_timesteps": 250000, "final_epsilon": 0.01, } - # Switch to greedy actions in evaluation workers. - self.evaluation_config = {"explore": False} self.num_workers = 0 self.rollout_fragment_length = 4 self.train_batch_size = 32 @@ -120,6 +118,8 @@ def __init__(self): self.min_time_s_per_iteration = 1 self.compress_observations = False self._disable_preprocessor_api = True + # Switch to greedy actions in evaluation workers. + self.evaluation(evaluation_config={"explore": False}) # __sphinx_doc_end__ # fmt: on diff --git a/rllib/algorithms/tests/test_algorithm.py b/rllib/algorithms/tests/test_algorithm.py index 6ad2341026e45..56b173ec2a296 100644 --- a/rllib/algorithms/tests/test_algorithm.py +++ b/rllib/algorithms/tests/test_algorithm.py @@ -9,7 +9,7 @@ import ray import ray.rllib.algorithms.a3c as a3c import ray.rllib.algorithms.dqn as dqn -from ray.rllib.algorithms.bc import BC, BCConfig +from ray.rllib.algorithms.bc import BCConfig import ray.rllib.algorithms.pg as pg from ray.rllib.examples.env.multi_agent import MultiAgentCartPole from ray.rllib.examples.parallel_evaluation_and_training import AssertEvalCallback @@ -429,7 +429,7 @@ def test_no_env_but_eval_workers_do_have_env(self): .offline_data(input_=[input_file]) ) - bc = BC(config=offline_rl_config) + bc = offline_rl_config.build() bc.train() bc.stop() diff --git a/rllib/algorithms/tests/test_worker_failures.py b/rllib/algorithms/tests/test_worker_failures.py index 5bf70480ff671..0077c9701acf6 100644 --- a/rllib/algorithms/tests/test_worker_failures.py +++ b/rllib/algorithms/tests/test_worker_failures.py @@ -1,9 +1,8 @@ -import time -import unittest from collections import defaultdict - import gym import numpy as np +import time +import unittest import ray from ray.rllib.algorithms.algorithm_config import AlgorithmConfig @@ -12,7 +11,8 @@ from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.rllib.algorithms.dqn.dqn import DQNConfig from ray.rllib.algorithms.impala import ImpalaConfig -from ray.rllib.algorithms.pg import PG, PGConfig +from ray.rllib.algorithms.pg import PGConfig +from ray.rllib.algorithms.pg.pg_tf_policy import PGTF2Policy from ray.rllib.algorithms.pg.pg_torch_policy import PGTorchPolicy from ray.rllib.algorithms.ppo.ppo import PPOConfig from ray.rllib.env.multi_agent_env import make_multi_agent @@ -328,6 +328,7 @@ def test_recreate_eval_workers_parallel_to_training_w_async_req_manager(self): config = ( PGConfig() .evaluation( + evaluation_num_workers=1, enable_async_evaluation=True, evaluation_parallel_to_training=True, evaluation_duration="auto", @@ -409,7 +410,11 @@ def on_algorithm_init(self, *, algorithm, **kwargs): # Add a custom policy to algorithm algorithm.add_policy( policy_id="test_policy", - policy_cls=PGTorchPolicy, + policy_cls=( + PGTorchPolicy + if algorithm.config.framework_str == "torch" + else PGTF2Policy + ), observation_space=gym.spaces.Box(low=0, high=1, shape=(8,)), action_space=gym.spaces.Discrete(2), config={}, @@ -464,7 +469,7 @@ def on_algorithm_init(self, *, algorithm, **kwargs): ) for _ in framework_iterator(config, frameworks=("tf2", "torch")): - # Reset interaciton counter. + # Reset interaction counter. ray.wait([counter.reset.remote()]) a = config.build() @@ -637,6 +642,7 @@ def test_eval_workers_fault_but_restore_env(self): config = ( PGConfig() + .environment("fault_env") .rollouts( num_rollout_workers=2, ignore_worker_failures=True, # Ignore failure. @@ -677,7 +683,7 @@ def test_eval_workers_fault_but_restore_env(self): # Reset interaciton counter. ray.wait([counter.reset.remote()]) - a = PG(config=config, env="fault_env") + a = config.build() # Before train loop, workers are fresh and not recreated. self.assertTrue( diff --git a/rllib/env/policy_client.py b/rllib/env/policy_client.py index 9eb4819818756..d921e0bb4e656 100644 --- a/rllib/env/policy_client.py +++ b/rllib/env/policy_client.py @@ -16,7 +16,6 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.policy.sample_batch import MultiAgentBatch from ray.rllib.utils.annotations import PublicAPI -from ray.rllib.utils.pre_checks.multi_agent import check_multi_agent from ray.rllib.utils.typing import ( MultiAgentDict, EnvInfoDict, @@ -367,8 +366,8 @@ def _create_embedded_rollout_worker(kwargs, send_fn): """Create a local rollout worker and a thread that samples from it. Args: - kwargs: args for the RolloutWorker constructor. - send_fn: function to send a JSON request to the server. + kwargs: Args for the RolloutWorker constructor. + send_fn: Function to send a JSON request to the server. """ # Since the server acts as an input datasource, we have to reset the @@ -384,18 +383,18 @@ def _create_embedded_rollout_worker(kwargs, send_fn): # If server has no env (which is the expected case): # Generate a dummy ExternalEnv here using RandomEnv and the # given observation/action spaces. - if kwargs["policy_config"].get("env") is None: + if kwargs["config"].env is None: from ray.rllib.examples.env.random_env import RandomEnv, RandomMultiAgentEnv config = { - "action_space": kwargs["policy_config"]["action_space"], - "observation_space": kwargs["policy_config"]["observation_space"], + "action_space": kwargs["config"].action_space, + "observation_space": kwargs["config"].observation_space, } - _, is_ma = check_multi_agent(kwargs["policy_config"]) + is_ma = kwargs["config"].is_multi_agent() kwargs["env_creator"] = _auto_wrap_external( lambda _: (RandomMultiAgentEnv if is_ma else RandomEnv)(config) ) - kwargs["policy_config"]["env"] = True + # kwargs["config"].env = True # Otherwise, use the env specified by the server args. else: real_env_creator = kwargs["env_creator"] diff --git a/rllib/env/tests/test_external_env.py b/rllib/env/tests/test_external_env.py index b88f1bdc454e8..309bf0bcb8d99 100644 --- a/rllib/env/tests/test_external_env.py +++ b/rllib/env/tests/test_external_env.py @@ -5,8 +5,9 @@ import uuid import ray -from ray.rllib.algorithms.dqn import DQN -from ray.rllib.algorithms.pg import PG +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig +from ray.rllib.algorithms.dqn import DQNConfig +from ray.rllib.algorithms.pg import PGConfig from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.env.external_env import ExternalEnv from ray.rllib.evaluation.tests.test_rollout_worker import BadPolicy, MockPolicy @@ -125,9 +126,12 @@ def tearDownClass(cls) -> None: def test_external_env_complete_episodes(self): ev = RolloutWorker( env_creator=lambda _: SimpleServing(MockEnv(25)), - policy_spec=MockPolicy, - rollout_fragment_length=40, - batch_mode="complete_episodes", + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + rollout_fragment_length=40, + batch_mode="complete_episodes", + num_rollout_workers=0, + ), ) for _ in range(3): batch = ev.sample() @@ -136,9 +140,11 @@ def test_external_env_complete_episodes(self): def test_external_env_truncate_episodes(self): ev = RolloutWorker( env_creator=lambda _: SimpleServing(MockEnv(25)), - policy_spec=MockPolicy, - rollout_fragment_length=40, - batch_mode="truncate_episodes", + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + rollout_fragment_length=40, + num_rollout_workers=0, + ), ) for _ in range(3): batch = ev.sample() @@ -147,9 +153,12 @@ def test_external_env_truncate_episodes(self): def test_external_env_off_policy(self): ev = RolloutWorker( env_creator=lambda _: SimpleOffPolicyServing(MockEnv(25), 42), - policy_spec=MockPolicy, - rollout_fragment_length=40, - batch_mode="complete_episodes", + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + rollout_fragment_length=40, + batch_mode="complete_episodes", + num_rollout_workers=0, + ), ) for _ in range(3): batch = ev.sample() @@ -160,10 +169,12 @@ def test_external_env_off_policy(self): def test_external_env_bad_actions(self): ev = RolloutWorker( env_creator=lambda _: SimpleServing(MockEnv(25)), - policy_spec=BadPolicy, - sample_async=True, - rollout_fragment_length=40, - batch_mode="truncate_episodes", + default_policy_class=BadPolicy, + config=AlgorithmConfig().rollouts( + rollout_fragment_length=40, + num_rollout_workers=0, + sample_async=True, + ), ) self.assertRaises(Exception, lambda: ev.sample()) @@ -172,12 +183,14 @@ def test_train_cartpole_off_policy(self): "test3", lambda _: PartOffPolicyServing(gym.make("CartPole-v0"), off_pol_frac=0.2), ) - config = { - "num_workers": 0, - "exploration_config": {"epsilon_timesteps": 100}, - } + config = ( + DQNConfig() + .environment("test3") + .rollouts(num_rollout_workers=0) + .exploration(exploration_config={"epsilon_timesteps": 100}) + ) for _ in framework_iterator(config, frameworks=("tf", "torch")): - dqn = DQN(env="test3", config=config) + dqn = config.build() reached = False for i in range(50): result = dqn.train() @@ -194,9 +207,9 @@ def test_train_cartpole_off_policy(self): def test_train_cartpole(self): register_env("test", lambda _: SimpleServing(gym.make("CartPole-v0"))) - config = {"num_workers": 0} + config = PGConfig().environment("test").rollouts(num_rollout_workers=0) for _ in framework_iterator(config, frameworks=("tf", "torch")): - pg = PG(env="test", config=config) + pg = config.build() reached = False for i in range(80): result = pg.train() @@ -213,9 +226,9 @@ def test_train_cartpole(self): def test_train_cartpole_multi(self): register_env("test2", lambda _: MultiServing(lambda: gym.make("CartPole-v0"))) - config = {"num_workers": 0} + config = PGConfig().environment("test2").rollouts(num_rollout_workers=0) for _ in framework_iterator(config, frameworks=("tf", "torch")): - pg = PG(env="test2", config=config) + pg = config.build() reached = False for i in range(80): result = pg.train() @@ -233,10 +246,13 @@ def test_train_cartpole_multi(self): def test_external_env_horizon_not_supported(self): ev = RolloutWorker( env_creator=lambda _: SimpleServing(MockEnv(25)), - policy_spec=MockPolicy, - episode_horizon=20, - rollout_fragment_length=10, - batch_mode="complete_episodes", + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + rollout_fragment_length=10, + horizon=20, + batch_mode="complete_episodes", + num_rollout_workers=0, + ), ) self.assertRaises(ValueError, lambda: ev.sample()) diff --git a/rllib/env/tests/test_external_multi_agent_env.py b/rllib/env/tests/test_external_multi_agent_env.py index 30c8e883d6a9f..596e364b7f977 100644 --- a/rllib/env/tests/test_external_multi_agent_env.py +++ b/rllib/env/tests/test_external_multi_agent_env.py @@ -1,8 +1,8 @@ -import gym import numpy as np import unittest import ray +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv from ray.rllib.env.tests.test_external_env import make_simple_serving from ray.rllib.evaluation.rollout_worker import RolloutWorker @@ -26,9 +26,12 @@ def test_external_multi_agent_env_complete_episodes(self): agents = 4 ev = RolloutWorker( env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)), - policy_spec=MockPolicy, - rollout_fragment_length=40, - batch_mode="complete_episodes", + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + rollout_fragment_length=40, + num_rollout_workers=0, + batch_mode="complete_episodes", + ), ) for _ in range(3): batch = ev.sample() @@ -39,9 +42,11 @@ def test_external_multi_agent_env_truncate_episodes(self): agents = 4 ev = RolloutWorker( env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)), - policy_spec=MockPolicy, - rollout_fragment_length=40, - batch_mode="truncate_episodes", + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + rollout_fragment_length=40, + num_rollout_workers=0, + ), ) for _ in range(3): batch = ev.sample() @@ -50,16 +55,19 @@ def test_external_multi_agent_env_truncate_episodes(self): def test_external_multi_agent_env_sample(self): agents = 2 - act_space = gym.spaces.Discrete(2) - obs_space = gym.spaces.Discrete(2) ev = RolloutWorker( env_creator=lambda _: SimpleMultiServing(BasicMultiAgent(agents)), - policy_spec={ - "p0": (MockPolicy, obs_space, act_space, {}), - "p1": (MockPolicy, obs_space, act_space, {}), - }, - policy_mapping_fn=lambda aid, **kwargs: "p{}".format(aid % 2), - rollout_fragment_length=50, + default_policy_class=MockPolicy, + config=AlgorithmConfig() + .rollouts( + rollout_fragment_length=50, + num_rollout_workers=0, + batch_mode="complete_episodes", + ) + .multi_agent( + policies={"p0", "p1"}, + policy_mapping_fn=lambda aid, **kwargs: "p{}".format(aid % 2), + ), ) batch = ev.sample() self.assertEqual(batch.count, 50) diff --git a/rllib/env/tests/test_multi_agent_env.py b/rllib/env/tests/test_multi_agent_env.py index eebfbf2b6807d..11f698e29f419 100644 --- a/rllib/env/tests/test_multi_agent_env.py +++ b/rllib/env/tests/test_multi_agent_env.py @@ -5,6 +5,7 @@ import ray from ray.tune.registry import register_env +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy from ray.rllib.algorithms.pg import PG from ray.rllib.env.multi_agent_env import make_multi_agent, MultiAgentEnvWrapper @@ -139,12 +140,13 @@ def policy_mapping_fn(agent_id, episode, worker, **kwargs): ev = RolloutWorker( env_creator=lambda _: BasicMultiAgent(5), - policy_spec={ - "p0": PolicySpec(policy_class=MockPolicy), - "p1": PolicySpec(policy_class=MockPolicy), - }, - policy_mapping_fn=policy_mapping_fn, - rollout_fragment_length=50, + default_policy_class=MockPolicy, + config=AlgorithmConfig() + .rollouts(rollout_fragment_length=50, num_rollout_workers=0) + .multi_agent( + policies={"p0", "p1"}, + policy_mapping_fn=policy_mapping_fn, + ), ) batch = ev.sample() self.assertEqual(batch.count, 50) @@ -155,18 +157,22 @@ def policy_mapping_fn(agent_id, episode, worker, **kwargs): def test_multi_agent_sample_sync_remote(self): ev = RolloutWorker( env_creator=lambda _: BasicMultiAgent(5), - policy_spec={ - "p0": PolicySpec(policy_class=MockPolicy), - "p1": PolicySpec(policy_class=MockPolicy), - }, + default_policy_class=MockPolicy, # This signature will raise a soft-deprecation warning due # to the new signature we are using (agent_id, episode, **kwargs), # but should not break this test. - policy_mapping_fn=(lambda agent_id: "p{}".format(agent_id % 2)), - rollout_fragment_length=50, - num_envs=4, - remote_worker_envs=True, - remote_env_batch_wait_ms=99999999, + config=AlgorithmConfig() + .rollouts( + rollout_fragment_length=50, + num_rollout_workers=0, + num_envs_per_worker=4, + remote_worker_envs=True, + remote_env_batch_wait_ms=99999999, + ) + .multi_agent( + policies={"p0", "p1"}, + policy_mapping_fn=(lambda agent_id: "p{}".format(agent_id % 2)), + ), ) batch = ev.sample() self.assertEqual(batch.count, 200) @@ -174,14 +180,18 @@ def test_multi_agent_sample_sync_remote(self): def test_multi_agent_sample_async_remote(self): ev = RolloutWorker( env_creator=lambda _: BasicMultiAgent(5), - policy_spec={ - "p0": PolicySpec(policy_class=MockPolicy), - "p1": PolicySpec(policy_class=MockPolicy), - }, - policy_mapping_fn=(lambda aid, **kwargs: "p{}".format(aid % 2)), - rollout_fragment_length=50, - num_envs=4, - remote_worker_envs=True, + default_policy_class=MockPolicy, + config=AlgorithmConfig() + .rollouts( + rollout_fragment_length=50, + num_rollout_workers=0, + num_envs_per_worker=4, + remote_worker_envs=True, + ) + .multi_agent( + policies={"p0", "p1"}, + policy_mapping_fn=(lambda aid, **kwargs: "p{}".format(aid % 2)), + ), ) batch = ev.sample() self.assertEqual(batch.count, 200) @@ -189,13 +199,17 @@ def test_multi_agent_sample_async_remote(self): def test_multi_agent_sample_with_horizon(self): ev = RolloutWorker( env_creator=lambda _: BasicMultiAgent(5), - policy_spec={ - "p0": PolicySpec(policy_class=MockPolicy), - "p1": PolicySpec(policy_class=MockPolicy), - }, - policy_mapping_fn=(lambda aid, **kwarg: "p{}".format(aid % 2)), - episode_horizon=10, # test with episode horizon set - rollout_fragment_length=50, + default_policy_class=MockPolicy, + config=AlgorithmConfig() + .rollouts( + rollout_fragment_length=50, + num_rollout_workers=0, + horizon=10, # test with episode horizon set + ) + .multi_agent( + policies={"p0", "p1"}, + policy_mapping_fn=(lambda aid, **kwarg: "p{}".format(aid % 2)), + ), ) batch = ev.sample() self.assertEqual(batch.count, 50) @@ -203,13 +217,17 @@ def test_multi_agent_sample_with_horizon(self): def test_sample_from_early_done_env(self): ev = RolloutWorker( env_creator=lambda _: EarlyDoneMultiAgent(), - policy_spec={ - "p0": PolicySpec(policy_class=MockPolicy), - "p1": PolicySpec(policy_class=MockPolicy), - }, - policy_mapping_fn=(lambda aid, **kwargs: "p{}".format(aid % 2)), - batch_mode="complete_episodes", - rollout_fragment_length=1, + default_policy_class=MockPolicy, + config=AlgorithmConfig() + .rollouts( + rollout_fragment_length=1, + num_rollout_workers=0, + batch_mode="complete_episodes", + ) + .multi_agent( + policies={"p0", "p1"}, + policy_mapping_fn=(lambda aid, **kwargs: "p{}".format(aid % 2)), + ), ) # This used to raise an Error due to the EarlyDoneMultiAgent # terminating at e.g. agent0 w/o publishing the observation for @@ -248,11 +266,16 @@ def test_multi_agent_with_flex_agents(self): def test_multi_agent_sample_round_robin(self): ev = RolloutWorker( env_creator=lambda _: RoundRobinMultiAgent(5, increment_obs=True), - policy_spec={ - "p0": PolicySpec(policy_class=MockPolicy), - }, - policy_mapping_fn=lambda agent_id, episode, **kwargs: "p0", - rollout_fragment_length=50, + default_policy_class=MockPolicy, + config=AlgorithmConfig() + .rollouts( + rollout_fragment_length=50, + num_rollout_workers=0, + ) + .multi_agent( + policies={"p0"}, + policy_mapping_fn=lambda agent_id, episode, **kwargs: "p0", + ), ) batch = ev.sample() self.assertEqual(batch.count, 50) @@ -302,15 +325,23 @@ def get_initial_state(self): ev = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v0"), - policy_spec=StatefulPolicy, - rollout_fragment_length=5, + default_policy_class=StatefulPolicy, + config=( + AlgorithmConfig().rollouts( + rollout_fragment_length=5, num_rollout_workers=0 + ) + # Force `state_in_0` to be repeated every ts in the collected batch + # (even though we don't even have a model that would care about this). + .training(model={"max_seq_len": 1}) + ), ) batch = ev.sample() self.assertEqual(batch.count, 5) self.assertEqual(batch["state_in_0"][0], {}) self.assertEqual(batch["state_out_0"][0], h) - self.assertEqual(batch["state_in_0"][1], h) - self.assertEqual(batch["state_out_0"][1], h) + for i in range(1, 5): + self.assertEqual(batch["state_in_0"][i], h) + self.assertEqual(batch["state_out_0"][i], h) def test_returning_model_based_rollouts_data(self): class ModelBasedPolicy(DQNTFPolicy): @@ -365,12 +396,16 @@ def compute_actions_from_input_dict( ev = RolloutWorker( env_creator=lambda _: MultiAgentCartPole({"num_agents": 2}), - policy_spec={ - "p0": PolicySpec(policy_class=ModelBasedPolicy), - "p1": PolicySpec(policy_class=ModelBasedPolicy), - }, - policy_mapping_fn=lambda agent_id, episode, **kwargs: "p0", - rollout_fragment_length=5, + default_policy_class=ModelBasedPolicy, + config=AlgorithmConfig() + .rollouts( + rollout_fragment_length=5, + num_rollout_workers=0, + ) + .multi_agent( + policies={"p0", "p1"}, + policy_mapping_fn=lambda agent_id, episode, **kwargs: "p0", + ), ) batch = ev.sample() # 5 environment steps (rollout_fragment_length). diff --git a/rllib/env/wrappers/model_vector_env.py b/rllib/env/wrappers/model_vector_env.py index 5837fce1e700f..628fc2b8bad12 100644 --- a/rllib/env/wrappers/model_vector_env.py +++ b/rllib/env/wrappers/model_vector_env.py @@ -29,14 +29,14 @@ def model_vector_env(env: EnvType) -> BaseEnv: env = _VectorizedModelGymEnv( make_env=worker.make_sub_env_fn, existing_envs=[env], - num_envs=worker.num_envs, + num_envs=worker.config.num_envs_per_worker, observation_space=env.observation_space, action_space=env.action_space, ) return convert_to_base_env( env, make_env=worker.make_sub_env_fn, - num_envs=worker.num_envs, + num_envs=worker.config.num_envs_per_worker, remote_envs=False, remote_env_batch_wait_ms=0, ) diff --git a/rllib/evaluation/collectors/agent_collector.py b/rllib/evaluation/collectors/agent_collector.py index 35355ee46dc0b..61c97ac661af7 100644 --- a/rllib/evaluation/collectors/agent_collector.py +++ b/rllib/evaluation/collectors/agent_collector.py @@ -325,9 +325,9 @@ def build_for_training( data_col, view_req, build_for_inference=False ) - # we need to skip this view_col if it does not exist in the buffers and + # We need to skip this view_col if it does not exist in the buffers and # is not an RNN state because it could be the special keys that gets - # added by policy's postprocessing function for trianing. + # added by policy's postprocessing function for training. if not is_state: continue diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index a3854e721af61..06e7efd9b48b9 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -1,8 +1,10 @@ from collections import defaultdict import copy +import importlib.util import logging import os import platform +from types import FunctionType from typing import ( TYPE_CHECKING, Any, @@ -17,10 +19,9 @@ Union, ) -import gym +from gym.spaces import Discrete, MultiDiscrete, Space import numpy as np import tree # pip install dm_tree -from gym.spaces import Discrete, MultiDiscrete, Space import ray from ray import ObjectRef @@ -38,10 +39,23 @@ from ray.rllib.evaluation.sampler import AsyncSampler, SyncSampler from ray.rllib.models import ModelCatalog from ray.rllib.models.preprocessors import Preprocessor -from ray.rllib.offline import NoopOutput, IOContext, OutputWriter, InputReader +from ray.rllib.offline import ( + D4RLReader, + DatasetReader, + DatasetWriter, + IOContext, + InputReader, + JsonReader, + JsonWriter, + MixedInput, + NoopOutput, + OutputWriter, + ShuffledInput, +) from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.policy.policy_map import PolicyMap from ray.rllib.utils.filter import NoFilter +from ray.rllib.utils.from_config import from_config from ray.rllib.policy.sample_batch import ( DEFAULT_POLICY_ID, MultiAgentBatch, @@ -49,10 +63,14 @@ ) from ray.rllib.policy.torch_policy import TorchPolicy from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 -from ray.rllib.utils import check_env, force_list, merge_dicts +from ray.rllib.utils import check_env, force_list from ray.rllib.utils.annotations import DeveloperAPI from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary -from ray.rllib.utils.deprecation import Deprecated, deprecation_warning +from ray.rllib.utils.deprecation import ( + Deprecated, + DEPRECATED_VALUE, + deprecation_warning, +) from ray.rllib.utils.error import ERR_MSG_NO_GPUS, HOWTO_CHANGE_CONFIG from ray.rllib.utils.filter import Filter, get_filter from ray.rllib.utils.framework import try_import_tf, try_import_torch @@ -62,10 +80,8 @@ from ray.rllib.utils.tf_utils import get_gpu_devices as get_tf_gpu_devices from ray.rllib.utils.typing import ( AgentID, - EnvConfigDict, EnvCreator, EnvType, - ModelConfigDict, ModelGradients, ModelWeights, MultiAgentPolicyConfigDict, @@ -78,11 +94,12 @@ from ray.util.annotations import PublicAPI from ray.util.debug import disable_log_once_globally, enable_periodic_logging, log_once from ray.util.iter import ParallelIteratorWorker +from ray.tune.registry import registry_contains_input, registry_get_input if TYPE_CHECKING: + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.algorithms.callbacks import DefaultCallbacks # noqa from ray.rllib.evaluation.episode import Episode - from ray.rllib.evaluation.observation_function import ObservationFunction tf1, tf, tfv = try_import_tf() torch, _ = try_import_torch() @@ -92,7 +109,7 @@ # Handle to the current rollout worker, which will be set to the most recently # created RolloutWorker in this process. This can be helpful to access in # custom env or policy classes for debugging or advanced use cases. -_global_worker: "RolloutWorker" = None +_global_worker: Optional["RolloutWorker"] = None @DeveloperAPI @@ -148,7 +165,7 @@ class RolloutWorker(ParallelIteratorWorker): >>> from ray.rllib.algorithms.pg.pg_tf_policy import PGTF1Policy >>> worker = RolloutWorker( # doctest: +SKIP ... env_creator=lambda _: gym.make("CartPole-v0"), # doctest: +SKIP - ... policy_spec=PGTF1Policy) # doctest: +SKIP + ... default_policy_class=PGTF1Policy) # doctest: +SKIP >>> print(worker.sample()) # doctest: +SKIP SampleBatch({ "obs": [[...]], "actions": [[...]], "rewards": [[...]], @@ -159,7 +176,8 @@ class RolloutWorker(ParallelIteratorWorker): >>> MultiAgentTrafficGrid = ... # doctest: +SKIP >>> worker = RolloutWorker( # doctest: +SKIP ... env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25), - ... policy_spec={ # doctest: +SKIP + ... config=AlgorithmConfig().multi_agent( + ... policies={ # doctest: +SKIP ... # Use an ensemble of two policies for car agents ... "car_policy1": # doctest: +SKIP ... (PGTFPolicy, Box(...), Discrete(...), {"gamma": 0.99}), @@ -168,10 +186,13 @@ class RolloutWorker(ParallelIteratorWorker): ... # Use a single shared policy for all traffic lights ... "traffic_light_policy": ... (PGTFPolicy, Box(...), Discrete(...), {}), - ... }, - ... policy_mapping_fn=lambda agent_id, episode, **kwargs: - ... random.choice(["car_policy1", "car_policy2"]) - ... if agent_id.startswith("car_") else "traffic_light_policy") + ... }, + ... policy_mapping_fn=( + ... lambda agent_id, episode, **kwargs: + ... random.choice(["car_policy1", "car_policy2"]) + ... if agent_id.startswith("car_") else "traffic_light_policy"), + ... ), + .. ) >>> print(worker.sample()) # doctest: +SKIP MultiAgentBatch({ "car_policy1": SampleBatch(...), @@ -216,49 +237,47 @@ def __init__( *, env_creator: EnvCreator, validate_env: Optional[Callable[[EnvType, EnvContext], None]] = None, - policy_spec: Optional[Union[type, Dict[PolicyID, PolicySpec]]] = None, - policy_mapping_fn: Optional[Callable[[AgentID, "Episode"], PolicyID]] = None, - policies_to_train: Union[ - Container[PolicyID], Callable[[PolicyID, SampleBatchType], bool] - ] = None, tf_session_creator: Optional[Callable[[], "tf1.Session"]] = None, - rollout_fragment_length: int = 100, - count_steps_by: str = "env_steps", - batch_mode: str = "truncate_episodes", - episode_horizon: Optional[int] = None, - preprocessor_pref: str = "deepmind", - sample_async: bool = False, - compress_observations: bool = False, - num_envs: int = 1, - observation_fn: Optional["ObservationFunction"] = None, - clip_rewards: Optional[Union[bool, float]] = None, - normalize_actions: bool = True, - clip_actions: bool = False, - env_config: Optional[EnvConfigDict] = None, - model_config: Optional[ModelConfigDict] = None, - policy_config: Optional[PartialAlgorithmConfigDict] = None, + config: Optional["AlgorithmConfig"] = None, worker_index: int = 0, - num_workers: int = 0, + num_workers: Optional[int] = None, recreated_worker: bool = False, log_dir: Optional[str] = None, - log_level: Optional[str] = None, - callbacks: Type["DefaultCallbacks"] = None, - input_creator: Callable[ - [IOContext], InputReader - ] = lambda ioctx: ioctx.default_sampler_input(), - output_creator: Callable[ - [IOContext], OutputWriter - ] = lambda ioctx: NoopOutput(), - remote_worker_envs: bool = False, - remote_env_batch_wait_ms: int = 0, - soft_horizon: bool = False, - no_done_at_end: bool = False, - seed: int = None, - extra_python_environs: Optional[dict] = None, - fake_sampler: bool = False, spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None, - policy=None, - disable_env_checking=False, + default_policy_class: Optional[Type[Policy]] = None, + dataset_shards: Optional[List[ray.data.dataset.Dataset]] = None, + # Deprecated: This is all specified in `config` anyways. + policy_config=DEPRECATED_VALUE, + input_creator=DEPRECATED_VALUE, + output_creator=DEPRECATED_VALUE, + rollout_fragment_length=DEPRECATED_VALUE, + count_steps_by=DEPRECATED_VALUE, + batch_mode=DEPRECATED_VALUE, + episode_horizon=DEPRECATED_VALUE, + preprocessor_pref=DEPRECATED_VALUE, + sample_async=DEPRECATED_VALUE, + compress_observations=DEPRECATED_VALUE, + num_envs=DEPRECATED_VALUE, + observation_fn=DEPRECATED_VALUE, + clip_rewards=DEPRECATED_VALUE, + normalize_actions=DEPRECATED_VALUE, + clip_actions=DEPRECATED_VALUE, + env_config=DEPRECATED_VALUE, + model_config=DEPRECATED_VALUE, + remote_worker_envs=DEPRECATED_VALUE, + remote_env_batch_wait_ms=DEPRECATED_VALUE, + soft_horizon=DEPRECATED_VALUE, + no_done_at_end=DEPRECATED_VALUE, + fake_sampler=DEPRECATED_VALUE, + seed=DEPRECATED_VALUE, + log_level=DEPRECATED_VALUE, + callbacks=DEPRECATED_VALUE, + disable_env_checking=DEPRECATED_VALUE, + policy_spec=DEPRECATED_VALUE, + policy_mapping_fn=DEPRECATED_VALUE, + policies_to_train=DEPRECATED_VALUE, + extra_python_environs=DEPRECATED_VALUE, + policy=DEPRECATED_VALUE, ): """Initializes a RolloutWorker instance. @@ -267,123 +286,165 @@ def __init__( wrapped configuration. validate_env: Optional callable to validate the generated environment (only on worker=0). - policy_spec: The MultiAgentPolicyConfigDict mapping policy IDs - (str) to PolicySpec's or a single policy class to use. - If a dict is specified, then we are in multi-agent mode and a - policy_mapping_fn can also be set (if not, will map all agents - to DEFAULT_POLICY_ID). - policy_mapping_fn: A callable that maps agent ids to policy ids in - multi-agent mode. This function will be called each time a new - agent appears in an episode, to bind that agent to a policy - for the duration of the episode. If not provided, will map all - agents to DEFAULT_POLICY_ID. - policies_to_train: Optional container of policies to train (None - for all policies), or a callable taking PolicyID and - SampleBatchType and returning a bool (trainable or not?). tf_session_creator: A function that returns a TF session. This is optional and only useful with TFPolicy. - rollout_fragment_length: The target number of steps - (measured in `count_steps_by`) to include in each sample - batch returned from this worker. - count_steps_by: The unit in which to count fragment - lengths. One of env_steps or agent_steps. - batch_mode: One of the following batch modes: - - "truncate_episodes": Each call to sample() will return a - batch of at most `rollout_fragment_length * num_envs` in size. - The batch will be exactly `rollout_fragment_length * num_envs` - in size if postprocessing does not change batch sizes. Episodes - may be truncated in order to meet this size requirement. - - "complete_episodes": Each call to sample() will return a - batch of at least `rollout_fragment_length * num_envs` in - size. Episodes will not be truncated, but multiple episodes - may be packed within one batch to meet the batch size. Note - that when `num_envs > 1`, episode steps will be buffered - until the episode completes, and hence batches may contain - significant amounts of off-policy data. - episode_horizon: Horizon at which to stop episodes (even if the - environment itself has not returned a "done" signal). - preprocessor_pref: Whether to use RLlib preprocessors - ("rllib") or deepmind ("deepmind"), when applicable. - sample_async: Whether to compute samples asynchronously in - the background, which improves throughput but can cause samples - to be slightly off-policy. - compress_observations: If true, compress the observations. - They can be decompressed with rllib/utils/compression. - num_envs: If more than one, will create multiple envs - and vectorize the computation of actions. This has no effect if - if the env already implements VectorEnv. - observation_fn: Optional multi-agent observation function. - clip_rewards: True for clipping rewards to [-1.0, 1.0] prior - to experience postprocessing. None: Clip for Atari only. - float: Clip to [-clip_rewards; +clip_rewards]. - normalize_actions: Whether to normalize actions to the - action space's bounds. - clip_actions: Whether to clip action values to the range - specified by the policy action space. - env_config: Config to pass to the env creator. - model_config: Config to use when creating the policy model. - policy_config: Config to pass to the - policy. In the multi-agent case, this config will be merged - with the per-policy configs specified by `policy_spec`. worker_index: For remote workers, this should be set to a non-zero and unique value. This index is passed to created envs through EnvContext so that envs can be configured per worker. - num_workers: For remote workers, how many workers altogether - have been created? recreated_worker: Whether this worker is a recreated one. Workers are recreated by an Algorithm (via WorkerSet) in case `recreate_failed_workers=True` and one of the original workers (or an already recreated one) has failed. They don't differ from original workers other than the value of this flag (`self.recreated_worker`). log_dir: Directory where logs can be placed. - log_level: Set the root log level on creation. - callbacks: Custom sub-class of - DefaultCallbacks for training/policy/rollout-worker callbacks. - input_creator: Function that returns an InputReader object for - loading previous generated experiences. - output_creator: Function that returns an OutputWriter object for - saving generated experiences. - remote_worker_envs: If using num_envs_per_worker > 1, - whether to create those new envs in remote processes instead of - in the current process. This adds overheads, but can make sense - if your envs are expensive to step/reset (e.g., for StarCraft). - Use this cautiously, overheads are significant! - remote_env_batch_wait_ms: Timeout that remote workers - are waiting when polling environments. 0 (continue when at - least one env is ready) is a reasonable default, but optimal - value could be obtained by measuring your environment - step / reset and model inference perf. - soft_horizon: Calculate rewards but don't reset the - environment when the horizon is hit. - no_done_at_end: Ignore the done=True at the end of the - episode and instead record done=False. - seed: Set the seed of both np and tf to this value to - to ensure each remote worker has unique exploration behavior. - extra_python_environs: Extra python environments need to be set. - fake_sampler: Use a fake (inf speed) sampler for testing. spaces: An optional space dict mapping policy IDs to (obs_space, action_space)-tuples. This is used in case no Env is created on this RolloutWorker. - policy: Obsoleted arg. Use `policy_spec` instead. - disable_env_checking: If True, disables the env checking module that - validates the properties of the passed environment. """ + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig # Deprecated args. - if policy is not None: - deprecation_warning("policy", "policy_spec", error=False) - policy_spec = policy - assert ( - policy_spec is not None - ), "Must provide `policy_spec` when creating RolloutWorker!" - - # Do quick translation into MultiAgentPolicyConfigDict. - if not isinstance(policy_spec, dict): - policy_spec = {DEFAULT_POLICY_ID: PolicySpec(policy_class=policy_spec)} - policy_spec = { - pid: spec if isinstance(spec, PolicySpec) else PolicySpec(*spec) - for pid, spec in policy_spec.copy().items() - } + if policy != DEPRECATED_VALUE: + deprecation_warning("policy", "policy_spec", error=True) + if policy_spec != DEPRECATED_VALUE: + deprecation_warning( + "policy_spec", + "RolloutWorker(default_policy_class=...)", + error=True, + ) + if policy_config != DEPRECATED_VALUE: + deprecation_warning("policy_config", "config", error=True) + if input_creator != DEPRECATED_VALUE: + deprecation_warning( + "input_creator", + "config.offline_data(input_=..)", + error=True, + ) + if output_creator != DEPRECATED_VALUE: + deprecation_warning( + "output_creator", + "config.offline_data(output=..)", + error=True, + ) + if rollout_fragment_length != DEPRECATED_VALUE: + deprecation_warning( + "rollout_fragment_length", + "config.rollouts(rollout_fragment_length=..)", + error=True, + ) + if count_steps_by != DEPRECATED_VALUE: + deprecation_warning( + "count_steps_by", "config.multi_agent(count_steps_by=..)", error=True + ) + if batch_mode != DEPRECATED_VALUE: + deprecation_warning( + "batch_mode", "config.rollouts(batch_mode=..)", error=True + ) + if episode_horizon != DEPRECATED_VALUE: + deprecation_warning( + "episode_horizon", "config.rollouts(horizon=..)", error=True + ) + if preprocessor_pref != DEPRECATED_VALUE: + deprecation_warning( + "preprocessor_pref", "config.rollouts(preprocessor_pref=..)", error=True + ) + if sample_async != DEPRECATED_VALUE: + deprecation_warning( + "sample_async", "config.rollouts(sample_async=..)", error=True + ) + if compress_observations != DEPRECATED_VALUE: + deprecation_warning( + "compress_observations", + "config.rollouts(compress_observations=..)", + error=True, + ) + if num_envs != DEPRECATED_VALUE: + deprecation_warning( + "num_envs", "config.rollouts(num_envs_per_worker=..)", error=True + ) + if observation_fn != DEPRECATED_VALUE: + deprecation_warning( + "observation_fn", "config.multi_agent(observation_fn=..)", error=True + ) + if clip_rewards != DEPRECATED_VALUE: + deprecation_warning( + "clip_rewards", "config.environment(clip_rewards=..)", error=True + ) + if normalize_actions != DEPRECATED_VALUE: + deprecation_warning( + "normalize_actions", + "config.environment(normalize_actions=..)", + error=True, + ) + if clip_actions != DEPRECATED_VALUE: + deprecation_warning( + "clip_actions", "config.environment(clip_actions=..)", error=True + ) + if env_config != DEPRECATED_VALUE: + deprecation_warning( + "env_config", "config.environment(env_config=..)", error=True + ) + if model_config != DEPRECATED_VALUE: + deprecation_warning("model_config", "config.training(model=..)", error=True) + if remote_worker_envs != DEPRECATED_VALUE: + deprecation_warning( + "remote_worker_envs", + "config.rollouts(remote_worker_envs=..)", + error=True, + ) + if remote_env_batch_wait_ms != DEPRECATED_VALUE: + deprecation_warning( + "remote_env_batch_wait_ms", + "config.rollouts(remote_env_batch_wait_ms=..)", + error=True, + ) + if soft_horizon != DEPRECATED_VALUE: + deprecation_warning( + "soft_horizon", "config.rollouts(soft_horizon=..)", error=True + ) + if no_done_at_end != DEPRECATED_VALUE: + deprecation_warning( + "no_done_at_end", "config.rollouts(no_done_at_end=..)", error=True + ) + if fake_sampler != DEPRECATED_VALUE: + deprecation_warning( + "fake_sampler", "config.rollouts(fake_sampler=..)", error=True + ) + if seed != DEPRECATED_VALUE: + deprecation_warning("seed", "config.debugging(seed=..)", error=True) + if log_level != DEPRECATED_VALUE: + deprecation_warning( + "log_level", "config.debugging(log_level=..)", error=True + ) + if callbacks != DEPRECATED_VALUE: + deprecation_warning( + "callbacks", "config.callbacks([DefaultCallbacks subclass])", error=True + ) + if disable_env_checking != DEPRECATED_VALUE: + deprecation_warning( + "disable_env_checking", + "config.environment(disable_env_checking=..)", + error=True, + ) + if policy_mapping_fn != DEPRECATED_VALUE: + deprecation_warning( + "policy_mapping_fn", + "config.multi_agent(policy_mapping_fn=..)", + error=True, + ) + if policies_to_train != DEPRECATED_VALUE: + deprecation_warning( + "policies_to_train", + "config.multi_agent(policies_to_train=..)", + error=True, + ) + if extra_python_environs != DEPRECATED_VALUE: + deprecation_warning( + "extra_python_environs", + "config.python_environment(extra_python_environs_for_driver=.., " + "extra_python_environs_for_worker=..)", + error=True, + ) self._original_kwargs: dict = locals().copy() del self._original_kwargs["self"] @@ -391,9 +452,18 @@ def __init__( global _global_worker _global_worker = self - # set extra environs first - if extra_python_environs: - for key, value in extra_python_environs.items(): + # Default config needed? + if config is None or isinstance(config, dict): + config = AlgorithmConfig().update_from_dict(config or {}) + # Freeze config, so no one else can alter it from here on. + config.freeze() + + # Set extra python env variables before calling super constructor. + if config.extra_python_environs_for_driver and worker_index == 0: + for key, value in config.extra_python_environs_for_driver.items(): + os.environ[key] = str(value) + elif config.extra_python_environs_for_worker and worker_index > 0: + for key, value in config.extra_python_environs_for_worker.items(): os.environ[key] = str(value) def gen_rollouts(): @@ -402,12 +472,23 @@ def gen_rollouts(): ParallelIteratorWorker.__init__(self, gen_rollouts, False) - policy_config = policy_config or {} + self.config = config + # TODO: Remove this backward compatibility. + # This property (old-style python config dict) should no longer be used! + self.policy_config = config.to_dict() + + self.num_workers = ( + num_workers if num_workers is not None else self.config.num_workers + ) + # In case we are reading from distributed datasets, store the shards here + # and pick our shard by our worker-index. + self._ds_shards = dataset_shards + self.worker_index: int = worker_index + if ( tf1 and ( - policy_config.get("framework") in ["tf2", "tfe"] - or policy_config.get("enable_tf1_exec_eagerly") + config.framework_str in ["tf2", "tfe"] or config.enable_tf1_exec_eagerly ) # This eager check is necessary for certain all-framework tests # that use tf's eager_mode() context generator. @@ -415,63 +496,57 @@ def gen_rollouts(): ): tf1.enable_eager_execution() - if log_level: - logging.getLogger("ray.rllib").setLevel(log_level) + if self.config.log_level: + logging.getLogger("ray.rllib").setLevel(self.config.log_level) - if worker_index > 1: + if self.worker_index > 1: disable_log_once_globally() # only need 1 worker to log - elif log_level == "DEBUG": + elif self.config.log_level == "DEBUG": enable_periodic_logging() env_context = EnvContext( - env_config or {}, - worker_index=worker_index, + self.config.env_config, + worker_index=self.worker_index, vector_index=0, - num_workers=num_workers, - remote=remote_worker_envs, + num_workers=self.num_workers, + remote=self.config.remote_worker_envs, recreated_worker=recreated_worker, ) self.env_context = env_context - self.policy_config: PartialAlgorithmConfigDict = policy_config - if callbacks: - self.callbacks: "DefaultCallbacks" = callbacks() - else: - from ray.rllib.algorithms.callbacks import DefaultCallbacks # noqa - - self.callbacks: DefaultCallbacks = DefaultCallbacks() - self.worker_index: int = worker_index - self.num_workers: int = num_workers + self.config: AlgorithmConfig = config + self.callbacks: DefaultCallbacks = self.config.callbacks_class() self.recreated_worker: bool = recreated_worker - model_config: ModelConfigDict = ( - model_config or self.policy_config.get("model") or {} - ) - # Default policy mapping fn is to always return DEFAULT_POLICY_ID, - # independent on the agent ID and the episode passed in. + # Setup current policy_mapping_fn. Start with the one from the config, which + # might be None in older checkpoints (nowadays AlgorithmConfig has a proper + # default for this); Need to cover this situation via the backup lambda here. self.policy_mapping_fn = ( - lambda agent_id, episode, worker, **kwargs: DEFAULT_POLICY_ID + lambda agent_id, episode, worker, **kw: DEFAULT_POLICY_ID ) - # If provided, set it here. - self.set_policy_mapping_fn(policy_mapping_fn) + self.set_policy_mapping_fn(self.config.policy_mapping_fn) self.env_creator: EnvCreator = env_creator - self.rollout_fragment_length: int = rollout_fragment_length * num_envs - self.count_steps_by: str = count_steps_by - self.batch_mode: str = batch_mode - self.compress_observations: bool = compress_observations - self.preprocessing_enabled: bool = not policy_config.get( - "_disable_preprocessor_api" + self.total_rollout_fragment_length: int = ( + self.config.rollout_fragment_length * self.config.num_envs_per_worker ) + self.preprocessing_enabled: bool = not config._disable_preprocessor_api self.last_batch: Optional[SampleBatchType] = None self.global_vars: Optional[dict] = None - self.fake_sampler: bool = fake_sampler - self._disable_env_checking: bool = disable_env_checking + + # If seed is provided, add worker index to it and 10k iff evaluation worker. + self.seed = ( + None + if self.config.seed is None + else self.config.seed + + self.worker_index + + self.config.in_evaluation * 10000 + ) # Update the global seed for numpy/random/tf-eager/torch if we are not # the local worker, otherwise, this was already done in the Trainer # object itself. if self.worker_index > 0: - update_global_seed_if_necessary(policy_config.get("framework"), seed) + update_global_seed_if_necessary(self.config.framework_str, self.seed) # A single environment provided by the user (via config.env). This may # also remain None. @@ -486,16 +561,18 @@ def gen_rollouts(): # Create a (single) env for this worker. if not ( - worker_index == 0 - and num_workers > 0 - and not policy_config.get("create_env_on_driver") + self.worker_index == 0 + and self.num_workers > 0 + and not self.config.create_env_on_local_worker ): # Run the `env_creator` function passing the EnvContext. self.env = env_creator(copy.deepcopy(self.env_context)) + clip_rewards = self.config.clip_rewards + if self.env is not None: # Validate environment (general validation function). - if not self._disable_env_checking: + if not self.config.disable_env_checking: check_env(self.env) # Custom validation function given, typically a function attribute of the # algorithm trainer. @@ -511,29 +588,29 @@ def wrap(env): # Atari type env and "deepmind" preprocessor pref. elif ( is_atari(self.env) - and not model_config.get("custom_preprocessor") - and preprocessor_pref == "deepmind" + and not self.config.model.get("custom_preprocessor") + and self.config.preprocessor_pref == "deepmind" ): # Deepmind wrappers already handle all preprocessing. self.preprocessing_enabled = False # If clip_rewards not explicitly set to False, switch it # on here (clip between -1.0 and 1.0). - if clip_rewards is None: + if self.config.clip_rewards is None: clip_rewards = True # Framestacking is used. - use_framestack = model_config.get("framestack") is True + use_framestack = self.config.model.get("framestack") is True def wrap(env): env = wrap_deepmind( - env, dim=model_config.get("dim"), framestack=use_framestack + env, dim=self.config.model.get("dim"), framestack=use_framestack ) return env elif ( - not model_config.get("custom_preprocessor") - and preprocessor_pref is None + not self.config.model.get("custom_preprocessor") + and self.config.preprocessor_pref is None ): # Only turn off preprocessing self.preprocessing_enabled = False @@ -552,7 +629,7 @@ def wrap(env): # to create self.env, but wrap(env) and self.env has a cyclic # dependency on each other right now, so we would settle on # duplicating the random seed setting logic for now. - _update_env_seed_if_necessary(self.env, seed, worker_index, 0) + _update_env_seed_if_necessary(self.env, self.seed, self.worker_index, 0) # Call custom callback function `on_sub_environment_created`. self.callbacks.on_sub_environment_created( worker=self, @@ -561,48 +638,38 @@ def wrap(env): ) self.make_sub_env_fn = self._get_make_sub_env_fn( - env_creator, env_context, validate_env, wrap, seed + env_creator, env_context, validate_env, wrap, self.seed ) self.spaces = spaces - - self.policy_dict = _determine_spaces_for_multi_agent_dict( - policy_spec, self.env, spaces=self.spaces, policy_config=policy_config + self.default_policy_class = default_policy_class + self.policy_dict, self.is_policy_to_train = self.config.get_multi_agent_setup( + env=self.env, + spaces=self.spaces, + default_policy_class=self.default_policy_class, ) - # Set of IDs of those policies, which should be trained. This property - # is optional and mainly used for backward compatibility. - self.policies_to_train = policies_to_train - self.is_policy_to_train: Callable[[PolicyID, SampleBatchType], bool] - - # By default (None), use the set of all policies found in the - # policy_dict. - if self.policies_to_train is None: - self.policies_to_train = set(self.policy_dict.keys()) - - self.set_is_policy_to_train(self.policies_to_train) - self.policy_map: PolicyMap = None # TODO(jungong) : clean up after non-connector env_runner is fully deprecated. self.preprocessors: Dict[PolicyID, Preprocessor] = None # Check available number of GPUs. num_gpus = ( - policy_config.get("num_gpus", 0) + self.config.num_gpus if self.worker_index == 0 - else policy_config.get("num_gpus_per_worker", 0) + else self.config.num_gpus_per_worker ) # Error if we don't find enough GPUs. if ( ray.is_initialized() and ray._private.worker._mode() != ray._private.worker.LOCAL_MODE - and not policy_config.get("_fake_gpus") + and not config._fake_gpus ): devices = [] - if policy_config.get("framework") in ["tf2", "tf", "tfe"]: + if self.config.framework_str in ["tf2", "tf", "tfe"]: devices = get_tf_gpu_devices() - elif policy_config.get("framework") == "torch": + elif self.config.framework_str == "torch": devices = list(range(torch.cuda.device_count())) if len(devices) < num_gpus: @@ -615,7 +682,7 @@ def wrap(env): ray.is_initialized() and ray._private.worker._mode() == ray._private.worker.LOCAL_MODE and num_gpus > 0 - and not policy_config.get("_fake_gpus") + and not self.config._fake_gpus ): logger.warning( "You are running ray with `local_mode=True`, but have " @@ -628,9 +695,9 @@ def wrap(env): self._build_policy_map( self.policy_dict, - policy_config, + config=self.config, session_creator=tf_session_creator, - seed=seed, + seed=self.seed, ) # Update Policy's view requirements from Model, only if Policy directly @@ -656,73 +723,71 @@ def wrap(env): if self.worker_index == 0: logger.info("Built filter map: {}".format(self.filters)) - # Vectorize environment, if any. - self.num_envs: int = num_envs # This RolloutWorker has no env. if self.env is None: self.async_env = None # Use a custom env-vectorizer and call it providing self.env. - elif "custom_vector_env" in policy_config: - self.async_env = policy_config["custom_vector_env"](self.env) + elif "custom_vector_env" in self.config: + self.async_env = self.config.custom_vector_env(self.env) # Default: Vectorize self.env via the make_sub_env function. This adds # further clones of self.env and creates a RLlib BaseEnv (which is # vectorized under the hood). else: - # Always use vector env for consistency even if num_envs = 1. + # Always use vector env for consistency even if num_envs_per_worker=1. self.async_env: BaseEnv = convert_to_base_env( self.env, make_env=self.make_sub_env_fn, - num_envs=num_envs, - remote_envs=remote_worker_envs, - remote_env_batch_wait_ms=remote_env_batch_wait_ms, + num_envs=self.config.num_envs_per_worker, + remote_envs=self.config.remote_worker_envs, + remote_env_batch_wait_ms=self.config.remote_env_batch_wait_ms, worker=self, - restart_failed_sub_environments=self.policy_config.get( - "restart_failed_sub_environments", False + restart_failed_sub_environments=( + self.config.restart_failed_sub_environments ), ) # `truncate_episodes`: Allow a batch to contain more than one episode # (fragments) and always make the batch `rollout_fragment_length` # long. - if self.batch_mode == "truncate_episodes": + rollout_fragment_length_for_sampler = self.config.rollout_fragment_length + if self.config.batch_mode == "truncate_episodes": pack = True # `complete_episodes`: Never cut episodes and sampler will return # exactly one (complete) episode per poll. - elif self.batch_mode == "complete_episodes": - rollout_fragment_length = float("inf") - pack = False else: - raise ValueError("Unsupported batch mode: {}".format(self.batch_mode)) + assert self.config.batch_mode == "complete_episodes" + rollout_fragment_length_for_sampler = float("inf") + pack = False # Create the IOContext for this worker. self.io_context: IOContext = IOContext( - log_dir, policy_config, worker_index, self + log_dir, self.config, self.worker_index, self ) render = False - if policy_config.get("render_env") is True and ( - num_workers == 0 or worker_index == 1 + if self.config.render_env is True and ( + self.num_workers == 0 or self.worker_index == 1 ): render = True if self.env is None: self.sampler = None - elif sample_async: + elif self.config.sample_async: self.sampler = AsyncSampler( worker=self, env=self.async_env, clip_rewards=clip_rewards, - rollout_fragment_length=rollout_fragment_length, - count_steps_by=count_steps_by, + rollout_fragment_length=rollout_fragment_length_for_sampler, + count_steps_by=self.config.count_steps_by, callbacks=self.callbacks, - horizon=episode_horizon, + horizon=self.config.horizon, multiple_episodes_in_batch=pack, - normalize_actions=normalize_actions, - clip_actions=clip_actions, - soft_horizon=soft_horizon, - no_done_at_end=no_done_at_end, - observation_fn=observation_fn, - sample_collector_class=policy_config.get("sample_collector"), + normalize_actions=self.config.normalize_actions, + clip_actions=self.config.clip_actions, + soft_horizon=self.config.soft_horizon, + no_done_at_end=self.config.no_done_at_end, + observation_fn=self.config.observation_fn, + sample_collector_class=self.config.sample_collector, render=render, ) # Start the Sampler thread. @@ -732,22 +797,26 @@ def wrap(env): worker=self, env=self.async_env, clip_rewards=clip_rewards, - rollout_fragment_length=rollout_fragment_length, - count_steps_by=count_steps_by, + rollout_fragment_length=rollout_fragment_length_for_sampler, + count_steps_by=self.config.count_steps_by, callbacks=self.callbacks, - horizon=episode_horizon, + horizon=self.config.horizon, multiple_episodes_in_batch=pack, - normalize_actions=normalize_actions, - clip_actions=clip_actions, - soft_horizon=soft_horizon, - no_done_at_end=no_done_at_end, - observation_fn=observation_fn, - sample_collector_class=policy_config.get("sample_collector"), + normalize_actions=self.config.normalize_actions, + clip_actions=self.config.clip_actions, + soft_horizon=self.config.soft_horizon, + no_done_at_end=self.config.no_done_at_end, + observation_fn=self.config.observation_fn, + sample_collector_class=self.config.sample_collector, render=render, ) - self.input_reader: InputReader = input_creator(self.io_context) - self.output_writer: OutputWriter = output_creator(self.io_context) + self.input_reader: InputReader = self._get_input_creator_from_config()( + self.io_context + ) + self.output_writer: OutputWriter = self._get_output_creator_from_config()( + self.io_context + ) # The current weights sequence number (version). May remain None for when # not tracking weights versions. @@ -790,11 +859,13 @@ def sample(self) -> SampleBatchType: >>> from ray.rllib.algorithms.pg.pg_tf_policy import PGTF1Policy >>> worker = RolloutWorker( # doctest: +SKIP ... env_creator=lambda _: gym.make("CartPole-v0"), # doctest: +SKIP - ... policy_spec=PGTF1Policy) # doctest: +SKIP + ... default_policy_class=PGTF1Policy, # doctest: +SKIP + ... config=AlgorithmConfig(), # doctest: +SKIP + ... ) >>> print(worker.sample()) # doctest: +SKIP SampleBatch({"obs": [...], "action": [...], ...}) """ - if self.fake_sampler and self.last_batch is not None: + if self.config.fake_sampler and self.last_batch is not None: return self.last_batch elif self.input_reader is None: raise ValueError( @@ -806,30 +877,33 @@ def sample(self) -> SampleBatchType: if log_once("sample_start"): logger.info( "Generating sample batch of size {}".format( - self.rollout_fragment_length + self.total_rollout_fragment_length ) ) batches = [self.input_reader.next()] steps_so_far = ( batches[0].count - if self.count_steps_by == "env_steps" + if self.config.count_steps_by == "env_steps" else batches[0].agent_steps() ) # In truncate_episodes mode, never pull more than 1 batch per env. # This avoids over-running the target batch size. - if self.batch_mode == "truncate_episodes": - max_batches = self.num_envs + if ( + self.config.batch_mode == "truncate_episodes" + and not self.config.offline_sampling + ): + max_batches = self.config.num_envs_per_worker else: max_batches = float("inf") - while steps_so_far < self.rollout_fragment_length and ( - len(batches) < max_batches or self.policy_config.get("offline_sampling") + while steps_so_far < self.total_rollout_fragment_length and ( + len(batches) < max_batches ): batch = self.input_reader.next() steps_so_far += ( batch.count - if self.count_steps_by == "env_steps" + if self.config.count_steps_by == "env_steps" else batch.agent_steps() ) batches.append(batch) @@ -844,10 +918,10 @@ def sample(self) -> SampleBatchType: if log_once("sample_end"): logger.info("Completed sample batch:\n\n{}\n".format(summarize(batch))) - if self.compress_observations: - batch.compress(bulk=self.compress_observations == "bulk") + if self.config.compress_observations: + batch.compress(bulk=self.config.compress_observations == "bulk") - if self.fake_sampler: + if self.config.fake_sampler: self.last_batch = batch return batch @@ -866,7 +940,7 @@ def sample_with_count(self) -> Tuple[SampleBatchType, int]: >>> from ray.rllib.algorithms.pg.pg_tf_policy import PGTF1Policy >>> worker = RolloutWorker( # doctest: +SKIP ... env_creator=lambda _: gym.make("CartPole-v0"), # doctest: +SKIP - ... policy_spec=PGTFPolicy) # doctest: +SKIP + ... default_policy_class=PGTFPolicy) # doctest: +SKIP >>> print(worker.sample_with_count()) # doctest: +SKIP (SampleBatch({"obs": [...], "action": [...], ...}), 3) """ @@ -892,7 +966,7 @@ def learn_on_batch(self, samples: SampleBatchType) -> Dict: >>> from ray.rllib.algorithms.pg.pg_tf_policy import PGTF1Policy >>> worker = RolloutWorker( # doctest: +SKIP ... env_creator=lambda _: gym.make("CartPole-v0"), # doctest: +SKIP - ... policy_spec=PGTF1Policy) # doctest: +SKIP + ... default_policy_class=PGTF1Policy) # doctest: +SKIP >>> batch = worker.sample() # doctest: +SKIP >>> info = worker.learn_on_batch(samples) # doctest: +SKIP """ @@ -908,7 +982,9 @@ def learn_on_batch(self, samples: SampleBatchType) -> Dict: builders = {} to_fetch = {} for pid, batch in samples.policy_batches.items(): - if not self.is_policy_to_train(pid, samples): + if self.is_policy_to_train is not None and not self.is_policy_to_train( + pid, samples + ): continue # Decompress SampleBatch, in case some columns are compressed. batch.decompress_if_needed() @@ -921,7 +997,9 @@ def learn_on_batch(self, samples: SampleBatchType) -> Dict: info_out[pid] = policy.learn_on_batch(batch) info_out.update({pid: builders[pid].get(v) for pid, v in to_fetch.items()}) else: - if self.is_policy_to_train(DEFAULT_POLICY_ID, samples): + if self.is_policy_to_train is None or self.is_policy_to_train( + DEFAULT_POLICY_ID, samples + ): info_out.update( { DEFAULT_POLICY_ID: self.policy_map[ @@ -1009,7 +1087,7 @@ def compute_gradients( >>> from ray.rllib.algorithms.pg.pg_tf_policy import PGTF1Policy >>> worker = RolloutWorker( # doctest: +SKIP ... env_creator=lambda _: gym.make("CartPole-v0"), # doctest: +SKIP - ... policy_spec=PGTF1Policy) # doctest: +SKIP + ... default_policy_class=PGTF1Policy) # doctest: +SKIP >>> batch = worker.sample() # doctest: +SKIP >>> grads, info = worker.compute_gradients(samples) # doctest: +SKIP """ @@ -1032,9 +1110,11 @@ def compute_gradients( # Calculate gradients for all policies. grad_out, info_out = {}, {} - if self.policy_config.get("framework") == "tf": + if self.config.framework_str == "tf": for pid, batch in samples.policy_batches.items(): - if not self.is_policy_to_train(pid, samples): + if self.is_policy_to_train is not None and not self.is_policy_to_train( + pid, samples + ): continue policy = self.policy_map[pid] builder = _TFRunBuilder(policy.get_session(), "compute_gradients") @@ -1045,7 +1125,9 @@ def compute_gradients( info_out = {k: builder.get(v) for k, v in info_out.items()} else: for pid, batch in samples.policy_batches.items(): - if not self.is_policy_to_train(pid, samples): + if self.is_policy_to_train is not None and not self.is_policy_to_train( + pid, samples + ): continue grad_out[pid], info_out[pid] = self.policy_map[pid].compute_gradients( batch @@ -1078,7 +1160,7 @@ def apply_gradients( >>> from ray.rllib.algorithms.pg.pg_tf_policy import PGTF1Policy >>> worker = RolloutWorker( # doctest: +SKIP ... env_creator=lambda _: gym.make("CartPole-v0"), # doctest: +SKIP - ... policy_spec=PGTF1Policy) # doctest: +SKIP + ... default_policy_class=PGTF1Policy) # doctest: +SKIP >>> samples = worker.sample() # doctest: +SKIP >>> grads, info = worker.compute_gradients(samples) # doctest: +SKIP >>> worker.apply_gradients(grads) # doctest: +SKIP @@ -1089,10 +1171,14 @@ def apply_gradients( # Multi-agent case. if isinstance(grads, dict): for pid, g in grads.items(): - if self.is_policy_to_train(pid, None): + if self.is_policy_to_train is None or self.is_policy_to_train( + pid, None + ): self.policy_map[pid].apply_gradients(g) # Grads is a ModelGradients type. Single-agent case. - elif self.is_policy_to_train(DEFAULT_POLICY_ID, None): + elif self.is_policy_to_train is None or self.is_policy_to_train( + DEFAULT_POLICY_ID, None + ): self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads) @DeveloperAPI @@ -1243,15 +1329,15 @@ def add_policy( ) if policy is None: - policy_dict_to_add = _determine_spaces_for_multi_agent_dict( - { + policy_dict_to_add, _ = self.config.get_multi_agent_setup( + policies={ policy_id: PolicySpec( - policy_cls, observation_space, action_space, config or {} + policy_cls, observation_space, action_space, config ) }, - self.env, + env=self.env, spaces=self.spaces, - policy_config=self.policy_config, + default_policy_class=self.default_policy_class, ) else: policy_dict_to_add = { @@ -1266,9 +1352,9 @@ def add_policy( self.policy_dict.update(policy_dict_to_add) self._build_policy_map( policy_dict=policy_dict_to_add, - policy_config=self.policy_config, + config=self.config, policy=policy, - seed=self.policy_config.get("seed"), + seed=self.seed, ) new_policy = self.policy_map[policy_id] @@ -1352,8 +1438,8 @@ def set_is_policy_to_train( # If container given, construct a simple default callable returning True # if the PolicyID is found in the list/set of IDs. if not callable(is_policy_to_train): - assert isinstance(is_policy_to_train, Container), ( - "ERROR: `is_policy_to_train`must be a container or a " + assert isinstance(is_policy_to_train, (list, set, tuple)), ( + "ERROR: `is_policy_to_train`must be a [list|set|tuple] or a " "callable taking PolicyID and SampleBatch and returning " "True|False (trainable or not?)." ) @@ -1382,7 +1468,9 @@ def get_policies_to_train( `batch`. """ return { - pid for pid in self.policy_map.keys() if self.is_policy_to_train(pid, batch) + pid + for pid in self.policy_map.keys() + if self.is_policy_to_train is None or self.is_policy_to_train(pid, batch) } @DeveloperAPI @@ -1456,7 +1544,7 @@ def foreach_policy_to_train( # unnecessary, making subsequent disk access unnecessary. func(self.policy_map[pid], pid, **kwargs) for pid in self.policy_map.keys() - if self.is_policy_to_train(pid, None) + if self.is_policy_to_train is None or self.is_policy_to_train(pid, None) ] @DeveloperAPI @@ -1540,7 +1628,7 @@ def set_state(self, state: dict) -> None: # key in `state` entirely (will be part of the policies then). self.sync_filters(state["filters"]) - connector_enabled = self.policy_config.get("enable_connectors", False) + connector_enabled = self.config.enable_connectors # Support older checkpoint versions (< 1.0), in which the policy_map # was stored under the "state" key, not "policy_states". @@ -1563,7 +1651,9 @@ def set_state(self, state: dict) -> None: ) else: policy_spec = ( - PolicySpec.deserialize(spec) if connector_enabled else spec + PolicySpec.deserialize(spec) + if connector_enabled or isinstance(spec, dict) + else spec ) self.add_policy( policy_id=pid, @@ -1578,7 +1668,7 @@ def set_state(self, state: dict) -> None: # Also restore mapping fn and which policies to train. if "policy_mapping_fn" in state: self.set_policy_mapping_fn(state["policy_mapping_fn"]) - if "is_policy_to_train" in state: + if state.get("is_policy_to_train") is not None: self.set_is_policy_to_train(state["is_policy_to_train"]) @DeveloperAPI @@ -1788,7 +1878,7 @@ def __del__(self): def _build_policy_map( self, policy_dict: MultiAgentPolicyConfigDict, - policy_config: PartialAlgorithmConfigDict, + config: "AlgorithmConfig", policy: Optional[Policy] = None, session_creator: Optional[Callable[[], "tf1.Session"]] = None, seed: Optional[int] = None, @@ -1798,7 +1888,7 @@ def _build_policy_map( Args: policy_dict: The MultiAgentPolicyConfigDict to be added to this worker's PolicyMap. - policy_config: The general policy config to use. May be updated + config: The general AlgorithmConfig to use. May be updated by individual policy config overrides in the given multi-agent `policy_dict`. policy: If the policy to add already exists, user can provide it here. @@ -1807,15 +1897,14 @@ def _build_policy_map( seed: An optional random seed to pass to PolicyMap's constructor. """ - ma_config = policy_config.get("multiagent", {}) + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig # If our policy_map does not exist yet, create it here. self.policy_map = self.policy_map or PolicyMap( worker_index=self.worker_index, num_workers=self.num_workers, - capacity=ma_config.get("policy_map_capacity"), - path=ma_config.get("policy_map_cache"), - policy_config=policy_config, + capacity=config.policy_map_capacity, + path=config.policy_map_cache, session_creator=session_creator, seed=seed, ) @@ -1825,15 +1914,18 @@ def _build_policy_map( # Loop through given policy-dict and add each entry to our map. for name, policy_spec in sorted(policy_dict.items()): logger.debug("Creating policy for {}".format(name)) - # Update the general policy_config with the specific config + # Update the general config with the specific config # for this particular policy. - merged_conf = merge_dicts(policy_config, policy_spec.config or {}) + merged_conf: "AlgorithmConfig" = config.copy(copy_frozen=False) + update_dict = ( + policy_spec.config.to_dict() + if isinstance(policy_spec.config, AlgorithmConfig) + else policy_spec.config + ) + merged_conf.update_from_dict(update_dict or {}) # Update num_workers and worker_index. - merged_conf["num_workers"] = self.num_workers - merged_conf["worker_index"] = self.worker_index - - connectors_enabled = policy_config.get("enable_connectors", False) + merged_conf.worker_index = self.worker_index # Preprocessors. obs_space = policy_spec.observation_space @@ -1843,14 +1935,14 @@ def _build_policy_map( # Policies should deal with preprocessed (automatically flattened) # observations if preprocessing is enabled. preprocessor = ModelCatalog.get_preprocessor_for_space( - obs_space, merged_conf.get("model") + obs_space, merged_conf.model ) # Original observation space should be accessible at # obs_space.original_space after this step. if preprocessor is not None: obs_space = preprocessor.observation_space - if not connectors_enabled: + if not merged_conf.enable_connectors: # If connectors are not enabled, rollout worker will handle # the running of these preprocessors. self.preprocessors[name] = preprocessor @@ -1864,12 +1956,12 @@ def _build_policy_map( policy_spec.policy_class, obs_space, policy_spec.action_space, - policy_spec.config, # overrides. - merged_conf, + config_override=None, + merged_config=merged_conf, ) new_policy = self.policy_map[name] - if connectors_enabled: + if merged_conf.enable_connectors: create_connectors_for_policy(new_policy, merged_conf) maybe_get_filters_for_syncing(self, name) else: @@ -1883,7 +1975,7 @@ def _build_policy_map( ) self.filters[name] = get_filter( - (merged_conf or {}).get("observation_filter", "NoFilter"), + merged_conf.observation_filter, filter_shape, ) @@ -1896,10 +1988,93 @@ def _build_policy_map( logger.info(f"Built policy map: {self.policy_map}") logger.info(f"Built preprocessor map: {self.preprocessors}") + def _get_input_creator_from_config(self): + def valid_module(class_path): + if ( + isinstance(class_path, str) + and not os.path.isfile(class_path) + and "." in class_path + ): + module_path, class_name = class_path.rsplit(".", 1) + try: + spec = importlib.util.find_spec(module_path) + if spec is not None: + return True + except (ModuleNotFoundError, ValueError): + print( + f"module {module_path} not found while trying to get " + f"input {class_path}" + ) + return False + + # A callable returning an InputReader object to use. + if isinstance(self.config.input_, FunctionType): + return self.config.input_ + # Use RLlib's Sampler classes (SyncSampler or AsynchSampler, depending + # on `config.sample_async` setting). + elif self.config.input_ == "sampler": + return lambda ioctx: ioctx.default_sampler_input() + # Ray Dataset input -> Use `config.input_config` to construct DatasetReader. + elif self.config.input_ == "dataset": + assert self._ds_shards is not None + # Input dataset shards should have already been prepared. + # We just need to take the proper shard here. + return lambda ioctx: DatasetReader( + self._ds_shards[self.worker_index], ioctx + ) + # Dict: Mix of different input methods with different ratios. + elif isinstance(self.config.input_, dict): + return lambda ioctx: ShuffledInput( + MixedInput(self.config.input_, ioctx), self.config.shuffle_buffer_size + ) + # A pre-registered input descriptor (str). + elif isinstance(self.config.input_, str) and registry_contains_input( + self.config.input_ + ): + return registry_get_input(self.config.input_) + # D4RL input. + elif "d4rl" in self.config.input_: + env_name = self.config.input_.split(".")[-1] + return lambda ioctx: D4RLReader(env_name, ioctx) + # Valid python module (class path) -> Create using `from_config`. + elif valid_module(self.config.input_): + return lambda ioctx: ShuffledInput( + from_config(self.config.input_, ioctx=ioctx) + ) + # JSON file or list of JSON files -> Use JsonReader (shuffled). + else: + return lambda ioctx: ShuffledInput( + JsonReader(self.config.input_, ioctx), self.config.shuffle_buffer_size + ) + + def _get_output_creator_from_config(self): + if isinstance(self.config.output, FunctionType): + return self.config.output + elif self.config.output is None: + return lambda ioctx: NoopOutput() + elif self.config.output == "dataset": + return lambda ioctx: DatasetWriter( + ioctx, compress_columns=self.config.output_compress_columns + ) + elif self.config.output == "logdir": + return lambda ioctx: JsonWriter( + ioctx.log_dir, + ioctx, + max_file_size=self.config.output_max_file_size, + compress_columns=self.config.output_compress_columns, + ) + else: + return lambda ioctx: JsonWriter( + self.config.output, + ioctx, + max_file_size=self.config.output_max_file_size, + compress_columns=self.config.output_compress_columns, + ) + def _get_make_sub_env_fn( self, env_creator, env_context, validate_env, env_wrapper, seed ): - disable_env_checking = self._disable_env_checking + disable_env_checking = self.config.disable_env_checking def _make_sub_env_local(vector_index): # Used to created additional environments during environment @@ -2001,154 +2176,3 @@ def save(self): def restore(self, objs): state_dict = pickle.loads(objs) self.set_state(state_dict) - - -def _determine_spaces_for_multi_agent_dict( - multi_agent_policies_dict: MultiAgentPolicyConfigDict, - env: Optional[EnvType] = None, - spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None, - policy_config: Optional[PartialAlgorithmConfigDict] = None, -) -> MultiAgentPolicyConfigDict: - """Infers the observation- and action spaces in a multi-agent policy dict. - - Args: - multi_agent_policies_dict: The multi-agent `policies` dict mapping policy IDs - to PolicySpec objects. Note that the `observation_space` and `action_space` - properties in these PolicySpecs may be None and must therefore be inferred - here. - env: An optional env instance, from which to infer the different spaces for - the different policies. - spaces: Optional dict mapping policy IDs to tuples of 1) observation space - and 2) action space that should be used for the respective policy. - These spaces were usually provided by an already instantiated remote worker. - policy_config: Optional partial config dict of the Trainer. - - Returns: - The updated MultiAgentPolicyConfigDict (changed in-place from the incoming - `multi_agent_policies_dict` arg). - """ - policy_config = policy_config or {} - - # Try extracting spaces from env or from given spaces dict. - env_obs_space = None - env_act_space = None - - # Env is a ray.remote: Get spaces via its (automatically added) - # `_get_spaces()` method. - if isinstance(env, ray.actor.ActorHandle): - env_obs_space, env_act_space = ray.get(env._get_spaces.remote()) - # Normal env (gym.Env or MultiAgentEnv): These should have the - # `observation_space` and `action_space` properties. - elif env is not None: - if hasattr(env, "observation_space") and isinstance( - env.observation_space, gym.Space - ): - env_obs_space = env.observation_space - - if hasattr(env, "action_space") and isinstance(env.action_space, gym.Space): - env_act_space = env.action_space - # Last resort: Try getting the env's spaces from the spaces - # dict's special __env__ key. - if spaces is not None: - if env_obs_space is None: - env_obs_space = spaces.get("__env__", [None])[0] - if env_act_space is None: - env_act_space = spaces.get("__env__", [None, None])[1] - - for pid, policy_spec in multi_agent_policies_dict.copy().items(): - if policy_spec.observation_space is None: - if spaces is not None and pid in spaces: - obs_space = spaces[pid][0] - elif env_obs_space is not None: - # Multi-agent case AND different agents have different spaces: - # Need to reverse map spaces (for the different agents) to certain - # policy IDs. - if ( - isinstance(env, MultiAgentEnv) - and hasattr(env, "_spaces_in_preferred_format") - and env._spaces_in_preferred_format - ): - obs_space = None - mapping_fn = policy_config.get("multiagent", {}).get( - "policy_mapping_fn", None - ) - if mapping_fn: - for aid in env.get_agent_ids(): - # Match: Assign spaces for this agentID to the policy ID. - if mapping_fn(aid, None, None) == pid: - # Make sure, different agents that map to the same - # policy don't have different spaces. - if ( - obs_space is not None - and env_obs_space[aid] != obs_space - ): - raise ValueError( - "Two agents in your environment map to the same" - " policyID (as per your `policy_mapping_fn`), " - "however, these agents also have different " - "observation spaces!" - ) - obs_space = env_obs_space[aid] - # Otherwise, just use env's obs space as-is. - else: - obs_space = env_obs_space - # Space given directly in config. - elif policy_config.get("observation_space"): - obs_space = policy_config["observation_space"] - else: - raise ValueError( - "`observation_space` not provided in PolicySpec for " - f"{pid} and env does not have an observation space OR " - "no spaces received from other workers' env(s) OR no " - "`observation_space` specified in config!" - ) - - multi_agent_policies_dict[pid].observation_space = obs_space - - if policy_spec.action_space is None: - if spaces is not None and pid in spaces: - act_space = spaces[pid][1] - elif env_act_space is not None: - # Multi-agent case AND different agents have different spaces: - # Need to reverse map spaces (for the different agents) to certain - # policy IDs. - if ( - isinstance(env, MultiAgentEnv) - and hasattr(env, "_spaces_in_preferred_format") - and env._spaces_in_preferred_format - ): - act_space = None - mapping_fn = policy_config.get("multiagent", {}).get( - "policy_mapping_fn", None - ) - if mapping_fn: - for aid in env.get_agent_ids(): - # Match: Assign spaces for this agentID to the policy ID. - if mapping_fn(aid, None, None) == pid: - # Make sure, different agents that map to the same - # policy don't have different spaces. - if ( - act_space is not None - and env_act_space[aid] != act_space - ): - raise ValueError( - "Two agents in your environment map to the same" - " policyID (as per your `policy_mapping_fn`), " - "however, these agents also have different " - "action spaces!" - ) - act_space = env_act_space[aid] - # Otherwise, just use env's action space as-is. - else: - act_space = env_act_space - elif policy_config.get("action_space"): - act_space = policy_config["action_space"] - else: - raise ValueError( - "`action_space` not provided in PolicySpec for " - f"{pid} and env does not have an action space OR " - "no spaces received from other workers' env(s) OR no " - "`action_space` specified in config!" - ) - multi_agent_policies_dict[pid].action_space = act_space - return multi_agent_policies_dict diff --git a/rllib/evaluation/sampler.py b/rllib/evaluation/sampler.py index 28037a215afcb..e14fe8ccf00a1 100644 --- a/rllib/evaluation/sampler.py +++ b/rllib/evaluation/sampler.py @@ -228,7 +228,7 @@ def __init__( self.horizon = horizon self.extra_batches = queue.Queue() self.perf_stats = _PerfStats( - ema_coef=worker.policy_config.get("sampler_perf_stats_ema_coef"), + ema_coef=worker.config.sampler_perf_stats_ema_coef, ) if not sample_collector_class: sample_collector_class = SimpleListCollector @@ -242,7 +242,7 @@ def __init__( ) self.render = render - if worker.policy_config.get("enable_connectors", False): + if worker.config.enable_connectors: # Keep a reference to the underlying EnvRunnerV2 instance for # unit testing purpose. self._env_runner_obj = EnvRunnerV2( @@ -425,7 +425,7 @@ def __init__( self.soft_horizon = soft_horizon self.no_done_at_end = no_done_at_end self.perf_stats = _PerfStats( - ema_coef=worker.policy_config.get("sampler_perf_stats_ema_coef"), + ema_coef=worker.config.sampler_perf_stats_ema_coef, ) self.shutdown = False self.observation_fn = observation_fn @@ -454,7 +454,7 @@ def _run(self): # We are in a thread: Switch on eager execution mode, iff framework==tf2|tfe. if ( tf1 - and self.worker.policy_config.get("framework", "tf") in ["tf2", "tfe"] + and self.worker.config.framework_str in ["tf2", "tfe"] and not tf1.executing_eagerly() ): tf1.enable_eager_execution() @@ -465,7 +465,7 @@ def _run(self): else: queue_putter = self.queue.put extra_batches_putter = lambda x: self.extra_batches.put(x, timeout=600.0) - if self.worker.policy_config.get("enable_connectors", False): + if self.worker.config.enable_connectors: env_runner = EnvRunnerV2( worker=self.worker, base_env=self.base_env, diff --git a/rllib/evaluation/tests/test_episode.py b/rllib/evaluation/tests/test_episode.py index 89f64dfe80a05..42f7e0b893ca5 100644 --- a/rllib/evaluation/tests/test_episode.py +++ b/rllib/evaluation/tests/test_episode.py @@ -1,6 +1,7 @@ import ray import unittest import numpy as np +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.evaluation.rollout_worker import RolloutWorker @@ -107,29 +108,27 @@ def setUpClass(cls): def tearDownClass(cls): ray.shutdown() - def test_singleagent_env(self): + def test_single_agent_env(self): ev = RolloutWorker( env_creator=lambda _: MockEnv3(NUM_STEPS), - policy_spec=EchoPolicy, - callbacks=LastInfoCallback, + default_policy_class=EchoPolicy, + config=AlgorithmConfig() + .rollouts(num_rollout_workers=0) + .callbacks(LastInfoCallback), ) ev.sample() - def test_multiagent_env(self): - temp_env = EpisodeEnv(NUM_STEPS, NUM_AGENTS) + def test_multi_agent_env(self): ev = RolloutWorker( env_creator=lambda _: EpisodeEnv(NUM_STEPS, NUM_AGENTS), - policy_spec={ - str(agent_id): ( - EchoPolicy, - temp_env.observation_space, - temp_env.action_space, - {}, - ) - for agent_id in range(NUM_AGENTS) - }, - policy_mapping_fn=lambda aid, eps, **kwargs: str(aid), - callbacks=LastInfoCallback, + default_policy_class=EchoPolicy, + config=AlgorithmConfig() + .rollouts(num_rollout_workers=0) + .callbacks(LastInfoCallback) + .multi_agent( + policies={str(agent_id) for agent_id in range(NUM_AGENTS)}, + policy_mapping_fn=lambda aid, eps, **kwargs: str(aid), + ), ) ev.sample() diff --git a/rllib/evaluation/tests/test_episode_v2.py b/rllib/evaluation/tests/test_episode_v2.py index 43058b6bbbf7a..84c0dfa085bda 100644 --- a/rllib/evaluation/tests/test_episode_v2.py +++ b/rllib/evaluation/tests/test_episode_v2.py @@ -1,6 +1,7 @@ import unittest import ray +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.examples.env.mock_env import MockEnv3 @@ -66,34 +67,32 @@ def tearDownClass(cls): def test_singleagent_env(self): ev = RolloutWorker( env_creator=lambda _: MockEnv3(NUM_STEPS), - policy_spec=EchoPolicy, + default_policy_class=EchoPolicy, + config=AlgorithmConfig().rollouts(num_rollout_workers=0), ) sample_batch = ev.sample() - self.assertEqual(sample_batch.count, 100) + self.assertEqual(sample_batch.count, 200) # A batch of 100. 4 episodes, each 25. - self.assertEqual(len(set(sample_batch["eps_id"])), 4) + self.assertEqual(len(set(sample_batch["eps_id"])), 8) def test_multiagent_env(self): temp_env = EpisodeEnv(NUM_STEPS, NUM_AGENTS) ev = RolloutWorker( env_creator=lambda _: temp_env, - policy_spec={ - str(agent_id): ( - EchoPolicy, - temp_env.observation_space, - temp_env.action_space, - {}, - ) - for agent_id in range(NUM_AGENTS) - }, - policy_mapping_fn=lambda aid, eps, **kwargs: str(aid), + default_policy_class=EchoPolicy, + config=AlgorithmConfig() + .multi_agent( + policies={str(agent_id) for agent_id in range(NUM_AGENTS)}, + policy_mapping_fn=lambda aid, eps, **kwargs: str(aid), + ) + .rollouts(num_rollout_workers=0), ) sample_batches = ev.sample() self.assertEqual(len(sample_batches.policy_batches), 4) for agent_id, sample_batch in sample_batches.policy_batches.items(): - self.assertEqual(sample_batch.count, 100) + self.assertEqual(sample_batch.count, 200) # A batch of 100. 4 episodes, each 25. - self.assertEqual(len(set(sample_batch["eps_id"])), 4) + self.assertEqual(len(set(sample_batch["eps_id"])), 8) if __name__ == "__main__": diff --git a/rllib/evaluation/tests/test_rollout_worker.py b/rllib/evaluation/tests/test_rollout_worker.py index 04f834b686018..89fed21a6d687 100644 --- a/rllib/evaluation/tests/test_rollout_worker.py +++ b/rllib/evaluation/tests/test_rollout_worker.py @@ -10,6 +10,7 @@ import ray from ray.rllib.algorithms.a2c import A2C +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.algorithms.pg import PG from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.evaluation.rollout_worker import RolloutWorker @@ -98,7 +99,9 @@ def tearDownClass(cls): def test_basic(self): ev = RolloutWorker( - env_creator=lambda _: gym.make("CartPole-v0"), policy_spec=MockPolicy + env_creator=lambda _: gym.make("CartPole-v0"), + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts(num_rollout_workers=0), ) batch = ev.sample() for key in [ @@ -129,8 +132,10 @@ def test_batch_ids(self): fragment_len = 100 ev = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v0"), - policy_spec=MockPolicy, - rollout_fragment_length=fragment_len, + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + rollout_fragment_length=fragment_len, num_rollout_workers=0 + ), ) batch1 = ev.sample() batch2 = ev.sample() @@ -211,9 +216,11 @@ def test_query_evaluators(self): "create_env_on_driver": True, }, ) - results = pg.workers.foreach_worker(lambda ev: ev.rollout_fragment_length) + results = pg.workers.foreach_worker( + lambda ev: ev.total_rollout_fragment_length + ) results2 = pg.workers.foreach_worker_with_index( - lambda ev, i: (i, ev.rollout_fragment_length) + lambda ev, i: (i, ev.total_rollout_fragment_length) ) results3 = pg.workers.foreach_worker( lambda ev: ev.foreach_env(lambda env: 1) @@ -238,14 +245,18 @@ def test_action_clipping(self): check_action_bounds=True, ) ), - policy_spec=RandomPolicy, - policy_config=dict( - action_space=action_space, - ignore_action_bounds=True, + config=AlgorithmConfig() + .multi_agent( + policies={ + "default_policy": PolicySpec( + policy_class=RandomPolicy, config={"ignore_action_bounds": True} + ) + } + ) + .rollouts(num_rollout_workers=0, batch_mode="complete_episodes") + .environment( + action_space=action_space, normalize_actions=False, clip_actions=True ), - normalize_actions=False, - clip_actions=True, - batch_mode="complete_episodes", ) sample = ev.sample() # Check, whether the action bounds have been breached (expected). @@ -266,16 +277,22 @@ def test_action_clipping(self): check_action_bounds=True, ) ), - policy_spec=RandomPolicy, - policy_config=dict( - action_space=action_space, - ignore_action_bounds=True, - ), # No normalization (+clipping) and no clipping -> # Should lead to Env complaining. - normalize_actions=False, - clip_actions=False, - batch_mode="complete_episodes", + config=AlgorithmConfig() + .environment( + normalize_actions=False, + clip_actions=False, + action_space=action_space, + ) + .rollouts(batch_mode="complete_episodes", num_rollout_workers=0) + .multi_agent( + policies={ + "default_policy": PolicySpec( + policy_class=RandomPolicy, config={"ignore_action_bounds": True} + ) + } + ), ) self.assertRaisesRegex(ValueError, r"Illegal action", ev2.sample) ev2.stop() @@ -291,12 +308,14 @@ def test_action_clipping(self): check_action_bounds=True, ) ), - policy_spec=RandomPolicy, - policy_config=dict(action_space=action_space), + default_policy_class=RandomPolicy, + config=AlgorithmConfig().rollouts( + num_rollout_workers=0, batch_mode="complete_episodes" + ) # Should not be a problem as RandomPolicy abides to bounds. - normalize_actions=False, - clip_actions=False, - batch_mode="complete_episodes", + .environment( + action_space=action_space, normalize_actions=False, clip_actions=False + ), ) sample = ev3.sample() self.assertGreater(np.min(sample["actions"]), action_space.low[0]) @@ -318,14 +337,18 @@ def test_action_normalization(self): check_action_bounds=True, ) ), - policy_spec=RandomPolicy, - policy_config=dict( - action_space=action_space, - ignore_action_bounds=True, + config=AlgorithmConfig() + .multi_agent( + policies={ + "default_policy": PolicySpec( + policy_class=RandomPolicy, config={"ignore_action_bounds": True} + ) + } + ) + .rollouts(num_rollout_workers=0, batch_mode="complete_episodes") + .environment( + action_space=action_space, normalize_actions=True, clip_actions=False ), - normalize_actions=True, - clip_actions=False, - batch_mode="complete_episodes", ) sample = ev.sample() # Check, whether the action bounds have been breached (expected). @@ -384,16 +407,22 @@ def json_reader_creator(ioctx): for actions_in_input_normalized, normalize_actions in parameters: ev = RolloutWorker( env_creator=lambda _: env, - policy_spec=MockPolicy, - policy_config=dict( - actions_in_input_normalized=actions_in_input_normalized, + default_policy_class=MockPolicy, + config=AlgorithmConfig() + .rollouts( + num_rollout_workers=0, + rollout_fragment_length=1, + ) + .environment( normalize_actions=normalize_actions, clip_actions=False, + ) + .training(train_batch_size=1) + .offline_data( offline_sampling=True, - train_batch_size=1, + actions_in_input_normalized=actions_in_input_normalized, + input_=input_creator, ), - rollout_fragment_length=1, - input_creator=input_creator, ) sample = ev.sample() @@ -445,13 +474,16 @@ def step(self, action): check_action_bounds=True, ) ), - policy_spec=RandomPolicy, - policy_config=dict( - action_space=action_space, - ignore_action_bounds=True, - ), - clip_actions=False, - batch_mode="complete_episodes", + config=AlgorithmConfig() + .multi_agent( + policies={ + "default_policy": PolicySpec( + policy_class=RandomPolicy, config={"ignore_action_bounds": True} + ) + } + ) + .environment(action_space=action_space, clip_actions=False) + .rollouts(batch_mode="complete_episodes", num_rollout_workers=0), ) ev.sample() ev.stop() @@ -460,9 +492,10 @@ def test_reward_clipping(self): # Clipping: True (clip between -1.0 and 1.0). ev = RolloutWorker( env_creator=lambda _: MockEnv2(episode_length=10), - policy_spec=MockPolicy, - clip_rewards=True, - batch_mode="complete_episodes", + default_policy_class=MockPolicy, + config=AlgorithmConfig() + .rollouts(num_rollout_workers=0, batch_mode="complete_episodes") + .environment(clip_rewards=True), ) self.assertEqual(max(ev.sample()["rewards"]), 1) result = collect_metrics(ev, []) @@ -480,9 +513,10 @@ def test_reward_clipping(self): max_episode_len=10, ) ), - policy_spec=MockPolicy, - clip_rewards=2.0, - batch_mode="complete_episodes", + default_policy_class=MockPolicy, + config=AlgorithmConfig() + .rollouts(num_rollout_workers=0, batch_mode="complete_episodes") + .environment(clip_rewards=2.0), ) sample = ev2.sample() self.assertEqual(max(sample["rewards"]), 2.0) @@ -494,9 +528,10 @@ def test_reward_clipping(self): # Clipping: Off. ev2 = RolloutWorker( env_creator=lambda _: MockEnv2(episode_length=10), - policy_spec=MockPolicy, - clip_rewards=False, - batch_mode="complete_episodes", + default_policy_class=MockPolicy, + config=AlgorithmConfig() + .rollouts(num_rollout_workers=0, batch_mode="complete_episodes") + .environment(clip_rewards=False), ) self.assertEqual(max(ev2.sample()["rewards"]), 100) result2 = collect_metrics(ev2, []) @@ -506,11 +541,14 @@ def test_reward_clipping(self): def test_hard_horizon(self): ev = RolloutWorker( env_creator=lambda _: MockEnv2(episode_length=10), - policy_spec=MockPolicy, - batch_mode="complete_episodes", - rollout_fragment_length=10, - episode_horizon=4, - soft_horizon=False, + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + num_rollout_workers=0, + batch_mode="complete_episodes", + rollout_fragment_length=10, + horizon=4, + soft_horizon=False, + ), ) samples = ev.sample() # Three logical episodes and correct episode resets (always after 4 @@ -526,11 +564,14 @@ def test_hard_horizon(self): # A gym env's max_episode_steps is smaller than Algorithm's horizon. ev = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v0"), - policy_spec=MockPolicy, - batch_mode="complete_episodes", - rollout_fragment_length=10, - episode_horizon=6, - soft_horizon=False, + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + num_rollout_workers=0, + batch_mode="complete_episodes", + rollout_fragment_length=10, + horizon=6, + soft_horizon=False, + ), ) samples = ev.sample() # 12 steps due to `complete_episodes` batch_mode. @@ -561,11 +602,14 @@ def test_hard_horizon(self): def test_soft_horizon(self): ev = RolloutWorker( env_creator=lambda _: MockEnv(episode_length=10), - policy_spec=MockPolicy, - batch_mode="complete_episodes", - rollout_fragment_length=10, - episode_horizon=4, - soft_horizon=True, + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + num_rollout_workers=0, + batch_mode="complete_episodes", + rollout_fragment_length=10, + horizon=4, + soft_horizon=True, + ), ) samples = ev.sample() # three logical episodes @@ -577,13 +621,21 @@ def test_soft_horizon(self): def test_metrics(self): ev = RolloutWorker( env_creator=lambda _: MockEnv(episode_length=10), - policy_spec=MockPolicy, - batch_mode="complete_episodes", + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + rollout_fragment_length=100, + num_rollout_workers=0, + batch_mode="complete_episodes", + ), ) remote_ev = RolloutWorker.as_remote().remote( env_creator=lambda _: MockEnv(episode_length=10), - policy_spec=MockPolicy, - batch_mode="complete_episodes", + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + rollout_fragment_length=100, + num_rollout_workers=0, + batch_mode="complete_episodes", + ), ) ev.sample() ray.get(remote_ev.sample.remote()) @@ -595,8 +647,8 @@ def test_metrics(self): def test_async(self): ev = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v0"), - sample_async=True, - policy_spec=MockPolicy, + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts(sample_async=True, num_rollout_workers=0), ) batch = ev.sample() for key in ["obs", "actions", "rewards", "dones", "advantages"]: @@ -607,10 +659,13 @@ def test_async(self): def test_auto_vectorization(self): ev = RolloutWorker( env_creator=lambda cfg: MockEnv(episode_length=20, config=cfg), - policy_spec=MockPolicy, - batch_mode="truncate_episodes", - rollout_fragment_length=2, - num_envs=8, + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + rollout_fragment_length=2, + num_envs_per_worker=8, + num_rollout_workers=0, + batch_mode="truncate_episodes", + ), ) for _ in range(8): batch = ev.sample() @@ -632,10 +687,13 @@ def test_auto_vectorization(self): def test_batches_larger_when_vectorized(self): ev = RolloutWorker( env_creator=lambda _: MockEnv(episode_length=8), - policy_spec=MockPolicy, - batch_mode="truncate_episodes", - rollout_fragment_length=4, - num_envs=4, + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + rollout_fragment_length=4, + num_envs_per_worker=4, + num_rollout_workers=0, + batch_mode="truncate_episodes", + ), ) batch = ev.sample() self.assertEqual(batch.count, 16) @@ -651,9 +709,12 @@ def test_vector_env_support(self): # (MockEnv instances). ev = RolloutWorker( env_creator=(lambda _: VectorizedMockEnv(episode_length=20, num_envs=8)), - policy_spec=MockPolicy, - batch_mode="truncate_episodes", - rollout_fragment_length=10, + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + rollout_fragment_length=10, + num_rollout_workers=0, + batch_mode="truncate_episodes", + ), ) for _ in range(8): batch = ev.sample() @@ -671,9 +732,12 @@ def test_vector_env_support(self): # only has 1 (CartPole). ev = RolloutWorker( env_creator=(lambda _: MockVectorEnv(20, mocked_num_envs=4)), - policy_spec=MockPolicy, - batch_mode="truncate_episodes", - rollout_fragment_length=10, + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + rollout_fragment_length=10, + num_rollout_workers=0, + batch_mode="truncate_episodes", + ), ) for _ in range(8): batch = ev.sample() @@ -690,9 +754,12 @@ def test_vector_env_support(self): def test_truncate_episodes(self): ev_env_steps = RolloutWorker( env_creator=lambda _: MockEnv(10), - policy_spec=MockPolicy, - rollout_fragment_length=15, - batch_mode="truncate_episodes", + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + rollout_fragment_length=15, + num_rollout_workers=0, + batch_mode="truncate_episodes", + ), ) batch = ev_env_steps.sample() self.assertEqual(batch.count, 15) @@ -703,16 +770,22 @@ def test_truncate_episodes(self): obs_space = Box(float("-inf"), float("inf"), (4,), dtype=np.float32) ev_agent_steps = RolloutWorker( env_creator=lambda _: MultiAgentCartPole({"num_agents": 4}), - policy_spec={ - "pol0": (MockPolicy, obs_space, action_space, {}), - "pol1": (MockPolicy, obs_space, action_space, {}), - }, - policy_mapping_fn=lambda agent_id, episode, **kwargs: "pol0" - if agent_id == 0 - else "pol1", - rollout_fragment_length=301, - count_steps_by="env_steps", - batch_mode="truncate_episodes", + default_policy_class=MockPolicy, + config=AlgorithmConfig() + .rollouts( + num_rollout_workers=0, + batch_mode="truncate_episodes", + rollout_fragment_length=301, + ) + .multi_agent( + policies={"pol0", "pol1"}, + policy_mapping_fn=( + lambda agent_id, episode, **kwargs: "pol0" + if agent_id == 0 + else "pol1" + ), + ) + .environment(action_space=action_space, observation_space=obs_space), ) batch = ev_agent_steps.sample() self.assertTrue(isinstance(batch, MultiAgentBatch)) @@ -722,16 +795,21 @@ def test_truncate_episodes(self): ev_agent_steps = RolloutWorker( env_creator=lambda _: MultiAgentCartPole({"num_agents": 4}), - policy_spec={ - "pol0": (MockPolicy, obs_space, action_space, {}), - "pol1": (MockPolicy, obs_space, action_space, {}), - }, - policy_mapping_fn=lambda agent_id, episode, **kwargs: "pol0" - if agent_id == 0 - else "pol1", - rollout_fragment_length=301, - count_steps_by="agent_steps", - batch_mode="truncate_episodes", + default_policy_class=MockPolicy, + config=AlgorithmConfig() + .rollouts( + num_rollout_workers=0, + rollout_fragment_length=301, + ) + .multi_agent( + count_steps_by="agent_steps", + policies={"pol0", "pol1"}, + policy_mapping_fn=( + lambda agent_id, episode, **kwargs: "pol0" + if agent_id == 0 + else "pol1" + ), + ), ) batch = ev_agent_steps.sample() self.assertTrue(isinstance(batch, MultiAgentBatch)) @@ -746,9 +824,12 @@ def test_truncate_episodes(self): def test_complete_episodes(self): ev = RolloutWorker( env_creator=lambda _: MockEnv(10), - policy_spec=MockPolicy, - rollout_fragment_length=5, - batch_mode="complete_episodes", + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + rollout_fragment_length=5, + num_rollout_workers=0, + batch_mode="complete_episodes", + ), ) batch = ev.sample() self.assertEqual(batch.count, 10) @@ -757,9 +838,12 @@ def test_complete_episodes(self): def test_complete_episodes_packing(self): ev = RolloutWorker( env_creator=lambda _: MockEnv(10), - policy_spec=MockPolicy, - rollout_fragment_length=15, - batch_mode="complete_episodes", + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + rollout_fragment_length=15, + num_rollout_workers=0, + batch_mode="complete_episodes", + ), ) batch = ev.sample() self.assertEqual(batch.count, 20) @@ -772,9 +856,12 @@ def test_complete_episodes_packing(self): def test_filter_sync(self): ev = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v0"), - policy_spec=MockPolicy, - sample_async=True, - policy_config={"observation_filter": "ConcurrentMeanStdFilter"}, + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + sample_async=True, + num_rollout_workers=0, + observation_filter="ConcurrentMeanStdFilter", + ), ) time.sleep(2) ev.sample() @@ -787,9 +874,12 @@ def test_filter_sync(self): def test_get_filters(self): ev = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v0"), - policy_spec=MockPolicy, - sample_async=True, - policy_config={"observation_filter": "ConcurrentMeanStdFilter"}, + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + observation_filter="ConcurrentMeanStdFilter", + num_rollout_workers=0, + sample_async=True, + ), ) self.sample_and_flush(ev) filters = ev.get_filters(flush_after=False) @@ -804,9 +894,12 @@ def test_get_filters(self): def test_sync_filter(self): ev = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v0"), - policy_spec=MockPolicy, - sample_async=True, - policy_config={"observation_filter": "ConcurrentMeanStdFilter"}, + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + observation_filter="ConcurrentMeanStdFilter", + num_rollout_workers=0, + sample_async=True, + ), ) obs_f = self.sample_and_flush(ev) @@ -831,8 +924,10 @@ def test_extra_python_envs(self): self.assertFalse("env_key_2" in os.environ) ev = RolloutWorker( env_creator=lambda _: MockEnv(10), - policy_spec=MockPolicy, - extra_python_environs=extra_envs, + default_policy_class=MockPolicy, + config=AlgorithmConfig() + .python_environment(extra_python_environs_for_driver=extra_envs) + .rollouts(num_rollout_workers=0), ) self.assertTrue("env_key_1" in os.environ) self.assertTrue("env_key_2" in os.environ) @@ -845,8 +940,8 @@ def test_extra_python_envs(self): def test_no_env_seed(self): ev = RolloutWorker( env_creator=lambda _: MockVectorEnv(20, mocked_num_envs=8), - policy_spec=MockPolicy, - seed=1, + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts(num_rollout_workers=0).debugging(seed=1), ) assert not hasattr(ev.env, "seed") ev.stop() @@ -854,9 +949,10 @@ def test_no_env_seed(self): def test_multi_env_seed(self): ev = RolloutWorker( env_creator=lambda _: MockEnv2(100), - num_envs=3, - policy_spec=MockPolicy, - seed=1, + default_policy_class=MockPolicy, + config=AlgorithmConfig() + .rollouts(num_envs_per_worker=3, num_rollout_workers=0) + .debugging(seed=1), ) # Make sure we can properly sample from the wrapped env. ev.sample() @@ -888,25 +984,26 @@ def step(self, action_dict): ev = RolloutWorker( env_creator=lambda _: MockMultiAgentEnv(), - num_envs=3, - policy_spec={ - "policy_1": PolicySpec(policy_class=MockPolicy), - "policy_2": PolicySpec(policy_class=MockPolicy), - }, - seed=1, + default_policy_class=MockPolicy, + config=AlgorithmConfig() + .rollouts(num_envs_per_worker=3, num_rollout_workers=0) + .multi_agent(policies={"policy_1", "policy_2"}) + .debugging(seed=1), ) # The fact that this RolloutWorker can be created without throwing - # exceptions means _determine_spaces_for_multi_agent_dict() is - # handling multiagent user environments properly. + # exceptions means AlgorithmConfig.get_multi_agent_setup() is + # handling multi-agent user environments properly. self.assertIsNotNone(ev) def test_wrap_multi_agent_env(self): ev = RolloutWorker( env_creator=lambda _: BasicMultiAgent(10), - policy_spec=MockPolicy, - policy_config={ - "in_evaluation": False, - }, + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + rollout_fragment_length=5, + batch_mode="complete_episodes", + num_rollout_workers=0, + ), ) # Make sure we can properly sample from the wrapped env. ev.sample() @@ -932,9 +1029,12 @@ def step(self, action): ev = RolloutWorker( env_creator=lambda _: NoTrainingEnv(10, True), - policy_spec=MockPolicy, - rollout_fragment_length=5, - batch_mode="complete_episodes", + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + rollout_fragment_length=5, + batch_mode="complete_episodes", + num_rollout_workers=0, + ), ) batch = ev.sample() self.assertEqual(batch.count, 10) @@ -943,9 +1043,12 @@ def step(self, action): ev = RolloutWorker( env_creator=lambda _: NoTrainingEnv(10, False), - policy_spec=MockPolicy, - rollout_fragment_length=5, - batch_mode="complete_episodes", + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + rollout_fragment_length=5, + batch_mode="complete_episodes", + num_rollout_workers=0, + ), ) batch = ev.sample() self.assertTrue(isinstance(batch, MultiAgentBatch)) diff --git a/rllib/evaluation/tests/test_trajectory_view_api.py b/rllib/evaluation/tests/test_trajectory_view_api.py index 123add500a4db..4a1a7cd4e118a 100644 --- a/rllib/evaluation/tests/test_trajectory_view_api.py +++ b/rllib/evaluation/tests/test_trajectory_view_api.py @@ -199,11 +199,10 @@ def test_traj_view_next_action(self): action_space = Discrete(2) rollout_worker_w_api = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v0"), - policy_config=ppo.DEFAULT_CONFIG, - rollout_fragment_length=200, - policy_spec=ppo.PPOTorchPolicy, - policy_mapping_fn=None, - num_envs=1, + default_policy_class=ppo.PPOTorchPolicy, + config=ppo.PPOConfig().rollouts( + rollout_fragment_length=200, num_rollout_workers=0 + ), ) # Add the next action (a') and 2nd next action (a'') to the view # requirements of the policy. @@ -271,25 +270,24 @@ def test_traj_view_lstm_functionality(self): def policy_fn(agent_id, episode, **kwargs): return "pol0" - config = { - "multiagent": { - "policies": policies, - "policy_mapping_fn": policy_fn, - }, - "model": { - "use_lstm": True, - "max_seq_len": max_seq_len, - }, - } - rw = RolloutWorker( env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}), - policy_config=config, - rollout_fragment_length=rollout_fragment_length, - policy_spec=policies, - policy_mapping_fn=policy_fn, - normalize_actions=False, - num_envs=1, + config=ppo.PPOConfig() + .rollouts( + rollout_fragment_length=rollout_fragment_length, + num_rollout_workers=0, + ) + .multi_agent( + policies=policies, + policy_mapping_fn=policy_fn, + ) + .environment(normalize_actions=False) + .training( + model={ + "use_lstm": True, + "max_seq_len": max_seq_len, + } + ), ) for iteration in range(20): @@ -315,24 +313,20 @@ def test_traj_view_attention_functionality(self): def policy_fn(agent_id, episode, **kwargs): return "pol0" - config = { - "multiagent": { - "policies": policies, - "policy_mapping_fn": policy_fn, - }, - "model": { - "max_seq_len": max_seq_len, - }, - } + config = ( + ppo.PPOConfig() + .multi_agent(policies=policies, policy_mapping_fn=policy_fn) + .training(model={"max_seq_len": max_seq_len}) + .rollouts( + num_rollout_workers=0, + rollout_fragment_length=rollout_fragment_length, + ) + .environment(normalize_actions=False) + ) rollout_worker_w_api = RolloutWorker( env_creator=lambda _: MultiAgentDebugCounterEnv({"num_agents": 4}), - policy_config=config, - rollout_fragment_length=rollout_fragment_length, - policy_spec=policies, - policy_mapping_fn=policy_fn, - normalize_actions=False, - num_envs=1, + config=config, ) batch = rollout_worker_w_api.sample() # noqa: F841 diff --git a/rllib/evaluation/worker_set.py b/rllib/evaluation/worker_set.py index f6d7fa5564228..c2b50f38646b6 100644 --- a/rllib/evaluation/worker_set.py +++ b/rllib/evaluation/worker_set.py @@ -2,7 +2,6 @@ import logging import importlib.util import os -from types import FunctionType from typing import ( Callable, Container, @@ -11,6 +10,7 @@ Optional, Tuple, Type, + TYPE_CHECKING, TypeVar, Union, ) @@ -21,23 +21,15 @@ from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.env.base_env import BaseEnv from ray.rllib.env.env_context import EnvContext -from ray.rllib.offline import ( - NoopOutput, - JsonReader, - MixedInput, - JsonWriter, - ShuffledInput, - D4RLReader, - DatasetReader, - DatasetWriter, - get_dataset_and_shards, -) -from ray.rllib.policy.policy import Policy, PolicySpec, PolicyState -from ray.rllib.utils import merge_dicts +from ray.rllib.offline import get_dataset_and_shards +from ray.rllib.policy.policy import Policy, PolicyState from ray.rllib.utils.annotations import DeveloperAPI -from ray.rllib.utils.deprecation import Deprecated +from ray.rllib.utils.deprecation import ( + Deprecated, + deprecation_warning, + DEPRECATED_VALUE, +) from ray.rllib.utils.framework import try_import_tf -from ray.rllib.utils.from_config import from_config from ray.rllib.utils.policy import validate_policy_id from ray.rllib.utils.typing import ( AgentID, @@ -50,7 +42,9 @@ SampleBatchType, TensorType, ) -from ray.tune.registry import registry_contains_input, registry_get_input + +if TYPE_CHECKING: + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig tf1, tf, tfv = try_import_tf() @@ -72,12 +66,15 @@ def __init__( *, env_creator: Optional[EnvCreator] = None, validate_env: Optional[Callable[[EnvType], None]] = None, - policy_class: Optional[Type[Policy]] = None, - trainer_config: Optional[AlgorithmConfigDict] = None, + default_policy_class: Optional[Type[Policy]] = None, + config: Optional[Union["AlgorithmConfig", AlgorithmConfigDict]] = None, num_workers: int = 0, local_worker: bool = True, logdir: Optional[str] = None, _setup: bool = True, + # deprecated args. + policy_class=DEPRECATED_VALUE, + trainer_config=DEPRECATED_VALUE, ): """Initializes a WorkerSet instance. @@ -85,11 +82,12 @@ def __init__( env_creator: Function that returns env given env config. validate_env: Optional callable to validate the generated environment (only on worker=0). - policy_class: An optional Policy class. If None, PolicySpecs can be - generated automatically by using the Algorithm's default class - of via a given multi-agent policy config dict. - trainer_config: Optional dict that extends the common config of - the Algorithm class. + default_policy_class: An optional default Policy class to use inside + the (multi-agent) `policies` dict. In case the PolicySpecs in there + have no class defined, use this `default_policy_class`. + If None, PolicySpecs will be using the Algorithm's default Policy + class. + config: Optional AlgorithmConfig (or config dict). num_workers: Number of remote rollout workers to create. local_worker: Whether to create a local (non @ray.remote) worker in the returned set as well (default: True). If `num_workers` @@ -97,19 +95,36 @@ def __init__( logdir: Optional logging directory for workers. _setup: Whether to setup workers. This is only for testing. """ + if policy_class != DEPRECATED_VALUE: + deprecation_warning( + old="WorkerSet(policy_class=..)", + new="WorkerSet(default_policy_class=..)", + error=False, + ) + default_policy_class = policy_class + if trainer_config != DEPRECATED_VALUE: + deprecation_warning( + old="WorkerSet(trainer_config=..)", + new="WorkerSet(config=..)", + error=False, + ) + config = trainer_config - if not trainer_config: - from ray.rllib.algorithms.algorithm import COMMON_CONFIG + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig - trainer_config = COMMON_CONFIG + # Make sure `config` is an AlgorithmConfig object. + if not config: + config = AlgorithmConfig() + elif isinstance(config, dict): + config = AlgorithmConfig.from_dict(config) self._env_creator = env_creator - self._policy_class = policy_class - self._remote_config = trainer_config + self._policy_class = default_policy_class + self._remote_config = config self._remote_args = { - "num_cpus": self._remote_config["num_cpus_per_worker"], - "num_gpus": self._remote_config["num_gpus_per_worker"], - "resources": self._remote_config["custom_resources_per_worker"], + "num_cpus": self._remote_config.num_cpus_per_worker, + "num_gpus": self._remote_config.num_gpus_per_worker, + "resources": self._remote_config.custom_resources_per_worker, } self._cls = RolloutWorker.as_remote(**self._remote_args).remote self._logdir = logdir @@ -119,17 +134,14 @@ def __init__( self._local_worker = None if num_workers == 0: local_worker = True - self._local_config = merge_dicts( - trainer_config, - {"tf_session_args": trainer_config["local_tf_session_args"]}, + self._local_config = config.copy(copy_frozen=False).framework( + tf_session_args=config.local_tf_session_args ) - if trainer_config["input"] == "dataset": + if config.input_ == "dataset": # Create the set of dataset readers to be shared by all the # rollout workers. - self._ds, self._ds_shards = get_dataset_and_shards( - trainer_config, num_workers - ) + self._ds, self._ds_shards = get_dataset_and_shards(config, num_workers) else: self._ds = None self._ds_shards = None @@ -138,7 +150,7 @@ def __init__( self._remote_workers = [] self.add_workers( num_workers, - validate=trainer_config.get("validate_workers_after_construction"), + validate=config.validate_workers_after_construction, ) # Create a local worker, if needed. @@ -148,11 +160,8 @@ def __init__( if ( local_worker and self._remote_workers - and not trainer_config.get("create_env_on_driver") - and ( - not trainer_config.get("observation_space") - or not trainer_config.get("action_space") - ) + and not config.create_env_on_local_worker + and (not config.observation_space or not config.action_space) ): remote_spaces = ray.get( self.remote_workers()[0].foreach_policy.remote( @@ -186,9 +195,6 @@ def __init__( cls=RolloutWorker, env_creator=env_creator, validate_env=validate_env, - policy_cls=self._policy_class, - # Initially, policy_specs will be inferred from config dict. - policy_specs=None, worker_index=0, num_workers=num_workers, config=self._local_config, @@ -253,7 +259,7 @@ def add_policy( *, observation_space: Optional[gym.spaces.Space] = None, action_space: Optional[gym.spaces.Space] = None, - config: Optional[PartialAlgorithmConfigDict] = None, + config: Optional[Union["AlgorithmConfig", PartialAlgorithmConfigDict]] = None, policy_state: Optional[PolicyState] = None, policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None, policies_to_train: Optional[ @@ -279,7 +285,7 @@ def add_policy( If None, try to infer this space from the environment. action_space: The action space of the policy to add. If None, try to infer this space from the environment. - config: The config overrides for the policy to add. + config: The config object or overrides for the policy to add. policy_state: Optional state dict to apply to the new policy instance, right after its construction. policy_mapping_fn: An optional (updated) policy mapping function @@ -340,7 +346,7 @@ def add_policy_to_workers( *, observation_space: Optional[gym.spaces.Space] = None, action_space: Optional[gym.spaces.Space] = None, - config: Optional[PartialAlgorithmConfigDict] = None, + config: Optional[Union["AlgorithmConfig", PartialAlgorithmConfigDict]] = None, policy_state: Optional[PolicyState] = None, policy_mapping_fn: Optional[Callable[[AgentID, EpisodeID], PolicyID]] = None, policies_to_train: Optional[ @@ -367,7 +373,7 @@ def add_policy_to_workers( If None, try to infer this space from the environment. action_space: The action space of the policy to add. If None, try to infer this space from the environment. - config: The config overrides for the policy to add. + config: The config object or overrides for the policy to add. policy_state: Optional state dict to apply to the new policy instance, right after its construction. policy_mapping_fn: An optional (updated) policy mapping function @@ -470,10 +476,6 @@ def add_workers(self, num_workers: int, validate: bool = False) -> None: cls=self._cls, env_creator=self._env_creator, validate_env=None, - policy_cls=self._policy_class, - # Setup remote workers with policy_specs inferred from config dict. - # Simply provide None here. - policy_specs=None, worker_index=old_num_workers + i + 1, num_workers=old_num_workers + num_workers, config=self._remote_config, @@ -548,19 +550,12 @@ def recreate_failed_workers( worker.__ray_terminate__.remote() except Exception: logger.exception("Error terminating faulty worker.") + # Try to recreate the failed worker (start a new one). new_worker = self._make_worker( cls=self._cls, env_creator=self._env_creator, validate_env=None, - policy_cls=self._policy_class, - # For recreated remote workers, we need to sync the entire - # policy specs dict from local_worker_for_synching. - # We can not let self._make_worker() infer policy specs - # from self._remote_config dict because custom policies - # may be added to both rollout and evaluation workers - # while the training job progresses. - policy_specs=local_worker_for_synching.policy_dict, worker_index=worker_index, num_workers=len(self._remote_workers), recreated_worker=True, @@ -569,9 +564,8 @@ def recreate_failed_workers( # Sync new worker from provided one (or local one). # Restore weights and global variables. - new_worker.set_weights.remote( - weights=local_worker_for_synching.get_weights(), - global_vars=local_worker_for_synching.get_global_vars(), + new_worker.set_state.remote( + state=local_worker_for_synching.get_state(), ) # Add new worker to list of remote workers. @@ -600,6 +594,8 @@ def is_policy_to_train( """Whether given PolicyID (optionally inside some batch) is trainable.""" local_worker = self.local_worker() if local_worker: + if local_worker.is_policy_to_train is None: + return True return local_worker.is_policy_to_train(policy_id, batch) else: raise NotImplementedError @@ -762,7 +758,7 @@ def _from_existing( local_worker: RolloutWorker, remote_workers: List[ActorHandle] = None ): workers = WorkerSet( - env_creator=None, policy_class=None, trainer_config={}, _setup=False + env_creator=None, default_policy_class=None, config=None, _setup=False ) workers._local_worker = local_worker workers._remote_workers = remote_workers or [] @@ -774,12 +770,10 @@ def _make_worker( cls: Callable, env_creator: EnvCreator, validate_env: Optional[Callable[[EnvType], None]], - policy_cls: Type[Policy], - policy_specs: Optional[Dict[str, PolicySpec]] = None, worker_index: int, num_workers: int, recreated_worker: bool = False, - config: AlgorithmConfigDict, + config: "AlgorithmConfig", spaces: Optional[ Dict[PolicyID, Tuple[gym.spaces.Space, gym.spaces.Space]] ] = None, @@ -788,147 +782,18 @@ def session_creator(): logger.debug("Creating TF session {}".format(config["tf_session_args"])) return tf1.Session(config=tf1.ConfigProto(**config["tf_session_args"])) - def valid_module(class_path): - if ( - isinstance(class_path, str) - and not os.path.isfile(class_path) - and "." in class_path - ): - module_path, class_name = class_path.rsplit(".", 1) - try: - spec = importlib.util.find_spec(module_path) - if spec is not None: - return True - except (ModuleNotFoundError, ValueError): - print( - f"module {module_path} not found while trying to get " - f"input {class_path}" - ) - return False - - # A callable returning an InputReader object to use. - if isinstance(config["input"], FunctionType): - input_creator = config["input"] - # Use RLlib's Sampler classes (SyncSampler or AsynchSampler, depending - # on `config.sample_async` setting). - elif config["input"] == "sampler": - input_creator = lambda ioctx: ioctx.default_sampler_input() - # Ray Dataset input -> Use `config.input_config` to construct DatasetReader. - elif config["input"] == "dataset": - # Input dataset shards should have already been prepared. - # We just need to take the proper shard here. - input_creator = lambda ioctx: DatasetReader( - self._ds_shards[worker_index], ioctx - ) - # Dict: Mix of different input methods with different ratios. - elif isinstance(config["input"], dict): - input_creator = lambda ioctx: ShuffledInput( - MixedInput(config["input"], ioctx), config["shuffle_buffer_size"] - ) - # A pre-registered input descriptor (str). - elif isinstance(config["input"], str) and registry_contains_input( - config["input"] - ): - input_creator = registry_get_input(config["input"]) - # D4RL input. - elif "d4rl" in config["input"]: - env_name = config["input"].split(".")[-1] - input_creator = lambda ioctx: D4RLReader(env_name, ioctx) - # Valid python module (class path) -> Create using `from_config`. - elif valid_module(config["input"]): - input_creator = lambda ioctx: ShuffledInput( - from_config(config["input"], ioctx=ioctx) - ) - # JSON file or list of JSON files -> Use JsonReader (shuffled). - else: - input_creator = lambda ioctx: ShuffledInput( - JsonReader(config["input"], ioctx), config["shuffle_buffer_size"] - ) - - if isinstance(config["output"], FunctionType): - output_creator = config["output"] - elif config["output"] is None: - output_creator = lambda ioctx: NoopOutput() - elif config["output"] == "dataset": - output_creator = lambda ioctx: DatasetWriter( - ioctx, compress_columns=config["output_compress_columns"] - ) - elif config["output"] == "logdir": - output_creator = lambda ioctx: JsonWriter( - ioctx.log_dir, - ioctx, - max_file_size=config["output_max_file_size"], - compress_columns=config["output_compress_columns"], - ) - else: - output_creator = lambda ioctx: JsonWriter( - config["output"], - ioctx, - max_file_size=config["output_max_file_size"], - compress_columns=config["output_compress_columns"], - ) - - if not policy_specs: - # Infer policy specs from multiagent.policies dict. - if config["multiagent"]["policies"]: - # Make a copy so we don't modify the original multiagent config dict - # by accident. - policy_specs = config["multiagent"]["policies"].copy() - # Assert everything is correct in "multiagent" config dict (if given). - for policy_spec in policy_specs.values(): - assert isinstance(policy_spec, PolicySpec) - # Class is None -> Use `policy_cls`. - if policy_spec.policy_class is None: - policy_spec.policy_class = policy_cls - # Use the only policy class as policy specs. - else: - policy_specs = policy_cls - - if worker_index == 0: - extra_python_environs = config.get("extra_python_environs_for_driver", None) - else: - extra_python_environs = config.get("extra_python_environs_for_worker", None) - worker = cls( env_creator=env_creator, validate_env=validate_env, - policy_spec=policy_specs, - policy_mapping_fn=config["multiagent"]["policy_mapping_fn"], - policies_to_train=config["multiagent"]["policies_to_train"], + default_policy_class=self._policy_class, tf_session_creator=(session_creator if config["tf_session_args"] else None), - rollout_fragment_length=config["rollout_fragment_length"], - count_steps_by=config["multiagent"]["count_steps_by"], - batch_mode=config["batch_mode"], - episode_horizon=config["horizon"], - preprocessor_pref=config["preprocessor_pref"], - sample_async=config["sample_async"], - compress_observations=config["compress_observations"], - num_envs=config["num_envs_per_worker"], - observation_fn=config["multiagent"]["observation_fn"], - clip_rewards=config["clip_rewards"], - normalize_actions=config["normalize_actions"], - clip_actions=config["clip_actions"], - env_config=config["env_config"], - policy_config=config, + config=config, worker_index=worker_index, num_workers=num_workers, recreated_worker=recreated_worker, log_dir=self._logdir, - log_level=config["log_level"], - callbacks=config["callbacks"], - input_creator=input_creator, - output_creator=output_creator, - remote_worker_envs=config["remote_worker_envs"], - remote_env_batch_wait_ms=config["remote_env_batch_wait_ms"], - soft_horizon=config["soft_horizon"], - no_done_at_end=config["no_done_at_end"], - seed=(config["seed"] + worker_index) - if config["seed"] is not None - else None, - fake_sampler=config["fake_sampler"], - extra_python_environs=extra_python_environs, spaces=spaces, - disable_env_checking=config["disable_env_checking"], + dataset_shards=self._ds_shards, ) return worker @@ -986,12 +851,4 @@ def foreach_trainable_policy(self, func): @Deprecated(new="WorkerSet.is_policy_to_train([pid], [batch]?)", error=True) def trainable_policies(self): - local_worker = self.local_worker() - if local_worker is not None: - return [ - pid - for pid in local_worker.policy_map.keys() - if local_worker.is_policy_to_train(pid, None) - ] - else: - raise NotImplementedError + pass diff --git a/rllib/examples/documentation/replay_buffer_demo.py b/rllib/examples/documentation/replay_buffer_demo.py index c6e204c917ae2..9a18e5ba29e63 100644 --- a/rllib/examples/documentation/replay_buffer_demo.py +++ b/rllib/examples/documentation/replay_buffer_demo.py @@ -15,19 +15,13 @@ # __sphinx_doc_replay_buffer_type_specification__begin__ -config = DQNConfig().training(replay_buffer_config={"type": ReplayBuffer}).to_dict() +config = DQNConfig().training(replay_buffer_config={"type": ReplayBuffer}) -another_config = ( - DQNConfig().training(replay_buffer_config={"type": "ReplayBuffer"}).to_dict() -) +another_config = DQNConfig().training(replay_buffer_config={"type": "ReplayBuffer"}) -yet_another_config = ( - DQNConfig() - .training( - replay_buffer_config={"type": "ray.rllib.utils.replay_buffers.ReplayBuffer"} - ) - .to_dict() +yet_another_config = DQNConfig().training( + replay_buffer_config={"type": "ray.rllib.utils.replay_buffers.ReplayBuffer"} ) validate_buffer_config(config) @@ -35,7 +29,12 @@ validate_buffer_config(yet_another_config) # After validation, all three configs yield the same effective config -assert config == another_config == yet_another_config +assert ( + config.replay_buffer_config + == another_config.replay_buffer_config + == yet_another_config.replay_buffer_config +) + # __sphinx_doc_replay_buffer_type_specification__end__ @@ -120,21 +119,25 @@ def sample( # __sphinx_doc_replay_buffer_advanced_usage_underlying_buffers__begin__ -config = { - "env": "CartPole-v1", - "replay_buffer_config": { - "type": "MultiAgentReplayBuffer", - "underlying_replay_buffer_config": { - "type": LessSampledReplayBuffer, - "evict_sampled_more_then": 20 # We can specify the default call argument - # for the sample method of the underlying buffer method here - }, - }, -} +config = ( + DQNConfig() + .training( + replay_buffer_config={ + "type": "MultiAgentReplayBuffer", + "underlying_replay_buffer_config": { + "type": LessSampledReplayBuffer, + # We can specify the default call argument + # for the sample method of the underlying buffer method here. + "evict_sampled_more_then": 20, + }, + } + ) + .environment(env="CartPole-v1") +) tune.Tuner( "DQN", - param_space=config, + param_space=config.to_dict(), run_config=air.RunConfig( stop={"episode_reward_mean": 50, "training_iteration": 10} ), diff --git a/rllib/examples/documentation/saving_and_loading_algos_and_policies.py b/rllib/examples/documentation/saving_and_loading_algos_and_policies.py index f45422e8ac4e9..093f7721ff366 100644 --- a/rllib/examples/documentation/saving_and_loading_algos_and_policies.py +++ b/rllib/examples/documentation/saving_and_loading_algos_and_policies.py @@ -115,9 +115,9 @@ # to avoid a runtime error). Now both agents ("agent0" and "agent1") map to # the same policy. policy_mapping_fn=lambda agent_id, episode, worker, **kw: "pol1", - # Since we defined this above, we have to de-define it here with the updated + # Since we defined this above, we have to re-define it here with the updated # PolicyIDs, otherwise, RLlib will throw an error (it will think that there is an - # unknown PolicyID in this list (pol2)). + # unknown PolicyID in this list ("pol2")). policies_to_train=["pol1"], ) diff --git a/rllib/examples/hierarchical_training.py b/rllib/examples/hierarchical_training.py index 751ac38ffd9b5..d24da537ec76b 100644 --- a/rllib/examples/hierarchical_training.py +++ b/rllib/examples/hierarchical_training.py @@ -55,12 +55,17 @@ parser.add_argument( "--stop-reward", type=float, default=0.0, help="Reward at which we stop training." ) +parser.add_argument( + "--local-mode", + action="store_true", + help="Init Ray in local mode for easier debugging.", +) logger = logging.getLogger(__name__) if __name__ == "__main__": args = parser.parse_args() - ray.init() + ray.init(local_mode=args.local_mode) stop = { "training_iteration": args.stop_iters, diff --git a/rllib/execution/multi_gpu_learner_thread.py b/rllib/execution/multi_gpu_learner_thread.py index 90fe7d901c255..706c76853f98a 100644 --- a/rllib/execution/multi_gpu_learner_thread.py +++ b/rllib/execution/multi_gpu_learner_thread.py @@ -153,7 +153,10 @@ def step(self) -> None: for pid in self.policy_map.keys(): # Not a policy-to-train. - if not self.local_worker.is_policy_to_train(pid): + if ( + self.local_worker.is_policy_to_train is not None + and not self.local_worker.is_policy_to_train(pid) + ): continue policy = self.policy_map[pid] default_policy_results = policy.learn_on_loaded_batch( @@ -213,7 +216,10 @@ def _step(self) -> None: # Load the batch into the idle stack. with self.load_timer: for pid in policy_map.keys(): - if not s.local_worker.is_policy_to_train(pid, batch): + if ( + s.local_worker.is_policy_to_train is not None + and not s.local_worker.is_policy_to_train(pid, batch) + ): continue policy = policy_map[pid] if isinstance(batch, SampleBatch): diff --git a/rllib/execution/rollout_ops.py b/rllib/execution/rollout_ops.py index 3dad49e3242b8..573d8c9c4c529 100644 --- a/rllib/execution/rollout_ops.py +++ b/rllib/execution/rollout_ops.py @@ -369,7 +369,10 @@ def __call__(self, samples: SampleBatchType) -> SampleBatchType: { pid: batch for pid, batch in samples.policy_batches.items() - if self.local_worker.is_policy_to_train(pid, batch) + if ( + self.local_worker.is_policy_to_train is None + or self.local_worker.is_policy_to_train(pid, batch) + ) }, samples.count, ) diff --git a/rllib/execution/train_ops.py b/rllib/execution/train_ops.py index e5d9a3c1256d7..9f494a1027b9e 100644 --- a/rllib/execution/train_ops.py +++ b/rllib/execution/train_ops.py @@ -140,7 +140,10 @@ def multi_gpu_train_one_step(algorithm, train_batch) -> Dict: num_loaded_samples = {} for policy_id, batch in train_batch.policy_batches.items(): # Not a policy-to-train. - if not local_worker.is_policy_to_train(policy_id, train_batch): + if ( + local_worker.is_policy_to_train is not None + and not local_worker.is_policy_to_train(policy_id, train_batch) + ): continue # Decompress SampleBatch, in case some columns are compressed. @@ -315,7 +318,10 @@ def __call__(self, samples: SampleBatchType) -> (SampleBatchType, List[dict]): num_loaded_samples = {} for policy_id, batch in samples.policy_batches.items(): # Not a policy-to-train. - if not self.local_worker.is_policy_to_train(policy_id, samples): + if ( + self.local_worker.is_policy_to_train is not None + and not self.local_worker.is_policy_to_train(policy_id, samples) + ): continue # Decompress SampleBatch, in case some columns are compressed. diff --git a/rllib/offline/estimators/tests/utils.py b/rllib/offline/estimators/tests/utils.py index d30b4e06f6b9c..28785ec7e8c1d 100644 --- a/rllib/offline/estimators/tests/utils.py +++ b/rllib/offline/estimators/tests/utils.py @@ -51,8 +51,8 @@ def get_cliff_walking_wall_policy_and_data( ) workers = WorkerSet( env_creator=lambda env_config: CliffWalkingWallEnv(), - policy_class=CliffWalkingWallPolicy, - trainer_config=config, + default_policy_class=CliffWalkingWallPolicy, + config=config, num_workers=4, ) ep_ret = [] diff --git a/rllib/offline/tests/test_feature_importance.py b/rllib/offline/tests/test_feature_importance.py index 14b3d60752cfb..7ad021f8a9206 100644 --- a/rllib/offline/tests/test_feature_importance.py +++ b/rllib/offline/tests/test_feature_importance.py @@ -14,8 +14,8 @@ def tearDown(self): ray.shutdown() def test_feat_importance_cartpole(self): - config = CRRConfig().framework("torch") - runner = CRR(config, env="CartPole-v0") + config = CRRConfig().environment("CartPole-v0").framework("torch") + runner = CRR(config) policy = runner.workers.local_worker().get_policy() sample_batch = synchronous_parallel_sample(worker_set=runner.workers) @@ -25,7 +25,7 @@ def test_feat_importance_cartpole(self): estimate = evaluator.estimate(sample_batch) # check if the estimate is positive - assert all([val > 0 for val in estimate.values()]) + assert all(val > 0 for val in estimate.values()) if __name__ == "__main__": diff --git a/rllib/policy/dynamic_tf_policy_v2.py b/rllib/policy/dynamic_tf_policy_v2.py index 874d0ee06ee33..847bc380755c0 100644 --- a/rllib/policy/dynamic_tf_policy_v2.py +++ b/rllib/policy/dynamic_tf_policy_v2.py @@ -61,7 +61,6 @@ def __init__( ): self.observation_space = obs_space self.action_space = action_space - config = dict(self.get_default_config(), **config) self.config = config self.framework = "tf" self._seq_lens = None @@ -141,11 +140,6 @@ def enable_eager_execution_if_necessary(): # Simply do nothing. pass - @DeveloperAPI - @OverrideToImplementCustomLogic - def get_default_config(self) -> AlgorithmConfigDict: - return {} - @DeveloperAPI @OverrideToImplementCustomLogic def validate_spaces( diff --git a/rllib/policy/eager_tf_policy_v2.py b/rllib/policy/eager_tf_policy_v2.py index 9adaac40bf375..3607a64823dec 100644 --- a/rllib/policy/eager_tf_policy_v2.py +++ b/rllib/policy/eager_tf_policy_v2.py @@ -76,9 +76,6 @@ def __init__( Policy.__init__(self, observation_space, action_space, config) - config = dict(self.get_default_config(), **config) - self.config = config - self._is_training = False # Global timestep should be a tensor. self.global_timestep = tf.Variable(0, trainable=False, dtype=tf.int64) @@ -101,9 +98,9 @@ def __init__( self._loss = None self.batch_divisibility_req = self.get_batch_divisibility_req() - self._max_seq_len = config["model"]["max_seq_len"] + self._max_seq_len = self.config["model"]["max_seq_len"] - self.validate_spaces(observation_space, action_space, config) + self.validate_spaces(observation_space, action_space, self.config) # If using default make_model(), dist_class will get updated when # the model is created next. @@ -144,11 +141,6 @@ def enable_eager_execution_if_necessary(): if tf1 and not tf1.executing_eagerly(): tf1.enable_eager_execution() - @DeveloperAPI - @OverrideToImplementCustomLogic - def get_default_config(self) -> AlgorithmConfigDict: - return {} - @DeveloperAPI @OverrideToImplementCustomLogic def validate_spaces( diff --git a/rllib/policy/policy_map.py b/rllib/policy/policy_map.py index 6e109ff3dbd6e..fe9bc88f219b9 100644 --- a/rllib/policy/policy_map.py +++ b/rllib/policy/policy_map.py @@ -2,24 +2,26 @@ import gym import os import threading -from typing import Callable, Dict, Optional, Set, Type +from typing import Callable, Dict, Optional, Set, Type, TYPE_CHECKING, Union import ray.cloudpickle as pickle from ray.rllib.policy.policy import Policy, PolicySpec from ray.rllib.utils.annotations import PublicAPI, override +from ray.rllib.utils.deprecation import deprecation_warning from ray.rllib.utils.framework import try_import_tf from ray.rllib.utils.policy import create_policy_for_framework from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary from ray.rllib.utils.threading import with_lock from ray.rllib.utils.typing import ( AlgorithmConfigDict, - PartialAlgorithmConfigDict, PolicyID, ) -from ray.tune.utils.util import merge_dicts tf1, tf, tfv = try_import_tf() +if TYPE_CHECKING: + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + @PublicAPI class PolicyMap(dict): @@ -36,7 +38,7 @@ def __init__( num_workers: int, capacity: Optional[int] = None, path: Optional[str] = None, - policy_config: Optional[AlgorithmConfigDict] = None, + policy_config=None, # deprecated arg: All Policies now bring their own config. session_creator: Optional[Callable[[], "tf1.Session"]] = None, seed: Optional[int] = None, ): @@ -52,11 +54,16 @@ def __init__( when needed. path: The path to store the policy pickle files to. Files will have the name: [policy_id].[worker idx].policy.pkl. - policy_config: The Algorithm's base config dict. session_creator: An optional tf1.Session creation callable. seed: An optional seed (used to seed tf policies). """ + if policy_config is not None: + deprecation_warning( + old="PolicyMap(policy_config=..)", + error=True, + ) + super().__init__() self.worker_index = worker_index @@ -76,9 +83,6 @@ def __init__( self.deque = deque(maxlen=capacity or 10) # The file path where to store overflowing policies. self.path = path or "." - # The core config to use. Each single policy's config override is - # added on top of this. - self.policy_config: AlgorithmConfigDict = policy_config or {} # The orig classes/obs+act spaces, and config overrides of the # Policies. self.policy_specs: Dict[PolicyID, PolicySpec] = {} @@ -89,8 +93,14 @@ def __init__( self._lock = threading.RLock() def insert_policy( - self, policy_id: PolicyID, policy: Policy, config_override=None + self, policy_id: PolicyID, policy: Policy, config_override=None # deprecated ) -> None: + if config_override is not None: + deprecation_warning( + old="PolicyMap.insert_policy(config_override=..)", + error=True, + ) + self[policy_id] = policy # Store spec (class, obs-space, act-space, and config overrides) such @@ -100,7 +110,7 @@ def insert_policy( policy_class=type(policy), observation_space=policy.observation_space, action_space=policy.action_space, - config=config_override if config_override is not None else policy.config, + config=policy.config, ) def create_policy( @@ -109,8 +119,8 @@ def create_policy( policy_cls: Type["Policy"], observation_space: gym.Space, action_space: gym.Space, - config_override: PartialAlgorithmConfigDict, - merged_config: AlgorithmConfigDict, + config_override, # deprecated arg + merged_config: Union["AlgorithmConfig", AlgorithmConfigDict], ) -> None: """Creates a new policy and stores it to the cache. @@ -123,12 +133,14 @@ def create_policy( observation_space: The observation space of the policy. action_space: The action space of the policy. - config_override: The config override - dict for this policy. This is the partial dict provided by - the user. - merged_config: The entire config (merged - default config + `config_override`). + merged_config: The config object (or complete config dict) for the policy + to use. """ + if config_override is not None: + deprecation_warning( + old="PolicyMap.create_policy(config_override=..)", + error=True, + ) _class = get_tf_eager_cls_if_necessary(policy_cls, merged_config) policy = create_policy_for_framework( @@ -141,7 +153,7 @@ def create_policy( self.session_creator, self.seed, ) - self.insert_policy(policy_id, policy, config_override) + self.insert_policy(policy_id, policy) @with_lock @override(dict) @@ -279,6 +291,8 @@ def _stash_to_disk(self): def _read_from_disk(self, policy_id): """Reads a policy ID from disk and re-adds it to the cache.""" + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + # Make sure this policy ID is not in the cache right now. assert policy_id not in self.cache # Read policy state from disk. @@ -286,9 +300,9 @@ def _read_from_disk(self, policy_id): policy_state = pickle.load(f) # Get class and config override. - merged_conf = merge_dicts( - self.policy_config, self.policy_specs[policy_id].config - ) + config = self.policy_specs[policy_id].config + if isinstance(config, AlgorithmConfig): + config = config.to_dict() # Create policy object (from its spec: cls, obs-space, act-space, # config). @@ -297,8 +311,8 @@ def _read_from_disk(self, policy_id): self.policy_specs[policy_id].policy_class, self.policy_specs[policy_id].observation_space, self.policy_specs[policy_id].action_space, - self.policy_specs[policy_id].config, - merged_conf, + config_override=None, # deprecated, must be None + merged_config=config, ) # Restore policy's state. policy = self[policy_id] diff --git a/rllib/tests/backward_compat/checkpoints/create_checkpoints.py b/rllib/tests/backward_compat/checkpoints/create_checkpoints.py index c496dffc0e760..9197156d7b969 100644 --- a/rllib/tests/backward_compat/checkpoints/create_checkpoints.py +++ b/rllib/tests/backward_compat/checkpoints/create_checkpoints.py @@ -18,7 +18,7 @@ ) for fw in framework_iterator(config, with_eager_tracing=True): - trainer = config.build() - results = trainer.train() - trainer.save() - trainer.stop() + algo = config.build() + results = algo.train() + algo.save() + algo.stop() diff --git a/rllib/tests/test_execution.py b/rllib/tests/test_execution.py index 5b9dcaa046de0..7ca2c027300cb 100644 --- a/rllib/tests/test_execution.py +++ b/rllib/tests/test_execution.py @@ -5,6 +5,7 @@ import unittest import ray +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy from ray.rllib.evaluation.worker_set import WorkerSet from ray.rllib.evaluation.rollout_worker import RolloutWorker @@ -38,14 +39,18 @@ def iter_list(values): def make_workers(n): local = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v0"), - policy_spec=PPOTF1Policy, - rollout_fragment_length=100, + default_policy_class=PPOTF1Policy, + config=AlgorithmConfig().rollouts( + rollout_fragment_length=100, num_rollout_workers=0 + ), ) remotes = [ RolloutWorker.as_remote().remote( env_creator=lambda _: gym.make("CartPole-v0"), - policy_spec=PPOTF1Policy, - rollout_fragment_length=100, + default_policy_class=PPOTF1Policy, + config=AlgorithmConfig().rollouts( + rollout_fragment_length=100, num_rollout_workers=0 + ), ) for _ in range(n) ] diff --git a/rllib/tests/test_perf.py b/rllib/tests/test_perf.py index e9f0f8394e0fa..9f6faf0102224 100644 --- a/rllib/tests/test_perf.py +++ b/rllib/tests/test_perf.py @@ -3,6 +3,7 @@ import unittest import ray +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.evaluation.tests.test_rollout_worker import MockPolicy @@ -23,8 +24,11 @@ def test_baseline_performance(self): for _ in range(20): ev = RolloutWorker( env_creator=lambda _: gym.make("CartPole-v0"), - policy_spec=MockPolicy, - rollout_fragment_length=100, + default_policy_class=MockPolicy, + config=AlgorithmConfig().rollouts( + rollout_fragment_length=100, + num_rollout_workers=0, + ), ) start = time.time() count = 0 diff --git a/rllib/utils/debug/memory.py b/rllib/utils/debug/memory.py index 7e991262e84f2..46e3d07d9610e 100644 --- a/rllib/utils/debug/memory.py +++ b/rllib/utils/debug/memory.py @@ -189,7 +189,7 @@ def code(): init=None, code=code, # How many times to repeat the function call? - repeats=repeats or 200, + repeats=repeats or 100, # How many times to re-try if we find a suspicious memory # allocation? max_num_trials=max_num_trials, diff --git a/rllib/utils/policy.py b/rllib/utils/policy.py index 895ec15e09387..48cb30417982e 100644 --- a/rllib/utils/policy.py +++ b/rllib/utils/policy.py @@ -83,6 +83,11 @@ def create_policy_for_framework( session_creator: An optional tf1.Session creation callable. seed: Optional random seed. """ + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + + if isinstance(merged_config, AlgorithmConfig): + merged_config = merged_config.to_dict() + framework = merged_config.get("framework", "tf") # Tf. if framework in ["tf2", "tf", "tfe"]: diff --git a/rllib/utils/pre_checks/multi_agent.py b/rllib/utils/pre_checks/multi_agent.py deleted file mode 100644 index 576a9a95c9210..0000000000000 --- a/rllib/utils/pre_checks/multi_agent.py +++ /dev/null @@ -1,137 +0,0 @@ -import logging -from typing import Tuple - -from ray.rllib.policy.policy import PolicySpec -from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID -from ray.rllib.utils.annotations import DeveloperAPI -from ray.rllib.utils.from_config import from_config -from ray.rllib.utils.policy import validate_policy_id -from ray.rllib.utils.typing import ( - MultiAgentPolicyConfigDict, - PartialAlgorithmConfigDict, -) - -logger = logging.getLogger(__name__) - - -@DeveloperAPI -def check_multi_agent( - config: PartialAlgorithmConfigDict, -) -> Tuple[MultiAgentPolicyConfigDict, bool]: - """Checks, whether a (partial) config defines a multi-agent setup. - - Args: - config: The user/Algorithm/Policy config to check for multi-agent. - - Returns: - Tuple consisting of the resulting (all fixed) multi-agent policy - dict and bool indicating whether we have a multi-agent setup or not. - - Raises: - KeyError: If `config` does not contain a "multiagent" key or if there - is an invalid key inside the "multiagent" config or if any policy - in the "policies" dict has a non-str ID (key). - ValueError: If any subkey of the "multiagent" dict has an invalid - value. - """ - if "multiagent" not in config: - raise KeyError( - "Your `config` to be checked for a multi-agent setup must have " - "the 'multiagent' key defined!" - ) - multiagent_config = config["multiagent"] - - policies = multiagent_config.get("policies") - - # Check for invalid sub-keys of multiagent config. - from ray.rllib.algorithms.algorithm import COMMON_CONFIG - - allowed = list(COMMON_CONFIG["multiagent"].keys()) - if ( - "replay_mode" in multiagent_config - and multiagent_config["replay_mode"] == "independent" - ): - multiagent_config.pop("replay_mode") - if any(k not in allowed for k in multiagent_config.keys()): - raise KeyError( - f"You have invalid keys in your 'multiagent' config dict! " - f"The only allowed keys are: {allowed}." - ) - - # Nothing specified in config dict -> Assume simple single agent setup - # with DEFAULT_POLICY_ID as only policy. - if not policies: - policies = {DEFAULT_POLICY_ID} - # Policies given as set/list/tuple (of PolicyIDs) -> Setup each policy - # automatically via empty PolicySpec (will make RLlib infer obs- and action spaces - # as well as the Policy's class). - if isinstance(policies, (set, list, tuple)): - policies = multiagent_config["policies"] = { - pid: PolicySpec() for pid in policies - } - - # Check each defined policy ID and spec. - for pid, policy_spec in policies.copy().items(): - # Make sure our Policy ID is ok. - validate_policy_id(pid, error=False) - - # Policy IDs must be strings. - if not isinstance(pid, str): - raise KeyError(f"Policy IDs must always be of type `str`, got {type(pid)}") - # Convert to PolicySpec if plain list/tuple. - if not isinstance(policy_spec, PolicySpec): - # Values must be lists/tuples of len 4. - if not isinstance(policy_spec, (list, tuple)) or len(policy_spec) != 4: - raise ValueError( - "Policy specs must be tuples/lists of " - "(cls or None, obs_space, action_space, config), " - f"got {policy_spec}" - ) - policies[pid] = PolicySpec(*policy_spec) - - # Config is None -> Set to {}. - if policies[pid].config is None: - policies[pid].config = {} - # Config not a dict. - elif not isinstance(policies[pid].config, dict): - raise ValueError( - f"Multiagent policy config for {pid} must be a dict, " - f"but got {type(policies[pid].config)}!" - ) - - # Check other "multiagent" sub-keys' values. - if multiagent_config.get("count_steps_by", "env_steps") not in [ - "env_steps", - "agent_steps", - ]: - raise ValueError( - "config.multiagent.count_steps_by must be one of " - "[env_steps|agent_steps], not " - f"{multiagent_config['count_steps_by']}!" - ) - - # Attempt to create a `policy_mapping_fn` from config dict. Helpful - # is users would like to specify custom callable classes in yaml files. - if isinstance(multiagent_config.get("policy_mapping_fn"), dict): - multiagent_config["policy_mapping_fn"] = from_config( - multiagent_config["policy_mapping_fn"] - ) - # Check `policies_to_train` for invalid entries. - if isinstance(multiagent_config["policies_to_train"], (list, set, tuple)): - if len(multiagent_config["policies_to_train"]) == 0: - logger.warning( - "`config.multiagent.policies_to_train` is empty! " - "Make sure - if you would like to learn at least one policy - " - "to add its ID to that list." - ) - for pid in multiagent_config["policies_to_train"]: - if pid not in policies: - logger.warning( - "`config.multiagent.policies_to_train` contains policy " - f"ID ({pid}) that was not defined in `config.multiagent.policies!" - ) - - # Is this a multi-agent setup? True, iff DEFAULT_POLICY_ID is only - # PolicyID found in policies dict. - is_multiagent = len(policies) > 1 or DEFAULT_POLICY_ID not in policies - return policies, is_multiagent diff --git a/rllib/utils/serialization.py b/rllib/utils/serialization.py index ce2e06e7b303f..72787cfbdf4c0 100644 --- a/rllib/utils/serialization.py +++ b/rllib/utils/serialization.py @@ -156,30 +156,33 @@ def gym_space_from_dict(d: Dict) -> gym.spaces.Space: def __common(d: Dict): """Common updates to the dict before we use it to construct spaces""" - del d["space"] - if "dtype" in d: - d["dtype"] = np.dtype(d["dtype"]) - return d + ret = d.copy() + del ret["space"] + if "dtype" in ret: + ret["dtype"] = np.dtype(ret["dtype"]) + return ret def _box(d: Dict) -> gym.spaces.Box: - d.update( + ret = d.copy() + ret.update( { "low": _deserialize_ndarray(d["low"]), "high": _deserialize_ndarray(d["high"]), } ) - return gym.spaces.Box(**__common(d)) + return gym.spaces.Box(**__common(ret)) def _discrete(d: Dict) -> gym.spaces.Discrete: return gym.spaces.Discrete(**__common(d)) def _multi_discrete(d: Dict) -> gym.spaces.Discrete: - d.update( + ret = d.copy() + ret.update( { - "nvec": _deserialize_ndarray(d["nvec"]), + "nvec": _deserialize_ndarray(ret["nvec"]), } ) - return gym.spaces.MultiDiscrete(**__common(d)) + return gym.spaces.MultiDiscrete(**__common(ret)) def _tuple(d: Dict) -> gym.spaces.Discrete: spaces = [gym_space_from_dict(sp) for sp in d["spaces"]] @@ -197,8 +200,7 @@ def _repeated(d: Dict) -> Repeated: return Repeated(child_space=child_space, max_len=d["max_len"]) def _flex_dict(d: Dict) -> FlexDict: - del d["space"] - spaces = {k: gym_space_from_dict(s) for k, s in d.items()} + spaces = {k: gym_space_from_dict(s) for k, s in d.items() if k != "space"} return FlexDict(spaces=spaces) space_map = { diff --git a/rllib/utils/test_utils.py b/rllib/utils/test_utils.py index 71ec57392ea91..b9d515d0b5549 100644 --- a/rllib/utils/test_utils.py +++ b/rllib/utils/test_utils.py @@ -521,7 +521,6 @@ def check_train_results(train_results): # Import these here to avoid circular dependencies. from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY - from ray.rllib.utils.pre_checks.multi_agent import check_multi_agent # Assert that some keys are where we would expect them. for key in [ @@ -554,7 +553,7 @@ def check_train_results(train_results): key in train_results ), f"'{key}' not found in `train_results` ({train_results})!" - _, is_multi_agent = check_multi_agent(train_results["config"]) + is_multi_agent = train_results["config"].is_multi_agent() # Check in particular the "info" dict. info = train_results["info"] diff --git a/rllib/utils/tests/test_errors.py b/rllib/utils/tests/test_errors.py index 04a321b4a5fad..aa69a3b6abf2b 100644 --- a/rllib/utils/tests/test_errors.py +++ b/rllib/utils/tests/test_errors.py @@ -24,56 +24,56 @@ def test_no_gpus_error(self): This test will only work ok on a CPU-only machine. """ - config = impala.DEFAULT_CONFIG.copy() - env = "CartPole-v0" + config = impala.ImpalaConfig().environment("CartPole-v0") for _ in framework_iterator(config): self.assertRaisesRegex( RuntimeError, # (?s): "dot matches all" (also newlines). "(?s)Found 0 GPUs on your machine.+To change the config", - lambda: impala.Impala(config=config, env=env), + lambda: config.build(), ) def test_bad_envs(self): """Tests different "bad env" errors.""" - config = pg.PGConfig() - config.rollouts(num_rollout_workers=0) + config = ( + pg.PGConfig().rollouts(num_rollout_workers=0) + # Non existing/non-registered gym env string. + .environment("Alien-Attack-v42") + ) - # Non existing/non-registered gym env string. - env = "Alien-Attack-v42" for _ in framework_iterator(config): self.assertRaisesRegex( EnvError, - f"The env string you provided \\('{env}'\\) is", - lambda: config.build(env=env), + f"The env string you provided \\('{config.env}'\\) is", + lambda: config.build(), ) # Malformed gym env string (must have v\d at end). - env = "Alien-Attack-part-42" + config.environment("Alien-Attack-part-42") for _ in framework_iterator(config): self.assertRaisesRegex( EnvError, - f"The env string you provided \\('{env}'\\) is", - lambda: pg.PG(config=config, env=env), + f"The env string you provided \\('{config.env}'\\) is", + lambda: config.build(), ) # Non-existing class in a full-class-path. - env = "ray.rllib.examples.env.random_env.RandomEnvThatDoesntExist" + config.environment("ray.rllib.examples.env.random_env.RandomEnvThatDoesntExist") for _ in framework_iterator(config): self.assertRaisesRegex( EnvError, - f"The env string you provided \\('{env}'\\) is", - lambda: pg.PG(config=config, env=env), + f"The env string you provided \\('{config.env}'\\) is", + lambda: config.build(), ) # Non-existing module inside a full-class-path. - env = "ray.rllib.examples.env.module_that_doesnt_exist.SomeEnv" + config.environment("ray.rllib.examples.env.module_that_doesnt_exist.SomeEnv") for _ in framework_iterator(config): self.assertRaisesRegex( EnvError, - f"The env string you provided \\('{env}'\\) is", - lambda: pg.PG(config=config, env=env), + f"The env string you provided \\('{config.env}'\\) is", + lambda: config.build(), ) diff --git a/rllib/utils/tf_utils.py b/rllib/utils/tf_utils.py index d83e5eb9d6131..fda9da630192c 100644 --- a/rllib/utils/tf_utils.py +++ b/rllib/utils/tf_utils.py @@ -18,6 +18,7 @@ ) if TYPE_CHECKING: + from ray.rllib.policy.policy import Policy from ray.rllib.policy.tf_policy import TFPolicy logger = logging.getLogger(__name__) @@ -227,8 +228,8 @@ def get_placeholder( @PublicAPI def get_tf_eager_cls_if_necessary( - orig_cls: Type["TFPolicy"], config: PartialAlgorithmConfigDict -) -> Type["TFPolicy"]: + orig_cls: Type["Policy"], config: PartialAlgorithmConfigDict +) -> Type["Policy"]: """Returns the corresponding tf-eager class for a given TFPolicy class. Args: