Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix examples #366

Merged
merged 1 commit into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions examples/full_multisolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
to the proper certificate (https://stackoverflow.com/a/31060428).

"""

from dataclasses import dataclass
from math import sqrt
from typing import Any, Callable
from typing import Callable, Optional

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -86,6 +86,37 @@ def heuristic(self, s):
return Value(cost=0)


def get_state_continuous_mountain_car(env):
return env.state


def set_state_continuous_mountain_car(env, state):
env.state = state


@dataclass
class CartPoleState:
state: np.array
steps_beyond_terminated: Optional[int]

def __eq__(self, other: "CartPoleState"):
return (
np.array_equal(self.state, other.state)
and self.steps_beyond_terminated == other.steps_beyond_terminated
)


def get_state_cart_pole(env):
return CartPoleState(
state=env.state, steps_beyond_terminated=env.steps_beyond_terminated
)


def set_state_get_state_cart_pole(env, state: CartPoleState):
env.state = state.state
env.steps_beyond_terminated = state.steps_beyond_terminated


if __name__ == "__main__":

try_domains = [
Expand Down Expand Up @@ -125,6 +156,9 @@ def heuristic(self, s):
"name": "Cart Pole (Gymnasium)",
"entry": "GymDomain",
"config": {"gym_env": gym.make("CartPole-v1", render_mode="human")},
"config_gym4width": dict(
get_state=get_state_cart_pole, set_state=set_state_get_state_cart_pole
),
"rollout": {
"num_episodes": 3,
"max_steps": 1000,
Expand All @@ -139,6 +173,10 @@ def heuristic(self, s):
"config": {
"gym_env": gym.make("MountainCarContinuous-v0", render_mode="human")
},
"config_gym4width": dict(
get_state=get_state_continuous_mountain_car,
set_state=set_state_continuous_mountain_car,
),
"rollout": {
"num_episodes": 3,
"max_steps": 1000,
Expand Down Expand Up @@ -393,7 +431,7 @@ def heuristic(self, s):
else:
# Solve with selected solver
actual_domain_type = domain_type
actual_domain_config = selected_domain["config"]
actual_domain_config = dict(selected_domain["config"]) # copy
actual_domain = domain
if selected_domain["entry"].__name__ == "GymDomain" and (
selected_solver["entry"].__name__ == "IW"
Expand All @@ -404,6 +442,10 @@ def heuristic(self, s):
actual_domain_type = GymDomainForWidthSolvers
if selected_domain["name"] == "Cart Pole (Gymnasium)":
actual_domain_config["termination_is_goal"] = False
if "config_gym4width" in selected_domain:
actual_domain_config.update(
selected_domain["config_gym4width"]
)
actual_domain = actual_domain_type(**actual_domain_config)
selected_solver["config"][
"domain_factory"
Expand Down
7 changes: 6 additions & 1 deletion examples/maze_multiagent_mdp_multisolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,11 @@ def _render_from(self, memory: D.T_memory[D.T_state], **kwargs: Any) -> Any:
plt.pause(0.0001)


class FlattenedMultiAgentMaze(MultiAgentMaze, SingleAgent):
def __init__(self, maze_str=DEFAULT_MAZE, nb_agents=4):
super().__init__(maze_str=maze_str, nb_agents=nb_agents, flatten_data=True)


class D(
Domain,
SingleAgent,
Expand Down Expand Up @@ -598,7 +603,7 @@ def mcts_callback(solver, i=None):
multiagent_domain._maze, multiagent_domain._agents_goals[agent]
),
"multiagent_solver_kwargs": {
"domain_factory": lambda: MultiAgentMaze(flatten_data=True),
"domain_factory": lambda: FlattenedMultiAgentMaze(),
"time_budget": 600000,
"max_depth": 50,
"residual_moving_average_window": 10,
Expand Down
2 changes: 1 addition & 1 deletion examples/nocycle_grid_goal_mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _get_observation_space_(self) -> D.T_agent[Space[D.T_observation]]:
"domain_factory": domain_factory,
"parallel": False,
"discount": 1.0,
"max_tip_expanions": 1,
"max_tip_expansions": 1,
"detect_cycles": False,
"heuristic": lambda d, s: Value(
cost=sqrt(
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ GymDomain = "skdecide.hub.domain.gym:GymDomain [domains]"
DeterministicGymDomain = "skdecide.hub.domain.gym:DeterministicGymDomain [domains]"
CostDeterministicGymDomain = "skdecide.hub.domain.gym:CostDeterministicGymDomain [domains]"
GymPlanningDomain = "skdecide.hub.domain.gym:GymPlanningDomain [domains]"
GymWidthPlanningDomain = "skdecide.hub.domain.gym:GymWidthPlanningDomain [domains]"
GymWidthDomain = "skdecide.hub.domain.gym:GymWidthDomain [domains]"
UPDomain = "skdecide.hub.domain.up:UPDomain [domains]"
MasterMind = "skdecide.hub.domain.mastermind:MasterMind [domains]"
Maze = "skdecide.hub.domain.maze:Maze [domains]"
Expand Down Expand Up @@ -162,7 +162,7 @@ SimpleGreedy = "skdecide.hub.solver.simple_greedy:SimpleGreedy [solvers]"
StableBaseline = "skdecide.hub.solver.stable_baselines:StableBaseline [solvers]"
DOSolver = "skdecide.hub.solver.do_solver:DOSolver [solvers]"
GPHH = "skdecide.hub.solver.do_solver:GPHH [solvers]"
PilePolicy = "skdecide.hub.solver.pile_policy:PilePolicy [solvers]"
PilePolicy = "skdecide.hub.solver.pile_policy_scheduling:PilePolicy [solvers]"
UPSolver = "skdecide.hub.solver.up:UPSolver [solvers]"

[tool.poetry.dev-dependencies]
Expand Down
9 changes: 6 additions & 3 deletions skdecide/hub/domain/gym/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,9 +937,12 @@ def _render_from(self, memory: D.T_memory[D.T_state], **kwargs: Any) -> Any:
# gym_env.render() can modify the environment
# and generate deepcopy errors later in _get_next_state
# thus we use a copy of the env to render it instead.
gym_env_for_rendering = deepcopy(self._gym_env)
render = gym_env_for_rendering.render()
return render
if self._set_state is None or self._get_state is None:
gym_env_for_rendering = deepcopy(self._gym_env)
render = gym_env_for_rendering.render()
return render
else:
self._gym_env.render()

def close(self):
return self._gym_env.close()
Expand Down
Loading