From c223234e40ae78340344e96411fc7a67bd056293 Mon Sep 17 00:00:00 2001 From: Denis Barbier Date: Fri, 3 Sep 2021 18:50:21 +0200 Subject: [PATCH] Apply clang-format pre-commit changes pre-commit run -a clang-format --- cpp/src/builders/domain/dynamics.hh | 271 +-- cpp/src/builders/domain/events.hh | 211 +- cpp/src/builders/domain/goals.hh | 44 +- cpp/src/builders/domain/initialization.hh | 157 +- cpp/src/builders/domain/memory.hh | 191 +- cpp/src/builders/domain/observability.hh | 140 +- cpp/src/builders/domain/renderability.hh | 8 +- cpp/src/builders/domain/value.hh | 18 +- cpp/src/core.hh | 620 +++--- cpp/src/hub/py_skdecide.cc | 38 +- cpp/src/hub/solver/aostar/aostar.hh | 129 +- cpp/src/hub/solver/aostar/impl/aostar_impl.hh | 418 ++-- cpp/src/hub/solver/aostar/py_aostar.cc | 48 +- cpp/src/hub/solver/aostar/py_aostar.hh | 338 ++-- cpp/src/hub/solver/astar/astar.hh | 104 +- cpp/src/hub/solver/astar/impl/astar_impl.hh | 337 ++-- cpp/src/hub/solver/astar/py_astar.cc | 40 +- cpp/src/hub/solver/astar/py_astar.hh | 324 ++-- cpp/src/hub/solver/bfws/bfws.hh | 182 +- cpp/src/hub/solver/bfws/impl/bfws_impl.hh | 490 ++--- cpp/src/hub/solver/bfws/py_bfws.cc | 48 +- cpp/src/hub/solver/bfws/py_bfws.hh | 440 +++-- cpp/src/hub/solver/ilaostar/ilaostar.hh | 167 +- .../hub/solver/ilaostar/impl/ilaostar_impl.hh | 712 +++---- cpp/src/hub/solver/ilaostar/py_ilaostar.cc | 55 +- cpp/src/hub/solver/ilaostar/py_ilaostar.hh | 401 ++-- cpp/src/hub/solver/iw/impl/iw_impl.hh | 824 ++++---- cpp/src/hub/solver/iw/iw.hh | 259 +-- cpp/src/hub/solver/iw/py_iw.cc | 58 +- cpp/src/hub/solver/iw/py_iw.hh | 513 ++--- cpp/src/hub/solver/lrtdp/impl/lrtdp_impl.hh | 832 ++++---- cpp/src/hub/solver/lrtdp/lrtdp.hh | 247 +-- cpp/src/hub/solver/lrtdp/py_lrtdp.cc | 76 +- cpp/src/hub/solver/lrtdp/py_lrtdp.hh | 473 ++--- cpp/src/hub/solver/martdp/impl/martdp_impl.hh | 1054 +++++----- cpp/src/hub/solver/martdp/martdp.hh | 284 +-- cpp/src/hub/solver/martdp/py_martdp.cc | 84 +- cpp/src/hub/solver/martdp/py_martdp.hh | 505 +++-- .../mcts_best_qvalue_action_selector_impl.hh | 56 +- .../impl/mcts_default_rollout_policy_impl.hh | 107 +- .../impl/mcts_default_tree_policy_impl.hh | 114 +- .../mcts_distribution_transition_mode_impl.hh | 52 +- .../solver/mcts/impl/mcts_full_expand_impl.hh | 662 ++++--- .../mcts/impl/mcts_graph_backup_impl.hh | 196 +- cpp/src/hub/solver/mcts/impl/mcts_impl.hh | 584 +++--- .../mcts/impl/mcts_partial_expand_impl.hh | 350 ++-- .../impl/mcts_sample_transition_mode_impl.hh | 50 +- .../impl/mcts_step_transition_mode_impl.hh | 62 +- .../impl/mcts_ucb1_action_selector_impl.hh | 65 +- cpp/src/hub/solver/mcts/impl/py_mcts_inst.hh | 42 +- cpp/src/hub/solver/mcts/mcts.hh | 689 +++---- cpp/src/hub/solver/mcts/py_mcts.cc | 223 ++- cpp/src/hub/solver/mcts/py_mcts.hh | 1150 +++++------ cpp/src/hub/solver/riw/impl/riw_impl.hh | 1669 ++++++++-------- cpp/src/hub/solver/riw/py_riw.cc | 81 +- cpp/src/hub/solver/riw/py_riw.hh | 599 +++--- cpp/src/hub/solver/riw/riw.hh | 512 +++-- .../utils/associative_container_deducer.hh | 227 +-- cpp/src/utils/execution.hh | 498 +++-- cpp/src/utils/impl/logging_impl.hh | 190 +- .../utils/impl/python_container_proxy_impl.hh | 402 ++-- .../impl/python_domain_proxy_base_impl.hh | 669 +++---- .../impl/python_domain_proxy_call_impl.hh | 366 ++-- .../impl/python_domain_proxy_common_impl.hh | 438 +++-- .../utils/impl/python_domain_proxy_impl.hh | 1718 +++++++++-------- cpp/src/utils/impl/python_hash_eq_impl.hh | 235 ++- cpp/src/utils/logging.hh | 42 +- cpp/src/utils/pegtl_spdlog_tracer.hh | 347 ++-- cpp/src/utils/python_container_proxy.hh | 219 +-- cpp/src/utils/python_domain_proxy.hh | 715 +++---- cpp/src/utils/python_domain_proxy_base.hh | 433 +++-- cpp/src/utils/python_gil_control.hh | 43 +- cpp/src/utils/python_globals.hh | 99 +- cpp/src/utils/python_hash_eq.hh | 28 +- cpp/src/utils/string_converter.hh | 22 +- cpp/src/utils/template_instantiator.hh | 325 ++-- 76 files changed, 13326 insertions(+), 12093 deletions(-) diff --git a/cpp/src/builders/domain/dynamics.hh b/cpp/src/builders/domain/dynamics.hh index 202ce091b6..d2dbf32108 100644 --- a/cpp/src/builders/domain/dynamics.hh +++ b/cpp/src/builders/domain/dynamics.hh @@ -17,78 +17,84 @@ template , typename TobservationDistribution = Distribution, template class TsmartPointer = std::unique_ptr> -class EnvironmentDomain : public virtual PartiallyObservableDomain, - public virtual HistoryDomain { -public : - typedef Tstate State; - typedef Tobservation Observation; - typedef Tvalue Value; - typedef Tinfo Info; - typedef EnvironmentOutcome EnvironmentOutcomeReturn; - typedef TransitionOutcome TransitionOutcomeReturn; - typedef TsmartPointer TransitionOutcomePtr; - typedef Tevent Event; - - EnvironmentOutcomeReturn step(const Event& event) { - const TransitionOutcomeReturn& transition_outcome = *(make_step(event)); - const State& next_state = transition_outcome.state; - Observation observation = get_observation_distribution(next_state, event).sample(); - if (this->get_memory_maxlen() > 0) { - this->_memory.push_back(next_state); - } - return EnvironmentOutcomeReturn(observation, transition_outcome.value, - transition_outcome.termination, transition_outcome.info); +class EnvironmentDomain + : public virtual PartiallyObservableDomain< + Tstate, Tobservation, Tevent, TstateSpace, TobservationSpace, + TobservationDistribution, TsmartPointer>, + public virtual HistoryDomain { +public: + typedef Tstate State; + typedef Tobservation Observation; + typedef Tvalue Value; + typedef Tinfo Info; + typedef EnvironmentOutcome + EnvironmentOutcomeReturn; + typedef TransitionOutcome TransitionOutcomeReturn; + typedef TsmartPointer TransitionOutcomePtr; + typedef Tevent Event; + + EnvironmentOutcomeReturn step(const Event &event) { + const TransitionOutcomeReturn &transition_outcome = *(make_step(event)); + const State &next_state = transition_outcome.state; + Observation observation = + get_observation_distribution(next_state, event).sample(); + if (this->get_memory_maxlen() > 0) { + this->_memory.push_back(next_state); } + return EnvironmentOutcomeReturn(observation, transition_outcome.value, + transition_outcome.termination, + transition_outcome.info); + } -protected : - virtual TransitionOutcomePtr make_step(const Event& event) = 0; +protected: + virtual TransitionOutcomePtr make_step(const Event &event) = 0; }; - template , typename TobservationSpace = Space, typename TobservationDistribution = Distribution, template class TsmartPointer = std::unique_ptr> -class SimulationDomain : public EnvironmentDomain { -public : - typedef Tstate State; - typedef Tobservation Observation; - typedef Tvalue Value; - typedef Tinfo Info; - typedef EnvironmentOutcome EnvironmentOutcomeReturn; - typedef TransitionOutcome TransitionOutcomeReturn; - typedef TsmartPointer TransitionOutcomePtr; - typedef Tevent Event; - typedef Memory StateMemory; - - EnvironmentOutcomeReturn sample(const StateMemory& memory, const Event& event) { - const TransitionOutcomeReturn& transition_outcome = *(make_sample(memory, event)); - const State& next_state = transition_outcome.state; - Observation observation = get_observation_distribution(next_state, event).sample(); - return EnvironmentOutcomeReturn(observation, transition_outcome.value, - transition_outcome.termination, transition_outcome.info); - } - - inline void set_memory(const StateMemory& m) { - this->_memory = m; - } - -protected : - inline virtual TransitionOutcomePtr make_step(const Event& event) { - return make_sample(this->_memory, event); - } - - virtual TransitionOutcomePtr make_sample(const StateMemory& memory, const Event& event) = 0; +class SimulationDomain + : public EnvironmentDomain { +public: + typedef Tstate State; + typedef Tobservation Observation; + typedef Tvalue Value; + typedef Tinfo Info; + typedef EnvironmentOutcome + EnvironmentOutcomeReturn; + typedef TransitionOutcome TransitionOutcomeReturn; + typedef TsmartPointer TransitionOutcomePtr; + typedef Tevent Event; + typedef Memory StateMemory; + + EnvironmentOutcomeReturn sample(const StateMemory &memory, + const Event &event) { + const TransitionOutcomeReturn &transition_outcome = + *(make_sample(memory, event)); + const State &next_state = transition_outcome.state; + Observation observation = + get_observation_distribution(next_state, event).sample(); + return EnvironmentOutcomeReturn(observation, transition_outcome.value, + transition_outcome.termination, + transition_outcome.info); + } + + inline void set_memory(const StateMemory &m) { this->_memory = m; } + +protected: + inline virtual TransitionOutcomePtr make_step(const Event &event) { + return make_sample(this->_memory, event); + } + + virtual TransitionOutcomePtr make_sample(const StateMemory &memory, + const Event &event) = 0; }; - template , @@ -96,35 +102,41 @@ template , typename TobservationDistribution = Distribution, template class TsmartPointer = std::unique_ptr> -class UncertainTransitionDomain : public SimulationDomain { -public : - typedef Tevent Event; - typedef Tstate State; - typedef Memory StateMemory; - typedef TstateDistribution NextStateDistribution; - typedef TsmartPointer NextStateDistributionPtr; - typedef Tinfo Info; - typedef Value TransitionValueReturn; - typedef TransitionOutcome TransitionOutcomeReturn; - typedef TsmartPointer TransitionOutcomePtr; - - virtual NextStateDistributionPtr get_next_state_distribution(const StateMemory& memory, const Event& event) = 0; - virtual TransitionValueReturn get_transition_value(const StateMemory& memory, const Event& event, const State& next_state) = 0; - virtual bool is_terminal(const State& state) = 0; - -protected : - virtual TransitionOutcomePtr make_sample(const StateMemory& memory, const Event& event) { - State next_state = get_next_state_distribution(memory, event).sample(); - TransitionValueReturn value = get_transition_value(memory, event, next_state); - bool termination = is_terminal(next_state); - return std::make_unique(next_state, value, termination); - } +class UncertainTransitionDomain + : public SimulationDomain { +public: + typedef Tevent Event; + typedef Tstate State; + typedef Memory StateMemory; + typedef TstateDistribution NextStateDistribution; + typedef TsmartPointer NextStateDistributionPtr; + typedef Tinfo Info; + typedef Value TransitionValueReturn; + typedef TransitionOutcome TransitionOutcomeReturn; + typedef TsmartPointer TransitionOutcomePtr; + + virtual NextStateDistributionPtr + get_next_state_distribution(const StateMemory &memory, + const Event &event) = 0; + virtual TransitionValueReturn + get_transition_value(const StateMemory &memory, const Event &event, + const State &next_state) = 0; + virtual bool is_terminal(const State &state) = 0; + +protected: + virtual TransitionOutcomePtr make_sample(const StateMemory &memory, + const Event &event) { + State next_state = get_next_state_distribution(memory, event).sample(); + TransitionValueReturn value = + get_transition_value(memory, event, next_state); + bool termination = is_terminal(next_state); + return std::make_unique(next_state, value, + termination); + } }; - template , @@ -132,50 +144,65 @@ template , typename TobservationDistribution = Distribution, template class TsmartPointer = std::unique_ptr> -class EnumerableTransitionDomain : public UncertainTransitionDomain, // Not the specialized DiscreteDistribution to allow for common base class recognition with multiple inheritance - TstateSpace, TobservationSpace, - TobservationDistribution, TsmartPointer> { - static_assert(std::is_base_of, TstateDistribution>::value, "State distribution type must be derived from skdecide::DiscreteDistribution"); - -public : - typedef Tevent Event; - typedef Tstate State; - typedef Memory StateMemory; - typedef Distribution NextStateDistribution; - typedef TsmartPointer NextStateDistributionPtr; - - virtual NextStateDistributionPtr get_next_state_distribution(const StateMemory& memory, const Event& event) = 0; +class EnumerableTransitionDomain + : public UncertainTransitionDomain< + Tstate, Tobservation, Tevent, TT, Tvalue, Tinfo, + Distribution, // Not the specialized + // DiscreteDistribution to allow for + // common base class recognition with multiple + // inheritance + TstateSpace, TobservationSpace, TobservationDistribution, + TsmartPointer> { + static_assert( + std::is_base_of, TstateDistribution>::value, + "State distribution type must be derived from " + "skdecide::DiscreteDistribution"); + +public: + typedef Tevent Event; + typedef Tstate State; + typedef Memory StateMemory; + typedef Distribution NextStateDistribution; + typedef TsmartPointer NextStateDistributionPtr; + + virtual NextStateDistributionPtr + get_next_state_distribution(const StateMemory &memory, + const Event &event) = 0; }; - template , typename TobservationSpace = Space, typename TobservationDistribution = Distribution, template class TsmartPointer = std::unique_ptr> -class DeterministicTransitionDomain : public EnumerableTransitionDomain, // Not the specialized SingleValueDistribution to allow for common base class recognition with multiple inheritance - TstateSpace, TobservationSpace, - TobservationDistribution, TsmartPointer> { -public : - typedef Tevent Event; - typedef Tstate State; - typedef TsmartPointer StatePtr; - typedef Memory StateMemory; - typedef Distribution NextStateDistribution; - typedef TsmartPointer NextStateDistributionPtr; - - inline virtual NextStateDistributionPtr get_next_state_distribution(const StateMemory& memory, const Event& event) { - return std::make_unique>(get_next_state(memory, event)); - } - - virtual StatePtr get_next_state(const StateMemory& memory, const Event& event) = 0; +class DeterministicTransitionDomain + : public EnumerableTransitionDomain< + Tstate, Tobservation, Tevent, TT, Tvalue, Tinfo, + Distribution, // Not the specialized + // SingleValueDistribution to allow for + // common base class recognition with multiple + // inheritance + TstateSpace, TobservationSpace, TobservationDistribution, + TsmartPointer> { +public: + typedef Tevent Event; + typedef Tstate State; + typedef TsmartPointer StatePtr; + typedef Memory StateMemory; + typedef Distribution NextStateDistribution; + typedef TsmartPointer NextStateDistributionPtr; + + inline virtual NextStateDistributionPtr + get_next_state_distribution(const StateMemory &memory, const Event &event) { + return std::make_unique>( + get_next_state(memory, event)); + } + + virtual StatePtr get_next_state(const StateMemory &memory, + const Event &event) = 0; }; -} +} // namespace skdecide #endif // SKDECIDE_DYNAMICS_HH diff --git a/cpp/src/builders/domain/events.hh b/cpp/src/builders/domain/events.hh index 29acf383c9..4517d03b80 100644 --- a/cpp/src/builders/domain/events.hh +++ b/cpp/src/builders/domain/events.hh @@ -19,129 +19,146 @@ template class TsmartPointer = std::unique_ptr> class EventDomain : public virtual HistoryDomain { - static_assert(std::is_same::value, "Event space elements must be of type Tevent"); - static_assert(std::is_base_of, TeventSpace>::value, "Event space type must be derived from skdecide::Space"); - static_assert(std::is_same::value, "Action space elements must be of type Tevent"); - static_assert(std::is_base_of, TactionSpace>::value, "Action space type must be derived from skdecide::Space"); - static_assert(std::is_same::value, "Enabled event space elements must be of type Tevent"); - static_assert(std::is_base_of, TenabledEventSpace>::value, "Enabled event space type must be derived from skdecide::Space"); - static_assert(std::is_same::value, "Applicable action space elements must be of type Tevent"); - static_assert(std::is_base_of, TapplicableActionSpace>::value, "Applicable action space type must be derived from skdecide::Space"); - -public : - typedef Tstate State; - typedef Memory StateMemory; - typedef Tevent Event; - typedef TeventSpace EventSpace; - typedef TsmartPointer EventSpacePtr; - typedef TenabledEventSpace EnabledEventSpace; - typedef TsmartPointer EnabledEventSpacePtr; - typedef Tevent Action; - typedef TactionSpace ActionSpace; - typedef TsmartPointer ActionSpacePtr; - typedef TapplicableActionSpace ApplicableActionSpace; - typedef TsmartPointer ApplicableActionSpacePtr; - - const EventSpace& get_event_space() { - if (!_event_space) { - _event_space = make_event_space(); - } - return *_event_space; + static_assert(std::is_same::value, + "Event space elements must be of type Tevent"); + static_assert( + std::is_base_of, TeventSpace>::value, + "Event space type must be derived from skdecide::Space"); + static_assert( + std::is_same::value, + "Action space elements must be of type Tevent"); + static_assert( + std::is_base_of, TactionSpace>::value, + "Action space type must be derived from skdecide::Space"); + static_assert( + std::is_same::value, + "Enabled event space elements must be of type Tevent"); + static_assert( + std::is_base_of, TenabledEventSpace>::value, + "Enabled event space type must be derived from skdecide::Space"); + static_assert(std::is_same::value, + "Applicable action space elements must be of type Tevent"); + static_assert(std::is_base_of, TapplicableActionSpace>::value, + "Applicable action space type must be derived from " + "skdecide::Space"); + +public: + typedef Tstate State; + typedef Memory StateMemory; + typedef Tevent Event; + typedef TeventSpace EventSpace; + typedef TsmartPointer EventSpacePtr; + typedef TenabledEventSpace EnabledEventSpace; + typedef TsmartPointer EnabledEventSpacePtr; + typedef Tevent Action; + typedef TactionSpace ActionSpace; + typedef TsmartPointer ActionSpacePtr; + typedef TapplicableActionSpace ApplicableActionSpace; + typedef TsmartPointer ApplicableActionSpacePtr; + + const EventSpace &get_event_space() { + if (!_event_space) { + _event_space = make_event_space(); } + return *_event_space; + } - virtual EnabledEventSpacePtr get_enabled_events(const StateMemory& memory) = 0; + virtual EnabledEventSpacePtr + get_enabled_events(const StateMemory &memory) = 0; - inline EnabledEventSpacePtr get_enabled_events() { - return get_enabled_events(this->_memory); - } + inline EnabledEventSpacePtr get_enabled_events() { + return get_enabled_events(this->_memory); + } - inline bool is_enabled_event(const Event& event, const StateMemory& memory) { - return get_enabled_events(memory).contains(event); - } + inline bool is_enabled_event(const Event &event, const StateMemory &memory) { + return get_enabled_events(memory).contains(event); + } - inline bool is_enabled_event(const Event& event) { - return is_enabled_event(event, this->_memory); - } + inline bool is_enabled_event(const Event &event) { + return is_enabled_event(event, this->_memory); + } - const ActionSpace& get_action_space() { - if (!_action_space) { - _action_space = make_action_space(); - } - return *_action_space; + const ActionSpace &get_action_space() { + if (!_action_space) { + _action_space = make_action_space(); } + return *_action_space; + } - inline bool is_action(const Event& event) { - return get_action_space().contains(event); - } + inline bool is_action(const Event &event) { + return get_action_space().contains(event); + } - virtual ApplicableActionSpacePtr get_applicable_actions(const StateMemory& memory) = 0; + virtual ApplicableActionSpacePtr + get_applicable_actions(const StateMemory &memory) = 0; - inline ApplicableActionSpacePtr get_applicable_actions() { - return get_applicable_actions(this->_memory); - } + inline ApplicableActionSpacePtr get_applicable_actions() { + return get_applicable_actions(this->_memory); + } - inline bool is_applicable_action(const Event& event, const StateMemory& memory) { - return get_applicable_actions(memory).contains(event); - } + inline bool is_applicable_action(const Event &event, + const StateMemory &memory) { + return get_applicable_actions(memory).contains(event); + } - inline bool is_applicable_action(const Event& event) { - return is_applicable_action(event, this->_memory); - } + inline bool is_applicable_action(const Event &event) { + return is_applicable_action(event, this->_memory); + } -protected : - virtual EventSpacePtr make_event_space() =0; - virtual ActionSpacePtr make_action_space() =0; +protected: + virtual EventSpacePtr make_event_space() = 0; + virtual ActionSpacePtr make_action_space() = 0; -private : - EventSpacePtr _event_space; - ActionSpacePtr _action_space; +private: + EventSpacePtr _event_space; + ActionSpacePtr _action_space; }; - template , typename TapplicableActionSpace = TactionSpace, template class TsmartPointer = std::unique_ptr> -class ActionDomain : public EventDomain { -public : - typedef Tstate State; - typedef Memory MemoryState; - typedef Taction Action; - typedef TactionSpace ActionSpace; - typedef TapplicableActionSpace ApplicableActionSpace; - typedef TsmartPointer ApplicableActionSpacePtr; - - inline virtual const ActionSpace& get_event_space() { - return this->get_action_space(); - } - - inline virtual ApplicableActionSpacePtr get_enabled_events(const MemoryState& memory) { - return this->get_enabled_actions(memory); - } +class ActionDomain : public EventDomain { +public: + typedef Tstate State; + typedef Memory MemoryState; + typedef Taction Action; + typedef TactionSpace ActionSpace; + typedef TapplicableActionSpace ApplicableActionSpace; + typedef TsmartPointer ApplicableActionSpacePtr; + + inline virtual const ActionSpace &get_event_space() { + return this->get_action_space(); + } + + inline virtual ApplicableActionSpacePtr + get_enabled_events(const MemoryState &memory) { + return this->get_enabled_actions(memory); + } }; - template , typename TapplicableActionSpace = TactionSpace, template class TsmartPointer = std::unique_ptr> -class UnrestrictedActionDomain : public ActionDomain { -public : - typedef Tstate State; - typedef Memory StateMemory; - typedef Taction Action; - typedef TactionSpace ActionSpace; - typedef TapplicableActionSpace ApplicableActionSpace; - typedef TsmartPointer ApplicableActionSpacePtr; - - inline virtual ApplicableActionSpacePtr get_applicable_actions(const StateMemory& memory) { - return this->get_action_space(); - } +class UnrestrictedActionDomain + : public ActionDomain { +public: + typedef Tstate State; + typedef Memory StateMemory; + typedef Taction Action; + typedef TactionSpace ActionSpace; + typedef TapplicableActionSpace ApplicableActionSpace; + typedef TsmartPointer ApplicableActionSpacePtr; + + inline virtual ApplicableActionSpacePtr + get_applicable_actions(const StateMemory &memory) { + return this->get_action_space(); + } }; } // namespace skdecide diff --git a/cpp/src/builders/domain/goals.hh b/cpp/src/builders/domain/goals.hh index 50de87edb5..06df28acf0 100644 --- a/cpp/src/builders/domain/goals.hh +++ b/cpp/src/builders/domain/goals.hh @@ -13,30 +13,34 @@ template , template class TsmartPointer = std::unique_ptr> class GoalDomain { - static_assert(std::is_same::value, "Observation space elements must be of type Tobservation"); - static_assert(std::is_base_of, TobservationSpace>::value, "Observation space type must be derived from skdecide::Space"); - -public : - typedef Tobservation Observation; - typedef TobservationSpace ObservationSpace; - typedef TsmartPointer ObservationSpacePtr; - - const ObservationSpace& get_goals() { - if (!_goals) { - _goals = make_goals(); - } - return *_goals; + static_assert(std::is_same::value, + "Observation space elements must be of type Tobservation"); + static_assert(std::is_base_of, TobservationSpace>::value, + "Observation space type must be derived from " + "skdecide::Space"); + +public: + typedef Tobservation Observation; + typedef TobservationSpace ObservationSpace; + typedef TsmartPointer ObservationSpacePtr; + + const ObservationSpace &get_goals() { + if (!_goals) { + _goals = make_goals(); } + return *_goals; + } - inline bool is_goal(const Observation& observation) { - return get_goals().contains(observation); - } + inline bool is_goal(const Observation &observation) { + return get_goals().contains(observation); + } -protected : - virtual ObservationSpacePtr make_goals() =0; +protected: + virtual ObservationSpacePtr make_goals() = 0; -private : - ObservationSpacePtr _goals; +private: + ObservationSpacePtr _goals; }; } // namespace skdecide diff --git a/cpp/src/builders/domain/initialization.hh b/cpp/src/builders/domain/initialization.hh index 18084bae86..3914644b4c 100644 --- a/cpp/src/builders/domain/initialization.hh +++ b/cpp/src/builders/domain/initialization.hh @@ -16,109 +16,108 @@ template , typename TobservationDistribution = Distribution, template class TsmartPointer = std::unique_ptr> -class InitializableDomain : public virtual PartiallyObservableDomain, - public virtual HistoryDomain { -public : - typedef Tobservation Observation; - typedef TobservationSpace ObservationSpace; - typedef TobservationDistribution ObservationDistribution; - typedef TsmartPointer ObservationDistributionPtr; - typedef Tstate State; - typedef TstateSpace StateSpace; - typedef Tevent Event; - - Observation reset() { - State initial_state = _reset(); - Observation initial_observation = this->get_observation_distribution(initial_state, nullptr)->sample(); - auto _memory = this->_init_memory({initial_state}); - return initial_observation; - } - -protected : - virtual State _reset() = 0; +class InitializableDomain + : public virtual PartiallyObservableDomain< + Tstate, Tobservation, Tevent, TstateSpace, TobservationSpace, + TobservationDistribution, TsmartPointer>, + public virtual HistoryDomain { +public: + typedef Tobservation Observation; + typedef TobservationSpace ObservationSpace; + typedef TobservationDistribution ObservationDistribution; + typedef TsmartPointer ObservationDistributionPtr; + typedef Tstate State; + typedef TstateSpace StateSpace; + typedef Tevent Event; + + Observation reset() { + State initial_state = _reset(); + Observation initial_observation = + this->get_observation_distribution(initial_state, nullptr)->sample(); + auto _memory = this->_init_memory({initial_state}); + return initial_observation; + } + +protected: + virtual State _reset() = 0; }; - template , typename TstateSpace = Space, typename TobservationSpace = Space, typename TobservationDistribution = Distribution, template class TsmartPointer = std::unique_ptr> -class UncertainInitializedDomain : public InitializableDomain { -public : - typedef Tobservation Observation; - typedef TobservationSpace ObservationSpace; - typedef TobservationDistribution ObservationDistribution; - typedef TsmartPointer ObservationDistributionPtr; - typedef Tstate State; - typedef TstateSpace StateSpace; - typedef Tevent Event; - typedef TinitialStateDistribution InitialStateDistribution; - typedef TsmartPointer InitialStateDistributionPtr; - - InitialStateDistribution& get_initial_state_distribution() { - if (!_initial_state_distribution) { - _initial_state_distribution = make_initial_state_distribution(); - } - return *_initial_state_distribution; +class UncertainInitializedDomain + : public InitializableDomain { +public: + typedef Tobservation Observation; + typedef TobservationSpace ObservationSpace; + typedef TobservationDistribution ObservationDistribution; + typedef TsmartPointer ObservationDistributionPtr; + typedef Tstate State; + typedef TstateSpace StateSpace; + typedef Tevent Event; + typedef TinitialStateDistribution InitialStateDistribution; + typedef TsmartPointer InitialStateDistributionPtr; + + InitialStateDistribution &get_initial_state_distribution() { + if (!_initial_state_distribution) { + _initial_state_distribution = make_initial_state_distribution(); } + return *_initial_state_distribution; + } -protected : - virtual InitialStateDistributionPtr make_initial_state_distribution() =0; +protected: + virtual InitialStateDistributionPtr make_initial_state_distribution() = 0; - virtual State _reset() { - return get_initial_state_distribution().sample(); - } + virtual State _reset() { return get_initial_state_distribution().sample(); } -private : - InitialStateDistributionPtr _initial_state_distribution; +private: + InitialStateDistributionPtr _initial_state_distribution; }; - template , typename TstateSpace = Space, typename TobservationSpace = Space, typename TobservationDistribution = Distribution, template class TsmartPointer = std::unique_ptr> -class DeterministicInitializedDomain : public UncertainInitializedDomain { -public : - typedef Tobservation Observation; - typedef TobservationSpace ObservationSpace; - typedef TobservationDistribution ObservationDistribution; - typedef TsmartPointer ObservationDistributionPtr; - typedef Tstate State; - typedef TsmartPointer StatePtr; - typedef TstateSpace StateSpace; - typedef Tevent Event; - typedef TinitialStateDistribution InitialStateDistribution; - typedef TsmartPointer InitialStateDistributionPtr; - - const State& get_initial_state() { - if (!_initial_state) { - _initial_state = make_initial_state(); - } - return *_initial_state; +class DeterministicInitializedDomain + : public UncertainInitializedDomain< + Tstate, Tobservation, Tevent, TinitialStateDistribution, TstateSpace, + TobservationSpace, TobservationDistribution, TsmartPointer> { +public: + typedef Tobservation Observation; + typedef TobservationSpace ObservationSpace; + typedef TobservationDistribution ObservationDistribution; + typedef TsmartPointer ObservationDistributionPtr; + typedef Tstate State; + typedef TsmartPointer StatePtr; + typedef TstateSpace StateSpace; + typedef Tevent Event; + typedef TinitialStateDistribution InitialStateDistribution; + typedef TsmartPointer InitialStateDistributionPtr; + + const State &get_initial_state() { + if (!_initial_state) { + _initial_state = make_initial_state(); } + return *_initial_state; + } -protected : - virtual StatePtr make_initial_state() =0; +protected: + virtual StatePtr make_initial_state() = 0; -private : - StatePtr _initial_state; +private: + StatePtr _initial_state; - virtual InitialStateDistributionPtr make_initial_state_distribution() { - return std::make_unique>(get_initial_state()); - } + virtual InitialStateDistributionPtr make_initial_state_distribution() { + return std::make_unique>( + get_initial_state()); + } }; } // namespace skdecide diff --git a/cpp/src/builders/domain/memory.hh b/cpp/src/builders/domain/memory.hh index d865869305..213d7de3f0 100644 --- a/cpp/src/builders/domain/memory.hh +++ b/cpp/src/builders/domain/memory.hh @@ -10,132 +10,121 @@ namespace skdecide { -template -class HistoryDomain { -public : - typedef Tstate State; - typedef Memory StateMemory; - typedef std::unique_ptr StateMemoryPtr; - - inline virtual bool check_memory(const StateMemory& memory) { - return true; - } - - inline virtual bool check_memory() { - if (!_memory) - throw std::invalid_argument("Uninitialized internal state memory"); - - return check_memory(*_memory); - } - - static State& get_last_state(StateMemory& memory) { - if (memory.size() > 0) { - return memory.back(); - } else { - throw std::out_of_range("Attempting to get last state of empty memory object"); - } +template class HistoryDomain { +public: + typedef Tstate State; + typedef Memory StateMemory; + typedef std::unique_ptr StateMemoryPtr; + + inline virtual bool check_memory(const StateMemory &memory) { return true; } + + inline virtual bool check_memory() { + if (!_memory) + throw std::invalid_argument("Uninitialized internal state memory"); + + return check_memory(*_memory); + } + + static State &get_last_state(StateMemory &memory) { + if (memory.size() > 0) { + return memory.back(); + } else { + throw std::out_of_range( + "Attempting to get last state of empty memory object"); } + } - inline State& get_last_state() { - if (!_memory) - throw std::invalid_argument("Uninitialized internal state memory"); + inline State &get_last_state() { + if (!_memory) + throw std::invalid_argument("Uninitialized internal state memory"); - return get_last_state(*_memory); - } + return get_last_state(*_memory); + } -protected : - StateMemoryPtr _memory; +protected: + StateMemoryPtr _memory; - /** - * Protected constructor because the class must be specialized to properly - * initialize the state memory - */ - HistoryDomain() {} + /** + * Protected constructor because the class must be specialized to properly + * initialize the state memory + */ + HistoryDomain() {} - template - inline StateMemoryPtr _init_memory(InputIt iBegin, InputIt iEnd) { - return std::make_unique(iBegin, iEnd, _get_memory_maxlen()); - } + template + inline StateMemoryPtr _init_memory(InputIt iBegin, InputIt iEnd) { + return std::make_unique(iBegin, iEnd, _get_memory_maxlen()); + } - inline StateMemoryPtr _init_memory(std::initializer_list iList) { - return std::make_unique(iList, _get_memory_maxlen()); - } + inline StateMemoryPtr _init_memory(std::initializer_list iList) { + return std::make_unique(iList, _get_memory_maxlen()); + } - inline virtual std::size_t _get_memory_maxlen() { - return std::numeric_limits::max(); - } + inline virtual std::size_t _get_memory_maxlen() { + return std::numeric_limits::max(); + } }; - template class FiniteHistoryDomain : public HistoryDomain { -public : - typedef Tstate State; - typedef Memory StateMemory; - - inline virtual bool check_memory(const StateMemory& memory) { - return memory.maxlen() == _get_memory_maxlen(); +public: + typedef Tstate State; + typedef Memory StateMemory; + + inline virtual bool check_memory(const StateMemory &memory) { + return memory.maxlen() == _get_memory_maxlen(); + } + + inline virtual bool check_memory() { + if (!(this->_memory)) + throw std::invalid_argument("Uninitialized internal state memory"); + + return check_memory(*(this->_memory)); + } + +protected: + /** + * Protected constructor because the class must be specialized to properly + * initialize the state memory + */ + FiniteHistoryDomain() {} + + inline virtual std::size_t _get_memory_maxlen() { + if (!_memory_maxlen) { + _memory_maxlen = std::make_unique(make_memory_maxlen()); } + return *_memory_maxlen; + } - inline virtual bool check_memory() { - if (!(this->_memory)) - throw std::invalid_argument("Uninitialized internal state memory"); + virtual std::size_t make_memory_maxlen() = 0; - return check_memory(*(this->_memory)); - } - -protected : - /** - * Protected constructor because the class must be specialized to properly - * initialize the state memory - */ - FiniteHistoryDomain() {} - - inline virtual std::size_t _get_memory_maxlen() { - if (!_memory_maxlen) { - _memory_maxlen = std::make_unique(make_memory_maxlen()); - } - return *_memory_maxlen; - } - - virtual std::size_t make_memory_maxlen() =0; - -private : - std::unique_ptr _memory_maxlen; +private: + std::unique_ptr _memory_maxlen; }; - template class MarkovianDomain : public FiniteHistoryDomain { -public : - typedef Tstate State; - -protected : - /** - * Protected constructor because the class must be specialized to properly - * initialize the state memory - */ - MarkovianDomain() {} - - inline virtual std::size_t make_memory_maxlen() { - return 1; - } -}; +public: + typedef Tstate State; + +protected: + /** + * Protected constructor because the class must be specialized to properly + * initialize the state memory + */ + MarkovianDomain() {} + inline virtual std::size_t make_memory_maxlen() { return 1; } +}; template class MemorylessDomain : public MarkovianDomain { -public : - typedef Tstate State; +public: + typedef Tstate State; - MemorylessDomain() { - this->_memory = this->_init_memory({}); - } + MemorylessDomain() { this->_memory = this->_init_memory({}); } -protected : - inline virtual std::size_t make_memory_maxlen() { - return 0; - } +protected: + inline virtual std::size_t make_memory_maxlen() { return 0; } }; } // namespace skdecide diff --git a/cpp/src/builders/domain/observability.hh b/cpp/src/builders/domain/observability.hh index de4d005fbc..a8a1ab5870 100644 --- a/cpp/src/builders/domain/observability.hh +++ b/cpp/src/builders/domain/observability.hh @@ -17,80 +17,98 @@ template , template class TsmartPointer = std::unique_ptr> class PartiallyObservableDomain { - static_assert(std::is_same::value, "State space elements must be of type Tstate"); - static_assert(std::is_base_of, TstateSpace>::value, "State space type must be derived from skdecide::Space"); - static_assert(std::is_same::value, "Observation space elements must be of type Tobservation"); - static_assert(std::is_base_of, TobservationSpace>::value, "Observation space type must be derived from skdecide::Space"); - static_assert(std::is_same::value, "Observation distribution elements must be of type Tobservation"); - static_assert(std::is_base_of, TobservationDistribution>::value, "State space type must be derived from skdecide::Space"); - -public : - typedef Tobservation Observation; - typedef TobservationSpace ObservationSpace; - typedef TsmartPointer ObservationSpacePtr; - typedef TobservationDistribution ObservationDistribution; - typedef TsmartPointer ObservationDistributionPtr; - typedef Tstate State; - typedef TstateSpace StateSpace; - typedef TsmartPointer StateSpacePtr; - typedef Tevent Event; - - const ObservationSpace& get_observation_space() { - if (!_observation_space) { - _observation_space = make_observation_space(); - } - return *_observation_space; + static_assert(std::is_same::value, + "State space elements must be of type Tstate"); + static_assert( + std::is_base_of, TstateSpace>::value, + "State space type must be derived from skdecide::Space"); + static_assert(std::is_same::value, + "Observation space elements must be of type Tobservation"); + static_assert(std::is_base_of, TobservationSpace>::value, + "Observation space type must be derived from " + "skdecide::Space"); + static_assert( + std::is_same::value, + "Observation distribution elements must be of type Tobservation"); + static_assert( + std::is_base_of, + TobservationDistribution>::value, + "State space type must be derived from skdecide::Space"); + +public: + typedef Tobservation Observation; + typedef TobservationSpace ObservationSpace; + typedef TsmartPointer ObservationSpacePtr; + typedef TobservationDistribution ObservationDistribution; + typedef TsmartPointer ObservationDistributionPtr; + typedef Tstate State; + typedef TstateSpace StateSpace; + typedef TsmartPointer StateSpacePtr; + typedef Tevent Event; + + const ObservationSpace &get_observation_space() { + if (!_observation_space) { + _observation_space = make_observation_space(); } + return *_observation_space; + } - const StateSpace& get_state_space() { - if (!_state_space) { - _state_space = make_state_space(); - } - return *_state_space; + const StateSpace &get_state_space() { + if (!_state_space) { + _state_space = make_state_space(); } + return *_state_space; + } - inline bool is_observation(const Observation& observation) { - return get_observation_space().contains(observation); - } + inline bool is_observation(const Observation &observation) { + return get_observation_space().contains(observation); + } - inline bool is_state(const State& state) { - return get_state_space().contains(state); - } + inline bool is_state(const State &state) { + return get_state_space().contains(state); + } - virtual ObservationDistributionPtr get_observation_distribution(const State& state, const Event& event) = 0; + virtual ObservationDistributionPtr + get_observation_distribution(const State &state, const Event &event) = 0; -protected : - virtual ObservationSpacePtr make_observation_space() =0; - virtual StateSpacePtr make_state_space() =0; +protected: + virtual ObservationSpacePtr make_observation_space() = 0; + virtual StateSpacePtr make_state_space() = 0; -private : - ObservationSpacePtr _observation_space; - StateSpacePtr _state_space; +private: + ObservationSpacePtr _observation_space; + StateSpacePtr _state_space; }; - template , template class TsmartPointer = std::unique_ptr> -class FullyObservableDomain : public virtual PartiallyObservableDomain, // Not the specialized SingleValueDistribution to allow for common base class recognition with multiple inheritance - TsmartPointer> { -public : - typedef Tstate State; - typedef TstateSpace StateSpace; - typedef TsmartPointer StateSpacePtr; - typedef Tevent Event; - typedef TsmartPointer> ObservationDistributionPtr; - - inline virtual ObservationDistributionPtr get_observation_distribution(const State& state, const Event& event) { - return std::make_unique>(state); - } - -protected : - inline virtual StateSpacePtr make_observation_space() { - return this->make_state_space(); - } +class FullyObservableDomain + : public virtual PartiallyObservableDomain< + Tstate, Tstate, Tevent, TstateSpace, TstateSpace, + Distribution, // Not the specialized + // SingleValueDistribution to allow for + // common base class recognition with multiple + // inheritance + TsmartPointer> { +public: + typedef Tstate State; + typedef TstateSpace StateSpace; + typedef TsmartPointer StateSpacePtr; + typedef Tevent Event; + typedef TsmartPointer> ObservationDistributionPtr; + + inline virtual ObservationDistributionPtr + get_observation_distribution(const State &state, const Event &event) { + return std::make_unique>(state); + } + +protected: + inline virtual StateSpacePtr make_observation_space() { + return this->make_state_space(); + } }; } // namespace skdecide diff --git a/cpp/src/builders/domain/renderability.hh b/cpp/src/builders/domain/renderability.hh index 9709fe49e6..d118397165 100644 --- a/cpp/src/builders/domain/renderability.hh +++ b/cpp/src/builders/domain/renderability.hh @@ -11,12 +11,10 @@ namespace skdecide { template class RenderableDomain : public virtual HistoryDomain { -public : - virtual void render(const Memory& memory) = 0; +public: + virtual void render(const Memory &memory) = 0; - inline void render() { - render(_memory); - } + inline void render() { render(_memory); } }; } // namespace skdecide diff --git a/cpp/src/builders/domain/value.hh b/cpp/src/builders/domain/value.hh index 4528e2973a..1db2ba9f8a 100644 --- a/cpp/src/builders/domain/value.hh +++ b/cpp/src/builders/domain/value.hh @@ -9,21 +9,17 @@ namespace skdecide { -template -class RewardDomain { -public : - inline virtual bool check_value(const Value& value) { - return true; - } +template class RewardDomain { +public: + inline virtual bool check_value(const Value &value) { return true; } }; - template class PositiveCostDomain : public RewardDomain { -public : - inline virtual bool check_value(const Value& value) { - return value.cost >= 0; - } +public: + inline virtual bool check_value(const Value &value) { + return value.cost >= 0; + } }; } // namespace skdecide diff --git a/cpp/src/core.hh b/cpp/src/core.hh index 5068c066a6..c87a81a91a 100644 --- a/cpp/src/core.hh +++ b/cpp/src/core.hh @@ -17,386 +17,370 @@ using json = nlohmann::json; namespace skdecide { -template -class Space { -public : - /** - * Type of elements of the space - */ - typedef T element_type; - - /** - * Default destructor - */ - virtual ~Space() {} - - /** - * Return boolean specifying if x is a valid member of this space - */ - virtual bool contains(const T& x) const =0; +template class Space { +public: + /** + * Type of elements of the space + */ + typedef T element_type; + + /** + * Default destructor + */ + virtual ~Space() {} + + /** + * Return boolean specifying if x is a valid member of this space + */ + virtual bool contains(const T &x) const = 0; }; - -template -class ImplicitSpace : public Space { -public : - /** - * Type of elements of the space - */ - typedef T element_type; - - /** - * Constructor - * @param containsFunctor Functor (can be a lambda expression returning - * boolean specifying if x is a valid member of - * this space) - */ - ImplicitSpace(std::function containsFunctor) - : m_containsFunctor(containsFunctor) {} - - /** - * Return boolean specifying if x is a valid member of this space - */ - virtual bool contains(const T& x) const { - return m_containsFunctor(x); - } - -private : - std::function m_containsFunctor; +template class ImplicitSpace : public Space { +public: + /** + * Type of elements of the space + */ + typedef T element_type; + + /** + * Constructor + * @param containsFunctor Functor (can be a lambda expression returning + * boolean specifying if x is a valid member of + * this space) + */ + ImplicitSpace(std::function containsFunctor) + : m_containsFunctor(containsFunctor) {} + + /** + * Return boolean specifying if x is a valid member of this space + */ + virtual bool contains(const T &x) const { return m_containsFunctor(x); } + +private: + std::function m_containsFunctor; }; - -template class Tcontainer = std::unordered_set> +template class Tcontainer = std::unordered_set> class EnumerableSpace : public Space { -public : - /** - * Type of elements of the space - */ - typedef T element_type; - - /** - * Return the elements of this space - */ - virtual const Tcontainer& get_elements() const =0; +public: + /** + * Type of elements of the space + */ + typedef T element_type; + + /** + * Return the elements of this space + */ + virtual const Tcontainer &get_elements() const = 0; }; - -template -class SamplableSpace : public Space { -public : - /** - * Type of elements of the space - */ - typedef T element_type; - - /** - * Uniformly randomly sample a random element of this space - */ - virtual T sample() const =0; +template class SamplableSpace : public Space { +public: + /** + * Type of elements of the space + */ + typedef T element_type; + + /** + * Uniformly randomly sample a random element of this space + */ + virtual T sample() const = 0; }; - -template class Container = std::unordered_set> +template class Container = std::unordered_set> class SerializableSpace : public Space { -public : - /** - * Type of elements of the space - */ - typedef T element_type; - - /** - * Convert a batch of samples from this space to a JSONable data type - */ - virtual json to_jsonable(const Container& sample_n) const { - // By default, assume identity is JSONable - // See https://github.com/nlohmann/json#arbitrary-types-conversions - return json(sample_n); - } - - /** - * Convert a JSONable data type to a batch of samples from this space - */ - virtual Container from_jsonable(const json& sample_n) const { - // By default, assume identity is JSONable - // See https://github.com/nlohmann/json#arbitrary-types-conversions - return Container(sample_n.get>()); - } +public: + /** + * Type of elements of the space + */ + typedef T element_type; + + /** + * Convert a batch of samples from this space to a JSONable data type + */ + virtual json to_jsonable(const Container &sample_n) const { + // By default, assume identity is JSONable + // See https://github.com/nlohmann/json#arbitrary-types-conversions + return json(sample_n); + } + + /** + * Convert a JSONable data type to a batch of samples from this space + */ + virtual Container from_jsonable(const json &sample_n) const { + // By default, assume identity is JSONable + // See https://github.com/nlohmann/json#arbitrary-types-conversions + return Container(sample_n.get>()); + } }; - -template -class Distribution { -public : - /** - * Type of elements of the distribution - */ - typedef T element_type; - - /** - * Destructor - */ - virtual ~Distribution() {} - - /** - * Returning a sample from the distribution - */ - virtual T sample() =0; +template class Distribution { +public: + /** + * Type of elements of the distribution + */ + typedef T element_type; + + /** + * Destructor + */ + virtual ~Distribution() {} + + /** + * Returning a sample from the distribution + */ + virtual T sample() = 0; }; - -template -class ImplicitDistribution : public Distribution { -public : - /** - * Type of elements of the distribution - */ - typedef T element_type; - - /** - * Constructor - * @param sampleFunctor Functor (can be a lambda expression) returning - * a sample from the distribution - */ - ImplicitDistribution(std::function sampleFunctor) - : m_sampleFunctor(sampleFunctor) {} - - /** - * Returning a sample from the distribution - */ - virtual T sample() { - return m_sampleFunctor(); - } - -private : - std::function m_sampleFunctor; +template class ImplicitDistribution : public Distribution { +public: + /** + * Type of elements of the distribution + */ + typedef T element_type; + + /** + * Constructor + * @param sampleFunctor Functor (can be a lambda expression) returning + * a sample from the distribution + */ + ImplicitDistribution(std::function sampleFunctor) + : m_sampleFunctor(sampleFunctor) {} + + /** + * Returning a sample from the distribution + */ + virtual T sample() { return m_sampleFunctor(); } + +private: + std::function m_sampleFunctor; }; - -template class Container = std::unordered_map, +template class Container = std::unordered_map, typename Generator = std::mt19937, typename IntType = int> class DiscreteDistribution : public Distribution { -public : - /** - * Type of elements of the distribution - */ - typedef T element_type; - - /** - * Constructor - * @param iBegin Associative container begin iterator - * @param iEnd Associative container end iterator - * @param g Random number generator - */ - template - DiscreteDistribution(InputIt iBegin, InputIt iEnd) - : m_generator(Generator(std::random_device()())) { - for (InputIt i = iBegin; i != iEnd; i++) { - std::pair::iterator, bool> r = m_values.insert(*i); - if (r.second) { - m_indexes.push_back(r.first); - } else { - r.first->second += i->second; - } - } - - if (m_values.empty()) { - m_indexes.push_back(m_values.insert(std::make_pair(T(), 1.0)).first); - } - - std::vector probabilities; - std::for_each(m_indexes.begin(), m_indexes.end(), [&](const auto& i){probabilities.push_back(i->second);}); - m_distribution.param(typename std::discrete_distribution::param_type(probabilities.begin(), probabilities.end())); - probabilities = m_distribution.probabilities(); // get normalized probabilities - - for (std::size_t i = 0 ; i < probabilities.size() ; i++) { - m_indexes[i]->second = probabilities[i]; - } +public: + /** + * Type of elements of the distribution + */ + typedef T element_type; + + /** + * Constructor + * @param iBegin Associative container begin iterator + * @param iEnd Associative container end iterator + * @param g Random number generator + */ + template + DiscreteDistribution(InputIt iBegin, InputIt iEnd) + : m_generator(Generator(std::random_device()())) { + for (InputIt i = iBegin; i != iEnd; i++) { + std::pair::iterator, bool> r = + m_values.insert(*i); + if (r.second) { + m_indexes.push_back(r.first); + } else { + r.first->second += i->second; + } } - /** - * Constructor - */ - DiscreteDistribution(std::initializer_list > iList) - : DiscreteDistribution(iList.begin(), iList.end()) {} - - /** - * Returning a sample from the distribution - */ - virtual T sample() { - return m_indexes[m_distribution(m_generator)]->first; + if (m_values.empty()) { + m_indexes.push_back(m_values.insert(std::make_pair(T(), 1.0)).first); } - /** - * Get the list of (element, probability) values - */ - const Container& get_values() const { - return m_values; + std::vector probabilities; + std::for_each(m_indexes.begin(), m_indexes.end(), + [&](const auto &i) { probabilities.push_back(i->second); }); + m_distribution.param( + typename std::discrete_distribution::param_type( + probabilities.begin(), probabilities.end())); + probabilities = + m_distribution.probabilities(); // get normalized probabilities + + for (std::size_t i = 0; i < probabilities.size(); i++) { + m_indexes[i]->second = probabilities[i]; } - -private : - Container m_values; - std::vector::iterator> m_indexes; - Generator m_generator; - std::discrete_distribution m_distribution; + } + + /** + * Constructor + */ + DiscreteDistribution(std::initializer_list> iList) + : DiscreteDistribution(iList.begin(), iList.end()) {} + + /** + * Returning a sample from the distribution + */ + virtual T sample() { return m_indexes[m_distribution(m_generator)]->first; } + + /** + * Get the list of (element, probability) values + */ + const Container &get_values() const { return m_values; } + +private: + Container m_values; + std::vector::iterator> m_indexes; + Generator m_generator; + std::discrete_distribution m_distribution; }; - template class SingleValueDistribution : public DiscreteDistribution { -public : - /** - * Type of elements of the distribution - */ - typedef T element_type; - - /** - * Constructor - */ - SingleValueDistribution(const T& value) - : DiscreteDistribution({{value, 1.0}}), m_value(value) {} - - /** - * Returning a sample from the distribution - */ - virtual T sample() { - return m_value; - } - - /** - * Returning the value - */ - const T& get_value() const { - return m_value; - } - -private : - T m_value; +public: + /** + * Type of elements of the distribution + */ + typedef T element_type; + + /** + * Constructor + */ + SingleValueDistribution(const T &value) + : DiscreteDistribution({{value, 1.0}}), m_value(value) {} + + /** + * Returning a sample from the distribution + */ + virtual T sample() { return m_value; } + + /** + * Returning the value + */ + const T &get_value() const { return m_value; } + +private: + T m_value; }; +enum class TransitionType { REWARD, COST }; -enum class TransitionType { - REWARD, - COST -}; +template +class Value; -template class Value; +template class Value { +public: + Value(const T &value) : m_value(value) {} -template -class Value { -public : - Value(const T& value) : m_value(value) {} + inline virtual T reward() const { return m_value; } + inline virtual T cost() const { return -m_value; } - inline virtual T reward() const { return m_value; } - inline virtual T cost() const { return -m_value; } - -private : - T m_value; +private: + T m_value; }; -template -class Value { -public : - Value(const T& value) : m_value(value) {} +template class Value { +public: + Value(const T &value) : m_value(value) {} - inline virtual T reward() const { return -m_value; } - inline virtual T cost() const { return m_value; } + inline virtual T reward() const { return -m_value; } + inline virtual T cost() const { return m_value; } -private : - T m_value; +private: + T m_value; }; - -template +template struct EnvironmentOutcome { - EnvironmentOutcome(const Tobservation& observation, const Tvalue& value, bool termination, const Tinfo& info= Tinfo()) - : observation(observation), value(value), termination(termination), info(info) {} - - Tobservation observation; - Value value; - bool termination; - Tinfo info; + EnvironmentOutcome(const Tobservation &observation, const Tvalue &value, + bool termination, const Tinfo &info = Tinfo()) + : observation(observation), value(value), termination(termination), + info(info) {} + + Tobservation observation; + Value value; + bool termination; + Tinfo info; }; - -template +template struct TransitionOutcome { - TransitionOutcome(const Tstate& state, const Tvalue& value, bool termination, const Tinfo& info = Tinfo()) - : state(state), value(value), termination(termination), info(info) {} - - Tstate state; - Value value; - bool termination; - Tinfo info; + TransitionOutcome(const Tstate &state, const Tvalue &value, bool termination, + const Tinfo &info = Tinfo()) + : state(state), value(value), termination(termination), info(info) {} + + Tstate state; + Value value; + bool termination; + Tinfo info; }; - /** - * Deque class with maxlen feature like python deque; only for push_back and push_front! + * Deque class with maxlen feature like python deque; only for push_back and + * push_front! */ -template -class Memory : public std::deque { -public : - Memory(std::size_t maxlen = std::numeric_limits::max()) - : std::deque(), _maxlen(maxlen) {} - - template - Memory(InputIt iBegin, InputIt iEnd, std::size_t maxlen = std::numeric_limits::max()) - : std::deque(iBegin, iEnd), _maxlen(maxlen) { - if (this->size() > maxlen) { - std::deque::erase(std::deque::begin(), std::deque::begin() + std::deque::size() - maxlen); - } +template class Memory : public std::deque { +public: + Memory(std::size_t maxlen = std::numeric_limits::max()) + : std::deque(), _maxlen(maxlen) {} + + template + Memory(InputIt iBegin, InputIt iEnd, + std::size_t maxlen = std::numeric_limits::max()) + : std::deque(iBegin, iEnd), _maxlen(maxlen) { + if (this->size() > maxlen) { + std::deque::erase(std::deque::begin(), std::deque::begin() + + std::deque::size() - + maxlen); } + } - Memory(std::initializer_list iList, std::size_t maxlen = std::numeric_limits::max()) - : Memory(iList.begin(), iList.end(), maxlen) {} + Memory(std::initializer_list iList, + std::size_t maxlen = std::numeric_limits::max()) + : Memory(iList.begin(), iList.end(), maxlen) {} - Memory(const Memory& m) - : std::deque(static_cast&>(m)), _maxlen(m._maxlen) {} + Memory(const Memory &m) + : std::deque(static_cast &>(m)), + _maxlen(m._maxlen) {} - void operator=(const Memory& m) { - static_cast&>(*this) = static_cast&>(m); - _maxlen = m._maxlen; - } + void operator=(const Memory &m) { + static_cast &>(*this) = static_cast &>(m); + _maxlen = m._maxlen; + } - bool operator==(const Memory& m) { - return static_cast&>(*this) == static_cast&>(m) - && _maxlen == m._maxlen; - } + bool operator==(const Memory &m) { + return static_cast &>(*this) == + static_cast &>(m) && + _maxlen == m._maxlen; + } - std::size_t maxlen() const { - return _maxlen; - } + std::size_t maxlen() const { return _maxlen; } - void push_back(const T& value) { - std::deque::push_back(value); - if (this->size() > _maxlen) { - std::deque::pop_front(); - } + void push_back(const T &value) { + std::deque::push_back(value); + if (this->size() > _maxlen) { + std::deque::pop_front(); } + } - void push_back(T&& value) { - std::deque::push_back(value); - if (this->size() > _maxlen) { - std::deque::pop_front(); - } + void push_back(T &&value) { + std::deque::push_back(value); + if (this->size() > _maxlen) { + std::deque::pop_front(); } + } - void push_front(const T& value) { - std::deque::push_front(value); - if (this->size() > _maxlen) { - std::deque::pop_back(); - } + void push_front(const T &value) { + std::deque::push_front(value); + if (this->size() > _maxlen) { + std::deque::pop_back(); } + } - void push_front(T&& value) { - std::deque::push_front(value); - if (this->size() > _maxlen) { - std::deque::pop_back(); - } + void push_front(T &&value) { + std::deque::push_front(value); + if (this->size() > _maxlen) { + std::deque::pop_back(); } + } -protected : - std::size_t _maxlen; +protected: + std::size_t _maxlen; }; } // namespace skdecide diff --git a/cpp/src/hub/py_skdecide.cc b/cpp/src/hub/py_skdecide.cc index 56091b4521..decd2c2b4d 100644 --- a/cpp/src/hub/py_skdecide.cc +++ b/cpp/src/hub/py_skdecide.cc @@ -8,25 +8,25 @@ namespace py = pybind11; -void init_pyaostar(py::module& m); -void init_pyastar(py::module& m); -void init_pybfws(py::module& m); -void init_pyilaostar(py::module& m); -void init_pyiw(py::module& m); -void init_pylrtdp(py::module& m); -void init_pymartdp(py::module& m); -void init_pymcts(py::module& m); -void init_pyriw(py::module& m); +void init_pyaostar(py::module &m); +void init_pyastar(py::module &m); +void init_pybfws(py::module &m); +void init_pyilaostar(py::module &m); +void init_pyiw(py::module &m); +void init_pylrtdp(py::module &m); +void init_pymartdp(py::module &m); +void init_pymcts(py::module &m); +void init_pyriw(py::module &m); PYBIND11_MODULE(__skdecide_hub_cpp, m) { - skdecide::Globals::init(); - init_pyaostar(m); - init_pyastar(m); - init_pybfws(m); - init_pyilaostar(m); - init_pyiw(m); - init_pylrtdp(m); - init_pymartdp(m); - init_pymcts(m); - init_pyriw(m); + skdecide::Globals::init(); + init_pyaostar(m); + init_pyastar(m); + init_pybfws(m); + init_pyilaostar(m); + init_pyiw(m); + init_pylrtdp(m); + init_pymartdp(m); + init_pymcts(m); + init_pyriw(m); } diff --git a/cpp/src/hub/solver/aostar/aostar.hh b/cpp/src/hub/solver/aostar/aostar.hh index a02c5244f8..914f50194b 100644 --- a/cpp/src/hub/solver/aostar/aostar.hh +++ b/cpp/src/hub/solver/aostar/aostar.hh @@ -15,77 +15,76 @@ namespace skdecide { -template +template class AOStarSolver { -public : - typedef Tdomain Domain; - typedef typename Domain::State State; - typedef typename Domain::Action Action; - typedef typename Domain::Predicate Predicate; - typedef typename Domain::Value Value; - typedef Texecution_policy ExecutionPolicy; - - AOStarSolver(Domain& domain, - const std::function& goal_checker, - const std::function& heuristic, - double discount = 1.0, - std::size_t max_tip_expansions = 1, - bool detect_cycles = false, - bool debug_logs = false); - - // clears the solver (clears the search graph, thus preventing from reusing - // previous search results) - void clear(); - - // solves from state s using heuristic function h - void solve(const State& s); - bool is_solution_defined_for(const State& s) const; - const Action& get_best_action(const State& s) const; - const double& get_best_value(const State& s) const; - -private : - Domain& _domain; - std::function _goal_checker; - std::function _heuristic; - double _discount; - std::size_t _max_tip_expansions; - bool _detect_cycles; - bool _debug_logs; - ExecutionPolicy _execution_policy; - - struct ActionNode; - - struct StateNode { - State state; - std::list> actions; - ActionNode* best_action; - double best_value; - bool solved; - std::list parents; - - StateNode(const State& s); - - struct Key { - const State& operator()(const StateNode& sn) const; - }; +public: + typedef Tdomain Domain; + typedef typename Domain::State State; + typedef typename Domain::Action Action; + typedef typename Domain::Predicate Predicate; + typedef typename Domain::Value Value; + typedef Texecution_policy ExecutionPolicy; + + AOStarSolver( + Domain &domain, + const std::function &goal_checker, + const std::function &heuristic, + double discount = 1.0, std::size_t max_tip_expansions = 1, + bool detect_cycles = false, bool debug_logs = false); + + // clears the solver (clears the search graph, thus preventing from reusing + // previous search results) + void clear(); + + // solves from state s using heuristic function h + void solve(const State &s); + bool is_solution_defined_for(const State &s) const; + const Action &get_best_action(const State &s) const; + const double &get_best_value(const State &s) const; + +private: + Domain &_domain; + std::function _goal_checker; + std::function _heuristic; + double _discount; + std::size_t _max_tip_expansions; + bool _detect_cycles; + bool _debug_logs; + ExecutionPolicy _execution_policy; + + struct ActionNode; + + struct StateNode { + State state; + std::list> actions; + ActionNode *best_action; + double best_value; + bool solved; + std::list parents; + + StateNode(const State &s); + + struct Key { + const State &operator()(const StateNode &sn) const; }; + }; - struct ActionNode { - Action action; - std::list> outcomes; // next state nodes owned by _graph - double value; - StateNode* parent; + struct ActionNode { + Action action; + std::list> + outcomes; // next state nodes owned by _graph + double value; + StateNode *parent; - ActionNode(const Action& a); - }; + ActionNode(const Action &a); + }; - struct StateNodeCompare { - bool operator()(StateNode*& a, StateNode*& b) const; - }; + struct StateNodeCompare { + bool operator()(StateNode *&a, StateNode *&b) const; + }; - typedef typename SetTypeDeducer::Set Graph; - Graph _graph; + typedef typename SetTypeDeducer::Set Graph; + Graph _graph; }; } // namespace skdecide diff --git a/cpp/src/hub/solver/aostar/impl/aostar_impl.hh b/cpp/src/hub/solver/aostar/impl/aostar_impl.hh index f513aa1a8f..fdc86e5e59 100644 --- a/cpp/src/hub/solver/aostar/impl/aostar_impl.hh +++ b/cpp/src/hub/solver/aostar/impl/aostar_impl.hh @@ -15,234 +15,278 @@ namespace skdecide { // === AOStarSolver implementation === -#define SK_AOSTAR_SOLVER_TEMPLATE_DECL \ -template +#define SK_AOSTAR_SOLVER_TEMPLATE_DECL \ + template -#define SK_AOSTAR_SOLVER_CLASS \ -AOStarSolver +#define SK_AOSTAR_SOLVER_CLASS AOStarSolver SK_AOSTAR_SOLVER_TEMPLATE_DECL -SK_AOSTAR_SOLVER_CLASS::AOStarSolver(Domain& domain, - const std::function& goal_checker, - const std::function& heuristic, - double discount, - std::size_t max_tip_expansions, - bool detect_cycles, - bool debug_logs) -: _domain(domain), _goal_checker(goal_checker), _heuristic(heuristic), - _discount(discount), _max_tip_expansions(max_tip_expansions), - _detect_cycles(detect_cycles), _debug_logs(debug_logs) { - - if (debug_logs) { - Logger::check_level(logging::debug, "algorithm AO*"); - } - +SK_AOSTAR_SOLVER_CLASS::AOStarSolver( + Domain &domain, + const std::function &goal_checker, + const std::function &heuristic, + double discount, std::size_t max_tip_expansions, bool detect_cycles, + bool debug_logs) + : _domain(domain), _goal_checker(goal_checker), _heuristic(heuristic), + _discount(discount), _max_tip_expansions(max_tip_expansions), + _detect_cycles(detect_cycles), _debug_logs(debug_logs) { + + if (debug_logs) { + Logger::check_level(logging::debug, "algorithm AO*"); + } } - SK_AOSTAR_SOLVER_TEMPLATE_DECL -void SK_AOSTAR_SOLVER_CLASS::clear() { - _graph.clear(); -} - +void SK_AOSTAR_SOLVER_CLASS::clear() { _graph.clear(); } SK_AOSTAR_SOLVER_TEMPLATE_DECL -void SK_AOSTAR_SOLVER_CLASS::solve(const State& s) { - try { - Logger::info("Running " + ExecutionPolicy::print_type() + " AO* solver from state " + s.print()); - auto start_time = std::chrono::high_resolution_clock::now(); - - auto si = _graph.emplace(s); - if (si.first->solved || _goal_checker(_domain, s)) { // problem already solved from this state (was present in _graph and already solved) - return; - } - StateNode& root_node = const_cast(*(si.first)); // we won't change the real key (StateNode::state) so we are safe - std::priority_queue, StateNodeCompare> q; // contains only non-goal unsolved tip nodes - q.push(&root_node); - - while (!q.empty()) { - if (_debug_logs) { - Logger::debug("Current number of tip nodes: " + StringConverter::from(q.size())); - Logger::debug("Current number of explored nodes: " + StringConverter::from(_graph.size())); - } - std::size_t nb_expansions = std::min(q.size(), _max_tip_expansions); - std::unordered_set frontier; - for (std::size_t cnt = 0 ; cnt < nb_expansions ; cnt++) { - // Select best tip node of best partial graph - StateNode* best_tip_node = q.top(); - q.pop(); - frontier.insert(best_tip_node); - if (_debug_logs) Logger::debug("Current best tip node: " + best_tip_node->state.print()); - - // Expand best tip node - auto applicable_actions = _domain.get_applicable_actions(best_tip_node->state).get_elements(); - std::for_each(ExecutionPolicy::policy, applicable_actions.begin(), applicable_actions.end(), [this, &best_tip_node](auto a){ - if (_debug_logs) Logger::debug("Current expanded action: " + a.print() + ExecutionPolicy::print_thread()); - _execution_policy.protect([&best_tip_node, &a]{ - best_tip_node->actions.push_back(std::make_unique(a)); - }); - ActionNode& an = *(best_tip_node->actions.back()); - an.parent = best_tip_node; - auto next_states = _domain.get_next_state_distribution(best_tip_node->state, a).get_values(); - for (auto ns : next_states) { - if (_debug_logs) Logger::debug("Current next state expansion: " + ns.state().print() + ExecutionPolicy::print_thread()); - std::pair i; - _execution_policy.protect([this, &i, &ns]{ - i = _graph.emplace(ns.state()); - }); - StateNode& next_node = const_cast(*(i.first)); // we won't change the real key (StateNode::state) so we are safe - an.outcomes.push_back(std::make_tuple(ns.probability(), _domain.get_transition_value(best_tip_node->state, a, next_node.state).cost(), &next_node)); - _execution_policy.protect([&next_node, &an]{ - next_node.parents.push_back(&an); - }); - if (i.second) { // new node - if (_goal_checker(_domain, next_node.state)) { - if (_debug_logs) Logger::debug("Found goal state " + next_node.state.print() + ExecutionPolicy::print_thread()); - next_node.solved = true; - next_node.best_value = 0.0; - } else { - next_node.best_value = _heuristic(_domain, next_node.state).cost(); - if (_debug_logs) Logger::debug("New state " + next_node.state.print() + " with heuristic value " + - StringConverter::from(next_node.best_value) + ExecutionPolicy::print_thread()); - } - } - } - }); - } - - // Back-propagate value function from best tip node - std::unique_ptr> explored_states; // only for detecting cycles - if (_detect_cycles) explored_states = std::make_unique>(frontier); - while (!frontier.empty()) { - std::unordered_set new_frontier; - std::for_each(ExecutionPolicy::policy, frontier.begin(), frontier.end(), [this, &new_frontier](const auto& fs){ - // update Q-values and V-value - fs->best_value = std::numeric_limits::infinity(); - fs->best_action = nullptr; - for (const auto& a : fs->actions) { - a->value = 0.0; - for (const auto& ns : a->outcomes) { - a->value += std::get<0>(ns) * (std::get<1>(ns) + (_discount * std::get<2>(ns)->best_value)); - } - if (a->value < fs->best_value) { +void SK_AOSTAR_SOLVER_CLASS::solve(const State &s) { + try { + Logger::info("Running " + ExecutionPolicy::print_type() + + " AO* solver from state " + s.print()); + auto start_time = std::chrono::high_resolution_clock::now(); + + auto si = _graph.emplace(s); + if (si.first->solved || + _goal_checker(_domain, + s)) { // problem already solved from this state (was + // present in _graph and already solved) + return; + } + StateNode &root_node = const_cast( + *(si.first)); // we won't change the real key (StateNode::state) so we + // are safe + std::priority_queue, StateNodeCompare> + q; // contains only non-goal unsolved tip nodes + q.push(&root_node); + + while (!q.empty()) { + if (_debug_logs) { + Logger::debug("Current number of tip nodes: " + + StringConverter::from(q.size())); + Logger::debug("Current number of explored nodes: " + + StringConverter::from(_graph.size())); + } + std::size_t nb_expansions = std::min(q.size(), _max_tip_expansions); + std::unordered_set frontier; + for (std::size_t cnt = 0; cnt < nb_expansions; cnt++) { + // Select best tip node of best partial graph + StateNode *best_tip_node = q.top(); + q.pop(); + frontier.insert(best_tip_node); + if (_debug_logs) + Logger::debug("Current best tip node: " + + best_tip_node->state.print()); + + // Expand best tip node + auto applicable_actions = + _domain.get_applicable_actions(best_tip_node->state).get_elements(); + std::for_each( + ExecutionPolicy::policy, applicable_actions.begin(), + applicable_actions.end(), [this, &best_tip_node](auto a) { + if (_debug_logs) + Logger::debug("Current expanded action: " + a.print() + + ExecutionPolicy::print_thread()); + _execution_policy.protect([&best_tip_node, &a] { + best_tip_node->actions.push_back( + std::make_unique(a)); + }); + ActionNode &an = *(best_tip_node->actions.back()); + an.parent = best_tip_node; + auto next_states = + _domain.get_next_state_distribution(best_tip_node->state, a) + .get_values(); + for (auto ns : next_states) { + if (_debug_logs) + Logger::debug( + "Current next state expansion: " + ns.state().print() + + ExecutionPolicy::print_thread()); + std::pair i; + _execution_policy.protect( + [this, &i, &ns] { i = _graph.emplace(ns.state()); }); + StateNode &next_node = const_cast( + *(i.first)); // we won't change the real key + // (StateNode::state) so we are safe + an.outcomes.push_back(std::make_tuple( + ns.probability(), + _domain + .get_transition_value(best_tip_node->state, a, + next_node.state) + .cost(), + &next_node)); + _execution_policy.protect( + [&next_node, &an] { next_node.parents.push_back(&an); }); + if (i.second) { // new node + if (_goal_checker(_domain, next_node.state)) { + if (_debug_logs) + Logger::debug("Found goal state " + + next_node.state.print() + + ExecutionPolicy::print_thread()); + next_node.solved = true; + next_node.best_value = 0.0; + } else { + next_node.best_value = + _heuristic(_domain, next_node.state).cost(); + if (_debug_logs) + Logger::debug( + "New state " + next_node.state.print() + + " with heuristic value " + + StringConverter::from(next_node.best_value) + + ExecutionPolicy::print_thread()); + } + } + } + }); + } + + // Back-propagate value function from best tip node + std::unique_ptr> + explored_states; // only for detecting cycles + if (_detect_cycles) + explored_states = + std::make_unique>(frontier); + while (!frontier.empty()) { + std::unordered_set new_frontier; + std::for_each(ExecutionPolicy::policy, frontier.begin(), frontier.end(), + [this, &new_frontier](const auto &fs) { + // update Q-values and V-value + fs->best_value = + std::numeric_limits::infinity(); + fs->best_action = nullptr; + for (const auto &a : fs->actions) { + a->value = 0.0; + for (const auto &ns : a->outcomes) { + a->value += + std::get<0>(ns) * + (std::get<1>(ns) + + (_discount * std::get<2>(ns)->best_value)); + } + if (a->value < fs->best_value) { fs->best_value = a->value; fs->best_action = a.get(); + } + fs->best_value = std::min(fs->best_value, a->value); } - fs->best_value = std::min(fs->best_value, a->value); - } - // update solved field - fs->solved = true; - for (const auto& ns : fs->best_action->outcomes) { - fs->solved = fs->solved && std::get<2>(ns)->solved; - } - // update new frontier - _execution_policy.protect([&fs, &new_frontier]{ - for (const auto& ps : fs->parents) { - new_frontier.insert(ps->parent); - } - }); - }); - frontier = new_frontier; - if (_detect_cycles) { - for (const auto& ps : new_frontier) { - if (explored_states->find(ps) != explored_states->end()) { - throw std::logic_error("SKDECIDE exception: cycle detected in the MDP graph! [with state " + ps->state.print() + "]"); + // update solved field + fs->solved = true; + for (const auto &ns : fs->best_action->outcomes) { + fs->solved = fs->solved && std::get<2>(ns)->solved; } - explored_states->insert(ps); - } - } + // update new frontier + _execution_policy.protect([&fs, &new_frontier] { + for (const auto &ps : fs->parents) { + new_frontier.insert(ps->parent); + } + }); + }); + frontier = new_frontier; + if (_detect_cycles) { + for (const auto &ps : new_frontier) { + if (explored_states->find(ps) != explored_states->end()) { + throw std::logic_error("SKDECIDE exception: cycle detected in " + "the MDP graph! [with state " + + ps->state.print() + "]"); } - - // Recompute best partial graph - q = std::priority_queue, StateNodeCompare>(); - frontier.insert(&root_node); - while (!frontier.empty()) { - std::unordered_set new_frontier; - for (const auto& fs : frontier) { - if (!(fs->solved)) { - if (fs->best_action != nullptr) { - for (const auto& ns : fs->best_action->outcomes) { - new_frontier.insert(std::get<2>(ns)); - } - } else { // tip node - q.push(fs); - } - } - } - frontier = new_frontier; + explored_states->insert(ps); + } + } + } + + // Recompute best partial graph + q = std::priority_queue, + StateNodeCompare>(); + frontier.insert(&root_node); + while (!frontier.empty()) { + std::unordered_set new_frontier; + for (const auto &fs : frontier) { + if (!(fs->solved)) { + if (fs->best_action != nullptr) { + for (const auto &ns : fs->best_action->outcomes) { + new_frontier.insert(std::get<2>(ns)); + } + } else { // tip node + q.push(fs); } + } } - - auto end_time = std::chrono::high_resolution_clock::now(); - auto duration = std::chrono::duration_cast(end_time - start_time).count(); - Logger::info("AO* finished to solve from state " + s.print() + " in " + StringConverter::from((double) duration / (double) 1e9) + " seconds."); - } catch (const std::exception& e) { - Logger::error("AO* failed solving from state " + s.print() + ". Reason: " + e.what()); - throw; + frontier = new_frontier; + } } -} + auto end_time = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast( + end_time - start_time) + .count(); + Logger::info("AO* finished to solve from state " + s.print() + " in " + + StringConverter::from((double)duration / (double)1e9) + + " seconds."); + } catch (const std::exception &e) { + Logger::error("AO* failed solving from state " + s.print() + + ". Reason: " + e.what()); + throw; + } +} SK_AOSTAR_SOLVER_TEMPLATE_DECL -bool SK_AOSTAR_SOLVER_CLASS::is_solution_defined_for(const State& s) const { - auto si = _graph.find(s); - if ((si == _graph.end()) || (si->best_action == nullptr) || (si->solved == false)) { - return false; - } else { - return true; - } +bool SK_AOSTAR_SOLVER_CLASS::is_solution_defined_for(const State &s) const { + auto si = _graph.find(s); + if ((si == _graph.end()) || (si->best_action == nullptr) || + (si->solved == false)) { + return false; + } else { + return true; + } } - SK_AOSTAR_SOLVER_TEMPLATE_DECL -const typename SK_AOSTAR_SOLVER_CLASS::Action& -SK_AOSTAR_SOLVER_CLASS::get_best_action(const State& s) const { - auto si = _graph.find(s); - if ((si == _graph.end()) || (si->best_action == nullptr)) { - throw std::runtime_error("SKDECIDE exception: no best action found in state " + s.print()); - } - return si->best_action->action; +const typename SK_AOSTAR_SOLVER_CLASS::Action & +SK_AOSTAR_SOLVER_CLASS::get_best_action(const State &s) const { + auto si = _graph.find(s); + if ((si == _graph.end()) || (si->best_action == nullptr)) { + throw std::runtime_error( + "SKDECIDE exception: no best action found in state " + s.print()); + } + return si->best_action->action; } - SK_AOSTAR_SOLVER_TEMPLATE_DECL -const double& SK_AOSTAR_SOLVER_CLASS::get_best_value(const State& s) const { - auto si = _graph.find(s); - if (si == _graph.end()) { - throw std::runtime_error("SKDECIDE exception: no best action found in state " + s.print()); - } - return si->best_value; +const double &SK_AOSTAR_SOLVER_CLASS::get_best_value(const State &s) const { + auto si = _graph.find(s); + if (si == _graph.end()) { + throw std::runtime_error( + "SKDECIDE exception: no best action found in state " + s.print()); + } + return si->best_value; } - // === AOStarSolver::StateNode implementation === SK_AOSTAR_SOLVER_TEMPLATE_DECL -SK_AOSTAR_SOLVER_CLASS::StateNode::StateNode(const State& s) -: state(s), best_action(nullptr), - best_value(std::numeric_limits::infinity()), - solved(false) {} - +SK_AOSTAR_SOLVER_CLASS::StateNode::StateNode(const State &s) + : state(s), best_action(nullptr), + best_value(std::numeric_limits::infinity()), solved(false) {} SK_AOSTAR_SOLVER_TEMPLATE_DECL -const typename SK_AOSTAR_SOLVER_CLASS::State& -SK_AOSTAR_SOLVER_CLASS::StateNode::Key::operator()(const StateNode& sn) const { - return sn.state; +const typename SK_AOSTAR_SOLVER_CLASS::State & +SK_AOSTAR_SOLVER_CLASS::StateNode::Key::operator()(const StateNode &sn) const { + return sn.state; } - // === AOStarSolver::ActionNode implementation === SK_AOSTAR_SOLVER_TEMPLATE_DECL -SK_AOSTAR_SOLVER_CLASS::ActionNode::ActionNode(const Action& a) -: action(a), value(std::numeric_limits::infinity()), parent(nullptr) {} - +SK_AOSTAR_SOLVER_CLASS::ActionNode::ActionNode(const Action &a) + : action(a), value(std::numeric_limits::infinity()), + parent(nullptr) {} // === AOStarSolver::StateNodeCompare implementation === SK_AOSTAR_SOLVER_TEMPLATE_DECL -bool SK_AOSTAR_SOLVER_CLASS::StateNodeCompare::operator()(StateNode*& a, StateNode*& b) const { - return (a->best_value) > (b->best_value); // smallest element appears at the top of the priority_queue => cost optimization +bool SK_AOSTAR_SOLVER_CLASS::StateNodeCompare::operator()(StateNode *&a, + StateNode *&b) const { + return (a->best_value) > + (b->best_value); // smallest element appears at the top of the + // priority_queue => cost optimization } } // namespace skdecide diff --git a/cpp/src/hub/solver/aostar/py_aostar.cc b/cpp/src/hub/solver/aostar/py_aostar.cc index 94c2d0aecc..ed1e68fedf 100644 --- a/cpp/src/hub/solver/aostar/py_aostar.cc +++ b/cpp/src/hub/solver/aostar/py_aostar.cc @@ -7,30 +7,26 @@ #include "py_aostar.hh" -void init_pyaostar(py::module& m) { - py::class_ py_aostar_solver(m, "_AOStarSolver_"); - py_aostar_solver - .def(py::init&, - const std::function&, - double, - std::size_t, - bool, - bool, - bool>(), - py::arg("domain"), - py::arg("goal_checker"), - py::arg("heuristic"), - py::arg("discount")=1.0, - py::arg("max_tip_expansions")=1, - py::arg("detect_cycles")=false, - py::arg("parallel")=false, - py::arg("debug_logs")=false) - .def("close", &skdecide::PyAOStarSolver::close) - .def("clear", &skdecide::PyAOStarSolver::clear) - .def("solve", &skdecide::PyAOStarSolver::solve, py::arg("state")) - .def("is_solution_defined_for", &skdecide::PyAOStarSolver::is_solution_defined_for, py::arg("state")) - .def("get_next_action", &skdecide::PyAOStarSolver::get_next_action, py::arg("state")) - .def("get_utility", &skdecide::PyAOStarSolver::get_utility, py::arg("state")) - ; +void init_pyaostar(py::module &m) { + py::class_ py_aostar_solver(m, "_AOStarSolver_"); + py_aostar_solver + .def(py::init &, + const std::function &, + double, std::size_t, bool, bool, bool>(), + py::arg("domain"), py::arg("goal_checker"), py::arg("heuristic"), + py::arg("discount") = 1.0, py::arg("max_tip_expansions") = 1, + py::arg("detect_cycles") = false, py::arg("parallel") = false, + py::arg("debug_logs") = false) + .def("close", &skdecide::PyAOStarSolver::close) + .def("clear", &skdecide::PyAOStarSolver::clear) + .def("solve", &skdecide::PyAOStarSolver::solve, py::arg("state")) + .def("is_solution_defined_for", + &skdecide::PyAOStarSolver::is_solution_defined_for, py::arg("state")) + .def("get_next_action", &skdecide::PyAOStarSolver::get_next_action, + py::arg("state")) + .def("get_utility", &skdecide::PyAOStarSolver::get_utility, + py::arg("state")); } diff --git a/cpp/src/hub/solver/aostar/py_aostar.hh b/cpp/src/hub/solver/aostar/py_aostar.hh index 5231896a5a..e072fb837e 100644 --- a/cpp/src/hub/solver/aostar/py_aostar.hh +++ b/cpp/src/hub/solver/aostar/py_aostar.hh @@ -24,199 +24,207 @@ namespace skdecide { template using PyAOStarDomain = PythonDomainProxy; - class PyAOStarSolver { -private : - - class BaseImplementation { - public : - virtual ~BaseImplementation() {} - virtual void close() = 0; - virtual void clear() = 0; - virtual void solve(const py::object& s) = 0; - virtual py::bool_ is_solution_defined_for(const py::object& s) = 0; - virtual py::object get_next_action(const py::object& s) = 0; - virtual py::float_ get_utility(const py::object& s) = 0; - }; - - template - class Implementation : public BaseImplementation { - public : - Implementation(py::object& domain, - const std::function& goal_checker, - const std::function& heuristic, - double discount = 1.0, - std::size_t max_tip_expansions = 1, - bool detect_cycles = false, - bool debug_logs = false) +private: + class BaseImplementation { + public: + virtual ~BaseImplementation() {} + virtual void close() = 0; + virtual void clear() = 0; + virtual void solve(const py::object &s) = 0; + virtual py::bool_ is_solution_defined_for(const py::object &s) = 0; + virtual py::object get_next_action(const py::object &s) = 0; + virtual py::float_ get_utility(const py::object &s) = 0; + }; + + template + class Implementation : public BaseImplementation { + public: + Implementation( + py::object &domain, + const std::function + &goal_checker, + const std::function + &heuristic, + double discount = 1.0, std::size_t max_tip_expansions = 1, + bool detect_cycles = false, bool debug_logs = false) : _goal_checker(goal_checker), _heuristic(heuristic) { - check_domain(domain); - _domain = std::make_unique>(domain); - _solver = std::make_unique, Texecution>>( - *_domain, - [this](PyAOStarDomain& d, const typename PyAOStarDomain::State& s) -> typename PyAOStarDomain::Predicate { - try { - auto fgc = [this](const py::object& dd, const py::object& ss, [[maybe_unused]] const py::object& ii) { - return _goal_checker(dd, ss); - }; - std::unique_ptr r = d.call(nullptr, fgc, s.pyobj()); - typename skdecide::GilControl::Acquire acquire; - bool rr = r->template cast(); - r.reset(); - return rr; - } catch (const std::exception& e) { - Logger::error(std::string("SKDECIDE exception when calling goal checker: ") + e.what()); - throw; - } - }, - [this](PyAOStarDomain& d, const typename PyAOStarDomain::State& s) -> typename PyAOStarDomain::Value { - try { - auto fh = [this](const py::object& dd, const py::object& ss, [[maybe_unused]] const py::object& ii) { - return _heuristic(dd, ss); - }; - return typename PyAOStarDomain::Value(d.call(nullptr, fh, s.pyobj())); - } catch (const std::exception& e) { - Logger::error(std::string("SKDECIDE exception when calling heuristic estimator: ") + e.what()); - throw; - } - }, - discount, - max_tip_expansions, - detect_cycles, - debug_logs); - _stdout_redirect = std::make_unique(std::cout, - py::module::import("sys").attr("stdout")); - _stderr_redirect = std::make_unique(std::cerr, - py::module::import("sys").attr("stderr")); - } - - virtual ~Implementation() {} - - void check_domain(py::object& domain) { - if (!py::hasattr(domain, "get_applicable_actions")) { - throw std::invalid_argument("SKDECIDE exception: AO* algorithm needs python domain for implementing get_applicable_actions()"); - } - if (!py::hasattr(domain, "get_next_state_distribution")) { - throw std::invalid_argument("SKDECIDE exception: AO* algorithm needs python domain for implementing get_next_state_distribution()"); + check_domain(domain); + _domain = std::make_unique>(domain); + _solver = std::make_unique< + skdecide::AOStarSolver, Texecution>>( + *_domain, + [this](PyAOStarDomain & d, + const typename PyAOStarDomain::State &s) -> + typename PyAOStarDomain::Predicate { + try { + auto fgc = [this](const py::object &dd, const py::object &ss, + [[maybe_unused]] const py::object &ii) { + return _goal_checker(dd, ss); + }; + std::unique_ptr r = d.call(nullptr, fgc, s.pyobj()); + typename skdecide::GilControl::Acquire acquire; + bool rr = r->template cast(); + r.reset(); + return rr; + } catch (const std::exception &e) { + Logger::error( + std::string( + "SKDECIDE exception when calling goal checker: ") + + e.what()); + throw; } - if (!py::hasattr(domain, "get_transition_value")) { - throw std::invalid_argument("SKDECIDE exception: AO* algorithm needs python domain for implementing get_transition_value()"); + }, + [this](PyAOStarDomain & d, + const typename PyAOStarDomain::State &s) -> + typename PyAOStarDomain::Value { + try { + auto fh = [this](const py::object &dd, const py::object &ss, + [[maybe_unused]] const py::object &ii) { + return _heuristic(dd, ss); + }; + return typename PyAOStarDomain::Value( + d.call(nullptr, fh, s.pyobj())); + } catch (const std::exception &e) { + Logger::error( + std::string( + "SKDECIDE exception when calling heuristic estimator: ") + + e.what()); + throw; } - } - - virtual void close() { - _domain->close(); - } - - virtual void clear() { - _solver->clear(); - } - - virtual void solve(const py::object& s) { - typename skdecide::GilControl::Release release; - _solver->solve(s); - } - - virtual py::bool_ is_solution_defined_for(const py::object& s) { - return _solver->is_solution_defined_for(s); - } + }, + discount, max_tip_expansions, detect_cycles, debug_logs); + _stdout_redirect = std::make_unique( + std::cout, py::module::import("sys").attr("stdout")); + _stderr_redirect = std::make_unique( + std::cerr, py::module::import("sys").attr("stderr")); + } - virtual py::object get_next_action(const py::object& s) { - return _solver->get_best_action(s).pyobj(); - } + virtual ~Implementation() {} + + void check_domain(py::object &domain) { + if (!py::hasattr(domain, "get_applicable_actions")) { + throw std::invalid_argument( + "SKDECIDE exception: AO* algorithm needs python domain for " + "implementing get_applicable_actions()"); + } + if (!py::hasattr(domain, "get_next_state_distribution")) { + throw std::invalid_argument( + "SKDECIDE exception: AO* algorithm needs python domain for " + "implementing get_next_state_distribution()"); + } + if (!py::hasattr(domain, "get_transition_value")) { + throw std::invalid_argument( + "SKDECIDE exception: AO* algorithm needs python domain for " + "implementing get_transition_value()"); + } + } - virtual py::float_ get_utility(const py::object& s) { - return _solver->get_best_value(s); - } + virtual void close() { _domain->close(); } - private : - std::unique_ptr> _domain; - std::unique_ptr, Texecution>> _solver; + virtual void clear() { _solver->clear(); } - std::function _goal_checker; - std::function _heuristic; + virtual void solve(const py::object &s) { + typename skdecide::GilControl::Release release; + _solver->solve(s); + } - std::unique_ptr _stdout_redirect; - std::unique_ptr _stderr_redirect; - }; + virtual py::bool_ is_solution_defined_for(const py::object &s) { + return _solver->is_solution_defined_for(s); + } - struct ExecutionSelector { - bool _parallel; + virtual py::object get_next_action(const py::object &s) { + return _solver->get_best_action(s).pyobj(); + } - ExecutionSelector(bool parallel) : _parallel(parallel) {} + virtual py::float_ get_utility(const py::object &s) { + return _solver->get_best_value(s); + } - template - struct Select { - template - Select(ExecutionSelector& This, Args... args) { - if (This._parallel) { - Propagator::template PushType::Forward(args...); - } else { - Propagator::template PushType::Forward(args...); - } - } - }; + private: + std::unique_ptr> _domain; + std::unique_ptr< + skdecide::AOStarSolver, Texecution>> + _solver; + + std::function + _goal_checker; + std::function + _heuristic; + + std::unique_ptr _stdout_redirect; + std::unique_ptr _stderr_redirect; + }; + + struct ExecutionSelector { + bool _parallel; + + ExecutionSelector(bool parallel) : _parallel(parallel) {} + + template struct Select { + template + Select(ExecutionSelector &This, Args... args) { + if (This._parallel) { + Propagator::template PushType::Forward(args...); + } else { + Propagator::template PushType::Forward(args...); + } + } }; + }; - struct SolverInstantiator { - std::unique_ptr& _implementation; + struct SolverInstantiator { + std::unique_ptr &_implementation; - SolverInstantiator(std::unique_ptr& implementation) + SolverInstantiator(std::unique_ptr &implementation) : _implementation(implementation) {} - template - struct Instantiate { - template - Instantiate(SolverInstantiator& This, Args... args) { - This._implementation = std::make_unique>(args...); - } - }; + template struct Instantiate { + template + Instantiate(SolverInstantiator &This, Args... args) { + This._implementation = + std::make_unique>(args...); + } }; + }; - std::unique_ptr _implementation; + std::unique_ptr _implementation; -public : - PyAOStarSolver(py::object& domain, - const std::function& goal_checker, - const std::function& heuristic, - double discount = 1.0, - std::size_t max_tip_expansions = 1, - bool detect_cycles = false, - bool parallel = false, - bool debug_logs = false) { +public: + PyAOStarSolver( + py::object &domain, + const std::function + &goal_checker, + const std::function + &heuristic, + double discount = 1.0, std::size_t max_tip_expansions = 1, + bool detect_cycles = false, bool parallel = false, + bool debug_logs = false) { - TemplateInstantiator::select( - ExecutionSelector(parallel), - SolverInstantiator(_implementation)).instantiate( - domain, goal_checker, heuristic, discount, - max_tip_expansions, detect_cycles, debug_logs); + TemplateInstantiator::select(ExecutionSelector(parallel), + SolverInstantiator(_implementation)) + .instantiate(domain, goal_checker, heuristic, discount, + max_tip_expansions, detect_cycles, debug_logs); + } - } + void close() { _implementation->close(); } - void close() { - _implementation->close(); - } + void clear() { _implementation->clear(); } - void clear() { - _implementation->clear(); - } + void solve(const py::object &s) { _implementation->solve(s); } - void solve(const py::object& s) { - _implementation->solve(s); - } + py::bool_ is_solution_defined_for(const py::object &s) { + return _implementation->is_solution_defined_for(s); + } - py::bool_ is_solution_defined_for(const py::object& s) { - return _implementation->is_solution_defined_for(s); - } - - py::object get_next_action(const py::object& s) { - return _implementation->get_next_action(s); - } + py::object get_next_action(const py::object &s) { + return _implementation->get_next_action(s); + } - py::float_ get_utility(const py::object& s) { - return _implementation->get_utility(s); - } + py::float_ get_utility(const py::object &s) { + return _implementation->get_utility(s); + } }; } // namespace skdecide diff --git a/cpp/src/hub/solver/astar/astar.hh b/cpp/src/hub/solver/astar/astar.hh index 3b2265ae76..a3b7b3a3c5 100644 --- a/cpp/src/hub/solver/astar/astar.hh +++ b/cpp/src/hub/solver/astar/astar.hh @@ -19,61 +19,63 @@ namespace skdecide { -template +template class AStarSolver { -public : - typedef Tdomain Domain; - typedef typename Domain::State State; - typedef typename Domain::Action Action; - typedef typename Domain::Predicate Predicate; - typedef typename Domain::Value Value; - typedef Texecution_policy ExecutionPolicy; - - AStarSolver(Domain& domain, - const std::function& goal_checker, - const std::function& heuristic, - bool debug_logs = false); - - // clears the solver (clears the search graph, thus preventing from reusing - // previous search results) - void clear(); - - // solves from state s using heuristic function h - void solve(const State& s); - - bool is_solution_defined_for(const State& s) const; - const Action& get_best_action(const State& s) const; - const double& get_best_value(const State& s) const; - -private : - Domain& _domain; - std::function _goal_checker; - std::function _heuristic; - bool _debug_logs; - ExecutionPolicy _execution_policy; - - struct Node { - State state; - std::tuple best_parent; - double gscore; - double fscore; - Action* best_action; // computed only when constructing the solution path backward from the goal state - bool solved; // set to true if on the solution path constructed backward from the goal state - - Node(const State& s); - - struct Key { - const State& operator()(const Node& sn) const; - }; +public: + typedef Tdomain Domain; + typedef typename Domain::State State; + typedef typename Domain::Action Action; + typedef typename Domain::Predicate Predicate; + typedef typename Domain::Value Value; + typedef Texecution_policy ExecutionPolicy; + + AStarSolver( + Domain &domain, + const std::function &goal_checker, + const std::function &heuristic, + bool debug_logs = false); + + // clears the solver (clears the search graph, thus preventing from reusing + // previous search results) + void clear(); + + // solves from state s using heuristic function h + void solve(const State &s); + + bool is_solution_defined_for(const State &s) const; + const Action &get_best_action(const State &s) const; + const double &get_best_value(const State &s) const; + +private: + Domain &_domain; + std::function _goal_checker; + std::function _heuristic; + bool _debug_logs; + ExecutionPolicy _execution_policy; + + struct Node { + State state; + std::tuple best_parent; + double gscore; + double fscore; + Action *best_action; // computed only when constructing the solution path + // backward from the goal state + bool solved; // set to true if on the solution path constructed backward + // from the goal state + + Node(const State &s); + + struct Key { + const State &operator()(const Node &sn) const; }; + }; - struct NodeCompare { - bool operator()(Node*& a, Node*& b) const; - }; + struct NodeCompare { + bool operator()(Node *&a, Node *&b) const; + }; - typedef typename SetTypeDeducer::Set Graph; - Graph _graph; + typedef typename SetTypeDeducer::Set Graph; + Graph _graph; }; } // namespace skdecide diff --git a/cpp/src/hub/solver/astar/impl/astar_impl.hh b/cpp/src/hub/solver/astar/impl/astar_impl.hh index 76f66446f0..435b994474 100644 --- a/cpp/src/hub/solver/astar/impl/astar_impl.hh +++ b/cpp/src/hub/solver/astar/impl/astar_impl.hh @@ -15,192 +15,225 @@ namespace skdecide { // === AStarSolver implementation === -#define SK_ASTAR_SOLVER_TEMPLATE_DECL \ -template +#define SK_ASTAR_SOLVER_TEMPLATE_DECL \ + template -#define SK_ASTAR_SOLVER_CLASS \ -AStarSolver +#define SK_ASTAR_SOLVER_CLASS AStarSolver SK_ASTAR_SOLVER_TEMPLATE_DECL -SK_ASTAR_SOLVER_CLASS::AStarSolver(Domain& domain, - const std::function& goal_checker, - const std::function& heuristic, - bool debug_logs) -: _domain(domain), _goal_checker(goal_checker), _heuristic(heuristic), _debug_logs(debug_logs) { - - if (debug_logs) { - Logger::check_level(logging::debug, "algorithm A*"); - } - +SK_ASTAR_SOLVER_CLASS::AStarSolver( + Domain &domain, + const std::function &goal_checker, + const std::function &heuristic, + bool debug_logs) + : _domain(domain), _goal_checker(goal_checker), _heuristic(heuristic), + _debug_logs(debug_logs) { + + if (debug_logs) { + Logger::check_level(logging::debug, "algorithm A*"); + } } - SK_ASTAR_SOLVER_TEMPLATE_DECL -void SK_ASTAR_SOLVER_CLASS::clear() { - _graph.clear(); -} - +void SK_ASTAR_SOLVER_CLASS::clear() { _graph.clear(); } SK_ASTAR_SOLVER_TEMPLATE_DECL -void SK_ASTAR_SOLVER_CLASS::solve(const State& s) { - try { - Logger::info("Running " + ExecutionPolicy::print_type() + " A* solver from state " + s.print()); - auto start_time = std::chrono::high_resolution_clock::now(); - - // Create the root node containing the given state s - auto si = _graph.emplace(s); - if (si.first->solved || _goal_checker(_domain, s)) { // problem already solved from this state (was present in _graph and already solved) - return; +void SK_ASTAR_SOLVER_CLASS::solve(const State &s) { + try { + Logger::info("Running " + ExecutionPolicy::print_type() + + " A* solver from state " + s.print()); + auto start_time = std::chrono::high_resolution_clock::now(); + + // Create the root node containing the given state s + auto si = _graph.emplace(s); + if (si.first->solved || + _goal_checker(_domain, + s)) { // problem already solved from this state (was + // present in _graph and already solved) + return; + } + Node &root_node = const_cast(*( + si.first)); // we won't change the real key (Node::state) so we are safe + root_node.gscore = 0; + root_node.fscore = _heuristic(_domain, root_node.state).cost(); + + // Priority queue used to sort non-goal unsolved tip nodes by increasing + // cost-to-go values (so-called OPEN container) + std::priority_queue, NodeCompare> open_queue; + open_queue.push(&root_node); + + // Set of states for which the g-value is optimal (so-called CLOSED + // container) + std::unordered_set closed_set; + + while (!open_queue.empty()) { + auto best_tip_node = open_queue.top(); + open_queue.pop(); + + // Check that the best tip node has not already been closed before + // (since this implementation's open_queue does not check for element + // uniqueness, it can contain many copies of the same node pointer that + // could have been closed earlier) + if (closed_set.find(best_tip_node) != + closed_set + .end()) { // this implementation's open_queue can contain several + continue; + } + + if (_debug_logs) + Logger::debug( + "Current best tip node: " + best_tip_node->state.print() + + ", gscore=" + StringConverter::from(best_tip_node->gscore) + + ", fscore=" + StringConverter::from(best_tip_node->fscore)); + + if (_goal_checker(_domain, best_tip_node->state) || + best_tip_node->solved) { + if (_debug_logs) + Logger::debug("Closing a goal state: " + + best_tip_node->state.print()); + auto current_node = best_tip_node; + if (!(best_tip_node->solved)) { + current_node->fscore = 0; + } // goal state + + while (current_node != &root_node) { + Node *parent_node = std::get<0>(current_node->best_parent); + parent_node->best_action = &std::get<1>(current_node->best_parent); + parent_node->fscore = + std::get<2>(current_node->best_parent) + current_node->fscore; + parent_node->solved = true; + current_node = parent_node; } - Node& root_node = const_cast(*(si.first)); // we won't change the real key (Node::state) so we are safe - root_node.gscore = 0; - root_node.fscore = _heuristic(_domain, root_node.state).cost(); - - // Priority queue used to sort non-goal unsolved tip nodes by increasing cost-to-go values (so-called OPEN container) - std::priority_queue, NodeCompare> open_queue; - open_queue.push(&root_node); - - // Set of states for which the g-value is optimal (so-called CLOSED container) - std::unordered_set closed_set; - - while (!open_queue.empty()) { - auto best_tip_node = open_queue.top(); - open_queue.pop(); - - // Check that the best tip node has not already been closed before - // (since this implementation's open_queue does not check for element uniqueness, - // it can contain many copies of the same node pointer that could have been closed earlier) - if (closed_set.find(best_tip_node) != closed_set.end()) { // this implementation's open_queue can contain several - continue; - } - - if (_debug_logs) Logger::debug("Current best tip node: " + best_tip_node->state.print() + - ", gscore=" + StringConverter::from(best_tip_node->gscore) + - ", fscore=" + StringConverter::from(best_tip_node->fscore)); - - if (_goal_checker(_domain, best_tip_node->state) || best_tip_node->solved) { - if (_debug_logs) Logger::debug("Closing a goal state: " + best_tip_node->state.print()); - auto current_node = best_tip_node; - if (!(best_tip_node->solved)) { current_node->fscore = 0; } // goal state - - while (current_node != &root_node) { - Node* parent_node = std::get<0>(current_node->best_parent); - parent_node->best_action = &std::get<1>(current_node->best_parent); - parent_node->fscore = std::get<2>(current_node->best_parent) + current_node->fscore; - parent_node->solved = true; - current_node = parent_node; - } - - auto end_time = std::chrono::high_resolution_clock::now(); - auto duration = std::chrono::duration_cast(end_time - start_time).count(); - Logger::info("A* finished to solve from state " + s.print() + " in " + StringConverter::from((double) duration / (double) 1e9) + " seconds."); - return; - } - closed_set.insert(best_tip_node); - - // Expand best tip node - auto applicable_actions = _domain.get_applicable_actions(best_tip_node->state).get_elements(); - std::for_each(ExecutionPolicy::policy, applicable_actions.begin(), applicable_actions.end(), [this, &best_tip_node, &open_queue, &closed_set](auto a){ - if (_debug_logs) Logger::debug("Current expanded action: " + a.print() + ExecutionPolicy::print_thread()); - auto next_state = _domain.get_next_state(best_tip_node->state, a); - if (_debug_logs) Logger::debug("Exploring next state " + next_state.print() + ExecutionPolicy::print_thread()); - std::pair i; - _execution_policy.protect([this, &i, &next_state]{ - i = _graph.emplace(next_state); + auto end_time = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast( + end_time - start_time) + .count(); + Logger::info("A* finished to solve from state " + s.print() + " in " + + StringConverter::from((double)duration / (double)1e9) + + " seconds."); + return; + } + + closed_set.insert(best_tip_node); + + // Expand best tip node + auto applicable_actions = + _domain.get_applicable_actions(best_tip_node->state).get_elements(); + std::for_each( + ExecutionPolicy::policy, applicable_actions.begin(), + applicable_actions.end(), + [this, &best_tip_node, &open_queue, &closed_set](auto a) { + if (_debug_logs) + Logger::debug("Current expanded action: " + a.print() + + ExecutionPolicy::print_thread()); + auto next_state = _domain.get_next_state(best_tip_node->state, a); + if (_debug_logs) + Logger::debug("Exploring next state " + next_state.print() + + ExecutionPolicy::print_thread()); + std::pair i; + _execution_policy.protect( + [this, &i, &next_state] { i = _graph.emplace(next_state); }); + Node &neighbor = const_cast( + *(i.first)); // we won't change the real key (StateNode::state) + // so we are safe + + bool neighbor_closed = false; + _execution_policy.protect( + [&closed_set, &neighbor, &neighbor_closed] { + neighbor_closed = + (closed_set.find(&neighbor) != closed_set.end()); }); - Node& neighbor = const_cast(*(i.first)); // we won't change the real key (StateNode::state) so we are safe - - bool neighbor_closed = false; - _execution_policy.protect([&closed_set, &neighbor, &neighbor_closed]{ - neighbor_closed = (closed_set.find(&neighbor) != closed_set.end()); - }); - if (neighbor_closed) { - // Ignore the neighbor which is already evaluated - return; - } - - double transition_cost = _domain.get_transition_value(best_tip_node->state, a, neighbor.state).cost(); - double tentative_gscore = best_tip_node->gscore + transition_cost; - - if ((i.second) || (tentative_gscore < neighbor.gscore)) { - neighbor.gscore = tentative_gscore; - neighbor.fscore = tentative_gscore + _heuristic(_domain, neighbor.state).cost(); - neighbor.best_parent = std::make_tuple(best_tip_node, a, transition_cost); - _execution_policy.protect([&open_queue, &neighbor]{ - open_queue.push(&neighbor); - }); - if (_debug_logs) Logger::debug("Update neighbor node: " + neighbor.state.print() + - ", gscore=" + StringConverter::from(neighbor.gscore) + - ", fscore=" + StringConverter::from(neighbor.fscore) + - ExecutionPolicy::print_thread()); - } - }); - } + if (neighbor_closed) { + // Ignore the neighbor which is already evaluated + return; + } - Logger::info("A* could not find a solution from state " + s.print()); - } catch (const std::exception& e) { - Logger::error("A* failed solving from state " + s.print() + ". Reason: " + e.what()); - throw; + double transition_cost = + _domain + .get_transition_value(best_tip_node->state, a, + neighbor.state) + .cost(); + double tentative_gscore = best_tip_node->gscore + transition_cost; + + if ((i.second) || (tentative_gscore < neighbor.gscore)) { + neighbor.gscore = tentative_gscore; + neighbor.fscore = + tentative_gscore + _heuristic(_domain, neighbor.state).cost(); + neighbor.best_parent = + std::make_tuple(best_tip_node, a, transition_cost); + _execution_policy.protect( + [&open_queue, &neighbor] { open_queue.push(&neighbor); }); + if (_debug_logs) + Logger::debug( + "Update neighbor node: " + neighbor.state.print() + + ", gscore=" + StringConverter::from(neighbor.gscore) + + ", fscore=" + StringConverter::from(neighbor.fscore) + + ExecutionPolicy::print_thread()); + } + }); } -} + Logger::info("A* could not find a solution from state " + s.print()); + } catch (const std::exception &e) { + Logger::error("A* failed solving from state " + s.print() + + ". Reason: " + e.what()); + throw; + } +} SK_ASTAR_SOLVER_TEMPLATE_DECL -bool SK_ASTAR_SOLVER_CLASS::is_solution_defined_for(const State& s) const { - auto si = _graph.find(s); - if ((si == _graph.end()) || (si->best_action == nullptr) || (si->solved == false)) { - return false; - } else { - return true; - } +bool SK_ASTAR_SOLVER_CLASS::is_solution_defined_for(const State &s) const { + auto si = _graph.find(s); + if ((si == _graph.end()) || (si->best_action == nullptr) || + (si->solved == false)) { + return false; + } else { + return true; + } } - SK_ASTAR_SOLVER_TEMPLATE_DECL -const typename SK_ASTAR_SOLVER_CLASS::Action& -SK_ASTAR_SOLVER_CLASS::get_best_action(const State& s) const { - auto si = _graph.find(s); - if ((si == _graph.end()) || (si->best_action == nullptr)) { - throw std::runtime_error("SKDECIDE exception: no best action found in state " + s.print()); - } - return *(si->best_action); +const typename SK_ASTAR_SOLVER_CLASS::Action & +SK_ASTAR_SOLVER_CLASS::get_best_action(const State &s) const { + auto si = _graph.find(s); + if ((si == _graph.end()) || (si->best_action == nullptr)) { + throw std::runtime_error( + "SKDECIDE exception: no best action found in state " + s.print()); + } + return *(si->best_action); } - SK_ASTAR_SOLVER_TEMPLATE_DECL -const double& SK_ASTAR_SOLVER_CLASS::get_best_value(const State& s) const { - auto si = _graph.find(s); - if (si == _graph.end()) { - throw std::runtime_error("SKDECIDE exception: no best action found in state " + s.print()); - } - return si->fscore; +const double &SK_ASTAR_SOLVER_CLASS::get_best_value(const State &s) const { + auto si = _graph.find(s); + if (si == _graph.end()) { + throw std::runtime_error( + "SKDECIDE exception: no best action found in state " + s.print()); + } + return si->fscore; } - // === AStarSolver::StateNode implementation === SK_ASTAR_SOLVER_TEMPLATE_DECL -SK_ASTAR_SOLVER_CLASS::Node::Node(const State& s) -: state(s), - gscore(std::numeric_limits::infinity()), - fscore(std::numeric_limits::infinity()), - best_action(nullptr), - solved(false) {} - +SK_ASTAR_SOLVER_CLASS::Node::Node(const State &s) + : state(s), gscore(std::numeric_limits::infinity()), + fscore(std::numeric_limits::infinity()), best_action(nullptr), + solved(false) {} SK_ASTAR_SOLVER_TEMPLATE_DECL -const typename SK_ASTAR_SOLVER_CLASS::State& -SK_ASTAR_SOLVER_CLASS::Node::Key::operator()(const Node& sn) const { - return sn.state; +const typename SK_ASTAR_SOLVER_CLASS::State & +SK_ASTAR_SOLVER_CLASS::Node::Key::operator()(const Node &sn) const { + return sn.state; } - // === AStarSolver::NodeCompare implementation === SK_ASTAR_SOLVER_TEMPLATE_DECL -bool SK_ASTAR_SOLVER_CLASS::NodeCompare::operator()(Node*& a, Node*& b) const { - return (a->fscore) > (b->fscore); // smallest element appears at the top of the priority_queue => cost optimization +bool SK_ASTAR_SOLVER_CLASS::NodeCompare::operator()(Node *&a, Node *&b) const { + return (a->fscore) > (b->fscore); // smallest element appears at the top of + // the priority_queue => cost optimization } } // namespace skdecide diff --git a/cpp/src/hub/solver/astar/py_astar.cc b/cpp/src/hub/solver/astar/py_astar.cc index 2837654962..c4611ec60e 100644 --- a/cpp/src/hub/solver/astar/py_astar.cc +++ b/cpp/src/hub/solver/astar/py_astar.cc @@ -7,24 +7,24 @@ #include "py_astar.hh" -void init_pyastar(py::module& m) { - py::class_ py_astar_solver(m, "_AStarSolver_"); - py_astar_solver - .def(py::init&, - const std::function&, - bool, - bool>(), - py::arg("domain"), - py::arg("goal_checker"), - py::arg("heuristic"), - py::arg("parallel")=false, - py::arg("debug_logs")=false) - .def("close", &skdecide::PyAStarSolver::close) - .def("clear", &skdecide::PyAStarSolver::clear) - .def("solve", &skdecide::PyAStarSolver::solve, py::arg("state")) - .def("is_solution_defined_for", &skdecide::PyAStarSolver::is_solution_defined_for, py::arg("state")) - .def("get_next_action", &skdecide::PyAStarSolver::get_next_action, py::arg("state")) - .def("get_utility", &skdecide::PyAStarSolver::get_utility, py::arg("state")) - ; +void init_pyastar(py::module &m) { + py::class_ py_astar_solver(m, "_AStarSolver_"); + py_astar_solver + .def(py::init &, + const std::function &, + bool, bool>(), + py::arg("domain"), py::arg("goal_checker"), py::arg("heuristic"), + py::arg("parallel") = false, py::arg("debug_logs") = false) + .def("close", &skdecide::PyAStarSolver::close) + .def("clear", &skdecide::PyAStarSolver::clear) + .def("solve", &skdecide::PyAStarSolver::solve, py::arg("state")) + .def("is_solution_defined_for", + &skdecide::PyAStarSolver::is_solution_defined_for, py::arg("state")) + .def("get_next_action", &skdecide::PyAStarSolver::get_next_action, + py::arg("state")) + .def("get_utility", &skdecide::PyAStarSolver::get_utility, + py::arg("state")); } diff --git a/cpp/src/hub/solver/astar/py_astar.hh b/cpp/src/hub/solver/astar/py_astar.hh index 756736bb48..07e181362c 100644 --- a/cpp/src/hub/solver/astar/py_astar.hh +++ b/cpp/src/hub/solver/astar/py_astar.hh @@ -24,189 +24,203 @@ namespace skdecide { template using PyAStarDomain = PythonDomainProxy; - class PyAStarSolver { -private : - - class BaseImplementation { - public : - virtual ~BaseImplementation() {} - virtual void close() = 0; - virtual void clear() = 0; - virtual void solve(const py::object& s) = 0; - virtual py::bool_ is_solution_defined_for(const py::object& s) = 0; - virtual py::object get_next_action(const py::object& s) = 0; - virtual py::float_ get_utility(const py::object& s) = 0; - }; - - template - class Implementation : public BaseImplementation { - public : - Implementation(py::object& domain, - const std::function& goal_checker, - const std::function& heuristic, - bool debug_logs = false) +private: + class BaseImplementation { + public: + virtual ~BaseImplementation() {} + virtual void close() = 0; + virtual void clear() = 0; + virtual void solve(const py::object &s) = 0; + virtual py::bool_ is_solution_defined_for(const py::object &s) = 0; + virtual py::object get_next_action(const py::object &s) = 0; + virtual py::float_ get_utility(const py::object &s) = 0; + }; + + template + class Implementation : public BaseImplementation { + public: + Implementation( + py::object &domain, + const std::function + &goal_checker, + const std::function + &heuristic, + bool debug_logs = false) : _goal_checker(goal_checker), _heuristic(heuristic) { - check_domain(domain); - _domain = std::make_unique>(domain); - _solver = std::make_unique, Texecution>>( - *_domain, - [this](PyAStarDomain& d, const typename PyAStarDomain::State& s) -> typename PyAStarDomain::Predicate { - try { - auto fgc = [this](const py::object& dd, const py::object& ss, [[maybe_unused]] const py::object& ii) { - return _goal_checker(dd, ss); - }; - std::unique_ptr r = d.call(nullptr, fgc, s.pyobj()); - typename skdecide::GilControl::Acquire acquire; - bool rr = r->template cast(); - r.reset(); - return rr; - } catch (const std::exception& e) { - Logger::error(std::string("SKDECIDE exception when calling goal checker: ") + e.what()); - throw; - } - }, - [this](PyAStarDomain& d, const typename PyAStarDomain::State& s) -> typename PyAStarDomain::Value { - try { - auto fh = [this](const py::object& dd, const py::object& ss, [[maybe_unused]] const py::object& ii) { - return _heuristic(dd, ss); - }; - return typename PyAStarDomain::Value(d.call(nullptr, fh, s.pyobj())); - } catch (const std::exception& e) { - Logger::error(std::string("SKDECIDE exception when calling heuristic estimator: ") + e.what()); - throw; - } - }, - debug_logs); - _stdout_redirect = std::make_unique(std::cout, - py::module::import("sys").attr("stdout")); - _stderr_redirect = std::make_unique(std::cerr, - py::module::import("sys").attr("stderr")); - } - - virtual ~Implementation() {} - - void check_domain(py::object& domain) { - if (!py::hasattr(domain, "get_applicable_actions")) { - throw std::invalid_argument("SKDECIDE exception: A* algorithm needs python domain for implementing get_applicable_actions()"); - } - if (!py::hasattr(domain, "get_next_state")) { - throw std::invalid_argument("SKDECIDE exception: A* algorithm needs python domain for implementing get_next_state()"); + check_domain(domain); + _domain = std::make_unique>(domain); + _solver = std::make_unique< + skdecide::AStarSolver, Texecution>>( + *_domain, + [this](PyAStarDomain & d, + const typename PyAStarDomain::State &s) -> + typename PyAStarDomain::Predicate { + try { + auto fgc = [this](const py::object &dd, const py::object &ss, + [[maybe_unused]] const py::object &ii) { + return _goal_checker(dd, ss); + }; + std::unique_ptr r = d.call(nullptr, fgc, s.pyobj()); + typename skdecide::GilControl::Acquire acquire; + bool rr = r->template cast(); + r.reset(); + return rr; + } catch (const std::exception &e) { + Logger::error( + std::string( + "SKDECIDE exception when calling goal checker: ") + + e.what()); + throw; } - if (!py::hasattr(domain, "get_transition_value")) { - throw std::invalid_argument("SKDECIDE exception: A* algorithm needs python domain for implementing get_transition_value()"); + }, + [this](PyAStarDomain & d, + const typename PyAStarDomain::State &s) -> + typename PyAStarDomain::Value { + try { + auto fh = [this](const py::object &dd, const py::object &ss, + [[maybe_unused]] const py::object &ii) { + return _heuristic(dd, ss); + }; + return typename PyAStarDomain::Value( + d.call(nullptr, fh, s.pyobj())); + } catch (const std::exception &e) { + Logger::error( + std::string( + "SKDECIDE exception when calling heuristic estimator: ") + + e.what()); + throw; } - } - - virtual void close() { - _domain->close(); - } - - virtual void clear() { - _solver->clear(); - } - - virtual void solve(const py::object& s) { - typename skdecide::GilControl::Release release; - _solver->solve(s); - } - - virtual py::bool_ is_solution_defined_for(const py::object& s) { - return _solver->is_solution_defined_for(s); - } + }, + debug_logs); + _stdout_redirect = std::make_unique( + std::cout, py::module::import("sys").attr("stdout")); + _stderr_redirect = std::make_unique( + std::cerr, py::module::import("sys").attr("stderr")); + } - virtual py::object get_next_action(const py::object& s) { - return _solver->get_best_action(s).pyobj(); - } + virtual ~Implementation() {} + + void check_domain(py::object &domain) { + if (!py::hasattr(domain, "get_applicable_actions")) { + throw std::invalid_argument( + "SKDECIDE exception: A* algorithm needs python domain for " + "implementing get_applicable_actions()"); + } + if (!py::hasattr(domain, "get_next_state")) { + throw std::invalid_argument( + "SKDECIDE exception: A* algorithm needs python domain for " + "implementing get_next_state()"); + } + if (!py::hasattr(domain, "get_transition_value")) { + throw std::invalid_argument( + "SKDECIDE exception: A* algorithm needs python domain for " + "implementing get_transition_value()"); + } + } - virtual py::float_ get_utility(const py::object& s) { - return _solver->get_best_value(s); - } + virtual void close() { _domain->close(); } - private : - std::unique_ptr> _domain; - std::unique_ptr, Texecution>> _solver; + virtual void clear() { _solver->clear(); } - std::function _goal_checker; - std::function _heuristic; + virtual void solve(const py::object &s) { + typename skdecide::GilControl::Release release; + _solver->solve(s); + } - std::unique_ptr _stdout_redirect; - std::unique_ptr _stderr_redirect; - }; + virtual py::bool_ is_solution_defined_for(const py::object &s) { + return _solver->is_solution_defined_for(s); + } - struct ExecutionSelector { - bool _parallel; + virtual py::object get_next_action(const py::object &s) { + return _solver->get_best_action(s).pyobj(); + } - ExecutionSelector(bool parallel) : _parallel(parallel) {} + virtual py::float_ get_utility(const py::object &s) { + return _solver->get_best_value(s); + } - template - struct Select { - template - Select(ExecutionSelector& This, Args... args) { - if (This._parallel) { - Propagator::template PushType::Forward(args...); - } else { - Propagator::template PushType::Forward(args...); - } - } - }; + private: + std::unique_ptr> _domain; + std::unique_ptr< + skdecide::AStarSolver, Texecution>> + _solver; + + std::function + _goal_checker; + std::function + _heuristic; + + std::unique_ptr _stdout_redirect; + std::unique_ptr _stderr_redirect; + }; + + struct ExecutionSelector { + bool _parallel; + + ExecutionSelector(bool parallel) : _parallel(parallel) {} + + template struct Select { + template + Select(ExecutionSelector &This, Args... args) { + if (This._parallel) { + Propagator::template PushType::Forward(args...); + } else { + Propagator::template PushType::Forward(args...); + } + } }; + }; - struct SolverInstantiator { - std::unique_ptr& _implementation; + struct SolverInstantiator { + std::unique_ptr &_implementation; - SolverInstantiator(std::unique_ptr& implementation) + SolverInstantiator(std::unique_ptr &implementation) : _implementation(implementation) {} - template - struct Instantiate { - template - Instantiate(SolverInstantiator& This, Args... args) { - This._implementation = std::make_unique>(args...); - } - }; + template struct Instantiate { + template + Instantiate(SolverInstantiator &This, Args... args) { + This._implementation = + std::make_unique>(args...); + } }; + }; - std::unique_ptr _implementation; + std::unique_ptr _implementation; -public : - PyAStarSolver(py::object& domain, - const std::function& goal_checker, - const std::function& heuristic, - bool parallel = false, - bool debug_logs = false) { +public: + PyAStarSolver( + py::object &domain, + const std::function + &goal_checker, + const std::function + &heuristic, + bool parallel = false, bool debug_logs = false) { - TemplateInstantiator::select( - ExecutionSelector(parallel), - SolverInstantiator(_implementation)).instantiate( - domain, goal_checker, heuristic, debug_logs); + TemplateInstantiator::select(ExecutionSelector(parallel), + SolverInstantiator(_implementation)) + .instantiate(domain, goal_checker, heuristic, debug_logs); + } - } + void close() { _implementation->close(); } - void close() { - _implementation->close(); - } + void clear() { _implementation->clear(); } - void clear() { - _implementation->clear(); - } + void solve(const py::object &s) { _implementation->solve(s); } - void solve(const py::object& s) { - _implementation->solve(s); - } + py::bool_ is_solution_defined_for(const py::object &s) { + return _implementation->is_solution_defined_for(s); + } - py::bool_ is_solution_defined_for(const py::object& s) { - return _implementation->is_solution_defined_for(s); - } - - py::object get_next_action(const py::object& s) { - return _implementation->get_next_action(s); - } + py::object get_next_action(const py::object &s) { + return _implementation->get_next_action(s); + } - py::float_ get_utility(const py::object& s) { - return _implementation->get_utility(s); - } + py::float_ get_utility(const py::object &s) { + return _implementation->get_utility(s); + } }; } // namespace skdecide diff --git a/cpp/src/hub/solver/bfws/bfws.hh b/cpp/src/hub/solver/bfws/bfws.hh index 5ba52adf9c..cd509c04a6 100644 --- a/cpp/src/hub/solver/bfws/bfws.hh +++ b/cpp/src/hub/solver/bfws/bfws.hh @@ -5,7 +5,8 @@ #ifndef SKDECIDE_BFWS_HH #define SKDECIDE_BFWS_HH -// From paper: Best-First Width Search: Exploration and Exploitation in Classical Planning +// From paper: Best-First Width Search: Exploration and Exploitation in +// Classical Planning // by Nir Lipovetsky and Hector Geffner // in proceedings of AAAI 2017 @@ -24,112 +25,115 @@ namespace skdecide { /** Use default hasher provided with domain's states */ -template -struct DomainStateHash { - typedef typename Tdomain::State Key; +template struct DomainStateHash { + typedef typename Tdomain::State Key; - template - static const Key& get_key(const Tnode& n); + template static const Key &get_key(const Tnode &n); - struct Hash { - std::size_t operator()(const Key& k) const; - }; + struct Hash { + std::size_t operator()(const Key &k) const; + }; - struct Equal { - bool operator()(const Key& k1, const Key& k2) const; - }; + struct Equal { + bool operator()(const Key &k1, const Key &k2) const; + }; }; - /** Use state binary feature vector to hash states */ -template -struct StateFeatureHash { - typedef Tfeature_vector Key; +template struct StateFeatureHash { + typedef Tfeature_vector Key; - template - static const Key& get_key(const Tnode& n); + template static const Key &get_key(const Tnode &n); - struct Hash { - std::size_t operator()(const Key& k) const; - }; + struct Hash { + std::size_t operator()(const Key &k) const; + }; - struct Equal { - bool operator()(const Key& k1, const Key& k2) const; - }; + struct Equal { + bool operator()(const Key &k1, const Key &k2) const; + }; }; -template class Thashing_policy = DomainStateHash, typename Texecution_policy = SequentialExecution> class BFWSSolver { -public : - typedef Tdomain Domain; - typedef typename Domain::State State; - typedef typename Domain::Action Action; - typedef typename Domain::Value Value; - typedef Tfeature_vector FeatureVector; - typedef Thashing_policy HashingPolicy; - typedef Texecution_policy ExecutionPolicy; - - BFWSSolver(Domain& domain, - const std::function (Domain& d, const State& s)>& state_features, - const std::function& heuristic, - const std::function& termination_checker, - bool debug_logs = false); - - // clears the solver (clears the search graph, thus preventing from reusing - // previous search results) - void clear(); - - // solves from state s - void solve(const State& s); - - bool is_solution_defined_for(const State& s) const; - const Action& get_best_action(const State& s) const; - const double& get_best_value(const State& s) const; - -private : - Domain& _domain; - std::function (Domain& d, const State& s)> _state_features; - std::function _heuristic; - std::function _termination_checker; - bool _debug_logs; - ExecutionPolicy _execution_policy; - - typedef std::pair PairType; - typedef std::unordered_map>> PairMap; - - struct Node { - State state; - std::unique_ptr features; - std::tuple best_parent; - std::size_t novelty; - double heuristic; - double gscore; - double fscore; - Action* best_action; // computed only when constructing the solution path backward from the goal state - bool solved; // set to true if on the solution path constructed backward from the goal state - - Node(const State& s, Domain& d, - const std::function (Domain& d, const State& s)>& state_features); - - struct Key { - const typename HashingPolicy::Key& operator()(const Node& n) const; - }; +public: + typedef Tdomain Domain; + typedef typename Domain::State State; + typedef typename Domain::Action Action; + typedef typename Domain::Value Value; + typedef Tfeature_vector FeatureVector; + typedef Thashing_policy HashingPolicy; + typedef Texecution_policy ExecutionPolicy; + + BFWSSolver( + Domain &domain, + const std::function( + Domain &d, const State &s)> &state_features, + const std::function &heuristic, + const std::function &termination_checker, + bool debug_logs = false); + + // clears the solver (clears the search graph, thus preventing from reusing + // previous search results) + void clear(); + + // solves from state s + void solve(const State &s); + + bool is_solution_defined_for(const State &s) const; + const Action &get_best_action(const State &s) const; + const double &get_best_value(const State &s) const; + +private: + Domain &_domain; + std::function(Domain &d, const State &s)> + _state_features; + std::function _heuristic; + std::function _termination_checker; + bool _debug_logs; + ExecutionPolicy _execution_policy; + + typedef std::pair PairType; + typedef std::unordered_map< + double, std::unordered_set>> + PairMap; + + struct Node { + State state; + std::unique_ptr features; + std::tuple best_parent; + std::size_t novelty; + double heuristic; + double gscore; + double fscore; + Action *best_action; // computed only when constructing the solution path + // backward from the goal state + bool solved; // set to true if on the solution path constructed backward + // from the goal state + + Node(const State &s, Domain &d, + const std::function( + Domain &d, const State &s)> &state_features); + + struct Key { + const typename HashingPolicy::Key &operator()(const Node &n) const; }; + }; - // we only compute novelty of 1 for complexity reasons and assign all other novelties to +infty - // see paper "Best-First Width Search: Exploration and Exploitation in Classical Planning" by Lipovetsky and Geffner - std::size_t novelty(PairMap& heuristic_features_map, - const double& heuristic_value, Node& n) const; + // we only compute novelty of 1 for complexity reasons and assign all other + // novelties to +infty see paper "Best-First Width Search: Exploration and + // Exploitation in Classical Planning" by Lipovetsky and Geffner + std::size_t novelty(PairMap &heuristic_features_map, + const double &heuristic_value, Node &n) const; - struct NodeCompare { - bool operator()(Node*& a, Node*& b) const; - }; + struct NodeCompare { + bool operator()(Node *&a, Node *&b) const; + }; - typedef typename SetTypeDeducer::Set Graph; - Graph _graph; + typedef typename SetTypeDeducer::Set Graph; + Graph _graph; }; } // namespace skdecide diff --git a/cpp/src/hub/solver/bfws/impl/bfws_impl.hh b/cpp/src/hub/solver/bfws/impl/bfws_impl.hh index a39eeeb543..5c144eb5a8 100644 --- a/cpp/src/hub/solver/bfws/impl/bfws_impl.hh +++ b/cpp/src/hub/solver/bfws/impl/bfws_impl.hh @@ -14,299 +14,341 @@ namespace skdecide { // === DomainStateHash implementation === -#define SK_BFWS_DOMAIN_STATE_HASH_TEMPLATE_DECL \ -template +#define SK_BFWS_DOMAIN_STATE_HASH_TEMPLATE_DECL \ + template -#define SK_BFWS_DOMAIN_STATE_HASH_CLASS \ -DomainStateHash +#define SK_BFWS_DOMAIN_STATE_HASH_CLASS \ + DomainStateHash SK_BFWS_DOMAIN_STATE_HASH_TEMPLATE_DECL template -const typename SK_BFWS_DOMAIN_STATE_HASH_CLASS::Key& -SK_BFWS_DOMAIN_STATE_HASH_CLASS::get_key(const Tnode& n) { - return n.state; +const typename SK_BFWS_DOMAIN_STATE_HASH_CLASS::Key & +SK_BFWS_DOMAIN_STATE_HASH_CLASS::get_key(const Tnode &n) { + return n.state; } SK_BFWS_DOMAIN_STATE_HASH_TEMPLATE_DECL -std::size_t SK_BFWS_DOMAIN_STATE_HASH_CLASS::Hash::operator()(const Key& k) const { - return typename Tdomain::State::Hash()(k); +std::size_t +SK_BFWS_DOMAIN_STATE_HASH_CLASS::Hash::operator()(const Key &k) const { + return typename Tdomain::State::Hash()(k); } SK_BFWS_DOMAIN_STATE_HASH_TEMPLATE_DECL -bool SK_BFWS_DOMAIN_STATE_HASH_CLASS::Equal::operator()(const Key& k1, const Key& k2) const { - return typename Tdomain::State::Equal()(k1, k2); +bool SK_BFWS_DOMAIN_STATE_HASH_CLASS::Equal::operator()(const Key &k1, + const Key &k2) const { + return typename Tdomain::State::Equal()(k1, k2); } // === StateFeatureHash implementation === -#define SK_BFWS_STATE_FEATURE_HASH_TEMPLATE_DECL \ -template +#define SK_BFWS_STATE_FEATURE_HASH_TEMPLATE_DECL \ + template -#define SK_BFWS_STATE_FEATURE_HASH_CLASS \ -StateFeatureHash +#define SK_BFWS_STATE_FEATURE_HASH_CLASS \ + StateFeatureHash SK_BFWS_STATE_FEATURE_HASH_TEMPLATE_DECL template -const typename SK_BFWS_STATE_FEATURE_HASH_CLASS::Key& -SK_BFWS_STATE_FEATURE_HASH_CLASS::get_key(const Tnode& n) { - return *n.features; +const typename SK_BFWS_STATE_FEATURE_HASH_CLASS::Key & +SK_BFWS_STATE_FEATURE_HASH_CLASS::get_key(const Tnode &n) { + return *n.features; } SK_BFWS_STATE_FEATURE_HASH_TEMPLATE_DECL -std::size_t SK_BFWS_STATE_FEATURE_HASH_CLASS::Hash::operator()(const Key& k) const { - std::size_t seed = 0; - for (std::size_t i = 0 ; i < k.size() ; i++) { - boost::hash_combine(seed, k[i]); - } - return seed; +std::size_t +SK_BFWS_STATE_FEATURE_HASH_CLASS::Hash::operator()(const Key &k) const { + std::size_t seed = 0; + for (std::size_t i = 0; i < k.size(); i++) { + boost::hash_combine(seed, k[i]); + } + return seed; } SK_BFWS_STATE_FEATURE_HASH_TEMPLATE_DECL -bool SK_BFWS_STATE_FEATURE_HASH_CLASS::Equal::operator()(const Key& k1, const Key& k2) const { - std::size_t size = k1.size(); - if (size != k2.size()) { - return false; - } - for (std::size_t i = 0 ; i < size ; i++) { - if (!(k1[i] == k2[i])) { - return false; - } +bool SK_BFWS_STATE_FEATURE_HASH_CLASS::Equal::operator()(const Key &k1, + const Key &k2) const { + std::size_t size = k1.size(); + if (size != k2.size()) { + return false; + } + for (std::size_t i = 0; i < size; i++) { + if (!(k1[i] == k2[i])) { + return false; } - return true; + } + return true; } - // === BFWSSolver implementation === -#define SK_BFWS_SOLVER_TEMPLATE_DECL \ -template class Thashing_policy, \ - typename Texecution_policy> +#define SK_BFWS_SOLVER_TEMPLATE_DECL \ + template class Thashing_policy, \ + typename Texecution_policy> -#define SK_BFWS_SOLVER_CLASS \ -BFWSSolver +#define SK_BFWS_SOLVER_CLASS \ + BFWSSolver SK_BFWS_SOLVER_TEMPLATE_DECL -SK_BFWS_SOLVER_CLASS::BFWSSolver(Domain& domain, - const std::function (Domain& d, const State& s)>& state_features, - const std::function& heuristic, - const std::function& termination_checker, - bool debug_logs) -: _domain(domain), _state_features(state_features), - _heuristic(heuristic), _termination_checker(termination_checker), - _debug_logs(debug_logs) { - - if (debug_logs) { - Logger::check_level(logging::debug, "algorithm BFWS"); - } - +SK_BFWS_SOLVER_CLASS::BFWSSolver( + Domain &domain, + const std::function( + Domain &d, const State &s)> &state_features, + const std::function &heuristic, + const std::function &termination_checker, + bool debug_logs) + : _domain(domain), _state_features(state_features), _heuristic(heuristic), + _termination_checker(termination_checker), _debug_logs(debug_logs) { + + if (debug_logs) { + Logger::check_level(logging::debug, "algorithm BFWS"); + } } - SK_BFWS_SOLVER_TEMPLATE_DECL -void SK_BFWS_SOLVER_CLASS::clear() { - _graph.clear(); -} - +void SK_BFWS_SOLVER_CLASS::clear() { _graph.clear(); } SK_BFWS_SOLVER_TEMPLATE_DECL -void SK_BFWS_SOLVER_CLASS::solve(const State& s) { - try { - Logger::info("Running " + ExecutionPolicy::print_type() + " BFWS solver from state " + s.print()); - auto start_time = std::chrono::high_resolution_clock::now(); - - // Map from heuristic values to set of state features with that given heuristic value - // whose value has changed at least once since the beginning of the search - // (stored by their index and value) - PairMap heuristic_features_map; - - // Create the root node containing the given state s - auto si = _graph.emplace(Node(s, _domain, _state_features)); - if (si.first->solved || _termination_checker(_domain, s)) { // problem already solved from this state (was present in _graph and already solved) - return; +void SK_BFWS_SOLVER_CLASS::solve(const State &s) { + try { + Logger::info("Running " + ExecutionPolicy::print_type() + + " BFWS solver from state " + s.print()); + auto start_time = std::chrono::high_resolution_clock::now(); + + // Map from heuristic values to set of state features with that given + // heuristic value whose value has changed at least once since the beginning + // of the search (stored by their index and value) + PairMap heuristic_features_map; + + // Create the root node containing the given state s + auto si = _graph.emplace(Node(s, _domain, _state_features)); + if (si.first->solved || + _termination_checker(_domain, + s)) { // problem already solved from this state + // (was present in _graph and already solved) + return; + } + Node &root_node = const_cast(*( + si.first)); // we won't change the real key (Node::state) so we are safe + root_node.gscore = 0; + root_node.heuristic = _heuristic(_domain, root_node.state).cost(); + root_node.novelty = + novelty(heuristic_features_map, root_node.heuristic, root_node); + + // Priority queue used to sort non-goal unsolved tip nodes by increasing + // cost-to-go values (so-called OPEN container) + std::priority_queue, NodeCompare> open_queue; + open_queue.push(&root_node); + + // Set of states that have already been explored + std::unordered_set closed_set; + + while (!open_queue.empty()) { + auto best_tip_node = open_queue.top(); + open_queue.pop(); + + // Check that the best tip node has not already been closed before + // (since this implementation's open_queue does not check for element + // uniqueness, it can contain many copies of the same node pointer that + // could have been closed earlier) + if (closed_set.find(best_tip_node) != + closed_set + .end()) { // this implementation's open_queue can contain several + continue; + } + + if (_debug_logs) + Logger::debug("Current best tip node (h=" + + StringConverter::from(best_tip_node->heuristic) + + ", n=" + StringConverter::from(best_tip_node->novelty) + + "): " + best_tip_node->state.print()); + + if (_termination_checker(_domain, best_tip_node->state) || + best_tip_node->solved) { + if (_debug_logs) + Logger::debug("Found a terminal state: " + + best_tip_node->state.print()); + auto current_node = best_tip_node; + if (!(best_tip_node->solved)) { + current_node->fscore = 0; + } // goal state + + while (current_node != &root_node) { + Node *parent_node = std::get<0>(current_node->best_parent); + parent_node->best_action = &std::get<1>(current_node->best_parent); + parent_node->fscore = + std::get<2>(current_node->best_parent) + current_node->fscore; + parent_node->solved = true; + current_node = parent_node; } - Node& root_node = const_cast(*(si.first)); // we won't change the real key (Node::state) so we are safe - root_node.gscore = 0; - root_node.heuristic = _heuristic(_domain, root_node.state).cost(); - root_node.novelty = novelty(heuristic_features_map, root_node.heuristic, root_node); - - // Priority queue used to sort non-goal unsolved tip nodes by increasing cost-to-go values (so-called OPEN container) - std::priority_queue, NodeCompare> open_queue; - open_queue.push(&root_node); - - // Set of states that have already been explored - std::unordered_set closed_set; - - while (!open_queue.empty()) { - auto best_tip_node = open_queue.top(); - open_queue.pop(); - - // Check that the best tip node has not already been closed before - // (since this implementation's open_queue does not check for element uniqueness, - // it can contain many copies of the same node pointer that could have been closed earlier) - if (closed_set.find(best_tip_node) != closed_set.end()) { // this implementation's open_queue can contain several - continue; + + auto end_time = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast( + end_time - start_time) + .count(); + Logger::info("BFWS finished to solve from state " + s.print() + " in " + + StringConverter::from((double)duration / (double)1e9) + + " seconds."); + return; + } + + closed_set.insert(best_tip_node); + + // Expand best tip node + auto applicable_actions = + _domain.get_applicable_actions(best_tip_node->state).get_elements(); + std::for_each( + ExecutionPolicy::policy, applicable_actions.begin(), + applicable_actions.end(), + [this, &best_tip_node, &open_queue, &closed_set, + &heuristic_features_map](auto a) { + if (_debug_logs) + Logger::debug("Current expanded action: " + a.print() + + ExecutionPolicy::print_thread()); + auto next_state = _domain.get_next_state(best_tip_node->state, a); + if (_debug_logs) + Logger::debug("Exploring next state " + next_state.print() + + ExecutionPolicy::print_thread()); + std::pair i; + _execution_policy.protect([this, &i, &next_state] { + i = _graph.emplace(Node(next_state, _domain, _state_features)); + }); + Node &neighbor = const_cast( + *(i.first)); // we won't change the real key (StateNode::state) + // so we are safe + + if (closed_set.find(&neighbor) != closed_set.end()) { + // Ignore the neighbor which is already evaluated + return; } - if (_debug_logs) Logger::debug("Current best tip node (h=" + StringConverter::from(best_tip_node->heuristic) + - ", n=" + StringConverter::from(best_tip_node->novelty) + - "): " + best_tip_node->state.print()); - - if (_termination_checker(_domain, best_tip_node->state) || best_tip_node->solved) { - if (_debug_logs) Logger::debug("Found a terminal state: " + best_tip_node->state.print()); - auto current_node = best_tip_node; - if (!(best_tip_node->solved)) { current_node->fscore = 0; } // goal state - - while (current_node != &root_node) { - Node* parent_node = std::get<0>(current_node->best_parent); - parent_node->best_action = &std::get<1>(current_node->best_parent); - parent_node->fscore = std::get<2>(current_node->best_parent) + current_node->fscore; - parent_node->solved = true; - current_node = parent_node; - } - - auto end_time = std::chrono::high_resolution_clock::now(); - auto duration = std::chrono::duration_cast(end_time - start_time).count(); - Logger::info("BFWS finished to solve from state " + s.print() + " in " + StringConverter::from((double) duration / (double) 1e9) + " seconds."); - return; + double transition_cost = + _domain + .get_transition_value(best_tip_node->state, a, + neighbor.state) + .cost(); + double tentative_gscore = best_tip_node->gscore + transition_cost; + + if ((i.second) || (tentative_gscore < neighbor.gscore)) { + if (_debug_logs) + Logger::debug("New gscore: " + + StringConverter::from(best_tip_node->gscore) + + "+" + StringConverter::from(transition_cost) + + "=" + StringConverter::from(tentative_gscore) + + ExecutionPolicy::print_thread()); + neighbor.gscore = tentative_gscore; + neighbor.best_parent = + std::make_tuple(best_tip_node, a, transition_cost); } - closed_set.insert(best_tip_node); - - // Expand best tip node - auto applicable_actions = _domain.get_applicable_actions(best_tip_node->state).get_elements(); - std::for_each(ExecutionPolicy::policy, applicable_actions.begin(), applicable_actions.end(), [this, &best_tip_node, &open_queue, &closed_set, &heuristic_features_map](auto a) { - if (_debug_logs) Logger::debug("Current expanded action: " + a.print() + ExecutionPolicy::print_thread()); - auto next_state = _domain.get_next_state(best_tip_node->state, a); - if (_debug_logs) Logger::debug("Exploring next state " + next_state.print() + ExecutionPolicy::print_thread()); - std::pair i; - _execution_policy.protect([this, &i, &next_state]{ - i = _graph.emplace(Node(next_state, _domain, _state_features)); - }); - Node& neighbor = const_cast(*(i.first)); // we won't change the real key (StateNode::state) so we are safe - - if (closed_set.find(&neighbor) != closed_set.end()) { - // Ignore the neighbor which is already evaluated - return; - } - - double transition_cost = _domain.get_transition_value(best_tip_node->state, a, neighbor.state).cost(); - double tentative_gscore = best_tip_node->gscore + transition_cost; - - if ((i.second) || (tentative_gscore < neighbor.gscore)) { - if (_debug_logs) Logger::debug("New gscore: " + - StringConverter::from(best_tip_node->gscore) + "+" + - StringConverter::from(transition_cost) + "=" + - StringConverter::from(tentative_gscore) + - ExecutionPolicy::print_thread()); - neighbor.gscore = tentative_gscore; - neighbor.best_parent = std::make_tuple(best_tip_node, a, transition_cost); - } - - neighbor.heuristic = _heuristic(_domain, neighbor.state).cost(); - if (_debug_logs) Logger::debug("Heuristic: " + StringConverter::from(neighbor.heuristic) + - ExecutionPolicy::print_thread()); - _execution_policy.protect([this, &heuristic_features_map, &open_queue, &neighbor]{ - neighbor.novelty = this->novelty(heuristic_features_map, neighbor.heuristic, neighbor); - open_queue.push(&neighbor); - if (_debug_logs) Logger::debug("Novelty: " + StringConverter::from(neighbor.novelty) + - ExecutionPolicy::print_thread()); - }); + neighbor.heuristic = _heuristic(_domain, neighbor.state).cost(); + if (_debug_logs) + Logger::debug( + "Heuristic: " + StringConverter::from(neighbor.heuristic) + + ExecutionPolicy::print_thread()); + _execution_policy.protect([this, &heuristic_features_map, + &open_queue, &neighbor] { + neighbor.novelty = this->novelty(heuristic_features_map, + neighbor.heuristic, neighbor); + open_queue.push(&neighbor); + if (_debug_logs) + Logger::debug( + "Novelty: " + StringConverter::from(neighbor.novelty) + + ExecutionPolicy::print_thread()); }); - } - - Logger::info("BFWS could not find a solution from state " + s.print()); - } catch (const std::exception& e) { - Logger::error("BFWS failed solving from state " + s.print() + ". Reason: " + e.what()); - throw; + }); } -} + Logger::info("BFWS could not find a solution from state " + s.print()); + } catch (const std::exception &e) { + Logger::error("BFWS failed solving from state " + s.print() + + ". Reason: " + e.what()); + throw; + } +} SK_BFWS_SOLVER_TEMPLATE_DECL -bool SK_BFWS_SOLVER_CLASS::is_solution_defined_for(const State& s) const { - auto si = _graph.find(Node(s, _domain, _state_features)); - if ((si == _graph.end()) || (si->best_action == nullptr) || (si->solved == false)) { - return false; - } else { - return true; - } +bool SK_BFWS_SOLVER_CLASS::is_solution_defined_for(const State &s) const { + auto si = _graph.find(Node(s, _domain, _state_features)); + if ((si == _graph.end()) || (si->best_action == nullptr) || + (si->solved == false)) { + return false; + } else { + return true; + } } - SK_BFWS_SOLVER_TEMPLATE_DECL -const typename SK_BFWS_SOLVER_CLASS::Action& -SK_BFWS_SOLVER_CLASS::get_best_action(const State& s) const { - auto si = _graph.find(Node(s, _domain, _state_features)); - if ((si == _graph.end()) || (si->best_action == nullptr)) { - throw std::runtime_error("SKDECIDE exception: no best action found in state " + s.print()); - } - return *(si->best_action); +const typename SK_BFWS_SOLVER_CLASS::Action & +SK_BFWS_SOLVER_CLASS::get_best_action(const State &s) const { + auto si = _graph.find(Node(s, _domain, _state_features)); + if ((si == _graph.end()) || (si->best_action == nullptr)) { + throw std::runtime_error( + "SKDECIDE exception: no best action found in state " + s.print()); + } + return *(si->best_action); } - SK_BFWS_SOLVER_TEMPLATE_DECL -const double& SK_BFWS_SOLVER_CLASS::get_best_value(const State& s) const { - auto si = _graph.find(Node(s, _domain, _state_features)); - if (si == _graph.end()) { - throw std::runtime_error("SKDECIDE exception: no best action found in state " + s.print()); - } - return si->fscore; +const double &SK_BFWS_SOLVER_CLASS::get_best_value(const State &s) const { + auto si = _graph.find(Node(s, _domain, _state_features)); + if (si == _graph.end()) { + throw std::runtime_error( + "SKDECIDE exception: no best action found in state " + s.print()); + } + return si->fscore; } - SK_BFWS_SOLVER_TEMPLATE_DECL -std::size_t SK_BFWS_SOLVER_CLASS::novelty(PairMap& heuristic_features_map, - const double& heuristic_value, - Node& n) const { - auto r = heuristic_features_map.emplace(heuristic_value, std::unordered_set>()); - std::unordered_set>& features = r.first->second; - std::size_t nov = 0; - const FeatureVector& state_features = *n.features; - for (std::size_t i = 0 ; i < state_features.size() ; i++) { - nov += (std::size_t) features.insert(std::make_pair(i, state_features[i])).second; - } - if (r.second) { - nov = 0; - } else if (nov == 0) { - nov = n.features->size() + 1; - } - return nov; +std::size_t SK_BFWS_SOLVER_CLASS::novelty(PairMap &heuristic_features_map, + const double &heuristic_value, + Node &n) const { + auto r = heuristic_features_map.emplace( + heuristic_value, std::unordered_set>()); + std::unordered_set> &features = + r.first->second; + std::size_t nov = 0; + const FeatureVector &state_features = *n.features; + for (std::size_t i = 0; i < state_features.size(); i++) { + nov += (std::size_t)features.insert(std::make_pair(i, state_features[i])) + .second; + } + if (r.second) { + nov = 0; + } else if (nov == 0) { + nov = n.features->size() + 1; + } + return nov; } - // === BFWSSolver::Node implementation === SK_BFWS_SOLVER_TEMPLATE_DECL -SK_BFWS_SOLVER_CLASS::Node::Node(const State& s, - Domain& d, - const std::function (Domain& d, const State& s)>& state_features) -: state(s), - novelty(std::numeric_limits::max()), - gscore(std::numeric_limits::infinity()), - fscore(std::numeric_limits::infinity()), - best_action(nullptr), - solved(false) { - features = state_features(d, s); +SK_BFWS_SOLVER_CLASS::Node::Node( + const State &s, Domain &d, + const std::function( + Domain &d, const State &s)> &state_features) + : state(s), novelty(std::numeric_limits::max()), + gscore(std::numeric_limits::infinity()), + fscore(std::numeric_limits::infinity()), best_action(nullptr), + solved(false) { + features = state_features(d, s); } - SK_BFWS_SOLVER_TEMPLATE_DECL -const typename SK_BFWS_SOLVER_CLASS::HashingPolicy::Key& -SK_BFWS_SOLVER_CLASS::Node::Key::operator()(const Node& n) const { - return HashingPolicy::get_key(n); +const typename SK_BFWS_SOLVER_CLASS::HashingPolicy::Key & +SK_BFWS_SOLVER_CLASS::Node::Key::operator()(const Node &n) const { + return HashingPolicy::get_key(n); } - // === AStarSolver::NodeCompare implementation === SK_BFWS_SOLVER_TEMPLATE_DECL -bool SK_BFWS_SOLVER_CLASS::NodeCompare::operator()(Node*& a, Node*& b) const { - // smallest element appears at the top of the priority_queue => cost optimization - // rank first by heuristic values then by novelty measures - return ((a->heuristic) > (b->heuristic)) || - (((a->heuristic) == (b->heuristic)) && ((a->novelty) > (b->novelty))); +bool SK_BFWS_SOLVER_CLASS::NodeCompare::operator()(Node *&a, Node *&b) const { + // smallest element appears at the top of the priority_queue => cost + // optimization rank first by heuristic values then by novelty measures + return ((a->heuristic) > (b->heuristic)) || + (((a->heuristic) == (b->heuristic)) && ((a->novelty) > (b->novelty))); } } // namespace skdecide diff --git a/cpp/src/hub/solver/bfws/py_bfws.cc b/cpp/src/hub/solver/bfws/py_bfws.cc index b7f9a2c77a..e3b5228c51 100644 --- a/cpp/src/hub/solver/bfws/py_bfws.cc +++ b/cpp/src/hub/solver/bfws/py_bfws.cc @@ -9,28 +9,28 @@ namespace py = pybind11; -void init_pybfws(py::module& m) { - py::class_ py_bfws_solver(m, "_BFWSSolver_"); - py_bfws_solver - .def(py::init&, - const std::function&, - const std::function&, - bool, - bool, - bool>(), - py::arg("domain"), - py::arg("state_features"), - py::arg("heuristic"), - py::arg("termination_checker"), - py::arg("use_state_feature_hash")=false, - py::arg("parallel")=false, - py::arg("debug_logs")=false) - .def("close", &skdecide::PyBFWSSolver::close) - .def("clear", &skdecide::PyBFWSSolver::clear) - .def("solve", &skdecide::PyBFWSSolver::solve, py::arg("state")) - .def("is_solution_defined_for", &skdecide::PyBFWSSolver::is_solution_defined_for, py::arg("state")) - .def("get_next_action", &skdecide::PyBFWSSolver::get_next_action, py::arg("state")) - .def("get_utility", &skdecide::PyBFWSSolver::get_utility, py::arg("state")) - ; +void init_pybfws(py::module &m) { + py::class_ py_bfws_solver(m, "_BFWSSolver_"); + py_bfws_solver + .def(py::init &, + const std::function &, + const std::function &, + bool, bool, bool>(), + py::arg("domain"), py::arg("state_features"), py::arg("heuristic"), + py::arg("termination_checker"), + py::arg("use_state_feature_hash") = false, + py::arg("parallel") = false, py::arg("debug_logs") = false) + .def("close", &skdecide::PyBFWSSolver::close) + .def("clear", &skdecide::PyBFWSSolver::clear) + .def("solve", &skdecide::PyBFWSSolver::solve, py::arg("state")) + .def("is_solution_defined_for", + &skdecide::PyBFWSSolver::is_solution_defined_for, py::arg("state")) + .def("get_next_action", &skdecide::PyBFWSSolver::get_next_action, + py::arg("state")) + .def("get_utility", &skdecide::PyBFWSSolver::get_utility, + py::arg("state")); } diff --git a/cpp/src/hub/solver/bfws/py_bfws.hh b/cpp/src/hub/solver/bfws/py_bfws.hh index 48979f7439..da834248f9 100644 --- a/cpp/src/hub/solver/bfws/py_bfws.hh +++ b/cpp/src/hub/solver/bfws/py_bfws.hh @@ -32,236 +32,264 @@ using PyBFWSFeatureVector = PythonContainerProxy; template using PyBFWSFeatureVector = skdecide::PythonContainerProxy; - class PyBFWSSolver { -private : - - class BaseImplementation { - public : - virtual ~BaseImplementation() {} - virtual void close() = 0; - virtual void clear() = 0; - virtual void solve(const py::object& s) = 0; - virtual py::bool_ is_solution_defined_for(const py::object& s) = 0; - virtual py::object get_next_action(const py::object& s) = 0; - virtual py::float_ get_utility(const py::object& s) = 0; - }; - - template class Thashing_policy> - class Implementation : public BaseImplementation { - public : - Implementation(py::object& domain, - const std::function& state_features, - const std::function& heuristic, - const std::function& termination_checker, - bool debug_logs = false) - : _state_features(state_features), _heuristic(heuristic), _termination_checker(termination_checker) { - - check_domain(domain); - _domain = std::make_unique>(domain); - _solver = std::make_unique, PyBFWSFeatureVector, Thashing_policy, Texecution>>( - *_domain, - [this](PyBFWSDomain& d, const typename PyBFWSDomain::State& s)->std::unique_ptr> { - try { - auto fsf = [this](const py::object& dd, const py::object& ss, [[maybe_unused]] const py::object& ii) { - return _state_features(dd, ss); - }; - std::unique_ptr r = d.call(nullptr, fsf, s.pyobj()); - typename skdecide::GilControl::Acquire acquire; - std::unique_ptr> rr = std::make_unique>(*r); - r.reset(); - return rr; - } catch (const std::exception& e) { - Logger::error(std::string("SKDECIDE exception when calling state features: ") + e.what()); - throw; - } - }, - [this](PyBFWSDomain& d, const typename PyBFWSDomain::State& s) -> typename PyBFWSDomain::Value { - try { - auto fh = [this](const py::object& dd, const py::object& ss, [[maybe_unused]] const py::object& ii) { - return _heuristic(dd, ss); - }; - return typename PyBFWSDomain::Value(d.call(nullptr, fh, s.pyobj())); - } catch (const std::exception& e) { - Logger::error(std::string("SKDECIDE exception when calling heuristic estimator: ") + e.what()); - throw; - } - }, - [this](PyBFWSDomain& d, const typename PyBFWSDomain::State& s)->bool { - try { - auto ftc = [this](const py::object& dd, const py::object& ss, [[maybe_unused]] const py::object& ii) { - return _termination_checker(dd, ss); - }; - std::unique_ptr r = d.call(nullptr, ftc, s.pyobj()); - typename skdecide::GilControl::Acquire acquire; - bool rr = r->template cast(); - r.reset(); - return rr; - } catch (const std::exception& e) { - Logger::error(std::string("SKDECIDE exception when calling termination checker: ") + e.what()); - throw; - } - }, - debug_logs); - _stdout_redirect = std::make_unique(std::cout, - py::module::import("sys").attr("stdout")); - _stderr_redirect = std::make_unique(std::cerr, - py::module::import("sys").attr("stderr")); - } - - virtual ~Implementation() {} - - void check_domain(py::object& domain) { - if (!py::hasattr(domain, "get_applicable_actions")) { - throw std::invalid_argument("SKDECIDE exception: BFWS algorithm needs python domain for implementing get_applicable_actions()"); +private: + class BaseImplementation { + public: + virtual ~BaseImplementation() {} + virtual void close() = 0; + virtual void clear() = 0; + virtual void solve(const py::object &s) = 0; + virtual py::bool_ is_solution_defined_for(const py::object &s) = 0; + virtual py::object get_next_action(const py::object &s) = 0; + virtual py::float_ get_utility(const py::object &s) = 0; + }; + + template class Thashing_policy> + class Implementation : public BaseImplementation { + public: + Implementation( + py::object &domain, + const std::function + &state_features, + const std::function + &heuristic, + const std::function + &termination_checker, + bool debug_logs = false) + : _state_features(state_features), _heuristic(heuristic), + _termination_checker(termination_checker) { + + check_domain(domain); + _domain = std::make_unique>(domain); + _solver = std::make_unique, PyBFWSFeatureVector, + Thashing_policy, Texecution>>( + *_domain, + [this](PyBFWSDomain &d, + const typename PyBFWSDomain::State &s) + -> std::unique_ptr> { + try { + auto fsf = [this](const py::object &dd, const py::object &ss, + [[maybe_unused]] const py::object &ii) { + return _state_features(dd, ss); + }; + std::unique_ptr r = d.call(nullptr, fsf, s.pyobj()); + typename skdecide::GilControl::Acquire acquire; + std::unique_ptr> rr = + std::make_unique>(*r); + r.reset(); + return rr; + } catch (const std::exception &e) { + Logger::error( + std::string( + "SKDECIDE exception when calling state features: ") + + e.what()); + throw; } - if (!py::hasattr(domain, "get_next_state")) { - throw std::invalid_argument("SKDECIDE exception: BFWS algorithm needs python domain for implementing get_sample()"); + }, + [this](PyBFWSDomain & d, + const typename PyBFWSDomain::State &s) -> + typename PyBFWSDomain::Value { + try { + auto fh = [this](const py::object &dd, const py::object &ss, + [[maybe_unused]] const py::object &ii) { + return _heuristic(dd, ss); + }; + return typename PyBFWSDomain::Value( + d.call(nullptr, fh, s.pyobj())); + } catch (const std::exception &e) { + Logger::error( + std::string( + "SKDECIDE exception when calling heuristic estimator: ") + + e.what()); + throw; } - if (!py::hasattr(domain, "get_transition_value")) { - throw std::invalid_argument("SKDECIDE exception: BFWS algorithm needs python domain for implementing get_transition_value()"); + }, + [this](PyBFWSDomain &d, + const typename PyBFWSDomain::State &s) -> bool { + try { + auto ftc = [this](const py::object &dd, const py::object &ss, + [[maybe_unused]] const py::object &ii) { + return _termination_checker(dd, ss); + }; + std::unique_ptr r = d.call(nullptr, ftc, s.pyobj()); + typename skdecide::GilControl::Acquire acquire; + bool rr = r->template cast(); + r.reset(); + return rr; + } catch (const std::exception &e) { + Logger::error( + std::string( + "SKDECIDE exception when calling termination checker: ") + + e.what()); + throw; } - if (!py::hasattr(domain, "is_terminal")) { - throw std::invalid_argument("SKDECIDE exception: BFWS algorithm needs python domain for implementing is_terminal()"); - } - } - - virtual void close() { - _domain->close(); - } - - virtual void clear() { - _solver->clear(); - } - - virtual void solve(const py::object& s) { - typename skdecide::GilControl::Release release; - _solver->solve(s); - } - - virtual py::bool_ is_solution_defined_for(const py::object& s) { - return _solver->is_solution_defined_for(s); - } + }, + debug_logs); + _stdout_redirect = std::make_unique( + std::cout, py::module::import("sys").attr("stdout")); + _stderr_redirect = std::make_unique( + std::cerr, py::module::import("sys").attr("stderr")); + } - virtual py::object get_next_action(const py::object& s) { - return _solver->get_best_action(s).pyobj(); - } + virtual ~Implementation() {} + + void check_domain(py::object &domain) { + if (!py::hasattr(domain, "get_applicable_actions")) { + throw std::invalid_argument( + "SKDECIDE exception: BFWS algorithm needs python domain for " + "implementing get_applicable_actions()"); + } + if (!py::hasattr(domain, "get_next_state")) { + throw std::invalid_argument( + "SKDECIDE exception: BFWS algorithm needs python domain for " + "implementing get_sample()"); + } + if (!py::hasattr(domain, "get_transition_value")) { + throw std::invalid_argument( + "SKDECIDE exception: BFWS algorithm needs python domain for " + "implementing get_transition_value()"); + } + if (!py::hasattr(domain, "is_terminal")) { + throw std::invalid_argument( + "SKDECIDE exception: BFWS algorithm needs python domain for " + "implementing is_terminal()"); + } + } - virtual py::float_ get_utility(const py::object& s) { - return _solver->get_best_value(s); - } + virtual void close() { _domain->close(); } - private : - std::unique_ptr> _domain; - std::unique_ptr, PyBFWSFeatureVector, Thashing_policy, Texecution>> _solver; + virtual void clear() { _solver->clear(); } - std::function _state_features; - std::function _heuristic; - std::function _termination_checker; + virtual void solve(const py::object &s) { + typename skdecide::GilControl::Release release; + _solver->solve(s); + } - std::unique_ptr _stdout_redirect; - std::unique_ptr _stderr_redirect; - }; + virtual py::bool_ is_solution_defined_for(const py::object &s) { + return _solver->is_solution_defined_for(s); + } - struct ExecutionSelector { - bool _parallel; + virtual py::object get_next_action(const py::object &s) { + return _solver->get_best_action(s).pyobj(); + } - ExecutionSelector(bool parallel) : _parallel(parallel) {} + virtual py::float_ get_utility(const py::object &s) { + return _solver->get_best_value(s); + } - template - struct Select { - template - Select(ExecutionSelector& This, Args... args) { - if (This._parallel) { - Propagator::template PushType::Forward(args...); - } else { - Propagator::template PushType::Forward(args...); - } - } - }; + private: + std::unique_ptr> _domain; + std::unique_ptr, + PyBFWSFeatureVector, + Thashing_policy, Texecution>> + _solver; + + std::function + _state_features; + std::function + _heuristic; + std::function + _termination_checker; + + std::unique_ptr _stdout_redirect; + std::unique_ptr _stderr_redirect; + }; + + struct ExecutionSelector { + bool _parallel; + + ExecutionSelector(bool parallel) : _parallel(parallel) {} + + template struct Select { + template + Select(ExecutionSelector &This, Args... args) { + if (This._parallel) { + Propagator::template PushType::Forward(args...); + } else { + Propagator::template PushType::Forward(args...); + } + } }; + }; - struct HashingPolicySelector { - bool _use_state_feature_hash; + struct HashingPolicySelector { + bool _use_state_feature_hash; - HashingPolicySelector(bool use_state_feature_hash) + HashingPolicySelector(bool use_state_feature_hash) : _use_state_feature_hash(use_state_feature_hash) {} - template - struct Select { - template - Select(HashingPolicySelector& This, Args... args) { - if (This._use_state_feature_hash) { - Propagator::template PushTemplate::Forward(args...); - } else { - Propagator::template PushTemplate::Forward(args...); - } - } - }; + template struct Select { + template + Select(HashingPolicySelector &This, Args... args) { + if (This._use_state_feature_hash) { + Propagator::template PushTemplate::Forward(args...); + } else { + Propagator::template PushTemplate::Forward(args...); + } + } }; + }; - struct SolverInstantiator { - std::unique_ptr& _implementation; + struct SolverInstantiator { + std::unique_ptr &_implementation; - SolverInstantiator(std::unique_ptr& implementation) + SolverInstantiator(std::unique_ptr &implementation) : _implementation(implementation) {} - template - struct TypeList { - template