diff --git a/doc/source/rllib/doc_code/catalog_guide.py b/doc/source/rllib/doc_code/catalog_guide.py index 6a9a5ef1f0839..e399616935090 100644 --- a/doc/source/rllib/doc_code/catalog_guide.py +++ b/doc/source/rllib/doc_code/catalog_guide.py @@ -102,7 +102,7 @@ # __sphinx_doc_algo_configs_begin__ from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog from ray.rllib.algorithms.ppo import PPOConfig -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec class MyPPOCatalog(PPOCatalog): @@ -119,9 +119,7 @@ def __init__(self, *args, **kwargs): ) # Specify the catalog to use for the PPORLModule. -config = config.rl_module( - rl_module_spec=SingleAgentRLModuleSpec(catalog_class=MyPPOCatalog) -) +config = config.rl_module(rl_module_spec=RLModuleSpec(catalog_class=MyPPOCatalog)) # This is how RLlib constructs a PPORLModule # It will say "Hi from within PPORLModule!". ppo = config.build() diff --git a/doc/source/rllib/doc_code/rlmodule_guide.py b/doc/source/rllib/doc_code/rlmodule_guide.py index 95b2245c6c4d1..00844d48a443f 100644 --- a/doc/source/rllib/doc_code/rlmodule_guide.py +++ b/doc/source/rllib/doc_code/rlmodule_guide.py @@ -28,12 +28,12 @@ # __constructing-rlmodules-sa-begin__ import gymnasium as gym -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule env = gym.make("CartPole-v1") -spec = SingleAgentRLModuleSpec( +spec = RLModuleSpec( module_class=DiscreteBCTorchModule, observation_space=env.observation_space, action_space=env.action_space, @@ -46,19 +46,19 @@ # __constructing-rlmodules-ma-begin__ import gymnasium as gym -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule -spec = MultiAgentRLModuleSpec( +spec = MultiRLModuleSpec( module_specs={ - "module_1": SingleAgentRLModuleSpec( + "module_1": RLModuleSpec( module_class=DiscreteBCTorchModule, observation_space=gym.spaces.Box(low=-1, high=1, shape=(10,)), action_space=gym.spaces.Discrete(2), model_config_dict={"fcnet_hiddens": [32]}, ), - "module_2": SingleAgentRLModuleSpec( + "module_2": RLModuleSpec( module_class=DiscreteBCTorchModule, observation_space=gym.spaces.Box(low=-1, high=1, shape=(5,)), action_space=gym.spaces.Discrete(2), @@ -67,13 +67,13 @@ }, ) -marl_module = spec.build() +multi_rl_module = spec.build() # __constructing-rlmodules-ma-end__ # __pass-specs-to-configs-sa-begin__ import gymnasium as gym -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule from ray.rllib.core.testing.bc_algorithm import BCConfigTest @@ -84,7 +84,7 @@ .environment("CartPole-v1") .rl_module( model_config_dict={"fcnet_hiddens": [32, 32]}, - rl_module_spec=SingleAgentRLModuleSpec(module_class=DiscreteBCTorchModule), + rl_module_spec=RLModuleSpec(module_class=DiscreteBCTorchModule), ) ) @@ -94,8 +94,8 @@ # __pass-specs-to-configs-ma-begin__ import gymnasium as gym -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule from ray.rllib.core.testing.bc_algorithm import BCConfigTest from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole @@ -107,8 +107,8 @@ .environment(MultiAgentCartPole, env_config={"num_agents": 2}) .rl_module( model_config_dict={"fcnet_hiddens": [32, 32]}, - rl_module_spec=MultiAgentRLModuleSpec( - module_specs=SingleAgentRLModuleSpec(module_class=DiscreteBCTorchModule) + rl_module_spec=MultiRLModuleSpec( + module_specs=RLModuleSpec(module_class=DiscreteBCTorchModule) ), ) ) @@ -117,11 +117,11 @@ # __convert-sa-to-ma-begin__ import gymnasium as gym -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule env = gym.make("CartPole-v1") -spec = SingleAgentRLModuleSpec( +spec = RLModuleSpec( module_class=DiscreteBCTorchModule, observation_space=env.observation_space, action_space=env.action_space, @@ -129,7 +129,7 @@ ) module = spec.build() -marl_module = module.as_multi_agent() +multi_rl_module = module.as_multi_rl_module() # __convert-sa-to-ma-end__ @@ -279,12 +279,9 @@ def output_specs_exploration(self) -> SpecType: # __extend-spec-checking-type-specs-end__ -# __write-custom-marlmodule-shared-enc-begin__ +# __write-custom-multirlmodule-shared-enc-begin__ from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule -from ray.rllib.core.rl_module.marl_module import ( - MultiAgentRLModuleConfig, - MultiAgentRLModule, -) +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleConfig, MultiRLModule import torch import torch.nn as nn @@ -325,8 +322,8 @@ def _common_forward(self, batch): return {"action_dist": torch.distributions.Categorical(logits=action_logits)} -class BCTorchMultiAgentModuleWithSharedEncoder(MultiAgentRLModule): - def __init__(self, config: MultiAgentRLModuleConfig) -> None: +class BCTorchMultiAgentModuleWithSharedEncoder(MultiRLModule): + def __init__(self, config: MultiRLModuleConfig) -> None: super().__init__(config) def setup(self): @@ -353,18 +350,18 @@ def setup(self): self._rl_modules = rl_modules -# __write-custom-marlmodule-shared-enc-end__ +# __write-custom-multirlmodule-shared-enc-end__ -# __pass-custom-marlmodule-shared-enc-begin__ +# __pass-custom-multirlmodule-shared-enc-begin__ import gymnasium as gym -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec -spec = MultiAgentRLModuleSpec( - marl_module_class=BCTorchMultiAgentModuleWithSharedEncoder, +spec = MultiRLModuleSpec( + multi_rl_module_class=BCTorchMultiAgentModuleWithSharedEncoder, module_specs={ - "local_2d": SingleAgentRLModuleSpec( + "local_2d": RLModuleSpec( observation_space=gym.spaces.Dict( { "global": gym.spaces.Box(low=-1, high=1, shape=(2,)), @@ -374,7 +371,7 @@ def setup(self): action_space=gym.spaces.Discrete(2), model_config_dict={"fcnet_hiddens": [64]}, ), - "local_5d": SingleAgentRLModuleSpec( + "local_5d": RLModuleSpec( observation_space=gym.spaces.Dict( { "global": gym.spaces.Box(low=-1, high=1, shape=(2,)), @@ -388,7 +385,7 @@ def setup(self): ) module = spec.build() -# __pass-custom-marlmodule-shared-enc-end__ +# __pass-custom-multirlmodule-shared-enc-end__ # __checkpointing-begin__ @@ -398,7 +395,7 @@ def setup(self): from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule -from ray.rllib.core.rl_module.rl_module import RLModule, SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec config = ( PPOConfig() @@ -407,7 +404,7 @@ def setup(self): ) env = gym.make("CartPole-v1") # Create an RL Module that we would like to checkpoint -module_spec = SingleAgentRLModuleSpec( +module_spec = RLModuleSpec( module_class=PPOTorchRLModule, observation_space=env.observation_space, action_space=env.action_space, diff --git a/doc/source/rllib/key-concepts.rst b/doc/source/rllib/key-concepts.rst index c9703333de216..a9c9e56809e3f 100644 --- a/doc/source/rllib/key-concepts.rst +++ b/doc/source/rllib/key-concepts.rst @@ -122,7 +122,7 @@ implement reinforcement learning policies in RLlib and can therefore be found in where their exploration and inference logic is used to sample from an environment. The second place in RLlib where RL Modules commonly occur is the :py:class:`~ray.rllib.core.learner.learner.Learner`, where their training logic is used in training the neural network. -RL Modules extend to the multi-agent case, where a single :py:class:`~ray.rllib.core.rl_module.marl_module.MultiAgentRLModule` +RL Modules extend to the multi-agent case, where a single :py:class:`~ray.rllib.core.rl_module.multi_rl_module.MultiRLModule` contains multiple RL Modules. The following figure is a rough sketch of how the above can look in practice: .. image:: images/rllib-concepts-rlmodules-sketch.png diff --git a/doc/source/rllib/package_ref/algorithm.rst b/doc/source/rllib/package_ref/algorithm.rst index af69e4f8eedae..726f5e1d61f39 100644 --- a/doc/source/rllib/package_ref/algorithm.rst +++ b/doc/source/rllib/package_ref/algorithm.rst @@ -99,7 +99,7 @@ Getter methods ~AlgorithmConfig.get_default_learner_class ~AlgorithmConfig.get_default_rl_module_spec ~AlgorithmConfig.get_evaluation_config_object - ~AlgorithmConfig.get_marl_module_spec + ~AlgorithmConfig.get_multi_rl_module_spec ~AlgorithmConfig.get_multi_agent_setup ~AlgorithmConfig.get_rollout_fragment_length diff --git a/doc/source/rllib/package_ref/rl_modules.rst b/doc/source/rllib/package_ref/rl_modules.rst index f6d410aa174c2..e6147162cd572 100644 --- a/doc/source/rllib/package_ref/rl_modules.rst +++ b/doc/source/rllib/package_ref/rl_modules.rst @@ -23,9 +23,9 @@ Single Agent :nosignatures: :toctree: doc/ - SingleAgentRLModuleSpec - SingleAgentRLModuleSpec.build - SingleAgentRLModuleSpec.get_rl_module_config + RLModuleSpec + RLModuleSpec.build + RLModuleSpec.get_rl_module_config RLModule Configuration +++++++++++++++++++++++ @@ -39,18 +39,18 @@ RLModule Configuration RLModuleConfig.from_dict RLModuleConfig.get_catalog -Multi Agent -++++++++++++ +Multi RLModule (multi-agent) +++++++++++++++++++++++++++++ -.. currentmodule:: ray.rllib.core.rl_module.marl_module +.. currentmodule:: ray.rllib.core.rl_module.multi_rl_module .. autosummary:: :nosignatures: :toctree: doc/ - MultiAgentRLModuleSpec - MultiAgentRLModuleSpec.build - MultiAgentRLModuleSpec.get_marl_config + MultiRLModuleSpec + MultiRLModuleSpec.build + MultiRLModuleSpec.get_multi_rl_module_config @@ -68,7 +68,7 @@ Constructor :toctree: doc/ RLModule - RLModule.as_multi_agent + RLModule.as_multi_rl_module Forward methods @@ -119,7 +119,7 @@ Saving and Loading Multi Agent RL Module API ------------------------- -.. currentmodule:: ray.rllib.core.rl_module.marl_module +.. currentmodule:: ray.rllib.core.rl_module.multi_rl_module Constructor +++++++++++ @@ -128,9 +128,9 @@ Constructor :nosignatures: :toctree: doc/ - MultiAgentRLModule - MultiAgentRLModule.setup - MultiAgentRLModule.as_multi_agent + MultiRLModule + MultiRLModule.setup + MultiRLModule.as_multi_rl_module Modifying the underlying RL modules ++++++++++++++++++++++++++++++++++++ @@ -139,8 +139,8 @@ Modifying the underlying RL modules :nosignatures: :toctree: doc/ - ~MultiAgentRLModule.add_module - ~MultiAgentRLModule.remove_module + ~MultiRLModule.add_module + ~MultiRLModule.remove_module Saving and Loading ++++++++++++++++++++++ @@ -149,5 +149,5 @@ Saving and Loading :nosignatures: :toctree: doc/ - ~MultiAgentRLModule.save_state - ~MultiAgentRLModule.load_state + ~MultiRLModule.save_to_path + ~MultiRLModule.restore_from_path diff --git a/doc/source/rllib/rllib-catalogs.rst b/doc/source/rllib/rllib-catalogs.rst index 056f404dbf9f6..eaef6248706cd 100644 --- a/doc/source/rllib/rllib-catalogs.rst +++ b/doc/source/rllib/rllib-catalogs.rst @@ -146,9 +146,9 @@ Since Catalogs effectively control what ``models`` and ``distributions`` RLlib u they are also part of RLlib’s configurations. As the primary entry point for configuring RLlib, :py:class:`~ray.rllib.algorithms.algorithm_config.AlgorithmConfig` is the place where you can configure the Catalogs of the RLModules that are created. -You set the ``catalog class`` by going through the :py:class:`~ray.rllib.core.rl_module.rl_module.SingleAgentRLModuleSpec` -or :py:class:`~ray.rllib.core.rl_module.marl_module.MultiAgentRLModuleSpec` of an AlgorithmConfig. -For example, in heterogeneous multi-agent cases, you modify the MultiAgentRLModuleSpec. +You set the ``catalog class`` by going through the :py:class:`~ray.rllib.core.rl_module.rl_module.RLModuleSpec` +or :py:class:`~ray.rllib.core.rl_module.multi_rl_module.MultiRLModuleSpec` of an AlgorithmConfig. +For example, in heterogeneous multi-agent cases, you modify the MultiRLModuleSpec. .. image:: images/catalog/catalog_rlmspecs_diagram.svg :align: center diff --git a/doc/source/rllib/rllib-learner.rst b/doc/source/rllib/rllib-learner.rst index edc69dbfbeffb..d5fa5c3f6280d 100644 --- a/doc/source/rllib/rllib-learner.rst +++ b/doc/source/rllib/rllib-learner.rst @@ -115,7 +115,7 @@ and :py:class:`~ray.rllib.core.learner.learner.Learner` APIs via the :py:class:` import ray from ray.rllib.algorithms.ppo import PPOConfig - from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec + from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.core.learner.learner_group import LearnerGroup diff --git a/doc/source/rllib/rllib-rlmodule.rst b/doc/source/rllib/rllib-rlmodule.rst index 9f5c994bdcc96..a944a5050ab91 100644 --- a/doc/source/rllib/rllib-rlmodule.rst +++ b/doc/source/rllib/rllib-rlmodule.rst @@ -75,7 +75,7 @@ Constructing RL Modules ----------------------- The RLModule API provides a unified way to define custom reinforcement learning models in RLlib. This API enables you to design and implement your own models to suit specific needs. -To maintain consistency and usability, RLlib offers a standardized approach for defining module objects for both single-agent and multi-agent reinforcement learning environments. This is achieved through the :py:class:`~ray.rllib.core.rl_module.rl_module.SingleAgentRLModuleSpec` and :py:class:`~ray.rllib.core.rl_module.marl_module.MultiAgentRLModuleSpec` classes. The built-in RLModules in RLlib follow this consistent design pattern, making it easier for you to understand and utilize these modules. +To maintain consistency and usability, RLlib offers a standardized approach for defining module objects for both single-agent and multi-agent reinforcement learning environments. This is achieved through the :py:class:`~ray.rllib.core.rl_module.rl_module.RLModuleSpec` and :py:class:`~ray.rllib.core.rl_module.multi_rl_module.MultiRLModuleSpec` classes. The built-in RLModules in RLlib follow this consistent design pattern, making it easier for you to understand and utilize these modules. .. tab-set:: @@ -122,9 +122,9 @@ You can pass RL Module specs to the algorithm configuration to be used by the al Writing Custom Single Agent RL Modules -------------------------------------- -For single-agent algorithms (e.g., PPO, DQN) or independent multi-agent algorithms (e.g., PPO-MultiAgent), use :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule`. For more advanced multi-agent use cases with a shared communication between agents, extend the :py:class:`~ray.rllib.core.rl_module.marl_module.MultiAgentRLModule` class. +For single-agent algorithms (e.g., PPO, DQN) or independent multi-agent algorithms (e.g., PPO-MultiAgent), use :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule`. For more advanced multi-agent use cases with a shared communication between agents, extend the :py:class:`~ray.rllib.core.rl_module.multi_rl_module.MultiRLModule` class. -RLlib treats single-agent modules as a special case of :py:class:`~ray.rllib.core.rl_module.marl_module.MultiAgentRLModule` with only one module. Create the multi-agent representation of all RLModules by calling :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.as_multi_agent`. For example: +RLlib treats single-agent modules as a special case of :py:class:`~ray.rllib.core.rl_module.multi_rl_module.MultiRLModule` with only one module. Create the multi-agent representation of all RLModules by calling :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.as_multi_rl_module`. For example: .. literalinclude:: doc_code/rlmodule_guide.py :language: python @@ -309,26 +309,26 @@ To learn more, see the `SpecType` documentation. Writing Custom Multi-Agent RL Modules (Advanced) ------------------------------------------------ -For multi-agent modules, RLlib implements :py:class:`~ray.rllib.core.rl_module.marl_module.MultiAgentRLModule`, which is a dictionary of :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` objects, one for each policy, and possibly some shared modules. The base-class implementation works for most of use cases that need to define independent neural networks for sub-groups of agents. For more complex, multi-agent use cases, where the agents share some part of their neural network, you should inherit from this class and override the default implementation. +For multi-agent modules, RLlib implements :py:class:`~ray.rllib.core.rl_module.multi_rl_module.MultiAgentRLModule`, which is a dictionary of :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` objects, one for each policy, and possibly some shared modules. The base-class implementation works for most of use cases that need to define independent neural networks for sub-groups of agents. For more complex, multi-agent use cases, where the agents share some part of their neural network, you should inherit from this class and override the default implementation. -The :py:class:`~ray.rllib.core.rl_module.marl_module.MultiAgentRLModule` offers an API for constructing custom models tailored to specific needs. The key method for this customization is :py:meth:`~ray.rllib.core.rl_module.marl_module.MultiAgentRLModule`.build. +The :py:class:`~ray.rllib.core.rl_module.multi_rl_module.MultiRLModule` offers an API for constructing custom models tailored to specific needs. The key method for this customization is :py:meth:`~ray.rllib.core.rl_module.multi_rl_module.MultiRLModule`.build. The following example creates a custom multi-agent RL module with underlying modules. The modules share an encoder, which gets applied to the global part of the observations space. The local part passes through a separate encoder, specific to each policy. .. literalinclude:: doc_code/rlmodule_guide.py :language: python - :start-after: __write-custom-marlmodule-shared-enc-begin__ - :end-before: __write-custom-marlmodule-shared-enc-end__ + :start-after: __write-custom-multirlmodule-shared-enc-begin__ + :end-before: __write-custom-multirlmodule-shared-enc-end__ -To construct this custom multi-agent RL module, pass the class to the :py:class:`~ray.rllib.core.rl_module.marl_module.MultiAgentRLModuleSpec` constructor. Also, pass the :py:class:`~ray.rllib.core.rl_module.rl_module.SingleAgentRLModuleSpec` for each agent because RLlib requires the observation, action spaces, and model hyper-parameters for each agent. +To construct this custom multi-agent RL module, pass the class to the :py:class:`~ray.rllib.core.rl_module.multi_rl_module.MultiRLModuleSpec` constructor. Also, pass the :py:class:`~ray.rllib.core.rl_module.rl_module.RLModuleSpec` for each agent because RLlib requires the observation, action spaces, and model hyper-parameters for each agent. .. literalinclude:: doc_code/rlmodule_guide.py :language: python - :start-after: __pass-custom-marlmodule-shared-enc-begin__ - :end-before: __pass-custom-marlmodule-shared-enc-end__ + :start-after: __pass-custom-multirlmodule-shared-enc-begin__ + :end-before: __pass-custom-multirlmodule-shared-enc-end__ Extending Existing RLlib RL Modules @@ -361,7 +361,7 @@ There are two possible ways to extend existing RL Modules: # Pass in the custom RL Module class to the spec algo_config = algo_config.rl_module( - rl_module_spec=SingleAgentRLModuleSpec(module_class=MyPPORLModule) + rl_module_spec=RLModuleSpec(module_class=MyPPORLModule) ) A concrete example: If you want to replace the default encoder that RLlib builds for torch, PPO and a given observation space, @@ -402,7 +402,7 @@ There are two possible ways to extend existing RL Modules: # Pass in the custom catalog class to the spec algo_config = algo_config.rl_module( - rl_module_spec=SingleAgentRLModuleSpec(catalog_class=MyAwesomeCatalog) + rl_module_spec=RLModuleSpec(catalog_class=MyAwesomeCatalog) ) diff --git a/rllib/BUILD b/rllib/BUILD index 36ad7ab1902b9..6078c6cb1f3fd 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -1384,10 +1384,10 @@ py_test( ) py_test( - name = "test_marl_module", + name = "test_multi_rl_module", tags = ["team:rllib", "core"], size = "medium", - srcs = ["core/rl_module/tests/test_marl_module.py"] + srcs = ["core/rl_module/tests/test_multi_rl_module.py"] ) py_test( diff --git a/rllib/algorithms/algorithm.py b/rllib/algorithms/algorithm.py index 3527ce43f52a6..0d5d47c20ca14 100644 --- a/rllib/algorithms/algorithm.py +++ b/rllib/algorithms/algorithm.py @@ -51,12 +51,12 @@ DEFAULT_MODULE_ID, ) from ray.rllib.core.columns import Columns -from ray.rllib.core.rl_module.marl_module import ( - MultiAgentRLModule, - MultiAgentRLModuleSpec, +from ray.rllib.core.rl_module.multi_rl_module import ( + MultiRLModule, + MultiRLModuleSpec, ) from ray.rllib.core.rl_module import validate_module_id -from ray.rllib.core.rl_module.rl_module import RLModule, SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec from ray.rllib.env.env_context import EnvContext from ray.rllib.env.env_runner import EnvRunner from ray.rllib.env.env_runner_group import EnvRunnerGroup @@ -789,27 +789,30 @@ def setup(self, config: AlgorithmConfig) -> None: local_env_runner = self.env_runner_group.local_env_runner env = spaces = None # EnvRunners have a `module` property, which stores the RLModule - # (or MARLModule, which is a subclass of RLModule, in the multi-agent case). + # (or MultiRLModule, which is a subclass of RLModule, in the multi-module + # case, e.g. for multi-agent). if ( hasattr(local_env_runner, "module") and local_env_runner.module is not None ): - marl_module_dict = dict(local_env_runner.module.as_multi_agent()) + multi_rl_module_dict = dict( + local_env_runner.module.as_multi_rl_module() + ) env = local_env_runner.env spaces = { mid: (mod.config.observation_space, mod.config.action_space) - for mid, mod in marl_module_dict.items() + for mid, mod in multi_rl_module_dict.items() } policy_dict, _ = self.config.get_multi_agent_setup( env=env, spaces=spaces ) - module_spec: MultiAgentRLModuleSpec = self.config.get_marl_module_spec( + module_spec: MultiRLModuleSpec = self.config.get_multi_rl_module_spec( policy_dict=policy_dict ) # TODO (Sven): Deprecate this path: Old stack API RolloutWorkers and - # DreamerV3's EnvRunners have a `marl_module_spec` property. - elif hasattr(local_env_runner, "marl_module_spec"): - module_spec: MultiAgentRLModuleSpec = local_env_runner.marl_module_spec + # DreamerV3's EnvRunners have a `multi_rl_module_spec` property. + elif hasattr(local_env_runner, "multi_rl_module_spec"): + module_spec: MultiRLModuleSpec = local_env_runner.multi_rl_module_spec else: raise AttributeError( "Your local EnvRunner/RolloutWorker does NOT have any property " @@ -821,14 +824,14 @@ def setup(self, config: AlgorithmConfig) -> None: # Check if there are modules to load from the `module_spec`. rl_module_ckpt_dirs = {} - marl_module_ckpt_dir = module_spec.load_state_path + multi_rl_module_ckpt_dir = module_spec.load_state_path modules_to_load = module_spec.modules_to_load for module_id, sub_module_spec in module_spec.module_specs.items(): if sub_module_spec.load_state_path: rl_module_ckpt_dirs[module_id] = sub_module_spec.load_state_path - if marl_module_ckpt_dir or rl_module_ckpt_dirs: + if multi_rl_module_ckpt_dir or rl_module_ckpt_dirs: self.learner_group.load_module_state( - marl_module_ckpt_dir=marl_module_ckpt_dir, + multi_rl_module_ckpt_dir=multi_rl_module_ckpt_dir, modules_to_load=modules_to_load, rl_module_ckpt_dirs=rl_module_ckpt_dirs, ) @@ -1776,7 +1779,7 @@ def get_module(self, module_id: ModuleID = DEFAULT_MODULE_ID) -> RLModule: local worker's (EnvRunner's) MARLModule. """ module = self.env_runner.module - if isinstance(module, MultiAgentRLModule): + if isinstance(module, MultiRLModule): return module[module_id] else: return module @@ -1785,7 +1788,7 @@ def get_module(self, module_id: ModuleID = DEFAULT_MODULE_ID) -> RLModule: def add_module( self, module_id: ModuleID, - module_spec: SingleAgentRLModuleSpec, + module_spec: RLModuleSpec, *, config_overrides: Optional[Dict] = None, new_agent_to_module_mapping_fn: Optional[AgentToModuleMappingFn] = None, @@ -1793,7 +1796,7 @@ def add_module( add_to_learners: bool = True, add_to_env_runners: bool = True, add_to_eval_env_runners: bool = True, - ) -> MultiAgentRLModuleSpec: + ) -> MultiRLModuleSpec: """Adds a new (single-agent) RLModule to this Algorithm's MARLModule. Note that an Algorithm has up to 3 different components to which to add @@ -1833,7 +1836,7 @@ def add_module( validate_module_id(module_id, error=True) # The to-be-returned new MultiAgentRLModuleSpec. - marl_spec = None + multi_rl_module_spec = None if not self.config.is_multi_agent(): raise RuntimeError( @@ -1851,7 +1854,7 @@ def add_module( # Add to Learners and sync weights. if add_to_learners: - marl_spec = self.learner_group.add_module( + multi_rl_module_spec = self.learner_group.add_module( module_id=module_id, module_spec=module_spec, config_overrides=config_overrides, @@ -1869,7 +1872,7 @@ def add_module( ) if new_agent_to_module_mapping_fn is not None: self.config.multi_agent(policy_mapping_fn=new_agent_to_module_mapping_fn) - self.config.rl_module(rl_module_spec=marl_spec) + self.config.rl_module(rl_module_spec=multi_rl_module_spec) if new_should_module_be_updated is not None: self.config.multi_agent(policies_to_train=new_should_module_be_updated) self.config.freeze() @@ -1884,12 +1887,12 @@ def _add(_env_runner, _module_spec=module_spec): _env_runner.config.multi_agent( policy_mapping_fn=new_agent_to_module_mapping_fn ) - return MultiAgentRLModuleSpec.from_module(_env_runner.module) + return MultiRLModuleSpec.from_module(_env_runner.module) # Add to (training) EnvRunners and sync weights. if add_to_env_runners: - if marl_spec is None: - marl_spec = self.env_runner_group.foreach_worker(_add)[0] + if multi_rl_module_spec is None: + multi_rl_module_spec = self.env_runner_group.foreach_worker(_add)[0] else: self.env_runner_group.foreach_worker(_add) self.env_runner_group.sync_weights( @@ -1898,8 +1901,10 @@ def _add(_env_runner, _module_spec=module_spec): ) # Add to eval EnvRunners and sync weights. if add_to_eval_env_runners is True and self.eval_env_runner_group is not None: - if marl_spec is None: - marl_spec = self.eval_env_runner_group.foreach_worker(_add)[0] + if multi_rl_module_spec is None: + multi_rl_module_spec = self.eval_env_runner_group.foreach_worker(_add)[ + 0 + ] else: self.eval_env_runner_group.foreach_worker(_add) self.eval_env_runner_group.sync_weights( @@ -1907,7 +1912,7 @@ def _add(_env_runner, _module_spec=module_spec): inference_only=True, ) - return marl_spec + return multi_rl_module_spec @PublicAPI def remove_module( @@ -1948,11 +1953,11 @@ def remove_module( The new MultiAgentRLModuleSpec (after the RLModule has been removed). """ # The to-be-returned new MultiAgentRLModuleSpec. - marl_spec = None + multi_rl_module_spec = None # Remove RLModule from the LearnerGroup. if remove_from_learners: - marl_spec = self.learner_group.remove_module( + multi_rl_module_spec = self.learner_group.remove_module( module_id=module_id, new_should_module_be_updated=new_should_module_be_updated, ) @@ -1965,7 +1970,7 @@ def remove_module( self.config.algorithm_config_overrides_per_module.pop(module_id, None) if new_agent_to_module_mapping_fn is not None: self.config.multi_agent(policy_mapping_fn=new_agent_to_module_mapping_fn) - self.config.rl_module(rl_module_spec=marl_spec) + self.config.rl_module(rl_module_spec=multi_rl_module_spec) if new_should_module_be_updated is not None: self.config.multi_agent(policies_to_train=new_should_module_be_updated) self.config.freeze() @@ -1978,12 +1983,12 @@ def _remove(_env_runner): _env_runner.config.multi_agent( policy_mapping_fn=new_agent_to_module_mapping_fn ) - return MultiAgentRLModuleSpec.from_module(_env_runner.module) + return MultiRLModuleSpec.from_module(_env_runner.module) # Remove from (training) EnvRunners and sync weights. if remove_from_env_runners: - if marl_spec is None: - marl_spec = self.env_runner_group.foreach_worker(_remove)[0] + if multi_rl_module_spec is None: + multi_rl_module_spec = self.env_runner_group.foreach_worker(_remove)[0] else: self.env_runner_group.foreach_worker(_remove) self.env_runner_group.sync_weights( @@ -1996,8 +2001,10 @@ def _remove(_env_runner): remove_from_eval_env_runners is True and self.eval_env_runner_group is not None ): - if marl_spec is None: - marl_spec = self.eval_env_runner_group.foreach_worker(_remove)[0] + if multi_rl_module_spec is None: + multi_rl_module_spec = self.eval_env_runner_group.foreach_worker( + _remove + )[0] else: self.eval_env_runner_group.foreach_worker(_remove) self.eval_env_runner_group.sync_weights( @@ -2005,7 +2012,7 @@ def _remove(_env_runner): inference_only=True, ) - return marl_spec + return multi_rl_module_spec @OldAPIStack def get_policy(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> Policy: @@ -2393,7 +2400,7 @@ def add_policy( ] ] = None, evaluation_workers: bool = True, - module_spec: Optional[SingleAgentRLModuleSpec] = None, + module_spec: Optional[RLModuleSpec] = None, ) -> Optional[Policy]: """Adds a new policy to this Algorithm. @@ -2466,7 +2473,7 @@ def add_policy( module = policy.model self.learner_group.add_module( module_id=policy_id, - module_spec=SingleAgentRLModuleSpec.from_module(module), + module_spec=RLModuleSpec.from_module(module), ) # Update each Learner's `policies_to_train` information, but only diff --git a/rllib/algorithms/algorithm_config.py b/rllib/algorithms/algorithm_config.py index 245ede211e3e2..d104b4525429a 100644 --- a/rllib/algorithms/algorithm_config.py +++ b/rllib/algorithms/algorithm_config.py @@ -24,8 +24,8 @@ from ray.rllib.algorithms.callbacks import DefaultCallbacks from ray.rllib.core import DEFAULT_MODULE_ID from ray.rllib.core.rl_module import validate_module_id -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.env.env_context import EnvContext from ray.rllib.env.multi_agent_env import MultiAgentEnv from ray.rllib.env.wrappers.atari_wrappers import is_atari @@ -63,7 +63,7 @@ PartialAlgorithmConfigDict, PolicyID, ResultDict, - RLModuleSpec, + RLModuleSpecType, SampleBatchType, ) from ray.tune.logger import Logger @@ -104,11 +104,11 @@ logger = logging.getLogger(__name__) -def _check_rl_module_spec(module_spec: RLModuleSpec) -> None: - if not isinstance(module_spec, (SingleAgentRLModuleSpec, MultiAgentRLModuleSpec)): +def _check_rl_module_spec(module_spec: RLModuleSpecType) -> None: + if not isinstance(module_spec, (RLModuleSpec, MultiRLModuleSpec)): raise ValueError( "rl_module_spec must be an instance of " - "SingleAgentRLModuleSpec or MultiAgentRLModuleSpec." + "RLModuleSpec or MultiRLModuleSpec." f"Got {type(module_spec)} instead." ) @@ -919,7 +919,7 @@ def build_env_to_module_connector(self, env): AgentToModuleMapping( module_specs=( self.rl_module_spec.module_specs - if isinstance(self.rl_module_spec, MultiAgentRLModuleSpec) + if isinstance(self.rl_module_spec, MultiRLModuleSpec) else set(self.policies) ), agent_to_module_mapping_fn=self.policy_mapping_fn, @@ -1075,7 +1075,7 @@ def build_learner_connector( AgentToModuleMapping( module_specs=( self.rl_module_spec.module_specs - if isinstance(self.rl_module_spec, MultiAgentRLModuleSpec) + if isinstance(self.rl_module_spec, MultiRLModuleSpec) else set(self.policies) ), agent_to_module_mapping_fn=self.policy_mapping_fn, @@ -1092,7 +1092,7 @@ def build_learner_group( *, env: Optional[EnvType] = None, spaces: Optional[Dict[ModuleID, Tuple[gym.Space, gym.Space]]] = None, - rl_module_spec: Optional[RLModuleSpec] = None, + rl_module_spec: Optional[RLModuleSpecType] = None, ) -> "LearnerGroup": """Builds and returns a new LearnerGroup object based on settings in `self`. @@ -1118,10 +1118,10 @@ def build_learner_group( """ from ray.rllib.core.learner.learner_group import LearnerGroup - # If `spaces` or `env` provided -> Create a MARL Module Spec first to be + # If `spaces` or `env` provided -> Create a MultiRLModuleSpec first to be # passed into the LearnerGroup constructor. if rl_module_spec is None and (env is not None or spaces is not None): - rl_module_spec = self.get_marl_module_spec(env=env, spaces=spaces) + rl_module_spec = self.get_multi_rl_module_spec(env=env, spaces=spaces) # Construct the actual LearnerGroup. learner_group = LearnerGroup(config=self.copy(), module_spec=rl_module_spec) @@ -1154,11 +1154,11 @@ def build_learner( Returns: The newly created (and already built) Learner object. """ - # If `spaces` or `env` provided -> Create a MARL Module Spec first to be + # If `spaces` or `env` provided -> Create a MultiRLModuleSpec first to be # passed into the LearnerGroup constructor. rl_module_spec = None if env is not None or spaces is not None: - rl_module_spec = self.get_marl_module_spec(env=env, spaces=spaces) + rl_module_spec = self.get_multi_rl_module_spec(env=env, spaces=spaces) # Construct the actual Learner object. learner = self.learner_class(config=self, module_spec=rl_module_spec) # `build()` the Learner (internal structures such as RLModule, etc..). @@ -2605,10 +2605,10 @@ def multi_agent( A mapping from ModuleIDs to per-module AlgorithmConfig override dicts, which apply certain settings, e.g. the learning rate, from the main AlgorithmConfig only to this - particular module (within a MultiAgentRLModule). + particular module (within a MultiRLModule). You can create override dicts by using the `AlgorithmConfig.overrides` utility. For example, to override your learning rate and (PPO) lambda - setting just for a single RLModule with your MultiAgentRLModule, do: + setting just for a single RLModule with your MultiRLModule, do: config.multi_agent(algorithm_config_overrides_per_module={ "module_1": PPOConfig.overrides(lr=0.0002, lambda_=0.75), }) @@ -3066,7 +3066,7 @@ def rl_module( self, *, model_config_dict: Optional[Dict[str, Any]] = NotProvided, - rl_module_spec: Optional[RLModuleSpec] = NotProvided, + rl_module_spec: Optional[RLModuleSpecType] = NotProvided, # Deprecated arg. _enable_rl_module_api=DEPRECATED_VALUE, ) -> "AlgorithmConfig": @@ -3077,7 +3077,7 @@ def rl_module( will be used for any `RLModule` if not otherwise specified in the `rl_module_spec`. rl_module_spec: The RLModule spec to use for this config. It can be either - a SingleAgentRLModuleSpec or a MultiAgentRLModuleSpec. If the + a RLModuleSpec or a MultiRLModuleSpec. If the observation_space, action_space, catalog_class, or the model config is not specified it will be inferred from the env and other parts of the algorithm config object. @@ -3166,19 +3166,18 @@ def rl_module_spec(self): _check_rl_module_spec(self._rl_module_spec) # Merge given spec with default one (in case items are missing, such as # spaces, module class, etc.) - if isinstance(self._rl_module_spec, SingleAgentRLModuleSpec): - if isinstance(default_rl_module_spec, SingleAgentRLModuleSpec): + if isinstance(self._rl_module_spec, RLModuleSpec): + if isinstance(default_rl_module_spec, RLModuleSpec): default_rl_module_spec.update(self._rl_module_spec) return default_rl_module_spec - elif isinstance(default_rl_module_spec, MultiAgentRLModuleSpec): + elif isinstance(default_rl_module_spec, MultiRLModuleSpec): raise ValueError( - "Cannot merge MultiAgentRLModuleSpec with " - "SingleAgentRLModuleSpec!" + "Cannot merge MultiRLModuleSpec with " "RLModuleSpec!" ) else: - marl_module_spec = copy.deepcopy(self._rl_module_spec) - marl_module_spec.update(default_rl_module_spec) - return marl_module_spec + multi_rl_module_spec = copy.deepcopy(self._rl_module_spec) + multi_rl_module_spec.update(default_rl_module_spec) + return multi_rl_module_spec # `self._rl_module_spec` has not been user defined -> return default one. else: @@ -3661,15 +3660,15 @@ def get_torch_compile_worker_config(self): torch_dynamo_mode=self.torch_compile_worker_dynamo_mode, ) - def get_default_rl_module_spec(self) -> RLModuleSpec: + def get_default_rl_module_spec(self) -> RLModuleSpecType: """Returns the RLModule spec to use for this algorithm. Override this method in the sub-class to return the RLModule spec given the input framework. Returns: - The RLModuleSpec (SingleAgentRLModuleSpec or MultiAgentRLModuleSpec) to use - for this algorithm's RLModule. + The RLModuleSpec (or MultiRLModuleSpec) to + use for this algorithm's RLModule. """ raise NotImplementedError @@ -3685,19 +3684,19 @@ def get_default_learner_class(self) -> Union[Type["Learner"], str]: """ raise NotImplementedError - def get_marl_module_spec( + def get_multi_rl_module_spec( self, *, policy_dict: Optional[Dict[str, PolicySpec]] = None, - single_agent_rl_module_spec: Optional[SingleAgentRLModuleSpec] = None, + single_agent_rl_module_spec: Optional[RLModuleSpec] = None, env: Optional[EnvType] = None, spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None, inference_only: bool = False, - ) -> MultiAgentRLModuleSpec: - """Returns the MultiAgentRLModule spec based on the given policy spec dict. + ) -> MultiRLModuleSpec: + """Returns the MultiRLModuleSpec based on the given policy spec dict. policy_dict could be a partial dict of the policies that we need to turn into - an equivalent multi-agent RLModule spec. + an equivalent `MultiRLModuleSpec`. Args: policy_dict: The policy spec dict. Using this dict, we can determine the @@ -3706,8 +3705,8 @@ def get_marl_module_spec( they will get auto-filled with these values obtrained from the policy spec dict. Here we are relying on the policy's logic for infering these values from other sources of information (e.g. environement) - single_agent_rl_module_spec: The SingleAgentRLModuleSpec to use for - constructing a MultiAgentRLModuleSpec. If None, the already + single_agent_rl_module_spec: The RLModuleSpec to use for + constructing a MultiRLModuleSpec. If None, the already configured spec (`self._rl_module_spec`) or the default RLModuleSpec for this algorithm (`self.get_default_rl_module_spec()`) will be used. env: An optional env instance, from which to infer the different spaces for @@ -3727,7 +3726,7 @@ def get_marl_module_spec( environment (no target or critic networks). """ # TODO (Kourosh,sven): When we replace policy entirely there will be no need for - # this function to map policy_dict to marl_module_specs anymore. The module + # this function to map policy_dict to multi_rl_module_specs anymore. The module # spec will be directly given by the user or inferred from env and spaces. if policy_dict is None: policy_dict, _ = self.get_multi_agent_setup(env=env, spaces=spaces) @@ -3742,16 +3741,16 @@ def get_marl_module_spec( current_rl_module_spec = self._rl_module_spec or default_rl_module_spec # Algorithm is currently setup as a single-agent one. - if isinstance(current_rl_module_spec, SingleAgentRLModuleSpec): + if isinstance(current_rl_module_spec, RLModuleSpec): # Use either the provided `single_agent_rl_module_spec` (a - # SingleAgentRLModuleSpec), the currently configured one of this + # RLModuleSpec), the currently configured one of this # AlgorithmConfig object, or the default one. single_agent_rl_module_spec = ( single_agent_rl_module_spec or current_rl_module_spec ) single_agent_rl_module_spec.inference_only = inference_only - # Now construct the proper MultiAgentRLModuleSpec. - marl_module_spec = MultiAgentRLModuleSpec( + # Now construct the proper MultiRLModuleSpec. + multi_rl_module_spec = MultiRLModuleSpec( module_specs={ k: copy.deepcopy(single_agent_rl_module_spec) for k in policy_dict.keys() @@ -3762,17 +3761,15 @@ def get_marl_module_spec( else: # The user currently has a MultiAgentSpec setup (either via # self._rl_module_spec or the default spec of this AlgorithmConfig). - assert isinstance(current_rl_module_spec, MultiAgentRLModuleSpec) + assert isinstance(current_rl_module_spec, MultiRLModuleSpec) # Default is single-agent but the user has provided a multi-agent spec # so the use-case is multi-agent. - if isinstance(default_rl_module_spec, SingleAgentRLModuleSpec): + if isinstance(default_rl_module_spec, RLModuleSpec): # The individual (single-agent) module specs are defined by the user - # in the currently setup MultiAgentRLModuleSpec -> Use that - # SingleAgentRLModuleSpec. - if isinstance( - current_rl_module_spec.module_specs, SingleAgentRLModuleSpec - ): + # in the currently setup MultiRLModuleSpec -> Use that + # RLModuleSpec. + if isinstance(current_rl_module_spec.module_specs, RLModuleSpec): single_agent_spec = single_agent_rl_module_spec or ( current_rl_module_spec.module_specs ) @@ -3783,7 +3780,7 @@ def get_marl_module_spec( # The individual (single-agent) module specs have not been configured # via this AlgorithmConfig object -> Use provided single-agent spec or - # the the default spec (which is also a SingleAgentRLModuleSpec in this + # the the default spec (which is also a RLModuleSpec in this # case). else: single_agent_spec = ( @@ -3799,11 +3796,11 @@ def get_marl_module_spec( for k in policy_dict.keys() } - # Now construct the proper MultiAgentRLModuleSpec. + # Now construct the proper MultiRLModuleSpec. # We need to infer the multi-agent class from `current_rl_module_spec` # and fill in the module_specs dict. - marl_module_spec = current_rl_module_spec.__class__( - marl_module_class=current_rl_module_spec.marl_module_class, + multi_rl_module_spec = current_rl_module_spec.__class__( + multi_rl_module_class=current_rl_module_spec.multi_rl_module_class, module_specs=module_specs, modules_to_load=current_rl_module_spec.modules_to_load, load_state_path=current_rl_module_spec.load_state_path, @@ -3812,41 +3809,39 @@ def get_marl_module_spec( # Default is multi-agent and user wants to override it -> Don't use the # default. else: - # Use has given an override SingleAgentRLModuleSpec -> Use this to - # construct the individual RLModules within the MultiAgentRLModuleSpec. + # Use has given an override RLModuleSpec -> Use this to + # construct the individual RLModules within the MultiRLModuleSpec. if single_agent_rl_module_spec is not None: pass - # User has NOT provided an override SingleAgentRLModuleSpec. + # User has NOT provided an override RLModuleSpec. else: # But the currently setup multi-agent spec has a SingleAgentRLModule # spec defined -> Use that to construct the individual RLModules - # within the MultiAgentRLModuleSpec. - if isinstance( - current_rl_module_spec.module_specs, SingleAgentRLModuleSpec - ): + # within the MultiRLModuleSpec. + if isinstance(current_rl_module_spec.module_specs, RLModuleSpec): # The individual module specs are not given, it is given as one - # SingleAgentRLModuleSpec to be re-used for all + # RLModuleSpec to be re-used for all single_agent_rl_module_spec = ( current_rl_module_spec.module_specs ) # The currently setup multi-agent spec has NO - # SingleAgentRLModuleSpec in it -> Error (there is no way we can + # RLModuleSpec in it -> Error (there is no way we can # infer this information from anywhere at this point). else: raise ValueError( - "We have a MultiAgentRLModuleSpec " + "We have a MultiRLModuleSpec " f"({current_rl_module_spec}), but no " - "`SingleAgentRLModuleSpec`s to compile the individual " + "`RLModuleSpec`s to compile the individual " "RLModules' specs! Use " - "`AlgorithmConfig.get_marl_module_spec(" + "`AlgorithmConfig.get_multi_rl_module_spec(" "policy_dict=.., single_agent_rl_module_spec=..)`." ) single_agent_rl_module_spec.inference_only = inference_only - # Now construct the proper MultiAgentRLModuleSpec. - marl_module_spec = current_rl_module_spec.__class__( - marl_module_class=current_rl_module_spec.marl_module_class, + # Now construct the proper MultiRLModuleSpec. + multi_rl_module_spec = current_rl_module_spec.__class__( + multi_rl_module_class=current_rl_module_spec.multi_rl_module_class, module_specs={ k: copy.deepcopy(single_agent_rl_module_spec) for k in policy_dict.keys() @@ -3855,12 +3850,12 @@ def get_marl_module_spec( load_state_path=current_rl_module_spec.load_state_path, ) - # Make sure that policy_dict and marl_module_spec have similar keys - if set(policy_dict.keys()) != set(marl_module_spec.module_specs.keys()): + # Make sure that policy_dict and multi_rl_module_spec have similar keys + if set(policy_dict.keys()) != set(multi_rl_module_spec.module_specs.keys()): raise ValueError( "Policy dict and module spec have different keys! \n" f"policy_dict keys: {list(policy_dict.keys())} \n" - f"module_spec keys: {list(marl_module_spec.module_specs.keys())}" + f"module_spec keys: {list(multi_rl_module_spec.module_specs.keys())}" ) # Fill in the missing values from the specs that we already have. By combining @@ -3868,20 +3863,18 @@ def get_marl_module_spec( for module_id in policy_dict: policy_spec = policy_dict[module_id] - module_spec = marl_module_spec.module_specs[module_id] + module_spec = multi_rl_module_spec.module_specs[module_id] if module_spec.module_class is None: - if isinstance(default_rl_module_spec, SingleAgentRLModuleSpec): + if isinstance(default_rl_module_spec, RLModuleSpec): module_spec.module_class = default_rl_module_spec.module_class - elif isinstance( - default_rl_module_spec.module_specs, SingleAgentRLModuleSpec - ): + elif isinstance(default_rl_module_spec.module_specs, RLModuleSpec): module_class = default_rl_module_spec.module_specs.module_class # This should be already checked in validate() but we check it # again here just in case if module_class is None: raise ValueError( "The default rl_module spec cannot have an empty " - "module_class under its SingleAgentRLModuleSpec." + "module_class under its RLModuleSpec." ) module_spec.module_class = module_class elif module_id in default_rl_module_spec.module_specs: @@ -3896,11 +3889,9 @@ def get_marl_module_spec( "the algorithm." ) if module_spec.catalog_class is None: - if isinstance(default_rl_module_spec, SingleAgentRLModuleSpec): + if isinstance(default_rl_module_spec, RLModuleSpec): module_spec.catalog_class = default_rl_module_spec.catalog_class - elif isinstance( - default_rl_module_spec.module_specs, SingleAgentRLModuleSpec - ): + elif isinstance(default_rl_module_spec.module_specs, RLModuleSpec): catalog_class = default_rl_module_spec.module_specs.catalog_class module_spec.catalog_class = catalog_class elif module_id in default_rl_module_spec.module_specs: @@ -3932,7 +3923,7 @@ def get_marl_module_spec( self.model_config | module_spec.model_config_dict ) - return marl_module_spec + return multi_rl_module_spec def __setattr__(self, key, value): """Gatekeeper in case we are in frozen state and need to error.""" @@ -4647,6 +4638,10 @@ def _resolve_tf_settings(self, _tf1, _tfv): "speed as with static-graph mode." ) + @Deprecated(new="AlgorithmConfig.get_multi_rl_module_spec()", error=False) + def get_marl_module_spec(self, *args, **kwargs): + return self.get_multi_rl_module_spec(*args, **kwargs) + @Deprecated(new="AlgorithmConfig.env_runners(..)", error=False) def rollouts(self, *args, **kwargs): return self.env_runners(*args, **kwargs) diff --git a/rllib/algorithms/appo/appo.py b/rllib/algorithms/appo/appo.py index 7316855754af3..570e40087f98b 100644 --- a/rllib/algorithms/appo/appo.py +++ b/rllib/algorithms/appo/appo.py @@ -15,7 +15,7 @@ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided from ray.rllib.algorithms.impala.impala import IMPALA, IMPALAConfig -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.policy.policy import Policy from ray.rllib.utils.annotations import override from ray.rllib.utils.deprecation import DEPRECATED_VALUE, deprecation_warning @@ -249,7 +249,7 @@ def get_default_learner_class(self): ) @override(IMPALAConfig) - def get_default_rl_module_spec(self) -> SingleAgentRLModuleSpec: + def get_default_rl_module_spec(self) -> RLModuleSpec: if self.framework_str == "torch": from ray.rllib.algorithms.appo.torch.appo_torch_rl_module import ( APPOTorchRLModule as RLModule, @@ -266,7 +266,7 @@ def get_default_rl_module_spec(self) -> SingleAgentRLModuleSpec: from ray.rllib.algorithms.appo.appo_catalog import APPOCatalog - return SingleAgentRLModuleSpec(module_class=RLModule, catalog_class=APPOCatalog) + return RLModuleSpec(module_class=RLModule, catalog_class=APPOCatalog) @property @override(AlgorithmConfig) diff --git a/rllib/algorithms/appo/appo_learner.py b/rllib/algorithms/appo/appo_learner.py index 38496107f99f6..9440dd9c33ca1 100644 --- a/rllib/algorithms/appo/appo_learner.py +++ b/rllib/algorithms/appo/appo_learner.py @@ -6,8 +6,8 @@ from ray.rllib.core.learner.learner import Learner from ray.rllib.core.learner.utils import update_target_network from ray.rllib.core.rl_module.apis.target_network_api import TargetNetworkAPI -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.utils.annotations import override from ray.rllib.utils.lambda_defaultdict import LambdaDefaultDict from ray.rllib.utils.metrics import ( @@ -54,10 +54,10 @@ def add_module( self, *, module_id: ModuleID, - module_spec: SingleAgentRLModuleSpec, + module_spec: RLModuleSpec, config_overrides: Optional[Dict] = None, new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None, - ) -> MultiAgentRLModuleSpec: + ) -> MultiRLModuleSpec: marl_spec = super().add_module(module_id=module_id) # Create target networks for added Module, if applicable. if isinstance(self.module[module_id].unwrapped(), TargetNetworkAPI): @@ -65,7 +65,7 @@ def add_module( return marl_spec @override(IMPALALearner) - def remove_module(self, module_id: str) -> MultiAgentRLModuleSpec: + def remove_module(self, module_id: str) -> MultiRLModuleSpec: marl_spec = super().remove_module(module_id) self.curr_kl_coeffs_per_module.pop(module_id) return marl_spec diff --git a/rllib/algorithms/bc/bc.py b/rllib/algorithms/bc/bc.py index 4f2f8b73ebfbf..0fa34a1db0bc4 100644 --- a/rllib/algorithms/bc/bc.py +++ b/rllib/algorithms/bc/bc.py @@ -3,7 +3,7 @@ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.algorithms.bc.bc_catalog import BCCatalog from ray.rllib.algorithms.marwil.marwil import MARWIL, MARWILConfig -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.execution.rollout_ops import synchronous_parallel_sample from ray.rllib.utils.annotations import override from ray.rllib.utils.metrics import ( @@ -22,7 +22,7 @@ SYNCH_WORKER_WEIGHTS_TIMER, TIMERS, ) -from ray.rllib.utils.typing import RLModuleSpec, ResultDict +from ray.rllib.utils.typing import RLModuleSpecType, ResultDict if TYPE_CHECKING: from ray.rllib.core.learner import Learner @@ -85,18 +85,18 @@ def __init__(self, algo_class=None): # fmt: on @override(AlgorithmConfig) - def get_default_rl_module_spec(self) -> RLModuleSpec: + def get_default_rl_module_spec(self) -> RLModuleSpecType: if self.framework_str == "torch": from ray.rllib.algorithms.bc.torch.bc_torch_rl_module import BCTorchRLModule - return SingleAgentRLModuleSpec( + return RLModuleSpec( module_class=BCTorchRLModule, catalog_class=BCCatalog, ) elif self.framework_str == "tf2": from ray.rllib.algorithms.bc.tf.bc_tf_rl_module import BCTfRLModule - return SingleAgentRLModuleSpec( + return RLModuleSpec( module_class=BCTfRLModule, catalog_class=BCCatalog, ) diff --git a/rllib/algorithms/callbacks.py b/rllib/algorithms/callbacks.py index bf178bc457072..29ac52b871594 100644 --- a/rllib/algorithms/callbacks.py +++ b/rllib/algorithms/callbacks.py @@ -317,7 +317,7 @@ def on_episode_start( (within the vector of sub-environments of the BaseEnv). rl_module: The RLModule used to compute actions for stepping the env. In a single-agent setup, this is a (single-agent) RLModule, in a multi- - agent setup, this will be a MultiAgentRLModule. + agent setup, this will be a MultiRLModule. kwargs: Forward compatibility placeholder. """ pass @@ -360,7 +360,7 @@ def on_episode_step( env_index: The index of the sub-environment that has just been stepped. rl_module: The RLModule used to compute actions for stepping the env. In a single-agent setup, this is a (single-agent) RLModule, in a multi- - agent setup, this will be a MultiAgentRLModule. + agent setup, this will be a MultiRLModule. kwargs: Forward compatibility placeholder. """ pass @@ -420,7 +420,7 @@ def on_episode_end( or truncated. rl_module: The RLModule used to compute actions for stepping the env. In a single-agent setup, this is a (single-agent) RLModule, in a multi- - agent setup, this will be a MultiAgentRLModule. + agent setup, this will be a MultiRLModule. kwargs: Forward compatibility placeholder. """ pass diff --git a/rllib/algorithms/dqn/dqn.py b/rllib/algorithms/dqn/dqn.py index 7a4265a121b68..27520ab4bacc8 100644 --- a/rllib/algorithms/dqn/dqn.py +++ b/rllib/algorithms/dqn/dqn.py @@ -19,7 +19,7 @@ from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy from ray.rllib.core.learner import Learner -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.execution.rollout_ops import ( synchronous_parallel_sample, ) @@ -67,7 +67,7 @@ ) from ray.rllib.utils.deprecation import DEPRECATED_VALUE from ray.rllib.utils.replay_buffers.utils import sample_min_n_steps_from_buffer -from ray.rllib.utils.typing import RLModuleSpec, SampleBatchType +from ray.rllib.utils.typing import RLModuleSpecType, SampleBatchType logger = logging.getLogger(__name__) @@ -490,7 +490,7 @@ def get_rollout_fragment_length(self, worker_index: int = 0) -> int: return self.rollout_fragment_length @override(AlgorithmConfig) - def get_default_rl_module_spec(self) -> RLModuleSpec: + def get_default_rl_module_spec(self) -> RLModuleSpecType: from ray.rllib.algorithms.dqn.dqn_rainbow_catalog import DQNRainbowCatalog if self.framework_str == "torch": @@ -498,7 +498,7 @@ def get_default_rl_module_spec(self) -> RLModuleSpec: DQNRainbowTorchRLModule, ) - return SingleAgentRLModuleSpec( + return RLModuleSpec( module_class=DQNRainbowTorchRLModule, catalog_class=DQNRainbowCatalog, model_config_dict=self.model_config, diff --git a/rllib/algorithms/dqn/dqn_rainbow_learner.py b/rllib/algorithms/dqn/dqn_rainbow_learner.py index 9728eeef84de0..4b6ec4dbd40ef 100644 --- a/rllib/algorithms/dqn/dqn_rainbow_learner.py +++ b/rllib/algorithms/dqn/dqn_rainbow_learner.py @@ -4,8 +4,8 @@ from ray.rllib.core.learner.learner import Learner from ray.rllib.core.learner.utils import update_target_network from ray.rllib.core.rl_module.apis.target_network_api import TargetNetworkAPI -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import ( AddObservationsFromEpisodesToBatch, ) @@ -68,10 +68,10 @@ def add_module( self, *, module_id: ModuleID, - module_spec: SingleAgentRLModuleSpec, + module_spec: RLModuleSpec, config_overrides: Optional[Dict] = None, new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None, - ) -> MultiAgentRLModuleSpec: + ) -> MultiRLModuleSpec: marl_spec = super().add_module(module_id=module_id) # Create target networks for added Module, if applicable. if isinstance(self.module[module_id].unwrapped(), TargetNetworkAPI): diff --git a/rllib/algorithms/dreamerv3/README.md b/rllib/algorithms/dreamerv3/README.md index 95c2ee302e249..a92918273f64d 100644 --- a/rllib/algorithms/dreamerv3/README.md +++ b/rllib/algorithms/dreamerv3/README.md @@ -136,10 +136,10 @@ new catalog via your ``DreamerV3Config`` object as follows: ```python from ray.rllib.algorithms.dreamerv3.tf.dreamerv3_tf_rl_module import DreamerV3TfRLModule -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec config.rl_module( - rl_module_spec=SingleAgentRLModuleSpec( + rl_module_spec=RLModuleSpec( module_class=DreamerV3TfRLModule, catalog_class=[your DreamerV3Catalog subclass], ) diff --git a/rllib/algorithms/dreamerv3/dreamerv3.py b/rllib/algorithms/dreamerv3/dreamerv3.py index 984aa994fd3f6..84e0cc6ae2522 100644 --- a/rllib/algorithms/dreamerv3/dreamerv3.py +++ b/rllib/algorithms/dreamerv3/dreamerv3.py @@ -26,7 +26,7 @@ ) from ray.rllib.core import DEFAULT_MODULE_ID from ray.rllib.core.columns import Columns -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.execution.rollout_ops import synchronous_parallel_sample from ray.rllib.policy.sample_batch import SampleBatch from ray.rllib.utils import deep_update @@ -437,13 +437,13 @@ def get_default_learner_class(self): raise ValueError(f"The framework {self.framework_str} is not supported.") @override(AlgorithmConfig) - def get_default_rl_module_spec(self) -> SingleAgentRLModuleSpec: + def get_default_rl_module_spec(self) -> RLModuleSpec: if self.framework_str == "tf2": from ray.rllib.algorithms.dreamerv3.tf.dreamerv3_tf_rl_module import ( DreamerV3TfRLModule, ) - return SingleAgentRLModuleSpec( + return RLModuleSpec( module_class=DreamerV3TfRLModule, catalog_class=DreamerV3Catalog ) else: diff --git a/rllib/algorithms/dreamerv3/tests/test_dreamerv3.py b/rllib/algorithms/dreamerv3/tests/test_dreamerv3.py index cec88c7a2777b..f9919816ea136 100644 --- a/rllib/algorithms/dreamerv3/tests/test_dreamerv3.py +++ b/rllib/algorithms/dreamerv3/tests/test_dreamerv3.py @@ -205,7 +205,9 @@ def test_dreamerv3_dreamer_model_sizes(self): # Create our RLModule to compute actions with. policy_dict, _ = config.get_multi_agent_setup() - module_spec = config.get_marl_module_spec(policy_dict=policy_dict) + module_spec = config.get_multi_rl_module_spec( + policy_dict=policy_dict + ) rl_module = module_spec.build()[DEFAULT_MODULE_ID] # Count the generated RLModule's parameters and compare to the diff --git a/rllib/algorithms/dreamerv3/utils/env_runner.py b/rllib/algorithms/dreamerv3/utils/env_runner.py index 43c267dbc6f04..9014a03ace3cc 100644 --- a/rllib/algorithms/dreamerv3/utils/env_runner.py +++ b/rllib/algorithms/dreamerv3/utils/env_runner.py @@ -172,7 +172,7 @@ def _entry_point(): # Create our RLModule to compute actions with. policy_dict, _ = self.config.get_multi_agent_setup(env=self.env) - self.marl_module_spec = self.config.get_marl_module_spec( + self.multi_rl_module_spec = self.config.get_multi_rl_module_spec( policy_dict=policy_dict ) if self.config.share_module_between_env_runner_and_learner: @@ -182,7 +182,7 @@ def _entry_point(): # weight-synched each iteration). else: # TODO (sven): DreamerV3 is currently single-agent only. - self.module = self.marl_module_spec.build()[DEFAULT_MODULE_ID] + self.module = self.multi_rl_module_spec.build()[DEFAULT_MODULE_ID] self.metrics = MetricsLogger() diff --git a/rllib/algorithms/impala/impala.py b/rllib/algorithms/impala/impala.py index 966b14d06f1e8..4bdff4a0e3f41 100644 --- a/rllib/algorithms/impala/impala.py +++ b/rllib/algorithms/impala/impala.py @@ -18,7 +18,7 @@ COMPONENT_ENV_TO_MODULE_CONNECTOR, COMPONENT_MODULE_TO_ENV_CONNECTOR, ) -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.env.env_runner_group import _handle_remote_call_result_errors from ray.rllib.execution.buffers.mixin_replay_buffer import MixInMultiAgentReplayBuffer from ray.rllib.execution.learner_thread import LearnerThread @@ -507,23 +507,19 @@ def get_default_learner_class(self): ) @override(AlgorithmConfig) - def get_default_rl_module_spec(self) -> SingleAgentRLModuleSpec: + def get_default_rl_module_spec(self) -> RLModuleSpec: from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog if self.framework_str == "tf2": from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import PPOTfRLModule - return SingleAgentRLModuleSpec( - module_class=PPOTfRLModule, catalog_class=PPOCatalog - ) + return RLModuleSpec(module_class=PPOTfRLModule, catalog_class=PPOCatalog) elif self.framework_str == "torch": from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import ( PPOTorchRLModule, ) - return SingleAgentRLModuleSpec( - module_class=PPOTorchRLModule, catalog_class=PPOCatalog - ) + return RLModuleSpec(module_class=PPOTorchRLModule, catalog_class=PPOCatalog) else: raise ValueError( f"The framework {self.framework_str} is not supported. " diff --git a/rllib/algorithms/ppo/ppo.py b/rllib/algorithms/ppo/ppo.py index 0a76daa8d0a54..a87b8461fe4c2 100644 --- a/rllib/algorithms/ppo/ppo.py +++ b/rllib/algorithms/ppo/ppo.py @@ -16,7 +16,7 @@ from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.execution.rollout_ops import ( standardize_fields, synchronous_parallel_sample, @@ -159,7 +159,7 @@ def __init__(self, algo_class=None): } @override(AlgorithmConfig) - def get_default_rl_module_spec(self) -> SingleAgentRLModuleSpec: + def get_default_rl_module_spec(self) -> RLModuleSpec: from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog if self.framework_str == "torch": @@ -167,15 +167,11 @@ def get_default_rl_module_spec(self) -> SingleAgentRLModuleSpec: PPOTorchRLModule, ) - return SingleAgentRLModuleSpec( - module_class=PPOTorchRLModule, catalog_class=PPOCatalog - ) + return RLModuleSpec(module_class=PPOTorchRLModule, catalog_class=PPOCatalog) elif self.framework_str == "tf2": from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import PPOTfRLModule - return SingleAgentRLModuleSpec( - module_class=PPOTfRLModule, catalog_class=PPOCatalog - ) + return RLModuleSpec(module_class=PPOTfRLModule, catalog_class=PPOCatalog) else: raise ValueError( f"The framework {self.framework_str} is not supported. " diff --git a/rllib/algorithms/sac/sac.py b/rllib/algorithms/sac/sac.py index 67f1925f36427..2f3bc8d11489b 100644 --- a/rllib/algorithms/sac/sac.py +++ b/rllib/algorithms/sac/sac.py @@ -5,7 +5,7 @@ from ray.rllib.algorithms.dqn.dqn import DQN from ray.rllib.algorithms.sac.sac_tf_policy import SACTFPolicy from ray.rllib.core.learner import Learner -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.policy.policy import Policy from ray.rllib.utils import deep_update from ray.rllib.utils.annotations import override @@ -14,7 +14,7 @@ deprecation_warning, ) from ray.rllib.utils.framework import try_import_tf, try_import_tfp -from ray.rllib.utils.typing import RLModuleSpec, ResultDict +from ray.rllib.utils.typing import RLModuleSpecType, ResultDict tf1, tf, tfv = try_import_tf() tfp = try_import_tfp() @@ -374,7 +374,7 @@ def get_rollout_fragment_length(self, worker_index: int = 0) -> int: return self.rollout_fragment_length @override(AlgorithmConfig) - def get_default_rl_module_spec(self) -> RLModuleSpec: + def get_default_rl_module_spec(self) -> RLModuleSpecType: from ray.rllib.algorithms.sac.sac_catalog import SACCatalog if self.framework_str == "torch": @@ -382,9 +382,7 @@ def get_default_rl_module_spec(self) -> RLModuleSpec: SACTorchRLModule, ) - return SingleAgentRLModuleSpec( - module_class=SACTorchRLModule, catalog_class=SACCatalog - ) + return RLModuleSpec(module_class=SACTorchRLModule, catalog_class=SACCatalog) else: raise ValueError( f"The framework {self.framework_str} is not supported. " "Use `torch`." diff --git a/rllib/algorithms/tests/test_algorithm.py b/rllib/algorithms/tests/test_algorithm.py index 320c469dabf36..ffe45ea858b35 100644 --- a/rllib/algorithms/tests/test_algorithm.py +++ b/rllib/algorithms/tests/test_algorithm.py @@ -12,7 +12,7 @@ from ray.rllib.algorithms.bc import BCConfig import ray.rllib.algorithms.ppo as ppo from ray.rllib.core.columns import Columns -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole from ray.rllib.examples.evaluation.evaluation_parallel_to_training import ( AssertEvalCallback, @@ -93,7 +93,7 @@ def new_mapping_fn(agent_id, episode, i=i, **kwargs): print(f"Adding new RLModule {mid} ...") new_marl_spec = algo.add_module( module_id=mid, - module_spec=SingleAgentRLModuleSpec.from_module(mod0), + module_spec=RLModuleSpec.from_module(mod0), # Test changing the mapping fn. new_agent_to_module_mapping_fn=new_mapping_fn, # Change the list of modules to train. @@ -111,11 +111,11 @@ def new_mapping_fn(agent_id, episode, i=i, **kwargs): # Assert new policy is part of local worker (eval worker set does NOT # have a local worker, only the main EnvRunnerGroup does). - marl_module = algo.env_runner.module + multi_rl_module = algo.env_runner.module self.assertTrue(new_module is not mod0) for j in range(i + 1): - self.assertTrue(f"p{j}" in marl_module) - self.assertTrue(len(marl_module) == i + 1) + self.assertTrue(f"p{j}" in multi_rl_module) + self.assertTrue(len(multi_rl_module) == i + 1) algo.train() checkpoint = algo.save_to_path() diff --git a/rllib/algorithms/tests/test_algorithm_config.py b/rllib/algorithms/tests/test_algorithm_config.py index 8db90c7d38074..03ec44a9aad9d 100644 --- a/rllib/algorithms/tests/test_algorithm_config.py +++ b/rllib/algorithms/tests/test_algorithm_config.py @@ -8,10 +8,10 @@ from ray.rllib.algorithms.ppo import PPO, PPOConfig from ray.rllib.algorithms.ppo.tf.ppo_tf_learner import PPOTfLearner from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec, RLModule -from ray.rllib.core.rl_module.marl_module import ( - MultiAgentRLModule, - MultiAgentRLModuleSpec, +from ray.rllib.core.rl_module.rl_module import RLModuleSpec, RLModule +from ray.rllib.core.rl_module.multi_rl_module import ( + MultiRLModule, + MultiRLModuleSpec, ) from ray.rllib.utils.test_utils import check @@ -182,7 +182,7 @@ def test_rl_module_api(self): class A: pass - config = config.rl_module(rl_module_spec=SingleAgentRLModuleSpec(A)) + config = config.rl_module(rl_module_spec=RLModuleSpec(A)) self.assertEqual(config.rl_module_spec.module_class, A) def test_config_per_module(self): @@ -240,7 +240,7 @@ def test_learner_api(self): self.assertEqual(config.learner_class, PPOTfLearner) def _assertEqualMARLSpecs(self, spec1, spec2): - self.assertEqual(spec1.marl_module_class, spec2.marl_module_class) + self.assertEqual(spec1.multi_rl_module_class, spec2.multi_rl_module_class) self.assertEqual(set(spec1.module_specs.keys()), set(spec2.module_specs.keys())) for k, module_spec1 in spec1.module_specs.items(): @@ -260,30 +260,29 @@ def _get_expected_marl_spec( config: AlgorithmConfig, expected_module_class: Type[RLModule], passed_module_class: Type[RLModule] = None, - expected_marl_module_class: Type[MultiAgentRLModule] = None, + expected_multi_rl_module_class: Type[MultiRLModule] = None, ): """This is a utility function that retrieves the expected marl specs. Args: config: The algorithm config. expected_module_class: This is the expected RLModule class that is going to - be reference in the SingleAgentRLModuleSpec parts of the - MultiAgentRLModuleSpec. + be reference in the RLModuleSpec parts of the MultiLModuleSpec. passed_module_class: This is the RLModule class that is passed into the - module_spec argument of get_marl_module_spec. The function is + module_spec argument of get_multi_rl_module_spec. The function is designed so that it will use the passed in module_spec for the - SingleAgentRLModuleSpec parts of the MultiAgentRLModuleSpec. - expected_marl_module_class: This is the expected MultiAgentRLModule class - that is going to be reference in the MultiAgentRLModuleSpec. + RLModuleSpec parts of the MultiRLModuleSpec. + expected_multi_rl_module_class: This is the expected MultiRLModule class + that is going to be reference in the MultiRLModuleSpec. Returns: - Tuple of the returned MultiAgentRLModuleSpec from config. - get_marl_module_spec() and the expected MultiAgentRLModuleSpec. + Tuple of the returned MultiRLModuleSpec from config. + get_multi_rl_module_spec() and the expected MultiRLModuleSpec. """ from ray.rllib.policy.policy import PolicySpec - if expected_marl_module_class is None: - expected_marl_module_class = MultiAgentRLModule + if expected_multi_rl_module_class is None: + expected_multi_rl_module_class = MultiRLModule env = gym.make("CartPole-v1") policy_spec_ph = PolicySpec( @@ -292,25 +291,23 @@ def _get_expected_marl_spec( config=AlgorithmConfig(), ) - marl_spec = config.get_marl_module_spec( + marl_spec = config.get_multi_rl_module_spec( policy_dict={"p1": policy_spec_ph, "p2": policy_spec_ph}, - single_agent_rl_module_spec=SingleAgentRLModuleSpec( - module_class=passed_module_class - ) + single_agent_rl_module_spec=RLModuleSpec(module_class=passed_module_class) if passed_module_class else None, ) - expected_marl_spec = MultiAgentRLModuleSpec( - marl_module_class=expected_marl_module_class, + expected_marl_spec = MultiRLModuleSpec( + multi_rl_module_class=expected_multi_rl_module_class, module_specs={ - "p1": SingleAgentRLModuleSpec( + "p1": RLModuleSpec( module_class=expected_module_class, observation_space=env.observation_space, action_space=env.action_space, model_config_dict=AlgorithmConfig().model_config, ), - "p2": SingleAgentRLModuleSpec( + "p2": RLModuleSpec( module_class=expected_module_class, observation_space=env.observation_space, action_space=env.action_space, @@ -321,8 +318,8 @@ def _get_expected_marl_spec( return marl_spec, expected_marl_spec - def test_get_marl_module_spec(self): - """Tests whether the get_marl_module_spec() method works properly.""" + def test_get_multi_rl_module_spec(self): + """Tests whether the get_multi_rl_module_spec() method works properly.""" from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule class CustomRLModule1(DiscreteBCTorchModule): @@ -334,32 +331,30 @@ class CustomRLModule2(DiscreteBCTorchModule): class CustomRLModule3(DiscreteBCTorchModule): pass - class CustomMARLModule1(MultiAgentRLModule): + class CustomMultiRLModule1(MultiRLModule): pass ######################################## # single agent class SingleAgentAlgoConfig(AlgorithmConfig): def get_default_rl_module_spec(self): - return SingleAgentRLModuleSpec(module_class=DiscreteBCTorchModule) + return RLModuleSpec(module_class=DiscreteBCTorchModule) # multi-agent class MultiAgentAlgoConfigWithNoSingleAgentSpec(AlgorithmConfig): def get_default_rl_module_spec(self): - return MultiAgentRLModuleSpec(marl_module_class=CustomMARLModule1) + return MultiRLModuleSpec(multi_rl_module_class=CustomMultiRLModule1) class MultiAgentAlgoConfig(AlgorithmConfig): def get_default_rl_module_spec(self): - return MultiAgentRLModuleSpec( - marl_module_class=CustomMARLModule1, - module_specs=SingleAgentRLModuleSpec( - module_class=DiscreteBCTorchModule - ), + return MultiRLModuleSpec( + multi_rl_module_class=CustomMultiRLModule1, + module_specs=RLModuleSpec(module_class=DiscreteBCTorchModule), ) ######################################## - # This is the simplest case where we have to construct the marl module based on - # the default specs only. + # This is the simplest case where we have to construct the MultiRLModule based + # on the default specs only. config = SingleAgentAlgoConfig().api_stack(enable_rl_module_and_learner=True) spec, expected = self._get_expected_marl_spec(config, DiscreteBCTorchModule) @@ -372,16 +367,16 @@ def get_default_rl_module_spec(self): self._assertEqualMARLSpecs(spec, expected) ######################################## - # This is the case where we pass in a multi-agent RLModuleSpec that asks the + # This is the case where we pass in a `MultiRLModuleSpec` that asks the # algorithm to assign a specific type of RLModule class to certain module_ids. config = ( SingleAgentAlgoConfig() .api_stack(enable_rl_module_and_learner=True) .rl_module( - rl_module_spec=MultiAgentRLModuleSpec( + rl_module_spec=MultiRLModuleSpec( module_specs={ - "p1": SingleAgentRLModuleSpec(module_class=CustomRLModule1), - "p2": SingleAgentRLModuleSpec(module_class=CustomRLModule1), + "p1": RLModuleSpec(module_class=CustomRLModule1), + "p2": RLModuleSpec(module_class=CustomRLModule1), }, ), ) @@ -397,7 +392,7 @@ def get_default_rl_module_spec(self): SingleAgentAlgoConfig() .api_stack(enable_rl_module_and_learner=True) .rl_module( - rl_module_spec=SingleAgentRLModuleSpec(module_class=CustomRLModule1), + rl_module_spec=RLModuleSpec(module_class=CustomRLModule1), ) ) @@ -416,8 +411,8 @@ def get_default_rl_module_spec(self): SingleAgentAlgoConfig() .api_stack(enable_rl_module_and_learner=True) .rl_module( - rl_module_spec=MultiAgentRLModuleSpec( - module_specs=SingleAgentRLModuleSpec(module_class=CustomRLModule1) + rl_module_spec=MultiRLModuleSpec( + module_specs=RLModuleSpec(module_class=CustomRLModule1) ), ) ) @@ -433,47 +428,46 @@ def get_default_rl_module_spec(self): ######################################## # This is not only assigning a specific type of RLModule class to EACH - # module_id, but also defining a new custom MultiAgentRLModule class to be used + # module_id, but also defining a new custom MultiRLModule class to be used # in the multi-agent scenario. config = ( SingleAgentAlgoConfig() .api_stack(enable_rl_module_and_learner=True) .rl_module( - rl_module_spec=MultiAgentRLModuleSpec( - marl_module_class=CustomMARLModule1, + rl_module_spec=MultiRLModuleSpec( + multi_rl_module_class=CustomMultiRLModule1, module_specs={ - "p1": SingleAgentRLModuleSpec(module_class=CustomRLModule1), - "p2": SingleAgentRLModuleSpec(module_class=CustomRLModule1), + "p1": RLModuleSpec(module_class=CustomRLModule1), + "p2": RLModuleSpec(module_class=CustomRLModule1), }, ), ) ) spec, expected = self._get_expected_marl_spec( - config, CustomRLModule1, expected_marl_module_class=CustomMARLModule1 + config, CustomRLModule1, expected_multi_rl_module_class=CustomMultiRLModule1 ) self._assertEqualMARLSpecs(spec, expected) # This is expected to return CustomRLModule1 instead of CustomRLModule3 which # is passed in. Because the default for p1, p2 is to use CustomRLModule1. The # passed module_spec only sets a default to fall back onto in case the - # module_id is not specified in the original MultiAgentRLModuleSpec. Since P1 + # module_id is not specified in the original MultiRLModuleSpec. Since P1 # and P2 are both assigned to CustomeRLModule1, the passed module_spec will not # be used. This is the expected behavior for adding a new modules to a - # multi-agent RLModule that is not defined in the original - # MultiAgentRLModuleSpec. + # `MultiRLModule` that is not defined in the original MultiRLModuleSpec. spec, expected = self._get_expected_marl_spec( config, CustomRLModule1, passed_module_class=CustomRLModule3, - expected_marl_module_class=CustomMARLModule1, + expected_multi_rl_module_class=CustomMultiRLModule1, ) self._assertEqualMARLSpecs(spec, expected) ######################################## # This is the case where we ask the algorithm to use its default - # MultiAgentRLModuleSpec, but the MultiAgentRLModuleSpec has not defined its - # SingleAgentRLmoduleSpecs. + # MultiRLModuleSpec, but the MultiRLModuleSpec has not defined its + # RLModuleSpecs. config = MultiAgentAlgoConfigWithNoSingleAgentSpec().api_stack( enable_rl_module_and_learner=True ) @@ -486,12 +480,14 @@ def get_default_rl_module_spec(self): ######################################## # This is the case where we ask the algorithm to use its default - # MultiAgentRLModuleSpec, and the MultiAgentRLModuleSpec has defined its - # SingleAgentRLmoduleSpecs. + # MultiRLModuleSpec, and the MultiRLModuleSpec has defined its + # RLModuleSpecs. config = MultiAgentAlgoConfig().api_stack(enable_rl_module_and_learner=True) spec, expected = self._get_expected_marl_spec( - config, DiscreteBCTorchModule, expected_marl_module_class=CustomMARLModule1 + config, + DiscreteBCTorchModule, + expected_multi_rl_module_class=CustomMultiRLModule1, ) self._assertEqualMARLSpecs(spec, expected) @@ -499,7 +495,7 @@ def get_default_rl_module_spec(self): config, CustomRLModule1, passed_module_class=CustomRLModule1, - expected_marl_module_class=CustomMARLModule1, + expected_multi_rl_module_class=CustomMultiRLModule1, ) self._assertEqualMARLSpecs(spec, expected) diff --git a/rllib/algorithms/tests/test_algorithm_rl_module_restore.py b/rllib/algorithms/tests/test_algorithm_rl_module_restore.py index 4db004048b19c..d13caa90766c9 100644 --- a/rllib/algorithms/tests/test_algorithm_rl_module_restore.py +++ b/rllib/algorithms/tests/test_algorithm_rl_module_restore.py @@ -11,10 +11,10 @@ from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import PPOTfRLModule from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule from ray.rllib.core import DEFAULT_MODULE_ID -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec -from ray.rllib.core.rl_module.marl_module import ( - MultiAgentRLModuleSpec, - MultiAgentRLModule, +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import ( + MultiRLModuleSpec, + MultiRLModule, ) from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole from ray.rllib.utils.test_utils import check, framework_iterator @@ -59,16 +59,16 @@ def policy_mapping_fn(agent_id, episode, worker, **kwargs): ) return config - def test_e2e_load_simple_marl_module(self): - """Test if we can train a PPO algorithm with a checkpointed MARL module e2e.""" + def test_e2e_load_simple_multi_rl_module(self): + """Test if we can train a PPO algo with a checkpointed MultiRLModule e2e.""" config = self.get_ppo_config() env = MultiAgentCartPole({"num_agents": NUM_AGENTS}) for fw in framework_iterator(config, frameworks=["tf2", "torch"]): - # create a marl_module to load and save it to a checkpoint directory + # create a multi_rl_module to load and save it to a checkpoint directory module_specs = {} module_class = PPO_MODULES[fw] for i in range(NUM_AGENTS): - module_specs[f"policy_{i}"] = SingleAgentRLModuleSpec( + module_specs[f"policy_{i}"] = RLModuleSpec( module_class=module_class, observation_space=env.observation_space[0], action_space=env.action_space[0], @@ -78,41 +78,41 @@ def test_e2e_load_simple_marl_module(self): | {"fcnet_hiddens": [32 * (i + 1)]}, catalog_class=PPOCatalog, ) - marl_module_spec = MultiAgentRLModuleSpec(module_specs=module_specs) - marl_module = marl_module_spec.build() - marl_module_weights = convert_to_numpy(marl_module.get_state()) + multi_rl_module_spec = MultiRLModuleSpec(module_specs=module_specs) + multi_rl_module = multi_rl_module_spec.build() + multi_rl_module_weights = convert_to_numpy(multi_rl_module.get_state()) marl_checkpoint_path = tempfile.mkdtemp() - marl_module.save_to_path(marl_checkpoint_path) + multi_rl_module.save_to_path(marl_checkpoint_path) # create a new MARL_spec with the checkpoint from the previous one - marl_module_spec_from_checkpoint = MultiAgentRLModuleSpec( + multi_rl_module_spec_from_checkpoint = MultiRLModuleSpec( module_specs=module_specs, load_state_path=marl_checkpoint_path, ) config = config.api_stack(enable_rl_module_and_learner=True).rl_module( - rl_module_spec=marl_module_spec_from_checkpoint, + rl_module_spec=multi_rl_module_spec_from_checkpoint, ) # create the algorithm with multiple nodes and check if the weights - # are the same as the original MARL Module + # are the same as the original MultiRLModule algo = config.build() algo_module_weights = algo.learner_group.get_weights() - check(algo_module_weights, marl_module_weights) + check(algo_module_weights, multi_rl_module_weights) algo.train() algo.stop() del algo shutil.rmtree(marl_checkpoint_path) - def test_e2e_load_complex_marl_module(self): + def test_e2e_load_complex_multi_rl_module(self): """Test if we can train a PPO algorithm with a cpkt MARL and RL module e2e.""" config = self.get_ppo_config() env = MultiAgentCartPole({"num_agents": NUM_AGENTS}) for fw in framework_iterator(config, frameworks=["tf2", "torch"]): - # create a marl_module to load and save it to a checkpoint directory + # create a multi_rl_module to load and save it to a checkpoint directory module_specs = {} module_class = PPO_MODULES[fw] for i in range(NUM_AGENTS): - module_specs[f"policy_{i}"] = SingleAgentRLModuleSpec( + module_specs[f"policy_{i}"] = RLModuleSpec( module_class=module_class, observation_space=env.observation_space[0], action_space=env.action_space[0], @@ -122,13 +122,13 @@ def test_e2e_load_complex_marl_module(self): | {"fcnet_hiddens": [32 * (i + 1)]}, catalog_class=PPOCatalog, ) - marl_module_spec = MultiAgentRLModuleSpec(module_specs=module_specs) - marl_module = marl_module_spec.build() + multi_rl_module_spec = MultiRLModuleSpec(module_specs=module_specs) + multi_rl_module = multi_rl_module_spec.build() marl_checkpoint_path = tempfile.mkdtemp() - marl_module.save_to_path(marl_checkpoint_path) + multi_rl_module.save_to_path(marl_checkpoint_path) # create a RLModule to load and override the "policy_1" module with - module_to_swap_in = SingleAgentRLModuleSpec( + module_to_swap_in = RLModuleSpec( module_class=module_class, observation_space=env.observation_space[0], action_space=env.action_space[0], @@ -143,7 +143,7 @@ def test_e2e_load_complex_marl_module(self): # create a new MARL_spec with the checkpoint from the marl_checkpoint # and the module_to_swap_in_checkpoint - module_specs["policy_1"] = SingleAgentRLModuleSpec( + module_specs["policy_1"] = RLModuleSpec( module_class=module_class, observation_space=env.observation_space[0], action_space=env.action_space[0], @@ -151,28 +151,30 @@ def test_e2e_load_complex_marl_module(self): catalog_class=PPOCatalog, load_state_path=module_to_swap_in_path, ) - marl_module_spec_from_checkpoint = MultiAgentRLModuleSpec( + multi_rl_module_spec_from_checkpoint = MultiRLModuleSpec( module_specs=module_specs, load_state_path=marl_checkpoint_path, ) config = config.api_stack(enable_rl_module_and_learner=True).rl_module( - rl_module_spec=marl_module_spec_from_checkpoint, + rl_module_spec=multi_rl_module_spec_from_checkpoint, ) # create the algorithm with multiple nodes and check if the weights - # are the same as the original MARL Module + # are the same as the original MultiRLModule algo = config.build() algo_module_weights = algo.learner_group.get_weights() - marl_module_with_swapped_in_module = MultiAgentRLModule() - marl_module_with_swapped_in_module.add_module( - "policy_0", marl_module["policy_0"] + multi_rl_module_with_swapped_in_module = MultiRLModule() + multi_rl_module_with_swapped_in_module.add_module( + "policy_0", multi_rl_module["policy_0"] + ) + multi_rl_module_with_swapped_in_module.add_module( + "policy_1", module_to_swap_in ) - marl_module_with_swapped_in_module.add_module("policy_1", module_to_swap_in) check( algo_module_weights, - convert_to_numpy(marl_module_with_swapped_in_module.get_state()), + convert_to_numpy(multi_rl_module_with_swapped_in_module.get_state()), ) algo.train() algo.stop() @@ -196,9 +198,9 @@ def test_e2e_load_rl_module(self): ) env = gym.make("CartPole-v1") for fw in framework_iterator(config, frameworks=["tf2", "torch"]): - # create a marl_module to load and save it to a checkpoint directory + # create a multi_rl_module to load and save it to a checkpoint directory module_class = PPO_MODULES[fw] - module_spec = SingleAgentRLModuleSpec( + module_spec = RLModuleSpec( module_class=module_class, observation_space=env.observation_space, action_space=env.action_space, @@ -212,7 +214,7 @@ def test_e2e_load_rl_module(self): module_ckpt_path = tempfile.mkdtemp() module.save_to_path(module_ckpt_path) - module_to_load_spec = SingleAgentRLModuleSpec( + module_to_load_spec = RLModuleSpec( module_class=module_class, observation_space=env.observation_space, action_space=env.action_space, @@ -226,7 +228,7 @@ def test_e2e_load_rl_module(self): ) # create the algorithm with multiple nodes and check if the weights - # are the same as the original MARL Module + # are the same as the original MultiRLModule algo = config.build() algo_module_weights = algo.learner_group.get_weights() @@ -239,22 +241,22 @@ def test_e2e_load_rl_module(self): del algo shutil.rmtree(module_ckpt_path) - def test_e2e_load_complex_marl_module_with_modules_to_load(self): + def test_e2e_load_complex_multi_rl_module_with_modules_to_load(self): """Test if we can train a PPO algorithm with a cpkt MARL and RL module e2e. Additionally, check if we can set modules to load so that we can exclude - a module from our ckpted MARL module from being loaded. + a module from our ckpted MultiRLModule from being loaded. """ num_agents = 3 config = self.get_ppo_config(num_agents=num_agents) env = MultiAgentCartPole({"num_agents": num_agents}) for fw in framework_iterator(config, frameworks=["tf2", "torch"]): - # create a marl_module to load and save it to a checkpoint directory + # create a multi_rl_module to load and save it to a checkpoint directory module_specs = {} module_class = PPO_MODULES[fw] for i in range(num_agents): - module_specs[f"policy_{i}"] = SingleAgentRLModuleSpec( + module_specs[f"policy_{i}"] = RLModuleSpec( module_class=module_class, observation_space=env.observation_space[0], action_space=env.action_space[0], @@ -264,13 +266,13 @@ def test_e2e_load_complex_marl_module_with_modules_to_load(self): | {"fcnet_hiddens": [32 * (i + 1)]}, catalog_class=PPOCatalog, ) - marl_module_spec = MultiAgentRLModuleSpec(module_specs=module_specs) - marl_module = marl_module_spec.build() + multi_rl_module_spec = MultiRLModuleSpec(module_specs=module_specs) + multi_rl_module = multi_rl_module_spec.build() marl_checkpoint_path = tempfile.mkdtemp() - marl_module.save_to_path(marl_checkpoint_path) + multi_rl_module.save_to_path(marl_checkpoint_path) # create a RLModule to load and override the "policy_1" module with - module_to_swap_in = SingleAgentRLModuleSpec( + module_to_swap_in = RLModuleSpec( module_class=module_class, observation_space=env.observation_space[0], action_space=env.action_space[0], @@ -285,7 +287,7 @@ def test_e2e_load_complex_marl_module_with_modules_to_load(self): # create a new MARL_spec with the checkpoint from the marl_checkpoint # and the module_to_swap_in_checkpoint - module_specs["policy_1"] = SingleAgentRLModuleSpec( + module_specs["policy_1"] = RLModuleSpec( module_class=module_class, observation_space=env.observation_space[0], action_space=env.action_space[0], @@ -293,7 +295,7 @@ def test_e2e_load_complex_marl_module_with_modules_to_load(self): catalog_class=PPOCatalog, load_state_path=module_to_swap_in_path, ) - marl_module_spec_from_checkpoint = MultiAgentRLModuleSpec( + multi_rl_module_spec_from_checkpoint = MultiRLModuleSpec( module_specs=module_specs, load_state_path=marl_checkpoint_path, modules_to_load={ @@ -301,28 +303,28 @@ def test_e2e_load_complex_marl_module_with_modules_to_load(self): }, ) config = config.api_stack(enable_rl_module_and_learner=True).rl_module( - rl_module_spec=marl_module_spec_from_checkpoint, + rl_module_spec=multi_rl_module_spec_from_checkpoint, ) # create the algorithm with multiple nodes and check if the weights - # are the same as the original MARL Module + # are the same as the original MultiRLModule algo = config.build() algo_module_weights = algo.learner_group.get_weights() - # weights of "policy_0" should be the same as in the loaded marl module + # weights of "policy_0" should be the same as in the loaded MultiRLModule # since we specified it as being apart of the modules_to_load check( algo_module_weights["policy_0"], - convert_to_numpy(marl_module["policy_0"].get_state()), + convert_to_numpy(multi_rl_module["policy_0"].get_state()), ) # weights of "policy_1" should be the same as in the module_to_swap_in since # we specified its load path separately in an rl_module_spec inside of the - # marl_module_spec_from_checkpoint + # multi_rl_module_spec_from_checkpoint check( algo_module_weights["policy_1"], convert_to_numpy(module_to_swap_in.get_state()), ) - # weights of "policy_2" should be different from the loaded marl module + # weights of "policy_2" should be different from the loaded MultiRLModule # since we didn't specify it as being apart of the modules_to_load policy_2_algo_module_weight_sum = np.sum( [ @@ -332,17 +334,17 @@ def test_e2e_load_complex_marl_module_with_modules_to_load(self): ) ] ) - policy_2_marl_module_weight_sum = np.sum( + policy_2_multi_rl_module_weight_sum = np.sum( [ np.sum(s) for s in tree.flatten( - convert_to_numpy(marl_module["policy_2"].get_state()) + convert_to_numpy(multi_rl_module["policy_2"].get_state()) ) ] ) check( policy_2_algo_module_weight_sum, - policy_2_marl_module_weight_sum, + policy_2_multi_rl_module_weight_sum, false=True, ) diff --git a/rllib/benchmarks/torch_compile/run_inference_bm.py b/rllib/benchmarks/torch_compile/run_inference_bm.py index fecf12219b548..d5da1f57a18ac 100644 --- a/rllib/benchmarks/torch_compile/run_inference_bm.py +++ b/rllib/benchmarks/torch_compile/run_inference_bm.py @@ -13,7 +13,7 @@ from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog from ray.rllib.benchmarks.torch_compile.utils import get_ppo_batch_for_env, timed -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.core.rl_module.torch.torch_rl_module import TorchCompileConfig from ray.rllib.env.wrappers.atari_wrappers import wrap_deepmind from ray.rllib.models.catalog import MODEL_DEFAULTS @@ -96,7 +96,7 @@ def main(pargs): # setup RLModule model_cfg = MODEL_DEFAULTS.copy() - spec = SingleAgentRLModuleSpec( + spec = RLModuleSpec( module_class=PPOTorchRLModule, observation_space=env.observation_space, action_space=env.action_space, diff --git a/rllib/connectors/common/add_states_from_episodes_to_batch.py b/rllib/connectors/common/add_states_from_episodes_to_batch.py index 49f4217d0cd08..bc9ccf12f0ca3 100644 --- a/rllib/connectors/common/add_states_from_episodes_to_batch.py +++ b/rllib/connectors/common/add_states_from_episodes_to_batch.py @@ -9,7 +9,7 @@ from ray.rllib.connectors.connector_v2 import ConnectorV2 from ray.rllib.core import DEFAULT_MODULE_ID from ray.rllib.core.columns import Columns -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModule +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.utils.annotations import override from ray.rllib.utils.numpy import convert_to_numpy @@ -245,7 +245,7 @@ def __call__( else: sa_module = ( rl_module[DEFAULT_MODULE_ID] - if isinstance(rl_module, MultiAgentRLModule) + if isinstance(rl_module, MultiRLModule) else rl_module ) # This single-agent RLModule is NOT stateful -> Skip. diff --git a/rllib/connectors/common/agent_to_module_mapping.py b/rllib/connectors/common/agent_to_module_mapping.py index b54a20bb050fb..ee6a738ee41e4 100644 --- a/rllib/connectors/common/agent_to_module_mapping.py +++ b/rllib/connectors/common/agent_to_module_mapping.py @@ -4,7 +4,7 @@ import gymnasium as gym from ray.rllib.connectors.connector_v2 import ConnectorV2 -from ray.rllib.core.rl_module.rl_module import RLModule, SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec from ray.rllib.env.multi_agent_episode import MultiAgentEpisode from ray.rllib.utils.annotations import override from ray.rllib.utils.typing import EpisodeType, ModuleID @@ -114,7 +114,7 @@ def __init__( input_observation_space: Optional[gym.Space] = None, input_action_space: Optional[gym.Space] = None, *, - module_specs: Dict[ModuleID, SingleAgentRLModuleSpec], + module_specs: Dict[ModuleID, RLModuleSpec], agent_to_module_mapping_fn, ): super().__init__(input_observation_space, input_action_space) @@ -239,8 +239,8 @@ def _map_space_if_necessary(self, space, which: str = "obs"): "mapping function is stochastic (such that for some agent A, " "more than one ModuleID might be returned somewhat randomly). " f"Fix this error by providing {which}-space information using " - "`config.rl_module(rl_module_spec=MultiAgentRLModuleSpec(" - f"module_specs={{'{module_id}': SingleAgentRLModuleSpec(" + "`config.rl_module(rl_module_spec=MultiRLModuleSpec(" + f"module_specs={{'{module_id}': RLModuleSpec(" "observation_space=..., action_space=...)}}))" ) diff --git a/rllib/connectors/common/batch_individual_items.py b/rllib/connectors/common/batch_individual_items.py index b095d4d77a7ad..1a9c671eb6c82 100644 --- a/rllib/connectors/common/batch_individual_items.py +++ b/rllib/connectors/common/batch_individual_items.py @@ -5,7 +5,7 @@ from ray.rllib.connectors.connector_v2 import ConnectorV2 from ray.rllib.core import DEFAULT_MODULE_ID from ray.rllib.core.columns import Columns -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModule +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.utils.annotations import override from ray.rllib.utils.spaces.space_utils import batch @@ -45,7 +45,7 @@ def __call__( shared_data: Optional[dict] = None, **kwargs, ) -> Any: - is_marl_module = isinstance(rl_module, MultiAgentRLModule) + is_multi_rl_module = isinstance(rl_module, MultiRLModule) # Convert lists of individual items into properly batched data. for column, column_data in data.copy().items(): @@ -53,7 +53,7 @@ def __call__( # the AgentToModuleMapping connector has already been applied, leading # to a batch structure of: # [module_id] -> [col0] -> [list of items] - if is_marl_module and column in rl_module: + if is_multi_rl_module and column in rl_module: # Case, in which a column has already been properly batched before this # connector piece is called. if not self._multi_agent: @@ -98,7 +98,7 @@ def __call__( individual_items_already_have_batch_dim="auto", ) ) - if is_marl_module: + if is_multi_rl_module: if DEFAULT_MODULE_ID not in data: data[DEFAULT_MODULE_ID] = {} data[DEFAULT_MODULE_ID][column] = data.pop(column) diff --git a/rllib/connectors/common/numpy_to_tensor.py b/rllib/connectors/common/numpy_to_tensor.py index b3d2c44d5f0ab..78cb9a02ad39b 100644 --- a/rllib/connectors/common/numpy_to_tensor.py +++ b/rllib/connectors/common/numpy_to_tensor.py @@ -5,7 +5,7 @@ from ray.rllib.connectors.connector_v2 import ConnectorV2 from ray.rllib.core import DEFAULT_MODULE_ID from ray.rllib.core.columns import Columns -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModule +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule from ray.rllib.core.rl_module.rl_module import RLModule from ray.rllib.utils.annotations import override from ray.rllib.utils.torch_utils import convert_to_torch_tensor @@ -62,9 +62,9 @@ def __call__( **kwargs, ) -> Any: is_single_agent = False - is_marl_module = isinstance(rl_module, MultiAgentRLModule) + is_multi_rl_module = isinstance(rl_module, MultiRLModule) # `data` already a ModuleID to batch mapping format. - if not (is_marl_module and all(c in rl_module._rl_modules for c in data)): + if not (is_multi_rl_module and all(c in rl_module._rl_modules for c in data)): is_single_agent = True data = {DEFAULT_MODULE_ID: data} diff --git a/rllib/core/learner/learner.py b/rllib/core/learner/learner.py index d7f8116ee6491..c92a87d92bebb 100644 --- a/rllib/core/learner/learner.py +++ b/rllib/core/learner/learner.py @@ -27,11 +27,11 @@ ) from ray.rllib.core import COMPONENT_OPTIMIZER, COMPONENT_RL_MODULE, DEFAULT_MODULE_ID from ray.rllib.core.rl_module import validate_module_id -from ray.rllib.core.rl_module.marl_module import ( - MultiAgentRLModule, - MultiAgentRLModuleSpec, +from ray.rllib.core.rl_module.multi_rl_module import ( + MultiRLModule, + MultiRLModuleSpec, ) -from ray.rllib.core.rl_module.rl_module import RLModule, SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec from ray.rllib.policy.policy import PolicySpec from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch from ray.rllib.utils.annotations import ( @@ -132,8 +132,8 @@ class Learner(Checkpointable): If the module is a single agent module, after building the module it will be converted to a multi-agent module with a default key. Can be none if the module is provided directly via the `module` argument. Refer to - ray.rllib.core.rl_module.SingleAgentRLModuleSpec - or ray.rllib.core.rl_module.MultiAgentRLModuleSpec for more info. + ray.rllib.core.rl_module.RLModuleSpec + or ray.rllib.core.rl_module.MultiRLModuleSpec for more info. module: If learner is being used stand-alone, the RLModule can be optionally passed in directly instead of the through the `module_spec`. @@ -151,7 +151,7 @@ class Learner(Checkpointable): PPOTorchRLModule ) from ray.rllib.core import COMPONENT_RL_MODULE - from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec + from ray.rllib.core.rl_module.rl_module import RLModuleSpec env = gym.make("CartPole-v1") @@ -173,7 +173,7 @@ class Learner(Checkpointable): # Add a new module, perhaps for league based training. learner.add_module( module_id="new_player", - module_spec=SingleAgentRLModuleSpec( + module_spec=RLModuleSpec( module_class=PPOTorchRLModule, observation_space=env.observation_space, action_space=env.action_space, @@ -197,10 +197,10 @@ class Learner(Checkpointable): # Set the state of the learner. learner.set_state(state) - # Get the weights of the underlying multi-agent RLModule. + # Get the weights of the underlying MultiRLModule. weights = learner.get_state(components=COMPONENT_RL_MODULE) - # Set the weights of the underlying multi-agent RLModule. + # Set the weights of the underlying MultiRLModule. learner.set_state({COMPONENT_RL_MODULE: weights}) @@ -225,9 +225,7 @@ def __init__( self, *, config: "AlgorithmConfig", - module_spec: Optional[ - Union[SingleAgentRLModuleSpec, MultiAgentRLModuleSpec] - ] = None, + module_spec: Optional[Union[RLModuleSpec, MultiRLModuleSpec]] = None, module: Optional[RLModule] = None, ): # TODO (sven): Figure out how to do this @@ -250,8 +248,8 @@ def __init__( # These are the attributes that are set during build. - # The actual MARLModule used by this Learner. - self._module: Optional[MultiAgentRLModule] = None + # The actual MultiRLModule used by this Learner. + self._module: Optional[MultiRLModule] = None # Our Learner connector pipeline. self._learner_connector: Optional[LearnerConnectorPipeline] = None # These are set for properly applying optimizers and adding or removing modules. @@ -294,7 +292,7 @@ def build(self) -> None: # TODO (sven): Figure out which space to provide here. For now, # it doesn't matter, as the default connector piece doesn't use # this information anyway. - # module_spec = self._module_spec.as_multi_agent() + # module_spec = self._module_spec.as_multi_rl_module_spec() self._learner_connector = self.config.build_learner_connector( input_observation_space=None, input_action_space=None, @@ -316,8 +314,8 @@ def distributed(self) -> bool: return self._distributed @property - def module(self) -> MultiAgentRLModule: - """The multi-agent RLModule that is being trained.""" + def module(self) -> MultiRLModule: + """The MultiRLModule that is being trained.""" return self._module def register_optimizer( @@ -428,7 +426,7 @@ def configure_optimizers_for_module( ) -> None: """Configures an optimizer for the given module_id. - This method is called for each RLModule in the Multi-Agent RLModule being + This method is called for each RLModule in the MultiRLModule being trained by the Learner, as well as any new module added during training via `self.add_module()`. It should configure and construct one or more optimizers and register them via calls to `self.register_optimizer()` along with the @@ -472,7 +470,7 @@ def postprocess_gradients(self, gradients_dict: ParamDict) -> ParamDict: algorithm specific gradient postprocessing steps. This default implementation calls `self.postprocess_gradients_for_module()` - on each of the sub-modules in our MultiAgentRLModule: `self.module` and + on each of the sub-modules in our MultiRLModule: `self.module` and returns the accumulated gradients dicts. Args: @@ -566,7 +564,7 @@ def postprocess_gradients_for_module( @OverrideToImplementCustomLogic @abc.abstractmethod def apply_gradients(self, gradients_dict: ParamDict) -> None: - """Applies the gradients to the MultiAgentRLModule parameters. + """Applies the gradients to the MultiRLModule parameters. Args: gradients_dict: A dictionary of gradients in the same (flat) format as @@ -696,11 +694,11 @@ def add_module( self, *, module_id: ModuleID, - module_spec: SingleAgentRLModuleSpec, + module_spec: RLModuleSpec, config_overrides: Optional[Dict] = None, new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None, - ) -> MultiAgentRLModuleSpec: - """Adds a module to the underlying MultiAgentRLModule. + ) -> MultiRLModuleSpec: + """Adds a module to the underlying MultiRLModule. Changes this Learner's config in order to make this architectural change permanent wrt. to checkpointing. @@ -718,7 +716,7 @@ def add_module( returns False) will not be updated. Returns: - The new MultiAgentRLModuleSpec (after the RLModule has been added). + The new MultiRLModuleSpec (after the RLModule has been added). """ validate_module_id(module_id, error=True) self._check_is_built() @@ -739,9 +737,7 @@ def add_module( self.config.multi_agent( algorithm_config_overrides_per_module={module_id: config_overrides} ) - self.config.rl_module( - rl_module_spec=MultiAgentRLModuleSpec.from_module(self.module) - ) + self.config.rl_module(rl_module_spec=MultiRLModuleSpec.from_module(self.module)) if new_should_module_be_updated is not None: self.config.multi_agent(policies_to_train=new_should_module_be_updated) @@ -758,7 +754,7 @@ def remove_module( module_id: ModuleID, *, new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None, - ) -> MultiAgentRLModuleSpec: + ) -> MultiRLModuleSpec: """Removes a module from the Learner. Args: @@ -771,7 +767,7 @@ def remove_module( returns False) will not be updated. Returns: - The new MultiAgentRLModuleSpec (after the RLModule has been removed). + The new MultiRLModuleSpec (after the RLModule has been removed). """ self._check_is_built() module = self.module[module_id] @@ -791,7 +787,7 @@ def remove_module( del self._optimizer_lr_schedules[optimizer] del self._module_optimizers[module_id] - # Remove the module from the MARLModule. + # Remove the module from the MultiRLModule. self.module.remove_module(module_id) # Change self.config to reflect the new architecture. @@ -801,9 +797,7 @@ def remove_module( self.config.algorithm_config_overrides_per_module.pop(module_id, None) if new_should_module_be_updated is not None: self.config.multi_agent(policies_to_train=new_should_module_be_updated) - self.config.rl_module( - rl_module_spec=MultiAgentRLModuleSpec.from_module(self.module) - ) + self.config.rl_module(rl_module_spec=MultiRLModuleSpec.from_module(self.module)) # Remove all stats from the module from our metrics logger, so we don't report # results from this module again. @@ -846,7 +840,7 @@ def compute_loss( specify the specific loss computation logic. If the algorithm is single agent `compute_loss_for_module()` should be overridden instead. `fwd_out` is the output of the `forward_train()` method of the underlying - MultiAgentRLModule. `batch` is the data that was used to compute `fwd_out`. + MultiRLModule. `batch` is the data that was used to compute `fwd_out`. The returned dictionary must contain a key called ALL_MODULES, which will be used to compute gradients. It is recommended to not compute any forward passes within this method, and to use the @@ -936,7 +930,7 @@ def update_from_batch( You can use this method to take more than one backward pass on the batch. The same `minibatch_size` and `num_iters` will be used for all module ids in - MultiAgentRLModule. + MultiRLModule. Args: batch: A batch of training data to update from. @@ -989,7 +983,7 @@ def update_from_episodes( You can use this method to take more than one backward pass on the batch. The same `minibatch_size` and `num_iters` will be used for all module ids in - MultiAgentRLModule. + MultiRLModule. Args: episodes: An list of episode objects to update from. @@ -1472,7 +1466,7 @@ def _is_module_compatible_with_learner(self, module: RLModule) -> bool: True if the module is compatible with the learner. """ - def _make_module(self) -> MultiAgentRLModule: + def _make_module(self) -> MultiRLModule: """Construct the multi-agent RL module for the learner. This method uses `self._module_specs` or `self._module_obj` to construct the @@ -1481,7 +1475,7 @@ def _make_module(self) -> MultiAgentRLModule: need to happen for instantiation of the module. Returns: - A constructed MultiAgentRLModule. + A constructed MultiRLModule. """ # Module was provided directly through constructor -> Use as-is. if self._module_obj is not None: @@ -1495,8 +1489,8 @@ def _make_module(self) -> MultiAgentRLModule: else: module = self.config.get_multi_agent_module_spec().build() - # If not already, convert to MultiAgentRLModule. - module = module.as_multi_agent() + # If not already, convert to MultiRLModule. + module = module.as_multi_rl_module() return module diff --git a/rllib/core/learner/learner_group.py b/rllib/core/learner/learner_group.py index 5b36e71f07d2c..525e15a081c05 100644 --- a/rllib/core/learner/learner_group.py +++ b/rllib/core/learner/learner_group.py @@ -23,8 +23,8 @@ from ray.rllib.core import COMPONENT_LEARNER, COMPONENT_RL_MODULE from ray.rllib.core.learner.learner import Learner from ray.rllib.core.rl_module import validate_module_id -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.env.multi_agent_episode import MultiAgentEpisode from ray.rllib.policy.policy import PolicySpec from ray.rllib.policy.sample_batch import MultiAgentBatch @@ -49,7 +49,7 @@ from ray.rllib.utils.typing import ( EpisodeType, ModuleID, - RLModuleSpec, + RLModuleSpecType, ShouldModuleBeUpdatedFn, StateDict, T, @@ -91,7 +91,8 @@ def __init__( self, *, config: "AlgorithmConfig", - module_spec: Optional[RLModuleSpec] = None, + # TODO (sven): Rename into `rl_module_spec`. + module_spec: Optional[RLModuleSpecType] = None, ): """Initializes a LearnerGroup instance. @@ -113,7 +114,7 @@ def __init__( self._module_spec = module_spec learner_class = self.config.learner_class - module_spec = module_spec or self.config.get_marl_module_spec() + module_spec = module_spec or self.config.get_multi_rl_module_spec() self._learner = None self._workers = None @@ -672,11 +673,11 @@ def add_module( self, *, module_id: ModuleID, - module_spec: SingleAgentRLModuleSpec, + module_spec: RLModuleSpec, config_overrides: Optional[Dict] = None, new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None, - ) -> MultiAgentRLModuleSpec: - """Adds a module to the underlying MultiAgentRLModule. + ) -> MultiRLModuleSpec: + """Adds a module to the underlying MultiRLModule. Changes this Learner's config in order to make this architectural change permanent wrt. to checkpointing. @@ -694,7 +695,7 @@ def add_module( returns False) will not be updated. Returns: - The new MultiAgentRLModuleSpec (after the change has been performed). + The new MultiRLModuleSpec (after the change has been performed). """ validate_module_id(module_id, error=True) @@ -731,7 +732,7 @@ def remove_module( module_id: ModuleID, *, new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None, - ) -> MultiAgentRLModuleSpec: + ) -> MultiRLModuleSpec: """Removes a module from the Learner. Args: @@ -744,7 +745,7 @@ def remove_module( returns False) will not be updated. Returns: - The new MultiAgentRLModuleSpec (after the change has been performed). + The new MultiRLModuleSpec (after the change has been performed). """ # Remove all stats from the module from our metrics logger (hybrid API stack # only), so we don't report results from this module again. @@ -848,7 +849,7 @@ def set_weights(self, weights) -> None: """Convenience method instead of self.set_state({'learner': {'rl_module': ..}}). Args: - weights: The weights dict of the MARLModule of a Learner inside this + weights: The weights dict of the MultiRLModule of a Learner inside this LearnerGroup. """ self.set_state({COMPONENT_LEARNER: {COMPONENT_RL_MODULE: weights}}) @@ -982,59 +983,59 @@ def load_state(self, *args, **kwargs): def load_module_state( self, *, - marl_module_ckpt_dir: Optional[str] = None, + multi_rl_module_ckpt_dir: Optional[str] = None, modules_to_load: Optional[Set[str]] = None, rl_module_ckpt_dirs: Optional[Dict[ModuleID, str]] = None, ) -> None: """Load the checkpoints of the modules being trained by this LearnerGroup. `load_module_state` can be used 3 ways: - 1. Load a checkpoint for the MultiAgentRLModule being trained by this + 1. Load a checkpoint for the MultiRLModule being trained by this LearnerGroup. Limit the modules that are loaded from the checkpoint by specifying the `modules_to_load` argument. 2. Load the checkpoint(s) for single agent RLModules that - are in the MultiAgentRLModule being trained by this LearnerGroup. - 3. Load a checkpoint for the MultiAgentRLModule being trained by this + are in the MultiRLModule being trained by this LearnerGroup. + 3. Load a checkpoint for the MultiRLModule being trained by this LearnerGroup and load the checkpoint(s) for single agent RLModules - that are in the MultiAgentRLModule. The checkpoints for the single + that are in the MultiRLModule. The checkpoints for the single agent RLModules take precedence over the module states in the - MultiAgentRLModule checkpoint. + MultiRLModule checkpoint. - NOTE: At lease one of marl_module_ckpt_dir or rl_module_ckpt_dirs is + NOTE: At lease one of multi_rl_module_ckpt_dir or rl_module_ckpt_dirs is must be specified. modules_to_load can only be specified if - marl_module_ckpt_dir is specified. + multi_rl_module_ckpt_dir is specified. Args: - marl_module_ckpt_dir: The path to the checkpoint for the - MultiAgentRLModule. + multi_rl_module_ckpt_dir: The path to the checkpoint for the + MultiRLModule. modules_to_load: A set of module ids to load from the checkpoint. rl_module_ckpt_dirs: A mapping from module ids to the path to a checkpoint for a single agent RLModule. """ - if not (marl_module_ckpt_dir or rl_module_ckpt_dirs): + if not (multi_rl_module_ckpt_dir or rl_module_ckpt_dirs): raise ValueError( - "At least one of `marl_module_ckpt_dir` or " + "At least one of `multi_rl_module_ckpt_dir` or " "`rl_module_ckpt_dirs` must be provided!" ) - if marl_module_ckpt_dir: - marl_module_ckpt_dir = pathlib.Path(marl_module_ckpt_dir) + if multi_rl_module_ckpt_dir: + multi_rl_module_ckpt_dir = pathlib.Path(multi_rl_module_ckpt_dir) if rl_module_ckpt_dirs: for module_id, path in rl_module_ckpt_dirs.items(): rl_module_ckpt_dirs[module_id] = pathlib.Path(path) - # MARLModule checkpoint is provided. - if marl_module_ckpt_dir: - # Restore the entire MARLModule state. + # MultiRLModule checkpoint is provided. + if multi_rl_module_ckpt_dir: + # Restore the entire MultiRLModule state. if modules_to_load is None: self.restore_from_path( - marl_module_ckpt_dir, + multi_rl_module_ckpt_dir, component=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE, ) # Restore individual module IDs. else: for module_id in modules_to_load: self.restore_from_path( - marl_module_ckpt_dir / module_id, + multi_rl_module_ckpt_dir / module_id, component=( COMPONENT_LEARNER + "/" diff --git a/rllib/core/learner/tests/test_learner_group.py b/rllib/core/learner/tests/test_learner_group.py index 73f3e3354237e..430c26c11b2b2 100644 --- a/rllib/core/learner/tests/test_learner_group.py +++ b/rllib/core/learner/tests/test_learner_group.py @@ -16,8 +16,8 @@ DEFAULT_MODULE_ID, ) from ray.rllib.core.learner.learner import Learner -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModule -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.core.testing.torch.bc_learner import BCTorchLearner from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule from ray.rllib.core.testing.utils import ( @@ -204,7 +204,7 @@ def test_learner_group_build_from_algorithm_config(self): BaseTestingAlgorithmConfig() .training(learner_class=BCTorchLearner) .rl_module( - rl_module_spec=SingleAgentRLModuleSpec( + rl_module_spec=RLModuleSpec( module_class=DiscreteBCTorchModule, model_config_dict={"fcnet_hiddens": [32]}, ) @@ -340,8 +340,8 @@ def setUpClass(cls) -> None: def tearDownClass(cls) -> None: ray.shutdown() - def test_restore_from_path_marl_module_and_individual_modules(self): - """Tests whether MARLModule- and single RLModule states can be restored.""" + def test_restore_from_path_multi_rl_module_and_individual_modules(self): + """Tests whether MultiRLModule- and single RLModule states can be restored.""" fws = ["torch", "tf2"] # this is expanded to more scaling modes on the release ci. scaling_modes = ["local-cpu", "multi-gpu-ddp"] @@ -357,28 +357,30 @@ def test_restore_from_path_marl_module_and_individual_modules(self): ) config = BaseTestingAlgorithmConfig().update_from_dict(config_overrides) learner_group = config.build_learner_group(env=env) - spec = config.get_marl_module_spec(env=env).module_specs[DEFAULT_MODULE_ID] + spec = config.get_multi_rl_module_spec(env=env).module_specs[ + DEFAULT_MODULE_ID + ] learner_group.add_module(module_id="0", module_spec=spec) learner_group.add_module(module_id="1", module_spec=spec) learner_group.remove_module(DEFAULT_MODULE_ID) module_0 = spec.build() module_1 = spec.build() - marl_module = MultiAgentRLModule() - marl_module.add_module(module_id="0", module=module_0) - marl_module.add_module(module_id="1", module=module_1) + multi_rl_module = MultiRLModule() + multi_rl_module.add_module(module_id="0", module=module_0) + multi_rl_module.add_module(module_id="1", module=module_1) - # Check if we can load just the MARL Module. + # Check if we can load just the MultiRLModule. with tempfile.TemporaryDirectory() as tmpdir: - marl_module.save_to_path(tmpdir) + multi_rl_module.save_to_path(tmpdir) old_learner_weights = learner_group.get_weights() learner_group.restore_from_path( tmpdir, component=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE, ) # Check the weights of the module in the learner group are the - # same as the weights of the newly created marl module - check(learner_group.get_weights(), marl_module.get_state()) + # same as the weights of the newly created MultiRLModule + check(learner_group.get_weights(), multi_rl_module.get_state()) learner_group.set_state( { COMPONENT_LEARNER: {COMPONENT_RL_MODULE: old_learner_weights}, @@ -403,22 +405,22 @@ def test_restore_from_path_marl_module_and_individual_modules(self): component=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE + "/1", ) # check the weights of the module in the learner group are the - # same as the weights of the newly created marl module - new_marl_module = MultiAgentRLModule() - new_marl_module.add_module(module_id="0", module=module_0) - new_marl_module.add_module(module_id="1", module=temp_module) - check(learner_group.get_weights(), new_marl_module.get_state()) + # same as the weights of the newly created MultiRLModule + new_multi_rl_module = MultiRLModule() + new_multi_rl_module.add_module(module_id="0", module=module_0) + new_multi_rl_module.add_module(module_id="1", module=temp_module) + check(learner_group.get_weights(), new_multi_rl_module.get_state()) learner_group.set_weights(old_learner_weights) - # Check if we can first load a MARLModule, then a single agent RLModule - # (within that MARLModule). Check that the single agent RL Module is loaded - # over the matching submodule in the MARL Module. + # Check if we can first load a MultiRLModule, then a single agent RLModule + # (within that MultiRLModule). Check that the single agent RL Module is + # loaded over the matching submodule in the MultiRLModule. with tempfile.TemporaryDirectory() as tmpdir: module_0 = spec.build() - marl_module = MultiAgentRLModule() - marl_module.add_module(module_id="0", module=module_0) - marl_module.add_module(module_id="1", module=spec.build()) - marl_module.save_to_path(tmpdir) + multi_rl_module = MultiRLModule() + multi_rl_module.add_module(module_id="0", module=module_0) + multi_rl_module.add_module(module_id="1", module=spec.build()) + multi_rl_module.save_to_path(tmpdir) with tempfile.TemporaryDirectory() as tmpdir2: module_1 = spec.build() module_1.save_to_path(tmpdir2) @@ -430,10 +432,10 @@ def test_restore_from_path_marl_module_and_individual_modules(self): tmpdir2, component=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE + "/1", ) - new_marl_module = MultiAgentRLModule() - new_marl_module.add_module(module_id="0", module=module_0) - new_marl_module.add_module(module_id="1", module=module_1) - check(learner_group.get_weights(), new_marl_module.get_state()) + new_multi_rl_module = MultiRLModule() + new_multi_rl_module.add_module(module_id="0", module=module_0) + new_multi_rl_module.add_module(module_id="1", module=module_1) + check(learner_group.get_weights(), new_multi_rl_module.get_state()) del learner_group diff --git a/rllib/core/learner/tf/tf_learner.py b/rllib/core/learner/tf/tf_learner.py index 65716349c752e..07caf7419f2cb 100644 --- a/rllib/core/learner/tf/tf_learner.py +++ b/rllib/core/learner/tf/tf_learner.py @@ -12,10 +12,10 @@ ) from ray.rllib.core.learner.learner import Learner -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec from ray.rllib.core.rl_module.rl_module import ( RLModule, - SingleAgentRLModuleSpec, + RLModuleSpec, ) from ray.rllib.core.rl_module.tf.tf_rl_module import TfRLModule from ray.rllib.policy.eager_tf_policy import _convert_to_tf @@ -122,8 +122,8 @@ def apply_gradients(self, gradients_dict: ParamDict) -> None: @override(Learner) def restore_from_path(self, path: Union[str, pathlib.Path]) -> None: - # This operation is potentially very costly because a MARL Module is created at - # build time, destroyed, and then a new one is created from a checkpoint. + # This operation is potentially very costly because a MultiRLModule is created + # at build time, destroyed, and then a new one is created from a checkpoint. # However, it is necessary due to complications with the way that Ray Tune # restores failed trials. When Tune restores a failed trial, it reconstructs the # entire experiment from the initial config. Therefore, to reflect any changes @@ -194,7 +194,7 @@ def add_module( self, *, module_id: ModuleID, - module_spec: SingleAgentRLModuleSpec, + module_spec: RLModuleSpec, ) -> None: # TODO(Avnishn): # WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead @@ -215,7 +215,7 @@ def add_module( ) @override(Learner) - def remove_module(self, module_id: ModuleID, **kwargs) -> MultiAgentRLModuleSpec: + def remove_module(self, module_id: ModuleID, **kwargs) -> MultiRLModuleSpec: with self._strategy.scope(): marl_spec = super().remove_module(module_id, **kwargs) diff --git a/rllib/core/learner/torch/torch_learner.py b/rllib/core/learner/torch/torch_learner.py index 9a85f6ed387e5..ff643ffb6098f 100644 --- a/rllib/core/learner/torch/torch_learner.py +++ b/rllib/core/learner/torch/torch_learner.py @@ -14,13 +14,13 @@ TorchCompileWhatToCompile, ) from ray.rllib.core.learner.learner import Learner -from ray.rllib.core.rl_module.marl_module import ( - MultiAgentRLModule, - MultiAgentRLModuleSpec, +from ray.rllib.core.rl_module.multi_rl_module import ( + MultiRLModule, + MultiRLModuleSpec, ) from ray.rllib.core.rl_module.rl_module import ( RLModule, - SingleAgentRLModuleSpec, + RLModuleSpec, ) from ray.rllib.core.rl_module.torch.torch_rl_module import ( TorchCompileConfig, @@ -208,10 +208,11 @@ def add_module( self, *, module_id: ModuleID, - module_spec: SingleAgentRLModuleSpec, + # TODO (sven): Rename to `rl_module_spec`. + module_spec: RLModuleSpec, config_overrides: Optional[Dict] = None, new_should_module_be_updated: Optional[ShouldModuleBeUpdatedFn] = None, - ) -> MultiAgentRLModuleSpec: + ) -> MultiRLModuleSpec: # Call super's add_module method. marl_spec = super().add_module( module_id=module_id, @@ -256,7 +257,7 @@ def add_module( return marl_spec @override(Learner) - def remove_module(self, module_id: ModuleID, **kwargs) -> MultiAgentRLModuleSpec: + def remove_module(self, module_id: ModuleID, **kwargs) -> MultiRLModuleSpec: marl_spec = super().remove_module(module_id, **kwargs) if self._torch_compile_complete_update: @@ -329,7 +330,7 @@ def build(self) -> None: if self._torch_compile_forward_train: if isinstance(self._module, TorchRLModule): self._module.compile(self._torch_compile_cfg) - elif isinstance(self._module, MultiAgentRLModule): + elif isinstance(self._module, MultiRLModule): for module in self._module._rl_modules.values(): # Compile only TorchRLModules, e.g. we don't want to compile # a RandomRLModule. @@ -338,7 +339,7 @@ def build(self) -> None: else: raise ValueError( "Torch compile is only supported for TorchRLModule and " - "MultiAgentRLModule." + "MultiRLModule." ) self._possibly_compiled_update = self._uncompiled_update @@ -396,11 +397,11 @@ def _update(self, batch: Dict[str, Any]) -> Tuple[Any, Any, Any]: def _make_modules_ddp_if_necessary(self) -> None: """Default logic for (maybe) making all Modules within self._module DDP.""" - # If the module is a MultiAgentRLModule and nn.Module we can simply assume + # If the module is a MultiRLModule and nn.Module we can simply assume # all the submodules are registered. Otherwise, we need to loop through # each submodule and move it to the correct device. # TODO (Kourosh): This can result in missing modules if the user does not - # register them in the MultiAgentRLModule. We should find a better way to + # register them in the MultiRLModule. We should find a better way to # handle this. if self._distributed: # Single agent module: Convert to `TorchDDPRLModule`. @@ -408,7 +409,7 @@ def _make_modules_ddp_if_necessary(self) -> None: self._module = TorchDDPRLModule(self._module) # Multi agent module: Convert each submodule to `TorchDDPRLModule`. else: - assert isinstance(self._module, MultiAgentRLModule) + assert isinstance(self._module, MultiRLModule) for key in self._module.keys(): sub_module = self._module[key] if isinstance(sub_module, TorchRLModule): @@ -440,12 +441,12 @@ def _check_registered_optimizer( ) @override(Learner) - def _make_module(self) -> MultiAgentRLModule: + def _make_module(self) -> MultiRLModule: module = super()._make_module() self._map_module_to_device(module) return module - def _map_module_to_device(self, module: MultiAgentRLModule) -> None: + def _map_module_to_device(self, module: MultiRLModule) -> None: """Moves the module to the correct device.""" if isinstance(module, torch.nn.Module): module.to(self._device) diff --git a/rllib/core/models/tests/test_base_models.py b/rllib/core/models/tests/test_base_models.py index ea9e3344a7543..bd0d1cb203110 100644 --- a/rllib/core/models/tests/test_base_models.py +++ b/rllib/core/models/tests/test_base_models.py @@ -8,7 +8,7 @@ from ray.rllib.core.models.specs.specs_base import TensorSpec from ray.rllib.core.models.specs.specs_dict import SpecDict from ray.rllib.core.models.torch.base import TorchModel -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.utils.framework import try_import_tf, try_import_torch from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog @@ -245,7 +245,7 @@ def compile_me(input_dict): def test_torch_compile_forwards(self): """Test if logic around TorchCompileConfig works as intended.""" - spec = SingleAgentRLModuleSpec( + spec = RLModuleSpec( module_class=PPOTorchRLModule, observation_space=gym.spaces.Box(low=0, high=1, shape=(32,)), action_space=gym.spaces.Box(low=0, high=1, shape=(1,)), diff --git a/rllib/core/models/tests/test_catalog.py b/rllib/core/models/tests/test_catalog.py index f7f81074b32a0..17790278a0f62 100644 --- a/rllib/core/models/tests/test_catalog.py +++ b/rllib/core/models/tests/test_catalog.py @@ -27,7 +27,7 @@ CNNEncoderConfig, ) from ray.rllib.core.models.torch.base import TorchModel -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.models import MODEL_DEFAULTS from ray.rllib.models.tf.tf_distributions import ( TfCategorical, @@ -390,7 +390,7 @@ def build_vf_head(self, framework): PPOConfig() .api_stack(enable_rl_module_and_learner=True) .rl_module( - rl_module_spec=SingleAgentRLModuleSpec(catalog_class=MyCatalog), + rl_module_spec=RLModuleSpec(catalog_class=MyCatalog), ) .framework("torch") ) @@ -405,7 +405,7 @@ def build_vf_head(self, framework): config = ( PPOConfig() .rl_module( - rl_module_spec=SingleAgentRLModuleSpec( + rl_module_spec=RLModuleSpec( module_class=PPOTorchRLModule, catalog_class=MyCatalog ) ) @@ -455,7 +455,7 @@ def _determine_components(self): input_dims=self.observation_space.shape, ) - spec = SingleAgentRLModuleSpec( + spec = RLModuleSpec( module_class=PPOTorchRLModule, observation_space=env.observation_space, action_space=env.action_space, diff --git a/rllib/core/rl_module/__init__.py b/rllib/core/rl_module/__init__.py index 7ab9e59df3082..df1a65a284be3 100644 --- a/rllib/core/rl_module/__init__.py +++ b/rllib/core/rl_module/__init__.py @@ -1,11 +1,11 @@ import logging import re -from ray.rllib.core.rl_module.marl_module import ( - MultiAgentRLModule, - MultiAgentRLModuleSpec, +from ray.rllib.core.rl_module.multi_rl_module import ( + MultiRLModule, + MultiRLModuleSpec, ) -from ray.rllib.core.rl_module.rl_module import RLModule, SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec from ray.util import log_once from ray.util.annotations import PublicAPI @@ -45,9 +45,9 @@ def validate_module_id(policy_id: str, error: bool = False) -> None: __all__ = [ - "MultiAgentRLModule", - "MultiAgentRLModuleSpec", + "MultiRLModule", + "MultiRLModuleSpec", "RLModule", - "SingleAgentRLModuleSpec", + "RLModuleSpec", "validate_module_id", ] diff --git a/rllib/core/rl_module/marl_module.py b/rllib/core/rl_module/marl_module.py index 2934692506da6..4f35e4b7b0762 100644 --- a/rllib/core/rl_module/marl_module.py +++ b/rllib/core/rl_module/marl_module.py @@ -1,645 +1,18 @@ -from dataclasses import dataclass, field -import logging -import pprint -from typing import ( - Any, - Callable, - Collection, - Dict, - KeysView, - List, - Optional, - Set, - Tuple, - Type, - Union, -) - -from ray.rllib.core.models.specs.typing import SpecType -from ray.rllib.core.rl_module.rl_module import RLModule, SingleAgentRLModuleSpec +from ray.rllib.utils.deprecation import deprecation_warning -from ray.rllib.policy.sample_batch import MultiAgentBatch -from ray.rllib.utils.annotations import ( - ExperimentalAPI, - override, - OverrideToImplementCustomLogic, +from ray.rllib.core.rl_module.multi_rl_module import ( + MultiRLModule, + MultiRLModuleSpec, + MultiRLModuleConfig, ) -from ray.rllib.utils.checkpoints import Checkpointable -from ray.rllib.utils.serialization import serialize_type, deserialize_type -from ray.rllib.utils.typing import ModuleID, StateDict, T -from ray.util.annotations import PublicAPI - -logger = logging.getLogger("ray.rllib") - - -@PublicAPI(stability="alpha") -class MultiAgentRLModule(RLModule): - """Base class for multi-agent RLModules. - - This class holds a mapping from module_ids to the underlying RLModules. It provides - a convenient way of accessing each individual module, as well as accessing all of - them with only one API call. Whether a given module is trainable is - determined by the caller of this class (not the instance of this class itself). - - The extension of this class can include any arbitrary neural networks as part of - the multi-agent module. For example, a multi-agent module can include a shared - encoder network that is used by all the individual RLModules. It is up to the user - to decide how to implement this class. - - The default implementation assumes the data communicated as input and output of - the APIs in this class are `MultiAgentBatch` types. The `MultiAgentRLModule` simply - loops through each `module_id`, and runs the forward pass of the corresponding - `RLModule` object with the associated `SampleBatch` within the `MultiAgentBatch`. - It also assumes that the underlying RLModules do not share any parameters or - communication with one another. The behavior of modules with such advanced - communication would be undefined by default. To share parameters or communication - between the underlying RLModules, you should implement your own - `MultiAgentRLModule` subclass. - """ - - def __init__(self, config: Optional["MultiAgentRLModuleConfig"] = None) -> None: - """Initializes a MultiagentRLModule instance. - - Args: - config: An optional MultiAgentRLModuleConfig to use. If None, will use - `MultiAgentRLModuleConfig()` as default config. - """ - super().__init__(config or MultiAgentRLModuleConfig()) - - @override(RLModule) - def setup(self): - """Sets up the underlying RLModules.""" - self._rl_modules = {} - self.__check_module_configs(self.config.modules) - # Make sure all individual RLModules have the same framework OR framework=None. - framework = None - for module_id, module_spec in self.config.modules.items(): - self._rl_modules[module_id] = module_spec.build() - if framework is None: - framework = self._rl_modules[module_id].framework - else: - assert self._rl_modules[module_id].framework in [None, framework] - self.framework = framework - - @OverrideToImplementCustomLogic - @override(RLModule) - def get_initial_state(self) -> Any: - # TODO (sven): Replace by call to `self.foreach_module`, but only if this method - # supports returning dicts. - ret = {} - for module_id, module in self._rl_modules.items(): - ret[module_id] = module.get_initial_state() - return ret - - @OverrideToImplementCustomLogic - @override(RLModule) - def is_stateful(self) -> bool: - initial_state = self.get_initial_state() - assert isinstance(initial_state, dict), ( - "The initial state of an RLModule must be a dict, but is " - f"{type(initial_state)} instead." - ) - return bool(any(sa_init_state for sa_init_state in initial_state.values())) - - @classmethod - def __check_module_configs(cls, module_configs: Dict[ModuleID, Any]): - """Checks the module configs for validity. - - The module_configs be a mapping from module_ids to SingleAgentRLModuleSpec - objects. - - Args: - module_configs: The module configs to check. - - Raises: - ValueError: If the module configs are invalid. - """ - for module_id, module_spec in module_configs.items(): - if not isinstance(module_spec, SingleAgentRLModuleSpec): - raise ValueError( - f"Module {module_id} is not a SingleAgentRLModuleSpec object." - ) - - def keys(self) -> KeysView[ModuleID]: - """Returns a keys view over the module IDs in this MultiAgentRLModule.""" - return self._rl_modules.keys() - - def __len__(self) -> int: - """Returns the number of RLModules within this MultiAgentRLModule.""" - return len(self._rl_modules) - - @override(RLModule) - def as_multi_agent(self) -> "MultiAgentRLModule": - """Returns a multi-agent wrapper around this module. - - This method is overridden to avoid double wrapping. - - Returns: - The instance itself. - """ - return self - - def add_module( - self, - module_id: ModuleID, - module: RLModule, - *, - override: bool = False, - ) -> None: - """Adds a module at run time to the multi-agent module. - - Args: - module_id: The module ID to add. If the module ID already exists and - override is False, an error is raised. If override is True, the module - is replaced. - module: The module to add. - override: Whether to override the module if it already exists. - - Raises: - ValueError: If the module ID already exists and override is False. - Warnings are raised if the module id is not valid according to the - logic of ``validate_module_id()``. - """ - from ray.rllib.core.rl_module import validate_module_id - - validate_module_id(module_id) - - if module_id in self._rl_modules and not override: - raise ValueError( - f"Module ID {module_id} already exists. If your intention is to " - "override, set override=True." - ) - - # Set our own inference_only flag to False as soon as any added Module - # has `inference_only=False`. - if not module.config.inference_only: - self.config.inference_only = False - self._rl_modules[module_id] = module - # Update our `MultiAgentRLModuleConfig`, such that - if written to disk - - # it'll allow for proper restoring this instance through `.from_checkpoint()`. - self.config.modules[module_id] = SingleAgentRLModuleSpec.from_module(module) - - def remove_module( - self, module_id: ModuleID, *, raise_err_if_not_found: bool = True - ) -> None: - """Removes a module at run time from the multi-agent module. - - Args: - module_id: The module ID to remove. - raise_err_if_not_found: Whether to raise an error if the module ID is not - found. - Raises: - ValueError: If the module ID does not exist and raise_err_if_not_found is - True. - """ - if raise_err_if_not_found: - self._check_module_exists(module_id) - del self._rl_modules[module_id] - del self.config.modules[module_id] - - def foreach_module( - self, func: Callable[[ModuleID, RLModule, Optional[Any]], T], **kwargs - ) -> List[T]: - """Calls the given function with each (module_id, module). - - Args: - func: The function to call with each (module_id, module) tuple. - - Returns: - The lsit of return values of all calls to - `func([module_id, module, **kwargs])`. - """ - return [ - func(module_id, module.unwrapped(), **kwargs) - for module_id, module in self._rl_modules.items() - ] - - def __contains__(self, item) -> bool: - """Returns whether the given `item` (ModuleID) is present in self.""" - return item in self._rl_modules - - def __getitem__(self, module_id: ModuleID) -> RLModule: - """Returns the RLModule with the given module ID. - - Args: - module_id: The module ID to get. - - Returns: - The RLModule with the given module ID. - - Raises: - KeyError: If `module_id` cannot be found in self. - """ - self._check_module_exists(module_id) - return self._rl_modules[module_id] - - def get( - self, - module_id: ModuleID, - default: Optional[RLModule] = None, - ) -> Optional[RLModule]: - """Returns the module with the given module ID or default if not found in self. - - Args: - module_id: The module ID to get. - - Returns: - The RLModule with the given module ID or `default` if `module_id` not found - in `self`. - """ - if module_id not in self._rl_modules: - return default - return self._rl_modules[module_id] - - @override(RLModule) - def output_specs_train(self) -> SpecType: - return [] - - @override(RLModule) - def output_specs_inference(self) -> SpecType: - return [] - - @override(RLModule) - def output_specs_exploration(self) -> SpecType: - return [] - - @override(RLModule) - def _default_input_specs(self) -> SpecType: - """Multi-agent RLModule should not check the input specs. - - The underlying single-agent RLModules will check the input specs. - """ - return [] - - @override(RLModule) - def _forward_train( - self, batch: MultiAgentBatch, **kwargs - ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: - """Runs the forward_train pass. - - TODO(avnishn, kourosh): Review type hints for forward methods. - - Args: - batch: The batch of multi-agent data (i.e. mapping from module ids to - SampleBaches). - - Returns: - The output of the forward_train pass the specified modules. - """ - return self._run_forward_pass("forward_train", batch, **kwargs) - - @override(RLModule) - def _forward_inference( - self, batch: MultiAgentBatch, **kwargs - ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: - """Runs the forward_inference pass. - TODO(avnishn, kourosh): Review type hints for forward methods. - Args: - batch: The batch of multi-agent data (i.e. mapping from module ids to - SampleBaches). +MultiAgentRLModule = MultiRLModule +MultiAgentRLModuleConfig = MultiRLModuleConfig +MultiAgentRLModuleSpec = MultiRLModuleSpec - Returns: - The output of the forward_inference pass the specified modules. - """ - return self._run_forward_pass("forward_inference", batch, **kwargs) - - @override(RLModule) - def _forward_exploration( - self, batch: MultiAgentBatch, **kwargs - ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: - """Runs the forward_exploration pass. - - TODO(avnishn, kourosh): Review type hints for forward methods. - - Args: - batch: The batch of multi-agent data (i.e. mapping from module ids to - SampleBaches). - - Returns: - The output of the forward_exploration pass the specified modules. - """ - return self._run_forward_pass("forward_exploration", batch, **kwargs) - - @override(RLModule) - def get_state( - self, - components: Optional[Union[str, Collection[str]]] = None, - *, - not_components: Optional[Union[str, Collection[str]]] = None, - inference_only: bool = False, - **kwargs, - ) -> StateDict: - state = {} - - for module_id, rl_module in self.get_checkpointable_components(): - if self._check_component(module_id, components, not_components): - state[module_id] = rl_module.get_state( - components=self._get_subcomponents(module_id, components), - not_components=self._get_subcomponents(module_id, not_components), - inference_only=inference_only, - ) - return state - - @override(RLModule) - def set_state(self, state: StateDict) -> None: - """Sets the state of the multi-agent module. - - It is assumed that the state_dict is a mapping from module IDs to the - corresponding module's state. This method sets the state of each module by - calling their set_state method. If you want to set the state of some of the - RLModules within this MultiAgentRLModule your state_dict can only include the - state of those RLModules. Override this method to customize the state_dict for - custom more advanced multi-agent use cases. - - Args: - state: The state dict to set. - """ - for module_id, module_state in state.items(): - if module_id in self: - self._rl_modules[module_id].set_state(module_state) - - @override(Checkpointable) - def get_checkpointable_components(self) -> List[Tuple[str, Checkpointable]]: - return list(self._rl_modules.items()) - - def __repr__(self) -> str: - return f"MARL({pprint.pformat(self._rl_modules)})" - - def _run_forward_pass( - self, - forward_fn_name: str, - batch: Dict[ModuleID, Any], - **kwargs, - ) -> Dict[ModuleID, Dict[ModuleID, Any]]: - """This is a helper method that runs the forward pass for the given module. - - It uses forward_fn_name to get the forward pass method from the RLModule - (e.g. forward_train vs. forward_exploration) and runs it on the given batch. - - Args: - forward_fn_name: The name of the forward pass method to run. - batch: The batch of multi-agent data (i.e. mapping from module ids to - SampleBaches). - **kwargs: Additional keyword arguments to pass to the forward function. - - Returns: - The output of the forward pass the specified modules. The output is a - mapping from module ID to the output of the forward pass. - """ - - outputs = {} - for module_id in batch.keys(): - self._check_module_exists(module_id) - rl_module = self._rl_modules[module_id] - forward_fn = getattr(rl_module, forward_fn_name) - outputs[module_id] = forward_fn(batch[module_id], **kwargs) - - return outputs - - def _check_module_exists(self, module_id: ModuleID) -> None: - if module_id not in self._rl_modules: - raise KeyError( - f"Module with module_id {module_id} not found. " - f"Available modules: {set(self.keys())}" - ) - - -@PublicAPI(stability="alpha") -@dataclass -class MultiAgentRLModuleSpec: - """A utility spec class to make it constructing MARL modules easier. - - Users can extend this class to modify the behavior of base class. For example to - share neural networks across the modules, the build method can be overriden to - create the shared module first and then pass it to custom module classes that would - then use it as a shared module. - - Args: - marl_module_class: The class of the multi-agent RLModule to construct. By - default it is set to MultiAgentRLModule class. This class simply loops - throught each module and calls their foward methods. - module_specs: The module specs for each individual module. It can be either a - SingleAgentRLModuleSpec used for all module_ids or a dictionary mapping - from module IDs to SingleAgentRLModuleSpecs for each individual module. - load_state_path: The path to the module state to load from. NOTE: This must be - an absolute path. NOTE: If the load_state_path of this spec is set, and - the load_state_path of one of the SingleAgentRLModuleSpecs' is also set, - the weights of that RL Module will be loaded from the path specified in - the SingleAgentRLModuleSpec. This is useful if you want to load the weights - of a MARL module and also manually load the weights of some of the RL - modules within that MARL module from other checkpoints. - modules_to_load: A set of module ids to load from the checkpoint. This is - only used if load_state_path is set. If this is None, all modules are - loaded. - """ - - marl_module_class: Type[MultiAgentRLModule] = MultiAgentRLModule - inference_only: bool = False - module_specs: Union[ - SingleAgentRLModuleSpec, Dict[ModuleID, SingleAgentRLModuleSpec] - ] = None - load_state_path: Optional[str] = None - modules_to_load: Optional[Set[ModuleID]] = None - - def __post_init__(self): - if self.module_specs is None: - raise ValueError( - "Module_specs cannot be None. It should be either a " - "SingleAgentRLModuleSpec or a dictionary mapping from module IDs to " - "SingleAgentRLModuleSpecs for each individual module." - ) - - def get_marl_config(self) -> "MultiAgentRLModuleConfig": - """Returns the MultiAgentRLModuleConfig for this spec.""" - return MultiAgentRLModuleConfig( - # Only set `inference_only=True` if all single-agent specs are - # `inference_only`. - inference_only=all( - spec.inference_only for spec in self.module_specs.values() - ), - modules=self.module_specs, - ) - - @OverrideToImplementCustomLogic - def build(self, module_id: Optional[ModuleID] = None) -> RLModule: - """Builds either the multi-agent module or the single-agent module. - - If module_id is None, it builds the multi-agent module. Otherwise, it builds - the single-agent module with the given module_id. - - Note: If when build is called the module_specs is not a dictionary, it will - raise an error, since it should have been updated by the caller to inform us - about the module_ids. - - Args: - module_id: The module_id of the single-agent module to build. If None, it - builds the multi-agent module. - - Returns: - The built module. If module_id is None, it returns the multi-agent module. - """ - self._check_before_build() - - # ModuleID provided, return single-agent RLModule. - if module_id: - return self.module_specs[module_id].build() - - # Return MultiAgentRLModule. - module_config = self.get_marl_config() - module = self.marl_module_class(module_config) - return module - - def add_modules( - self, - module_specs: Dict[ModuleID, SingleAgentRLModuleSpec], - override: bool = True, - ) -> None: - """Add new module specs to the spec or updates existing ones. - - Args: - module_specs: The mapping for the module_id to the single-agent module - specs to be added to this multi-agent module spec. - override: Whether to override the existing module specs if they already - exist. If False, they are only updated. - """ - if self.module_specs is None: - self.module_specs = {} - for module_id, module_spec in module_specs.items(): - if override or module_id not in self.module_specs: - # Disable our `inference_only` as soon as any single-agent module has - # `inference_only=False`. - if not module_spec.inference_only: - self.inference_only = False - self.module_specs[module_id] = module_spec - else: - self.module_specs[module_id].update(module_spec) - - @classmethod - def from_module(self, module: MultiAgentRLModule) -> "MultiAgentRLModuleSpec": - """Creates a MultiAgentRLModuleSpec from a MultiAgentRLModule. - - Args: - module: The MultiAgentRLModule to create the spec from. - - Returns: - The MultiAgentRLModuleSpec. - """ - # we want to get the spec of the underlying unwrapped module that way we can - # easily reconstruct it. The only wrappers that we expect to support today are - # wrappers that allow us to do distributed training. Those will be added back - # by the learner if necessary. - module_specs = { - module_id: SingleAgentRLModuleSpec.from_module(rl_module.unwrapped()) - for module_id, rl_module in module._rl_modules.items() - } - marl_module_class = module.__class__ - return MultiAgentRLModuleSpec( - marl_module_class=marl_module_class, - inference_only=module.config.inference_only, - module_specs=module_specs, - ) - - def _check_before_build(self): - if not isinstance(self.module_specs, dict): - raise ValueError( - f"When build() is called on {self.__class__}, the module_specs " - "should be a dictionary mapping from module IDs to " - "SingleAgentRLModuleSpecs for each individual module." - ) - - def to_dict(self) -> Dict[str, Any]: - """Converts the MultiAgentRLModuleSpec to a dictionary.""" - return { - "marl_module_class": serialize_type(self.marl_module_class), - "inference_only": self.inference_only, - "module_specs": { - module_id: module_spec.to_dict() - for module_id, module_spec in self.module_specs.items() - }, - } - - @classmethod - def from_dict(cls, d) -> "MultiAgentRLModuleSpec": - """Creates a MultiAgentRLModuleSpec from a dictionary.""" - return MultiAgentRLModuleSpec( - marl_module_class=deserialize_type(d["marl_module_class"]), - inference_only=d["inference_only"], - module_specs={ - module_id: SingleAgentRLModuleSpec.from_dict(module_spec) - for module_id, module_spec in d["module_specs"].items() - }, - ) - - def update( - self, - other: Union["MultiAgentRLModuleSpec", SingleAgentRLModuleSpec], - override=False, - ) -> None: - """Updates this spec with the other spec. - - Traverses this MultiAgentRLModuleSpec's module_specs and updates them with - the module specs from the other MultiAgentRLModuleSpec. - - Args: - other: The other spec to update this spec with. - override: Whether to override the existing module specs if they already - exist. If False, they are only updated. - """ - if isinstance(other, SingleAgentRLModuleSpec): - # Disable our `inference_only` as soon as any single-agent module has - # `inference_only=False`. - if not other.inference_only: - self.inference_only = False - for mid, spec in self.module_specs.items(): - self.module_specs[mid].update(other, override=False) - elif isinstance(other.module_specs, dict): - self.add_modules(other.module_specs, override=override) - else: - assert isinstance(other, MultiAgentRLModuleSpec) - if not self.module_specs: - self.inference_only = other.inference_only - self.module_specs = other.module_specs - else: - if not other.inference_only: - self.inference_only = False - self.module_specs.update(other.module_specs) - - def as_multi_agent(self) -> "MultiAgentRLModuleSpec": - """Returns self to match `SingleAgentRLModuleSpec.as_multi_agent()`.""" - return self - - def __contains__(self, item) -> bool: - """Returns whether the given `item` (ModuleID) is present in self.""" - return item in self.module_specs - - -# TODO (sven): Shouldn't we simply use this class inside MultiAgentRLModuleSpec instead -# of duplicating all data records (e.g. `inference_only`) in `MultiAgentRLModuleSpec`? -# Same for SingleAgentRLModuleSpec, which should use RLModuleConfig instead of -# duplicating all settings, e.g. `observation_space`, `inference_only`, ... -@ExperimentalAPI -@dataclass -class MultiAgentRLModuleConfig: - inference_only: bool = False - modules: Dict[ModuleID, SingleAgentRLModuleSpec] = field(default_factory=dict) - - def to_dict(self): - return { - "inference_only": self.inference_only, - "modules": { - module_id: module_spec.to_dict() - for module_id, module_spec in self.modules.items() - }, - } - - @classmethod - def from_dict(cls, d) -> "MultiAgentRLModuleConfig": - return cls( - inference_only=d["inference_only"], - modules={ - module_id: SingleAgentRLModuleSpec.from_dict(module_spec) - for module_id, module_spec in d["modules"].items() - }, - ) +deprecation_warning( + old="ray.rllib.core.rl_module.marl_module", + new="ray.rllib.core.rl_module.multi_rl_module", + error=False, +) diff --git a/rllib/core/rl_module/multi_rl_module.py b/rllib/core/rl_module/multi_rl_module.py new file mode 100644 index 0000000000000..0775433e4fe48 --- /dev/null +++ b/rllib/core/rl_module/multi_rl_module.py @@ -0,0 +1,652 @@ +from dataclasses import dataclass, field +import logging +import pprint +from typing import ( + Any, + Callable, + Collection, + Dict, + KeysView, + List, + Optional, + Set, + Tuple, + Type, + Union, +) + +from ray.rllib.core.models.specs.typing import SpecType +from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec + +from ray.rllib.policy.sample_batch import MultiAgentBatch +from ray.rllib.utils.annotations import ( + ExperimentalAPI, + override, + OverrideToImplementCustomLogic, +) +from ray.rllib.utils.checkpoints import Checkpointable +from ray.rllib.utils.deprecation import Deprecated +from ray.rllib.utils.serialization import serialize_type, deserialize_type +from ray.rllib.utils.typing import ModuleID, StateDict, T +from ray.util.annotations import PublicAPI + +logger = logging.getLogger("ray.rllib") + + +@PublicAPI(stability="alpha") +class MultiRLModule(RLModule): + """Base class for an RLModule that contains n sub-RLModules. + + This class holds a mapping from ModuleID to underlying RLModules. It provides + a convenient way of accessing each individual module, as well as accessing all of + them with only one API call. Whether a given module is trainable is + determined by the caller of this class (not the instance of this class itself). + + The extension of this class can include any arbitrary neural networks as part of + the MultiRLModule. For example, a MultiRLModule can include a shared encoder network + that is used by all the individual (single-agent) RLModules. It is up to the user + to decide how to implement this class. + + The default implementation assumes the data communicated as input and output of + the APIs in this class are `MultiAgentBatch` types. The `MultiRLModule` simply + loops through each `module_id`, and runs the forward pass of the corresponding + `RLModule` object with the associated `SampleBatch` within the `MultiAgentBatch`. + It also assumes that the underlying RLModules do not share any parameters or + communication with one another. The behavior of modules with such advanced + communication would be undefined by default. To share parameters or communication + between the underlying RLModules, you should implement your own + `MultiRLModule` subclass. + """ + + def __init__(self, config: Optional["MultiRLModuleConfig"] = None) -> None: + """Initializes a MultiRLModule instance. + + Args: + config: An optional MultiRLModuleConfig to use. If None, will use + `MultiRLModuleConfig()` as default config. + """ + super().__init__(config or MultiRLModuleConfig()) + + @override(RLModule) + def setup(self): + """Sets up the underlying RLModules.""" + self._rl_modules = {} + self.__check_module_configs(self.config.modules) + # Make sure all individual RLModules have the same framework OR framework=None. + framework = None + for module_id, module_spec in self.config.modules.items(): + self._rl_modules[module_id] = module_spec.build() + if framework is None: + framework = self._rl_modules[module_id].framework + else: + assert self._rl_modules[module_id].framework in [None, framework] + self.framework = framework + + @OverrideToImplementCustomLogic + @override(RLModule) + def get_initial_state(self) -> Any: + # TODO (sven): Replace by call to `self.foreach_module`, but only if this method + # supports returning dicts. + ret = {} + for module_id, module in self._rl_modules.items(): + ret[module_id] = module.get_initial_state() + return ret + + @OverrideToImplementCustomLogic + @override(RLModule) + def is_stateful(self) -> bool: + initial_state = self.get_initial_state() + assert isinstance(initial_state, dict), ( + "The initial state of an RLModule must be a dict, but is " + f"{type(initial_state)} instead." + ) + return bool(any(sa_init_state for sa_init_state in initial_state.values())) + + @classmethod + def __check_module_configs(cls, module_configs: Dict[ModuleID, Any]): + """Checks the module configs for validity. + + The module_configs be a mapping from module_ids to RLModuleSpec + objects. + + Args: + module_configs: The module configs to check. + + Raises: + ValueError: If the module configs are invalid. + """ + for module_id, module_spec in module_configs.items(): + if not isinstance(module_spec, RLModuleSpec): + raise ValueError(f"Module {module_id} is not a RLModuleSpec object.") + + def keys(self) -> KeysView[ModuleID]: + """Returns a keys view over the module IDs in this MultiRLModule.""" + return self._rl_modules.keys() + + def __len__(self) -> int: + """Returns the number of RLModules within this MultiRLModule.""" + return len(self._rl_modules) + + @override(RLModule) + def as_multi_rl_module(self) -> "MultiRLModule": + """Returns self in order to match `RLModule.as_multi_rl_module()` behavior. + + This method is overridden to avoid double wrapping. + + Returns: + The instance itself. + """ + return self + + def add_module( + self, + module_id: ModuleID, + module: RLModule, + *, + override: bool = False, + ) -> None: + """Adds a module at run time to the multi-agent module. + + Args: + module_id: The module ID to add. If the module ID already exists and + override is False, an error is raised. If override is True, the module + is replaced. + module: The module to add. + override: Whether to override the module if it already exists. + + Raises: + ValueError: If the module ID already exists and override is False. + Warnings are raised if the module id is not valid according to the + logic of ``validate_module_id()``. + """ + from ray.rllib.core.rl_module import validate_module_id + + validate_module_id(module_id) + + if module_id in self._rl_modules and not override: + raise ValueError( + f"Module ID {module_id} already exists. If your intention is to " + "override, set override=True." + ) + # Set our own inference_only flag to False as soon as any added Module + # has `inference_only=False`. + if not module.config.inference_only: + self.config.inference_only = False + self._rl_modules[module_id] = module + # Update our `MultiRLModuleConfig`, such that - if written to disk - + # it'll allow for proper restoring this instance through `.from_checkpoint()`. + self.config.modules[module_id] = RLModuleSpec.from_module(module) + + def remove_module( + self, module_id: ModuleID, *, raise_err_if_not_found: bool = True + ) -> None: + """Removes a module at run time from the multi-agent module. + + Args: + module_id: The module ID to remove. + raise_err_if_not_found: Whether to raise an error if the module ID is not + found. + Raises: + ValueError: If the module ID does not exist and raise_err_if_not_found is + True. + """ + if raise_err_if_not_found: + self._check_module_exists(module_id) + del self._rl_modules[module_id] + del self.config.modules[module_id] + + def foreach_module( + self, func: Callable[[ModuleID, RLModule, Optional[Any]], T], **kwargs + ) -> List[T]: + """Calls the given function with each (module_id, module). + + Args: + func: The function to call with each (module_id, module) tuple. + + Returns: + The lsit of return values of all calls to + `func([module_id, module, **kwargs])`. + """ + return [ + func(module_id, module.unwrapped(), **kwargs) + for module_id, module in self._rl_modules.items() + ] + + def __contains__(self, item) -> bool: + """Returns whether the given `item` (ModuleID) is present in self.""" + return item in self._rl_modules + + def __getitem__(self, module_id: ModuleID) -> RLModule: + """Returns the RLModule with the given module ID. + + Args: + module_id: The module ID to get. + + Returns: + The RLModule with the given module ID. + + Raises: + KeyError: If `module_id` cannot be found in self. + """ + self._check_module_exists(module_id) + return self._rl_modules[module_id] + + def get( + self, + module_id: ModuleID, + default: Optional[RLModule] = None, + ) -> Optional[RLModule]: + """Returns the module with the given module ID or default if not found in self. + + Args: + module_id: The module ID to get. + + Returns: + The RLModule with the given module ID or `default` if `module_id` not found + in `self`. + """ + if module_id not in self._rl_modules: + return default + return self._rl_modules[module_id] + + @override(RLModule) + def output_specs_train(self) -> SpecType: + return [] + + @override(RLModule) + def output_specs_inference(self) -> SpecType: + return [] + + @override(RLModule) + def output_specs_exploration(self) -> SpecType: + return [] + + @override(RLModule) + def _default_input_specs(self) -> SpecType: + """MultiRLModule should not check the input specs. + + The underlying single-agent RLModules will check the input specs. + """ + return [] + + @override(RLModule) + def _forward_train( + self, batch: MultiAgentBatch, **kwargs + ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: + """Runs the forward_train pass. + + TODO(avnishn, kourosh): Review type hints for forward methods. + + Args: + batch: The batch of multi-agent data (i.e. mapping from module ids to + SampleBaches). + + Returns: + The output of the forward_train pass the specified modules. + """ + return self._run_forward_pass("forward_train", batch, **kwargs) + + @override(RLModule) + def _forward_inference( + self, batch: MultiAgentBatch, **kwargs + ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: + """Runs the forward_inference pass. + + TODO(avnishn, kourosh): Review type hints for forward methods. + + Args: + batch: The batch of multi-agent data (i.e. mapping from module ids to + SampleBaches). + + Returns: + The output of the forward_inference pass the specified modules. + """ + return self._run_forward_pass("forward_inference", batch, **kwargs) + + @override(RLModule) + def _forward_exploration( + self, batch: MultiAgentBatch, **kwargs + ) -> Union[Dict[str, Any], Dict[ModuleID, Dict[str, Any]]]: + """Runs the forward_exploration pass. + + TODO(avnishn, kourosh): Review type hints for forward methods. + + Args: + batch: The batch of multi-agent data (i.e. mapping from module ids to + SampleBaches). + + Returns: + The output of the forward_exploration pass the specified modules. + """ + return self._run_forward_pass("forward_exploration", batch, **kwargs) + + @override(RLModule) + def get_state( + self, + components: Optional[Union[str, Collection[str]]] = None, + *, + not_components: Optional[Union[str, Collection[str]]] = None, + inference_only: bool = False, + **kwargs, + ) -> StateDict: + state = {} + + for module_id, rl_module in self.get_checkpointable_components(): + if self._check_component(module_id, components, not_components): + state[module_id] = rl_module.get_state( + components=self._get_subcomponents(module_id, components), + not_components=self._get_subcomponents(module_id, not_components), + inference_only=inference_only, + ) + return state + + @override(RLModule) + def set_state(self, state: StateDict) -> None: + """Sets the state of the multi-agent module. + + It is assumed that the state_dict is a mapping from module IDs to the + corresponding module's state. This method sets the state of each module by + calling their set_state method. If you want to set the state of some of the + RLModules within this MultiRLModule your state_dict can only include the + state of those RLModules. Override this method to customize the state_dict for + custom more advanced multi-agent use cases. + + Args: + state: The state dict to set. + """ + for module_id, module_state in state.items(): + if module_id in self: + self._rl_modules[module_id].set_state(module_state) + + @override(Checkpointable) + def get_checkpointable_components(self) -> List[Tuple[str, Checkpointable]]: + return list(self._rl_modules.items()) + + def __repr__(self) -> str: + return f"MARL({pprint.pformat(self._rl_modules)})" + + def _run_forward_pass( + self, + forward_fn_name: str, + batch: Dict[ModuleID, Any], + **kwargs, + ) -> Dict[ModuleID, Dict[ModuleID, Any]]: + """This is a helper method that runs the forward pass for the given module. + + It uses forward_fn_name to get the forward pass method from the RLModule + (e.g. forward_train vs. forward_exploration) and runs it on the given batch. + + Args: + forward_fn_name: The name of the forward pass method to run. + batch: The batch of multi-agent data (i.e. mapping from module ids to + SampleBaches). + **kwargs: Additional keyword arguments to pass to the forward function. + + Returns: + The output of the forward pass the specified modules. The output is a + mapping from module ID to the output of the forward pass. + """ + + outputs = {} + for module_id in batch.keys(): + self._check_module_exists(module_id) + rl_module = self._rl_modules[module_id] + forward_fn = getattr(rl_module, forward_fn_name) + outputs[module_id] = forward_fn(batch[module_id], **kwargs) + + return outputs + + def _check_module_exists(self, module_id: ModuleID) -> None: + if module_id not in self._rl_modules: + raise KeyError( + f"Module with module_id {module_id} not found. " + f"Available modules: {set(self.keys())}" + ) + + +@PublicAPI(stability="alpha") +@dataclass +class MultiRLModuleSpec: + """A utility spec class to make it constructing MultiRLModules easier. + + Users can extend this class to modify the behavior of base class. For example to + share neural networks across the modules, the build method can be overriden to + create the shared module first and then pass it to custom module classes that would + then use it as a shared module. + + Args: + multi_rl_module_class: The class of the MultiRLModule to construct. By + default it is set to MultiRLModule class. This class simply loops + throught each module and calls their foward methods. + module_specs: The module specs for each individual module. It can be either a + RLModuleSpec used for all module_ids or a dictionary mapping + from module IDs to RLModuleSpecs for each individual module. + load_state_path: The path to the module state to load from. NOTE: This must be + an absolute path. NOTE: If the load_state_path of this spec is set, and + the load_state_path of one of the RLModuleSpecs' is also set, + the weights of that RL Module will be loaded from the path specified in + the RLModuleSpec. This is useful if you want to load the weights + of a MultiRLModule and also manually load the weights of some of the RL + modules within that MultiRLModule from other checkpoints. + modules_to_load: A set of module ids to load from the checkpoint. This is + only used if load_state_path is set. If this is None, all modules are + loaded. + """ + + multi_rl_module_class: Type[MultiRLModule] = MultiRLModule + inference_only: bool = False + module_specs: Union[RLModuleSpec, Dict[ModuleID, RLModuleSpec]] = None + load_state_path: Optional[str] = None + modules_to_load: Optional[Set[ModuleID]] = None + + # To be deprecated (same as `multi_rl_module_class`). + marl_module_class: Type[MultiRLModule] = MultiRLModule + + def __post_init__(self): + if self.module_specs is None: + raise ValueError( + "Module_specs cannot be None. It should be either a " + "RLModuleSpec or a dictionary mapping from module IDs to " + "RLModuleSpecs for each individual module." + ) + + def get_multi_rl_module_config(self) -> "MultiRLModuleConfig": + """Returns the MultiRLModuleConfig for this spec.""" + return MultiRLModuleConfig( + # Only set `inference_only=True` if all single-agent specs are + # `inference_only`. + inference_only=all( + spec.inference_only for spec in self.module_specs.values() + ), + modules=self.module_specs, + ) + + @OverrideToImplementCustomLogic + def build(self, module_id: Optional[ModuleID] = None) -> RLModule: + """Builds either the multi-agent module or the single-agent module. + + If module_id is None, it builds the multi-agent module. Otherwise, it builds + the single-agent module with the given module_id. + + Note: If when build is called the module_specs is not a dictionary, it will + raise an error, since it should have been updated by the caller to inform us + about the module_ids. + + Args: + module_id: The module_id of the single-agent module to build. If None, it + builds the multi-agent module. + + Returns: + The built module. If module_id is None, it returns the multi-agent module. + """ + self._check_before_build() + + # ModuleID provided, return single-agent RLModule. + if module_id: + return self.module_specs[module_id].build() + + # Return MultiRLModule. + module_config = self.get_multi_rl_module_config() + module = self.multi_rl_module_class(module_config) + return module + + def add_modules( + self, + module_specs: Dict[ModuleID, RLModuleSpec], + override: bool = True, + ) -> None: + """Add new module specs to the spec or updates existing ones. + + Args: + module_specs: The mapping for the module_id to the single-agent module + specs to be added to this multi-agent module spec. + override: Whether to override the existing module specs if they already + exist. If False, they are only updated. + """ + if self.module_specs is None: + self.module_specs = {} + for module_id, module_spec in module_specs.items(): + if override or module_id not in self.module_specs: + # Disable our `inference_only` as soon as any single-agent module has + # `inference_only=False`. + if not module_spec.inference_only: + self.inference_only = False + self.module_specs[module_id] = module_spec + else: + self.module_specs[module_id].update(module_spec) + + @classmethod + def from_module(self, module: MultiRLModule) -> "MultiRLModuleSpec": + """Creates a MultiRLModuleSpec from a MultiRLModule. + + Args: + module: The MultiRLModule to create the spec from. + + Returns: + The MultiRLModuleSpec. + """ + # we want to get the spec of the underlying unwrapped module that way we can + # easily reconstruct it. The only wrappers that we expect to support today are + # wrappers that allow us to do distributed training. Those will be added back + # by the learner if necessary. + module_specs = { + module_id: RLModuleSpec.from_module(rl_module.unwrapped()) + for module_id, rl_module in module._rl_modules.items() + } + multi_rl_module_class = module.__class__ + return MultiRLModuleSpec( + multi_rl_module_class=multi_rl_module_class, + inference_only=module.config.inference_only, + module_specs=module_specs, + ) + + def _check_before_build(self): + if not isinstance(self.module_specs, dict): + raise ValueError( + f"When build() is called on {self.__class__}, the module_specs " + "should be a dictionary mapping from module IDs to " + "RLModuleSpecs for each individual module." + ) + + def to_dict(self) -> Dict[str, Any]: + """Converts the MultiRLModuleSpec to a dictionary.""" + return { + "multi_rl_module_class": serialize_type(self.multi_rl_module_class), + "inference_only": self.inference_only, + "module_specs": { + module_id: module_spec.to_dict() + for module_id, module_spec in self.module_specs.items() + }, + } + + @classmethod + def from_dict(cls, d) -> "MultiRLModuleSpec": + """Creates a MultiRLModuleSpec from a dictionary.""" + return MultiRLModuleSpec( + multi_rl_module_class=deserialize_type(d["multi_rl_module_class"]), + inference_only=d["inference_only"], + module_specs={ + module_id: RLModuleSpec.from_dict(module_spec) + for module_id, module_spec in d["module_specs"].items() + }, + ) + + def update( + self, + other: Union["MultiRLModuleSpec", RLModuleSpec], + override=False, + ) -> None: + """Updates this spec with the other spec. + + Traverses this MultiRLModuleSpec's module_specs and updates them with + the module specs from the other MultiRLModuleSpec. + + Args: + other: The other spec to update this spec with. + override: Whether to override the existing module specs if they already + exist. If False, they are only updated. + """ + if isinstance(other, RLModuleSpec): + # Disable our `inference_only` as soon as any single-agent module has + # `inference_only=False`. + if not other.inference_only: + self.inference_only = False + for mid, spec in self.module_specs.items(): + self.module_specs[mid].update(other, override=False) + elif isinstance(other.module_specs, dict): + self.add_modules(other.module_specs, override=override) + else: + assert isinstance(other, MultiRLModuleSpec) + if not self.module_specs: + self.inference_only = other.inference_only + self.module_specs = other.module_specs + else: + if not other.inference_only: + self.inference_only = False + self.module_specs.update(other.module_specs) + + def as_multi_rl_module_spec(self) -> "MultiRLModuleSpec": + """Returns self in order to match `RLModuleSpec.as_multi_rl_module_spec()`.""" + return self + + def __contains__(self, item) -> bool: + """Returns whether the given `item` (ModuleID) is present in self.""" + return item in self.module_specs + + @Deprecated(new="MultiRLModuleSpec.as_multi_rl_module_spec()", error=True) + def as_multi_agent(self): + pass + + @Deprecated(new="MultiRLModuleSpec.get_multi_rl_module_config", error=True) + def get_marl_config(self, *args, **kwargs): + pass + + +# TODO (sven): Shouldn't we simply use this class inside MultiRLModuleSpec instead +# of duplicating all data records (e.g. `inference_only`) in `MultiRLModuleSpec`? +# Same for RLModuleSpec, which should use RLModuleConfig instead of +# duplicating all settings, e.g. `observation_space`, `inference_only`, ... +@ExperimentalAPI +@dataclass +class MultiRLModuleConfig: + inference_only: bool = False + modules: Dict[ModuleID, RLModuleSpec] = field(default_factory=dict) + + def to_dict(self): + return { + "inference_only": self.inference_only, + "modules": { + module_id: module_spec.to_dict() + for module_id, module_spec in self.modules.items() + }, + } + + @classmethod + def from_dict(cls, d) -> "MultiRLModuleConfig": + return cls( + inference_only=d["inference_only"], + modules={ + module_id: RLModuleSpec.from_dict(module_spec) + for module_id, module_spec in d["modules"].items() + }, + ) diff --git a/rllib/core/rl_module/rl_module.py b/rllib/core/rl_module/rl_module.py index b19193d80f1a2..253db06161981 100644 --- a/rllib/core/rl_module/rl_module.py +++ b/rllib/core/rl_module/rl_module.py @@ -6,9 +6,9 @@ import tree # pip install dm_tree if TYPE_CHECKING: - from ray.rllib.core.rl_module.marl_module import ( - MultiAgentRLModule, - MultiAgentRLModuleSpec, + from ray.rllib.core.rl_module.multi_rl_module import ( + MultiRLModule, + MultiRLModuleSpec, ) from ray.rllib.core.models.catalog import Catalog @@ -45,7 +45,7 @@ @PublicAPI(stability="alpha") @dataclass -class SingleAgentRLModuleSpec: +class RLModuleSpec: """Utility spec class to make constructing RLModules (in single-agent case) easier. Args: @@ -96,15 +96,13 @@ def build(self) -> "RLModule": return module @classmethod - def from_module(cls, module: "RLModule") -> "SingleAgentRLModuleSpec": - from ray.rllib.core.rl_module.marl_module import MultiAgentRLModule + def from_module(cls, module: "RLModule") -> "RLModuleSpec": + from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule - if isinstance(module, MultiAgentRLModule): - raise ValueError( - "MultiAgentRLModule cannot be converted to SingleAgentRLModuleSpec." - ) + if isinstance(module, MultiRLModule): + raise ValueError("MultiRLModule cannot be converted to RLModuleSpec.") - return SingleAgentRLModuleSpec( + return RLModuleSpec( module_class=type(module), observation_space=module.config.observation_space, action_space=module.config.action_space, @@ -127,7 +125,7 @@ def from_dict(cls, d): module_class = deserialize_type(d["module_class"]) module_config = RLModuleConfig.from_dict(d["module_config"]) - spec = SingleAgentRLModuleSpec( + spec = RLModuleSpec( module_class=module_class, observation_space=module_config.observation_space, action_space=module_config.action_space, @@ -145,8 +143,8 @@ def update(self, other, override: bool = True) -> None: override: Whether to update all properties in `self` with those of `other. If False, only update those properties in `self` that are not None. """ - if not isinstance(other, SingleAgentRLModuleSpec): - raise ValueError("Can only update with another SingleAgentRLModuleSpec.") + if not isinstance(other, RLModuleSpec): + raise ValueError("Can only update with another RLModuleSpec.") # If the field is None in the other, keep the current field, otherwise update # with the new value. @@ -167,15 +165,19 @@ def update(self, other, override: bool = True) -> None: self.catalog_class = self.catalog_class or other.catalog_class self.load_state_path = self.load_state_path or other.load_state_path - def as_multi_agent(self) -> "MultiAgentRLModuleSpec": - """Returns a MultiAgentRLModuleSpec (`self` under DEFAULT_MODULE_ID key).""" - from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec + def as_multi_rl_module_spec(self) -> "MultiRLModuleSpec": + """Returns a MultiRLModuleSpec (`self` under DEFAULT_MODULE_ID key).""" + from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec - return MultiAgentRLModuleSpec( + return MultiRLModuleSpec( module_specs={DEFAULT_MODULE_ID: self}, load_state_path=self.load_state_path, ) + @Deprecated(new="RLModuleSpec.as_multi_rl_module_spec()", error=True) + def as_multi_agent(self, *args, **kwargs): + pass + @ExperimentalAPI @dataclass @@ -265,7 +267,7 @@ class RLModule(Checkpointable, abc.ABC): env = gym.make("CartPole-v1") # Create a single agent RL module spec. - module_spec = SingleAgentRLModuleSpec( + module_spec = RLModuleSpec( module_class=PPOTorchRLModule, observation_space=env.observation_space, action_space=env.action_space, @@ -302,7 +304,7 @@ class RLModule(Checkpointable, abc.ABC): env = gym.make("CartPole-v1") # Create a single agent RL module spec. - module_spec = SingleAgentRLModuleSpec( + module_spec = RLModuleSpec( module_class=PPOTorchRLModule, observation_space=env.observation_space, action_space=env.action_space, @@ -330,7 +332,7 @@ class RLModule(Checkpointable, abc.ABC): env = gym.make("CartPole-v1") # Create a single agent RL module spec. - module_spec = SingleAgentRLModuleSpec( + module_spec = RLModuleSpec( module_class=PPOTorchRLModule, observation_space=env.observation_space, action_space=env.action_space, @@ -726,13 +728,13 @@ def get_ctor_args_and_kwargs(self): {}, # **kwargs ) - def as_multi_agent(self) -> "MultiAgentRLModule": + def as_multi_rl_module(self) -> "MultiRLModule": """Returns a multi-agent wrapper around this module.""" - from ray.rllib.core.rl_module.marl_module import MultiAgentRLModule + from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule - marl_module = MultiAgentRLModule() - marl_module.add_module(DEFAULT_MODULE_ID, self) - return marl_module + multi_rl_module = MultiRLModule() + multi_rl_module.add_module(DEFAULT_MODULE_ID, self) + return multi_rl_module def unwrapped(self) -> "RLModule": """Returns the underlying module if this module is a wrapper. @@ -745,6 +747,10 @@ def unwrapped(self) -> "RLModule": """ return self + @Deprecated(new="RLModule.as_multi_rl_module()", error=True) + def as_multi_agent(self, *args, **kwargs): + pass + @Deprecated(new="RLModule.save_to_path(...)", error=True) def save_state(self, *args, **kwargs): pass diff --git a/rllib/core/rl_module/tests/test_marl_module.py b/rllib/core/rl_module/tests/test_multi_rl_module.py similarity index 80% rename from rllib/core/rl_module/tests/test_marl_module.py rename to rllib/core/rl_module/tests/test_multi_rl_module.py index 46c4e6246e617..3ec8f47882477 100644 --- a/rllib/core/rl_module/tests/test_marl_module.py +++ b/rllib/core/rl_module/tests/test_multi_rl_module.py @@ -2,64 +2,59 @@ import unittest from ray.rllib.core import DEFAULT_MODULE_ID -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec, RLModuleConfig -from ray.rllib.core.rl_module.marl_module import ( - MultiAgentRLModule, - MultiAgentRLModuleConfig, -) +from ray.rllib.core.rl_module.rl_module import RLModuleSpec, RLModuleConfig +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule, MultiRLModuleConfig from ray.rllib.core.testing.torch.bc_module import DiscreteBCTorchModule from ray.rllib.env.multi_agent_env import make_multi_agent from ray.rllib.utils.test_utils import check -class TestMARLModule(unittest.TestCase): +class TestMultiRLModule(unittest.TestCase): def test_from_config(self): - """Tests whether a MultiAgentRLModule can be constructed from a config.""" + """Tests whether a MultiRLModule can be constructed from a config.""" env_class = make_multi_agent("CartPole-v0") env = env_class({"num_agents": 2}) - module1 = SingleAgentRLModuleSpec( + module1 = RLModuleSpec( module_class=DiscreteBCTorchModule, observation_space=env.observation_space[0], action_space=env.action_space[0], model_config_dict={"fcnet_hiddens": [32]}, ) - module2 = SingleAgentRLModuleSpec( + module2 = RLModuleSpec( module_class=DiscreteBCTorchModule, observation_space=env.observation_space[0], action_space=env.action_space[0], model_config_dict={"fcnet_hiddens": [32]}, ) - config = MultiAgentRLModuleConfig( - modules={"module1": module1, "module2": module2} - ) - marl_module = MultiAgentRLModule(config) + config = MultiRLModuleConfig(modules={"module1": module1, "module2": module2}) + multi_rl_module = MultiRLModule(config) - self.assertEqual(set(marl_module.keys()), {"module1", "module2"}) - self.assertIsInstance(marl_module["module1"], DiscreteBCTorchModule) - self.assertIsInstance(marl_module["module2"], DiscreteBCTorchModule) + self.assertEqual(set(multi_rl_module.keys()), {"module1", "module2"}) + self.assertIsInstance(multi_rl_module["module1"], DiscreteBCTorchModule) + self.assertIsInstance(multi_rl_module["module2"], DiscreteBCTorchModule) - def test_as_multi_agent(self): + def test_as_multi_rl_module(self): env_class = make_multi_agent("CartPole-v0") env = env_class({"num_agents": 2}) - marl_module = DiscreteBCTorchModule( + multi_rl_module = DiscreteBCTorchModule( config=RLModuleConfig( env.observation_space[0], env.action_space[0], model_config_dict={"fcnet_hiddens": [32]}, ) - ).as_multi_agent() + ).as_multi_rl_module() - self.assertNotIsInstance(marl_module, DiscreteBCTorchModule) - self.assertIsInstance(marl_module, MultiAgentRLModule) - self.assertEqual({DEFAULT_MODULE_ID}, set(marl_module.keys())) + self.assertNotIsInstance(multi_rl_module, DiscreteBCTorchModule) + self.assertIsInstance(multi_rl_module, MultiRLModule) + self.assertEqual({DEFAULT_MODULE_ID}, set(multi_rl_module.keys())) - # check as_multi_agent() for the second time - marl_module2 = marl_module.as_multi_agent() - self.assertEqual(id(marl_module), id(marl_module2)) + # Check as_multi_rl_module() for the second time + multi_rl_module2 = multi_rl_module.as_multi_rl_module() + self.assertEqual(id(multi_rl_module), id(multi_rl_module2)) def test_get_state_and_set_state(self): @@ -72,7 +67,7 @@ def test_get_state_and_set_state(self): env.action_space[0], model_config_dict={"fcnet_hiddens": [32]}, ) - ).as_multi_agent() + ).as_multi_rl_module() state = module.get_state() self.assertIsInstance(state, dict) @@ -88,7 +83,7 @@ def test_get_state_and_set_state(self): env.action_space[0], model_config_dict={"fcnet_hiddens": [32]}, ) - ).as_multi_agent() + ).as_multi_rl_module() state2 = module2.get_state() check(state, state2, false=True) @@ -108,7 +103,7 @@ def test_add_remove_modules(self): env.action_space[0], model_config_dict={"fcnet_hiddens": [32]}, ) - ).as_multi_agent() + ).as_multi_rl_module() module.add_module( "test", @@ -161,7 +156,7 @@ def test_save_to_path_and_from_checkpoint(self): env.action_space[0], model_config_dict={"fcnet_hiddens": [32]}, ) - ).as_multi_agent() + ).as_multi_rl_module() module.add_module( "test", @@ -186,7 +181,7 @@ def test_save_to_path_and_from_checkpoint(self): with tempfile.TemporaryDirectory() as tmpdir: module.save_to_path(tmpdir) - module2 = MultiAgentRLModule.from_checkpoint(tmpdir) + module2 = MultiRLModule.from_checkpoint(tmpdir) check(module.get_state(), module2.get_state()) self.assertEqual(module.keys(), module2.keys()) self.assertEqual(module.keys(), {"test", "test2", DEFAULT_MODULE_ID}) @@ -197,7 +192,7 @@ def test_save_to_path_and_from_checkpoint(self): # Check that - after removing a module - the checkpoint is correct. with tempfile.TemporaryDirectory() as tmpdir: module.save_to_path(tmpdir) - module2 = MultiAgentRLModule.from_checkpoint(tmpdir) + module2 = MultiRLModule.from_checkpoint(tmpdir) check(module.get_state(), module2.get_state()) self.assertEqual(module.keys(), module2.keys()) self.assertEqual(module.keys(), {"test2", DEFAULT_MODULE_ID}) @@ -216,9 +211,9 @@ def test_save_to_path_and_from_checkpoint(self): ) # Check that - after adding a module - the checkpoint is correct. with tempfile.TemporaryDirectory() as tmpdir: - tmpdir = "/tmp/test_marl_module" + tmpdir = "/tmp/test_multi_rl_module" module.save_to_path(tmpdir) - module2 = MultiAgentRLModule.from_checkpoint(tmpdir) + module2 = MultiRLModule.from_checkpoint(tmpdir) check(module.get_state(), module2.get_state()) self.assertEqual(module.keys(), module2.keys()) self.assertEqual(module.keys(), {"test2", "test3", DEFAULT_MODULE_ID}) diff --git a/rllib/core/rl_module/tests/test_rl_module_specs.py b/rllib/core/rl_module/tests/test_rl_module_specs.py index a90ac507c7f0f..3c479eea4ced9 100644 --- a/rllib/core/rl_module/tests/test_rl_module_specs.py +++ b/rllib/core/rl_module/tests/test_rl_module_specs.py @@ -2,10 +2,10 @@ import gymnasium as gym import torch -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec -from ray.rllib.core.rl_module.marl_module import ( - MultiAgentRLModule, - MultiAgentRLModuleSpec, +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import ( + MultiRLModule, + MultiRLModuleSpec, ) from ray.rllib.core.testing.torch.bc_module import ( DiscreteBCTorchModule, @@ -23,7 +23,7 @@ "torch": BCTorchRLModuleWithSharedGlobalEncoder, "tf2": BCTfRLModuleWithSharedGlobalEncoder, } -CUSTOM_MARL_MODULES = { +CUSTOM_multi_rl_moduleS = { "torch": BCTorchMultiAgentModuleWithSharedEncoder, "tf2": BCTfMultiAgentModuleWithSharedEncoder, } @@ -31,10 +31,10 @@ class TestRLModuleSpecs(unittest.TestCase): def test_single_agent_spec(self): - """Tests RLlib's default SingleAgentRLModuleSpec.""" + """Tests RLlib's default RLModuleSpec.""" env = gym.make("CartPole-v1") for module_class in MODULES: - spec = SingleAgentRLModuleSpec( + spec = RLModuleSpec( module_class=module_class, observation_space=env.observation_space, action_space=env.action_space, @@ -50,16 +50,16 @@ def test_multi_agent_spec(self): for module_class in MODULES: module_specs = {} for i in range(num_agents): - module_specs[f"module_{i}"] = SingleAgentRLModuleSpec( + module_specs[f"module_{i}"] = RLModuleSpec( module_class=module_class, observation_space=env.observation_space, action_space=env.action_space, model_config_dict={"fcnet_hiddens": [32 * (i + 1)]}, ) - spec = MultiAgentRLModuleSpec(module_specs=module_specs) + spec = MultiRLModuleSpec(module_specs=module_specs) module = spec.build() - self.assertIsInstance(module, MultiAgentRLModule) + self.assertIsInstance(module, MultiRLModule) def test_customized_multi_agent_module(self): """Tests creating a customized MARL BC module that owns a shared encoder.""" @@ -70,13 +70,13 @@ def test_customized_multi_agent_module(self): # TODO (Kourosh): add tf support for fw in ["torch"]: - marl_module_cls = CUSTOM_MARL_MODULES[fw] + multi_rl_module_cls = CUSTOM_multi_rl_moduleS[fw] rl_module_cls = CUSTOM_MODULES[fw] - spec = MultiAgentRLModuleSpec( - marl_module_class=marl_module_cls, + spec = MultiRLModuleSpec( + multi_rl_module_class=multi_rl_module_cls, module_specs={ - "agent_1": SingleAgentRLModuleSpec( + "agent_1": RLModuleSpec( module_class=rl_module_cls, observation_space=gym.spaces.Dict( { @@ -91,7 +91,7 @@ def test_customized_multi_agent_module(self): action_space=gym.spaces.Discrete(action_dims[0]), model_config_dict={"fcnet_hiddens": [128]}, ), - "agent_2": SingleAgentRLModuleSpec( + "agent_2": RLModuleSpec( module_class=rl_module_cls, observation_space=gym.spaces.Dict( { @@ -119,30 +119,30 @@ def test_customized_multi_agent_module(self): self.assertTrue(torch.allclose(model["agent_2"].encoder[0].bias, foo)) def test_get_spec_from_module_multi_agent(self): - """Tests wether MultiAgentRLModuleSpec.from_module() works.""" + """Tests wether MultiRLModuleSpec.from_module() works.""" env = gym.make("CartPole-v1") num_agents = 2 for module_class in MODULES: module_specs = {} for i in range(num_agents): - module_specs[f"module_{i}"] = SingleAgentRLModuleSpec( + module_specs[f"module_{i}"] = RLModuleSpec( module_class=module_class, observation_space=env.observation_space, action_space=env.action_space, model_config_dict={"fcnet_hiddens": [32 * (i + 1)]}, ) - spec = MultiAgentRLModuleSpec(module_specs=module_specs) + spec = MultiRLModuleSpec(module_specs=module_specs) module = spec.build() - spec_from_module = MultiAgentRLModuleSpec.from_module(module) + spec_from_module = MultiRLModuleSpec.from_module(module) self.assertEqual(spec, spec_from_module) def test_get_spec_from_module_single_agent(self): - """Tests wether SingleAgentRLModuleSpec.from_module() works.""" + """Tests wether RLModuleSpec.from_module() works.""" env = gym.make("CartPole-v1") for module_class in MODULES: - spec = SingleAgentRLModuleSpec( + spec = RLModuleSpec( module_class=module_class, observation_space=env.observation_space, action_space=env.action_space, @@ -150,53 +150,51 @@ def test_get_spec_from_module_single_agent(self): ) module = spec.build() - spec_from_module = SingleAgentRLModuleSpec.from_module(module) + spec_from_module = RLModuleSpec.from_module(module) self.assertEqual(spec, spec_from_module) def test_update_specs(self): - """Tests wether SingleAgentRLModuleSpec.update() works.""" + """Tests wether RLModuleSpec.update() works.""" env = gym.make("CartPole-v0") - # Test if SingleAgentRLModuleSpec.update() works. - module_spec_1 = SingleAgentRLModuleSpec( + # Test if RLModuleSpec.update() works. + module_spec_1 = RLModuleSpec( module_class=DiscreteBCTorchModule, observation_space=env.observation_space, action_space=env.action_space, model_config_dict="Update me!", ) - module_spec_2 = SingleAgentRLModuleSpec( - model_config_dict={"fcnet_hiddens": [32]} - ) + module_spec_2 = RLModuleSpec(model_config_dict={"fcnet_hiddens": [32]}) self.assertEqual(module_spec_1.model_config_dict, "Update me!") module_spec_1.update(module_spec_2) self.assertEqual(module_spec_1.model_config_dict, {"fcnet_hiddens": [32]}) def test_update_specs_multi_agent(self): - """Test if updating a SingleAgentRLModuleSpec in MultiAgentRLModuleSpec works. + """Test if updating a RLModuleSpec in MultiRLModuleSpec works. This tests if we can update a `model_config_dict` field through different kinds of updates: - - Create a SingleAgentRLModuleSpec and update its model_config_dict. - - Create two MultiAgentRLModuleSpecs and update the first one with the + - Create a RLModuleSpec and update its model_config_dict. + - Create two MultiRLModuleSpecs and update the first one with the second one without overwriting it. - - Check if the updated MultiAgentRLModuleSpec does not(!) have the + - Check if the updated MultiRLModuleSpec does not(!) have the updated model_config_dict. - - Create two MultiAgentRLModuleSpecs and update the first one with the + - Create two MultiRLModuleSpecs and update the first one with the second one with overwriting it. - - Check if the updated MultiAgentRLModuleSpec has(!) the updated + - Check if the updated MultiRLModuleSpec has(!) the updated model_config_dict. """ env = gym.make("CartPole-v0") - # Test if SingleAgentRLModuleSpec.update() works. - module_spec_1 = SingleAgentRLModuleSpec( + # Test if RLModuleSpec.update() works. + module_spec_1 = RLModuleSpec( module_class=DiscreteBCTorchModule, observation_space="Do not update me!", action_space=env.action_space, model_config_dict="Update me!", ) - module_spec_2 = SingleAgentRLModuleSpec( + module_spec_2 = RLModuleSpec( model_config_dict={"fcnet_hiddens": [32]}, ) @@ -210,23 +208,23 @@ def test_update_specs_multi_agent(self): ) # Redefine module_spec_1 for following tests. - module_spec_1 = SingleAgentRLModuleSpec( + module_spec_1 = RLModuleSpec( module_class=DiscreteBCTorchModule, observation_space="Do not update me!", action_space=env.action_space, model_config_dict="Update me!", ) - marl_spec_1 = MultiAgentRLModuleSpec( - marl_module_class=BCTorchMultiAgentModuleWithSharedEncoder, + marl_spec_1 = MultiRLModuleSpec( + multi_rl_module_class=BCTorchMultiAgentModuleWithSharedEncoder, module_specs={"agent_1": module_spec_1}, ) - marl_spec_2 = MultiAgentRLModuleSpec( - marl_module_class=BCTorchMultiAgentModuleWithSharedEncoder, + marl_spec_2 = MultiRLModuleSpec( + multi_rl_module_class=BCTorchMultiAgentModuleWithSharedEncoder, module_specs={"agent_1": module_spec_2}, ) - # Test if updating MultiAgentRLModuleSpec with overwriting works. This means + # Test if updating MultiRLModuleSpec with overwriting works. This means # that the single agent specs should be overwritten self.assertEqual( marl_spec_1.module_specs["agent_1"].model_config_dict, "Update me!" @@ -234,10 +232,10 @@ def test_update_specs_multi_agent(self): marl_spec_1.update(marl_spec_2, override=True) self.assertEqual(marl_spec_1.module_specs["agent_1"], module_spec_2) - # Test if updating MultiAgentRLModuleSpec without overwriting works. This + # Test if updating MultiRLModuleSpec without overwriting works. This # means that the single agent specs should not be overwritten - marl_spec_3 = MultiAgentRLModuleSpec( - marl_module_class=BCTorchMultiAgentModuleWithSharedEncoder, + marl_spec_3 = MultiRLModuleSpec( + multi_rl_module_class=BCTorchMultiAgentModuleWithSharedEncoder, module_specs={"agent_1": module_spec_1}, ) @@ -251,15 +249,15 @@ def test_update_specs_multi_agent(self): marl_spec_3.module_specs["agent_1"].observation_space, "Do not update me!" ) - # Test if updating with an additional SingleAgentRLModuleSpec works. - module_spec_3 = SingleAgentRLModuleSpec( + # Test if updating with an additional RLModuleSpec works. + module_spec_3 = RLModuleSpec( module_class=DiscreteBCTorchModule, observation_space=env.observation_space, action_space=env.action_space, model_config_dict="I'm new!", ) - marl_spec_3 = MultiAgentRLModuleSpec( - marl_module_class=BCTorchMultiAgentModuleWithSharedEncoder, + marl_spec_3 = MultiRLModuleSpec( + multi_rl_module_class=BCTorchMultiAgentModuleWithSharedEncoder, module_specs={"agent_2": module_spec_3}, ) self.assertEqual(marl_spec_1.module_specs.get("agent_2"), None) diff --git a/rllib/core/testing/bc_algorithm.py b/rllib/core/testing/bc_algorithm.py index db2ae109752c0..8f5c3bdbf50fb 100644 --- a/rllib/core/testing/bc_algorithm.py +++ b/rllib/core/testing/bc_algorithm.py @@ -11,7 +11,7 @@ from ray.rllib.core.testing.torch.bc_learner import BCTorchLearner from ray.rllib.core.testing.tf.bc_module import DiscreteBCTFModule from ray.rllib.core.testing.tf.bc_learner import BCTfLearner -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.utils.annotations import override from ray.rllib.utils.typing import ResultDict @@ -22,9 +22,9 @@ def __init__(self, algo_class=None): def get_default_rl_module_spec(self): if self.framework_str == "torch": - return SingleAgentRLModuleSpec(module_class=DiscreteBCTorchModule) + return RLModuleSpec(module_class=DiscreteBCTorchModule) elif self.framework_str == "tf2": - return SingleAgentRLModuleSpec(module_class=DiscreteBCTFModule) + return RLModuleSpec(module_class=DiscreteBCTFModule) def get_default_learner_class(self): if self.framework_str == "torch": diff --git a/rllib/core/testing/testing_learner.py b/rllib/core/testing/testing_learner.py index 057784a300afa..d9796f2a69028 100644 --- a/rllib/core/testing/testing_learner.py +++ b/rllib/core/testing/testing_learner.py @@ -5,14 +5,14 @@ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.core import DEFAULT_MODULE_ID from ray.rllib.core.learner.learner import Learner -from ray.rllib.core.rl_module.marl_module import ( - MultiAgentRLModule, - MultiAgentRLModuleSpec, +from ray.rllib.core.rl_module.multi_rl_module import ( + MultiRLModule, + MultiRLModuleSpec, ) -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.utils.annotations import override from ray.rllib.utils.numpy import convert_to_numpy -from ray.rllib.utils.typing import RLModuleSpec +from ray.rllib.utils.typing import RLModuleSpecType class BaseTestingAlgorithmConfig(AlgorithmConfig): @@ -33,7 +33,7 @@ def get_default_learner_class(self) -> Type["Learner"]: raise ValueError(f"Unsupported framework: {self.framework_str}") @override(AlgorithmConfig) - def get_default_rl_module_spec(self) -> "RLModuleSpec": + def get_default_rl_module_spec(self) -> "RLModuleSpecType": if self.framework_str == "tf2": from ray.rllib.core.testing.tf.bc_module import DiscreteBCTFModule @@ -45,7 +45,7 @@ def get_default_rl_module_spec(self) -> "RLModuleSpec": else: raise ValueError(f"Unsupported framework: {self.framework_str}") - spec = SingleAgentRLModuleSpec( + spec = RLModuleSpec( module_class=cls, model_config_dict={"fcnet_hiddens": [32]}, ) @@ -53,8 +53,8 @@ def get_default_rl_module_spec(self) -> "RLModuleSpec": if self.is_multi_agent(): # TODO (Kourosh): Make this more multi-agent for example with policy ids # "1" and "2". - return MultiAgentRLModuleSpec( - marl_module_class=MultiAgentRLModule, + return MultiRLModuleSpec( + multi_rl_module_class=MultiRLModule, module_specs={DEFAULT_MODULE_ID: spec}, ) else: diff --git a/rllib/core/testing/tests/test_bc_algorithm.py b/rllib/core/testing/tests/test_bc_algorithm.py index fe798a4a48463..9403e183eda34 100644 --- a/rllib/core/testing/tests/test_bc_algorithm.py +++ b/rllib/core/testing/tests/test_bc_algorithm.py @@ -12,8 +12,8 @@ BCTfRLModuleWithSharedGlobalEncoder, BCTfMultiAgentModuleWithSharedEncoder, ) -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec from ray.rllib.core.testing.bc_algorithm import BCConfigTest from ray.rllib.utils.test_utils import framework_iterator from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole @@ -75,23 +75,23 @@ def test_bc_algorithm_marl(self): elif fw == "tf2": assert isinstance(rl_module, DiscreteBCTFModule) - def test_bc_algorithm_w_custom_marl_module(self): + def test_bc_algorithm_w_custom_multi_rl_module(self): """Tests the independent multi-agent case with shared encoders.""" policies = {"policy_1", "policy_2"} for fw in ["torch"]: if fw == "torch": - spec = MultiAgentRLModuleSpec( - marl_module_class=BCTorchMultiAgentModuleWithSharedEncoder, - module_specs=SingleAgentRLModuleSpec( + spec = MultiRLModuleSpec( + multi_rl_module_class=BCTorchMultiAgentModuleWithSharedEncoder, + module_specs=RLModuleSpec( module_class=BCTorchRLModuleWithSharedGlobalEncoder ), ) else: - spec = MultiAgentRLModuleSpec( - marl_module_class=BCTfMultiAgentModuleWithSharedEncoder, - module_specs=SingleAgentRLModuleSpec( + spec = MultiRLModuleSpec( + multi_rl_module_class=BCTfMultiAgentModuleWithSharedEncoder, + module_specs=RLModuleSpec( module_class=BCTfRLModuleWithSharedGlobalEncoder ), ) diff --git a/rllib/core/testing/tf/bc_module.py b/rllib/core/testing/tf/bc_module.py index 1c606d71ccc90..db276631e7f3b 100644 --- a/rllib/core/testing/tf/bc_module.py +++ b/rllib/core/testing/tf/bc_module.py @@ -4,10 +4,7 @@ from ray.rllib.core.columns import Columns from ray.rllib.core.models.specs.typing import SpecType from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleConfig -from ray.rllib.core.rl_module.marl_module import ( - MultiAgentRLModule, - MultiAgentRLModuleConfig, -) +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule, MultiRLModuleConfig from ray.rllib.core.rl_module.tf.tf_rl_module import TfRLModule from ray.rllib.models.tf.tf_distributions import TfCategorical from ray.rllib.utils.annotations import override @@ -123,8 +120,8 @@ def _common_forward(self, batch): return {Columns.ACTION_DIST_INPUTS: action_logits} -class BCTfMultiAgentModuleWithSharedEncoder(MultiAgentRLModule): - def __init__(self, config: MultiAgentRLModuleConfig) -> None: +class BCTfMultiAgentModuleWithSharedEncoder(MultiRLModule): + def __init__(self, config: MultiRLModuleConfig) -> None: super().__init__(config) def setup(self): diff --git a/rllib/core/testing/torch/bc_module.py b/rllib/core/testing/torch/bc_module.py index 5d17a27316d82..dd616b046a10d 100644 --- a/rllib/core/testing/torch/bc_module.py +++ b/rllib/core/testing/torch/bc_module.py @@ -3,10 +3,7 @@ from ray.rllib.core.columns import Columns from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleConfig from ray.rllib.models.torch.torch_distributions import TorchCategorical -from ray.rllib.core.rl_module.marl_module import ( - MultiAgentRLModule, - MultiAgentRLModuleConfig, -) +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModule, MultiRLModuleConfig from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule from ray.rllib.core.models.specs.typing import SpecType from ray.rllib.utils.annotations import override @@ -130,8 +127,8 @@ def _common_forward(self, batch): return {Columns.ACTION_DIST_INPUTS: action_logits} -class BCTorchMultiAgentModuleWithSharedEncoder(MultiAgentRLModule): - def __init__(self, config: MultiAgentRLModuleConfig) -> None: +class BCTorchMultiAgentModuleWithSharedEncoder(MultiRLModule): + def __init__(self, config: MultiRLModuleConfig) -> None: super().__init__(config) def setup(self): diff --git a/rllib/core/testing/utils.py b/rllib/core/testing/utils.py index 1e32dfd2e18f1..ca4878aba0b6d 100644 --- a/rllib/core/testing/utils.py +++ b/rllib/core/testing/utils.py @@ -18,7 +18,7 @@ def add_module_to_learner_or_learner_group( ): learner_group_or_learner.add_module( module_id=module_id, - module_spec=config.get_marl_module_spec(env=env).module_specs[ + module_spec=config.get_multi_rl_module_spec(env=env).module_specs[ DEFAULT_MODULE_ID ], ) diff --git a/rllib/env/env_runner_group.py b/rllib/env/env_runner_group.py index a644835e4c004..30b81ce5fa0b5 100644 --- a/rllib/env/env_runner_group.py +++ b/rllib/env/env_runner_group.py @@ -28,7 +28,7 @@ ) from ray.rllib.core.learner import LearnerGroup from ray.rllib.core.rl_module import validate_module_id -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.evaluation.rollout_worker import RolloutWorker from ray.rllib.utils.actor_manager import RemoteCallResults from ray.rllib.env.base_env import BaseEnv @@ -293,14 +293,14 @@ def _get_spaces_from_remote_worker(self): # Generic EnvRunner. else: remote_spaces = self.foreach_worker( - lambda worker: worker.marl_module.foreach_module( + lambda worker: worker.multi_rl_module.foreach_module( lambda mid, m: ( mid, m.config.observation_space, m.config.action_space, ), ) - if hasattr(worker, "marl_module") + if hasattr(worker, "multi_rl_module") else [ ( DEFAULT_POLICY_ID, @@ -655,7 +655,7 @@ def add_policy( Callable[[PolicyID, Optional[SampleBatchType]], bool], ] ] = None, - module_spec: Optional[SingleAgentRLModuleSpec] = None, + module_spec: Optional[RLModuleSpec] = None, # Deprecated. workers: Optional[List[Union[EnvRunner, ActorHandle]]] = DEPRECATED_VALUE, ) -> None: diff --git a/rllib/env/multi_agent_env_runner.py b/rllib/env/multi_agent_env_runner.py index 56e97b4bc83a2..2d3ba49027237 100644 --- a/rllib/env/multi_agent_env_runner.py +++ b/rllib/env/multi_agent_env_runner.py @@ -13,7 +13,7 @@ COMPONENT_RL_MODULE, ) from ray.rllib.core.columns import Columns -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec from ray.rllib.env.env_context import EnvContext from ray.rllib.env.env_runner import EnvRunner from ray.rllib.env.multi_agent_env import MultiAgentEnv @@ -273,7 +273,7 @@ def _sample_timesteps( ) self._cached_to_module = None - # MARLModule forward pass: Explore or not. + # MultiRLModule forward pass: Explore or not. if explore: env_steps_lifetime = self.metrics.peek( NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0 @@ -470,7 +470,7 @@ def _sample_episodes( shared_data=_shared_data, ) - # MARLModule forward pass: Explore or not. + # MultiRLModule forward pass: Explore or not. if explore: env_steps_lifetime = self.metrics.peek( NUM_ENV_STEPS_SAMPLED_LIFETIME, default=0 @@ -755,7 +755,7 @@ def get_checkpointable_components(self): def assert_healthy(self): """Checks that self.__init__() has been completed properly. - Ensures that the instances has a `MultiAgentRLModule` and an + Ensures that the instances has a `MultiRLModule` and an environment defined. Raises: @@ -855,7 +855,7 @@ def _make_module(self): for mid, o in self._env_to_module.observation_space.spaces.items() }, ) - ma_rlm_spec: MultiAgentRLModuleSpec = self.config.get_marl_module_spec( + ma_rlm_spec: MultiRLModuleSpec = self.config.get_multi_rl_module_spec( policy_dict=policy_dict, # Built only a light version of the module in sampling and inference. inference_only=True, diff --git a/rllib/env/single_agent_env_runner.py b/rllib/env/single_agent_env_runner.py index 7dc8b33d0ce31..81814b5af76f7 100644 --- a/rllib/env/single_agent_env_runner.py +++ b/rllib/env/single_agent_env_runner.py @@ -16,7 +16,7 @@ DEFAULT_MODULE_ID, ) from ray.rllib.core.columns import Columns -from ray.rllib.core.rl_module.rl_module import RLModule, SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleSpec from ray.rllib.env.env_context import EnvContext from ray.rllib.env.env_runner import EnvRunner from ray.rllib.env.single_agent_episode import SingleAgentEpisode @@ -99,7 +99,7 @@ def __init__(self, config: AlgorithmConfig, **kwargs): # Create our own instance of the (single-agent) `RLModule` (which # the needs to be weight-synched) each iteration. try: - module_spec: SingleAgentRLModuleSpec = self.config.rl_module_spec + module_spec: RLModuleSpec = self.config.rl_module_spec module_spec.observation_space = self._env_to_module.observation_space module_spec.action_space = self.env.single_action_space if module_spec.model_config_dict is None: @@ -107,7 +107,7 @@ def __init__(self, config: AlgorithmConfig, **kwargs): # Only load a light version of the module, if available. This is useful # if the the module has target or critic networks not needed in sampling # or inference. - # TODO (simon): Once we use `get_marl_module_spec` here, we can remove + # TODO (simon): Once we use `get_multi_rl_module_spec` here, we can remove # this line here as the function takes care of this flag. module_spec.inference_only = True self.module: RLModule = module_spec.build() @@ -742,7 +742,7 @@ def get_checkpointable_components(self): def assert_healthy(self): """Checks that self.__init__() has been completed properly. - Ensures that the instances has a `MultiAgentRLModule` and an + Ensures that the instances has a `MultiRLModule` and an environment defined. Raises: diff --git a/rllib/evaluation/rollout_worker.py b/rllib/evaluation/rollout_worker.py index 851a320fceaf8..f29fe7d0528c1 100644 --- a/rllib/evaluation/rollout_worker.py +++ b/rllib/evaluation/rollout_worker.py @@ -32,7 +32,7 @@ maybe_get_filters_for_syncing, ) from ray.rllib.core.rl_module import validate_module_id -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.env.base_env import BaseEnv, convert_to_base_env from ray.rllib.env.env_context import EnvContext from ray.rllib.env.env_runner import EnvRunner @@ -517,8 +517,9 @@ def wrap(env): self.filters: Dict[PolicyID, Filter] = defaultdict(NoFilter) - # if RLModule API is enabled, marl_module_spec holds the specs of the RLModules - self.marl_module_spec = None + # If RLModule API is enabled, multi_rl_module_spec holds the specs of the + # RLModules. + self.multi_rl_module_spec = None self._update_policy_map(policy_dict=self.policy_dict) # Update Policy's view requirements from Model, only if Policy directly @@ -1074,7 +1075,7 @@ def add_policy( policies_to_train: Optional[ Union[Collection[PolicyID], Callable[[PolicyID, SampleBatchType], bool]] ] = None, - module_spec: Optional[SingleAgentRLModuleSpec] = None, + module_spec: Optional[RLModuleSpec] = None, ) -> Policy: """Adds a new policy to this RolloutWorker. @@ -1674,7 +1675,7 @@ def _update_policy_map( policy_dict: MultiAgentPolicyConfigDict, policy: Optional[Policy] = None, policy_states: Optional[Dict[PolicyID, PolicyState]] = None, - single_agent_rl_module_spec: Optional[SingleAgentRLModuleSpec] = None, + single_agent_rl_module_spec: Optional[RLModuleSpec] = None, ) -> None: """Updates the policy map (and other stuff) on this worker. @@ -1683,7 +1684,7 @@ def _update_policy_map( with the postprocessed observation_spaces. 2. It updates the policy_specs with the complete algorithm_config (merged with the policy_spec's config). - 3. If needed it will update the self.marl_module_spec on this worker + 3. If needed it will update the self.multi_rl_module_spec on this worker 3. It updates the policy map with the new policies 4. It updates the filter dict 5. It calls the on_create_policy() hook of the callbacks on the newly added @@ -1693,8 +1694,8 @@ def _update_policy_map( policy_dict: The policy dict to update the policy map with. policy: The policy to update the policy map with. policy_states: The policy states to update the policy map with. - single_agent_rl_module_spec: The SingleAgentRLModuleSpec to add to the - MultiAgentRLModuleSpec. If None, the config's + single_agent_rl_module_spec: The RLModuleSpec to add to the + MultiRLModuleSpec. If None, the config's `get_default_rl_module_spec` method's output will be used to create the policy with. """ @@ -1703,23 +1704,23 @@ def _update_policy_map( # merge configs. Also updates the preprocessor dict. updated_policy_dict = self._get_complete_policy_specs_dict(policy_dict) - # Use the updated policy dict to create the marl_module_spec if necessary + # Use the updated policy dict to create the multi_rl_module_spec if necessary if self.config.enable_rl_module_and_learner: - spec = self.config.get_marl_module_spec( + spec = self.config.get_multi_rl_module_spec( policy_dict=updated_policy_dict, single_agent_rl_module_spec=single_agent_rl_module_spec, ) - if self.marl_module_spec is None: - # this is the first time, so we should create the marl_module_spec - self.marl_module_spec = spec + if self.multi_rl_module_spec is None: + # this is the first time, so we should create the multi_rl_module_spec + self.multi_rl_module_spec = spec else: # This is adding a new policy, so we need call add_modules on the # module_specs of returned spec. - self.marl_module_spec.add_modules(spec.module_specs) + self.multi_rl_module_spec.add_modules(spec.module_specs) - # Add __marl_module_spec key into the config so that the policy can access - # it. - updated_policy_dict = self._update_policy_dict_with_marl_module( + # Add `__multi_rl_module_spec` key into the config so that the policy can + # access it. + updated_policy_dict = self._update_policy_dict_with_multi_rl_module( updated_policy_dict ) @@ -1801,11 +1802,11 @@ def _get_complete_policy_specs_dict( return updated_policy_dict - def _update_policy_dict_with_marl_module( + def _update_policy_dict_with_multi_rl_module( self, policy_dict: MultiAgentPolicyConfigDict ) -> MultiAgentPolicyConfigDict: for name, policy_spec in policy_dict.items(): - policy_spec.config["__marl_module_spec"] = self.marl_module_spec + policy_spec.config["__multi_rl_module_spec"] = self.multi_rl_module_spec return policy_dict def _build_policy_map( diff --git a/rllib/evaluation/tests/test_env_runner_v2.py b/rllib/evaluation/tests/test_env_runner_v2.py index a05e2591be770..3f3c785e251ac 100644 --- a/rllib/evaluation/tests/test_env_runner_v2.py +++ b/rllib/evaluation/tests/test_env_runner_v2.py @@ -14,8 +14,8 @@ from ray.rllib.policy.sample_batch import convert_ma_batch_to_sample_batch # The new RLModule / Learner API -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec from ray.rllib.env.tests.test_multi_agent_env import BasicMultiAgent from ray.rllib.examples.rl_modules.classes.random_rlm import RandomRLModule @@ -170,10 +170,10 @@ def compute_actions( # placeholder RLModule, since the compute_actions() method is # directly overridden in the policy class. .rl_module( - rl_module_spec=MultiAgentRLModuleSpec( + rl_module_spec=MultiRLModuleSpec( module_specs={ - "pol1": SingleAgentRLModuleSpec(module_class=RandomRLModule), - "pol2": SingleAgentRLModuleSpec(module_class=RandomRLModule), + "pol1": RLModuleSpec(module_class=RandomRLModule), + "pol2": RLModuleSpec(module_class=RandomRLModule), } ), ) @@ -237,10 +237,10 @@ def __init__(self, *args, **kwargs): count_steps_by="agent_steps", ) .rl_module( - rl_module_spec=MultiAgentRLModuleSpec( + rl_module_spec=MultiRLModuleSpec( module_specs={ - "one": SingleAgentRLModuleSpec(module_class=RandomRLModule), - "two": SingleAgentRLModuleSpec(module_class=RandomRLModule), + "one": RLModuleSpec(module_class=RandomRLModule), + "two": RLModuleSpec(module_class=RandomRLModule), } ), ) @@ -334,10 +334,10 @@ def test_start_episode(self): count_steps_by="agent_steps", ) .rl_module( - rl_module_spec=MultiAgentRLModuleSpec( + rl_module_spec=MultiRLModuleSpec( module_specs={ - "one": SingleAgentRLModuleSpec(module_class=RandomRLModule), - "two": SingleAgentRLModuleSpec(module_class=RandomRLModule), + "one": RLModuleSpec(module_class=RandomRLModule), + "two": RLModuleSpec(module_class=RandomRLModule), } ), ) @@ -390,10 +390,10 @@ def test_env_runner_output(self): count_steps_by="agent_steps", ) .rl_module( - rl_module_spec=MultiAgentRLModuleSpec( + rl_module_spec=MultiRLModuleSpec( module_specs={ - "one": SingleAgentRLModuleSpec(module_class=RandomRLModule), - "two": SingleAgentRLModuleSpec(module_class=RandomRLModule), + "one": RLModuleSpec(module_class=RandomRLModule), + "two": RLModuleSpec(module_class=RandomRLModule), } ), ) @@ -449,10 +449,10 @@ def on_episode_end( count_steps_by="agent_steps", ) .rl_module( - rl_module_spec=MultiAgentRLModuleSpec( + rl_module_spec=MultiRLModuleSpec( module_specs={ - "one": SingleAgentRLModuleSpec(module_class=RandomRLModule), - "two": SingleAgentRLModuleSpec(module_class=RandomRLModule), + "one": RLModuleSpec(module_class=RandomRLModule), + "two": RLModuleSpec(module_class=RandomRLModule), } ), ) diff --git a/rllib/examples/catalogs/custom_action_distribution.py b/rllib/examples/catalogs/custom_action_distribution.py index 6eb8aa234ea41..7aef265927b91 100644 --- a/rllib/examples/catalogs/custom_action_distribution.py +++ b/rllib/examples/catalogs/custom_action_distribution.py @@ -11,7 +11,7 @@ from ray.rllib.algorithms.ppo.ppo import PPOConfig from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.models.distributions import Distribution from ray.rllib.models.torch.torch_distributions import TorchDeterministic @@ -77,7 +77,7 @@ def get_action_dist_cls(self, framework): algo = ( PPOConfig() .environment("CartPole-v1") - .rl_module(rl_module_spec=SingleAgentRLModuleSpec(catalog_class=CustomPPOCatalog)) + .rl_module(rl_module_spec=RLModuleSpec(catalog_class=CustomPPOCatalog)) .build() ) results = algo.train() diff --git a/rllib/examples/catalogs/mobilenet_v2_encoder.py b/rllib/examples/catalogs/mobilenet_v2_encoder.py index 2358711a96675..119d9f6442ef6 100644 --- a/rllib/examples/catalogs/mobilenet_v2_encoder.py +++ b/rllib/examples/catalogs/mobilenet_v2_encoder.py @@ -14,7 +14,7 @@ from ray.rllib.algorithms.ppo.ppo import PPOConfig from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.examples._old_api_stack.models.mobilenet_v2_encoder import ( MobileNetV2EncoderConfig, MOBILENET_INPUT_SHAPE, @@ -48,11 +48,7 @@ def _get_encoder_config( enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True, ) - .rl_module( - rl_module_spec=SingleAgentRLModuleSpec( - catalog_class=MobileNetEnhancedPPOCatalog - ) - ) + .rl_module(rl_module_spec=RLModuleSpec(catalog_class=MobileNetEnhancedPPOCatalog)) .env_runners(num_env_runners=0) # The following training settings make it so that a training iteration is very # quick. This is just for the sake of this example. PPO will not learn properly diff --git a/rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py b/rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py index bf6889113fed3..e38b46309dc1d 100644 --- a/rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py +++ b/rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py @@ -3,7 +3,7 @@ This example: - Runs a multi-agent `Pendulum-v1` experiment with >= 2 policies. - - Saves a checkpoint of the `MultiAgentRLModule` used every `--checkpoint-freq` + - Saves a checkpoint of the `MultiRLModule` used every `--checkpoint-freq` iterations. - Stops the experiments after the agents reach a combined return of -800. - Picks the best checkpoint by combined return and restores policy 0 from it. @@ -40,7 +40,7 @@ ----------------- You should expect a reward of -400.0 eventually being achieved by a simple single PPO policy (no tuning, just using RLlib's default settings). In the -second run of the experiment, the MARL module weights for policy 0 are +second run of the experiment, the MultiRLModule weights for policy 0 are restored from the checkpoint of the first run. The reward for a single agent should be -400.0 again, but the training time should be shorter (around 30 iterations instead of 190). @@ -48,7 +48,7 @@ import os from ray.air.constants import TRAINING_ITERATION -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec from ray.rllib.examples.envs.classes.multi_agent import MultiAgentPendulum from ray.rllib.utils.metrics import ( ENV_RUNNER_RESULTS, @@ -137,10 +137,10 @@ module_spec.load_state_path = p_0_module_state_path module_specs["p0"] = module_spec - # Create the MARL module. - marl_module_spec = MultiAgentRLModuleSpec(module_specs=module_specs) - # Define the MARL module in the base config. - base_config.rl_module(rl_module_spec=marl_module_spec) + # Create the MultiRLModule. + multi_rl_module_spec = MultiRLModuleSpec(module_specs=module_specs) + # Define the MultiRLModule in the base config. + base_config.rl_module(rl_module_spec=multi_rl_module_spec) # We need to re-register the environment when starting a new run. register_env( "env", @@ -153,5 +153,5 @@ TRAINING_ITERATION: 30, } - # Run the experiment again with the restored MARL module. + # Run the experiment again with the restored MultiRLModule. run_rllib_example_script_experiment(base_config, args, stop=stop) diff --git a/rllib/examples/custom_recurrent_rnn_tokenizer.py b/rllib/examples/custom_recurrent_rnn_tokenizer.py index fe1d6c225f216..f41c432a35191 100644 --- a/rllib/examples/custom_recurrent_rnn_tokenizer.py +++ b/rllib/examples/custom_recurrent_rnn_tokenizer.py @@ -18,7 +18,7 @@ from ray.tune.registry import register_env from ray.rllib.examples.envs.classes.repeat_after_me_env import RepeatAfterMeEnv from ray.rllib.examples.envs.classes.repeat_initial_obs_env import RepeatInitialObsEnv -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.policy.sample_batch import SampleBatch from dataclasses import dataclass from ray.rllib.core.models.specs.specs_dict import SpecDict @@ -183,9 +183,7 @@ def get_tokenizer_config( entropy_coeff=0.001, vf_loss_coeff=1e-5, ) - .rl_module( - rl_module_spec=SingleAgentRLModuleSpec(catalog_class=CustomPPOCatalog) - ) + .rl_module(rl_module_spec=RLModuleSpec(catalog_class=CustomPPOCatalog)) # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0. .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0"))) ) diff --git a/rllib/examples/learners/ppo_load_rl_modules.py b/rllib/examples/learners/ppo_load_rl_modules.py index 6bfd9acfbb120..6f175c404ee9e 100644 --- a/rllib/examples/learners/ppo_load_rl_modules.py +++ b/rllib/examples/learners/ppo_load_rl_modules.py @@ -10,7 +10,7 @@ from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog from ray.rllib.algorithms.ppo.tf.ppo_tf_rl_module import PPOTfRLModule from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec def _parse_args(): @@ -37,7 +37,7 @@ def _parse_args(): # where you had enabled checkpointing, the learner api and the rl module api module_class = PPOTfRLModule if args.framework == "tf2" else PPOTorchRLModule env = gym.make("CartPole-v1") - module_to_load = SingleAgentRLModuleSpec( + module_to_load = RLModuleSpec( module_class=module_class, model_config_dict={"fcnet_hiddens": [32]}, catalog_class=PPOCatalog, @@ -49,7 +49,7 @@ def _parse_args(): module_to_load.save_to_path(CHECKPOINT_DIR) # Create a module spec to load the checkpoint - module_to_load_spec = SingleAgentRLModuleSpec( + module_to_load_spec = RLModuleSpec( module_class=module_class, model_config_dict={"fcnet_hiddens": [32]}, catalog_class=PPOCatalog, diff --git a/rllib/examples/learners/train_w_bc_finetune_w_ppo.py b/rllib/examples/learners/train_w_bc_finetune_w_ppo.py index 6e2778e0d1997..2a5a2baae7309 100644 --- a/rllib/examples/learners/train_w_bc_finetune_w_ppo.py +++ b/rllib/examples/learners/train_w_bc_finetune_w_ppo.py @@ -17,7 +17,7 @@ from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog from ray.rllib.core.models.base import ACTOR, ENCODER_OUT -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.utils.metrics import ( EPISODE_RETURN_MEAN, ENV_RUNNER_RESULTS, @@ -66,7 +66,7 @@ def forward( def train_ppo_module_with_bc_finetune( - dataset: ray.data.Dataset, ppo_module_spec: SingleAgentRLModuleSpec + dataset: ray.data.Dataset, ppo_module_spec: RLModuleSpec ) -> str: """Train an Actor with BC finetuning on dataset. @@ -107,7 +107,7 @@ def train_ppo_module_with_bc_finetune( def train_ppo_agent_from_checkpointed_module( - module_spec_from_ckpt: SingleAgentRLModuleSpec, + module_spec_from_ckpt: RLModuleSpec, ) -> float: """Trains a checkpointed RLModule using PPO. @@ -156,7 +156,7 @@ def train_ppo_agent_from_checkpointed_module( ds = ray.data.read_json("s3://rllib-oss-tests/cartpole-expert") - module_spec = SingleAgentRLModuleSpec( + module_spec = RLModuleSpec( module_class=PPOTorchRLModule, observation_space=GYM_ENV.observation_space, action_space=GYM_ENV.action_space, diff --git a/rllib/examples/multi_agent/custom_heuristic_policy.py b/rllib/examples/multi_agent/custom_heuristic_policy.py index 9da722e047fe5..3fbb760b310f4 100644 --- a/rllib/examples/multi_agent/custom_heuristic_policy.py +++ b/rllib/examples/multi_agent/custom_heuristic_policy.py @@ -42,8 +42,8 @@ """ from ray.rllib.algorithms.ppo import PPOConfig -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole from ray.rllib.examples.rl_modules.classes.random_rlm import RandomRLModule from ray.rllib.utils.test_utils import ( @@ -88,10 +88,10 @@ policies_to_train=["learnable_policy"], ) .rl_module( - rl_module_spec=MultiAgentRLModuleSpec( + rl_module_spec=MultiRLModuleSpec( module_specs={ - "learnable_policy": SingleAgentRLModuleSpec(), - "random": SingleAgentRLModuleSpec(module_class=RandomRLModule), + "learnable_policy": RLModuleSpec(), + "random": RLModuleSpec(module_class=RandomRLModule), } ), ) diff --git a/rllib/examples/multi_agent/pettingzoo_independent_learning.py b/rllib/examples/multi_agent/pettingzoo_independent_learning.py index d7fb8312a2bb0..9233b9fe15850 100644 --- a/rllib/examples/multi_agent/pettingzoo_independent_learning.py +++ b/rllib/examples/multi_agent/pettingzoo_independent_learning.py @@ -51,8 +51,8 @@ from pettingzoo.sisl import waterworld_v4 -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv from ray.rllib.utils.test_utils import ( add_rllib_example_script_args, @@ -98,8 +98,8 @@ ) .rl_module( model_config_dict={"vf_share_layers": True}, - rl_module_spec=MultiAgentRLModuleSpec( - module_specs={p: SingleAgentRLModuleSpec() for p in policies}, + rl_module_spec=MultiRLModuleSpec( + module_specs={p: RLModuleSpec() for p in policies}, ), ) ) diff --git a/rllib/examples/multi_agent/pettingzoo_parameter_sharing.py b/rllib/examples/multi_agent/pettingzoo_parameter_sharing.py index 55a1c173b3154..477f8cad2bab9 100644 --- a/rllib/examples/multi_agent/pettingzoo_parameter_sharing.py +++ b/rllib/examples/multi_agent/pettingzoo_parameter_sharing.py @@ -50,8 +50,8 @@ """ from pettingzoo.sisl import waterworld_v4 -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv from ray.rllib.utils.test_utils import ( add_rllib_example_script_args, @@ -96,8 +96,8 @@ vf_loss_coeff=0.005, ) .rl_module( - rl_module_spec=MultiAgentRLModuleSpec( - module_specs={"p0": SingleAgentRLModuleSpec()}, + rl_module_spec=MultiRLModuleSpec( + module_specs={"p0": RLModuleSpec()}, ), ) ) diff --git a/rllib/examples/multi_agent/pettingzoo_shared_value_function.py b/rllib/examples/multi_agent/pettingzoo_shared_value_function.py index 7285b9bedc6dc..e2c8bb9a4ffb9 100644 --- a/rllib/examples/multi_agent/pettingzoo_shared_value_function.py +++ b/rllib/examples/multi_agent/pettingzoo_shared_value_function.py @@ -1,6 +1,6 @@ msg = """ This script is NOT yet ready, but will be available soon at this location. It will -feature a MultiAgentRLModule with one shared value function and n policy heads for +feature a MultiRLModule with one shared value function and n policy heads for cooperative multi-agent learning. """ diff --git a/rllib/examples/multi_agent/rock_paper_scissors_heuristic_vs_learned.py b/rllib/examples/multi_agent/rock_paper_scissors_heuristic_vs_learned.py index c09a297627f13..813dde80205e6 100644 --- a/rllib/examples/multi_agent/rock_paper_scissors_heuristic_vs_learned.py +++ b/rllib/examples/multi_agent/rock_paper_scissors_heuristic_vs_learned.py @@ -34,8 +34,8 @@ from ray.air.constants import TRAINING_ITERATION from ray.rllib.connectors.env_to_module import FlattenObservations -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv from ray.rllib.utils.metrics import ( ENV_RUNNER_RESULTS, @@ -119,19 +119,19 @@ "max_seq_len": 15, "vf_share_layers": True, }, - rl_module_spec=MultiAgentRLModuleSpec( + rl_module_spec=MultiRLModuleSpec( module_specs={ - "always_same": SingleAgentRLModuleSpec( + "always_same": RLModuleSpec( module_class=AlwaysSameHeuristicRLM, observation_space=gym.spaces.Discrete(4), action_space=gym.spaces.Discrete(3), ), - "beat_last": SingleAgentRLModuleSpec( + "beat_last": RLModuleSpec( module_class=BeatLastHeuristicRLM, observation_space=gym.spaces.Discrete(4), action_space=gym.spaces.Discrete(3), ), - "learned": SingleAgentRLModuleSpec(), + "learned": RLModuleSpec(), } ), ) diff --git a/rllib/examples/multi_agent/rock_paper_scissors_learned_vs_learned.py b/rllib/examples/multi_agent/rock_paper_scissors_learned_vs_learned.py index e3e75c9906924..6c7f70c7248b1 100644 --- a/rllib/examples/multi_agent/rock_paper_scissors_learned_vs_learned.py +++ b/rllib/examples/multi_agent/rock_paper_scissors_learned_vs_learned.py @@ -16,8 +16,8 @@ from pettingzoo.classic import rps_v2 from ray.rllib.connectors.env_to_module import FlattenObservations -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv from ray.rllib.utils.test_utils import ( add_rllib_example_script_args, @@ -77,10 +77,10 @@ "max_seq_len": 15, "vf_share_layers": True, }, - rl_module_spec=MultiAgentRLModuleSpec( + rl_module_spec=MultiRLModuleSpec( module_specs={ - "p0": SingleAgentRLModuleSpec(), - "p1": SingleAgentRLModuleSpec(), + "p0": RLModuleSpec(), + "p1": RLModuleSpec(), } ), ) diff --git a/rllib/examples/multi_agent/self_play_league_based_with_open_spiel.py b/rllib/examples/multi_agent/self_play_league_based_with_open_spiel.py index a262d4db9e455..c4fe7e30e8147 100644 --- a/rllib/examples/multi_agent/self_play_league_based_with_open_spiel.py +++ b/rllib/examples/multi_agent/self_play_league_based_with_open_spiel.py @@ -35,8 +35,8 @@ import ray from ray.air.constants import TRAINING_ITERATION -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.env.utils import try_import_pyspiel, try_import_open_spiel from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv from ray.rllib.examples.multi_agent.utils import ( @@ -129,7 +129,7 @@ def _get_multi_agent(): if args.enable_new_api_stack: policies = names spec = { - mid: SingleAgentRLModuleSpec( + mid: RLModuleSpec( module_class=( RandomRLModule if mid in ["main_exploiter_0", "league_exploiter_0"] @@ -205,9 +205,7 @@ def _get_multi_agent(): policies_to_train=["main"], ) .rl_module( - rl_module_spec=MultiAgentRLModuleSpec( - module_specs=_get_multi_agent()["spec"] - ), + rl_module_spec=MultiRLModuleSpec(module_specs=_get_multi_agent()["spec"]), ) ) diff --git a/rllib/examples/multi_agent/self_play_with_open_spiel.py b/rllib/examples/multi_agent/self_play_with_open_spiel.py index a12b5df0c7cd2..3c01d25a244c0 100644 --- a/rllib/examples/multi_agent/self_play_with_open_spiel.py +++ b/rllib/examples/multi_agent/self_play_with_open_spiel.py @@ -23,8 +23,8 @@ import numpy as np from ray.air.constants import TRAINING_ITERATION -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec from ray.rllib.env.utils import try_import_pyspiel, try_import_open_spiel from ray.rllib.env.wrappers.open_spiel import OpenSpielEnv from ray.rllib.examples.rl_modules.classes.random_rlm import RandomRLModule @@ -164,10 +164,10 @@ def policy_mapping_fn(agent_id, episode, worker, **kwargs): "fcnet_hiddens": [512, 512], "uses_new_env_runners": args.enable_new_api_stack, }, - rl_module_spec=MultiAgentRLModuleSpec( + rl_module_spec=MultiRLModuleSpec( module_specs={ - "main": SingleAgentRLModuleSpec(), - "random": SingleAgentRLModuleSpec(module_class=RandomRLModule), + "main": RLModuleSpec(), + "random": RLModuleSpec(module_class=RandomRLModule), } ), ) diff --git a/rllib/examples/multi_agent/two_step_game_with_grouped_agents.py b/rllib/examples/multi_agent/two_step_game_with_grouped_agents.py index 2c94358222905..d4ac3d24c36cb 100644 --- a/rllib/examples/multi_agent/two_step_game_with_grouped_agents.py +++ b/rllib/examples/multi_agent/two_step_game_with_grouped_agents.py @@ -41,8 +41,8 @@ """ from ray.rllib.connectors.env_to_module import FlattenObservations -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.examples.envs.classes.two_step_game import TwoStepGameWithGroupedAgents from ray.rllib.utils.test_utils import ( add_rllib_example_script_args, @@ -79,9 +79,9 @@ policy_mapping_fn=lambda aid, *a, **kw: "p0", ) .rl_module( - rl_module_spec=MultiAgentRLModuleSpec( + rl_module_spec=MultiRLModuleSpec( module_specs={ - "p0": SingleAgentRLModuleSpec(), + "p0": RLModuleSpec(), }, ) ) diff --git a/rllib/examples/multi_agent/utils/self_play_callback.py b/rllib/examples/multi_agent/utils/self_play_callback.py index db5962ee1f4b6..7544d2ff8efe6 100644 --- a/rllib/examples/multi_agent/utils/self_play_callback.py +++ b/rllib/examples/multi_agent/utils/self_play_callback.py @@ -3,7 +3,7 @@ import numpy as np from ray.rllib.algorithms.callbacks import DefaultCallbacks -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.utils.metrics import ENV_RUNNER_RESULTS @@ -61,7 +61,7 @@ def agent_to_module_mapping_fn(agent_id, episode, **kwargs): main_module = algorithm.get_module("main") algorithm.add_module( module_id=new_module_id, - module_spec=SingleAgentRLModuleSpec.from_module(main_module), + module_spec=RLModuleSpec.from_module(main_module), module_state=main_module.get_state(), new_agent_to_module_mapping_fn=agent_to_module_mapping_fn, ) diff --git a/rllib/examples/multi_agent/utils/self_play_league_based_callback.py b/rllib/examples/multi_agent/utils/self_play_league_based_callback.py index e920793db31d4..86f27d8b0778c 100644 --- a/rllib/examples/multi_agent/utils/self_play_league_based_callback.py +++ b/rllib/examples/multi_agent/utils/self_play_league_based_callback.py @@ -5,7 +5,7 @@ import numpy as np from ray.rllib.algorithms.callbacks import DefaultCallbacks -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.utils.metrics import ENV_RUNNER_RESULTS @@ -162,26 +162,26 @@ def agent_to_module_mapping_fn(agent_id, episode, **kwargs): else: return main - marl_module = local_worker.module - main_module = marl_module["main"] + multi_rl_module = local_worker.module + main_module = multi_rl_module["main"] # Set the weights of the new polic(y/ies). if initializing_exploiters: main_state = main_module.get_state() - marl_module["main_0"].set_state(main_state) - marl_module["league_exploiter_1"].set_state(main_state) - marl_module["main_exploiter_1"].set_state(main_state) + multi_rl_module["main_0"].set_state(main_state) + multi_rl_module["league_exploiter_1"].set_state(main_state) + multi_rl_module["main_exploiter_1"].set_state(main_state) # We need to sync the just copied local weights to all the # remote workers and remote Learner workers as well. algorithm.env_runner_group.sync_weights( policies=["main_0", "league_exploiter_1", "main_exploiter_1"] ) - algorithm.learner_group.set_weights(marl_module.get_state()) + algorithm.learner_group.set_weights(multi_rl_module.get_state()) else: algorithm.add_module( module_id=new_mod_id, - module_spec=SingleAgentRLModuleSpec.from_module(main_module), - module_state=marl_module[module_id].get_state(), + module_spec=RLModuleSpec.from_module(main_module), + module_state=multi_rl_module[module_id].get_state(), ) algorithm.env_runner_group.foreach_worker( diff --git a/rllib/examples/rl_modules/action_masking_rlm.py b/rllib/examples/rl_modules/action_masking_rlm.py index ddd37d2010967..8a7ba3d6e966e 100644 --- a/rllib/examples/rl_modules/action_masking_rlm.py +++ b/rllib/examples/rl_modules/action_masking_rlm.py @@ -66,7 +66,7 @@ from gymnasium.spaces import Box, Discrete from ray.rllib.algorithms.ppo import PPOConfig -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.examples.envs.classes.action_mask_env import ActionMaskEnv from ray.rllib.examples.rl_modules.classes.action_masking_rlm import ( ActionMaskingTorchRLModule, @@ -116,7 +116,7 @@ }, # We need to explicitly specify here RLModule to use and # the catalog needed to build it. - rl_module_spec=SingleAgentRLModuleSpec( + rl_module_spec=RLModuleSpec( module_class=ActionMaskingTorchRLModule, ), ) diff --git a/rllib/examples/rl_modules/autoregressive_actions_rlm.py b/rllib/examples/rl_modules/autoregressive_actions_rlm.py index 03a6e12dbd3d8..0f0c8510f9ad3 100644 --- a/rllib/examples/rl_modules/autoregressive_actions_rlm.py +++ b/rllib/examples/rl_modules/autoregressive_actions_rlm.py @@ -41,7 +41,7 @@ from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.core.models.catalog import Catalog -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.examples.envs.classes.correlated_actions_env import CorrelatedActionsEnv from ray.rllib.examples.rl_modules.classes.autoregressive_actions_rlm import ( AutoregressiveActionsTorchRLM, @@ -83,7 +83,7 @@ }, # We need to explicitly specify here RLModule to use and # the catalog needed to build it. - rl_module_spec=SingleAgentRLModuleSpec( + rl_module_spec=RLModuleSpec( module_class=AutoregressiveActionsTorchRLM, catalog_class=Catalog, ), diff --git a/rllib/examples/rl_modules/classes/mobilenet_rlm.py b/rllib/examples/rl_modules/classes/mobilenet_rlm.py index f31ae4f1c6d46..49878ec555f9d 100644 --- a/rllib/examples/rl_modules/classes/mobilenet_rlm.py +++ b/rllib/examples/rl_modules/classes/mobilenet_rlm.py @@ -10,7 +10,7 @@ from ray.rllib.algorithms.ppo.ppo import PPOConfig from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule from ray.rllib.core.models.configs import MLPHeadConfig -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.examples.envs.classes.random_env import RandomEnv from ray.rllib.models.torch.torch_distributions import TorchCategorical from ray.rllib.examples._old_api_stack.models.mobilenet_v2_encoder import ( @@ -57,9 +57,7 @@ def setup(self): config = ( PPOConfig() .api_stack(enable_rl_module_and_learner=True) - .rl_module( - rl_module_spec=SingleAgentRLModuleSpec(module_class=MobileNetTorchPPORLModule) - ) + .rl_module(rl_module_spec=RLModuleSpec(module_class=MobileNetTorchPPORLModule)) .environment( RandomEnv, env_config={ diff --git a/rllib/examples/rl_modules/custom_cnn_rl_module.py b/rllib/examples/rl_modules/custom_cnn_rl_module.py index a171f83467d2f..a340a1ccb86c1 100644 --- a/rllib/examples/rl_modules/custom_cnn_rl_module.py +++ b/rllib/examples/rl_modules/custom_cnn_rl_module.py @@ -54,7 +54,7 @@ """ import gymnasium as gym -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.env.wrappers.atari_wrappers import wrap_atari_for_new_api_stack from ray.rllib.examples.rl_modules.classes.tiny_atari_cnn_rlm import TinyAtariCNN from ray.rllib.utils.test_utils import ( @@ -96,7 +96,7 @@ ) .rl_module( # Plug-in our custom RLModule class. - rl_module_spec=SingleAgentRLModuleSpec( + rl_module_spec=RLModuleSpec( module_class=TinyAtariCNN, # Feel free to specify your own `model_config_dict` settings below. # The `model_config_dict` defined here will be available inside your diff --git a/rllib/examples/rl_modules/custom_lstm_rl_module.py b/rllib/examples/rl_modules/custom_lstm_rl_module.py index 5612df47104d5..3d3cf285eb192 100644 --- a/rllib/examples/rl_modules/custom_lstm_rl_module.py +++ b/rllib/examples/rl_modules/custom_lstm_rl_module.py @@ -41,7 +41,7 @@ You should see the following output (during the experiment) in your console: """ -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.examples.envs.classes.stateless_cartpole import StatelessCartPole from ray.rllib.examples.envs.classes.multi_agent import MultiAgentStatelessCartPole from ray.rllib.examples.rl_modules.classes.lstm_containing_rlm import ( @@ -87,7 +87,7 @@ ) .rl_module( # Plug-in our custom RLModule class. - rl_module_spec=SingleAgentRLModuleSpec( + rl_module_spec=RLModuleSpec( module_class=LSTMContainingRLModule, # Feel free to specify your own `model_config_dict` settings below. # The `model_config_dict` defined here will be available inside your diff --git a/rllib/examples/rl_modules/pretraining_single_agent_training_multi_agent.py b/rllib/examples/rl_modules/pretraining_single_agent_training_multi_agent.py index 6d8f7b3d84f4a..b49761efb2fb6 100644 --- a/rllib/examples/rl_modules/pretraining_single_agent_training_multi_agent.py +++ b/rllib/examples/rl_modules/pretraining_single_agent_training_multi_agent.py @@ -28,8 +28,8 @@ from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole from ray.rllib.utils.test_utils import ( add_rllib_example_script_args, @@ -99,12 +99,12 @@ # Get the checkpoint path. module_chkpt_path = results.get_best_result().checkpoint.path - # Create a new MARL Module using the pre-trained module for policy 0. + # Create a new MultiRLModule using the pre-trained module for policy 0. env = gym.make("CartPole-v1") module_specs = {} module_class = PPOTorchRLModule for i in range(args.num_agents): - module_specs[f"policy_{i}"] = SingleAgentRLModuleSpec( + module_specs[f"policy_{i}"] = RLModuleSpec( module_class=PPOTorchRLModule, observation_space=env.observation_space, action_space=env.action_space, @@ -113,7 +113,7 @@ ) # Swap in the pre-trained module for policy 0. - module_specs["policy_0"] = SingleAgentRLModuleSpec( + module_specs["policy_0"] = RLModuleSpec( module_class=PPOTorchRLModule, observation_space=env.observation_space, action_space=env.action_space, @@ -122,7 +122,7 @@ # Note, we load here the module directly from the checkpoint. load_state_path=module_chkpt_path, ) - marl_module_spec = MultiAgentRLModuleSpec(module_specs=module_specs) + multi_rl_module_spec = MultiRLModuleSpec(module_specs=module_specs) # Register our environment with tune if we use multiple agents. register_env( @@ -136,7 +136,7 @@ .environment( "multi-agent-carpole-env" if args.num_agents > 0 else "CartPole-v1" ) - .rl_module(rl_module_spec=marl_module_spec) + .rl_module(rl_module_spec=multi_rl_module_spec) ) # Restore the user's stopping criteria for the training run. diff --git a/rllib/offline/offline_prelearner.py b/rllib/offline/offline_prelearner.py index 446bcc69bc5fe..c7e7a79cce11b 100644 --- a/rllib/offline/offline_prelearner.py +++ b/rllib/offline/offline_prelearner.py @@ -7,7 +7,7 @@ from ray.rllib.algorithms.algorithm_config import AlgorithmConfig from ray.rllib.core.columns import Columns from ray.rllib.core.learner import Learner -from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch from ray.rllib.utils.annotations import ( @@ -79,7 +79,7 @@ def __init__( config: AlgorithmConfig, learner: Union[Learner, list[ActorHandle]], locality_hints: Optional[list] = None, - module_spec: Optional[MultiAgentRLModuleSpec] = None, + module_spec: Optional[MultiRLModuleSpec] = None, module_state: Optional[Dict[ModuleID, Any]] = None, ): @@ -111,7 +111,7 @@ def __init__( # Then choose a learner randomly. self._learner = learner[random.randint(0, len(learner) - 1)] self.learner_is_remote = True - # Build the module from spec. Note, this will be a MARL module. + # Build the module from spec. Note, this will be a MultiRLModule. self._module = module_spec.build() self._module.set_state(module_state) # Build the learner connector pipeline. @@ -190,7 +190,7 @@ def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]] return {"batch": [batch]} def _should_module_be_updated(self, module_id, multi_agent_batch=None): - """Checks which modules in a MARL module should be updated.""" + """Checks which modules in a MultiRLModule should be updated.""" if not self._policies_to_train: # In case of no update information, the module is updated. return True diff --git a/rllib/policy/policy.py b/rllib/policy/policy.py index 048d0d97455ba..8ddeda0668b05 100644 --- a/rllib/policy/policy.py +++ b/rllib/policy/policy.py @@ -405,7 +405,7 @@ def make_rl_module(self) -> "RLModule": Otherwise, RLlib will error out. """ # if imported on top it creates circular dependency - from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec + from ray.rllib.core.rl_module.rl_module import RLModuleSpec if self.__policy_id is None: raise ValueError( @@ -414,17 +414,17 @@ def make_rl_module(self) -> "RLModule": "bug, please file a github issue." ) - spec = self.config["__marl_module_spec"] - if isinstance(spec, SingleAgentRLModuleSpec): + spec = self.config["__multi_rl_module_spec"] + if isinstance(spec, RLModuleSpec): module = spec.build() else: # filter the module_spec to only contain the policy_id of this policy marl_spec = type(spec)( - marl_module_class=spec.marl_module_class, + multi_rl_module_class=spec.multi_rl_module_class, module_specs={self.__policy_id: spec.module_specs[self.__policy_id]}, ) - marl_module = marl_spec.build() - module = marl_module[self.__policy_id] + multi_rl_module = marl_spec.build() + module = multi_rl_module[self.__policy_id] return module diff --git a/rllib/tuned_examples/impala/pong_impala.py b/rllib/tuned_examples/impala/pong_impala.py index 72e635d46d99f..2b50c8242a55e 100644 --- a/rllib/tuned_examples/impala/pong_impala.py +++ b/rllib/tuned_examples/impala/pong_impala.py @@ -1,7 +1,7 @@ import gymnasium as gym from ray.rllib.algorithms.impala import IMPALAConfig -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.env.wrappers.atari_wrappers import wrap_atari_for_new_api_stack from ray.rllib.examples.rl_modules.classes.tiny_atari_cnn_rlm import TinyAtariCNN from ray.rllib.utils.metrics import ( @@ -68,9 +68,7 @@ def _env_creator(cfg): ) .rl_module( rl_module_spec=( - SingleAgentRLModuleSpec(module_class=TinyAtariCNN) - if args.use_tiny_cnn - else None + RLModuleSpec(module_class=TinyAtariCNN) if args.use_tiny_cnn else None ), model_config_dict=( { diff --git a/rllib/tuned_examples/impala/pong_impala_pb2_hyperopt.py b/rllib/tuned_examples/impala/pong_impala_pb2_hyperopt.py index e66f153874f5a..b7d5274aa4a3f 100644 --- a/rllib/tuned_examples/impala/pong_impala_pb2_hyperopt.py +++ b/rllib/tuned_examples/impala/pong_impala_pb2_hyperopt.py @@ -1,7 +1,7 @@ import gymnasium as gym from ray.rllib.algorithms.impala import IMPALAConfig -from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.core.rl_module.rl_module import RLModuleSpec from ray.rllib.env.wrappers.atari_wrappers import wrap_atari_for_new_api_stack from ray.rllib.examples.rl_modules.classes.tiny_atari_cnn_rlm import TinyAtariCNN from ray.rllib.utils.metrics import ( @@ -99,9 +99,7 @@ def _env_creator(cfg): ) .rl_module( rl_module_spec=( - SingleAgentRLModuleSpec(module_class=TinyAtariCNN) - if args.use_tiny_cnn - else None + RLModuleSpec(module_class=TinyAtariCNN) if args.use_tiny_cnn else None ), model_config_dict=( { diff --git a/rllib/utils/typing.py b/rllib/utils/typing.py index 7ef1089be21b4..7c86ad47844b9 100644 --- a/rllib/utils/typing.py +++ b/rllib/utils/typing.py @@ -19,8 +19,8 @@ from ray.rllib.utils.annotations import OldAPIStack if TYPE_CHECKING: - from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec - from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec + from ray.rllib.core.rl_module.rl_module import RLModuleSpec + from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec from ray.rllib.env.env_context import EnvContext from ray.rllib.env.multi_agent_episode import MultiAgentEpisode from ray.rllib.env.single_agent_episode import SingleAgentEpisode @@ -52,7 +52,7 @@ NetworkType = Union["torch.nn.Module", "tf.keras.Model"] # An RLModule spec (single-agent or multi-agent). -RLModuleSpec = Union["SingleAgentRLModuleSpec", "MultiAgentRLModuleSpec"] +RLModuleSpecType = Union["RLModuleSpec", "MultiRLModuleSpec"] # A state dict of an RLlib component (e.g. EnvRunner, Learner, RLModule). StateDict = Dict[str, Any]