diff --git a/skdecide/hub/solver/ray_rllib/gnn/models/torch/gnn.py b/skdecide/hub/solver/ray_rllib/gnn/models/torch/gnn.py index 7f235126db..916485cc00 100644 --- a/skdecide/hub/solver/ray_rllib/gnn/models/torch/gnn.py +++ b/skdecide/hub/solver/ray_rllib/gnn/models/torch/gnn.py @@ -8,11 +8,11 @@ from ray.rllib.utils.typing import ModelConfigDict from torch import nn -from skdecide.hub.solver.ray_rllib.gnn.torch_layers import GraphFeaturesExtractor from skdecide.hub.solver.ray_rllib.gnn.utils.spaces.space_utils import ( convert_dict_space_to_graph_space, is_graph_dict_space, ) +from skdecide.hub.solver.utils.gnn.torch_layers import GraphFeaturesExtractor class GnnBasedModel(TorchModelV2, nn.Module): diff --git a/skdecide/hub/solver/ray_rllib/gnn/policy/sample_batch.py b/skdecide/hub/solver/ray_rllib/gnn/policy/sample_batch.py deleted file mode 100644 index ffe6266b88..0000000000 --- a/skdecide/hub/solver/ray_rllib/gnn/policy/sample_batch.py +++ /dev/null @@ -1,116 +0,0 @@ -from numbers import Number -from typing import Any, Optional, Union - -import gymnasium as gym -import numpy as np -import tree -from ray.rllib import SampleBatch -from ray.rllib.policy.sample_batch import attempt_count_timesteps, tf, torch -from ray.rllib.utils.typing import ViewRequirementsDict - - -def _pop_graph_items( - full_dict: dict[Any, Any] -) -> dict[Any, Union[gym.spaces.GraphInstance, list[gym.spaces.GraphInstance]]]: - graph_dict = {} - for k, v in full_dict.items(): - if isinstance(v, gym.spaces.GraphInstance) or ( - isinstance(v, list) and isinstance(v[0], gym.spaces.GraphInstance) - ): - graph_dict[k] = v - for k in graph_dict: - full_dict.pop(k) - return graph_dict - - -def _split_graph_requirements( - full_dict: ViewRequirementsDict, -) -> tuple[ViewRequirementsDict, ViewRequirementsDict]: - graph_dict = {} - for k, v in full_dict.items(): - if isinstance(v.space, gym.spaces.Graph): - graph_dict[k] = v - wo_graph_dict = {k: v for k, v in full_dict.items() if k not in graph_dict} - return graph_dict, wo_graph_dict - - -class GraphSampleBatch(SampleBatch): - def __init__(self, *args, **kwargs): - """Constructs a sample batch with possibly graph obs. - - See `ray.rllib.SampleBatch` for more information. - - """ - # split graph samples from others. - dict_graphs = _pop_graph_items(kwargs) - dict_from_args = dict(*args) - dict_graphs.update(_pop_graph_items(dict_from_args)) - - super().__init__(dict_from_args, **kwargs) - super().update(dict_graphs) - - def copy(self, shallow: bool = False) -> "SampleBatch": - """Creates a deep or shallow copy of this SampleBatch and returns it. - - Args: - shallow: Whether the copying should be done shallowly. - - Returns: - A deep or shallow copy of this SampleBatch object. - """ - copy_ = dict(self) - data = tree.map_structure( - lambda v: ( - np.array(v, copy=not shallow) if isinstance(v, np.ndarray) else v - ), - copy_, - ) - copy_ = GraphSampleBatch( - data, - _time_major=self.time_major, - _zero_padded=self.zero_padded, - _max_seq_len=self.max_seq_len, - _num_grad_updates=self.num_grad_updates, - ) - copy_.set_get_interceptor(self.get_interceptor) - copy_.added_keys = self.added_keys - copy_.deleted_keys = self.deleted_keys - copy_.accessed_keys = self.accessed_keys - return copy_ - - def get_single_step_input_dict( - self, - view_requirements: ViewRequirementsDict, - index: Union[str, int] = "last", - ) -> "SampleBatch": - ( - view_requirements_graphs, - view_requirements_wo_graphs, - ) = _split_graph_requirements(view_requirements) - # w/o graphs - sample = GraphSampleBatch( - super().get_single_step_input_dict(view_requirements_wo_graphs, index) - ) - # handle graphs - last_mappings = { - SampleBatch.OBS: SampleBatch.NEXT_OBS, - SampleBatch.PREV_ACTIONS: SampleBatch.ACTIONS, - SampleBatch.PREV_REWARDS: SampleBatch.REWARDS, - } - for view_col, view_req in view_requirements_graphs.items(): - if view_req.used_for_compute_actions is False: - continue - - # Create batches of size 1 (single-agent input-dict). - data_col = view_req.data_col or view_col - if index == "last": - data_col = last_mappings.get(data_col, data_col) - if view_req.shift_from is not None: - raise NotImplementedError() - else: - sample[view_col] = self[data_col][-1:] - else: - sample[view_col] = self[data_col][ - index : index + 1 if index != -1 else None - ] - return sample diff --git a/skdecide/hub/solver/ray_rllib/gnn/policy/torch_mixins.py b/skdecide/hub/solver/ray_rllib/gnn/policy/torch_mixins.py deleted file mode 100644 index 775bc3eee2..0000000000 --- a/skdecide/hub/solver/ray_rllib/gnn/policy/torch_mixins.py +++ /dev/null @@ -1,26 +0,0 @@ -from ray.rllib.policy.torch_mixins import ValueNetworkMixin - -from skdecide.hub.solver.ray_rllib.gnn.policy.sample_batch import GraphSampleBatch - - -class ValueNetworkGraphMixin(ValueNetworkMixin): - def __init__(self, config): - if config.get("use_gae") or config.get("vtrace"): - # Input dict is provided to us automatically via the Model's - # requirements. It's a single-timestep (last one in trajectory) - # input_dict. - - def value(**input_dict): - input_dict = GraphSampleBatch(input_dict) - input_dict = self._lazy_tensor_dict(input_dict) - model_out, _ = self.model(input_dict) - # [0] = remove the batch dim. - return self.model.value_function()[0].item() - - # When not doing GAE, we do not require the value function's output. - else: - - def value(*args, **kwargs): - return 0.0 - - self._value = value diff --git a/skdecide/hub/solver/ray_rllib/gnn/utils/torch_utils.py b/skdecide/hub/solver/ray_rllib/gnn/utils/torch_utils.py index f74be0e97e..a9e5549761 100644 --- a/skdecide/hub/solver/ray_rllib/gnn/utils/torch_utils.py +++ b/skdecide/hub/solver/ray_rllib/gnn/utils/torch_utils.py @@ -15,30 +15,7 @@ is_graph_dict_multiinput, is_masked_obs, ) - - -def graph_obs_to_thg_data( - obs: gym.spaces.GraphInstance, - device: Optional[th.device] = None, - pin_memory: bool = False, -) -> thg.data.Data: - # Node features - flatten_node_features = obs.nodes.reshape((len(obs.nodes), -1)) - x = th.tensor(flatten_node_features).float() - # Edge features - if obs.edges is None: - edge_attr = None - else: - flatten_edge_features = obs.edges.reshape((len(obs.edges), -1)) - edge_attr = th.tensor(flatten_edge_features).float() - edge_index = th.tensor(obs.edge_links, dtype=th.long).t().contiguous().view(2, -1) - # thg.Data - data = thg.data.Data(x=x, edge_index=edge_index, edge_attr=edge_attr) - # Pin the tensor's memory (for faster transfer to GPU later). - if pin_memory and th.cuda.is_available(): - data.pin_memory() - - return data if device is None else data.to(device) +from skdecide.hub.solver.utils.gnn.torch_utils import graph_obs_to_thg_data def convert_to_torch_tensor( diff --git a/skdecide/hub/solver/stable_baselines/gnn/common/torch_layers.py b/skdecide/hub/solver/stable_baselines/gnn/common/torch_layers.py index 18aa0d7612..91cbf5aacd 100644 --- a/skdecide/hub/solver/stable_baselines/gnn/common/torch_layers.py +++ b/skdecide/hub/solver/stable_baselines/gnn/common/torch_layers.py @@ -1,13 +1,13 @@ from typing import Any, Optional, Union import gymnasium as gym -import numpy as np import torch as th import torch_geometric as thg from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, NatureCNN from torch import nn -from torch_geometric.nn import global_max_pool + +from skdecide.hub.solver.utils.gnn import torch_layers class GraphFeaturesExtractor(BaseFeaturesExtractor): @@ -20,6 +20,10 @@ class GraphFeaturesExtractor(BaseFeaturesExtractor): - gnn: a 2-layers GCN - reduction layer: global_max_pool + linear layer + relu + This merely wraps `skdecide.hub.solver.utils.gnn.torch_layers.GraphFeaturesExtractor` to + makes it a `stable_baselines3.common.torch_layers.BaseFeaturesExtractor`. See the former documentation + for more precisions about its arguments. + Args: observation_space: features_dim: Number of extracted features @@ -45,75 +49,19 @@ def __init__( reduction_layer_class: Optional[type[nn.Module]] = None, reduction_layer_kwargs: Optional[dict[str, Any]] = None, ): - - super().__init__(observation_space, features_dim=features_dim) - - if gnn_out_dim is None: - if gnn_class is None: - gnn_out_dim = 2 * features_dim - else: - raise ValueError( - "`gnn_out_dim` cannot be None if `gnn` is not None, " - "and should match `gnn` output." - ) - - if gnn_class is None: - node_features_dim = int(np.prod(observation_space.node_space.shape)) - self.gnn = thg.nn.models.GCN( - in_channels=node_features_dim, - hidden_channels=gnn_out_dim, - num_layers=2, - dropout=0.2, - ) - else: - if gnn_kwargs is None: - gnn_kwargs = {} - self.gnn = gnn_class(**gnn_kwargs) - - if reduction_layer_class is None: - self.reduction_layer = _DefaultReductionLayer( - gnn_out_dim=gnn_out_dim, features_dim=features_dim - ) - else: - if reduction_layer_kwargs is None: - reduction_layer_kwargs = {} - self.reduction_layer = reduction_layer_class(**reduction_layer_kwargs) + super().__init__(observation_space=observation_space, features_dim=features_dim) + self._extractor = torch_layers.GraphFeaturesExtractor( + observation_space=observation_space, + features_dim=features_dim, + gnn_out_dim=gnn_out_dim, + gnn_class=gnn_class, + gnn_kwargs=gnn_kwargs, + reduction_layer_class=reduction_layer_class, + reduction_layer_kwargs=reduction_layer_kwargs, + ) def forward(self, observations: thg.data.Data) -> th.Tensor: - x, edge_index, edge_attr, batch = ( - observations.x, - observations.edge_index, - observations.edge_attr, - observations.batch, - ) - # construct edge weights, for GNNs needing it, as the first edge feature - edge_weight = edge_attr[:, 0] - h = self.gnn( - x=x, edge_index=edge_index, edge_weight=edge_weight, edge_attr=edge_attr - ) - embedded_observations = thg.data.Data( - x=h, edge_index=edge_index, edge_attr=edge_attr, batch=batch - ) - h = self.reduction_layer(embedded_observations=embedded_observations) - return h - - -class _DefaultReductionLayer(nn.Module): - def __init__(self, gnn_out_dim: int, features_dim: int): - super().__init__() - self.gnn_out_dim = gnn_out_dim - self.features_dim = features_dim - self.linear_layer = nn.Linear(gnn_out_dim, features_dim) - - def forward(self, embedded_observations: thg.data.Data) -> th.Tensor: - x, edge_index, batch = ( - embedded_observations.x, - embedded_observations.edge_index, - embedded_observations.batch, - ) - h = global_max_pool(x, batch) - h = self.linear_layer(h).relu() - return h + return self._extractor.forward(observations=observations) class CombinedFeaturesExtractor(BaseFeaturesExtractor): diff --git a/skdecide/hub/solver/stable_baselines/gnn/common/utils.py b/skdecide/hub/solver/stable_baselines/gnn/common/utils.py index e0abcef987..0a58860e06 100644 --- a/skdecide/hub/solver/stable_baselines/gnn/common/utils.py +++ b/skdecide/hub/solver/stable_baselines/gnn/common/utils.py @@ -6,6 +6,8 @@ import torch as th import torch_geometric as thg +from skdecide.hub.solver.utils.gnn.torch_utils import graph_obs_to_thg_data + SubObsType = Union[np.ndarray, gym.spaces.GraphInstance, list[gym.spaces.GraphInstance]] ObsType = Union[SubObsType, dict[str, SubObsType]] TorchSubObsType = Union[th.Tensor, thg.data.Data] @@ -27,22 +29,6 @@ def copy_np_array_or_list_of_graph_instances( return np.copy(obs) -def graph_obs_to_thg_data( - obs: gym.spaces.GraphInstance, device: th.device -) -> thg.data.Data: - # Node features - flatten_node_features = obs.nodes.reshape((len(obs.nodes), -1)) - x = th.tensor(flatten_node_features).float() - # Edge features - if obs.edges is None: - edge_attr = None - else: - flatten_edge_features = obs.edges.reshape((len(obs.edges), -1)) - edge_attr = th.tensor(flatten_edge_features).float() - edge_index = th.tensor(obs.edge_links, dtype=th.long).t().contiguous().view(2, -1) - return thg.data.Data(x=x, edge_index=edge_index, edge_attr=edge_attr).to(device) - - def obs_as_tensor( obs: ObsType, device: th.device, diff --git a/skdecide/hub/solver/utils/__init__.py b/skdecide/hub/solver/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/skdecide/hub/solver/utils/gnn/__init__.py b/skdecide/hub/solver/utils/gnn/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/skdecide/hub/solver/ray_rllib/gnn/torch_layers.py b/skdecide/hub/solver/utils/gnn/torch_layers.py similarity index 98% rename from skdecide/hub/solver/ray_rllib/gnn/torch_layers.py rename to skdecide/hub/solver/utils/gnn/torch_layers.py index 8180b35f88..3949565343 100644 --- a/skdecide/hub/solver/ray_rllib/gnn/torch_layers.py +++ b/skdecide/hub/solver/utils/gnn/torch_layers.py @@ -77,7 +77,7 @@ def __init__( reduction_layer_kwargs = {} self.reduction_layer = reduction_layer_class(**reduction_layer_kwargs) - def forward(self, observations) -> th.Tensor: + def forward(self, observations: thg.data.Data) -> th.Tensor: x, edge_index, edge_attr, batch = ( observations.x, observations.edge_index, diff --git a/skdecide/hub/solver/utils/gnn/torch_utils.py b/skdecide/hub/solver/utils/gnn/torch_utils.py new file mode 100644 index 0000000000..e020dcb3fa --- /dev/null +++ b/skdecide/hub/solver/utils/gnn/torch_utils.py @@ -0,0 +1,29 @@ +from typing import Optional + +import gymnasium as gym +import torch as th +import torch_geometric as thg + + +def graph_obs_to_thg_data( + obs: gym.spaces.GraphInstance, + device: Optional[th.device] = None, + pin_memory: bool = False, +) -> thg.data.Data: + # Node features + flatten_node_features = obs.nodes.reshape((len(obs.nodes), -1)) + x = th.tensor(flatten_node_features).float() + # Edge features + if obs.edges is None: + edge_attr = None + else: + flatten_edge_features = obs.edges.reshape((len(obs.edges), -1)) + edge_attr = th.tensor(flatten_edge_features).float() + edge_index = th.tensor(obs.edge_links, dtype=th.long).t().contiguous().view(2, -1) + # thg.Data + data = thg.data.Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + # Pin the tensor's memory (for faster transfer to GPU later). + if pin_memory and th.cuda.is_available(): + data.pin_memory() + + return data if device is None else data.to(device)