Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a solver characteristic for solvers implementing solve_from() #320

Merged
merged 2 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions examples/gym_jsbsim_greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from gym_jsbsim.envs.taxi_utils import *

from skdecide import Solver
from skdecide.builders.solver import DeterministicPolicies, Utilities
from skdecide.builders.solver import DeterministicPolicies, FromAnyState, Utilities
from skdecide.hub.domain.gym import DeterministicGymDomain, GymDiscreteActionDomain
from skdecide.utils import rollout

Expand Down Expand Up @@ -100,7 +100,7 @@ def _render_from(self, memory: D.T_memory[D.T_state], **kwargs: Any) -> Any:
self._map.save("gym_jsbsim_map_update.html")


class GreedyPlanner(Solver, DeterministicPolicies, Utilities):
class GreedyPlanner(Solver, DeterministicPolicies, Utilities, FromAnyState):
T_domain = D

def __init__(self):
Expand All @@ -116,9 +116,6 @@ def _init_solve(self, domain_factory: Callable[[], D]) -> None:
lat = self._domain._gym_env.sim.get_property_value(prp.position_lat_geod_deg)
self._current_pos = (lat, lon)

def _solve(self, domain_factory: Callable[[], D]) -> None:
self._init_solve(domain_factory)

def _solve_from(self, memory: D.T_memory[D.T_state]) -> None:
self._best_action = None
self._best_reward = -float("inf")
Expand Down
8 changes: 4 additions & 4 deletions examples/scheduling/toy_rcpsp_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,8 @@ def run_example():
solver = None
# UNCOMMENT BELOW TO USE ASTAR
# domain.set_inplace_environment(False)
# solver = lazy_astar.LazyAstar(from_state=state, heuristic=None, verbose=True)
# solver.solve(domain_factory=lambda: domain)
# solver = lazy_astar.LazyAstar(heuristic=None, verbose=True)
# solver.solve(domain_factory=lambda: domain, from_memory=state)
states, actions, values = rollout_episode(
domain=domain,
max_steps=1000,
Expand All @@ -455,8 +455,8 @@ def run_astar():
domain.set_inplace_environment(False)
state = domain.get_initial_state()
print("Initial state : ", state)
solver = LazyAstar(from_state=state, heuristic=None, verbose=True)
solver.solve(domain_factory=lambda: domain)
solver = LazyAstar(heuristic=None, verbose=True)
solver.solve(domain_factory=lambda: domain, from_memory=state)
states, actions, values = rollout_episode(
domain=domain,
max_steps=1000,
Expand Down
10 changes: 5 additions & 5 deletions notebooks/13_scheduling_tuto.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,8 @@
"source": [
"from skdecide.hub.solver.lazy_astar import LazyAstar\n",
"\n",
"solver = LazyAstar(from_state=state, heuristic=None)\n",
"solver.solve(domain_factory=lambda: domain)"
"solver = LazyAstar(heuristic=None)\n",
"solver.solve(domain_factory=lambda: domain, from_memory=state)"
]
},
{
Expand Down Expand Up @@ -669,9 +669,9 @@
"# Wait for 300 seconds\n",
"signal.alarm(300)\n",
"\n",
"solver = LazyAstar(from_state=state, heuristic=None)\n",
"solver = LazyAstar(heuristic=None)\n",
"try:\n",
" solver.solve(domain_factory=lambda: domain)\n",
" solver.solve(domain_factory=lambda: domain, from_memory=state)\n",
"except Exception:\n",
" print(\"the algorithm could not finish\")\n",
"finally:\n",
Expand Down Expand Up @@ -1030,7 +1030,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.13"
},
"toc": {
"base_numbering": 1,
Expand Down
1 change: 1 addition & 0 deletions skdecide/builders/solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# LICENSE file in the root directory of this source tree.

from skdecide.builders.solver.assessability import *
from skdecide.builders.solver.fromanystatesolvability import *
from skdecide.builders.solver.parallelability import *
from skdecide.builders.solver.policy import *
from skdecide.builders.solver.restorability import *
189 changes: 189 additions & 0 deletions skdecide/builders/solver/fromanystatesolvability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# Copyright (c) AIRBUS and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Optional

from skdecide.builders.domain.initialization import Initializable
from skdecide.core import D, autocast_all, autocastable

if TYPE_CHECKING: # avoids circular imports
from skdecide.domains import Domain

__all__ = ["FromInitialState", "FromAnyState"]


class FromInitialState:
""" "A solver must inherit this class if it can solve only from the initial state"""

def solve(
self,
domain_factory: Callable[[], Domain],
) -> None:
"""Run the solving process.

By default, #FromInitialState.solve() provides some boilerplate code and internally calls #FromInitialState._solve(). The
boilerplate code transforms the domain factory to auto-cast the new domains to the level expected by the solver.

# Parameters
domain_factory: A callable with no argument returning the domain to solve (can be just a domain class).

!!! tip
The nature of the solutions produced here depends on other solver's characteristics like
#policy and #assessibility.
"""

def cast_domain_factory():
domain = domain_factory()
autocast_all(domain, domain, self.T_domain)
return domain

return self._solve(cast_domain_factory)

def _solve(
self,
domain_factory: Callable[[], Domain],
) -> None:
"""Run the solving process.

This is a helper function called by default from #FromInitialState.solve(), the difference being that the domain factory
here returns domains auto-cast to the level expected by the solver.

# Parameters
domain_factory: A callable with no argument returning the domain to solve (auto-cast to expected level).

!!! tip
domain_factory: A callable with no argument returning the domain to solve (auto-cast to expected level).
The nature of the solutions produced here depends on other solver's characteristics like
#policy and #assessibility.
"""
raise NotImplementedError


class FromAnyState(FromInitialState):
"""A solver must inherit this class if it can solve from any given state."""

def solve(
self,
domain_factory: Callable[[], Domain],
from_memory: Optional[D.T_memory[D.T_state]] = None,
) -> None:
"""Run the solving process.

By default, #FromInitialState.solve() provides some boilerplate code and internally calls #FromInitialState._solve(). The
boilerplate code transforms the domain factory to auto-cast the new domains to the level expected by the solver.

# Parameters
domain_factory: A callable with no argument returning the domain to solve (can be just a domain class).
from_memory: The source memory (state or history) from which we begin the solving process.
If None, initial state is used if the domain is initializable, else a ValueError is raised.

!!! tip
The nature of the solutions produced here depends on other solver's characteristics like
#policy and #assessibility.
"""

def cast_domain_factory():
domain = domain_factory()
autocast_all(domain, domain, self.T_domain)
return domain

return self._solve(cast_domain_factory, from_memory=from_memory)

def _solve(
self,
domain_factory: Callable[[], Domain],
from_memory: Optional[D.T_memory[D.T_state]] = None,
) -> None:
"""Run the solving process.

This is a helper function called by default from #FromInitState.solve(), the difference being that the domain factory
here returns domains auto-cast to the level expected by the solver.

# Parameters
domain_factory: A callable with no argument returning the domain to solve (auto-cast to expected level).
from_memory: The source memory (state or history) from which we begin the solving process.
If None, initial state is used if the domain is initializable, else a ValueError is raised.

!!! tip
The nature of the solutions produced here depends on other solver's characteristics like
#policy and #assessibility.
"""
self._init_solve(domain_factory=domain_factory)
if from_memory is None:
domain = domain_factory()
if not isinstance(domain, Initializable):
raise ValueError(
"from_memory cannot be None if the domain is not initializable."
)
domain.reset()
from_memory = domain._memory # updated by domain.reset()

self._solve_from(from_memory)

@autocastable
def solve_from(self, memory: D.T_memory[D.T_state]) -> None:
"""Run the solving process from a given state.

!!! tip
Create the domain first by calling the @FromAnyState.init_solve() method

# Parameters
memory: The source memory (state or history) of the transition.

!!! tip
The nature of the solutions produced here depends on other solver's characteristics like
#policy and #assessibility.
"""
return self._solve_from(memory)

def _solve_from(self, memory: D.T_memory[D.T_state]) -> None:
"""Run the solving process from a given state.

!!! tip
Create the domain first by calling the @FromAnyState._init_solve() method

# Parameters
memory: The source memory (state or history) of the transition.

!!! tip
The nature of the solutions produced here depends on other solver's characteristics like
#policy and #assessibility.
"""
raise NotImplementedError

def init_solve(self, domain_factory: Callable[[], Domain]) -> None:
"""Initialize solver before calling `solve_from()`

In particular, initialize the underlying domain.

By default, #FromAnyState.init_solve() provides some boilerplate code and internally calls #FromAnyState._init_solve(). The
boilerplate code transforms the domain factory to auto-cast the new domains to the level expected by the solver.

# Parameters
domain_factory: A callable with no argument returning the domain to solve (can be just a domain class).

"""

def cast_domain_factory():
domain = domain_factory()
autocast_all(domain, domain, self.T_domain)
return domain

return self._init_solve(cast_domain_factory)

def _init_solve(self, domain_factory: Callable[[], Domain]) -> None:
"""Initialize solver before calling `solve_from()`

In particular, initialize the underlying domain.

This is a helper function called by default from #FromAnyState.init_solve(), the difference being that the domain factory
here returns domains auto-cast to the level expected by the solver.

# Parameters
domain_factory: A callable with no argument returning the domain to solve (can be just a domain class).

"""
raise NotImplementedError
6 changes: 4 additions & 2 deletions skdecide/builders/solver/parallelability.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@

from __future__ import annotations

from typing import Callable, List
from typing import TYPE_CHECKING, Callable, List

from skdecide.domains import Domain
from skdecide.parallel_domains import PipeParallelDomain, ShmParallelDomain

if TYPE_CHECKING: # avoids circular imports
from skdecide.domains import Domain

__all__ = ["ParallelSolver"]


Expand Down
18 changes: 15 additions & 3 deletions skdecide/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
TransformedObservable,
)
from skdecide.builders.domain.value import PositiveCosts, Rewards
from skdecide.core import autocast_all
from skdecide.builders.solver.fromanystatesolvability import FromAnyState
from skdecide.core import D, autocast_all

if (
False
Expand Down Expand Up @@ -117,6 +118,7 @@ def solve_with(
solver: Solver,
domain_factory: Optional[Callable[[], Domain]] = None,
load_path: Optional[str] = None,
from_memory: Optional[D.T_memory[T_state]] = None,
) -> Solver:
"""Solve the domain with a new or loaded solver and return it auto-cast to the level of the domain.

Expand All @@ -129,14 +131,17 @@ def solve_with(
solver: The solver.
domain_factory: A callable with no argument returning the domain to solve (factory is the domain class if None).
load_path: The path to restore the solver state from (if None, the solving process will be launched instead).
from_memory: The source memory (state or history) from which we begin the solving process.
To be used, if the solving process must begin from a specific state,
and only for solvers having the characteristic #FromAnyState, else raise a ValueError.
Ignored if load_path is used.

# Returns
The new solver (auto-cast to the level of the domain).
"""
if domain_factory is None:
domain_factory = cls
if load_path is not None:

# TODO: avoid repeating this code somehow (identical in solver.solve(...))? Is factory necessary (vs cls)?
def cast_domain_factory():
domain = domain_factory()
Expand All @@ -145,7 +150,14 @@ def cast_domain_factory():

solver.load(load_path, cast_domain_factory)
else:
solver.solve(domain_factory)
if isinstance(solver, FromAnyState):
solver.solve(domain_factory, from_memory=from_memory)
elif from_memory is None:
solver.solve(domain_factory)
else:
raise ValueError(
f"`from_memory` must be None when used with a solver not having {FromAnyState.__name__} characteristic."
)
autocast_all(solver, solver.T_domain, cls)
return solver

Expand Down
14 changes: 9 additions & 5 deletions skdecide/hub/solver/aostar/aostar.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
Sequential,
SingleAgent,
)
from skdecide.builders.solver import DeterministicPolicies, ParallelSolver, Utilities
from skdecide.builders.solver import (
DeterministicPolicies,
FromAnyState,
ParallelSolver,
Utilities,
)
from skdecide.core import Value

record_sys_path = sys.path
Expand All @@ -45,7 +50,9 @@ class D(
):
pass

class AOstar(ParallelSolver, Solver, DeterministicPolicies, Utilities):
class AOstar(
ParallelSolver, Solver, DeterministicPolicies, Utilities, FromAnyState
):
T_domain = D

def __init__(
Expand Down Expand Up @@ -104,9 +111,6 @@ def _init_solve(self, domain_factory: Callable[[], Domain]) -> None:
)
self._solver.clear()

def _solve(self, domain_factory: Callable[[], D]) -> None:
self._init_solve(domain_factory)

def _solve_from(self, memory: D.T_memory[D.T_state]) -> None:
self._solver.solve(memory)

Expand Down
Loading
Loading