From 324ebf4fb7d1c8d0f508057294d625542a22bcae Mon Sep 17 00:00:00 2001 From: Nolwen Date: Fri, 20 Dec 2024 09:56:40 +0100 Subject: [PATCH] Implement a GNN PPO for ray-rllib We follow the same guidelines as for the sb3 wrapper: - GNN based on pytorch-geometric - Feature extraction via GNN + reduction layer to a fixed number of feature - Observation = Graph or dict whose values contains at least one Graph - Action masks are taken into account if available - User must use GraphPPO instead of PPO as algorithm: GraphPPO overrides PPO to change the way obs is converted to pytorch format Worth noticing: - We use the old api stack as the RLlib wrapper is currently using it - For graph observations, the model is gnn extractor followed by a FullyConnectedNetwork - For dict of graphs (and other) observations, the model is - preprocess obs by using gnn features extractor for graph components - apply to the prepreocessed obs a ComplexInputNetwork - action masking is automatically activated according to domain class (not UnrestrictedActions) and algo class, as it was already coded in RayRLlib wrapper. The algo to be used is still GraphPPO as masking is managed by a custom model at RayRLlib wrapper level. --- .../hub/solver/ray_rllib/custom_models.py | 45 +- skdecide/hub/solver/ray_rllib/gnn/__init__.py | 0 .../ray_rllib/gnn/algorithms/__init__.py | 1 + .../ray_rllib/gnn/algorithms/ppo/__init__.py | 0 .../ray_rllib/gnn/algorithms/ppo/ppo.py | 21 + .../gnn/algorithms/ppo/ppo_torch_policy.py | 7 + .../solver/ray_rllib/gnn/models/__init__.py | 0 .../ray_rllib/gnn/models/torch/__init__.py | 0 .../gnn/models/torch/complex_input_net.py | 66 ++ .../solver/ray_rllib/gnn/models/torch/gnn.py | 77 ++ .../solver/ray_rllib/gnn/policy/__init__.py | 0 .../ray_rllib/gnn/policy/sample_batch.py | 116 +++ .../gnn/policy/torch_graph_policy.py | 16 + .../ray_rllib/gnn/policy/torch_mixins.py | 26 + .../hub/solver/ray_rllib/gnn/torch_layers.py | 114 +++ .../solver/ray_rllib/gnn/utils/__init__.py | 0 .../ray_rllib/gnn/utils/spaces/__init__.py | 0 .../ray_rllib/gnn/utils/spaces/space_utils.py | 129 ++++ .../solver/ray_rllib/gnn/utils/torch_utils.py | 142 ++++ skdecide/hub/solver/ray_rllib/ray_rllib.py | 328 ++++++--- tests/solvers/python/test_gnn_ray_rllib.py | 675 ++++++++++++++++++ 21 files changed, 1669 insertions(+), 94 deletions(-) create mode 100644 skdecide/hub/solver/ray_rllib/gnn/__init__.py create mode 100644 skdecide/hub/solver/ray_rllib/gnn/algorithms/__init__.py create mode 100644 skdecide/hub/solver/ray_rllib/gnn/algorithms/ppo/__init__.py create mode 100644 skdecide/hub/solver/ray_rllib/gnn/algorithms/ppo/ppo.py create mode 100644 skdecide/hub/solver/ray_rllib/gnn/algorithms/ppo/ppo_torch_policy.py create mode 100644 skdecide/hub/solver/ray_rllib/gnn/models/__init__.py create mode 100644 skdecide/hub/solver/ray_rllib/gnn/models/torch/__init__.py create mode 100644 skdecide/hub/solver/ray_rllib/gnn/models/torch/complex_input_net.py create mode 100644 skdecide/hub/solver/ray_rllib/gnn/models/torch/gnn.py create mode 100644 skdecide/hub/solver/ray_rllib/gnn/policy/__init__.py create mode 100644 skdecide/hub/solver/ray_rllib/gnn/policy/sample_batch.py create mode 100644 skdecide/hub/solver/ray_rllib/gnn/policy/torch_graph_policy.py create mode 100644 skdecide/hub/solver/ray_rllib/gnn/policy/torch_mixins.py create mode 100644 skdecide/hub/solver/ray_rllib/gnn/torch_layers.py create mode 100644 skdecide/hub/solver/ray_rllib/gnn/utils/__init__.py create mode 100644 skdecide/hub/solver/ray_rllib/gnn/utils/spaces/__init__.py create mode 100644 skdecide/hub/solver/ray_rllib/gnn/utils/spaces/space_utils.py create mode 100644 skdecide/hub/solver/ray_rllib/gnn/utils/torch_utils.py create mode 100644 tests/solvers/python/test_gnn_ray_rllib.py diff --git a/skdecide/hub/solver/ray_rllib/custom_models.py b/skdecide/hub/solver/ray_rllib/custom_models.py index af93f358a3..44462eedc2 100644 --- a/skdecide/hub/solver/ray_rllib/custom_models.py +++ b/skdecide/hub/solver/ray_rllib/custom_models.py @@ -1,4 +1,5 @@ from gymnasium.spaces import flatten_space +from ray.rllib import SampleBatch from ray.rllib.models.tf.fcnet import FullyConnectedNetwork as TFFullyConnectedNetwork from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.torch.fcnet import ( @@ -9,6 +10,15 @@ from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray, unbatch from ray.rllib.utils.torch_utils import FLOAT_MAX, FLOAT_MIN +from skdecide.hub.solver.ray_rllib.gnn.models.torch.complex_input_net import ( + GraphComplexInputNetwork, +) +from skdecide.hub.solver.ray_rllib.gnn.models.torch.gnn import GnnBasedModel +from skdecide.hub.solver.ray_rllib.gnn.utils.spaces.space_utils import ( + is_graph_dict_multiinput_space, + is_graph_dict_space, +) + tf1, tf, tfv = try_import_tf() torch, nn = try_import_torch() @@ -98,8 +108,20 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, name, **k self.action_ids_shifted = torch.arange(1, num_outputs + 1, dtype=torch.int64) self.true_obs_space = model_config["custom_model_config"]["true_obs_space"] - self.pred_action_embed_model = TorchFullyConnectedNetwork( - flatten_space(self.true_obs_space), + if is_graph_dict_space(self.true_obs_space): + pred_action_embed_model_cls = GnnBasedModel + self.obs_with_graph = True + embed_model_obs_space = self.true_obs_space + elif is_graph_dict_multiinput_space(self.true_obs_space): + pred_action_embed_model_cls = GraphComplexInputNetwork + self.obs_with_graph = True + embed_model_obs_space = self.true_obs_space + else: + pred_action_embed_model_cls = TorchFullyConnectedNetwork + self.obs_with_graph = False + embed_model_obs_space = flatten_space(self.true_obs_space) + self.pred_action_embed_model = pred_action_embed_model_cls( + embed_model_obs_space, action_space, model_config["custom_model_config"]["action_embed_size"], model_config, @@ -115,16 +137,21 @@ def forward(self, input_dict, state, seq_lens): # Extract the available actions mask tensor from the observation. valid_avail_actions_mask = input_dict["obs"]["valid_avail_actions_mask"] - # Unbatch true observations before flattening them - unbatched_true_obs = unbatch(input_dict["obs"]["true_obs"]) + if self.obs_with_graph: + # use directly the obs (already converted at proper format by custom `convert_to_torch_tensor`) + embed_model_obs = input_dict["obs"]["true_obs"] + else: + # Unbatch true observations before flattening them + embed_model_obs = torch.stack( + [ + flatten_to_single_ndarray(o) + for o in unbatch(input_dict["obs"]["true_obs"]) + ] + ) # Compute the predicted action embedding pred_action_embed, _ = self.pred_action_embed_model( - { - "obs": torch.stack( - [flatten_to_single_ndarray(o) for o in unbatched_true_obs] - ) - } + SampleBatch({SampleBatch.OBS: embed_model_obs}) ) # Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the diff --git a/skdecide/hub/solver/ray_rllib/gnn/__init__.py b/skdecide/hub/solver/ray_rllib/gnn/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/skdecide/hub/solver/ray_rllib/gnn/algorithms/__init__.py b/skdecide/hub/solver/ray_rllib/gnn/algorithms/__init__.py new file mode 100644 index 0000000000..d5d5c40f48 --- /dev/null +++ b/skdecide/hub/solver/ray_rllib/gnn/algorithms/__init__.py @@ -0,0 +1 @@ +from .ppo.ppo import GraphPPO diff --git a/skdecide/hub/solver/ray_rllib/gnn/algorithms/ppo/__init__.py b/skdecide/hub/solver/ray_rllib/gnn/algorithms/ppo/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/skdecide/hub/solver/ray_rllib/gnn/algorithms/ppo/ppo.py b/skdecide/hub/solver/ray_rllib/gnn/algorithms/ppo/ppo.py new file mode 100644 index 0000000000..75bb88bcb2 --- /dev/null +++ b/skdecide/hub/solver/ray_rllib/gnn/algorithms/ppo/ppo.py @@ -0,0 +1,21 @@ +from typing import Optional + +from ray.rllib import Policy +from ray.rllib.algorithms import PPO, AlgorithmConfig + +from skdecide.hub.solver.ray_rllib.gnn.algorithms.ppo.ppo_torch_policy import ( + PPOTorchGraphPolicy, +) + + +class GraphPPO(PPO): + @classmethod + def get_default_policy_class( + cls, config: AlgorithmConfig + ) -> Optional[type[Policy]]: + if config["framework"] == "torch": + return PPOTorchGraphPolicy + elif config["framework"] == "tf": + raise NotImplementedError("GraphPPO implemented for torch context") + else: + raise NotImplementedError("GraphPPO implemented for torch context") diff --git a/skdecide/hub/solver/ray_rllib/gnn/algorithms/ppo/ppo_torch_policy.py b/skdecide/hub/solver/ray_rllib/gnn/algorithms/ppo/ppo_torch_policy.py new file mode 100644 index 0000000000..673c3a165a --- /dev/null +++ b/skdecide/hub/solver/ray_rllib/gnn/algorithms/ppo/ppo_torch_policy.py @@ -0,0 +1,7 @@ +from ray.rllib.algorithms.ppo import PPOTorchPolicy + +from skdecide.hub.solver.ray_rllib.gnn.policy.torch_graph_policy import TorchGraphPolicy + + +class PPOTorchGraphPolicy(TorchGraphPolicy, PPOTorchPolicy): + ... diff --git a/skdecide/hub/solver/ray_rllib/gnn/models/__init__.py b/skdecide/hub/solver/ray_rllib/gnn/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/skdecide/hub/solver/ray_rllib/gnn/models/torch/__init__.py b/skdecide/hub/solver/ray_rllib/gnn/models/torch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/skdecide/hub/solver/ray_rllib/gnn/models/torch/complex_input_net.py b/skdecide/hub/solver/ray_rllib/gnn/models/torch/complex_input_net.py new file mode 100644 index 0000000000..a3d1af2260 --- /dev/null +++ b/skdecide/hub/solver/ray_rllib/gnn/models/torch/complex_input_net.py @@ -0,0 +1,66 @@ +import gymnasium as gym +from ray.rllib import SampleBatch +from ray.rllib.models.torch.complex_input_net import ComplexInputNetwork +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.utils.typing import TensorType +from torch import nn + +from skdecide.hub.solver.ray_rllib.gnn.models.torch.gnn import GnnBasedModel +from skdecide.hub.solver.ray_rllib.gnn.utils.spaces.space_utils import ( + is_graph_dict_space, +) + + +class GraphComplexInputNetwork(TorchModelV2, nn.Module): + def __init__(self, obs_space, action_space, num_outputs, model_config, name): + if not model_config.get("_disable_preprocessor_api"): + raise ValueError( + "This model is intent to be used only when preprocessors are disabled." + ) + if not isinstance(obs_space, gym.spaces.Dict): + raise ValueError( + "This model is intent to be used only on dict observation space." + ) + + nn.Module.__init__(self) + super().__init__(obs_space, action_space, num_outputs, model_config, name) + + self.gnn = nn.ModuleDict() + post_graph_obs_subspaces = dict(obs_space.spaces) + for k, subspace in obs_space.spaces.items(): + if is_graph_dict_space(subspace): + submodel_name = f"gnn_{k}" + gnn = GnnBasedModel( + obs_space=subspace, + action_space=action_space, + num_outputs=None, + model_config=model_config, + framework="torch", + name=submodel_name, + ) + self.add_module(submodel_name, gnn) + self.gnn[k] = gnn + post_graph_obs_subspaces[k] = gnn.features_space + + post_graph_obs_space = gym.spaces.Dict(post_graph_obs_subspaces) + self.post_graph_model = ComplexInputNetwork( + obs_space=post_graph_obs_space, + action_space=action_space, + num_outputs=num_outputs, + model_config=model_config, + name="post_graph_model", + ) + + def forward(self, input_dict: SampleBatch, state, seq_lens): + post_graph_input_dict = input_dict.copy(shallow=True) + obs = input_dict["obs"] + post_graph_obs = dict(obs) + for k, gnn in self.gnn.items(): + post_graph_obs[k] = gnn(SampleBatch({SampleBatch.OBS: obs[k]})) + post_graph_input_dict["obs"] = post_graph_obs + return self.post_graph_model( + input_dict=post_graph_input_dict, state=state, seq_lens=seq_lens + ) + + def value_function(self) -> TensorType: + return self.post_graph_model.value_function() diff --git a/skdecide/hub/solver/ray_rllib/gnn/models/torch/gnn.py b/skdecide/hub/solver/ray_rllib/gnn/models/torch/gnn.py new file mode 100644 index 0000000000..7f235126db --- /dev/null +++ b/skdecide/hub/solver/ray_rllib/gnn/models/torch/gnn.py @@ -0,0 +1,77 @@ +from collections import defaultdict +from typing import Optional + +import gymnasium as gym +import numpy as np +from ray.rllib.models.torch.fcnet import FullyConnectedNetwork +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.utils.typing import ModelConfigDict +from torch import nn + +from skdecide.hub.solver.ray_rllib.gnn.torch_layers import GraphFeaturesExtractor +from skdecide.hub.solver.ray_rllib.gnn.utils.spaces.space_utils import ( + convert_dict_space_to_graph_space, + is_graph_dict_space, +) + + +class GnnBasedModel(TorchModelV2, nn.Module): + def __init__( + self, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + num_outputs: Optional[int], + model_config: ModelConfigDict, + name: str, + **kw, + ): + nn.Module.__init__(self) + super().__init__(obs_space, action_space, num_outputs, model_config, name) + + # config for custom model + custom_config = defaultdict( + lambda: None, # will return None for missing keys + model_config.get("custom_model_config", {}), + ) + + # gnn-based feature extractor + features_extractor_kwargs = custom_config.get("features_extractor", {}) + assert is_graph_dict_space( + obs_space + ), f"{self.__class__.__name__} can only be applied to Graph observation spaces." + graph_observation_space = convert_dict_space_to_graph_space(obs_space) + self.features_extractor = GraphFeaturesExtractor( + observation_space=graph_observation_space, **features_extractor_kwargs + ) + self.features_space = gym.spaces.Box( + low=-np.inf, high=np.inf, shape=(self.features_extractor.features_dim,) + ) + + if num_outputs is None: + # only feature extraction (e.g. to be used by GraphComplexInputNetwork) + self.num_outputs = self.features_extractor.features_dim + self.pred_action_embed_model = None + else: + # fully connected network + self.pred_action_embed_model = FullyConnectedNetwork( + obs_space=self.features_space, + action_space=action_space, + num_outputs=num_outputs, + model_config=model_config, + name=name + "_pred_action_embed", + ) + + def forward(self, input_dict, state, seq_lens): + obs = input_dict["obs"] + features = self.features_extractor(obs) + if self.pred_action_embed_model is None: + return features, state + else: + return self.pred_action_embed_model( + input_dict={"obs": features}, + state=state, + seq_lens=seq_lens, + ) + + def value_function(self): + return self.pred_action_embed_model.value_function() diff --git a/skdecide/hub/solver/ray_rllib/gnn/policy/__init__.py b/skdecide/hub/solver/ray_rllib/gnn/policy/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/skdecide/hub/solver/ray_rllib/gnn/policy/sample_batch.py b/skdecide/hub/solver/ray_rllib/gnn/policy/sample_batch.py new file mode 100644 index 0000000000..ffe6266b88 --- /dev/null +++ b/skdecide/hub/solver/ray_rllib/gnn/policy/sample_batch.py @@ -0,0 +1,116 @@ +from numbers import Number +from typing import Any, Optional, Union + +import gymnasium as gym +import numpy as np +import tree +from ray.rllib import SampleBatch +from ray.rllib.policy.sample_batch import attempt_count_timesteps, tf, torch +from ray.rllib.utils.typing import ViewRequirementsDict + + +def _pop_graph_items( + full_dict: dict[Any, Any] +) -> dict[Any, Union[gym.spaces.GraphInstance, list[gym.spaces.GraphInstance]]]: + graph_dict = {} + for k, v in full_dict.items(): + if isinstance(v, gym.spaces.GraphInstance) or ( + isinstance(v, list) and isinstance(v[0], gym.spaces.GraphInstance) + ): + graph_dict[k] = v + for k in graph_dict: + full_dict.pop(k) + return graph_dict + + +def _split_graph_requirements( + full_dict: ViewRequirementsDict, +) -> tuple[ViewRequirementsDict, ViewRequirementsDict]: + graph_dict = {} + for k, v in full_dict.items(): + if isinstance(v.space, gym.spaces.Graph): + graph_dict[k] = v + wo_graph_dict = {k: v for k, v in full_dict.items() if k not in graph_dict} + return graph_dict, wo_graph_dict + + +class GraphSampleBatch(SampleBatch): + def __init__(self, *args, **kwargs): + """Constructs a sample batch with possibly graph obs. + + See `ray.rllib.SampleBatch` for more information. + + """ + # split graph samples from others. + dict_graphs = _pop_graph_items(kwargs) + dict_from_args = dict(*args) + dict_graphs.update(_pop_graph_items(dict_from_args)) + + super().__init__(dict_from_args, **kwargs) + super().update(dict_graphs) + + def copy(self, shallow: bool = False) -> "SampleBatch": + """Creates a deep or shallow copy of this SampleBatch and returns it. + + Args: + shallow: Whether the copying should be done shallowly. + + Returns: + A deep or shallow copy of this SampleBatch object. + """ + copy_ = dict(self) + data = tree.map_structure( + lambda v: ( + np.array(v, copy=not shallow) if isinstance(v, np.ndarray) else v + ), + copy_, + ) + copy_ = GraphSampleBatch( + data, + _time_major=self.time_major, + _zero_padded=self.zero_padded, + _max_seq_len=self.max_seq_len, + _num_grad_updates=self.num_grad_updates, + ) + copy_.set_get_interceptor(self.get_interceptor) + copy_.added_keys = self.added_keys + copy_.deleted_keys = self.deleted_keys + copy_.accessed_keys = self.accessed_keys + return copy_ + + def get_single_step_input_dict( + self, + view_requirements: ViewRequirementsDict, + index: Union[str, int] = "last", + ) -> "SampleBatch": + ( + view_requirements_graphs, + view_requirements_wo_graphs, + ) = _split_graph_requirements(view_requirements) + # w/o graphs + sample = GraphSampleBatch( + super().get_single_step_input_dict(view_requirements_wo_graphs, index) + ) + # handle graphs + last_mappings = { + SampleBatch.OBS: SampleBatch.NEXT_OBS, + SampleBatch.PREV_ACTIONS: SampleBatch.ACTIONS, + SampleBatch.PREV_REWARDS: SampleBatch.REWARDS, + } + for view_col, view_req in view_requirements_graphs.items(): + if view_req.used_for_compute_actions is False: + continue + + # Create batches of size 1 (single-agent input-dict). + data_col = view_req.data_col or view_col + if index == "last": + data_col = last_mappings.get(data_col, data_col) + if view_req.shift_from is not None: + raise NotImplementedError() + else: + sample[view_col] = self[data_col][-1:] + else: + sample[view_col] = self[data_col][ + index : index + 1 if index != -1 else None + ] + return sample diff --git a/skdecide/hub/solver/ray_rllib/gnn/policy/torch_graph_policy.py b/skdecide/hub/solver/ray_rllib/gnn/policy/torch_graph_policy.py new file mode 100644 index 0000000000..84b599614c --- /dev/null +++ b/skdecide/hub/solver/ray_rllib/gnn/policy/torch_graph_policy.py @@ -0,0 +1,16 @@ +import functools + +from ray.rllib import SampleBatch +from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 + +from skdecide.hub.solver.ray_rllib.gnn.utils.torch_utils import convert_to_torch_tensor + + +class TorchGraphPolicy(TorchPolicyV2): + def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch, device=None): + if not isinstance(postprocessed_batch, SampleBatch): + postprocessed_batch = SampleBatch(postprocessed_batch) + postprocessed_batch.set_get_interceptor( + functools.partial(convert_to_torch_tensor, device=device or self.device) + ) + return postprocessed_batch diff --git a/skdecide/hub/solver/ray_rllib/gnn/policy/torch_mixins.py b/skdecide/hub/solver/ray_rllib/gnn/policy/torch_mixins.py new file mode 100644 index 0000000000..775bc3eee2 --- /dev/null +++ b/skdecide/hub/solver/ray_rllib/gnn/policy/torch_mixins.py @@ -0,0 +1,26 @@ +from ray.rllib.policy.torch_mixins import ValueNetworkMixin + +from skdecide.hub.solver.ray_rllib.gnn.policy.sample_batch import GraphSampleBatch + + +class ValueNetworkGraphMixin(ValueNetworkMixin): + def __init__(self, config): + if config.get("use_gae") or config.get("vtrace"): + # Input dict is provided to us automatically via the Model's + # requirements. It's a single-timestep (last one in trajectory) + # input_dict. + + def value(**input_dict): + input_dict = GraphSampleBatch(input_dict) + input_dict = self._lazy_tensor_dict(input_dict) + model_out, _ = self.model(input_dict) + # [0] = remove the batch dim. + return self.model.value_function()[0].item() + + # When not doing GAE, we do not require the value function's output. + else: + + def value(*args, **kwargs): + return 0.0 + + self._value = value diff --git a/skdecide/hub/solver/ray_rllib/gnn/torch_layers.py b/skdecide/hub/solver/ray_rllib/gnn/torch_layers.py new file mode 100644 index 0000000000..8180b35f88 --- /dev/null +++ b/skdecide/hub/solver/ray_rllib/gnn/torch_layers.py @@ -0,0 +1,114 @@ +from typing import Any, Optional + +import gymnasium as gym +import numpy as np +import torch as th +import torch_geometric as thg +from torch import nn +from torch_geometric.nn import global_max_pool + + +class GraphFeaturesExtractor(nn.Module): + """Graph feature extractor for Graph observation spaces. + + Will chain a gnn with a reduction layer to extract a fixed number of features. + The user can specify both the gnn and reduction layer. + + By default, we use: + - gnn: a 2-layers GCN + - reduction layer: global_max_pool + linear layer + relu + + Args: + observation_space: + features_dim: Number of extracted features + - If reduction_layer_class is given, should match the output of this network. + - If reduction_layer is None, will be used by the default network as its output dimension. + gnn_out_dim: dimension of the node embedding in gnn output + - If gnn is given, should not be None and should match the output of gnn + - If gnn is not given, will be used to generate it. By default, gnn_out_dim = 2 * features_dim + gnn_class: GNN network class (for instance chosen from `torch_geometric.nn.models` used to embed the graph observations) + gnn_kwargs: used by `gnn_class.__init__()`. Without effect if `gnn_class` is None. + reduction_layer_class: network class to be plugged after the gnn to get a fixed number of features. + reduction_layer_kwargs: used by `reduction_layer_class.__init__()`. Without effect if `reduction_layer_class` is None. + + """ + + def __init__( + self, + observation_space: gym.spaces.Graph, + features_dim: int = 64, + gnn_out_dim: Optional[int] = None, + gnn_class: Optional[type[nn.Module]] = None, + gnn_kwargs: Optional[dict[str, Any]] = None, + reduction_layer_class: Optional[type[nn.Module]] = None, + reduction_layer_kwargs: Optional[dict[str, Any]] = None, + ): + super().__init__() + self.features_dim = features_dim + + if gnn_out_dim is None: + if gnn_class is None: + gnn_out_dim = 2 * features_dim + else: + raise ValueError( + "`gnn_out_dim` cannot be None if `gnn` is not None, " + "and should match `gnn` output." + ) + + if gnn_class is None: + node_features_dim = int(np.prod(observation_space.node_space.shape)) + self.gnn = thg.nn.models.GCN( + in_channels=node_features_dim, + hidden_channels=gnn_out_dim, + num_layers=2, + dropout=0.2, + ) + else: + if gnn_kwargs is None: + gnn_kwargs = {} + self.gnn = gnn_class(**gnn_kwargs) + + if reduction_layer_class is None: + self.reduction_layer = _DefaultReductionLayer( + gnn_out_dim=gnn_out_dim, features_dim=features_dim + ) + else: + if reduction_layer_kwargs is None: + reduction_layer_kwargs = {} + self.reduction_layer = reduction_layer_class(**reduction_layer_kwargs) + + def forward(self, observations) -> th.Tensor: + x, edge_index, edge_attr, batch = ( + observations.x, + observations.edge_index, + observations.edge_attr, + observations.batch, + ) + # construct edge weights, for GNNs needing it, as the first edge feature + edge_weight = edge_attr[:, 0] + h = self.gnn( + x=x, edge_index=edge_index, edge_weight=edge_weight, edge_attr=edge_attr + ) + embedded_observations = thg.data.Data( + x=h, edge_index=edge_index, edge_attr=edge_attr, batch=batch + ) + h = self.reduction_layer(embedded_observations=embedded_observations) + return h + + +class _DefaultReductionLayer(nn.Module): + def __init__(self, gnn_out_dim: int, features_dim: int): + super().__init__() + self.gnn_out_dim = gnn_out_dim + self.features_dim = features_dim + self.linear_layer = nn.Linear(gnn_out_dim, features_dim) + + def forward(self, embedded_observations: thg.data.Data) -> th.Tensor: + x, edge_index, batch = ( + embedded_observations.x, + embedded_observations.edge_index, + embedded_observations.batch, + ) + h = global_max_pool(x, batch) + h = self.linear_layer(h).relu() + return h diff --git a/skdecide/hub/solver/ray_rllib/gnn/utils/__init__.py b/skdecide/hub/solver/ray_rllib/gnn/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/skdecide/hub/solver/ray_rllib/gnn/utils/spaces/__init__.py b/skdecide/hub/solver/ray_rllib/gnn/utils/spaces/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/skdecide/hub/solver/ray_rllib/gnn/utils/spaces/space_utils.py b/skdecide/hub/solver/ray_rllib/gnn/utils/spaces/space_utils.py new file mode 100644 index 0000000000..11df126454 --- /dev/null +++ b/skdecide/hub/solver/ray_rllib/gnn/utils/spaces/space_utils.py @@ -0,0 +1,129 @@ +from typing import Any, Union + +import gymnasium as gym +import numpy as np + + +def convert_graph_space_to_dict_space(space: gym.spaces.Graph) -> gym.spaces.Dict: + # artificially decide of 2 nodes and 1 edge (for dummy samples auto-generated by ray.rllib) + return gym.spaces.Dict( + dict( + nodes=repeat_space(space.node_space, n_rep=2), + edges=repeat_space(space.edge_space, n_rep=1), + edge_links=gym.spaces.Box(low=0, high=2, shape=(1, 2), dtype=np.int_), + ) + ) + + +def repeat_space_box(space: gym.spaces.Box, n_rep: int) -> gym.spaces.Box: + rep_low = np.repeat(space.low[None, :], n_rep, axis=0) + rep_high = np.repeat(space.high[None, :], n_rep, axis=0) + rep_shape = (n_rep,) + space.shape + return gym.spaces.Box( + low=rep_low, + high=rep_high, + shape=rep_shape, + dtype=space.dtype, + ) + + +def repeat_space_discrete(space: gym.spaces.Discrete, n_rep: int) -> gym.spaces.Box: + return gym.spaces.Box( + low=space.start, + high=space.start + space.n - 1, + shape=(n_rep, 1), + dtype=space.dtype, + ) + + +def repeat_space(space: Union[gym.spaces.Box, gym.spaces.Discrete], n_rep: int): + if isinstance(space, gym.spaces.Box): + return repeat_space_box(space=space, n_rep=n_rep) + elif isinstance(space, gym.spaces.Discrete): + return repeat_space_discrete(space=space, n_rep=n_rep) + else: + raise NotImplementedError() + + +def remove_first_axis_space(space: gym.spaces.Box) -> gym.spaces.Box: + return gym.spaces.Box( + low=space.low[0, :], + high=space.high[0, :], + shape=space.shape[1:], + dtype=space.dtype, + ) + + +def convert_graph_to_dict(x: gym.spaces.GraphInstance) -> dict[str, np.ndarray]: + return dict( + nodes=x.nodes, + edges=x.edges, + edge_links=x.edge_links, + ) + + +def convert_dict_space_to_graph_space(space: gym.spaces.Dict) -> gym.spaces.Graph: + return gym.spaces.Graph( + node_space=remove_first_axis_space(space.spaces["nodes"]), + edge_space=remove_first_axis_space(space.spaces["edges"]), + ) + + +def convert_dict_to_graph(x: dict[str, np.ndarray]) -> gym.spaces.GraphInstance: + return gym.spaces.GraphInstance( + nodes=x["nodes"], edges=x["edges"], edge_links=x["edge_links"] + ) + + +def is_graph_dict(x: Any) -> bool: + return ( + isinstance(x, dict) + and len(x) == 3 + and "nodes" in x + and "edges" in x + and "edge_links" in x + ) + + +def is_graph_dict_space(x: gym.spaces.Space) -> bool: + return ( + isinstance(x, gym.spaces.Dict) + and len(x.spaces) == 3 + and "nodes" in x.spaces + and "edges" in x.spaces + and "edge_links" in x.spaces + ) + + +def is_graph_dict_multiinput(x: Any) -> bool: + return isinstance(x, dict) and any([is_graph_dict(v) for v in x.values()]) + + +def is_graph_dict_multiinput_space(x: gym.spaces.Space) -> bool: + return isinstance(x, gym.spaces.Dict) and any( + [is_graph_dict_space(subspace) for subspace in x.values()] + ) + + +def is_masked_obs(x: Any) -> bool: + return ( + isinstance(x, dict) + and len(x) == 2 + and "true_obs" in x + and "valid_avail_actions_mask" in x + ) + + +def is_masked_obs_space(x: gym.spaces.Space) -> bool: + return ( + isinstance(x, gym.spaces.Dict) + and len(x.spaces) == 2 + and "true_obs" in x.spaces + and "valid_avail_actions_mask" in x.spaces + ) + + +def extract_graph_dict_from_batched_graph_dict( + batched_graph_dict: dict[str, np.ndarray], index: int +) -> dict[str, np.ndarray]: + return {k: v[index, :] for k, v in batched_graph_dict.items()} diff --git a/skdecide/hub/solver/ray_rllib/gnn/utils/torch_utils.py b/skdecide/hub/solver/ray_rllib/gnn/utils/torch_utils.py new file mode 100644 index 0000000000..f74be0e97e --- /dev/null +++ b/skdecide/hub/solver/ray_rllib/gnn/utils/torch_utils.py @@ -0,0 +1,142 @@ +from typing import Optional, Union + +import gymnasium as gym +import numpy as np +import torch as th +import torch_geometric as thg +from ray.rllib.utils.torch_utils import ( + convert_to_torch_tensor as convert_to_torch_tensor_original, +) +from ray.rllib.utils.typing import TensorStructType + +from skdecide.hub.solver.ray_rllib.gnn.utils.spaces.space_utils import ( + extract_graph_dict_from_batched_graph_dict, + is_graph_dict, + is_graph_dict_multiinput, + is_masked_obs, +) + + +def graph_obs_to_thg_data( + obs: gym.spaces.GraphInstance, + device: Optional[th.device] = None, + pin_memory: bool = False, +) -> thg.data.Data: + # Node features + flatten_node_features = obs.nodes.reshape((len(obs.nodes), -1)) + x = th.tensor(flatten_node_features).float() + # Edge features + if obs.edges is None: + edge_attr = None + else: + flatten_edge_features = obs.edges.reshape((len(obs.edges), -1)) + edge_attr = th.tensor(flatten_edge_features).float() + edge_index = th.tensor(obs.edge_links, dtype=th.long).t().contiguous().view(2, -1) + # thg.Data + data = thg.data.Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + # Pin the tensor's memory (for faster transfer to GPU later). + if pin_memory and th.cuda.is_available(): + data.pin_memory() + + return data if device is None else data.to(device) + + +def convert_to_torch_tensor( + x: Union[ + TensorStructType, + thg.data.Data, + gym.spaces.GraphInstance, + list[gym.spaces.GraphInstance], + ], + device: Optional[str] = None, + pin_memory: bool = False, +) -> Union[TensorStructType, thg.data.Data]: + """Converts any struct to torch.Tensors. + + Args: + x: Any (possibly nested) struct, the values in which will be + converted and returned as a new struct with all leaves converted + to torch tensors. + device: The device to create the tensor on. + pin_memory: If True, will call the `pin_memory()` method on the created tensors. + + Returns: + Any: A new struct with the same structure as `x`, but with all + values converted to torch Tensor types. This does not convert possibly + nested elements that are None because torch has no representation for that. + """ + if isinstance(x, gym.spaces.GraphInstance): + return graph_obs_to_thg_data(x, device=device, pin_memory=pin_memory) + elif isinstance(x, list) and isinstance(x[0], gym.spaces.GraphInstance): + return thg.data.Batch.from_data_list( + [ + graph_obs_to_thg_data(graph, device=device, pin_memory=pin_memory) + for graph in x + ] + ) + elif isinstance(x, thg.data.Data): + return x + elif is_masked_obs(x): + return { + k: convert_to_torch_tensor(v, device=device, pin_memory=pin_memory) + for k, v in x.items() + } + elif is_graph_dict(x): + return batched_graph_dict_to_thg_data(x, device=device, pin_memory=pin_memory) + elif is_graph_dict_multiinput(x): + return { + k: convert_to_torch_tensor(v, device=device, pin_memory=pin_memory) + for k, v in x.items() + } + else: + return convert_to_torch_tensor_original( + x=x, device=device, pin_memory=pin_memory + ) + + +def graph_dict_to_thg_data( + graph_dict: dict[str, np.ndarray], + device: Optional[str] = None, + pin_memory: bool = False, +): + # Node features + flatten_node_features = graph_dict["nodes"].reshape((len(graph_dict["nodes"]), -1)) + x = th.tensor(flatten_node_features).float() + # Edge features + if graph_dict["edges"] is None: + edge_attr = None + else: + flatten_edge_features = graph_dict["edges"].reshape( + (len(graph_dict["edges"]), -1) + ) + edge_attr = th.tensor(flatten_edge_features).float() + edge_index = ( + th.tensor(graph_dict["edge_links"], dtype=th.long).t().contiguous().view(2, -1) + ) + # thg.Data + data = thg.data.Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + # Pin the tensor's memory (for faster transfer to GPU later). + if pin_memory and th.cuda.is_available(): + data.pin_memory() + + return data if device is None else data.to(device) + + +def batched_graph_dict_to_thg_data( + batched_graph_dict: dict[str, np.ndarray], + device: Optional[str] = None, + pin_memory: bool = False, +): + batch_size = batched_graph_dict["nodes"].shape[0] + return thg.data.Batch.from_data_list( + [ + graph_dict_to_thg_data( + graph_dict=extract_graph_dict_from_batched_graph_dict( + batched_graph_dict=batched_graph_dict, index=index + ), + device=device, + pin_memory=pin_memory, + ) + for index in range(batch_size) + ] + ) diff --git a/skdecide/hub/solver/ray_rllib/ray_rllib.py b/skdecide/hub/solver/ray_rllib/ray_rllib.py index a91167c224..8dc9715035 100644 --- a/skdecide/hub/solver/ray_rllib/ray_rllib.py +++ b/skdecide/hub/solver/ray_rllib/ray_rllib.py @@ -14,7 +14,6 @@ CategoricalHyperparameter, FloatHyperparameter, IntegerHyperparameter, - SubBrickKwargsHyperparameter, ) from packaging.version import Version from ray.rllib.algorithms import DQN, PPO, SAC @@ -37,6 +36,12 @@ from skdecide.hub.space.gym import GymSpace from .custom_models import TFParametricActionsModel, TorchParametricActionsModel +from .gnn.models.torch.complex_input_net import GraphComplexInputNetwork +from .gnn.models.torch.gnn import GnnBasedModel +from .gnn.utils.spaces.space_utils import ( + convert_graph_space_to_dict_space, + convert_graph_to_dict, +) if TYPE_CHECKING: # imports useful only in annotations, may change according to releases @@ -99,6 +104,9 @@ class RayRLlib(Solver, Policies, Restorable, Maskable): ), ] + MASKABLE_ALGOS = ["APPO", "BC", "DQN", "Rainbow", "IMPALA", "MARWIL", "PPO"] + """The only algos being able to handle action masking in ray[rllib]==2.9.0.""" + def __init__( self, domain_factory: Callable[[], Domain], @@ -110,8 +118,8 @@ def __init__( Callable[[str, Optional["EpisodeV2"], Optional["RolloutWorker"]], str] ] = None, action_embed_sizes: Optional[dict[str, int]] = None, - config_kwargs: Optional[dict[str, Any]] = None, callback: Callable[[RayRLlib], bool] = lambda solver: False, + graph_feature_extractors_kwargs: Optional[dict[str, Any]] = None, **kwargs, ) -> None: """Initialize Ray RLlib. @@ -125,13 +133,13 @@ def __init__( policy_configs: The mapping from policy id (str) to additional config (dict) (leave default for single policy). policy_mapping_fn: The function mapping agent ids to policy ids (leave default for single policy). action_embed_sizes: The mapping from policy id (str) to action embedding size (only used with domains filtering allowed actions per state, default to 2) - config_kwargs: keyword arguments for the `AlgorithmConfigKwargs` class used to update programmatically the algorithm config. - Will be used by hyerparameters tuners like optuna. Should probably not be used directly by the user, - who could rather directly specify the correct `config`. callback: function called at each solver iteration. If returning true, the solve process stops and exit the current train iteration. However, if train_iterations > 1, another train loop will be entered after that. (One can code its callback in such a way that further training loop are stopped directly after that.) + graph_feature_extractors_kwargs: in case of graph observations, these are the kwargs to the GraphFreaturesExtractor model + used to extract features. (See skdecide.hub.solver.ray_rllib.gnn.models.torch.gnn.GraphFeaturesExtractor) + **kwargs: used to update the algo config with kwargs automatically filled by optuna. """ Solver.__init__(self, domain_factory=domain_factory) @@ -156,6 +164,10 @@ def __init__( raise RuntimeError( "Action embed size keys must be the same as policy config keys" ) + if graph_feature_extractors_kwargs is None: + self._graph_feature_extractors_kwargs = {} + else: + self._graph_feature_extractors_kwargs = graph_feature_extractors_kwargs ray.init(ignore_reinit_error=True) self._algo_callbacks: Optional[DefaultCallbacks] = None @@ -168,16 +180,23 @@ def __init__( self._wrapped_observation_space = domain.get_observation_space() # action masking? - domain = self._domain_factory() self._action_masking = ( (not isinstance(domain, UnrestrictedActions)) and all( isinstance(agent_action_space, EnumerableSpace) for agent_action_space in self._wrapped_action_space.values() ) - and self._algo_class.__name__ - # Only the following algos handle action masking in ray[rllib]==2.9.0 - in ["APPO", "BC", "DQN", "Rainbow", "IMPALA", "MARWIL", "PPO"] + and ( + self._algo_class.__name__ in RayRLlib.MASKABLE_ALGOS + or self._algo_class.__name__ + in [f"Graph{algo_name}" for algo_name in RayRLlib.MASKABLE_ALGOS] + ) + ) + + # graph obs? + self._is_graph_obs = _is_graph_space(self._wrapped_observation_space) + self._is_graph_multiinput_obs = _is_graph_multiinput_space( + (self._wrapped_observation_space) ) # Handle kwargs (potentially generated by optuna) @@ -268,7 +287,13 @@ def _load(self, path: str): self.set_callback() # ensure putting back actual callback def _init_algo(self) -> None: + # custom model? if self._action_masking: + if self._is_graph_obs or self._is_graph_multiinput_obs: + # let the observation pass as is + self._config.experimental( + _disable_preprocessor_api=True, + ) if self._config.get("framework") not in ["tf", "tf2", "torch"]: raise RuntimeError( "Action masking (invalid action filtering) for RLlib requires TensorFlow or PyTorch to be installed" @@ -290,6 +315,37 @@ def _init_algo(self) -> None: self._config.training( model={"vf_share_layers": True}, ) + elif self._is_graph_obs: + if self._config.get("framework") not in ["torch"]: + raise RuntimeError( + "Graph observation with RLlib requires PyTorch framework." + ) + ModelCatalog.register_custom_model( + "skdecide_rllib_graph_model", + GnnBasedModel + if self._config.get("framework") == "torch" + else NotProvided, + ) + # let the observation pass as is + self._config.experimental( + _disable_preprocessor_api=True, + ) + elif self._is_graph_multiinput_obs: + if self._config.get("framework") not in ["torch"]: + raise RuntimeError( + "Graph observation with RLlib requires PyTorch framework." + ) + ModelCatalog.register_custom_model( + "skdecide_rllib_graph_multiinput_model", + GraphComplexInputNetwork + if self._config.get("framework") == "torch" + else NotProvided, + ) + # let the observation pass as is + self._config.experimental( + _disable_preprocessor_api=True, + ) + self._wrap_action = lambda action: _wrap_action( action=action, wrapped_action_space=self._wrapped_action_space ) @@ -312,23 +368,31 @@ def _init_algo(self) -> None: # Overwrite multi-agent config pol_obs_spaces = ( { - self._policy_mapping_fn(k, None, None): v.unwrapped() - for k, v in self._wrapped_observation_space.items() + self._policy_mapping_fn(agent, None, None): _unwrap_agent_obs_space( + wrapped_observation_space=self._wrapped_observation_space, + agent=agent, + ) + for agent in self._wrapped_observation_space } if not self._action_masking else { - self._policy_mapping_fn(k, None, None): gym.spaces.Dict( + self._policy_mapping_fn(agent, None, None): gym.spaces.Dict( { - "true_obs": v.unwrapped(), + "true_obs": _unwrap_agent_obs_space( + wrapped_observation_space=self._wrapped_observation_space, + agent=agent, + ), "valid_avail_actions_mask": gym.spaces.Box( 0, 1, - shape=(len(self._wrapped_action_space[k].get_elements()),), + shape=( + len(self._wrapped_action_space[agent].get_elements()), + ), dtype=np.int8, ), } ) - for k, v in self._wrapped_observation_space.items() + for agent in self._wrapped_observation_space } ) pol_act_spaces = { @@ -336,13 +400,14 @@ def _init_algo(self) -> None: for k, v in self._wrapped_action_space.items() } - policies = ( - { - k: (None, pol_obs_spaces[k], pol_act_spaces[k], v or {}) - for k, v in self._policy_configs.items() - } - if not self._action_masking - else { + if self._action_masking: + if self._is_graph_obs or self._is_graph_multiinput_obs: + extra_custom_model_config_kwargs = { + "features_extractor": self._graph_feature_extractors_kwargs + } + else: + extra_custom_model_config_kwargs = {} + policies = { self._policy_mapping_fn(k, None, None): ( None, pol_obs_spaces[k], @@ -357,6 +422,7 @@ def _init_algo(self) -> None: "true_obs" ], "action_embed_size": action_embed_size, + **extra_custom_model_config_kwargs, }, }, }, @@ -364,7 +430,52 @@ def _init_algo(self) -> None: ) for k, action_embed_size in self._action_embed_sizes.items() } - ) + elif self._is_graph_obs: + policies = { + self._policy_mapping_fn(k, None, None): ( + None, + pol_obs_spaces[k], + pol_act_spaces[k], + { + **(v or {}), + **{ + "model": { + "custom_model": "skdecide_rllib_graph_model", + "custom_model_config": { + "features_extractor": self._graph_feature_extractors_kwargs, # kwargs for GraphFeaturesExtractor + }, + }, + }, + }, + ) + for k, v in self._policy_configs.items() + } + elif self._is_graph_multiinput_obs: + policies = { + self._policy_mapping_fn(k, None, None): ( + None, + pol_obs_spaces[k], + pol_act_spaces[k], + { + **(v or {}), + **{ + "model": { + "custom_model": "skdecide_rllib_graph_multiinput_model", + "custom_model_config": { + "features_extractor": self._graph_feature_extractors_kwargs, + # kwargs for GraphFeaturesExtractor + }, + }, + }, + }, + ) + for k, v in self._policy_configs.items() + } + else: + policies = { + k: (None, pol_obs_spaces[k], pol_act_spaces[k], v or {}) + for k, v in self._policy_configs.items() + } self._config.multi_agent( policies=policies, policy_mapping_fn=self._policy_mapping_fn, @@ -503,80 +614,58 @@ def __init__( self, domain: D, action_masking: bool, - unwrap_spaces: bool = True, ) -> None: """Initialize AsLegacyRLlibMultiAgentEnv. # Parameters domain: The scikit-decide domain to wrap as a RLlib multi-agent environment. action_masking: Boolean specifying whether action masking is used - unwrap_spaces: Boolean specifying whether the action & observation spaces should be unwrapped. """ self._domain = domain self._action_masking = action_masking - self._unwrap_spaces = unwrap_spaces - self._wrapped_observation_space = domain.get_observation_space() self._wrapped_action_space = domain.get_action_space() - if unwrap_spaces: - if not self._action_masking: - self.observation_space = gym.spaces.Dict( - { - k: agent_observation_space.unwrapped() - for k, agent_observation_space in self._wrapped_observation_space.items() - } - ) - else: - self.observation_space = gym.spaces.Dict( - { - k: gym.spaces.Dict( - { - "true_obs": agent_observation_space.unwrapped(), - "valid_avail_actions_mask": gym.spaces.Box( - 0, - 1, - shape=( - len( - self._wrapped_action_space[k].get_elements() - ), - ), - dtype=np.int8, - ), - } - ) - for k, agent_observation_space in self._wrapped_observation_space.items() - } - ) - self.action_space = gym.spaces.Dict( + self._wrapped_observation_space = domain.get_observation_space() + + if not self._action_masking: + self.observation_space = gym.spaces.Dict( { - k: agent_action_space.unwrapped() - for k, agent_action_space in self._wrapped_action_space.items() + agent: _unwrap_agent_obs_space( + wrapped_observation_space=self._wrapped_observation_space, + agent=agent, + ) + for agent in self._wrapped_observation_space } ) else: - if not self._action_masking: - self.observation_space = self._wrapped_observation_space - else: - self.observation_space = gym.spaces.Dict( - { - k: gym.spaces.Dict( - { - "true_obs": agent_observation_space, - "valid_avail_actions_mask": gym.spaces.Box( - 0, - 1, - shape=( - len( - self._wrapped_action_space[k].get_elements() - ), + self.observation_space = gym.spaces.Dict( + { + agent: gym.spaces.Dict( + { + "true_obs": _unwrap_agent_obs_space( + wrapped_observation_space=self._wrapped_observation_space, + agent=agent, + ), + "valid_avail_actions_mask": gym.spaces.Box( + 0, + 1, + shape=( + len( + self._wrapped_action_space[agent].get_elements() ), - dtype=np.int8, ), - } - ) - for k, agent_observation_space in self._wrapped_observation_space.items() - } - ) - self.action_space = self._wrapped_action_space + dtype=np.int8, + ), + } + ) + for agent in self._wrapped_observation_space + } + ) + self.action_space = gym.spaces.Dict( + { + k: agent_action_space.unwrapped() + for k, agent_action_space in self._wrapped_action_space.items() + } + ) def _wrap_action(self, action_dict: dict[str, Any]) -> dict[str, D.T_event]: return _wrap_action( @@ -676,14 +765,63 @@ def generate_rllibcallback_class( ) +def _unwrap_agent_obs_space( + wrapped_observation_space: dict[str, GymSpace[D.T_observation]], + agent: str, +) -> gym.Space: + unwrapped_agent_obs_space = wrapped_observation_space[agent].unwrapped() + if isinstance(unwrapped_agent_obs_space, gym.spaces.Graph): + return convert_graph_space_to_dict_space(unwrapped_agent_obs_space) + elif _is_graph_multiinput_unwrapped_agent_space(unwrapped_agent_obs_space): + return gym.spaces.Dict( + { + k: convert_graph_space_to_dict_space(subspace) + if isinstance(subspace, gym.spaces.Graph) + else subspace + for k, subspace in unwrapped_agent_obs_space.spaces.items() + } + ) + else: + return unwrapped_agent_obs_space + + def _unwrap_agent_obs( obs: dict[str, D.T_observation], agent: str, - wrapped_observation_space: dict[str, GymSpace], + wrapped_observation_space: dict[str, GymSpace[D.T_observation]], + transform_graph: bool = True, ) -> Any: - # Trick to get obs[agent]'s unwrapped value - # (no unwrapping method for single elements in enumerable spaces) - return next(iter(wrapped_observation_space[agent].to_unwrapped([obs[agent]]))) + unwrapped_agent_obs_space = wrapped_observation_space[agent].unwrapped() + if isinstance(unwrapped_agent_obs_space, gym.spaces.Graph) and transform_graph: + # get original unwrapped graph instance + unwrapped_agent_obs: gym.spaces.GraphInstance = _unwrap_agent_obs( + obs=obs, + agent=agent, + wrapped_observation_space=wrapped_observation_space, + transform_graph=False, + ) + # transform graph instance into a dict + return convert_graph_to_dict(unwrapped_agent_obs) + elif ( + _is_graph_multiinput_unwrapped_agent_space((unwrapped_agent_obs_space)) + and transform_graph + ): + unwrapped_agent_obs: dict[str, Any] = _unwrap_agent_obs( + obs=obs, + agent=agent, + wrapped_observation_space=wrapped_observation_space, + transform_graph=False, + ) + return { + k: convert_graph_to_dict(v) + if isinstance(v, gym.spaces.GraphInstance) + else v + for k, v in unwrapped_agent_obs.items() + } + else: + # Trick to get obs[agent]'s unwrapped value + # (no unwrapping method for single elements in enumerable spaces) + return next(iter(wrapped_observation_space[agent].to_unwrapped([obs[agent]]))) def _unwrap_agent_obs_with_action_masking( @@ -711,3 +849,23 @@ def _wrap_action( ) for agent, unwrapped_action in action.items() } + + +def _is_graph_space(space: dict[str, GymSpace[Any]]) -> bool: + return all( + isinstance(agent_observation_space.unwrapped(), gym.spaces.Graph) + for agent_observation_space in space.values() + ) + + +def _is_graph_multiinput_space(space: dict[str, GymSpace[Any]]) -> bool: + return all( + _is_graph_multiinput_unwrapped_agent_space(agent_observation_space.unwrapped()) + for agent_observation_space in space.values() + ) + + +def _is_graph_multiinput_unwrapped_agent_space(space: gym.spaces.Space) -> bool: + return isinstance(space, gym.spaces.Dict) and any( + isinstance(subspace, gym.spaces.Graph) for subspace in space.spaces.values() + ) diff --git a/tests/solvers/python/test_gnn_ray_rllib.py b/tests/solvers/python/test_gnn_ray_rllib.py new file mode 100644 index 0000000000..126675b6b0 --- /dev/null +++ b/tests/solvers/python/test_gnn_ray_rllib.py @@ -0,0 +1,675 @@ +import logging +import os +from typing import Any, Callable, Dict, Optional, Union + +import numpy as np +import ray +import torch as th +import torch_geometric as thg +from gymnasium.spaces import Box, Discrete, Graph, GraphInstance +from pytest_cases import fixture, param_fixture +from torch_geometric.nn import global_add_pool + +from skdecide.builders.domain import Renderable, UnrestrictedActions +from skdecide.core import Mask, Space, Value +from skdecide.domains import DeterministicPlanningDomain +from skdecide.hub.domain.maze import Maze +from skdecide.hub.domain.maze.maze import DEFAULT_MAZE, Action, State +from skdecide.hub.solver.ray_rllib import RayRLlib +from skdecide.hub.solver.ray_rllib.gnn.algorithms import GraphPPO +from skdecide.hub.space.gym import DictSpace, GymSpace, ListSpace +from skdecide.utils import rollout + + +class D(DeterministicPlanningDomain, UnrestrictedActions, Renderable): + T_state = GraphInstance # Type of states + T_observation = T_state # Type of observations + T_event = Action # Type of events + T_value = float # Type of transition values (rewards or costs) + T_predicate = bool # Type of logical checks + T_info = ( + None # Type of additional information given as part of an environment outcome + ) + + +class GraphMaze(D): + def __init__(self, maze_str: str = DEFAULT_MAZE, discrete_features: bool = False): + self.discrete_features = discrete_features + self.maze_domain = Maze(maze_str=maze_str) + np_wall = np.array(self.maze_domain._maze) + np_y = np.array( + [ + [(i) for j in range(self.maze_domain._num_cols)] + for i in range(self.maze_domain._num_rows) + ] + ) + np_x = np.array( + [ + [(j) for j in range(self.maze_domain._num_cols)] + for i in range(self.maze_domain._num_rows) + ] + ) + walls = np_wall.ravel() + coords = [i for i in zip(np_y.ravel(), np_x.ravel())] + np_node_id = np.reshape(range(len(walls)), np_wall.shape) + edge_links = [] + edges = [] + for i in range(self.maze_domain._num_rows): + for j in range(self.maze_domain._num_cols): + current_coord = (i, j) + if i > 0: + next_coord = (i - 1, j) + edge_links.append( + (np_node_id[current_coord], np_node_id[next_coord]) + ) + edges.append(np_wall[current_coord] * np_wall[next_coord]) + if i < self.maze_domain._num_rows - 1: + next_coord = (i + 1, j) + edge_links.append( + (np_node_id[current_coord], np_node_id[next_coord]) + ) + edges.append(np_wall[current_coord] * np_wall[next_coord]) + if j > 0: + next_coord = (i, j - 1) + edge_links.append( + (np_node_id[current_coord], np_node_id[next_coord]) + ) + edges.append(np_wall[current_coord] * np_wall[next_coord]) + if j < self.maze_domain._num_cols - 1: + next_coord = (i, j + 1) + edge_links.append( + (np_node_id[current_coord], np_node_id[next_coord]) + ) + edges.append(np_wall[current_coord] * np_wall[next_coord]) + self.edges = np.array(edges) + self.edge_links = np.array(edge_links) + self.walls = walls + self.node_ids = np_node_id + self.coords = coords + + def _mazestate2graph(self, state: State) -> D.T_state: + x, y = state + agent_presence = np.zeros(self.walls.shape, dtype=self.walls.dtype) + agent_presence[self.node_ids[y, x]] = 1 + nodes = np.stack([self.walls, agent_presence], axis=-1) + if self.discrete_features: + return GraphInstance( + nodes=nodes, edges=self.edges, edge_links=self.edge_links + ) + else: + return GraphInstance( + nodes=nodes, edges=self.edges[:, None], edge_links=self.edge_links + ) + + def _graph2mazestate(self, graph: D.T_state) -> State: + y, x = self.coords[graph.nodes[:, 1].nonzero()[0][0]] + return State(x=x, y=y) + + def _is_terminal(self, state: D.T_state) -> D.T_predicate: + return self.maze_domain._is_terminal(self._graph2mazestate(state)) + + def _get_next_state(self, memory: D.T_state, action: D.T_event) -> D.T_state: + maze_memory = self._graph2mazestate(memory) + maze_next_state = self.maze_domain._get_next_state( + memory=maze_memory, action=action + ) + return self._mazestate2graph(maze_next_state) + + def _get_transition_value( + self, + memory: D.T_state, + action: D.T_event, + next_state: Optional[D.T_state] = None, + ) -> Value[D.T_value]: + maze_memory = self._graph2mazestate(memory) + if next_state is None: + maze_next_state = None + else: + maze_next_state = self._graph2mazestate(next_state) + return self.maze_domain._get_transition_value( + memory=maze_memory, action=action, next_state=maze_next_state + ) + + def _get_action_space_(self) -> Space[D.T_event]: + return self.maze_domain._get_action_space_() + + def _get_goals_(self) -> Space[D.T_observation]: + return ListSpace([self._mazestate2graph(self.maze_domain._goal)]) + + def _is_goal( + self, observation: D.T_agent[D.T_observation] + ) -> D.T_agent[D.T_predicate]: + return self.maze_domain._is_goal(self._graph2mazestate(observation)) + + def _get_initial_state_(self) -> D.T_state: + return self._mazestate2graph(self.maze_domain._get_initial_state_()) + + def _get_observation_space_(self) -> Space[D.T_observation]: + if self.discrete_features: + return GymSpace( + Graph( + node_space=Box(low=0, high=1, shape=(2,), dtype=self.walls.dtype), + edge_space=Discrete(2), + ) + ) + else: + return GymSpace( + Graph( + node_space=Box(low=0, high=1, shape=(2,), dtype=self.walls.dtype), + edge_space=Box(low=0, high=1, shape=(1,), dtype=self.edges.dtype), + ) + ) + + def _render_from(self, memory: D.T_state, **kwargs: Any) -> Any: + maze_memory = self._graph2mazestate(memory) + self.maze_domain._render_from(memory=maze_memory, **kwargs) + + +class D(DeterministicPlanningDomain, Renderable): + T_state = GraphInstance # Type of states + T_observation = T_state # Type of observations + T_event = Action # Type of events + T_value = float # Type of transition values (rewards or costs) + T_predicate = bool # Type of logical checks + T_info = ( + None # Type of additional information given as part of an environment outcome + ) + + +class MaskedGraphMaze(D): + def __init__(self, maze_str: str = DEFAULT_MAZE, discrete_features: bool = False): + self.graph_maze = GraphMaze( + maze_str=maze_str, discrete_features=discrete_features + ) + + def _get_next_state( + self, + memory: D.T_memory[D.T_state], + action: D.T_agent[D.T_concurrency[D.T_event]], + ) -> D.T_state: + return self.graph_maze._get_next_state(memory=memory, action=action) + + def _get_transition_value( + self, + memory: D.T_memory[D.T_state], + action: D.T_agent[D.T_concurrency[D.T_event]], + next_state: Optional[D.T_state] = None, + ) -> D.T_agent[Value[D.T_value]]: + return self.graph_maze._get_transition_value( + memory=memory, action=action, next_state=next_state + ) + + def _is_terminal(self, state: D.T_state) -> D.T_agent[D.T_predicate]: + return self.graph_maze._is_terminal(state=state) + + def _get_action_space_(self) -> D.T_agent[Space[D.T_event]]: + return self.graph_maze._get_action_space_() + + def _get_goals_(self) -> D.T_agent[Space[D.T_observation]]: + return self.graph_maze._get_goals_() + + def _is_goal( + self, observation: D.T_agent[D.T_observation] + ) -> D.T_agent[D.T_predicate]: + return self.graph_maze._is_goal(observation=observation) + + def _get_initial_state_(self) -> D.T_state: + return self.graph_maze._get_initial_state_() + + def _get_observation_space_(self) -> D.T_agent[Space[D.T_observation]]: + return self.graph_maze._get_observation_space_() + + def _render_from(self, memory: D.T_memory[D.T_state], **kwargs: Any) -> Any: + return self.graph_maze._render_from(memory=memory, **kwargs) + + def _get_action_mask( + self, memory: Optional[D.T_memory[D.T_state]] = None + ) -> D.T_agent[Mask]: + # a different way to display applicable actions + # we could also override only _get_applicable_action() but it will be more computationally efficient to + # implement directly get_action_mask() + if memory is None: + memory = self._memory + mazestate_memory = self.graph_maze._graph2mazestate(memory) + return np.array( + [ + self.graph_maze._graph2mazestate( + self._get_next_state(action=action, memory=memory) + ) + != mazestate_memory + for action in self._get_action_space().get_elements() + ] + ) + + def _get_applicable_actions_from( + self, memory: D.T_memory[D.T_state] + ) -> D.T_agent[Space[D.T_event]]: + return ListSpace( + [ + action + for action, mask in zip( + self._get_action_space().get_elements(), + self._get_action_mask(memory=memory), + ) + if mask + ] + ) + + +class D(GraphMaze): + T_state = dict[str, Any] + + +class DictGraphMaze(D): + def _get_observation_space_(self) -> Space[D.T_observation]: + return DictSpace( + spaces=dict( + graph=super()._get_observation_space_(), + static=Box(low=0.0, high=1.0, dtype=np.float_), + ) + ) + + def _mazestate2graph(self, state: State) -> D.T_state: + return dict( + graph=super()._mazestate2graph(state), + static=np.array([0.5], dtype=np.float_), + ) + + def _graph2mazestate(self, graph: D.T_state) -> State: + return super()._graph2mazestate(graph["graph"]) + + +class D(DeterministicPlanningDomain, Renderable): + T_state = dict[str, Any] + T_observation = T_state # Type of observations + T_event = Action # Type of events + T_value = float # Type of transition values (rewards or costs) + T_predicate = bool # Type of logical checks + T_info = ( + None # Type of additional information given as part of an environment outcome + ) + + +class MaskedDictGraphMaze(D): + def __init__(self, maze_str: str = DEFAULT_MAZE, discrete_features: bool = False): + self.graph_maze = DictGraphMaze( + maze_str=maze_str, discrete_features=discrete_features + ) + + def _get_next_state( + self, + memory: D.T_memory[D.T_state], + action: D.T_agent[D.T_concurrency[D.T_event]], + ) -> D.T_state: + return self.graph_maze._get_next_state(memory=memory, action=action) + + def _get_transition_value( + self, + memory: D.T_memory[D.T_state], + action: D.T_agent[D.T_concurrency[D.T_event]], + next_state: Optional[D.T_state] = None, + ) -> D.T_agent[Value[D.T_value]]: + return self.graph_maze._get_transition_value( + memory=memory, action=action, next_state=next_state + ) + + def _is_terminal(self, state: D.T_state) -> D.T_agent[D.T_predicate]: + return self.graph_maze._is_terminal(state=state) + + def _get_action_space_(self) -> D.T_agent[Space[D.T_event]]: + return self.graph_maze._get_action_space_() + + def _get_goals_(self) -> D.T_agent[Space[D.T_observation]]: + return self.graph_maze._get_goals_() + + def _is_goal( + self, observation: D.T_agent[D.T_observation] + ) -> D.T_agent[D.T_predicate]: + return self.graph_maze._is_goal(observation=observation) + + def _get_initial_state_(self) -> D.T_state: + return self.graph_maze._get_initial_state_() + + def _get_observation_space_(self) -> D.T_agent[Space[D.T_observation]]: + return self.graph_maze._get_observation_space_() + + def _render_from(self, memory: D.T_memory[D.T_state], **kwargs: Any) -> Any: + return self.graph_maze._render_from(memory=memory, **kwargs) + + def _get_action_mask( + self, memory: Optional[D.T_memory[D.T_state]] = None + ) -> D.T_agent[Mask]: + # a different way to display applicable actions + # we could also override only _get_applicable_action() but it will be more computationally efficient to + # implement directly get_action_mask() + if memory is None: + memory = self._memory + mazestate_memory = self.graph_maze._graph2mazestate(memory) + return np.array( + [ + self.graph_maze._graph2mazestate( + self._get_next_state(action=action, memory=memory) + ) + != mazestate_memory + for action in self._get_action_space().get_elements() + ] + ) + + def _get_applicable_actions_from( + self, memory: D.T_memory[D.T_state] + ) -> D.T_agent[Space[D.T_event]]: + return ListSpace( + [ + action + for action, mask in zip( + self._get_action_space().get_elements(), + self._get_action_mask(memory=memory), + ) + if mask + ] + ) + + +discrete_features = param_fixture("discrete_features", [False, True]) + + +@fixture +def domain_factory(discrete_features): + return lambda: GraphMaze(discrete_features=discrete_features) + + +def test_observation_space(domain_factory): + domain = domain_factory() + assert domain.reset() in domain.get_observation_space() + rollout(domain=domain, num_episodes=1, max_steps=3, render=False, verbose=False) + + +def test_dict_observation_space(): + domain = DictGraphMaze() + assert domain.reset() in domain.get_observation_space() + rollout(domain=domain, num_episodes=1, max_steps=3, render=False, verbose=False) + + +@fixture +def ray_init(): + # add module test_gnn_ray_rllib and thus GraphMaze to ray runtimeenv + ray.init( + ignore_reinit_error=True, + runtime_env={"working_dir": os.path.dirname(__file__)}, + # local_mode=True, # uncomment this line and comment the one above to debug more easily + ) + + +@fixture +def graphppo_config(): + return ( + GraphPPO.get_default_config() + # set num of CPU<1 to avoid hanging for ever in github actions on macos 11 + .resources( + num_cpus_per_worker=0.5, + ) + # small number to increase speed of the unit test + .training(train_batch_size=256) + ) + + +def test_ppo(domain_factory, graphppo_config, ray_init): + solver_kwargs = dict( + algo_class=GraphPPO, train_iterations=1 # , gamma=0.95, train_batch_size_log2=8 + ) + with RayRLlib( + domain_factory=domain_factory, config=graphppo_config, **solver_kwargs + ) as solver: + + solver.solve() + rollout( + domain=domain_factory(), + solver=solver, + max_steps=100, + num_episodes=1, + render=False, + ) + + +class MyGNN(thg.nn.models.GAT): + + LOG_SENTENCE = "Using custom GNN." + + def __init__( + self, + in_channels: int, + hidden_channels: int, + num_layers: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + act: Union[str, Callable, None] = "relu", + act_first: bool = False, + act_kwargs: Optional[Dict[str, Any]] = None, + norm: Union[str, Callable, None] = None, + norm_kwargs: Optional[Dict[str, Any]] = None, + jk: Optional[str] = None, + **kwargs + ): + super().__init__( + in_channels, + hidden_channels, + num_layers, + out_channels, + dropout, + act, + act_first, + act_kwargs, + norm, + norm_kwargs, + jk, + **kwargs + ) + logging.warning(MyGNN.LOG_SENTENCE) + + +def test_ppo_user_gnn(domain_factory, graphppo_config, ray_init, caplog): + domain = domain_factory() + node_features_dim = int( + np.prod(domain.get_observation_space().unwrapped().node_space.shape) + ) + + solver_kwargs = dict( + algo_class=GraphPPO, + train_iterations=1, + graph_feature_extractors_kwargs=dict( + gnn_class=MyGNN, + gnn_kwargs=dict( + in_channels=node_features_dim, + hidden_channels=64, + num_layers=2, + dropout=0.2, + ), + gnn_out_dim=64, + features_dim=64, + ), + ) + with RayRLlib( + domain_factory=domain_factory, config=graphppo_config, **solver_kwargs + ) as solver: + with caplog.at_level(logging.WARNING): + solver.solve() + rollout( + domain=domain_factory(), + solver=solver, + max_steps=100, + num_episodes=1, + render=False, + ) + + assert MyGNN.LOG_SENTENCE in caplog.text + + +class MyReductionLayer(th.nn.Module): + LOG_SENTENCE = "Using custom reduction layer." + + def __init__(self, gnn_out_dim: int, features_dim: int): + super().__init__() + self.gnn_out_dim = gnn_out_dim + self.features_dim = features_dim + self.linear_layer = th.nn.Linear(gnn_out_dim, features_dim) + logging.warning(MyReductionLayer.LOG_SENTENCE) + + def forward(self, embedded_observations: thg.data.Data) -> th.Tensor: + x, edge_index, batch = ( + embedded_observations.x, + embedded_observations.edge_index, + embedded_observations.batch, + ) + h = global_add_pool(x, batch) + h = self.linear_layer(h).relu() + return h + + +def test_ppo_user_reduction_layer(domain_factory, graphppo_config, ray_init, caplog): + solver_kwargs = dict( + algo_class=GraphPPO, + train_iterations=1, + graph_feature_extractors_kwargs=dict( + gnn_out_dim=128, + features_dim=64, + reduction_layer_class=MyReductionLayer, + reduction_layer_kwargs=dict( + gnn_out_dim=128, + features_dim=64, + ), + ), + ) + with RayRLlib( + domain_factory=domain_factory, config=graphppo_config, **solver_kwargs + ) as solver: + with caplog.at_level(logging.WARNING): + solver.solve() + rollout( + domain=domain_factory(), + solver=solver, + max_steps=100, + num_episodes=1, + render=False, + ) + + assert MyReductionLayer.LOG_SENTENCE in caplog.text + + +def test_dict_ppo(graphppo_config, ray_init): + domain_factory = DictGraphMaze + solver_kwargs = dict( + algo_class=GraphPPO, train_iterations=1 # , gamma=0.95, train_batch_size_log2=8 + ) + with RayRLlib( + domain_factory=domain_factory, config=graphppo_config, **solver_kwargs + ) as solver: + + solver.solve() + rollout( + domain=domain_factory(), + solver=solver, + max_steps=100, + num_episodes=1, + render=False, + ) + + +def test_ppo_masked(discrete_features, graphppo_config, ray_init): + domain_factory = lambda: MaskedGraphMaze(discrete_features=discrete_features) + solver_kwargs = dict( + algo_class=GraphPPO, train_iterations=1 # , gamma=0.95, train_batch_size_log2=8 + ) + with RayRLlib( + domain_factory=domain_factory, config=graphppo_config, **solver_kwargs + ) as solver: + assert solver._action_masking and solver._is_graph_obs + solver.solve() + rollout( + domain=domain_factory(), + solver=solver, + max_steps=100, + num_episodes=1, + render=False, + ) + + +def test_dict_ppo_masked(discrete_features, graphppo_config, ray_init): + domain_factory = lambda: MaskedDictGraphMaze(discrete_features=discrete_features) + solver_kwargs = dict( + algo_class=GraphPPO, train_iterations=1 # , gamma=0.95, train_batch_size_log2=8 + ) + with RayRLlib( + domain_factory=domain_factory, config=graphppo_config, **solver_kwargs + ) as solver: + assert solver._action_masking and solver._is_graph_multiinput_obs + solver.solve() + rollout( + domain=domain_factory(), + solver=solver, + max_steps=100, + num_episodes=1, + render=False, + ) + + +def test_ppo_masked_user_gnn(graphppo_config, ray_init, caplog): + domain_factory = MaskedGraphMaze + node_features_dim = int( + np.prod(domain_factory().get_observation_space().unwrapped().node_space.shape) + ) + solver_kwargs = dict( + algo_class=GraphPPO, + train_iterations=1, + graph_feature_extractors_kwargs=dict( + gnn_class=MyGNN, + gnn_kwargs=dict( + in_channels=node_features_dim, + hidden_channels=64, + num_layers=2, + dropout=0.2, + ), + gnn_out_dim=64, + features_dim=64, + ), + ) + with RayRLlib( + domain_factory=domain_factory, config=graphppo_config, **solver_kwargs + ) as solver: + assert solver._action_masking and solver._is_graph_obs + with caplog.at_level(logging.WARNING): + solver.solve() + assert MyGNN.LOG_SENTENCE in caplog.text + + +def test_dict_ppo_masked_user_gnn(graphppo_config, ray_init, caplog): + domain_factory = MaskedDictGraphMaze + node_features_dim = int( + np.prod( + domain_factory() + .get_observation_space() + .unwrapped()["graph"] + .node_space.shape + ) + ) + solver_kwargs = dict( + algo_class=GraphPPO, + train_iterations=1, + graph_feature_extractors_kwargs=dict( + gnn_class=MyGNN, + gnn_kwargs=dict( + in_channels=node_features_dim, + hidden_channels=64, + num_layers=2, + dropout=0.2, + ), + gnn_out_dim=64, + features_dim=64, + ), + ) + with RayRLlib( + domain_factory=domain_factory, config=graphppo_config, **solver_kwargs + ) as solver: + assert solver._action_masking and solver._is_graph_multiinput_obs + with caplog.at_level(logging.WARNING): + solver.solve() + assert MyGNN.LOG_SENTENCE in caplog.text