Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DiscreteFactorGraph::optimize to return MPE #1050

Merged
merged 21 commits into from
Jan 22, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions gtsam/discrete/DecisionTreeFactor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <gtsam/base/FastSet.h>

#include <boost/make_shared.hpp>
#include <boost/format.hpp>
#include <utility>

using namespace std;
Expand Down Expand Up @@ -65,9 +66,13 @@ namespace gtsam {

/* ************************************************************************* */
void DecisionTreeFactor::print(const string& s,
const KeyFormatter& formatter) const {
const KeyFormatter& formatter) const {
cout << s;
ADT::print("Potentials:",formatter);
cout << " f[";
for (auto&& key : keys())
cout << boost::format(" (%1%,%2%),") % formatter(key) % cardinality(key);
cout << " ]" << endl;
ADT::print("Potentials:", formatter);
}

/* ************************************************************************* */
Expand Down
7 changes: 6 additions & 1 deletion gtsam/discrete/DecisionTreeFactor.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,16 @@ namespace gtsam {
return combine(keys, ADT::Ring::add);
}

/// Create new factor by maximizing over all values with the same separator values
/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(size_t nrFrontals) const {
return combine(nrFrontals, ADT::Ring::max);
}

/// Create new factor by maximizing over all values with the same separator.
shared_ptr max(const Ordering& keys) const {
return combine(keys, ADT::Ring::max);
}

/// @}
/// @name Advanced Interface
/// @{
Expand Down
7 changes: 7 additions & 0 deletions gtsam/discrete/DiscreteBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,24 @@ double DiscreteBayesNet::evaluate(const DiscreteValues& values) const {
}

/* ************************************************************************* */
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
DiscreteValues DiscreteBayesNet::optimize() const {
DiscreteValues result;
return optimize(result);
}

DiscreteValues DiscreteBayesNet::optimize(DiscreteValues result) const {
// solve each node in turn in topological sort order (parents first)
#ifdef _MSC_VER
#pragma message("DiscreteBayesNet::optimize (deprecated) does not compute MPE!")
#else
#warning "DiscreteBayesNet::optimize (deprecated) does not compute MPE!"
#endif
for (auto conditional : boost::adaptors::reverse(*this))
conditional->solveInPlace(&result);
return result;
}
#endif

/* ************************************************************************* */
DiscreteValues DiscreteBayesNet::sample() const {
Expand Down
36 changes: 12 additions & 24 deletions gtsam/discrete/DiscreteBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@

namespace gtsam {

/** A Bayes net made from linear-Discrete densities */
/** A Bayes net made from discrete conditional distributions. */
class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional>
{
public:

typedef FactorGraph<DiscreteConditional> Base;
typedef BayesNet<DiscreteConditional> Base;
dellaert marked this conversation as resolved.
Show resolved Hide resolved
typedef DiscreteBayesNet This;
typedef DiscreteConditional ConditionalType;
typedef boost::shared_ptr<This> shared_ptr;
Expand All @@ -45,7 +45,7 @@ namespace gtsam {
/// @name Standard Constructors
/// @{

/** Construct empty factor graph */
/// Construct empty Bayes net.
DiscreteBayesNet() {}

/** Construct from iterator over conditionals */
Expand Down Expand Up @@ -98,27 +98,6 @@ namespace gtsam {
return evaluate(values);
}

/**
* @brief solve by back-substitution.
*
* Assumes the Bayes net is reverse topologically sorted, i.e. last
* conditional will be optimized first. If the Bayes net resulted from
* eliminating a factor graph, this is true for the elimination ordering.
*
* @return a sampled value for all variables.
*/
DiscreteValues optimize() const;

/**
* @brief solve by back-substitution, given certain variables.
*
* Assumes the Bayes net is reverse topologically sorted *and* that the
* Bayes net does not contain any conditionals for the given values.
*
* @return given values extended with optimized value for other variables.
*/
DiscreteValues optimize(DiscreteValues given) const;

/**
* @brief do ancestral sampling
*
Expand Down Expand Up @@ -152,7 +131,16 @@ namespace gtsam {
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteFactor::Names& names = {}) const;

///@}

#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
/// @name Deprecated functionality
/// @{

DiscreteValues GTSAM_DEPRECATED optimize() const;
DiscreteValues GTSAM_DEPRECATED optimize(DiscreteValues given) const;
/// @}
#endif

private:
/** Serialization function */
Expand Down
79 changes: 40 additions & 39 deletions gtsam/discrete/DiscreteConditional.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,25 @@
* @author Frank Dellaert
*/

#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/Signature.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/base/Testable.h>
#include <gtsam/base/debug.h>

#include <boost/make_shared.hpp>

#include <algorithm>
#include <boost/make_shared.hpp>
#include <random>
#include <set>
#include <stdexcept>
#include <string>
#include <vector>
#include <utility>
#include <set>
#include <vector>

using namespace std;
using std::pair;
using std::stringstream;
using std::vector;
using std::pair;
namespace gtsam {

// Instantiate base class
Expand Down Expand Up @@ -147,7 +146,7 @@ void DiscreteConditional::print(const string& s,
cout << endl;
}

/* ******************************************************************************** */
/* ************************************************************************** */
bool DiscreteConditional::equals(const DiscreteFactor& other,
double tol) const {
if (!dynamic_cast<const DecisionTreeFactor*>(&other)) {
Expand All @@ -159,22 +158,21 @@ bool DiscreteConditional::equals(const DiscreteFactor& other,
}

/* ************************************************************************** */
static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
const DiscreteValues& given,
bool forceComplete = true) {
DiscreteConditional::ADT DiscreteConditional::choose(
const DiscreteValues& given, bool forceComplete) const {
// Get the big decision tree with all the levels, and then go down the
// branches based on the value of the parent variables.
DiscreteConditional::ADT adt(conditional);
DiscreteConditional::ADT adt(*this);
size_t value;
for (Key j : conditional.parents()) {
for (Key j : parents()) {
try {
value = given.at(j);
adt = adt.choose(j, value); // ADT keeps getting smaller.
} catch (std::out_of_range&) {
if (forceComplete) {
given.print("parentsValues: ");
throw runtime_error(
"DiscreteConditional::Choose: parent value missing");
"DiscreteConditional::choose: parent value missing");
}
}
}
Expand All @@ -184,7 +182,7 @@ static DiscreteConditional::ADT Choose(const DiscreteConditional& conditional,
/* ************************************************************************** */
DiscreteConditional::shared_ptr DiscreteConditional::choose(
const DiscreteValues& given) const {
ADT adt = Choose(*this, given, false); // P(F|S=given)
ADT adt = choose(given, false); // P(F|S=given)

// Collect all keys not in given.
DiscreteKeys dKeys;
Expand Down Expand Up @@ -225,7 +223,7 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
return boost::make_shared<DecisionTreeFactor>(discreteKeys, adt);
}

/* ******************************************************************************** */
/* ****************************************************************************/
DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
size_t parent_value) const {
if (nrFrontals() != 1)
Expand All @@ -238,8 +236,9 @@ DecisionTreeFactor::shared_ptr DiscreteConditional::likelihood(
}

/* ************************************************************************** */
#ifdef GTSAM_ALLOW_DEPRECATED_SINCE_V42
void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
ADT pFS = Choose(*this, *values); // P(F|S=parentsValues)
ADT pFS = choose(*values, true); // P(F|S=parentsValues)

// Initialize
DiscreteValues mpe;
Expand All @@ -248,59 +247,59 @@ void DiscreteConditional::solveInPlace(DiscreteValues* values) const {
// Get all Possible Configurations
const auto allPosbValues = frontalAssignments();

// Find the MPE
// Find the maximum
dellaert marked this conversation as resolved.
Show resolved Hide resolved
for (const auto& frontalVals : allPosbValues) {
double pValueS = pFS(frontalVals); // P(F=value|S=parentsValues)
// Update MPE solution if better
// Update maximum solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = frontalVals;
}
}

// set values (inPlace) to mpe
// set values (inPlace) to maximum
for (Key j : frontals()) {
(*values)[j] = mpe[j];
}
}

/* ******************************************************************************** */
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
size_t sampled = sample(*values); // Sample variable given parents
(*values)[j] = sampled; // store result in partial solution
}

/* ************************************************************************** */
size_t DiscreteConditional::solve(const DiscreteValues& parentsValues) const {
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)

// Then, find the max over all remaining
// TODO, only works for one key now, seems horribly slow this way
size_t mpe = 0;
DiscreteValues frontals;
size_t max = 0;
double maxP = 0;
DiscreteValues frontals;
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
for (size_t value = 0; value < cardinality(j); value++) {
frontals[j] = value;
double pValueS = pFS(frontals); // P(F=value|S=parentsValues)
// Update MPE solution if better
// Update solution if better
if (pValueS > maxP) {
maxP = pValueS;
mpe = value;
max = value;
}
}
return mpe;
return max;
}
#endif

/* ******************************************************************************** */
/* ************************************************************************** */
void DiscreteConditional::sampleInPlace(DiscreteValues* values) const {
assert(nrFrontals() == 1);
Key j = (firstFrontalKey());
size_t sampled = sample(*values); // Sample variable given parents
(*values)[j] = sampled; // store result in partial solution
}

/* ************************************************************************** */
size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
static mt19937 rng(2); // random number generator

// Get the correct conditional density
ADT pFS = Choose(*this, parentsValues); // P(F|S=parentsValues)
ADT pFS = choose(parentsValues, true); // P(F|S=parentsValues)

// TODO(Duy): only works for one key now, seems horribly slow this way
if (nrFrontals() != 1) {
Expand All @@ -323,7 +322,8 @@ size_t DiscreteConditional::sample(const DiscreteValues& parentsValues) const {
return distribution(rng);
}

/* ******************************************************************************** */
/* ********************************************************************************
*/
dellaert marked this conversation as resolved.
Show resolved Hide resolved
size_t DiscreteConditional::sample(size_t parent_value) const {
if (nrParents() != 1)
throw std::invalid_argument(
Expand All @@ -334,7 +334,8 @@ size_t DiscreteConditional::sample(size_t parent_value) const {
return sample(values);
}

/* ******************************************************************************** */
/* ********************************************************************************
*/
size_t DiscreteConditional::sample() const {
if (nrParents() != 0)
throw std::invalid_argument(
Expand Down
Loading