Skip to content

Commit

Permalink
Unify usage of domain_factory in API
Browse files Browse the repository at this point in the history
- C++ solvers need the domain_factory in their __init__
- Some other solvers needed it during solve and thus it was introduced
  in Domain.solve_with() and Solver.solve() and even in
  Restorable.load()

Here we unify this by putting all these domain_factory in all solvers
__init__.
During __init__ we take care to autocast the domain_factory so that it
produced the domain at the appropriate level for the solver (at is was
previously done in solve() which then called _solve()).
We thus avoid repeating this process twice (it was done also in load())
and improve the readibility for users (that sometime needed to put
domain_factory in __init__, sometimes in solve_with)
  • Loading branch information
nhuet committed May 17, 2024
1 parent 8a139be commit 2465a30
Show file tree
Hide file tree
Showing 66 changed files with 401 additions and 381 deletions.
4 changes: 2 additions & 2 deletions docs/guide/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ print(compatible_solvers)

# select Lazy A* solver and instanciate with default parameters
from skdecide.hub.solver.lazy_astar import LazyAstar
mysolver = LazyAstar()
mysolver = LazyAstar(domain_factory=MyDomain)
```

### Compute a solution
Expand Down Expand Up @@ -119,7 +119,7 @@ mysolver._cleanup()
::: tip
Note that this is automatically done if you use the solver within a `with` statement:
```python
with LazyAstar() as mysolver:
with LazyAstar(domain_factory=MyDomain) as mysolver:
MyDomain.solve_with(mysolver)
utils.rollout(MyDomain(), mysolver)
```
Expand Down
5 changes: 2 additions & 3 deletions examples/ars_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,10 @@
learning_rate=learning_rate,
policy_noise=policy_noise,
reward_maximization=reward_maximization,
domain_factory=lambda: domain_type(**selected_domain["config"]),
)
with solver_factory() as solver:
GymDomain.solve_with(
solver, lambda: domain_type(**selected_domain["config"])
)
GymDomain.solve_with(solver)
# Test solver solution on domain
print("==================== TEST SOLVER ====================")
print(
Expand Down
10 changes: 7 additions & 3 deletions examples/baselines_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,14 @@
domain = domain_factory()
if StableBaseline.check_domain(domain):
solver_factory = lambda: StableBaseline(
PPO, "MlpPolicy", learn_config={"total_timesteps": 30000}, verbose=1
domain_factory=domain_factory,
algo_class=PPO,
baselines_policy="MlpPolicy",
learn_config={"total_timesteps": 30000},
verbose=1,
)
with solver_factory() as solver:
GymDomain.solve_with(solver, domain_factory)
GymDomain.solve_with(solver)
solver.save("TEMP_Baselines")
rollout(
domain,
Expand All @@ -51,7 +55,7 @@

# %%
with solver_factory() as solver:
GymDomain.solve_with(solver, domain_factory, load_path="TEMP_Baselines")
GymDomain.solve_with(solver, load_path="TEMP_Baselines")
rollout(
domain,
solver,
Expand Down
6 changes: 4 additions & 2 deletions examples/cgp_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,11 @@
domain_factory = lambda: GymDomain(gym.make(ENV_NAME))
domain = domain_factory()
if CGP.check_domain(domain):
solver_factory = lambda: CGP("TEMP_CGP", n_it=25)
solver_factory = lambda: CGP(
domain_factory=domain_factory, folder_name="TEMP_CGP", n_it=25
)
with solver_factory() as solver:
GymDomain.solve_with(solver, domain_factory)
GymDomain.solve_with(solver)
rollout(
domain,
solver,
Expand Down
42 changes: 14 additions & 28 deletions examples/full_multisolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,21 +165,18 @@ def heuristic(self, s):
{
"name": "Simple greedy",
"entry": "SimpleGreedy",
"need_domain_factory": False,
"config": {},
},
# Lazy A* (classical planning)
{
"name": "Lazy A* (classical planning)",
"entry": "LazyAstar",
"need_domain_factory": False,
"config": {"heuristic": lambda d, s: d.heuristic(s), "verbose": False},
},
# A* (planning)
{
"name": "A* (planning)",
"entry": "Astar",
"need_domain_factory": True,
"config": {
"heuristic": lambda d, s: d.heuristic(s),
"parallel": False,
Expand All @@ -190,7 +187,6 @@ def heuristic(self, s):
{
"name": "LRTAStar",
"entry": "LRTAstar",
"need_domain_factory": False,
"config": {
"max_depth": 200,
"max_iter": 1000,
Expand All @@ -202,7 +198,6 @@ def heuristic(self, s):
{
"name": "UCT (reinforcement learning / search)",
"entry": "UCT",
"need_domain_factory": True,
"config": {
"time_budget": 200,
"rollout_budget": 100000,
Expand All @@ -218,7 +213,6 @@ def heuristic(self, s):
{
"name": "PPO: Proximal Policy Optimization (deep reinforcement learning)",
"entry": "StableBaseline",
"need_domain_factory": False,
"config": {
"algo_class": PPO,
"baselines_policy": "MlpPolicy",
Expand All @@ -230,21 +224,18 @@ def heuristic(self, s):
{
"name": "POMCP: Partially Observable Monte-Carlo Planning (online planning for POMDP)",
"entry": "POMCP",
"need_domain_factory": False,
"config": {},
},
# CGP: Cartesian Genetic Programming (evolution strategy)
{
"name": "CGP: Cartesian Genetic Programming (evolution strategy)",
"entry": "CGP",
"need_domain_factory": False,
"config": {"folder_name": "TEMP", "n_it": 25},
},
# Rollout-IW (classical planning)
{
"name": "Rollout-IW (classical planning)",
"entry": "RIW",
"need_domain_factory": True,
"config": {
"state_features": lambda d, s: d.state_features(s),
"use_state_feature_hash": False,
Expand All @@ -263,7 +254,6 @@ def heuristic(self, s):
{
"name": "IW (classical planning)",
"entry": "IW",
"need_domain_factory": True,
"config": {
"state_features": lambda d, s: d.state_features(s),
"node_ordering": lambda a_gscore, a_novelty, a_depth, b_gscore, b_novelty, b_depth: a_novelty
Expand All @@ -276,7 +266,6 @@ def heuristic(self, s):
{
"name": "BFWS (planning) - (num_rows * num_cols) binary encoding (1 binary variable <=> 1 cell)",
"entry": "BFWS",
"need_domain_factory": True,
"config": {
"state_features": lambda d, s: d.state_features(s),
"heuristic": lambda d, s: d.heuristic(s),
Expand Down Expand Up @@ -407,24 +396,21 @@ def heuristic(self, s):
actual_domain_type = domain_type
actual_domain_config = selected_domain["config"]
actual_domain = domain
if selected_solver["need_domain_factory"]:
if selected_domain["entry"].__name__ == "GymDomain" and (
selected_solver["entry"].__name__ == "IW"
or selected_solver["entry"].__name__ == "RIW"
or selected_solver["entry"].__name__ == "BFWS"
or selected_solver["entry"].__name__ == "UCT"
):
actual_domain_type = GymDomainForWidthSolvers
if selected_domain["name"] == "Cart Pole (Gymnasium)":
actual_domain_config["termination_is_goal"] = False
actual_domain = actual_domain_type(**actual_domain_config)
selected_solver["config"][
"domain_factory"
] = lambda: actual_domain_type(**actual_domain_config)
if selected_domain["entry"].__name__ == "GymDomain" and (
selected_solver["entry"].__name__ == "IW"
or selected_solver["entry"].__name__ == "RIW"
or selected_solver["entry"].__name__ == "BFWS"
or selected_solver["entry"].__name__ == "UCT"
):
actual_domain_type = GymDomainForWidthSolvers
if selected_domain["name"] == "Cart Pole (Gymnasium)":
actual_domain_config["termination_is_goal"] = False
actual_domain = actual_domain_type(**actual_domain_config)
selected_solver["config"][
"domain_factory"
] = lambda: actual_domain_type(**actual_domain_config)
with solver_type(**selected_solver["config"]) as solver:
actual_domain_type.solve_with(
solver, lambda: actual_domain_type(**actual_domain_config)
)
actual_domain_type.solve_with(solver)
rollout(actual_domain, solver, **selected_domain["rollout"])
if hasattr(domain, "close"):
domain.close()
20 changes: 8 additions & 12 deletions examples/graph_domain/example_building_graph_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,14 @@ def solve_with_astar_full_space_explo():
domain_maze.get_initial_state()
)
t = time.perf_counter()
solver = LazyAstar()
solver.solve(domain_factory=lambda: domain_maze)
solver = LazyAstar(domain_factory=lambda: domain_maze)
solver.solve()
print(solver.get_plan())
t1 = time.perf_counter()
print(t1 - t, " sec to solve original maze")
t = time.perf_counter()
solver = LazyAstar()
solver.solve(
domain_factory=lambda: graph_maze, from_memory=domain_maze.get_initial_state()
)
solver = LazyAstar(domain_factory=lambda: graph_maze)
solver.solve(from_memory=domain_maze.get_initial_state())
print(solver.get_plan())
t2 = time.perf_counter()
print(t2 - t, " sec to solve graph maze")
Expand All @@ -49,16 +47,14 @@ def solve_with_astar_dfs():
dfs_explorator = DFSExploration(domain=domain_maze)
graph_maze = dfs_explorator.build_graph_domain(domain_maze.get_initial_state())
t = time.perf_counter()
solver = LazyAstar()
solver.solve(domain_factory=lambda: domain_maze)
solver = LazyAstar(domain_factory=lambda: domain_maze)
solver.solve()
print(solver.get_plan())
t1 = time.perf_counter()
print(t1 - t, " sec to solve original maze")
t = time.perf_counter()
solver = LazyAstar()
solver.solve(
domain_factory=lambda: graph_maze, from_memory=domain_maze.get_initial_state()
)
solver = LazyAstar(domain_factory=lambda: graph_maze)
solver.solve(from_memory=domain_maze.get_initial_state())
print(solver.get_plan())
t2 = time.perf_counter()
print(t2 - t, " sec to solve graph maze")
2 changes: 2 additions & 0 deletions examples/grid_multisolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def decode(val):
"name": "Lazy A* (classical planning)",
"entry": "LazyAstar",
"config": {
"domain_factory": lambda: MyDomain(),
"heuristic": lambda d, s: Value(
cost=sqrt((d.num_cols - 1 - s.x) ** 2 + (d.num_rows - 1 - s.y) ** 2)
),
Expand Down Expand Up @@ -383,6 +384,7 @@ def decode(val):
"name": "PPO: Proximal Policy Optimization (deep reinforcement learning)",
"entry": "StableBaseline",
"config": {
"domain_factory": lambda: MyDomain(),
"algo_class": PPO,
"baselines_policy": "MlpPolicy",
"learn_config": {"total_timesteps": 30000},
Expand Down
13 changes: 7 additions & 6 deletions examples/gym_jsbsim_greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from gym_jsbsim.catalogs.catalog import Catalog as prp
from gym_jsbsim.envs.taxi_utils import *

from skdecide import Solver
from skdecide import Domain, Solver
from skdecide.builders.solver import DeterministicPolicies, FromAnyState, Utilities
from skdecide.hub.domain.gym import DeterministicGymDomain, GymDiscreteActionDomain
from skdecide.utils import rollout
Expand Down Expand Up @@ -103,14 +103,15 @@ def _render_from(self, memory: D.T_memory[D.T_state], **kwargs: Any) -> Any:
class GreedyPlanner(Solver, DeterministicPolicies, Utilities, FromAnyState):
T_domain = D

def __init__(self):
def __init__(self, domain_factory: Callable[[], Domain]):
Solver.__init__(self, domain_factory=domain_factory)
self._domain = None
self._best_action = None
self._best_reward = None
self._current_pos = None

def _init_solve(self, domain_factory: Callable[[], D]) -> None:
self._domain = domain_factory()
def _init_solve(self) -> None:
self._domain = self._domain_factory()
self._domain.reset()
lon = self._domain._gym_env.sim.get_property_value(prp.position_long_gc_deg)
lat = self._domain._gym_env.sim.get_property_value(prp.position_lat_geod_deg)
Expand Down Expand Up @@ -162,8 +163,8 @@ def get_current_position(self):
domain.reset()

if GreedyPlanner.check_domain(domain):
with GreedyPlanner() as solver:
GymGreedyDomain.solve_with(solver, domain_factory)
with GreedyPlanner(domain_factory=domain_factory) as solver:
GymGreedyDomain.solve_with(solver)
initial_state = solver._domain.reset()
rollout(
domain,
Expand Down
2 changes: 1 addition & 1 deletion examples/gym_jsbsim_iw.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def _get_next_action(
debug_logs=False,
)
with solver_factory() as solver:
GymIWDomain.solve_with(solver, domain_factory)
GymIWDomain.solve_with(solver)
evaluation_domain = EvaluationDomain(solver._domain)
evaluation_domain.reset()
rollout(
Expand Down
2 changes: 1 addition & 1 deletion examples/gym_jsbsim_riw.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def _get_next_action(
debug_logs=False,
)
with solver_factory() as solver:
GymRIWDomain.solve_with(solver, domain_factory)
GymRIWDomain.solve_with(solver)
rollout(
domain,
solver,
Expand Down
2 changes: 1 addition & 1 deletion examples/gym_jsbsim_uct.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def _render_from(self, memory: D.T_memory[D.T_state], **kwargs: Any) -> Any:
debug_logs=False,
)
with solver_factory() as solver:
GymUCTRawDomain.solve_with(solver, domain_factory)
GymUCTRawDomain.solve_with(solver)
solver._domain.reset()
rollout(
domain_factory(),
Expand Down
2 changes: 1 addition & 1 deletion examples/gym_line_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def _state_step(
debug_logs=False,
)
with solver_factory() as solver:
GymRIWDomain.solve_with(solver, domain_factory)
GymRIWDomain.solve_with(solver)
initial_state = solver._domain.reset()
rollout(
domain,
Expand Down
3 changes: 2 additions & 1 deletion examples/maxent_irl_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
print("===>", domain.get_action_space().unwrapped())
if MaxentIRL.check_domain(domain):
solver_factory = lambda: MaxentIRL(
domain_factory=domain_factory,
n_states=400,
n_actions=3,
one_feature=20,
expert_trajectories="expert_mountain.npy",
n_epochs=10000,
)
with solver_factory() as solver:
GymDomain.solve_with(solver, domain_factory)
GymDomain.solve_with(solver)
rollout(
domain,
solver,
Expand Down
6 changes: 0 additions & 6 deletions examples/maze_multiagent_mdp_multisolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,9 +565,6 @@ def mcts_watchdog(elapsed_time, nb_rollouts, best_value, epsilon_moving_average)
"debug_logs": False,
},
"singleagent_solver_kwargs": {
"domain_factory": lambda: lambda multiagent_domain, agent: SingleAgentMaze(
multiagent_domain._maze, multiagent_domain._agents_goals[agent]
),
"heuristic": lambda d, s: Value(
cost=sqrt((d._goal.x - s.x) ** 2 + (d._goal.y - s.y) ** 2)
),
Expand Down Expand Up @@ -613,9 +610,6 @@ def mcts_watchdog(elapsed_time, nb_rollouts, best_value, epsilon_moving_average)
"debug_logs": False,
},
"singleagent_solver_kwargs": {
"domain_factory": lambda: lambda multiagent_domain, agent: SingleAgentMaze(
multiagent_domain._maze, multiagent_domain._agents_goals[agent]
),
"heuristic": lambda d, s: Value(
cost=sqrt((d._goal.x - s.x) ** 2 + (d._goal.y - s.y) ** 2)
),
Expand Down
6 changes: 4 additions & 2 deletions examples/rllib_solver_applicable_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,13 @@ def _get_observation_space_(self) -> D.T_agent[Space[D.T_observation]]:

# Check domain compatibility
if RayRLlib.check_domain(domain):
solver_factory = lambda: RayRLlib(DQN, train_iterations=5)
solver_factory = lambda: RayRLlib(
domain_factory=domain_factory, algo_class=DQN, train_iterations=5
)

# Start solving
with solver_factory() as solver:
GridWorldFilteredActions.solve_with(solver, domain_factory)
GridWorldFilteredActions.solve_with(solver)

# Test solution
rollout(
Expand Down
Loading

0 comments on commit 2465a30

Please sign in to comment.