diff --git a/skdecide/hub/solver/ray_rllib/custom_models.py b/skdecide/hub/solver/ray_rllib/custom_models.py index af93f358a3..44462eedc2 100644 --- a/skdecide/hub/solver/ray_rllib/custom_models.py +++ b/skdecide/hub/solver/ray_rllib/custom_models.py @@ -1,4 +1,5 @@ from gymnasium.spaces import flatten_space +from ray.rllib import SampleBatch from ray.rllib.models.tf.fcnet import FullyConnectedNetwork as TFFullyConnectedNetwork from ray.rllib.models.tf.tf_modelv2 import TFModelV2 from ray.rllib.models.torch.fcnet import ( @@ -9,6 +10,15 @@ from ray.rllib.utils.spaces.space_utils import flatten_to_single_ndarray, unbatch from ray.rllib.utils.torch_utils import FLOAT_MAX, FLOAT_MIN +from skdecide.hub.solver.ray_rllib.gnn.models.torch.complex_input_net import ( + GraphComplexInputNetwork, +) +from skdecide.hub.solver.ray_rllib.gnn.models.torch.gnn import GnnBasedModel +from skdecide.hub.solver.ray_rllib.gnn.utils.spaces.space_utils import ( + is_graph_dict_multiinput_space, + is_graph_dict_space, +) + tf1, tf, tfv = try_import_tf() torch, nn = try_import_torch() @@ -98,8 +108,20 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, name, **k self.action_ids_shifted = torch.arange(1, num_outputs + 1, dtype=torch.int64) self.true_obs_space = model_config["custom_model_config"]["true_obs_space"] - self.pred_action_embed_model = TorchFullyConnectedNetwork( - flatten_space(self.true_obs_space), + if is_graph_dict_space(self.true_obs_space): + pred_action_embed_model_cls = GnnBasedModel + self.obs_with_graph = True + embed_model_obs_space = self.true_obs_space + elif is_graph_dict_multiinput_space(self.true_obs_space): + pred_action_embed_model_cls = GraphComplexInputNetwork + self.obs_with_graph = True + embed_model_obs_space = self.true_obs_space + else: + pred_action_embed_model_cls = TorchFullyConnectedNetwork + self.obs_with_graph = False + embed_model_obs_space = flatten_space(self.true_obs_space) + self.pred_action_embed_model = pred_action_embed_model_cls( + embed_model_obs_space, action_space, model_config["custom_model_config"]["action_embed_size"], model_config, @@ -115,16 +137,21 @@ def forward(self, input_dict, state, seq_lens): # Extract the available actions mask tensor from the observation. valid_avail_actions_mask = input_dict["obs"]["valid_avail_actions_mask"] - # Unbatch true observations before flattening them - unbatched_true_obs = unbatch(input_dict["obs"]["true_obs"]) + if self.obs_with_graph: + # use directly the obs (already converted at proper format by custom `convert_to_torch_tensor`) + embed_model_obs = input_dict["obs"]["true_obs"] + else: + # Unbatch true observations before flattening them + embed_model_obs = torch.stack( + [ + flatten_to_single_ndarray(o) + for o in unbatch(input_dict["obs"]["true_obs"]) + ] + ) # Compute the predicted action embedding pred_action_embed, _ = self.pred_action_embed_model( - { - "obs": torch.stack( - [flatten_to_single_ndarray(o) for o in unbatched_true_obs] - ) - } + SampleBatch({SampleBatch.OBS: embed_model_obs}) ) # Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the diff --git a/skdecide/hub/solver/ray_rllib/gnn/__init__.py b/skdecide/hub/solver/ray_rllib/gnn/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/skdecide/hub/solver/ray_rllib/gnn/algorithms/__init__.py b/skdecide/hub/solver/ray_rllib/gnn/algorithms/__init__.py new file mode 100644 index 0000000000..d5d5c40f48 --- /dev/null +++ b/skdecide/hub/solver/ray_rllib/gnn/algorithms/__init__.py @@ -0,0 +1 @@ +from .ppo.ppo import GraphPPO diff --git a/skdecide/hub/solver/ray_rllib/gnn/algorithms/ppo/__init__.py b/skdecide/hub/solver/ray_rllib/gnn/algorithms/ppo/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/skdecide/hub/solver/ray_rllib/gnn/algorithms/ppo/ppo.py b/skdecide/hub/solver/ray_rllib/gnn/algorithms/ppo/ppo.py new file mode 100644 index 0000000000..75bb88bcb2 --- /dev/null +++ b/skdecide/hub/solver/ray_rllib/gnn/algorithms/ppo/ppo.py @@ -0,0 +1,21 @@ +from typing import Optional + +from ray.rllib import Policy +from ray.rllib.algorithms import PPO, AlgorithmConfig + +from skdecide.hub.solver.ray_rllib.gnn.algorithms.ppo.ppo_torch_policy import ( + PPOTorchGraphPolicy, +) + + +class GraphPPO(PPO): + @classmethod + def get_default_policy_class( + cls, config: AlgorithmConfig + ) -> Optional[type[Policy]]: + if config["framework"] == "torch": + return PPOTorchGraphPolicy + elif config["framework"] == "tf": + raise NotImplementedError("GraphPPO implemented for torch context") + else: + raise NotImplementedError("GraphPPO implemented for torch context") diff --git a/skdecide/hub/solver/ray_rllib/gnn/algorithms/ppo/ppo_torch_policy.py b/skdecide/hub/solver/ray_rllib/gnn/algorithms/ppo/ppo_torch_policy.py new file mode 100644 index 0000000000..673c3a165a --- /dev/null +++ b/skdecide/hub/solver/ray_rllib/gnn/algorithms/ppo/ppo_torch_policy.py @@ -0,0 +1,7 @@ +from ray.rllib.algorithms.ppo import PPOTorchPolicy + +from skdecide.hub.solver.ray_rllib.gnn.policy.torch_graph_policy import TorchGraphPolicy + + +class PPOTorchGraphPolicy(TorchGraphPolicy, PPOTorchPolicy): + ... diff --git a/skdecide/hub/solver/ray_rllib/gnn/models/__init__.py b/skdecide/hub/solver/ray_rllib/gnn/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/skdecide/hub/solver/ray_rllib/gnn/models/torch/__init__.py b/skdecide/hub/solver/ray_rllib/gnn/models/torch/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/skdecide/hub/solver/ray_rllib/gnn/models/torch/complex_input_net.py b/skdecide/hub/solver/ray_rllib/gnn/models/torch/complex_input_net.py new file mode 100644 index 0000000000..a3d1af2260 --- /dev/null +++ b/skdecide/hub/solver/ray_rllib/gnn/models/torch/complex_input_net.py @@ -0,0 +1,66 @@ +import gymnasium as gym +from ray.rllib import SampleBatch +from ray.rllib.models.torch.complex_input_net import ComplexInputNetwork +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.utils.typing import TensorType +from torch import nn + +from skdecide.hub.solver.ray_rllib.gnn.models.torch.gnn import GnnBasedModel +from skdecide.hub.solver.ray_rllib.gnn.utils.spaces.space_utils import ( + is_graph_dict_space, +) + + +class GraphComplexInputNetwork(TorchModelV2, nn.Module): + def __init__(self, obs_space, action_space, num_outputs, model_config, name): + if not model_config.get("_disable_preprocessor_api"): + raise ValueError( + "This model is intent to be used only when preprocessors are disabled." + ) + if not isinstance(obs_space, gym.spaces.Dict): + raise ValueError( + "This model is intent to be used only on dict observation space." + ) + + nn.Module.__init__(self) + super().__init__(obs_space, action_space, num_outputs, model_config, name) + + self.gnn = nn.ModuleDict() + post_graph_obs_subspaces = dict(obs_space.spaces) + for k, subspace in obs_space.spaces.items(): + if is_graph_dict_space(subspace): + submodel_name = f"gnn_{k}" + gnn = GnnBasedModel( + obs_space=subspace, + action_space=action_space, + num_outputs=None, + model_config=model_config, + framework="torch", + name=submodel_name, + ) + self.add_module(submodel_name, gnn) + self.gnn[k] = gnn + post_graph_obs_subspaces[k] = gnn.features_space + + post_graph_obs_space = gym.spaces.Dict(post_graph_obs_subspaces) + self.post_graph_model = ComplexInputNetwork( + obs_space=post_graph_obs_space, + action_space=action_space, + num_outputs=num_outputs, + model_config=model_config, + name="post_graph_model", + ) + + def forward(self, input_dict: SampleBatch, state, seq_lens): + post_graph_input_dict = input_dict.copy(shallow=True) + obs = input_dict["obs"] + post_graph_obs = dict(obs) + for k, gnn in self.gnn.items(): + post_graph_obs[k] = gnn(SampleBatch({SampleBatch.OBS: obs[k]})) + post_graph_input_dict["obs"] = post_graph_obs + return self.post_graph_model( + input_dict=post_graph_input_dict, state=state, seq_lens=seq_lens + ) + + def value_function(self) -> TensorType: + return self.post_graph_model.value_function() diff --git a/skdecide/hub/solver/ray_rllib/gnn/models/torch/gnn.py b/skdecide/hub/solver/ray_rllib/gnn/models/torch/gnn.py new file mode 100644 index 0000000000..7f235126db --- /dev/null +++ b/skdecide/hub/solver/ray_rllib/gnn/models/torch/gnn.py @@ -0,0 +1,77 @@ +from collections import defaultdict +from typing import Optional + +import gymnasium as gym +import numpy as np +from ray.rllib.models.torch.fcnet import FullyConnectedNetwork +from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 +from ray.rllib.utils.typing import ModelConfigDict +from torch import nn + +from skdecide.hub.solver.ray_rllib.gnn.torch_layers import GraphFeaturesExtractor +from skdecide.hub.solver.ray_rllib.gnn.utils.spaces.space_utils import ( + convert_dict_space_to_graph_space, + is_graph_dict_space, +) + + +class GnnBasedModel(TorchModelV2, nn.Module): + def __init__( + self, + obs_space: gym.spaces.Space, + action_space: gym.spaces.Space, + num_outputs: Optional[int], + model_config: ModelConfigDict, + name: str, + **kw, + ): + nn.Module.__init__(self) + super().__init__(obs_space, action_space, num_outputs, model_config, name) + + # config for custom model + custom_config = defaultdict( + lambda: None, # will return None for missing keys + model_config.get("custom_model_config", {}), + ) + + # gnn-based feature extractor + features_extractor_kwargs = custom_config.get("features_extractor", {}) + assert is_graph_dict_space( + obs_space + ), f"{self.__class__.__name__} can only be applied to Graph observation spaces." + graph_observation_space = convert_dict_space_to_graph_space(obs_space) + self.features_extractor = GraphFeaturesExtractor( + observation_space=graph_observation_space, **features_extractor_kwargs + ) + self.features_space = gym.spaces.Box( + low=-np.inf, high=np.inf, shape=(self.features_extractor.features_dim,) + ) + + if num_outputs is None: + # only feature extraction (e.g. to be used by GraphComplexInputNetwork) + self.num_outputs = self.features_extractor.features_dim + self.pred_action_embed_model = None + else: + # fully connected network + self.pred_action_embed_model = FullyConnectedNetwork( + obs_space=self.features_space, + action_space=action_space, + num_outputs=num_outputs, + model_config=model_config, + name=name + "_pred_action_embed", + ) + + def forward(self, input_dict, state, seq_lens): + obs = input_dict["obs"] + features = self.features_extractor(obs) + if self.pred_action_embed_model is None: + return features, state + else: + return self.pred_action_embed_model( + input_dict={"obs": features}, + state=state, + seq_lens=seq_lens, + ) + + def value_function(self): + return self.pred_action_embed_model.value_function() diff --git a/skdecide/hub/solver/ray_rllib/gnn/policy/__init__.py b/skdecide/hub/solver/ray_rllib/gnn/policy/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/skdecide/hub/solver/ray_rllib/gnn/policy/sample_batch.py b/skdecide/hub/solver/ray_rllib/gnn/policy/sample_batch.py new file mode 100644 index 0000000000..ffe6266b88 --- /dev/null +++ b/skdecide/hub/solver/ray_rllib/gnn/policy/sample_batch.py @@ -0,0 +1,116 @@ +from numbers import Number +from typing import Any, Optional, Union + +import gymnasium as gym +import numpy as np +import tree +from ray.rllib import SampleBatch +from ray.rllib.policy.sample_batch import attempt_count_timesteps, tf, torch +from ray.rllib.utils.typing import ViewRequirementsDict + + +def _pop_graph_items( + full_dict: dict[Any, Any] +) -> dict[Any, Union[gym.spaces.GraphInstance, list[gym.spaces.GraphInstance]]]: + graph_dict = {} + for k, v in full_dict.items(): + if isinstance(v, gym.spaces.GraphInstance) or ( + isinstance(v, list) and isinstance(v[0], gym.spaces.GraphInstance) + ): + graph_dict[k] = v + for k in graph_dict: + full_dict.pop(k) + return graph_dict + + +def _split_graph_requirements( + full_dict: ViewRequirementsDict, +) -> tuple[ViewRequirementsDict, ViewRequirementsDict]: + graph_dict = {} + for k, v in full_dict.items(): + if isinstance(v.space, gym.spaces.Graph): + graph_dict[k] = v + wo_graph_dict = {k: v for k, v in full_dict.items() if k not in graph_dict} + return graph_dict, wo_graph_dict + + +class GraphSampleBatch(SampleBatch): + def __init__(self, *args, **kwargs): + """Constructs a sample batch with possibly graph obs. + + See `ray.rllib.SampleBatch` for more information. + + """ + # split graph samples from others. + dict_graphs = _pop_graph_items(kwargs) + dict_from_args = dict(*args) + dict_graphs.update(_pop_graph_items(dict_from_args)) + + super().__init__(dict_from_args, **kwargs) + super().update(dict_graphs) + + def copy(self, shallow: bool = False) -> "SampleBatch": + """Creates a deep or shallow copy of this SampleBatch and returns it. + + Args: + shallow: Whether the copying should be done shallowly. + + Returns: + A deep or shallow copy of this SampleBatch object. + """ + copy_ = dict(self) + data = tree.map_structure( + lambda v: ( + np.array(v, copy=not shallow) if isinstance(v, np.ndarray) else v + ), + copy_, + ) + copy_ = GraphSampleBatch( + data, + _time_major=self.time_major, + _zero_padded=self.zero_padded, + _max_seq_len=self.max_seq_len, + _num_grad_updates=self.num_grad_updates, + ) + copy_.set_get_interceptor(self.get_interceptor) + copy_.added_keys = self.added_keys + copy_.deleted_keys = self.deleted_keys + copy_.accessed_keys = self.accessed_keys + return copy_ + + def get_single_step_input_dict( + self, + view_requirements: ViewRequirementsDict, + index: Union[str, int] = "last", + ) -> "SampleBatch": + ( + view_requirements_graphs, + view_requirements_wo_graphs, + ) = _split_graph_requirements(view_requirements) + # w/o graphs + sample = GraphSampleBatch( + super().get_single_step_input_dict(view_requirements_wo_graphs, index) + ) + # handle graphs + last_mappings = { + SampleBatch.OBS: SampleBatch.NEXT_OBS, + SampleBatch.PREV_ACTIONS: SampleBatch.ACTIONS, + SampleBatch.PREV_REWARDS: SampleBatch.REWARDS, + } + for view_col, view_req in view_requirements_graphs.items(): + if view_req.used_for_compute_actions is False: + continue + + # Create batches of size 1 (single-agent input-dict). + data_col = view_req.data_col or view_col + if index == "last": + data_col = last_mappings.get(data_col, data_col) + if view_req.shift_from is not None: + raise NotImplementedError() + else: + sample[view_col] = self[data_col][-1:] + else: + sample[view_col] = self[data_col][ + index : index + 1 if index != -1 else None + ] + return sample diff --git a/skdecide/hub/solver/ray_rllib/gnn/policy/torch_graph_policy.py b/skdecide/hub/solver/ray_rllib/gnn/policy/torch_graph_policy.py new file mode 100644 index 0000000000..84b599614c --- /dev/null +++ b/skdecide/hub/solver/ray_rllib/gnn/policy/torch_graph_policy.py @@ -0,0 +1,16 @@ +import functools + +from ray.rllib import SampleBatch +from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 + +from skdecide.hub.solver.ray_rllib.gnn.utils.torch_utils import convert_to_torch_tensor + + +class TorchGraphPolicy(TorchPolicyV2): + def _lazy_tensor_dict(self, postprocessed_batch: SampleBatch, device=None): + if not isinstance(postprocessed_batch, SampleBatch): + postprocessed_batch = SampleBatch(postprocessed_batch) + postprocessed_batch.set_get_interceptor( + functools.partial(convert_to_torch_tensor, device=device or self.device) + ) + return postprocessed_batch diff --git a/skdecide/hub/solver/ray_rllib/gnn/policy/torch_mixins.py b/skdecide/hub/solver/ray_rllib/gnn/policy/torch_mixins.py new file mode 100644 index 0000000000..775bc3eee2 --- /dev/null +++ b/skdecide/hub/solver/ray_rllib/gnn/policy/torch_mixins.py @@ -0,0 +1,26 @@ +from ray.rllib.policy.torch_mixins import ValueNetworkMixin + +from skdecide.hub.solver.ray_rllib.gnn.policy.sample_batch import GraphSampleBatch + + +class ValueNetworkGraphMixin(ValueNetworkMixin): + def __init__(self, config): + if config.get("use_gae") or config.get("vtrace"): + # Input dict is provided to us automatically via the Model's + # requirements. It's a single-timestep (last one in trajectory) + # input_dict. + + def value(**input_dict): + input_dict = GraphSampleBatch(input_dict) + input_dict = self._lazy_tensor_dict(input_dict) + model_out, _ = self.model(input_dict) + # [0] = remove the batch dim. + return self.model.value_function()[0].item() + + # When not doing GAE, we do not require the value function's output. + else: + + def value(*args, **kwargs): + return 0.0 + + self._value = value diff --git a/skdecide/hub/solver/ray_rllib/gnn/torch_layers.py b/skdecide/hub/solver/ray_rllib/gnn/torch_layers.py new file mode 100644 index 0000000000..8180b35f88 --- /dev/null +++ b/skdecide/hub/solver/ray_rllib/gnn/torch_layers.py @@ -0,0 +1,114 @@ +from typing import Any, Optional + +import gymnasium as gym +import numpy as np +import torch as th +import torch_geometric as thg +from torch import nn +from torch_geometric.nn import global_max_pool + + +class GraphFeaturesExtractor(nn.Module): + """Graph feature extractor for Graph observation spaces. + + Will chain a gnn with a reduction layer to extract a fixed number of features. + The user can specify both the gnn and reduction layer. + + By default, we use: + - gnn: a 2-layers GCN + - reduction layer: global_max_pool + linear layer + relu + + Args: + observation_space: + features_dim: Number of extracted features + - If reduction_layer_class is given, should match the output of this network. + - If reduction_layer is None, will be used by the default network as its output dimension. + gnn_out_dim: dimension of the node embedding in gnn output + - If gnn is given, should not be None and should match the output of gnn + - If gnn is not given, will be used to generate it. By default, gnn_out_dim = 2 * features_dim + gnn_class: GNN network class (for instance chosen from `torch_geometric.nn.models` used to embed the graph observations) + gnn_kwargs: used by `gnn_class.__init__()`. Without effect if `gnn_class` is None. + reduction_layer_class: network class to be plugged after the gnn to get a fixed number of features. + reduction_layer_kwargs: used by `reduction_layer_class.__init__()`. Without effect if `reduction_layer_class` is None. + + """ + + def __init__( + self, + observation_space: gym.spaces.Graph, + features_dim: int = 64, + gnn_out_dim: Optional[int] = None, + gnn_class: Optional[type[nn.Module]] = None, + gnn_kwargs: Optional[dict[str, Any]] = None, + reduction_layer_class: Optional[type[nn.Module]] = None, + reduction_layer_kwargs: Optional[dict[str, Any]] = None, + ): + super().__init__() + self.features_dim = features_dim + + if gnn_out_dim is None: + if gnn_class is None: + gnn_out_dim = 2 * features_dim + else: + raise ValueError( + "`gnn_out_dim` cannot be None if `gnn` is not None, " + "and should match `gnn` output." + ) + + if gnn_class is None: + node_features_dim = int(np.prod(observation_space.node_space.shape)) + self.gnn = thg.nn.models.GCN( + in_channels=node_features_dim, + hidden_channels=gnn_out_dim, + num_layers=2, + dropout=0.2, + ) + else: + if gnn_kwargs is None: + gnn_kwargs = {} + self.gnn = gnn_class(**gnn_kwargs) + + if reduction_layer_class is None: + self.reduction_layer = _DefaultReductionLayer( + gnn_out_dim=gnn_out_dim, features_dim=features_dim + ) + else: + if reduction_layer_kwargs is None: + reduction_layer_kwargs = {} + self.reduction_layer = reduction_layer_class(**reduction_layer_kwargs) + + def forward(self, observations) -> th.Tensor: + x, edge_index, edge_attr, batch = ( + observations.x, + observations.edge_index, + observations.edge_attr, + observations.batch, + ) + # construct edge weights, for GNNs needing it, as the first edge feature + edge_weight = edge_attr[:, 0] + h = self.gnn( + x=x, edge_index=edge_index, edge_weight=edge_weight, edge_attr=edge_attr + ) + embedded_observations = thg.data.Data( + x=h, edge_index=edge_index, edge_attr=edge_attr, batch=batch + ) + h = self.reduction_layer(embedded_observations=embedded_observations) + return h + + +class _DefaultReductionLayer(nn.Module): + def __init__(self, gnn_out_dim: int, features_dim: int): + super().__init__() + self.gnn_out_dim = gnn_out_dim + self.features_dim = features_dim + self.linear_layer = nn.Linear(gnn_out_dim, features_dim) + + def forward(self, embedded_observations: thg.data.Data) -> th.Tensor: + x, edge_index, batch = ( + embedded_observations.x, + embedded_observations.edge_index, + embedded_observations.batch, + ) + h = global_max_pool(x, batch) + h = self.linear_layer(h).relu() + return h diff --git a/skdecide/hub/solver/ray_rllib/gnn/utils/__init__.py b/skdecide/hub/solver/ray_rllib/gnn/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/skdecide/hub/solver/ray_rllib/gnn/utils/spaces/__init__.py b/skdecide/hub/solver/ray_rllib/gnn/utils/spaces/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/skdecide/hub/solver/ray_rllib/gnn/utils/spaces/space_utils.py b/skdecide/hub/solver/ray_rllib/gnn/utils/spaces/space_utils.py new file mode 100644 index 0000000000..11df126454 --- /dev/null +++ b/skdecide/hub/solver/ray_rllib/gnn/utils/spaces/space_utils.py @@ -0,0 +1,129 @@ +from typing import Any, Union + +import gymnasium as gym +import numpy as np + + +def convert_graph_space_to_dict_space(space: gym.spaces.Graph) -> gym.spaces.Dict: + # artificially decide of 2 nodes and 1 edge (for dummy samples auto-generated by ray.rllib) + return gym.spaces.Dict( + dict( + nodes=repeat_space(space.node_space, n_rep=2), + edges=repeat_space(space.edge_space, n_rep=1), + edge_links=gym.spaces.Box(low=0, high=2, shape=(1, 2), dtype=np.int_), + ) + ) + + +def repeat_space_box(space: gym.spaces.Box, n_rep: int) -> gym.spaces.Box: + rep_low = np.repeat(space.low[None, :], n_rep, axis=0) + rep_high = np.repeat(space.high[None, :], n_rep, axis=0) + rep_shape = (n_rep,) + space.shape + return gym.spaces.Box( + low=rep_low, + high=rep_high, + shape=rep_shape, + dtype=space.dtype, + ) + + +def repeat_space_discrete(space: gym.spaces.Discrete, n_rep: int) -> gym.spaces.Box: + return gym.spaces.Box( + low=space.start, + high=space.start + space.n - 1, + shape=(n_rep, 1), + dtype=space.dtype, + ) + + +def repeat_space(space: Union[gym.spaces.Box, gym.spaces.Discrete], n_rep: int): + if isinstance(space, gym.spaces.Box): + return repeat_space_box(space=space, n_rep=n_rep) + elif isinstance(space, gym.spaces.Discrete): + return repeat_space_discrete(space=space, n_rep=n_rep) + else: + raise NotImplementedError() + + +def remove_first_axis_space(space: gym.spaces.Box) -> gym.spaces.Box: + return gym.spaces.Box( + low=space.low[0, :], + high=space.high[0, :], + shape=space.shape[1:], + dtype=space.dtype, + ) + + +def convert_graph_to_dict(x: gym.spaces.GraphInstance) -> dict[str, np.ndarray]: + return dict( + nodes=x.nodes, + edges=x.edges, + edge_links=x.edge_links, + ) + + +def convert_dict_space_to_graph_space(space: gym.spaces.Dict) -> gym.spaces.Graph: + return gym.spaces.Graph( + node_space=remove_first_axis_space(space.spaces["nodes"]), + edge_space=remove_first_axis_space(space.spaces["edges"]), + ) + + +def convert_dict_to_graph(x: dict[str, np.ndarray]) -> gym.spaces.GraphInstance: + return gym.spaces.GraphInstance( + nodes=x["nodes"], edges=x["edges"], edge_links=x["edge_links"] + ) + + +def is_graph_dict(x: Any) -> bool: + return ( + isinstance(x, dict) + and len(x) == 3 + and "nodes" in x + and "edges" in x + and "edge_links" in x + ) + + +def is_graph_dict_space(x: gym.spaces.Space) -> bool: + return ( + isinstance(x, gym.spaces.Dict) + and len(x.spaces) == 3 + and "nodes" in x.spaces + and "edges" in x.spaces + and "edge_links" in x.spaces + ) + + +def is_graph_dict_multiinput(x: Any) -> bool: + return isinstance(x, dict) and any([is_graph_dict(v) for v in x.values()]) + + +def is_graph_dict_multiinput_space(x: gym.spaces.Space) -> bool: + return isinstance(x, gym.spaces.Dict) and any( + [is_graph_dict_space(subspace) for subspace in x.values()] + ) + + +def is_masked_obs(x: Any) -> bool: + return ( + isinstance(x, dict) + and len(x) == 2 + and "true_obs" in x + and "valid_avail_actions_mask" in x + ) + + +def is_masked_obs_space(x: gym.spaces.Space) -> bool: + return ( + isinstance(x, gym.spaces.Dict) + and len(x.spaces) == 2 + and "true_obs" in x.spaces + and "valid_avail_actions_mask" in x.spaces + ) + + +def extract_graph_dict_from_batched_graph_dict( + batched_graph_dict: dict[str, np.ndarray], index: int +) -> dict[str, np.ndarray]: + return {k: v[index, :] for k, v in batched_graph_dict.items()} diff --git a/skdecide/hub/solver/ray_rllib/gnn/utils/torch_utils.py b/skdecide/hub/solver/ray_rllib/gnn/utils/torch_utils.py new file mode 100644 index 0000000000..f74be0e97e --- /dev/null +++ b/skdecide/hub/solver/ray_rllib/gnn/utils/torch_utils.py @@ -0,0 +1,142 @@ +from typing import Optional, Union + +import gymnasium as gym +import numpy as np +import torch as th +import torch_geometric as thg +from ray.rllib.utils.torch_utils import ( + convert_to_torch_tensor as convert_to_torch_tensor_original, +) +from ray.rllib.utils.typing import TensorStructType + +from skdecide.hub.solver.ray_rllib.gnn.utils.spaces.space_utils import ( + extract_graph_dict_from_batched_graph_dict, + is_graph_dict, + is_graph_dict_multiinput, + is_masked_obs, +) + + +def graph_obs_to_thg_data( + obs: gym.spaces.GraphInstance, + device: Optional[th.device] = None, + pin_memory: bool = False, +) -> thg.data.Data: + # Node features + flatten_node_features = obs.nodes.reshape((len(obs.nodes), -1)) + x = th.tensor(flatten_node_features).float() + # Edge features + if obs.edges is None: + edge_attr = None + else: + flatten_edge_features = obs.edges.reshape((len(obs.edges), -1)) + edge_attr = th.tensor(flatten_edge_features).float() + edge_index = th.tensor(obs.edge_links, dtype=th.long).t().contiguous().view(2, -1) + # thg.Data + data = thg.data.Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + # Pin the tensor's memory (for faster transfer to GPU later). + if pin_memory and th.cuda.is_available(): + data.pin_memory() + + return data if device is None else data.to(device) + + +def convert_to_torch_tensor( + x: Union[ + TensorStructType, + thg.data.Data, + gym.spaces.GraphInstance, + list[gym.spaces.GraphInstance], + ], + device: Optional[str] = None, + pin_memory: bool = False, +) -> Union[TensorStructType, thg.data.Data]: + """Converts any struct to torch.Tensors. + + Args: + x: Any (possibly nested) struct, the values in which will be + converted and returned as a new struct with all leaves converted + to torch tensors. + device: The device to create the tensor on. + pin_memory: If True, will call the `pin_memory()` method on the created tensors. + + Returns: + Any: A new struct with the same structure as `x`, but with all + values converted to torch Tensor types. This does not convert possibly + nested elements that are None because torch has no representation for that. + """ + if isinstance(x, gym.spaces.GraphInstance): + return graph_obs_to_thg_data(x, device=device, pin_memory=pin_memory) + elif isinstance(x, list) and isinstance(x[0], gym.spaces.GraphInstance): + return thg.data.Batch.from_data_list( + [ + graph_obs_to_thg_data(graph, device=device, pin_memory=pin_memory) + for graph in x + ] + ) + elif isinstance(x, thg.data.Data): + return x + elif is_masked_obs(x): + return { + k: convert_to_torch_tensor(v, device=device, pin_memory=pin_memory) + for k, v in x.items() + } + elif is_graph_dict(x): + return batched_graph_dict_to_thg_data(x, device=device, pin_memory=pin_memory) + elif is_graph_dict_multiinput(x): + return { + k: convert_to_torch_tensor(v, device=device, pin_memory=pin_memory) + for k, v in x.items() + } + else: + return convert_to_torch_tensor_original( + x=x, device=device, pin_memory=pin_memory + ) + + +def graph_dict_to_thg_data( + graph_dict: dict[str, np.ndarray], + device: Optional[str] = None, + pin_memory: bool = False, +): + # Node features + flatten_node_features = graph_dict["nodes"].reshape((len(graph_dict["nodes"]), -1)) + x = th.tensor(flatten_node_features).float() + # Edge features + if graph_dict["edges"] is None: + edge_attr = None + else: + flatten_edge_features = graph_dict["edges"].reshape( + (len(graph_dict["edges"]), -1) + ) + edge_attr = th.tensor(flatten_edge_features).float() + edge_index = ( + th.tensor(graph_dict["edge_links"], dtype=th.long).t().contiguous().view(2, -1) + ) + # thg.Data + data = thg.data.Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + # Pin the tensor's memory (for faster transfer to GPU later). + if pin_memory and th.cuda.is_available(): + data.pin_memory() + + return data if device is None else data.to(device) + + +def batched_graph_dict_to_thg_data( + batched_graph_dict: dict[str, np.ndarray], + device: Optional[str] = None, + pin_memory: bool = False, +): + batch_size = batched_graph_dict["nodes"].shape[0] + return thg.data.Batch.from_data_list( + [ + graph_dict_to_thg_data( + graph_dict=extract_graph_dict_from_batched_graph_dict( + batched_graph_dict=batched_graph_dict, index=index + ), + device=device, + pin_memory=pin_memory, + ) + for index in range(batch_size) + ] + ) diff --git a/skdecide/hub/solver/ray_rllib/ray_rllib.py b/skdecide/hub/solver/ray_rllib/ray_rllib.py index a91167c224..8dc9715035 100644 --- a/skdecide/hub/solver/ray_rllib/ray_rllib.py +++ b/skdecide/hub/solver/ray_rllib/ray_rllib.py @@ -14,7 +14,6 @@ CategoricalHyperparameter, FloatHyperparameter, IntegerHyperparameter, - SubBrickKwargsHyperparameter, ) from packaging.version import Version from ray.rllib.algorithms import DQN, PPO, SAC @@ -37,6 +36,12 @@ from skdecide.hub.space.gym import GymSpace from .custom_models import TFParametricActionsModel, TorchParametricActionsModel +from .gnn.models.torch.complex_input_net import GraphComplexInputNetwork +from .gnn.models.torch.gnn import GnnBasedModel +from .gnn.utils.spaces.space_utils import ( + convert_graph_space_to_dict_space, + convert_graph_to_dict, +) if TYPE_CHECKING: # imports useful only in annotations, may change according to releases @@ -99,6 +104,9 @@ class RayRLlib(Solver, Policies, Restorable, Maskable): ), ] + MASKABLE_ALGOS = ["APPO", "BC", "DQN", "Rainbow", "IMPALA", "MARWIL", "PPO"] + """The only algos being able to handle action masking in ray[rllib]==2.9.0.""" + def __init__( self, domain_factory: Callable[[], Domain], @@ -110,8 +118,8 @@ def __init__( Callable[[str, Optional["EpisodeV2"], Optional["RolloutWorker"]], str] ] = None, action_embed_sizes: Optional[dict[str, int]] = None, - config_kwargs: Optional[dict[str, Any]] = None, callback: Callable[[RayRLlib], bool] = lambda solver: False, + graph_feature_extractors_kwargs: Optional[dict[str, Any]] = None, **kwargs, ) -> None: """Initialize Ray RLlib. @@ -125,13 +133,13 @@ def __init__( policy_configs: The mapping from policy id (str) to additional config (dict) (leave default for single policy). policy_mapping_fn: The function mapping agent ids to policy ids (leave default for single policy). action_embed_sizes: The mapping from policy id (str) to action embedding size (only used with domains filtering allowed actions per state, default to 2) - config_kwargs: keyword arguments for the `AlgorithmConfigKwargs` class used to update programmatically the algorithm config. - Will be used by hyerparameters tuners like optuna. Should probably not be used directly by the user, - who could rather directly specify the correct `config`. callback: function called at each solver iteration. If returning true, the solve process stops and exit the current train iteration. However, if train_iterations > 1, another train loop will be entered after that. (One can code its callback in such a way that further training loop are stopped directly after that.) + graph_feature_extractors_kwargs: in case of graph observations, these are the kwargs to the GraphFreaturesExtractor model + used to extract features. (See skdecide.hub.solver.ray_rllib.gnn.models.torch.gnn.GraphFeaturesExtractor) + **kwargs: used to update the algo config with kwargs automatically filled by optuna. """ Solver.__init__(self, domain_factory=domain_factory) @@ -156,6 +164,10 @@ def __init__( raise RuntimeError( "Action embed size keys must be the same as policy config keys" ) + if graph_feature_extractors_kwargs is None: + self._graph_feature_extractors_kwargs = {} + else: + self._graph_feature_extractors_kwargs = graph_feature_extractors_kwargs ray.init(ignore_reinit_error=True) self._algo_callbacks: Optional[DefaultCallbacks] = None @@ -168,16 +180,23 @@ def __init__( self._wrapped_observation_space = domain.get_observation_space() # action masking? - domain = self._domain_factory() self._action_masking = ( (not isinstance(domain, UnrestrictedActions)) and all( isinstance(agent_action_space, EnumerableSpace) for agent_action_space in self._wrapped_action_space.values() ) - and self._algo_class.__name__ - # Only the following algos handle action masking in ray[rllib]==2.9.0 - in ["APPO", "BC", "DQN", "Rainbow", "IMPALA", "MARWIL", "PPO"] + and ( + self._algo_class.__name__ in RayRLlib.MASKABLE_ALGOS + or self._algo_class.__name__ + in [f"Graph{algo_name}" for algo_name in RayRLlib.MASKABLE_ALGOS] + ) + ) + + # graph obs? + self._is_graph_obs = _is_graph_space(self._wrapped_observation_space) + self._is_graph_multiinput_obs = _is_graph_multiinput_space( + (self._wrapped_observation_space) ) # Handle kwargs (potentially generated by optuna) @@ -268,7 +287,13 @@ def _load(self, path: str): self.set_callback() # ensure putting back actual callback def _init_algo(self) -> None: + # custom model? if self._action_masking: + if self._is_graph_obs or self._is_graph_multiinput_obs: + # let the observation pass as is + self._config.experimental( + _disable_preprocessor_api=True, + ) if self._config.get("framework") not in ["tf", "tf2", "torch"]: raise RuntimeError( "Action masking (invalid action filtering) for RLlib requires TensorFlow or PyTorch to be installed" @@ -290,6 +315,37 @@ def _init_algo(self) -> None: self._config.training( model={"vf_share_layers": True}, ) + elif self._is_graph_obs: + if self._config.get("framework") not in ["torch"]: + raise RuntimeError( + "Graph observation with RLlib requires PyTorch framework." + ) + ModelCatalog.register_custom_model( + "skdecide_rllib_graph_model", + GnnBasedModel + if self._config.get("framework") == "torch" + else NotProvided, + ) + # let the observation pass as is + self._config.experimental( + _disable_preprocessor_api=True, + ) + elif self._is_graph_multiinput_obs: + if self._config.get("framework") not in ["torch"]: + raise RuntimeError( + "Graph observation with RLlib requires PyTorch framework." + ) + ModelCatalog.register_custom_model( + "skdecide_rllib_graph_multiinput_model", + GraphComplexInputNetwork + if self._config.get("framework") == "torch" + else NotProvided, + ) + # let the observation pass as is + self._config.experimental( + _disable_preprocessor_api=True, + ) + self._wrap_action = lambda action: _wrap_action( action=action, wrapped_action_space=self._wrapped_action_space ) @@ -312,23 +368,31 @@ def _init_algo(self) -> None: # Overwrite multi-agent config pol_obs_spaces = ( { - self._policy_mapping_fn(k, None, None): v.unwrapped() - for k, v in self._wrapped_observation_space.items() + self._policy_mapping_fn(agent, None, None): _unwrap_agent_obs_space( + wrapped_observation_space=self._wrapped_observation_space, + agent=agent, + ) + for agent in self._wrapped_observation_space } if not self._action_masking else { - self._policy_mapping_fn(k, None, None): gym.spaces.Dict( + self._policy_mapping_fn(agent, None, None): gym.spaces.Dict( { - "true_obs": v.unwrapped(), + "true_obs": _unwrap_agent_obs_space( + wrapped_observation_space=self._wrapped_observation_space, + agent=agent, + ), "valid_avail_actions_mask": gym.spaces.Box( 0, 1, - shape=(len(self._wrapped_action_space[k].get_elements()),), + shape=( + len(self._wrapped_action_space[agent].get_elements()), + ), dtype=np.int8, ), } ) - for k, v in self._wrapped_observation_space.items() + for agent in self._wrapped_observation_space } ) pol_act_spaces = { @@ -336,13 +400,14 @@ def _init_algo(self) -> None: for k, v in self._wrapped_action_space.items() } - policies = ( - { - k: (None, pol_obs_spaces[k], pol_act_spaces[k], v or {}) - for k, v in self._policy_configs.items() - } - if not self._action_masking - else { + if self._action_masking: + if self._is_graph_obs or self._is_graph_multiinput_obs: + extra_custom_model_config_kwargs = { + "features_extractor": self._graph_feature_extractors_kwargs + } + else: + extra_custom_model_config_kwargs = {} + policies = { self._policy_mapping_fn(k, None, None): ( None, pol_obs_spaces[k], @@ -357,6 +422,7 @@ def _init_algo(self) -> None: "true_obs" ], "action_embed_size": action_embed_size, + **extra_custom_model_config_kwargs, }, }, }, @@ -364,7 +430,52 @@ def _init_algo(self) -> None: ) for k, action_embed_size in self._action_embed_sizes.items() } - ) + elif self._is_graph_obs: + policies = { + self._policy_mapping_fn(k, None, None): ( + None, + pol_obs_spaces[k], + pol_act_spaces[k], + { + **(v or {}), + **{ + "model": { + "custom_model": "skdecide_rllib_graph_model", + "custom_model_config": { + "features_extractor": self._graph_feature_extractors_kwargs, # kwargs for GraphFeaturesExtractor + }, + }, + }, + }, + ) + for k, v in self._policy_configs.items() + } + elif self._is_graph_multiinput_obs: + policies = { + self._policy_mapping_fn(k, None, None): ( + None, + pol_obs_spaces[k], + pol_act_spaces[k], + { + **(v or {}), + **{ + "model": { + "custom_model": "skdecide_rllib_graph_multiinput_model", + "custom_model_config": { + "features_extractor": self._graph_feature_extractors_kwargs, + # kwargs for GraphFeaturesExtractor + }, + }, + }, + }, + ) + for k, v in self._policy_configs.items() + } + else: + policies = { + k: (None, pol_obs_spaces[k], pol_act_spaces[k], v or {}) + for k, v in self._policy_configs.items() + } self._config.multi_agent( policies=policies, policy_mapping_fn=self._policy_mapping_fn, @@ -503,80 +614,58 @@ def __init__( self, domain: D, action_masking: bool, - unwrap_spaces: bool = True, ) -> None: """Initialize AsLegacyRLlibMultiAgentEnv. # Parameters domain: The scikit-decide domain to wrap as a RLlib multi-agent environment. action_masking: Boolean specifying whether action masking is used - unwrap_spaces: Boolean specifying whether the action & observation spaces should be unwrapped. """ self._domain = domain self._action_masking = action_masking - self._unwrap_spaces = unwrap_spaces - self._wrapped_observation_space = domain.get_observation_space() self._wrapped_action_space = domain.get_action_space() - if unwrap_spaces: - if not self._action_masking: - self.observation_space = gym.spaces.Dict( - { - k: agent_observation_space.unwrapped() - for k, agent_observation_space in self._wrapped_observation_space.items() - } - ) - else: - self.observation_space = gym.spaces.Dict( - { - k: gym.spaces.Dict( - { - "true_obs": agent_observation_space.unwrapped(), - "valid_avail_actions_mask": gym.spaces.Box( - 0, - 1, - shape=( - len( - self._wrapped_action_space[k].get_elements() - ), - ), - dtype=np.int8, - ), - } - ) - for k, agent_observation_space in self._wrapped_observation_space.items() - } - ) - self.action_space = gym.spaces.Dict( + self._wrapped_observation_space = domain.get_observation_space() + + if not self._action_masking: + self.observation_space = gym.spaces.Dict( { - k: agent_action_space.unwrapped() - for k, agent_action_space in self._wrapped_action_space.items() + agent: _unwrap_agent_obs_space( + wrapped_observation_space=self._wrapped_observation_space, + agent=agent, + ) + for agent in self._wrapped_observation_space } ) else: - if not self._action_masking: - self.observation_space = self._wrapped_observation_space - else: - self.observation_space = gym.spaces.Dict( - { - k: gym.spaces.Dict( - { - "true_obs": agent_observation_space, - "valid_avail_actions_mask": gym.spaces.Box( - 0, - 1, - shape=( - len( - self._wrapped_action_space[k].get_elements() - ), + self.observation_space = gym.spaces.Dict( + { + agent: gym.spaces.Dict( + { + "true_obs": _unwrap_agent_obs_space( + wrapped_observation_space=self._wrapped_observation_space, + agent=agent, + ), + "valid_avail_actions_mask": gym.spaces.Box( + 0, + 1, + shape=( + len( + self._wrapped_action_space[agent].get_elements() ), - dtype=np.int8, ), - } - ) - for k, agent_observation_space in self._wrapped_observation_space.items() - } - ) - self.action_space = self._wrapped_action_space + dtype=np.int8, + ), + } + ) + for agent in self._wrapped_observation_space + } + ) + self.action_space = gym.spaces.Dict( + { + k: agent_action_space.unwrapped() + for k, agent_action_space in self._wrapped_action_space.items() + } + ) def _wrap_action(self, action_dict: dict[str, Any]) -> dict[str, D.T_event]: return _wrap_action( @@ -676,14 +765,63 @@ def generate_rllibcallback_class( ) +def _unwrap_agent_obs_space( + wrapped_observation_space: dict[str, GymSpace[D.T_observation]], + agent: str, +) -> gym.Space: + unwrapped_agent_obs_space = wrapped_observation_space[agent].unwrapped() + if isinstance(unwrapped_agent_obs_space, gym.spaces.Graph): + return convert_graph_space_to_dict_space(unwrapped_agent_obs_space) + elif _is_graph_multiinput_unwrapped_agent_space(unwrapped_agent_obs_space): + return gym.spaces.Dict( + { + k: convert_graph_space_to_dict_space(subspace) + if isinstance(subspace, gym.spaces.Graph) + else subspace + for k, subspace in unwrapped_agent_obs_space.spaces.items() + } + ) + else: + return unwrapped_agent_obs_space + + def _unwrap_agent_obs( obs: dict[str, D.T_observation], agent: str, - wrapped_observation_space: dict[str, GymSpace], + wrapped_observation_space: dict[str, GymSpace[D.T_observation]], + transform_graph: bool = True, ) -> Any: - # Trick to get obs[agent]'s unwrapped value - # (no unwrapping method for single elements in enumerable spaces) - return next(iter(wrapped_observation_space[agent].to_unwrapped([obs[agent]]))) + unwrapped_agent_obs_space = wrapped_observation_space[agent].unwrapped() + if isinstance(unwrapped_agent_obs_space, gym.spaces.Graph) and transform_graph: + # get original unwrapped graph instance + unwrapped_agent_obs: gym.spaces.GraphInstance = _unwrap_agent_obs( + obs=obs, + agent=agent, + wrapped_observation_space=wrapped_observation_space, + transform_graph=False, + ) + # transform graph instance into a dict + return convert_graph_to_dict(unwrapped_agent_obs) + elif ( + _is_graph_multiinput_unwrapped_agent_space((unwrapped_agent_obs_space)) + and transform_graph + ): + unwrapped_agent_obs: dict[str, Any] = _unwrap_agent_obs( + obs=obs, + agent=agent, + wrapped_observation_space=wrapped_observation_space, + transform_graph=False, + ) + return { + k: convert_graph_to_dict(v) + if isinstance(v, gym.spaces.GraphInstance) + else v + for k, v in unwrapped_agent_obs.items() + } + else: + # Trick to get obs[agent]'s unwrapped value + # (no unwrapping method for single elements in enumerable spaces) + return next(iter(wrapped_observation_space[agent].to_unwrapped([obs[agent]]))) def _unwrap_agent_obs_with_action_masking( @@ -711,3 +849,23 @@ def _wrap_action( ) for agent, unwrapped_action in action.items() } + + +def _is_graph_space(space: dict[str, GymSpace[Any]]) -> bool: + return all( + isinstance(agent_observation_space.unwrapped(), gym.spaces.Graph) + for agent_observation_space in space.values() + ) + + +def _is_graph_multiinput_space(space: dict[str, GymSpace[Any]]) -> bool: + return all( + _is_graph_multiinput_unwrapped_agent_space(agent_observation_space.unwrapped()) + for agent_observation_space in space.values() + ) + + +def _is_graph_multiinput_unwrapped_agent_space(space: gym.spaces.Space) -> bool: + return isinstance(space, gym.spaces.Dict) and any( + isinstance(subspace, gym.spaces.Graph) for subspace in space.spaces.values() + ) diff --git a/tests/solvers/python/test_gnn_ray_rllib.py b/tests/solvers/python/test_gnn_ray_rllib.py new file mode 100644 index 0000000000..126675b6b0 --- /dev/null +++ b/tests/solvers/python/test_gnn_ray_rllib.py @@ -0,0 +1,675 @@ +import logging +import os +from typing import Any, Callable, Dict, Optional, Union + +import numpy as np +import ray +import torch as th +import torch_geometric as thg +from gymnasium.spaces import Box, Discrete, Graph, GraphInstance +from pytest_cases import fixture, param_fixture +from torch_geometric.nn import global_add_pool + +from skdecide.builders.domain import Renderable, UnrestrictedActions +from skdecide.core import Mask, Space, Value +from skdecide.domains import DeterministicPlanningDomain +from skdecide.hub.domain.maze import Maze +from skdecide.hub.domain.maze.maze import DEFAULT_MAZE, Action, State +from skdecide.hub.solver.ray_rllib import RayRLlib +from skdecide.hub.solver.ray_rllib.gnn.algorithms import GraphPPO +from skdecide.hub.space.gym import DictSpace, GymSpace, ListSpace +from skdecide.utils import rollout + + +class D(DeterministicPlanningDomain, UnrestrictedActions, Renderable): + T_state = GraphInstance # Type of states + T_observation = T_state # Type of observations + T_event = Action # Type of events + T_value = float # Type of transition values (rewards or costs) + T_predicate = bool # Type of logical checks + T_info = ( + None # Type of additional information given as part of an environment outcome + ) + + +class GraphMaze(D): + def __init__(self, maze_str: str = DEFAULT_MAZE, discrete_features: bool = False): + self.discrete_features = discrete_features + self.maze_domain = Maze(maze_str=maze_str) + np_wall = np.array(self.maze_domain._maze) + np_y = np.array( + [ + [(i) for j in range(self.maze_domain._num_cols)] + for i in range(self.maze_domain._num_rows) + ] + ) + np_x = np.array( + [ + [(j) for j in range(self.maze_domain._num_cols)] + for i in range(self.maze_domain._num_rows) + ] + ) + walls = np_wall.ravel() + coords = [i for i in zip(np_y.ravel(), np_x.ravel())] + np_node_id = np.reshape(range(len(walls)), np_wall.shape) + edge_links = [] + edges = [] + for i in range(self.maze_domain._num_rows): + for j in range(self.maze_domain._num_cols): + current_coord = (i, j) + if i > 0: + next_coord = (i - 1, j) + edge_links.append( + (np_node_id[current_coord], np_node_id[next_coord]) + ) + edges.append(np_wall[current_coord] * np_wall[next_coord]) + if i < self.maze_domain._num_rows - 1: + next_coord = (i + 1, j) + edge_links.append( + (np_node_id[current_coord], np_node_id[next_coord]) + ) + edges.append(np_wall[current_coord] * np_wall[next_coord]) + if j > 0: + next_coord = (i, j - 1) + edge_links.append( + (np_node_id[current_coord], np_node_id[next_coord]) + ) + edges.append(np_wall[current_coord] * np_wall[next_coord]) + if j < self.maze_domain._num_cols - 1: + next_coord = (i, j + 1) + edge_links.append( + (np_node_id[current_coord], np_node_id[next_coord]) + ) + edges.append(np_wall[current_coord] * np_wall[next_coord]) + self.edges = np.array(edges) + self.edge_links = np.array(edge_links) + self.walls = walls + self.node_ids = np_node_id + self.coords = coords + + def _mazestate2graph(self, state: State) -> D.T_state: + x, y = state + agent_presence = np.zeros(self.walls.shape, dtype=self.walls.dtype) + agent_presence[self.node_ids[y, x]] = 1 + nodes = np.stack([self.walls, agent_presence], axis=-1) + if self.discrete_features: + return GraphInstance( + nodes=nodes, edges=self.edges, edge_links=self.edge_links + ) + else: + return GraphInstance( + nodes=nodes, edges=self.edges[:, None], edge_links=self.edge_links + ) + + def _graph2mazestate(self, graph: D.T_state) -> State: + y, x = self.coords[graph.nodes[:, 1].nonzero()[0][0]] + return State(x=x, y=y) + + def _is_terminal(self, state: D.T_state) -> D.T_predicate: + return self.maze_domain._is_terminal(self._graph2mazestate(state)) + + def _get_next_state(self, memory: D.T_state, action: D.T_event) -> D.T_state: + maze_memory = self._graph2mazestate(memory) + maze_next_state = self.maze_domain._get_next_state( + memory=maze_memory, action=action + ) + return self._mazestate2graph(maze_next_state) + + def _get_transition_value( + self, + memory: D.T_state, + action: D.T_event, + next_state: Optional[D.T_state] = None, + ) -> Value[D.T_value]: + maze_memory = self._graph2mazestate(memory) + if next_state is None: + maze_next_state = None + else: + maze_next_state = self._graph2mazestate(next_state) + return self.maze_domain._get_transition_value( + memory=maze_memory, action=action, next_state=maze_next_state + ) + + def _get_action_space_(self) -> Space[D.T_event]: + return self.maze_domain._get_action_space_() + + def _get_goals_(self) -> Space[D.T_observation]: + return ListSpace([self._mazestate2graph(self.maze_domain._goal)]) + + def _is_goal( + self, observation: D.T_agent[D.T_observation] + ) -> D.T_agent[D.T_predicate]: + return self.maze_domain._is_goal(self._graph2mazestate(observation)) + + def _get_initial_state_(self) -> D.T_state: + return self._mazestate2graph(self.maze_domain._get_initial_state_()) + + def _get_observation_space_(self) -> Space[D.T_observation]: + if self.discrete_features: + return GymSpace( + Graph( + node_space=Box(low=0, high=1, shape=(2,), dtype=self.walls.dtype), + edge_space=Discrete(2), + ) + ) + else: + return GymSpace( + Graph( + node_space=Box(low=0, high=1, shape=(2,), dtype=self.walls.dtype), + edge_space=Box(low=0, high=1, shape=(1,), dtype=self.edges.dtype), + ) + ) + + def _render_from(self, memory: D.T_state, **kwargs: Any) -> Any: + maze_memory = self._graph2mazestate(memory) + self.maze_domain._render_from(memory=maze_memory, **kwargs) + + +class D(DeterministicPlanningDomain, Renderable): + T_state = GraphInstance # Type of states + T_observation = T_state # Type of observations + T_event = Action # Type of events + T_value = float # Type of transition values (rewards or costs) + T_predicate = bool # Type of logical checks + T_info = ( + None # Type of additional information given as part of an environment outcome + ) + + +class MaskedGraphMaze(D): + def __init__(self, maze_str: str = DEFAULT_MAZE, discrete_features: bool = False): + self.graph_maze = GraphMaze( + maze_str=maze_str, discrete_features=discrete_features + ) + + def _get_next_state( + self, + memory: D.T_memory[D.T_state], + action: D.T_agent[D.T_concurrency[D.T_event]], + ) -> D.T_state: + return self.graph_maze._get_next_state(memory=memory, action=action) + + def _get_transition_value( + self, + memory: D.T_memory[D.T_state], + action: D.T_agent[D.T_concurrency[D.T_event]], + next_state: Optional[D.T_state] = None, + ) -> D.T_agent[Value[D.T_value]]: + return self.graph_maze._get_transition_value( + memory=memory, action=action, next_state=next_state + ) + + def _is_terminal(self, state: D.T_state) -> D.T_agent[D.T_predicate]: + return self.graph_maze._is_terminal(state=state) + + def _get_action_space_(self) -> D.T_agent[Space[D.T_event]]: + return self.graph_maze._get_action_space_() + + def _get_goals_(self) -> D.T_agent[Space[D.T_observation]]: + return self.graph_maze._get_goals_() + + def _is_goal( + self, observation: D.T_agent[D.T_observation] + ) -> D.T_agent[D.T_predicate]: + return self.graph_maze._is_goal(observation=observation) + + def _get_initial_state_(self) -> D.T_state: + return self.graph_maze._get_initial_state_() + + def _get_observation_space_(self) -> D.T_agent[Space[D.T_observation]]: + return self.graph_maze._get_observation_space_() + + def _render_from(self, memory: D.T_memory[D.T_state], **kwargs: Any) -> Any: + return self.graph_maze._render_from(memory=memory, **kwargs) + + def _get_action_mask( + self, memory: Optional[D.T_memory[D.T_state]] = None + ) -> D.T_agent[Mask]: + # a different way to display applicable actions + # we could also override only _get_applicable_action() but it will be more computationally efficient to + # implement directly get_action_mask() + if memory is None: + memory = self._memory + mazestate_memory = self.graph_maze._graph2mazestate(memory) + return np.array( + [ + self.graph_maze._graph2mazestate( + self._get_next_state(action=action, memory=memory) + ) + != mazestate_memory + for action in self._get_action_space().get_elements() + ] + ) + + def _get_applicable_actions_from( + self, memory: D.T_memory[D.T_state] + ) -> D.T_agent[Space[D.T_event]]: + return ListSpace( + [ + action + for action, mask in zip( + self._get_action_space().get_elements(), + self._get_action_mask(memory=memory), + ) + if mask + ] + ) + + +class D(GraphMaze): + T_state = dict[str, Any] + + +class DictGraphMaze(D): + def _get_observation_space_(self) -> Space[D.T_observation]: + return DictSpace( + spaces=dict( + graph=super()._get_observation_space_(), + static=Box(low=0.0, high=1.0, dtype=np.float_), + ) + ) + + def _mazestate2graph(self, state: State) -> D.T_state: + return dict( + graph=super()._mazestate2graph(state), + static=np.array([0.5], dtype=np.float_), + ) + + def _graph2mazestate(self, graph: D.T_state) -> State: + return super()._graph2mazestate(graph["graph"]) + + +class D(DeterministicPlanningDomain, Renderable): + T_state = dict[str, Any] + T_observation = T_state # Type of observations + T_event = Action # Type of events + T_value = float # Type of transition values (rewards or costs) + T_predicate = bool # Type of logical checks + T_info = ( + None # Type of additional information given as part of an environment outcome + ) + + +class MaskedDictGraphMaze(D): + def __init__(self, maze_str: str = DEFAULT_MAZE, discrete_features: bool = False): + self.graph_maze = DictGraphMaze( + maze_str=maze_str, discrete_features=discrete_features + ) + + def _get_next_state( + self, + memory: D.T_memory[D.T_state], + action: D.T_agent[D.T_concurrency[D.T_event]], + ) -> D.T_state: + return self.graph_maze._get_next_state(memory=memory, action=action) + + def _get_transition_value( + self, + memory: D.T_memory[D.T_state], + action: D.T_agent[D.T_concurrency[D.T_event]], + next_state: Optional[D.T_state] = None, + ) -> D.T_agent[Value[D.T_value]]: + return self.graph_maze._get_transition_value( + memory=memory, action=action, next_state=next_state + ) + + def _is_terminal(self, state: D.T_state) -> D.T_agent[D.T_predicate]: + return self.graph_maze._is_terminal(state=state) + + def _get_action_space_(self) -> D.T_agent[Space[D.T_event]]: + return self.graph_maze._get_action_space_() + + def _get_goals_(self) -> D.T_agent[Space[D.T_observation]]: + return self.graph_maze._get_goals_() + + def _is_goal( + self, observation: D.T_agent[D.T_observation] + ) -> D.T_agent[D.T_predicate]: + return self.graph_maze._is_goal(observation=observation) + + def _get_initial_state_(self) -> D.T_state: + return self.graph_maze._get_initial_state_() + + def _get_observation_space_(self) -> D.T_agent[Space[D.T_observation]]: + return self.graph_maze._get_observation_space_() + + def _render_from(self, memory: D.T_memory[D.T_state], **kwargs: Any) -> Any: + return self.graph_maze._render_from(memory=memory, **kwargs) + + def _get_action_mask( + self, memory: Optional[D.T_memory[D.T_state]] = None + ) -> D.T_agent[Mask]: + # a different way to display applicable actions + # we could also override only _get_applicable_action() but it will be more computationally efficient to + # implement directly get_action_mask() + if memory is None: + memory = self._memory + mazestate_memory = self.graph_maze._graph2mazestate(memory) + return np.array( + [ + self.graph_maze._graph2mazestate( + self._get_next_state(action=action, memory=memory) + ) + != mazestate_memory + for action in self._get_action_space().get_elements() + ] + ) + + def _get_applicable_actions_from( + self, memory: D.T_memory[D.T_state] + ) -> D.T_agent[Space[D.T_event]]: + return ListSpace( + [ + action + for action, mask in zip( + self._get_action_space().get_elements(), + self._get_action_mask(memory=memory), + ) + if mask + ] + ) + + +discrete_features = param_fixture("discrete_features", [False, True]) + + +@fixture +def domain_factory(discrete_features): + return lambda: GraphMaze(discrete_features=discrete_features) + + +def test_observation_space(domain_factory): + domain = domain_factory() + assert domain.reset() in domain.get_observation_space() + rollout(domain=domain, num_episodes=1, max_steps=3, render=False, verbose=False) + + +def test_dict_observation_space(): + domain = DictGraphMaze() + assert domain.reset() in domain.get_observation_space() + rollout(domain=domain, num_episodes=1, max_steps=3, render=False, verbose=False) + + +@fixture +def ray_init(): + # add module test_gnn_ray_rllib and thus GraphMaze to ray runtimeenv + ray.init( + ignore_reinit_error=True, + runtime_env={"working_dir": os.path.dirname(__file__)}, + # local_mode=True, # uncomment this line and comment the one above to debug more easily + ) + + +@fixture +def graphppo_config(): + return ( + GraphPPO.get_default_config() + # set num of CPU<1 to avoid hanging for ever in github actions on macos 11 + .resources( + num_cpus_per_worker=0.5, + ) + # small number to increase speed of the unit test + .training(train_batch_size=256) + ) + + +def test_ppo(domain_factory, graphppo_config, ray_init): + solver_kwargs = dict( + algo_class=GraphPPO, train_iterations=1 # , gamma=0.95, train_batch_size_log2=8 + ) + with RayRLlib( + domain_factory=domain_factory, config=graphppo_config, **solver_kwargs + ) as solver: + + solver.solve() + rollout( + domain=domain_factory(), + solver=solver, + max_steps=100, + num_episodes=1, + render=False, + ) + + +class MyGNN(thg.nn.models.GAT): + + LOG_SENTENCE = "Using custom GNN." + + def __init__( + self, + in_channels: int, + hidden_channels: int, + num_layers: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + act: Union[str, Callable, None] = "relu", + act_first: bool = False, + act_kwargs: Optional[Dict[str, Any]] = None, + norm: Union[str, Callable, None] = None, + norm_kwargs: Optional[Dict[str, Any]] = None, + jk: Optional[str] = None, + **kwargs + ): + super().__init__( + in_channels, + hidden_channels, + num_layers, + out_channels, + dropout, + act, + act_first, + act_kwargs, + norm, + norm_kwargs, + jk, + **kwargs + ) + logging.warning(MyGNN.LOG_SENTENCE) + + +def test_ppo_user_gnn(domain_factory, graphppo_config, ray_init, caplog): + domain = domain_factory() + node_features_dim = int( + np.prod(domain.get_observation_space().unwrapped().node_space.shape) + ) + + solver_kwargs = dict( + algo_class=GraphPPO, + train_iterations=1, + graph_feature_extractors_kwargs=dict( + gnn_class=MyGNN, + gnn_kwargs=dict( + in_channels=node_features_dim, + hidden_channels=64, + num_layers=2, + dropout=0.2, + ), + gnn_out_dim=64, + features_dim=64, + ), + ) + with RayRLlib( + domain_factory=domain_factory, config=graphppo_config, **solver_kwargs + ) as solver: + with caplog.at_level(logging.WARNING): + solver.solve() + rollout( + domain=domain_factory(), + solver=solver, + max_steps=100, + num_episodes=1, + render=False, + ) + + assert MyGNN.LOG_SENTENCE in caplog.text + + +class MyReductionLayer(th.nn.Module): + LOG_SENTENCE = "Using custom reduction layer." + + def __init__(self, gnn_out_dim: int, features_dim: int): + super().__init__() + self.gnn_out_dim = gnn_out_dim + self.features_dim = features_dim + self.linear_layer = th.nn.Linear(gnn_out_dim, features_dim) + logging.warning(MyReductionLayer.LOG_SENTENCE) + + def forward(self, embedded_observations: thg.data.Data) -> th.Tensor: + x, edge_index, batch = ( + embedded_observations.x, + embedded_observations.edge_index, + embedded_observations.batch, + ) + h = global_add_pool(x, batch) + h = self.linear_layer(h).relu() + return h + + +def test_ppo_user_reduction_layer(domain_factory, graphppo_config, ray_init, caplog): + solver_kwargs = dict( + algo_class=GraphPPO, + train_iterations=1, + graph_feature_extractors_kwargs=dict( + gnn_out_dim=128, + features_dim=64, + reduction_layer_class=MyReductionLayer, + reduction_layer_kwargs=dict( + gnn_out_dim=128, + features_dim=64, + ), + ), + ) + with RayRLlib( + domain_factory=domain_factory, config=graphppo_config, **solver_kwargs + ) as solver: + with caplog.at_level(logging.WARNING): + solver.solve() + rollout( + domain=domain_factory(), + solver=solver, + max_steps=100, + num_episodes=1, + render=False, + ) + + assert MyReductionLayer.LOG_SENTENCE in caplog.text + + +def test_dict_ppo(graphppo_config, ray_init): + domain_factory = DictGraphMaze + solver_kwargs = dict( + algo_class=GraphPPO, train_iterations=1 # , gamma=0.95, train_batch_size_log2=8 + ) + with RayRLlib( + domain_factory=domain_factory, config=graphppo_config, **solver_kwargs + ) as solver: + + solver.solve() + rollout( + domain=domain_factory(), + solver=solver, + max_steps=100, + num_episodes=1, + render=False, + ) + + +def test_ppo_masked(discrete_features, graphppo_config, ray_init): + domain_factory = lambda: MaskedGraphMaze(discrete_features=discrete_features) + solver_kwargs = dict( + algo_class=GraphPPO, train_iterations=1 # , gamma=0.95, train_batch_size_log2=8 + ) + with RayRLlib( + domain_factory=domain_factory, config=graphppo_config, **solver_kwargs + ) as solver: + assert solver._action_masking and solver._is_graph_obs + solver.solve() + rollout( + domain=domain_factory(), + solver=solver, + max_steps=100, + num_episodes=1, + render=False, + ) + + +def test_dict_ppo_masked(discrete_features, graphppo_config, ray_init): + domain_factory = lambda: MaskedDictGraphMaze(discrete_features=discrete_features) + solver_kwargs = dict( + algo_class=GraphPPO, train_iterations=1 # , gamma=0.95, train_batch_size_log2=8 + ) + with RayRLlib( + domain_factory=domain_factory, config=graphppo_config, **solver_kwargs + ) as solver: + assert solver._action_masking and solver._is_graph_multiinput_obs + solver.solve() + rollout( + domain=domain_factory(), + solver=solver, + max_steps=100, + num_episodes=1, + render=False, + ) + + +def test_ppo_masked_user_gnn(graphppo_config, ray_init, caplog): + domain_factory = MaskedGraphMaze + node_features_dim = int( + np.prod(domain_factory().get_observation_space().unwrapped().node_space.shape) + ) + solver_kwargs = dict( + algo_class=GraphPPO, + train_iterations=1, + graph_feature_extractors_kwargs=dict( + gnn_class=MyGNN, + gnn_kwargs=dict( + in_channels=node_features_dim, + hidden_channels=64, + num_layers=2, + dropout=0.2, + ), + gnn_out_dim=64, + features_dim=64, + ), + ) + with RayRLlib( + domain_factory=domain_factory, config=graphppo_config, **solver_kwargs + ) as solver: + assert solver._action_masking and solver._is_graph_obs + with caplog.at_level(logging.WARNING): + solver.solve() + assert MyGNN.LOG_SENTENCE in caplog.text + + +def test_dict_ppo_masked_user_gnn(graphppo_config, ray_init, caplog): + domain_factory = MaskedDictGraphMaze + node_features_dim = int( + np.prod( + domain_factory() + .get_observation_space() + .unwrapped()["graph"] + .node_space.shape + ) + ) + solver_kwargs = dict( + algo_class=GraphPPO, + train_iterations=1, + graph_feature_extractors_kwargs=dict( + gnn_class=MyGNN, + gnn_kwargs=dict( + in_channels=node_features_dim, + hidden_channels=64, + num_layers=2, + dropout=0.2, + ), + gnn_out_dim=64, + features_dim=64, + ), + ) + with RayRLlib( + domain_factory=domain_factory, config=graphppo_config, **solver_kwargs + ) as solver: + assert solver._action_masking and solver._is_graph_multiinput_obs + with caplog.at_level(logging.WARNING): + solver.solve() + assert MyGNN.LOG_SENTENCE in caplog.text