Skip to content

Commit

Permalink
Use GNN without fixed features reduction for graph -> node case (#466)
Browse files Browse the repository at this point in the history
- observations = graph
- actions = node of the graph
  action space= Discrete space, an action is represented by the node index

We also:
- rename `graph_obs_to_thg_data` into `graph_instance_to_thg_data`, and add
  the reverse `thg_data_to_graph_instance`
- move `get_obs_shape` in preprocessing module
- add an utility function to extract torch model parameters values
- add a debug option to policy to store initial model parameters and
  check that model parameters have actually been updated during training
  (see examples/gnn/gnn_graph2node_sb3_jsp.py)
- move gnn examples into a new directory examples/gnn
  • Loading branch information
nhuet authored Feb 7, 2025
1 parent 95dabf2 commit ae1ba6a
Show file tree
Hide file tree
Showing 13 changed files with 779 additions and 52 deletions.
217 changes: 217 additions & 0 deletions examples/gnn/gnn_graph2node_sb3_jsp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
from typing import Any

import numpy as np
from graph_jsp_env.disjunctive_graph_jsp_env import DisjunctiveGraphJspEnv
from gymnasium.spaces import Box, Graph, GraphInstance

from skdecide.builders.domain import (
FullyObservable,
Initializable,
Markovian,
Renderable,
Rewards,
Sequential,
SingleAgent,
)
from skdecide.core import Space, TransitionOutcome, Value
from skdecide.domains import Domain
from skdecide.hub.solver.stable_baselines import StableBaseline
from skdecide.hub.solver.stable_baselines.gnn.ppo.ppo import Graph2NodePPO
from skdecide.hub.solver.utils.gnn.torch_utils import extract_module_parameters_values
from skdecide.hub.space.gym import GymSpace, ListSpace
from skdecide.utils import rollout


class D(
Domain,
SingleAgent,
Sequential,
Initializable,
Markovian,
FullyObservable,
Renderable,
Rewards,
):
T_state = GraphInstance # Type of states
T_observation = T_state # Type of observations
T_event = int # Type of events
T_value = float # Type of transition values (rewards or costs)
T_info = None # Type of additional information in environment outcome


class GraphJspDomain(D):
_gym_env: DisjunctiveGraphJspEnv

def __init__(self, gym_env, deterministic=False):
self._gym_env = gym_env
if self._gym_env.normalize_observation_space:
self.n_nodes_features = gym_env.n_machines + 1
else:
self.n_nodes_features = 2
self.deterministic = deterministic

def _state_reset(self) -> D.T_state:
return self._np_state2graph_state(self._gym_env.reset()[0])

def _state_step(
self, action: D.T_event
) -> TransitionOutcome[D.T_state, Value[D.T_value], D.T_predicate, D.T_info]:
env_state, reward, terminated, truncated, info = self._gym_env.step(action)
state = self._np_state2graph_state(env_state)
if truncated:
info["TimeLimit.truncated"] = True
return TransitionOutcome(
state=state, value=Value(reward=reward), termination=terminated, info=info
)

def _get_applicable_actions_from(
self, memory: D.T_memory[D.T_state]
) -> D.T_agent[Space[D.T_event]]:
return ListSpace(np.nonzero(self._gym_env.valid_action_mask())[0])

def _get_observation_space_(self) -> Space[D.T_observation]:
if self._gym_env.normalize_observation_space:
original_graph_space = Graph(
node_space=Box(
low=0.0,
high=1.0,
shape=(self.n_nodes_features,),
dtype=np.float_,
),
edge_space=Box(low=0, high=1.0, dtype=np.float_),
)

else:
original_graph_space = Graph(
node_space=Box(
low=np.array([0, 0]),
high=np.array(
[
self._gym_env.n_machines,
self._gym_env.longest_processing_time,
]
),
dtype=np.int_,
),
edge_space=Box(
low=0, high=self._gym_env.longest_processing_time, dtype=np.int_
),
)
return GymSpace(original_graph_space)

def _get_action_space_(self) -> Space[D.T_observation]:
return GymSpace(self._gym_env.action_space)

def _np_state2graph_state(self, np_state: np.array) -> GraphInstance:
if not self._gym_env.normalize_observation_space:
np_state = np_state.astype(np.int_)

nodes = np_state[:, -self.n_nodes_features :]
adj = np_state[:, : -self.n_nodes_features]
edge_starts_ends = adj.nonzero()
edge_links = np.transpose(edge_starts_ends)
edges = adj[edge_starts_ends][:, None]

return GraphInstance(nodes=nodes, edges=edges, edge_links=edge_links)

def _render_from(self, memory: D.T_memory[D.T_state], **kwargs: Any) -> Any:
return self._gym_env.render(**kwargs)


jsp = np.array(
[
[
[0, 1, 2], # machines for job 0
[0, 2, 1], # machines for job 1
[0, 1, 2], # machines for job 2
],
[
[3, 2, 2], # task durations of job 0
[2, 1, 4], # task durations of job 1
[0, 4, 3], # task durations of job 2
],
]
)


domain_factory = lambda: GraphJspDomain(
gym_env=DisjunctiveGraphJspEnv(
jps_instance=jsp,
perform_left_shift_if_possible=True,
normalize_observation_space=False,
flat_observation_space=False,
action_mode="task",
)
)


with StableBaseline(
domain_factory=domain_factory,
algo_class=Graph2NodePPO,
baselines_policy="GraphInputPolicy",
policy_kwargs=dict(debug=True),
learn_config={
"total_timesteps": 10_000,
},
) as solver:
solver.solve()
rollout(
domain=domain_factory(),
solver=solver,
max_steps=30,
num_episodes=1,
render=True,
)


# action gnn parameters
initial_parameters = solver._algo.policy.action_net.initial_parameters
final_parameters = extract_module_parameters_values(solver._algo.policy.action_net)
same_parameters: dict[str, bool] = {
name: (initial_parameters[name] == final_parameters[name]).all()
for name in final_parameters
}

if all(same_parameters.values()):
print("Action full GNN parameters have not changed during training!")
else:
unchanging_parameters = [name for name, same in same_parameters.items() if same]
print(
f"Action full GNN parameter unchanged after training: {unchanging_parameters}"
)
changing_parameters = [name for name, same in same_parameters.items() if not same]
print(
f"Action full GNN parameters having changed during training: {changing_parameters}"
)
diff_parameters = {
name: abs(initial_parameters[name] - final_parameters[name]).max()
for name in changing_parameters
}
print(diff_parameters)

# value gnn parameters
initial_parameters = solver._algo.policy.features_extractor.extractor.initial_parameters
final_parameters = extract_module_parameters_values(
solver._algo.policy.features_extractor.extractor
)
same_parameters: dict[str, bool] = {
name: (initial_parameters[name] == final_parameters[name]).all()
for name in final_parameters
}

if all(same_parameters.values()):
print("Value GNN feature extractor parameters have not changed during training!")
else:
unchanging_parameters = [name for name, same in same_parameters.items() if same]
print(
f"Value GNN feature extracto parameter unchanged after training: {unchanging_parameters}"
)
changing_parameters = [name for name, same in same_parameters.items() if not same]
print(
f"Value GNN feature extractor parameters having changed during training: {changing_parameters}"
)
diff_parameters = {
name: abs(initial_parameters[name] - final_parameters[name]).max()
for name in changing_parameters
}
print(diff_parameters)
File renamed without changes.
File renamed without changes.
6 changes: 3 additions & 3 deletions skdecide/hub/solver/ray_rllib/gnn/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
is_graph_dict_multiinput,
is_masked_obs,
)
from skdecide.hub.solver.utils.gnn.torch_utils import graph_obs_to_thg_data
from skdecide.hub.solver.utils.gnn.torch_utils import graph_instance_to_thg_data


def convert_to_torch_tensor(
Expand Down Expand Up @@ -43,11 +43,11 @@ def convert_to_torch_tensor(
nested elements that are None because torch has no representation for that.
"""
if isinstance(x, gym.spaces.GraphInstance):
return graph_obs_to_thg_data(x, device=device, pin_memory=pin_memory)
return graph_instance_to_thg_data(x, device=device, pin_memory=pin_memory)
elif isinstance(x, list) and isinstance(x[0], gym.spaces.GraphInstance):
return thg.data.Batch.from_data_list(
[
graph_obs_to_thg_data(graph, device=device, pin_memory=pin_memory)
graph_instance_to_thg_data(graph, device=device, pin_memory=pin_memory)
for graph in x
]
)
Expand Down
37 changes: 3 additions & 34 deletions skdecide/hub/solver/stable_baselines/gnn/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,39 +25,8 @@
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import VecNormalize

from .utils import copy_graph_instance, graph_obs_to_thg_data


def get_obs_shape(
observation_space: spaces.Space,
) -> Union[tuple[int, ...], dict[str, tuple[int, ...]]]:
"""
Get the shape of the observation (useful for the buffers).
:param observation_space:
:return:
"""
if isinstance(observation_space, spaces.Box):
return observation_space.shape
elif isinstance(observation_space, spaces.Discrete):
# Observation is an int
return (1,)
elif isinstance(observation_space, spaces.MultiDiscrete):
# Number of discrete features
return (int(len(observation_space.nvec)),)
elif isinstance(observation_space, spaces.MultiBinary):
# Number of binary features
return observation_space.shape
elif isinstance(observation_space, spaces.Graph):
# Will not be used
return observation_space.node_space.shape + observation_space.edge_space.shape
elif isinstance(observation_space, spaces.Dict):
return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} # type: ignore[misc]

else:
raise NotImplementedError(
f"{observation_space} observation space is not supported"
)
from .preprocessing import get_obs_shape
from .utils import copy_graph_instance, graph_instance_to_thg_data


class GraphBaseBuffer(BaseBuffer):
Expand All @@ -84,7 +53,7 @@ def _graphlist_to_torch(
) -> thg.data.Data:
return thg.data.Batch.from_data_list(
[
graph_obs_to_thg_data(graph_list[idx], device=self.device)
graph_instance_to_thg_data(graph_list[idx], device=self.device)
for idx in batch_inds
]
)
Expand Down
Loading

0 comments on commit ae1ba6a

Please sign in to comment.