Skip to content

Commit

Permalink
Use gnn without reduction to predict graph actions with sb3 ppo
Browse files Browse the repository at this point in the history
  • Loading branch information
nhuet committed Jan 28, 2025
1 parent d1f63c5 commit 86181fa
Show file tree
Hide file tree
Showing 15 changed files with 916 additions and 99 deletions.
200 changes: 200 additions & 0 deletions examples/gnn/full_gnn_jsp_sb3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
from typing import Any

import gymnasium as gym
import numpy as np
import scipy
from graph_jsp_env.disjunctive_graph_jsp_env import DisjunctiveGraphJspEnv
from gymnasium.spaces import Box, Graph, GraphInstance

from skdecide.builders.domain import (
FullyObservable,
Initializable,
Markovian,
Renderable,
Rewards,
Sequential,
SingleAgent,
)
from skdecide.core import Space, TransitionOutcome, Value
from skdecide.domains import Domain
from skdecide.hub.solver.stable_baselines import StableBaseline
from skdecide.hub.solver.stable_baselines.gnn.ppo.ppo import Graph2GraphPPO
from skdecide.hub.space.gym import GymSpace
from skdecide.utils import rollout


class D(
Domain,
SingleAgent,
Sequential,
Initializable,
Markovian,
FullyObservable,
Renderable,
Rewards,
):
T_state = GraphInstance # Type of states
T_observation = T_state # Type of observations
T_event = GraphInstance # Type of events
T_value = float # Type of transition values (rewards or costs)
T_info = None # Type of additional information in environment outcome


class GraphJspDomain(D):
_gym_env: DisjunctiveGraphJspEnv

def __init__(self, gym_env, deterministic=False):
self._gym_env = gym_env
if self._gym_env.normalize_observation_space:
self.n_nodes_features = gym_env.n_machines + 1
else:
self.n_nodes_features = 2
self.deterministic = deterministic

def _state_reset(self) -> D.T_state:
return self._np_state2graph_state(self._gym_env.reset()[0])

def _state_step(
self, action: D.T_event
) -> TransitionOutcome[D.T_state, Value[D.T_value], D.T_predicate, D.T_info]:
env_state, reward, terminated, truncated, info = self._gym_env.step(
self._graph_action2env_action(action, deterministic=self.deterministic)
)
state = self._np_state2graph_state(env_state)
if truncated:
info["TimeLimit.truncated"] = True
return TransitionOutcome(
state=state, value=Value(reward=reward), termination=terminated, info=info
)

def _get_applicable_actions_from(
self, memory: D.T_memory[D.T_state]
) -> D.T_agent[Space[D.T_event]]:
return NotImplementedError(
"`get_applicable_actions()` is not applicable :) here as the graph action space is continuous"
)

def _is_applicable_action_from(
self, action: D.T_agent[D.T_event], memory: D.T_memory[D.T_state]
) -> bool:
return self._graph_action2env_action(action) in self._gym_env.valid_actions()

def _get_observation_space_(self) -> Space[D.T_observation]:
if self._gym_env.normalize_observation_space:
original_graph_space = Graph(
node_space=Box(
low=0.0,
high=1.0,
shape=(self.n_nodes_features,),
dtype=np.float_,
),
edge_space=Box(low=0, high=1.0, dtype=np.float_),
)

else:
original_graph_space = Graph(
node_space=Box(
low=np.array([0, 0]),
high=np.array(
[
self._gym_env.n_machines,
self._gym_env.longest_processing_time,
]
),
dtype=np.int_,
),
edge_space=Box(
low=0, high=self._gym_env.longest_processing_time, dtype=np.int_
),
)
return GymSpace(original_graph_space)

def _get_action_space(self) -> Space[D.T_observation]:
if self._gym_env.normalize_observation_space:
edge_space = Box(low=0, high=1.0, dtype=np.float_)
else:
edge_space = Box(
low=0, high=self._gym_env.longest_processing_time, dtype=np.int_
)
original_graph_space = Graph(
node_space=Box(
low=-np.inf,
high=np.inf,
shape=(1,),
dtype=np.float_,
),
edge_space=edge_space,
)
return GymSpace(original_graph_space)

def _np_state2graph_state(self, np_state: np.array) -> GraphInstance:
if not self._gym_env.normalize_observation_space:
np_state = np_state.astype(np.int_)

nodes = np_state[:, -self.n_nodes_features :]
adj = np_state[:, : -self.n_nodes_features]
edge_starts_ends = adj.nonzero()
edge_links = np.transpose(edge_starts_ends)
edges = adj[edge_starts_ends][:, None]

return GraphInstance(nodes=nodes, edges=edges, edge_links=edge_links)

def _graph_action2env_action(
self, graph_action: gym.spaces.GraphInstance, deterministic=False
) -> int:
logits = graph_action.nodes
if deterministic:
node_idx = np.argmax(logits)
else:
probs = scipy.special.softmax(logits, axis=0).flatten()
node_idx = np.random.choice(a=len(probs), p=probs)
return int(node_idx)

def _render_from(self, memory: D.T_memory[D.T_state], **kwargs: Any) -> Any:
return self._gym_env.render(**kwargs)


jsp = np.array(
[
[
[0, 1, 2], # machines for job 0
[0, 2, 1], # machines for job 1
[0, 1, 2], # machines for job 2
],
[
[3, 2, 2], # task durations of job 0
[2, 1, 4], # task durations of job 1
[0, 4, 3], # task durations of job 2
],
]
)


domain_factory = lambda: GraphJspDomain(
gym_env=DisjunctiveGraphJspEnv(
jps_instance=jsp,
perform_left_shift_if_possible=True,
normalize_observation_space=False,
flat_observation_space=False,
action_mode="task",
)
)


with StableBaseline(
domain_factory=domain_factory,
algo_class=Graph2GraphPPO,
baselines_policy="GraphInputPolicy",
learn_config={
"total_timesteps": 10000,
},
# n_steps=512,
) as solver:
solver.solve()
rollout(
domain=domain_factory(),
solver=solver,
max_steps=30,
num_episodes=1,
render=True,
)
File renamed without changes.
File renamed without changes.
6 changes: 3 additions & 3 deletions skdecide/hub/solver/ray_rllib/gnn/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
is_graph_dict_multiinput,
is_masked_obs,
)
from skdecide.hub.solver.utils.gnn.torch_utils import graph_obs_to_thg_data
from skdecide.hub.solver.utils.gnn.torch_utils import graph_instance_to_thg_data


def convert_to_torch_tensor(
Expand Down Expand Up @@ -43,11 +43,11 @@ def convert_to_torch_tensor(
nested elements that are None because torch has no representation for that.
"""
if isinstance(x, gym.spaces.GraphInstance):
return graph_obs_to_thg_data(x, device=device, pin_memory=pin_memory)
return graph_instance_to_thg_data(x, device=device, pin_memory=pin_memory)
elif isinstance(x, list) and isinstance(x[0], gym.spaces.GraphInstance):
return thg.data.Batch.from_data_list(
[
graph_obs_to_thg_data(graph, device=device, pin_memory=pin_memory)
graph_instance_to_thg_data(graph, device=device, pin_memory=pin_memory)
for graph in x
]
)
Expand Down
86 changes: 42 additions & 44 deletions skdecide/hub/solver/stable_baselines/gnn/common/buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,47 +17,15 @@
ReplayBuffer,
RolloutBuffer,
)
from stable_baselines3.common.preprocessing import get_action_dim
from stable_baselines3.common.type_aliases import (
ReplayBufferSamples,
RolloutBufferSamples,
)
from stable_baselines3.common.utils import get_device
from stable_baselines3.common.vec_env import VecNormalize

from .utils import copy_graph_instance, graph_obs_to_thg_data


def get_obs_shape(
observation_space: spaces.Space,
) -> Union[tuple[int, ...], dict[str, tuple[int, ...]]]:
"""
Get the shape of the observation (useful for the buffers).
:param observation_space:
:return:
"""
if isinstance(observation_space, spaces.Box):
return observation_space.shape
elif isinstance(observation_space, spaces.Discrete):
# Observation is an int
return (1,)
elif isinstance(observation_space, spaces.MultiDiscrete):
# Number of discrete features
return (int(len(observation_space.nvec)),)
elif isinstance(observation_space, spaces.MultiBinary):
# Number of binary features
return observation_space.shape
elif isinstance(observation_space, spaces.Graph):
# Will not be used
return observation_space.node_space.shape + observation_space.edge_space.shape
elif isinstance(observation_space, spaces.Dict):
return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} # type: ignore[misc]

else:
raise NotImplementedError(
f"{observation_space} observation space is not supported"
)
from .preprocessing import get_action_dim, get_obs_shape
from .utils import copy_graph_instance, graph_instance_to_thg_data


class GraphBaseBuffer(BaseBuffer):
Expand All @@ -84,7 +52,7 @@ def _graphlist_to_torch(
) -> thg.data.Data:
return thg.data.Batch.from_data_list(
[
graph_obs_to_thg_data(graph_list[idx], device=self.device)
graph_instance_to_thg_data(graph_list[idx], device=self.device)
for idx in batch_inds
]
)
Expand All @@ -100,7 +68,7 @@ class GraphRolloutBuffer(RolloutBuffer, GraphBaseBuffer):
"""

observations: Union[list[spaces.GraphInstance], list[list[spaces.GraphInstance]]]
tensor_names = ["actions", "values", "log_probs", "advantages", "returns"]
tensor_names = ["values", "log_probs", "advantages", "returns"]

def reset(self) -> None:
assert isinstance(
Expand All @@ -123,11 +91,7 @@ def add(
log_prob = log_prob.reshape(-1, 1)

self._add_obs(obs)

# Same reshape, for actions
action = action.reshape((self.n_envs, self.action_dim))

self.actions[self.pos] = np.array(action).copy()
self._add_action(action)
self.rewards[self.pos] = np.array(reward).copy()
self.episode_starts[self.pos] = np.array(episode_start).copy()
self.values[self.pos] = value.clone().cpu().numpy().flatten()
Expand All @@ -136,12 +100,19 @@ def add(
if self.pos == self.buffer_size:
self.full = True

def _add_action(self, action: np.ndarray) -> None:
action = action.reshape((self.n_envs, self.action_dim))
self.actions[self.pos] = np.array(action).copy()

def _add_obs(self, obs: list[spaces.GraphInstance]) -> None:
self.observations.append([copy_graph_instance(g) for g in obs])

def _swap_and_flatten_obs(self) -> None:
self.observations = _swap_and_flatten_nested_list(self.observations)

def _swap_and_flatten_action(self) -> None:
self.actions = self.swap_and_flatten(self.actions)

def get(
self, batch_size: Optional[int] = None
) -> Generator[RolloutBufferSamples, None, None]:
Expand All @@ -150,6 +121,7 @@ def get(
# Prepare the data
if not self.generator_ready:
self._swap_and_flatten_obs()
self._swap_and_flatten_action()
for tensor in self.tensor_names:
self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor])
self.generator_ready = True
Expand All @@ -167,18 +139,45 @@ def _get_samples(
self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None
) -> RolloutBufferSamples:
observations = self._get_observations_samples(batch_inds)
actions = self._get_actions_samples(batch_inds)
data = (
self.actions[batch_inds],
self.values[batch_inds].flatten(),
self.log_probs[batch_inds].flatten(),
self.advantages[batch_inds].flatten(),
self.returns[batch_inds].flatten(),
)
return RolloutBufferSamples(observations, *tuple(map(self.to_torch, data)))
return RolloutBufferSamples(
observations, actions, *tuple(map(self.to_torch, data))
)

def _get_observations_samples(self, batch_inds: np.ndarray) -> thg.data.Data:
return self._graphlist_to_torch(self.observations, batch_inds=batch_inds)

def _get_actions_samples(self, batch_inds: np.ndarray) -> th.Tensor:
return self.to_torch(self.actions[batch_inds])


class Graph2GraphRolloutBuffer(GraphRolloutBuffer):
"""Rollout buffer when both observations and actions are graphs."""

actions: Union[list[spaces.GraphInstance], list[list[spaces.GraphInstance]]]

def reset(self) -> None:
assert isinstance(
self.action_space, spaces.Graph
), "Graph2GraphRolloutBuffer must be used with Graph action space only"
super().reset()
self.actions = list()

def _add_action(self, action: list[spaces.GraphInstance]) -> None:
self.actions.append([copy_graph_instance(g) for g in action])

def _swap_and_flatten_action(self) -> None:
self.actions = _swap_and_flatten_nested_list(self.actions)

def _get_actions_samples(self, batch_inds: np.ndarray) -> thg.data.Data:
return self._graphlist_to_torch(self.actions, batch_inds=batch_inds)


class DictGraphRolloutBuffer(GraphRolloutBuffer, DictRolloutBuffer):

Expand Down Expand Up @@ -257,7 +256,6 @@ def _get_observations_samples(
class _BaseMaskableRolloutBuffer:

tensor_names = [
"actions",
"values",
"log_probs",
"advantages",
Expand Down
Loading

0 comments on commit 86181fa

Please sign in to comment.