Skip to content

Commit

Permalink
Fix examples
Browse files Browse the repository at this point in the history
  • Loading branch information
nhuet committed May 21, 2024
1 parent b9be0e8 commit 4a2a46f
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 10 deletions.
48 changes: 45 additions & 3 deletions examples/full_multisolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
to the proper certificate (https://stackoverflow.com/a/31060428).
"""

from dataclasses import dataclass
from math import sqrt
from typing import Any, Callable
from typing import Callable, Optional

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -86,6 +86,37 @@ def heuristic(self, s):
return Value(cost=0)


def get_state_continuous_mountain_car(env):
return env.state


def set_state_continuous_mountain_car(env, state):
env.state = state


@dataclass
class CartPoleState:
state: np.array
steps_beyond_terminated: Optional[int]

def __eq__(self, other: "CartPoleState"):
return (
np.array_equal(self.state, other.state)
and self.steps_beyond_terminated == other.steps_beyond_terminated
)


def get_state_cart_pole(env):
return CartPoleState(
state=env.state, steps_beyond_terminated=env.steps_beyond_terminated
)


def set_state_get_state_cart_pole(env, state: CartPoleState):
env.state = state.state
env.steps_beyond_terminated = state.steps_beyond_terminated


if __name__ == "__main__":

try_domains = [
Expand Down Expand Up @@ -125,6 +156,9 @@ def heuristic(self, s):
"name": "Cart Pole (Gymnasium)",
"entry": "GymDomain",
"config": {"gym_env": gym.make("CartPole-v1", render_mode="human")},
"config_gym4width": dict(
get_state=get_state_cart_pole, set_state=set_state_get_state_cart_pole
),
"rollout": {
"num_episodes": 3,
"max_steps": 1000,
Expand All @@ -139,6 +173,10 @@ def heuristic(self, s):
"config": {
"gym_env": gym.make("MountainCarContinuous-v0", render_mode="human")
},
"config_gym4width": dict(
get_state=get_state_continuous_mountain_car,
set_state=set_state_continuous_mountain_car,
),
"rollout": {
"num_episodes": 3,
"max_steps": 1000,
Expand Down Expand Up @@ -393,7 +431,7 @@ def heuristic(self, s):
else:
# Solve with selected solver
actual_domain_type = domain_type
actual_domain_config = selected_domain["config"]
actual_domain_config = dict(selected_domain["config"]) # copy
actual_domain = domain
if selected_domain["entry"].__name__ == "GymDomain" and (
selected_solver["entry"].__name__ == "IW"
Expand All @@ -404,6 +442,10 @@ def heuristic(self, s):
actual_domain_type = GymDomainForWidthSolvers
if selected_domain["name"] == "Cart Pole (Gymnasium)":
actual_domain_config["termination_is_goal"] = False
if "config_gym4width" in selected_domain:
actual_domain_config.update(
selected_domain["config_gym4width"]
)
actual_domain = actual_domain_type(**actual_domain_config)
selected_solver["config"][
"domain_factory"
Expand Down
7 changes: 6 additions & 1 deletion examples/maze_multiagent_mdp_multisolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,11 @@ def _render_from(self, memory: D.T_memory[D.T_state], **kwargs: Any) -> Any:
plt.pause(0.0001)


class FlattenedMultiAgentMaze(MultiAgentMaze, SingleAgent):
def __init__(self, maze_str=DEFAULT_MAZE, nb_agents=4):
super().__init__(maze_str=maze_str, nb_agents=nb_agents, flatten_data=True)


class D(
Domain,
SingleAgent,
Expand Down Expand Up @@ -598,7 +603,7 @@ def mcts_callback(solver, i=None):
multiagent_domain._maze, multiagent_domain._agents_goals[agent]
),
"multiagent_solver_kwargs": {
"domain_factory": lambda: MultiAgentMaze(flatten_data=True),
"domain_factory": lambda: FlattenedMultiAgentMaze(),
"time_budget": 600000,
"max_depth": 50,
"residual_moving_average_window": 10,
Expand Down
2 changes: 1 addition & 1 deletion examples/nocycle_grid_goal_mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _get_observation_space_(self) -> D.T_agent[Space[D.T_observation]]:
"domain_factory": domain_factory,
"parallel": False,
"discount": 1.0,
"max_tip_expanions": 1,
"max_tip_expansions": 1,
"detect_cycles": False,
"heuristic": lambda d, s: Value(
cost=sqrt(
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ GymDomain = "skdecide.hub.domain.gym:GymDomain [domains]"
DeterministicGymDomain = "skdecide.hub.domain.gym:DeterministicGymDomain [domains]"
CostDeterministicGymDomain = "skdecide.hub.domain.gym:CostDeterministicGymDomain [domains]"
GymPlanningDomain = "skdecide.hub.domain.gym:GymPlanningDomain [domains]"
GymWidthPlanningDomain = "skdecide.hub.domain.gym:GymWidthPlanningDomain [domains]"
GymWidthDomain = "skdecide.hub.domain.gym:GymWidthDomain [domains]"
UPDomain = "skdecide.hub.domain.up:UPDomain [domains]"
MasterMind = "skdecide.hub.domain.mastermind:MasterMind [domains]"
Maze = "skdecide.hub.domain.maze:Maze [domains]"
Expand Down Expand Up @@ -162,7 +162,7 @@ SimpleGreedy = "skdecide.hub.solver.simple_greedy:SimpleGreedy [solvers]"
StableBaseline = "skdecide.hub.solver.stable_baselines:StableBaseline [solvers]"
DOSolver = "skdecide.hub.solver.do_solver:DOSolver [solvers]"
GPHH = "skdecide.hub.solver.do_solver:GPHH [solvers]"
PilePolicy = "skdecide.hub.solver.pile_policy:PilePolicy [solvers]"
PilePolicy = "skdecide.hub.solver.pile_policy_scheduling:PilePolicy [solvers]"
UPSolver = "skdecide.hub.solver.up:UPSolver [solvers]"

[tool.poetry.dev-dependencies]
Expand Down
9 changes: 6 additions & 3 deletions skdecide/hub/domain/gym/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,9 +937,12 @@ def _render_from(self, memory: D.T_memory[D.T_state], **kwargs: Any) -> Any:
# gym_env.render() can modify the environment
# and generate deepcopy errors later in _get_next_state
# thus we use a copy of the env to render it instead.
gym_env_for_rendering = deepcopy(self._gym_env)
render = gym_env_for_rendering.render()
return render
if self._set_state is None or self._get_state is None:
gym_env_for_rendering = deepcopy(self._gym_env)
render = gym_env_for_rendering.render()
return render
else:
self._gym_env.render()

def close(self):
return self._gym_env.close()
Expand Down

0 comments on commit 4a2a46f

Please sign in to comment.