Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DiscreteFactorGraph::optimize to return MPE #1050

Merged
merged 21 commits into from
Jan 22, 2022
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/scripts/python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,6 @@ cmake $GITHUB_WORKSPACE -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \
make -j2 install

cd $GITHUB_WORKSPACE/build/python
$PYTHON setup.py install --user --prefix=
pip install --user --install-option="--prefix=" .
cd $GITHUB_WORKSPACE/python/gtsam/tests
$PYTHON -m unittest discover -v
9 changes: 4 additions & 5 deletions examples/DiscreteBayesNetExample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@ int main(int argc, char **argv) {
// Create solver and eliminate
Ordering ordering;
ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7);
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);

// solve
auto mpe = chordal->optimize();
auto mpe = fg.optimize();
GTSAM_PRINT(mpe);

// We can also build a Bayes tree (directed junction tree).
Expand All @@ -69,14 +68,14 @@ int main(int argc, char **argv) {
fg.add(Dyspnea, "0 1");

// solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
auto mpe2 = chordal2->optimize();
auto mpe2 = fg.optimize();
GTSAM_PRINT(mpe2);

// We can also sample from it
DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
cout << "\n10 samples:" << endl;
for (size_t i = 0; i < 10; i++) {
auto sample = chordal2->sample();
auto sample = chordal->sample();
GTSAM_PRINT(sample);
}
return 0;
Expand Down
8 changes: 4 additions & 4 deletions examples/DiscreteBayesNet_FG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ int main(int argc, char **argv) {
}

// "Most Probable Explanation", i.e., configuration with largest value
auto mpe = graph.eliminateSequential()->optimize();
auto mpe = graph.optimize();
cout << "\nMost Probable Explanation (MPE):" << endl;
print(mpe);

Expand All @@ -96,8 +96,7 @@ int main(int argc, char **argv) {
graph.add(Cloudy, "1 0");

// solve again, now with evidence
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
auto mpe_with_evidence = chordal->optimize();
auto mpe_with_evidence = graph.optimize();

cout << "\nMPE given C=0:" << endl;
print(mpe_with_evidence);
Expand All @@ -110,7 +109,8 @@ int main(int argc, char **argv) {
cout << "\nP(W=1|C=0):" << marginals.marginalProbabilities(WetGrass)[1]
<< endl;

// We can also sample from it
// We can also sample from the eliminated graph
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
cout << "\n10 samples:" << endl;
for (size_t i = 0; i < 10; i++) {
auto sample = chordal->sample();
Expand Down
8 changes: 4 additions & 4 deletions examples/HMMExample.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,16 @@ int main(int argc, char **argv) {
// Convert to factor graph
DiscreteFactorGraph factorGraph(hmm);

// Do max-prodcut
auto mpe = factorGraph.optimize();
GTSAM_PRINT(mpe);

// Create solver and eliminate
// This will create a DAG ordered with arrow of time reversed
DiscreteBayesNet::shared_ptr chordal =
factorGraph.eliminateSequential(ordering);
chordal->print("Eliminated");

// solve
auto mpe = chordal->optimize();
GTSAM_PRINT(mpe);

// We can also sample from it
cout << "\n10 samples:" << endl;
for (size_t k = 0; k < 10; k++) {
Expand Down
5 changes: 2 additions & 3 deletions examples/UGM_chain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,8 @@ int main(int argc, char** argv) {
<< graph.size() << " factors (Unary+Edge).";

// "Decoding", i.e., configuration with largest value
// We use sequential variable elimination
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
auto optimalDecoding = chordal->optimize();
// Uses max-product.
auto optimalDecoding = graph.optimize();
optimalDecoding.print("\nMost Probable Explanation (optimalDecoding)\n");

// "Inference" Computing marginals for each node
Expand Down
5 changes: 2 additions & 3 deletions examples/UGM_small.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,8 @@ int main(int argc, char** argv) {
}

// "Decoding", i.e., configuration with largest value (MPE)
// We use sequential variable elimination
DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
auto optimalDecoding = chordal->optimize();
// Uses max-product
auto optimalDecoding = graph.optimize();
GTSAM_PRINT(optimalDecoding);

// "Inference" Computing marginals
Expand Down
9 changes: 7 additions & 2 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <gtsam/base/FastSet.h>

#include <boost/make_shared.hpp>
#include <boost/format.hpp>
#include <utility>

using namespace std;
Expand Down Expand Up @@ -65,9 +66,13 @@ namespace gtsam {

/* ************************************************************************* */
void DecisionTreeFactor::print(const string& s,
const KeyFormatter& formatter) const {
const KeyFormatter& formatter) const {
cout << s;
ADT::print("Potentials:",formatter);
cout << " f[";
for (auto&& key : keys())
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
cout << " ]" << endl;
ADT::print("Potentials:", formatter);
}

/* ************************************************************************* */
Expand Down
7 changes: 6 additions & 1 deletion gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,16 @@ namespace gtsam {
return combine(keys, ADT::Ring::add);
}

/// Create new factor by maximizing over all values with the same separator values
/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(size_t nrFrontals) const {
return combine(nrFrontals, ADT::Ring::max);
}

/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(const Ordering& keys) const {
return combine(keys, ADT::Ring::max);
}

/// @}
/// @name Advanced Interface
/// @{
Expand Down
7 changes: 7 additions & 0 deletions gtsam/discrete/DiscreteBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,24 @@ double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
}

/* ************************************************************************* */
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
DiscreteValues DiscreteBayesNet::optimize() const {
DiscreteValues result;
return optimize(result);
}

DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const {
// solve each node in turn in topological sort order (parents first)
#ifdef _MSC_VER
#pragma message("DiscreteBayesNet::optimize (deprecated) does not compute MPE!")
#else
#warning "DiscreteBayesNet::optimize (deprecated) does not compute MPE!"
#endif
for (auto conditional : boost::adaptors::reverse(*this))
conditional->solveInPlace(&result);
return result;
}
#endif

/* ************************************************************************* */
DiscreteValues DiscreteBayesNet::sample() const {
Expand Down
36 changes: 12 additions & 24 deletions gtsam/discrete/DiscreteBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@

namespace gtsam {

/** A Bayes net made from linear-Discrete densities */
/** A Bayes net made from discrete conditional distributions. */
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional>
{
public:

typedef FactorGraph<DiscreteConditional> Base;
typedef BayesNet<DiscreteConditional> Base;
dellaert marked this conversation as resolved.
Show resolved Hide resolved
typedef DiscreteBayesNet This;
typedef DiscreteConditional ConditionalType;
typedef boost::shared_ptr<This> shared_ptr;
Expand All @@ -45,7 +45,7 @@ namespace gtsam {
/// @name Standard Constructors
/// @{

/** Construct empty factor graph */
/// Construct empty Bayes net.
DiscreteBayesNet() {}

/** Construct from iterator over conditionals */
Expand Down Expand Up @@ -98,27 +98,6 @@ namespace gtsam {
return evaluate(values);
}

/**
* @brief solve by back-substitution.
*
* Assumes the Bayes net is reverse topologically sorted, i.e. last
* conditional will be optimized first. If the Bayes net resulted from
* eliminating a factor graph, this is true for the elimination ordering.
*
* @return a sampled value for all variables.
*/
DiscreteValues optimize() const;

/**
* @brief solve by back-substitution, given certain variables.
*
* Assumes the Bayes net is reverse topologically sorted *and* that the
* Bayes net does not contain any conditionals for the given values.
*
* @return given values extended with optimized value for other variables.
*/
DiscreteValues optimize(DiscreteValues given) const;

/**
* @brief do ancestral sampling
*
Expand Down Expand Up @@ -152,7 +131,16 @@ namespace gtsam {
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const;

///@}

#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality
/// @{

DiscreteValues GTSAM_DEPRECATED optimize() const;
DiscreteValues GTSAM_DEPRECATED optimize(DiscreteValues given) const;
/// @}
#endif

private:
/** Serialization function */
Expand Down
Loading