From 22a2e3629acfeec7f1e90459ffa338f9ccbd5112 Mon Sep 17 00:00:00 2001 From: Nolwen Date: Tue, 2 Apr 2024 12:21:10 +0200 Subject: [PATCH 1/2] Put solve_from() in a dedicated solver characteristic - generic characteristic: FromInitState - specific characteristic: FromAnyState: - the solver have a solve_from() method - solve() use it by default to solve from initial state (if domain is initializable else raise a ValueError if not implemented) - solve() (and _solve) gets also an optional argument from_memory to solve from another state (and call solve_from accordingly) Consequences: - add from_memory arg in Domain.solve_with() - Solver inherits from FromInitState (to have solve() and _solve() in its api) - FromAnyState inherits from FromInitState and overrides solve() and _solve() - init_solve() added to FromAnyState to be called in _solve(), to ensure solve_from() will already have the domain (and for instance the underlying c++ solver) initialized when called from _solve. Minor modifications to avoid circular imports - We import Domain for annotations only when type checking - We use a dedicated logger in parallel_domains.py to avoid havind to import domains --- .../solver/fromanystatesolvability.py | 189 ++++++++++++++++++ skdecide/builders/solver/parallelability.py | 6 +- skdecide/domains.py | 18 +- skdecide/parallel_domains.py | 3 +- skdecide/solvers.py | 73 +------ 5 files changed, 213 insertions(+), 76 deletions(-) create mode 100644 skdecide/builders/solver/fromanystatesolvability.py diff --git a/skdecide/builders/solver/fromanystatesolvability.py b/skdecide/builders/solver/fromanystatesolvability.py new file mode 100644 index 0000000000..2fa6905ed3 --- /dev/null +++ b/skdecide/builders/solver/fromanystatesolvability.py @@ -0,0 +1,189 @@ +# Copyright (c) AIRBUS and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, Optional + +from skdecide.builders.domain.initialization import Initializable +from skdecide.core import D, autocast_all, autocastable + +if TYPE_CHECKING: # avoids circular imports + from skdecide.domains import Domain + +__all__ = ["FromInitialState", "FromAnyState"] + + +class FromInitialState: + """ "A solver must inherit this class if it can solve only from the initial state""" + + def solve( + self, + domain_factory: Callable[[], Domain], + ) -> None: + """Run the solving process. + + By default, #FromInitialState.solve() provides some boilerplate code and internally calls #FromInitialState._solve(). The + boilerplate code transforms the domain factory to auto-cast the new domains to the level expected by the solver. + + # Parameters + domain_factory: A callable with no argument returning the domain to solve (can be just a domain class). + + !!! tip + The nature of the solutions produced here depends on other solver's characteristics like + #policy and #assessibility. + """ + + def cast_domain_factory(): + domain = domain_factory() + autocast_all(domain, domain, self.T_domain) + return domain + + return self._solve(cast_domain_factory) + + def _solve( + self, + domain_factory: Callable[[], Domain], + ) -> None: + """Run the solving process. + + This is a helper function called by default from #FromInitialState.solve(), the difference being that the domain factory + here returns domains auto-cast to the level expected by the solver. + + # Parameters + domain_factory: A callable with no argument returning the domain to solve (auto-cast to expected level). + + !!! tip + domain_factory: A callable with no argument returning the domain to solve (auto-cast to expected level). + The nature of the solutions produced here depends on other solver's characteristics like + #policy and #assessibility. + """ + raise NotImplementedError + + +class FromAnyState(FromInitialState): + """A solver must inherit this class if it can solve from any given state.""" + + def solve( + self, + domain_factory: Callable[[], Domain], + from_memory: Optional[D.T_memory[D.T_state]] = None, + ) -> None: + """Run the solving process. + + By default, #FromInitialState.solve() provides some boilerplate code and internally calls #FromInitialState._solve(). The + boilerplate code transforms the domain factory to auto-cast the new domains to the level expected by the solver. + + # Parameters + domain_factory: A callable with no argument returning the domain to solve (can be just a domain class). + from_memory: The source memory (state or history) from which we begin the solving process. + If None, initial state is used if the domain is initializable, else a ValueError is raised. + + !!! tip + The nature of the solutions produced here depends on other solver's characteristics like + #policy and #assessibility. + """ + + def cast_domain_factory(): + domain = domain_factory() + autocast_all(domain, domain, self.T_domain) + return domain + + return self._solve(cast_domain_factory, from_memory=from_memory) + + def _solve( + self, + domain_factory: Callable[[], Domain], + from_memory: Optional[D.T_memory[D.T_state]] = None, + ) -> None: + """Run the solving process. + + This is a helper function called by default from #FromInitState.solve(), the difference being that the domain factory + here returns domains auto-cast to the level expected by the solver. + + # Parameters + domain_factory: A callable with no argument returning the domain to solve (auto-cast to expected level). + from_memory: The source memory (state or history) from which we begin the solving process. + If None, initial state is used if the domain is initializable, else a ValueError is raised. + + !!! tip + The nature of the solutions produced here depends on other solver's characteristics like + #policy and #assessibility. + """ + self._init_solve(domain_factory=domain_factory) + if from_memory is None: + domain = domain_factory() + if not isinstance(domain, Initializable): + raise ValueError( + "from_memory cannot be None if the domain is not initializable." + ) + domain.reset() + from_memory = domain._memory # updated by domain.reset() + + self._solve_from(from_memory) + + @autocastable + def solve_from(self, memory: D.T_memory[D.T_state]) -> None: + """Run the solving process from a given state. + + !!! tip + Create the domain first by calling the @FromAnyState.init_solve() method + + # Parameters + memory: The source memory (state or history) of the transition. + + !!! tip + The nature of the solutions produced here depends on other solver's characteristics like + #policy and #assessibility. + """ + return self._solve_from(memory) + + def _solve_from(self, memory: D.T_memory[D.T_state]) -> None: + """Run the solving process from a given state. + + !!! tip + Create the domain first by calling the @FromAnyState._init_solve() method + + # Parameters + memory: The source memory (state or history) of the transition. + + !!! tip + The nature of the solutions produced here depends on other solver's characteristics like + #policy and #assessibility. + """ + raise NotImplementedError + + def init_solve(self, domain_factory: Callable[[], Domain]) -> None: + """Initialize solver before calling `solve_from()` + + In particular, initialize the underlying domain. + + By default, #FromAnyState.init_solve() provides some boilerplate code and internally calls #FromAnyState._init_solve(). The + boilerplate code transforms the domain factory to auto-cast the new domains to the level expected by the solver. + + # Parameters + domain_factory: A callable with no argument returning the domain to solve (can be just a domain class). + + """ + + def cast_domain_factory(): + domain = domain_factory() + autocast_all(domain, domain, self.T_domain) + return domain + + return self._init_solve(cast_domain_factory) + + def _init_solve(self, domain_factory: Callable[[], Domain]) -> None: + """Initialize solver before calling `solve_from()` + + In particular, initialize the underlying domain. + + This is a helper function called by default from #FromAnyState.init_solve(), the difference being that the domain factory + here returns domains auto-cast to the level expected by the solver. + + # Parameters + domain_factory: A callable with no argument returning the domain to solve (can be just a domain class). + + """ + raise NotImplementedError diff --git a/skdecide/builders/solver/parallelability.py b/skdecide/builders/solver/parallelability.py index 651acd018a..a8fcffebe0 100644 --- a/skdecide/builders/solver/parallelability.py +++ b/skdecide/builders/solver/parallelability.py @@ -4,11 +4,13 @@ from __future__ import annotations -from typing import Callable, List +from typing import TYPE_CHECKING, Callable, List -from skdecide.domains import Domain from skdecide.parallel_domains import PipeParallelDomain, ShmParallelDomain +if TYPE_CHECKING: # avoids circular imports + from skdecide.domains import Domain + __all__ = ["ParallelSolver"] diff --git a/skdecide/domains.py b/skdecide/domains.py index 573bc505c4..db64bb5db9 100644 --- a/skdecide/domains.py +++ b/skdecide/domains.py @@ -30,7 +30,8 @@ TransformedObservable, ) from skdecide.builders.domain.value import PositiveCosts, Rewards -from skdecide.core import autocast_all +from skdecide.builders.solver.fromanystatesolvability import FromAnyState +from skdecide.core import D, autocast_all if ( False @@ -117,6 +118,7 @@ def solve_with( solver: Solver, domain_factory: Optional[Callable[[], Domain]] = None, load_path: Optional[str] = None, + from_memory: Optional[D.T_memory[T_state]] = None, ) -> Solver: """Solve the domain with a new or loaded solver and return it auto-cast to the level of the domain. @@ -129,6 +131,10 @@ def solve_with( solver: The solver. domain_factory: A callable with no argument returning the domain to solve (factory is the domain class if None). load_path: The path to restore the solver state from (if None, the solving process will be launched instead). + from_memory: The source memory (state or history) from which we begin the solving process. + To be used, if the solving process must begin from a specific state, + and only for solvers having the characteristic #FromAnyState, else raise a ValueError. + Ignored if load_path is used. # Returns The new solver (auto-cast to the level of the domain). @@ -136,7 +142,6 @@ def solve_with( if domain_factory is None: domain_factory = cls if load_path is not None: - # TODO: avoid repeating this code somehow (identical in solver.solve(...))? Is factory necessary (vs cls)? def cast_domain_factory(): domain = domain_factory() @@ -145,7 +150,14 @@ def cast_domain_factory(): solver.load(load_path, cast_domain_factory) else: - solver.solve(domain_factory) + if isinstance(solver, FromAnyState): + solver.solve(domain_factory, from_memory=from_memory) + elif from_memory is None: + solver.solve(domain_factory) + else: + raise ValueError( + f"`from_memory` must be None when used with a solver not having {FromAnyState.__name__} characteristic." + ) autocast_all(solver, solver.T_domain, cls) return solver diff --git a/skdecide/parallel_domains.py b/skdecide/parallel_domains.py index 6094d9777f..60547c0e06 100644 --- a/skdecide/parallel_domains.py +++ b/skdecide/parallel_domains.py @@ -6,6 +6,7 @@ from __future__ import annotations +import logging import os import tempfile @@ -15,7 +16,7 @@ from pathos.helpers import mp from pynng import Push0 -from skdecide.domains import logger +logger = logging.getLogger(__name__) dill.settings["byref"] = True diff --git a/skdecide/solvers.py b/skdecide/solvers.py index 5f3b40a0ae..a8757d438d 100644 --- a/skdecide/solvers.py +++ b/skdecide/solvers.py @@ -5,10 +5,10 @@ """This module contains base classes for quickly building solvers.""" from __future__ import annotations -from typing import Callable, List +from typing import List +from skdecide.builders.solver.fromanystatesolvability import FromInitialState from skdecide.builders.solver.policy import DeterministicPolicies -from skdecide.core import D, autocast_all, autocastable from skdecide.domains import Domain __all__ = ["Solver", "DeterministicPolicySolver"] @@ -17,7 +17,7 @@ # MAIN BASE CLASS -class Solver: +class Solver(FromInitialState): """This is the highest level solver class (inheriting top-level class for each mandatory solver characteristic). This helper class can be used as the main base class for solvers. @@ -127,73 +127,6 @@ def _reset(self) -> None: """ pass - def solve(self, domain_factory: Callable[[], Domain]) -> None: - """Run the solving process. - - By default, #Solver.solve() provides some boilerplate code and internally calls #Solver._solve(). The - boilerplate code transforms the domain factory to auto-cast the new domains to the level expected by the solver. - - # Parameters - domain_factory: A callable with no argument returning the domain to solve (can be just a domain class). - - !!! tip - The nature of the solutions produced here depends on other solver's characteristics like - #policy and #assessibility. - """ - - def cast_domain_factory(): - domain = domain_factory() - autocast_all(domain, domain, self.T_domain) - return domain - - return self._solve(cast_domain_factory) - - def _solve(self, domain_factory: Callable[[], T_domain]) -> None: - """Run the solving process. - - This is a helper function called by default from #Solver.solve(), the difference being that the domain factory - here returns domains auto-cast to the level expected by the solver. - - # Parameters - domain_factory: A callable with no argument returning the domain to solve (auto-cast to expected level). - - !!! tip - The nature of the solutions produced here depends on other solver's characteristics like - #policy and #assessibility. - """ - raise NotImplementedError - - @autocastable - def solve_from(self, memory: D.T_memory[D.T_state]) -> None: - """Run the solving process from a given state. - - !!! tip - Create the domain first by calling the @Solver.reset() method - - # Parameters - memory: The source memory (state or history) of the transition. - - !!! tip - The nature of the solutions produced here depends on other solver's characteristics like - #policy and #assessibility. - """ - return self._solve_from(memory) - - def _solve_from(self, memory: D.T_memory[D.T_state]) -> None: - """Run the solving process from a given state. - - !!! tip - Create the domain first by calling the @Solver.reset() method - - # Parameters - memory: The source memory (state or history) of the transition. - - !!! tip - The nature of the solutions produced here depends on other solver's characteristics like - #policy and #assessibility. - """ - pass - def _initialize(self): """Runs long-lasting initialization code here, or code to be executed at the entering of a 'with' context statement. From 044bc12336c913a28c8f37b2d277f26cfc157b88 Mon Sep 17 00:00:00 2001 From: Nolwen Date: Tue, 2 Apr 2024 15:52:40 +0200 Subject: [PATCH 2/2] Update solvers to match new api - for solvers implementing solve_from(): - derive from FromAnyState - remove _solve implementation (already in FromAnyState) - for LazyAstar and LRTAstar: - remove attribute self._from_state - derive from FromAnyState - initialize self._domain in _init_solve() - convert _solve into _solve_from - for MAHD: - derive FromAnyState, but raises an error if chosen solver cannot do it - for other solvers: nothing changes. --- examples/gym_jsbsim_greedy.py | 7 +-- examples/scheduling/toy_rcpsp_examples.py | 8 ++-- notebooks/13_scheduling_tuto.ipynb | 10 ++--- skdecide/builders/solver/__init__.py | 1 + skdecide/hub/solver/aostar/aostar.py | 14 +++--- skdecide/hub/solver/astar/astar.py | 12 ++--- skdecide/hub/solver/bfws/bfws.py | 12 ++--- skdecide/hub/solver/ilaostar/ilaostar.py | 14 +++--- skdecide/hub/solver/iw/iw.py | 12 ++--- skdecide/hub/solver/lazy_astar/lazy_astar.py | 42 ++++++++++++----- skdecide/hub/solver/lrtastar/lrtastar.py | 47 ++++++++++++++------ skdecide/hub/solver/lrtdp/lrtdp.py | 12 ++--- skdecide/hub/solver/mahd/mahd.py | 26 ++++++++--- skdecide/hub/solver/martdp/martdp.py | 14 +++--- skdecide/hub/solver/mcts/mcts.py | 16 ++++--- skdecide/hub/solver/riw/riw.py | 12 ++--- tests/scheduling/test_scheduling.py | 4 +- 17 files changed, 171 insertions(+), 92 deletions(-) diff --git a/examples/gym_jsbsim_greedy.py b/examples/gym_jsbsim_greedy.py index fd2289d624..143b5d487c 100644 --- a/examples/gym_jsbsim_greedy.py +++ b/examples/gym_jsbsim_greedy.py @@ -11,7 +11,7 @@ from gym_jsbsim.envs.taxi_utils import * from skdecide import Solver -from skdecide.builders.solver import DeterministicPolicies, Utilities +from skdecide.builders.solver import DeterministicPolicies, FromAnyState, Utilities from skdecide.hub.domain.gym import DeterministicGymDomain, GymDiscreteActionDomain from skdecide.utils import rollout @@ -100,7 +100,7 @@ def _render_from(self, memory: D.T_memory[D.T_state], **kwargs: Any) -> Any: self._map.save("gym_jsbsim_map_update.html") -class GreedyPlanner(Solver, DeterministicPolicies, Utilities): +class GreedyPlanner(Solver, DeterministicPolicies, Utilities, FromAnyState): T_domain = D def __init__(self): @@ -116,9 +116,6 @@ def _init_solve(self, domain_factory: Callable[[], D]) -> None: lat = self._domain._gym_env.sim.get_property_value(prp.position_lat_geod_deg) self._current_pos = (lat, lon) - def _solve(self, domain_factory: Callable[[], D]) -> None: - self._init_solve(domain_factory) - def _solve_from(self, memory: D.T_memory[D.T_state]) -> None: self._best_action = None self._best_reward = -float("inf") diff --git a/examples/scheduling/toy_rcpsp_examples.py b/examples/scheduling/toy_rcpsp_examples.py index 88db0534fc..dc372627da 100644 --- a/examples/scheduling/toy_rcpsp_examples.py +++ b/examples/scheduling/toy_rcpsp_examples.py @@ -435,8 +435,8 @@ def run_example(): solver = None # UNCOMMENT BELOW TO USE ASTAR # domain.set_inplace_environment(False) - # solver = lazy_astar.LazyAstar(from_state=state, heuristic=None, verbose=True) - # solver.solve(domain_factory=lambda: domain) + # solver = lazy_astar.LazyAstar(heuristic=None, verbose=True) + # solver.solve(domain_factory=lambda: domain, from_memory=state) states, actions, values = rollout_episode( domain=domain, max_steps=1000, @@ -455,8 +455,8 @@ def run_astar(): domain.set_inplace_environment(False) state = domain.get_initial_state() print("Initial state : ", state) - solver = LazyAstar(from_state=state, heuristic=None, verbose=True) - solver.solve(domain_factory=lambda: domain) + solver = LazyAstar(heuristic=None, verbose=True) + solver.solve(domain_factory=lambda: domain, from_memory=state) states, actions, values = rollout_episode( domain=domain, max_steps=1000, diff --git a/notebooks/13_scheduling_tuto.ipynb b/notebooks/13_scheduling_tuto.ipynb index ebef654392..a5ef05c15f 100644 --- a/notebooks/13_scheduling_tuto.ipynb +++ b/notebooks/13_scheduling_tuto.ipynb @@ -360,8 +360,8 @@ "source": [ "from skdecide.hub.solver.lazy_astar import LazyAstar\n", "\n", - "solver = LazyAstar(from_state=state, heuristic=None)\n", - "solver.solve(domain_factory=lambda: domain)" + "solver = LazyAstar(heuristic=None)\n", + "solver.solve(domain_factory=lambda: domain, from_memory=state)" ] }, { @@ -669,9 +669,9 @@ "# Wait for 300 seconds\n", "signal.alarm(300)\n", "\n", - "solver = LazyAstar(from_state=state, heuristic=None)\n", + "solver = LazyAstar(heuristic=None)\n", "try:\n", - " solver.solve(domain_factory=lambda: domain)\n", + " solver.solve(domain_factory=lambda: domain, from_memory=state)\n", "except Exception:\n", " print(\"the algorithm could not finish\")\n", "finally:\n", @@ -1030,7 +1030,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.10.13" }, "toc": { "base_numbering": 1, diff --git a/skdecide/builders/solver/__init__.py b/skdecide/builders/solver/__init__.py index 50578be7e4..2af8a744d5 100644 --- a/skdecide/builders/solver/__init__.py +++ b/skdecide/builders/solver/__init__.py @@ -3,6 +3,7 @@ # LICENSE file in the root directory of this source tree. from skdecide.builders.solver.assessability import * +from skdecide.builders.solver.fromanystatesolvability import * from skdecide.builders.solver.parallelability import * from skdecide.builders.solver.policy import * from skdecide.builders.solver.restorability import * diff --git a/skdecide/hub/solver/aostar/aostar.py b/skdecide/hub/solver/aostar/aostar.py index 9f9776daef..36adf75e35 100644 --- a/skdecide/hub/solver/aostar/aostar.py +++ b/skdecide/hub/solver/aostar/aostar.py @@ -19,7 +19,12 @@ Sequential, SingleAgent, ) -from skdecide.builders.solver import DeterministicPolicies, ParallelSolver, Utilities +from skdecide.builders.solver import ( + DeterministicPolicies, + FromAnyState, + ParallelSolver, + Utilities, +) from skdecide.core import Value record_sys_path = sys.path @@ -45,7 +50,9 @@ class D( ): pass - class AOstar(ParallelSolver, Solver, DeterministicPolicies, Utilities): + class AOstar( + ParallelSolver, Solver, DeterministicPolicies, Utilities, FromAnyState + ): T_domain = D def __init__( @@ -104,9 +111,6 @@ def _init_solve(self, domain_factory: Callable[[], Domain]) -> None: ) self._solver.clear() - def _solve(self, domain_factory: Callable[[], D]) -> None: - self._init_solve(domain_factory) - def _solve_from(self, memory: D.T_memory[D.T_state]) -> None: self._solver.solve(memory) diff --git a/skdecide/hub/solver/astar/astar.py b/skdecide/hub/solver/astar/astar.py index e5079a780a..5de815c820 100644 --- a/skdecide/hub/solver/astar/astar.py +++ b/skdecide/hub/solver/astar/astar.py @@ -19,7 +19,12 @@ Sequential, SingleAgent, ) -from skdecide.builders.solver import DeterministicPolicies, ParallelSolver, Utilities +from skdecide.builders.solver import ( + DeterministicPolicies, + FromAnyState, + ParallelSolver, + Utilities, +) from skdecide.core import Value record_sys_path = sys.path @@ -45,7 +50,7 @@ class D( ): pass - class Astar(ParallelSolver, Solver, DeterministicPolicies, Utilities): + class Astar(ParallelSolver, Solver, DeterministicPolicies, Utilities, FromAnyState): T_domain = D def __init__( @@ -95,9 +100,6 @@ def _init_solve(self, domain_factory: Callable[[], Domain]) -> None: ) self._solver.clear() - def _solve(self, domain_factory: Callable[[], D]) -> None: - self._init_solve(domain_factory) - def _solve_from(self, memory: D.T_memory[D.T_state]) -> None: self._solver.solve(memory) diff --git a/skdecide/hub/solver/bfws/bfws.py b/skdecide/hub/solver/bfws/bfws.py index 682efd272f..eadb2a34aa 100644 --- a/skdecide/hub/solver/bfws/bfws.py +++ b/skdecide/hub/solver/bfws/bfws.py @@ -19,7 +19,12 @@ Sequential, SingleAgent, ) -from skdecide.builders.solver import DeterministicPolicies, ParallelSolver, Utilities +from skdecide.builders.solver import ( + DeterministicPolicies, + FromAnyState, + ParallelSolver, + Utilities, +) from skdecide.core import Value record_sys_path = sys.path @@ -44,7 +49,7 @@ class D( ): # TODO: check why DeterministicInitialized & PositiveCosts/Rewards? pass - class BFWS(ParallelSolver, Solver, DeterministicPolicies, Utilities): + class BFWS(ParallelSolver, Solver, DeterministicPolicies, Utilities, FromAnyState): T_domain = D def __init__( @@ -108,9 +113,6 @@ def _init_solve(self, domain_factory: Callable[[], D]) -> None: ) self._solver.clear() - def _solve(self, domain_factory: Callable[[], D]) -> None: - self._init_solve(domain_factory) - def _solve_from(self, memory: D.T_memory[D.T_state]) -> None: self._solver.solve(memory) diff --git a/skdecide/hub/solver/ilaostar/ilaostar.py b/skdecide/hub/solver/ilaostar/ilaostar.py index 90bf0767ae..6cc572d117 100644 --- a/skdecide/hub/solver/ilaostar/ilaostar.py +++ b/skdecide/hub/solver/ilaostar/ilaostar.py @@ -19,7 +19,12 @@ Sequential, SingleAgent, ) -from skdecide.builders.solver import DeterministicPolicies, ParallelSolver, Utilities +from skdecide.builders.solver import ( + DeterministicPolicies, + FromAnyState, + ParallelSolver, + Utilities, +) from skdecide.core import Value record_sys_path = sys.path @@ -45,7 +50,9 @@ class D( ): pass - class ILAOstar(ParallelSolver, Solver, DeterministicPolicies, Utilities): + class ILAOstar( + ParallelSolver, Solver, DeterministicPolicies, Utilities, FromAnyState + ): T_domain = D def __init__( @@ -101,9 +108,6 @@ def _init_solve(self, domain_factory: Callable[[], Domain]) -> None: ) self._solver.clear() - def _solve(self, domain_factory: Callable[[], D]) -> None: - self._init_solve(domain_factory) - def _solve_from(self, memory: D.T_memory[D.T_state]) -> None: self._solver.solve(memory) diff --git a/skdecide/hub/solver/iw/iw.py b/skdecide/hub/solver/iw/iw.py index 11c38ffe2d..dd40629f5a 100644 --- a/skdecide/hub/solver/iw/iw.py +++ b/skdecide/hub/solver/iw/iw.py @@ -19,7 +19,12 @@ Sequential, SingleAgent, ) -from skdecide.builders.solver import DeterministicPolicies, ParallelSolver, Utilities +from skdecide.builders.solver import ( + DeterministicPolicies, + FromAnyState, + ParallelSolver, + Utilities, +) from skdecide.hub.space.gym import ListSpace record_sys_path = sys.path @@ -44,7 +49,7 @@ class D( ): # TODO: check why DeterministicInitialized & PositiveCosts/Rewards? pass - class IW(ParallelSolver, Solver, DeterministicPolicies, Utilities): + class IW(ParallelSolver, Solver, DeterministicPolicies, Utilities, FromAnyState): T_domain = D def __init__( @@ -98,9 +103,6 @@ def _init_solve(self, domain_factory: Callable[[], D]) -> None: ) self._solver.clear() - def _solve(self, domain_factory: Callable[[], D]) -> None: - self._init_solve(domain_factory) - def _solve_from(self, memory: D.T_memory[D.T_state]) -> None: self._solver.solve(memory) diff --git a/skdecide/hub/solver/lazy_astar/lazy_astar.py b/skdecide/hub/solver/lazy_astar/lazy_astar.py index c49f869a7d..b7d4715d82 100644 --- a/skdecide/hub/solver/lazy_astar/lazy_astar.py +++ b/skdecide/hub/solver/lazy_astar/lazy_astar.py @@ -19,7 +19,7 @@ Sequential, SingleAgent, ) -from skdecide.builders.solver import DeterministicPolicies, Utilities +from skdecide.builders.solver import DeterministicPolicies, FromAnyState, Utilities # TODO: remove Markovian req? @@ -37,12 +37,11 @@ class D( pass -class LazyAstar(Solver, DeterministicPolicies, Utilities): +class LazyAstar(Solver, DeterministicPolicies, Utilities, FromAnyState): T_domain = D def __init__( self, - from_state: Optional[D.T_state] = None, heuristic: Optional[ Callable[[Domain, D.T_state], D.T_agent[Value[D.T_value]]] ] = None, @@ -51,7 +50,6 @@ def __init__( render: bool = False, ) -> None: - self._from_state = from_state self._heuristic = ( (lambda _, __: Value(cost=0.0)) if heuristic is None else heuristic ) @@ -61,9 +59,37 @@ def __init__( self._values = {} self._plan = [] - def _solve(self, domain_factory: Callable[[], D]) -> None: + def _init_solve(self, domain_factory: Callable[[], Domain]) -> None: + """Initialize solver before calling `solve_from()` + + In particular, initialize the underlying domain. + + This is a helper function called by default from #FromAnyState.init_solve(), the difference being that the domain factory + here returns domains auto-cast to the level expected by the solver. + + # Parameters + domain_factory: A callable with no argument returning the domain to solve (can be just a domain class). + + """ self._domain = domain_factory() + def _solve_from( + self, + memory: D.T_state, + ) -> None: + """Run the solving process from a given state. + + !!! tip + Create the domain first by calling the @FromAnyState._init_solve() method + + # Parameters + memory: The source memory (state or history) of the transition. + + !!! tip + The nature of the solutions produced here depends on other solver's characteristics like + #policy and #assessibility. + """ + def extender(node, label, explored): result = [] for a in self._domain.get_applicable_actions(node).get_elements(): @@ -80,11 +106,7 @@ def extender(node, label, explored): push = heappush pop = heappop - if self._from_state is None: - # get initial observation from domain (assuming DeterministicInitialized) - sources = [self._domain.get_initial_state()] - else: - sources = [self._from_state] + sources = [memory] # targets = list(self._domain.get_goals().get_elements()) # The queue is the OPEN list. diff --git a/skdecide/hub/solver/lrtastar/lrtastar.py b/skdecide/hub/solver/lrtastar/lrtastar.py index 9df6b6e9ab..4d88bbf498 100644 --- a/skdecide/hub/solver/lrtastar/lrtastar.py +++ b/skdecide/hub/solver/lrtastar/lrtastar.py @@ -18,7 +18,7 @@ Sequential, SingleAgent, ) -from skdecide.builders.solver import DeterministicPolicies, Utilities +from skdecide.builders.solver import DeterministicPolicies, FromAnyState, Utilities class D( @@ -35,7 +35,7 @@ class D( pass -class LRTAstar(Solver, DeterministicPolicies, Utilities): +class LRTAstar(Solver, DeterministicPolicies, Utilities, FromAnyState): T_domain = D def _get_next_action( @@ -53,7 +53,6 @@ def _get_utility(self, observation: D.T_agent[D.T_observation]) -> D.T_value: def __init__( self, - from_state: Optional[D.T_state] = None, heuristic: Optional[ Callable[[Domain, D.T_state], D.T_agent[Value[D.T_value]]] ] = None, @@ -62,7 +61,6 @@ def __init__( max_iter=5000, max_depth=200, ) -> None: - self._from_state = from_state self._heuristic = ( (lambda _, __: Value(cost=0.0)) if heuristic is None else heuristic ) @@ -77,22 +75,43 @@ def __init__( self.heuristic_changed = False self._policy = {} - def _solve(self, domain_factory: Callable[[], D]) -> None: + def _init_solve(self, domain_factory: Callable[[], Domain]) -> None: + """Initialize solver before calling `solve_from()` + + In particular, initialize the underlying domain. + + This is a helper function called by default from #FromAnyState.init_solve(), the difference being that the domain factory + here returns domains auto-cast to the level expected by the solver. + + # Parameters + domain_factory: A callable with no argument returning the domain to solve (can be just a domain class). + + """ self._domain = domain_factory() + + def _solve_from( + self, + memory: D.T_state, + ) -> None: + """Run the solving process from a given state. + + !!! tip + Create the domain first by calling the @FromAnyState._init_solve() method + + # Parameters + memory: The source memory (state or history) of the transition. + + !!! tip + The nature of the solutions produced here depends on other solver's characteristics like + #policy and #assessibility. + """ self.values = {} iteration = 0 best_cost = float("inf") - if self._from_state is None: - # get initial observation from domain (assuming DeterministicInitialized) - from_observation = self._domain.get_initial_state() - else: - from_observation = self._from_state # best_path = None while True: - print(from_observation) - dead_end, cumulated_cost, current_roll, list_action = self.doTrial( - from_observation - ) + print(memory) + dead_end, cumulated_cost, current_roll, list_action = self.doTrial(memory) if self._verbose: print( "iter ", diff --git a/skdecide/hub/solver/lrtdp/lrtdp.py b/skdecide/hub/solver/lrtdp/lrtdp.py index 6bc3f24c92..34dcfca9ef 100644 --- a/skdecide/hub/solver/lrtdp/lrtdp.py +++ b/skdecide/hub/solver/lrtdp/lrtdp.py @@ -19,7 +19,12 @@ SingleAgent, UncertainTransitions, ) -from skdecide.builders.solver import DeterministicPolicies, ParallelSolver, Utilities +from skdecide.builders.solver import ( + DeterministicPolicies, + FromAnyState, + ParallelSolver, + Utilities, +) from skdecide.core import Value record_sys_path = sys.path @@ -45,7 +50,7 @@ class D( ): pass - class LRTDP(ParallelSolver, Solver, DeterministicPolicies, Utilities): + class LRTDP(ParallelSolver, Solver, DeterministicPolicies, Utilities, FromAnyState): T_domain = D def __init__( @@ -131,9 +136,6 @@ def _init_solve(self, domain_factory: Callable[[], Domain]) -> None: ) self._solver.clear() - def _solve(self, domain_factory: Callable[[], D]) -> None: - self._init_solve(domain_factory) - def _solve_from(self, memory: D.T_memory[D.T_state]) -> None: self._solver.solve(memory) diff --git a/skdecide/hub/solver/mahd/mahd.py b/skdecide/hub/solver/mahd/mahd.py index c11051504d..ea5f54a276 100644 --- a/skdecide/hub/solver/mahd/mahd.py +++ b/skdecide/hub/solver/mahd/mahd.py @@ -4,11 +4,11 @@ from __future__ import annotations -from typing import Any, Callable, Set, Tuple +from typing import Any, Callable, Optional, Set, Tuple from skdecide import Domain, Solver from skdecide.builders.domain import MultiAgent, Sequential, SingleAgent -from skdecide.builders.solver import DeterministicPolicies, Utilities +from skdecide.builders.solver import DeterministicPolicies, FromAnyState, Utilities from skdecide.core import Value @@ -17,7 +17,7 @@ class D(Domain, MultiAgent, Sequential): pass -class MAHD(Solver, DeterministicPolicies, Utilities): +class MAHD(Solver, DeterministicPolicies, Utilities, FromAnyState): T_domain = D def __init__( @@ -74,11 +74,27 @@ def __init__( a: {} for a in self._multiagent_domain.get_agents() } - def _solve(self, domain_factory: Callable[[], D]) -> None: + def _solve( + self, + domain_factory: Callable[[], D], + from_memory: Optional[D.T_memory[D.T_state]] = None, + ) -> None: + self._multiagent_domain_class.solve_with( + solver=self._multiagent_solver, + domain_factory=domain_factory, + from_memory=from_memory, + ) + + def _solve_from(self, memory: D.T_memory[D.T_state]) -> None: self._multiagent_domain_class.solve_with( - solver=self._multiagent_solver, domain_factory=domain_factory + solver=self._multiagent_solver, + domain_factory=self._domain_factory, + from_memory=memory, ) + def _init_solve(self, domain_factory: Callable[[], Domain]) -> None: + self._domain_factory = domain_factory + def _get_next_action( self, observation: D.T_agent[D.T_observation] ) -> D.T_agent[D.T_concurrency[D.T_event]]: diff --git a/skdecide/hub/solver/martdp/martdp.py b/skdecide/hub/solver/martdp/martdp.py index 493dc1fcc2..204348998a 100644 --- a/skdecide/hub/solver/martdp/martdp.py +++ b/skdecide/hub/solver/martdp/martdp.py @@ -19,7 +19,12 @@ Sequential, Simulation, ) -from skdecide.builders.solver import DeterministicPolicies, ParallelSolver, Utilities +from skdecide.builders.solver import ( + DeterministicPolicies, + FromAnyState, + ParallelSolver, + Utilities, +) from skdecide.core import Value record_sys_path = sys.path @@ -45,7 +50,9 @@ class D( ): pass - class MARTDP(ParallelSolver, Solver, DeterministicPolicies, Utilities): + class MARTDP( + ParallelSolver, Solver, DeterministicPolicies, Utilities, FromAnyState + ): T_domain = D def __init__( @@ -147,9 +154,6 @@ def _init_solve(self, domain_factory: Callable[[], Domain]) -> None: ) self._solver.clear() - def _solve(self, domain_factory: Callable[[], D]) -> None: - self._init_solve(domain_factory) - def _solve_from(self, memory: D.T_memory[D.T_state]) -> None: self._solver.solve(memory) diff --git a/skdecide/hub/solver/mcts/mcts.py b/skdecide/hub/solver/mcts/mcts.py index 8dd2f138c2..31e334a4cd 100644 --- a/skdecide/hub/solver/mcts/mcts.py +++ b/skdecide/hub/solver/mcts/mcts.py @@ -21,7 +21,12 @@ Sequential, SingleAgent, ) -from skdecide.builders.solver import DeterministicPolicies, ParallelSolver, Utilities +from skdecide.builders.solver import ( + DeterministicPolicies, + FromAnyState, + ParallelSolver, + Utilities, +) from skdecide.core import Value record_sys_path = sys.path @@ -47,7 +52,7 @@ class D( ): # TODO: check why DeterministicInitialized & PositiveCosts/Rewards? pass - class MCTS(ParallelSolver, Solver, DeterministicPolicies, Utilities): + class MCTS(ParallelSolver, Solver, DeterministicPolicies, Utilities, FromAnyState): T_domain = D Options = mcts_options @@ -170,9 +175,6 @@ def _init_solve(self, domain_factory: Callable[[], D]) -> None: ) self._solver.clear() - def _solve(self, domain_factory: Callable[[], D]) -> None: - self._init_solve(domain_factory) - def _solve_from(self, memory: D.T_memory[D.T_state]) -> None: self._solver.solve(memory) @@ -293,8 +295,8 @@ def __init__( self._action_choice_noise = action_choice_noise self._heuristic_records = {} - def _solve(self, domain_factory: Callable[[], D]) -> None: - super()._solve(domain_factory=domain_factory) + def _init_solve(self, domain_factory: Callable[[], D]) -> None: + super()._init_solve(domain_factory) self._heuristic_records = {} def _value_heuristic( diff --git a/skdecide/hub/solver/riw/riw.py b/skdecide/hub/solver/riw/riw.py index 82e27f27f9..24da638b3a 100644 --- a/skdecide/hub/solver/riw/riw.py +++ b/skdecide/hub/solver/riw/riw.py @@ -19,7 +19,12 @@ Sequential, SingleAgent, ) -from skdecide.builders.solver import DeterministicPolicies, ParallelSolver, Utilities +from skdecide.builders.solver import ( + DeterministicPolicies, + FromAnyState, + ParallelSolver, + Utilities, +) record_sys_path = sys.path skdecide_cpp_extension_lib_path = os.path.abspath(hub.__path__[0]) @@ -43,7 +48,7 @@ class D( ): # TODO: check why DeterministicInitialized & PositiveCosts/Rewards? pass - class RIW(ParallelSolver, Solver, DeterministicPolicies, Utilities): + class RIW(ParallelSolver, Solver, DeterministicPolicies, Utilities, FromAnyState): T_domain = D def __init__( @@ -128,9 +133,6 @@ def _init_solve(self, domain_factory: Callable[[], D]) -> None: ) self._solver.clear() - def _solve(self, domain_factory: Callable[[], D]) -> None: - self._init_solve(domain_factory) - def _solve_from(self, memory: D.T_memory[D.T_state]) -> None: self._solver.solve(memory) diff --git a/tests/scheduling/test_scheduling.py b/tests/scheduling/test_scheduling.py index 318c996cf3..da8a4a3b4f 100644 --- a/tests/scheduling/test_scheduling.py +++ b/tests/scheduling/test_scheduling.py @@ -704,8 +704,8 @@ def test_planning_algos(domain, solver_str): state = domain.get_initial_state() print("Initial state : ", state) if solver_str == "LazyAstar": - solver = LazyAstar(from_state=state, heuristic=None, verbose=False) - solver.solve(domain_factory=lambda: domain) + solver = LazyAstar(heuristic=None, verbose=False) + solver.solve(domain_factory=lambda: domain, from_memory=state) states, actions, values = rollout_episode( domain=domain, max_steps=1000,