Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Harmonize the C++ solvers API #362

Merged
merged 28 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
729e630
Refactored LRTDP
fteicht May 5, 2024
767e31d
LRTDP refactoring fixes
fteicht May 6, 2024
f19290e
RIW refactoring
fteicht May 6, 2024
ddbb556
Correct LRTDP reference paper year
fteicht May 6, 2024
def8a0a
Remove useless comment in LRTDP header class
fteicht May 6, 2024
77c62e4
Improve LRTDP and RIW documentation
fteicht May 6, 2024
45f3ac6
rename residual moving averages in LRTDP and RIW
fteicht May 6, 2024
10a53d2
rename moving average window in LRTDP and RIW
fteicht May 6, 2024
70fadc0
Refactor MCTS
fteicht May 7, 2024
79028a1
Some more harmonizations
fteicht May 7, 2024
8637fe8
Debug callbacks with parallel domains
fteicht May 12, 2024
e648ae6
Progress solver API harmonization
fteicht May 13, 2024
73b41b3
Protect solver getters from race conditions when called from the call…
fteicht May 13, 2024
fb32aad
Merge branch 'airbus:master' into sanitarize-and-document-solver-api
fteicht May 13, 2024
542ebc7
Progress new solver api
fteicht May 14, 2024
87afbf0
Add tests for cpp solvers with callbacks
nhuet May 14, 2024
70c82ab
Update aostar.hh
fteicht May 14, 2024
674971c
WIP c++ solver API change
fteicht May 14, 2024
92324dd
Correct typos in AOstar docs
fteicht May 14, 2024
bf15714
Progressing on solver API harmonization
fteicht May 15, 2024
c060f3b
WIP solver API refactoring
fteicht May 15, 2024
df661e5
Progressing on solver API harmonization
fteicht May 15, 2024
85d7395
WiP solver API harmonization
fteicht May 16, 2024
ed141f1
Rename debug_logs as verbose and properly doc stringify for autodoc.py
fteicht May 16, 2024
3c32ea6
WiP solver API harmonization
fteicht May 16, 2024
98deda4
WiP solver API harmonization
fteicht May 17, 2024
e495119
Merge branch 'airbus:master' into sanitarize-and-document-solver-api
fteicht May 17, 2024
36eba3d
Update callback unit test for cpp solvers
fteicht May 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 147 additions & 13 deletions cpp/src/hub/solver/aostar/aostar.hh
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,24 @@
#include <memory>
#include <unordered_set>
#include <list>
#include <queue>

#include "utils/associative_container_deducer.hh"
#include "utils/execution.hh"

namespace skdecide {

/**
* @brief This is the skdecide implementation of the AO* algorithm for searching
* cost-minimal policies in additive AND/OR graphs with admissible heuristics
* as described in "Principles of Artificial Intelligence" by Nilsson, N. (1980)
*
* @tparam Tdomain Type of the domain class
* @tparam Texecution_policy Type of the execution policy (one of
* 'SequentialExecution' to generate state-action transitions in sequence,
* or 'ParallelExecution' to generate state-action transitions in parallel on
* different threads)
*/
template <typename Tdomain, typename Texecution_policy = SequentialExecution>
class AOStarSolver {
public:
Expand All @@ -25,31 +37,146 @@ public:
typedef typename Domain::Value Value;
typedef Texecution_policy ExecutionPolicy;

typedef std::function<Predicate(Domain &, const State &)> GoalCheckerFunctor;
typedef std::function<Value(Domain &, const State &)> HeuristicFunctor;
typedef std::function<bool(const AOStarSolver &, Domain &)> CallbackFunctor;

/**
* @brief Construct a new AOStarSolver object
*
* @param domain The domain instance
* @param goal_checker Functor taking as arguments the domain and a state
* object, and returning true if the state is the goal
* @param heuristic Functor taking as arguments the domain and a state object,
* and returning the heuristic estimate from the state to the goal
* @param discount Value function's discount factor
* @param max_tip_expansions Maximum number of states to extract from the
* priority queue at each iteration before recomputing the policy graph
* @param detect_cycles Boolean indicating whether cycles in the search graph
* should be automatically detected (true) or not (false), knowing that the
* AO* algorithm is not meant to work with graph cycles into which it might be
* infinitely trapped
* @param callback Functor called before popping the next state from the
* priority queue, taking as arguments the solver and the domain, and
* returning true if the solver must be stopped
* @param verbose Boolean indicating whether verbose messages should be
* logged (true) or not (false)
*/
AOStarSolver(
Domain &domain,
const std::function<Predicate(Domain &, const State &)> &goal_checker,
const std::function<Value(Domain &, const State &)> &heuristic,
double discount = 1.0, std::size_t max_tip_expansions = 1,
bool detect_cycles = false, bool debug_logs = false);

// clears the solver (clears the search graph, thus preventing from reusing
// previous search results)
Domain &domain, const GoalCheckerFunctor &goal_checker,
const HeuristicFunctor &heuristic, double discount = 1.0,
std::size_t max_tip_expansions = 1, bool detect_cycles = false,
const CallbackFunctor &callback = [](const AOStarSolver &,
Domain &) { return false; },
bool verbose = false);

/**
* @brief Clears the search graph, thus preventing from reusing previous
* search results)
*
*/
void clear();

// solves from state s using heuristic function h
/**
* @brief Call the AO* algorithm
*
* @param s Root state of the search from which AO* graph traversals are
* performed
*/
void solve(const State &s);

/**
* @brief Indicates whether the solution policy is defined for a given state
*
* @param s State for which an entry is searched in the policy graph
* @return true If the state has been explored and an action is defined in
* this state
* @return false If the state has not been explored or no action is defined in
* this state
*/
bool is_solution_defined_for(const State &s) const;

/**
* @brief Get the best computed action in terms of best Q-value in a given
* state (throws a runtime error exception if no action is defined in the
* given state, which is why it is advised to call
* AOStarSolver::is_solution_defined_for before).
*
* @param s State for which the best action is requested
* @return const Action& Best computed action
*/
const Action &get_best_action(const State &s) const;
const double &get_best_value(const State &s) const;

/**
* @brief Get the best Q-value in a given state (throws a runtime
* error exception if no action is defined in the given state, which is why it
* is advised to call AOStarSolver::is_solution_defined_for before)
*
* @param s State from which the best Q-value is requested
* @return double Minimum Q-value of the given state over the applicable
* actions in this state
*/
Value get_best_value(const State &s) const;

/**
* @brief Get the number of states present in the search graph
*
* @return std::size_t Number of states present in the search graph
*/
std::size_t get_nb_explored_states() const;

/**
* @brief Get the set of states present in the search graph (i.e. the graph's
* state nodes minus the nodes' encapsulation and their children)
*
* @return SetTypeDeducer<State>::Set Set of states present in the search
* graph
*/
typename SetTypeDeducer<State>::Set get_explored_states() const;

/**
* @brief Get the number of states present in the priority queue (i.e. those
* explored states that have not been yet expanded)
*
* @return std::size_t Number of states present in the priority queue
*/
std::size_t get_nb_tip_states() const;

/**
* @brief Get the top tip state, i.e. the tip state with the lowest value
* function
*
* @return const State& Next tip state to be expanded by the algorithm
*/
const State &get_top_tip_state() const;

/**
* @brief Get the solving time in milliseconds since the beginning of the
* search from the root solving state
*
* @return std::size_t Solving time in milliseconds
*/
std::size_t get_solving_time() const;

/**
* @brief Get the (partial) solution policy defined for the states for which
* the Q-value has been updated at least once (which is optimal for the
* non-tip states reachable by this policy
*
* @return Mapping from states to pairs of action and best Q-value
*/
typename MapTypeDeducer<State, std::pair<Action, Value>>::Map
get_policy() const;

private:
Domain &_domain;
std::function<bool(Domain &, const State &)> _goal_checker;
std::function<Value(Domain &, const State &)> _heuristic;
GoalCheckerFunctor _goal_checker;
HeuristicFunctor _heuristic;
double _discount;
std::size_t _max_tip_expansions;
bool _detect_cycles;
bool _debug_logs;
CallbackFunctor _callback;
bool _verbose;
ExecutionPolicy _execution_policy;

struct ActionNode;
Expand Down Expand Up @@ -85,6 +212,13 @@ private:

typedef typename SetTypeDeducer<StateNode, State>::Set Graph;
Graph _graph;

typedef std::priority_queue<StateNode *, std::vector<StateNode *>,
StateNodeCompare>
PriorityQueue;
PriorityQueue _priority_queue;

std::chrono::time_point<std::chrono::high_resolution_clock> _start_time;
};

} // namespace skdecide
Expand Down
Loading
Loading