diff --git a/skdecide/hub/solver/do_solver/do_solver_scheduling.py b/skdecide/hub/solver/do_solver/do_solver_scheduling.py index d246daa0f2..311434db33 100644 --- a/skdecide/hub/solver/do_solver/do_solver_scheduling.py +++ b/skdecide/hub/solver/do_solver/do_solver_scheduling.py @@ -5,9 +5,13 @@ from __future__ import annotations from enum import Enum -from typing import Any, Callable, Dict, Tuple, Union +from typing import Any, Callable, Dict, Optional, Tuple, Union +from discrete_optimization.generic_tools.callbacks.callback import Callback from discrete_optimization.generic_tools.do_solver import SolverDO +from discrete_optimization.generic_tools.result_storage.result_storage import ( + ResultStorage, +) from discrete_optimization.rcpsp.rcpsp_model import RCPSPModel, RCPSPSolution from discrete_optimization.rcpsp_multiskill.rcpsp_multiskill import ( MS_RCPSPModel, @@ -15,6 +19,7 @@ MS_RCPSPSolution_Variant, ) +from skdecide import Domain from skdecide.builders.domain.scheduling.scheduling_domains import SchedulingDomain from skdecide.hub.solver.do_solver.sgs_policies import PolicyMethodParams, PolicyRCPSP from skdecide.hub.solver.do_solver.sk_to_do_binding import build_do_domain @@ -131,8 +136,10 @@ def __init__( self, policy_method_params: PolicyMethodParams, method: SolvingMethod = SolvingMethod.PILE, - dict_params: Dict[Any, Any] = None, + dict_params: Optional[Dict[Any, Any]] = None, + callback: Optional[Callable[[Domain, DOSolver], bool]] = None, ): + self.callback = callback self.method = method self.policy_method_params = policy_method_params self.dict_params = dict_params @@ -169,12 +176,20 @@ def _solve(self, domain_factory: Callable[[], D]) -> None: if k not in self.dict_params: self.dict_params[k] = params[k] + # callbacks + if self.callback is None: + callbacks = [] + else: + callbacks = [ + _DOCallback(callback=self.callback, domain=self.domain, solver=self) + ] + self.solver = solver_class(self.do_domain, **self.dict_params) if hasattr(self.solver, "init_model") and callable(self.solver.init_model): self.solver.init_model(**self.dict_params) - result_storage = self.solver.solve(**self.dict_params) + result_storage = self.solver.solve(callbacks=callbacks, **self.dict_params) best_solution: RCPSPSolution = result_storage.get_best_solution() assert best_solution is not None @@ -206,3 +221,32 @@ def _get_next_action( def _is_policy_defined_for(self, observation: D.T_agent[D.T_observation]) -> bool: return self.policy_object.is_policy_defined_for(observation=observation) + + +class _DOCallback(Callback): + def __init__( + self, + callback: Callable[[Domain, DOSolver], bool], + domain: Domain, + solver: Solver, + ): + self.domain = domain + self.solver = solver + self.callback = callback + + def on_step_end( + self, step: int, res: ResultStorage, solver: SolverDO + ) -> Optional[bool]: + """Called at the end of an optimization step. + + Args: + step: index of step + res: current result storage + solver: solvers using the callback + + Returns: + If `True`, the optimization process is stopped, else it goes on. + + """ + stopping = self.callback(self.domain, self.solver) + return stopping diff --git a/tests/scheduling/test_scheduling.py b/tests/scheduling/test_scheduling.py index a60d5f7da4..289572f14e 100644 --- a/tests/scheduling/test_scheduling.py +++ b/tests/scheduling/test_scheduling.py @@ -1,3 +1,4 @@ +import logging import random from enum import Enum from typing import Any, Dict, List, Optional, Set, Union @@ -36,16 +37,17 @@ ) from skdecide.builders.domain.scheduling.task_duration import DeterministicTaskDuration from skdecide.hub.domain.rcpsp.rcpsp_sk import build_n_determinist_from_stochastic -from skdecide.hub.solver.do_solver.do_solver_scheduling import ( +from skdecide.hub.solver.do_solver.do_solver_scheduling import DOSolver, SolvingMethod +from skdecide.hub.solver.do_solver.gphh import GPHH, ParametersGPHH +from skdecide.hub.solver.do_solver.sgs_policies import ( BasePolicyMethod, - DOSolver, PolicyMethodParams, - SolvingMethod, ) -from skdecide.hub.solver.do_solver.gphh import GPHH, ParametersGPHH from skdecide.hub.solver.graph_explorer.DFS_Uncertain_Exploration import DFSExploration from skdecide.hub.solver.lazy_astar import LazyAstar +logger = logging.getLogger(__name__) + optimal_solutions = { "ToyRCPSPDomain": {"makespan": 10}, "ToyMS_RCPSPDomain": {"makespan": 10}, @@ -940,3 +942,60 @@ def test_sgs_policies(domain): ) print("Cost :", sum([v.cost for v in values])) check_rollout_consistency(domain, states) + + +class MyCallback: + """Callback for testing. + + - displays iteration number + - stops after max iteration reached + - check classes of domain and solver + + """ + + def __init__(self, max_iter=2): + self.max_iter = max_iter + self.iter = 0 + + def __call__(self, domain, solver): + self.iter += 1 + logger.warning(f"End of iteration #{self.iter}.") + assert isinstance(domain, ToyRCPSPDomain) + assert isinstance(solver, DOSolver) + stopping = self.iter >= self.max_iter + return stopping + + +def test_do_with_cb(caplog): + domain = ToyRCPSPDomain() + domain.set_inplace_environment(False) + state = domain.get_initial_state() + print("Initial state : ", state) + solver = DOSolver( + policy_method_params=PolicyMethodParams( + base_policy_method=BasePolicyMethod.SGS_PRECEDENCE, + delta_index_freedom=0, + delta_time_freedom=0, + ), + method=SolvingMethod.LNS_CP, + callback=MyCallback(), + # dict_params={"cp_solver_name": CPSolverName.GECODE} + ) + solver.solve(domain_factory=lambda: domain) + + # Check that 2 iterations were done and messages logged by callback + assert "End of iteration #2" in caplog.text + assert "End of iteration #3" not in caplog.text + + # action_formatter=lambda o: str(o), + # outcome_formatter=lambda o: f'{o.observation} - cost: {o.value.cost:.2f}') + states, actions, values = rollout_episode( + domain=domain, + max_steps=1000, + solver=solver, + from_memory=state, + action_formatter=None, + outcome_formatter=None, + verbose=False, + ) + check_rollout_consistency(domain, states)