Skip to content

Commit

Permalink
Merge pull request #2013 from borglab/fix/pruning
Browse files Browse the repository at this point in the history
Fix pruning
  • Loading branch information
dellaert authored Jan 30, 2025
2 parents 588a20f + 9bae03a commit 3cf1590
Show file tree
Hide file tree
Showing 13 changed files with 325 additions and 156 deletions.
54 changes: 54 additions & 0 deletions gtsam/discrete/DiscreteBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteMarginals.h>
#include <gtsam/inference/FactorGraph-inst.h>

namespace gtsam {
Expand Down Expand Up @@ -68,6 +69,59 @@ DiscreteValues DiscreteBayesNet::sample(DiscreteValues result) const {
return result;
}

/* ************************************************************************* */
// The implementation is: build the entire joint into one factor and then prune.
// TODO(Frank): This can be quite expensive *unless* the factors have already
// been pruned before. Another, possibly faster approach is branch and bound
// search to find the K-best leaves and then create a single pruned conditional.
DiscreteBayesNet DiscreteBayesNet::prune(
size_t maxNrLeaves, const std::optional<double>& marginalThreshold,
DiscreteValues* fixedValues) const {
// Multiply into one big conditional. NOTE: possibly quite expensive.
DiscreteConditional joint;
for (const DiscreteConditional::shared_ptr& conditional : *this)
joint = joint * (*conditional);

// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
DiscreteConditional pruned = joint;
pruned.prune(maxNrLeaves);

DiscreteValues deadModesValues;
// If we have a dead mode threshold and discrete variables left after pruning,
// then we run dead mode removal.
if (marginalThreshold.has_value() && pruned.keys().size() > 0) {
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
for (auto dkey : pruned.discreteKeys()) {
const Vector probabilities = marginals.marginalProbabilities(dkey);

int index = -1;
auto threshold = (probabilities.array() > *marginalThreshold);
// If atleast 1 value is non-zero, then we can find the index
// Else if all are zero, index would be set to 0 which is incorrect
if (!threshold.isZero()) {
threshold.maxCoeff(&index);
}

if (index >= 0) {
deadModesValues.emplace(dkey.first, index);
}
}

// Remove the modes (imperative)
pruned.removeDiscreteModes(deadModesValues);

// Set the fixed values if requested.
if (fixedValues) {
*fixedValues = deadModesValues;
}
}

// Return the resulting DiscreteBayesNet.
DiscreteBayesNet result;
if (pruned.keys().size() > 0) result.push_back(pruned);
return result;
}

/* *********************************************************************** */
std::string DiscreteBayesNet::markdown(
const KeyFormatter& keyFormatter,
Expand Down
12 changes: 12 additions & 0 deletions gtsam/discrete/DiscreteBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,18 @@ class GTSAM_EXPORT DiscreteBayesNet: public BayesNet<DiscreteConditional> {
*/
DiscreteValues sample(DiscreteValues given) const;

/**
* @brief Prune the Bayes net
*
* @param maxNrLeaves The maximum number of leaves to keep.
* @param marginalThreshold If given, threshold on marginals to prune variables.
* @param fixedValues If given, return the fixed values removed.
* @return A new DiscreteBayesNet with pruned conditionals.
*/
DiscreteBayesNet prune(size_t maxNrLeaves,
const std::optional<double>& marginalThreshold = {},
DiscreteValues* fixedValues = nullptr) const;

///@}
/// @name Wrapper support
/// @{
Expand Down
90 changes: 73 additions & 17 deletions gtsam/discrete/DiscreteValues.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,28 +68,82 @@ class GTSAM_EXPORT DiscreteValues : public Assignment<Key> {
friend std::ostream& operator<<(std::ostream& os, const DiscreteValues& x);

// insert in base class;
std::pair<iterator, bool> insert( const value_type& value ){
std::pair<iterator, bool> insert(const value_type& value) {
return Base::insert(value);
}

/**
* Insert key-assignment pair.
* Throws an invalid_argument exception if
* any keys to be inserted are already used. */
* @brief Insert key-assignment pair.
*
* @param assignment The key-assignment pair to insert.
* @return DiscreteValues& Reference to the updated DiscreteValues object.
* @throws std::invalid_argument if any keys to be inserted are already used.
*/
DiscreteValues& insert(const std::pair<Key, size_t>& assignment);

/** Insert all values from \c values. Throws an invalid_argument exception if
* any keys to be inserted are already used. */
/**
* @brief Insert all values from another DiscreteValues object.
*
* @param values The DiscreteValues object containing values to insert.
* @return DiscreteValues& Reference to the updated DiscreteValues object.
* @throws std::invalid_argument if any keys to be inserted are already used.
*/
DiscreteValues& insert(const DiscreteValues& values);

/** For all key/value pairs in \c values, replace values with corresponding
* keys in this object with those in \c values. Throws std::out_of_range if
* any keys in \c values are not present in this object. */
/**
* @brief Update values with corresponding keys from another DiscreteValues
* object.
*
* @param values The DiscreteValues object containing values to update.
* @return DiscreteValues& Reference to the updated DiscreteValues object.
* @throws std::out_of_range if any keys in values are not present in this
* object.
*/
DiscreteValues& update(const DiscreteValues& values);

/**
* @brief Return a vector of DiscreteValues, one for each possible
* combination of values.
* @brief Check if the DiscreteValues contains the given key.
*
* @param key The key to check for.
* @return True if the key is present, false otherwise.
*/
bool contains(Key key) const { return this->find(key) != this->end(); }

/**
* @brief Filter values by keys.
*
* @param keys The keys to filter by.
* @return DiscreteValues The filtered DiscreteValues object.
*/
DiscreteValues filter(const DiscreteKeys& keys) const {
DiscreteValues result;
for (const auto& [key, _] : keys) {
if (auto it = this->find(key); it != this->end())
result[key] = it->second;
}
return result;
}

/**
* @brief Return the keys that are not present in the DiscreteValues object.
*
* @param keys The keys to check for.
* @return DiscreteKeys Keys not present in the DiscreteValues object.
*/
DiscreteKeys missingKeys(const DiscreteKeys& keys) const {
DiscreteKeys result;
for (const auto& [key, cardinality] : keys) {
if (!this->contains(key)) result.emplace_back(key, cardinality);
}
return result;
}

/**
* @brief Return a vector of DiscreteValues, one for each possible combination
* of values.
*
* @param keys The keys to generate the Cartesian product for.
* @return std::vector<DiscreteValues> The vector of DiscreteValues.
*/
static std::vector<DiscreteValues> CartesianProduct(
const DiscreteKeys& keys) {
Expand Down Expand Up @@ -135,14 +189,16 @@ inline std::vector<DiscreteValues> cartesianProduct(const DiscreteKeys& keys) {
}

/// Free version of markdown.
std::string GTSAM_EXPORT markdown(const DiscreteValues& values,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteValues::Names& names = {});
std::string GTSAM_EXPORT
markdown(const DiscreteValues& values,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteValues::Names& names = {});

/// Free version of html.
std::string GTSAM_EXPORT html(const DiscreteValues& values,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteValues::Names& names = {});
std::string GTSAM_EXPORT
html(const DiscreteValues& values,
const KeyFormatter& keyFormatter = DefaultKeyFormatter,
const DiscreteValues::Names& names = {});

// traits
template <>
Expand Down
18 changes: 18 additions & 0 deletions gtsam/discrete/tests/testDiscreteValues.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,24 @@ TEST(DiscreteValues, Update) {
DiscreteValues(kExample).update({{12, 2}})));
}

/* ************************************************************************* */
// Test DiscreteValues::filter
TEST(DiscreteValues, Filter) {
DiscreteValues values = {{12, 1}, {5, 0}, {13, 2}};
DiscreteKeys keys = {{12, 0}, {13, 0}, {99, 0}}; // 99 is missing in values

EXPECT(assert_equal(DiscreteValues({{12, 1}, {13, 2}}), values.filter(keys)));
}

/* ************************************************************************* */
// Test DiscreteValues::missingKeys
TEST(DiscreteValues, MissingKeys) {
DiscreteValues values = {{12, 1}, {5, 0}};
DiscreteKeys keys = {{12, 0}, {5, 0}, {99, 0}, {42, 0}}; // 99 and 42 are missing

EXPECT(assert_equal(DiscreteKeys({{99, 0}, {42, 0}}), values.missingKeys(keys)));
}

/* ************************************************************************* */
// Check markdown representation with a value formatter.
TEST(DiscreteValues, markdownWithValueFormatter) {
Expand Down
124 changes: 34 additions & 90 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteMarginals.h>
#include <gtsam/discrete/TableDistribution.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridValues.h>
Expand All @@ -43,115 +42,60 @@ bool HybridBayesNet::equals(const This &bn, double tol) const {
}

/* ************************************************************************* */
// The implementation is: build the entire joint into one factor and then prune.
// TODO(Frank): This can be quite expensive *unless* the factors have already
// been pruned before. Another, possibly faster approach is branch and bound
// search to find the K-best leaves and then create a single pruned conditional.
HybridBayesNet HybridBayesNet::prune(
size_t maxNrLeaves, const std::optional<double> &deadModeThreshold) const {
size_t maxNrLeaves, const std::optional<double> &marginalThreshold,
DiscreteValues *fixedValues) const {
#if GTSAM_HYBRID_TIMING
gttic_(HybridPruning);
#endif
// Collect all the discrete conditionals. Could be small if already pruned.
const DiscreteBayesNet marginal = discreteMarginal();

// Prune discrete Bayes net
DiscreteValues fixed;
auto prunedBN = marginal.prune(maxNrLeaves, marginalThreshold, &fixed);

// Multiply into one big conditional. NOTE: possibly quite expensive.
DiscreteConditional joint;
for (auto &&conditional : marginal) {
joint = joint * (*conditional);
DiscreteConditional pruned;
for (auto &&conditional : prunedBN) pruned = pruned * (*conditional);

// Set the fixed values if requested.
if (marginalThreshold && fixedValues) {
*fixedValues = fixed;
}

// Initialize the resulting HybridBayesNet.
HybridBayesNet result;

// Prune the joint. NOTE: imperative and, again, possibly quite expensive.
DiscreteConditional pruned = joint;
pruned.prune(maxNrLeaves);

DiscreteValues deadModesValues;
// If we have a dead mode threshold and discrete variables left after pruning,
// then we run dead mode removal.
if (deadModeThreshold.has_value() && pruned.keys().size() > 0) {
DiscreteMarginals marginals(DiscreteFactorGraph{pruned});
for (auto dkey : pruned.discreteKeys()) {
Vector probabilities = marginals.marginalProbabilities(dkey);

int index = -1;
auto threshold = (probabilities.array() > *deadModeThreshold);
// If atleast 1 value is non-zero, then we can find the index
// Else if all are zero, index would be set to 0 which is incorrect
if (!threshold.isZero()) {
threshold.maxCoeff(&index);
}

if (index >= 0) {
deadModesValues.emplace(dkey.first, index);
}
}

// Remove the modes (imperative)
pruned.removeDiscreteModes(deadModesValues);

/*
If the pruned discrete conditional has any keys left,
we add it to the HybridBayesNet.
If not, it means it is an orphan so we don't add this pruned joint,
and instead add only the marginals below.
*/
if (pruned.keys().size() > 0) {
result.emplace_shared<DiscreteConditional>(pruned);
}
// Go through all the Gaussian conditionals, restrict them according to
// fixed values, and then prune further.
for (std::shared_ptr<gtsam::HybridConditional> conditional : *this) {
if (conditional->isDiscrete()) continue;

// Add the marginals for future factors
for (auto &&[key, _] : deadModesValues) {
result.push_back(
std::dynamic_pointer_cast<DiscreteConditional>(marginals(key)));
}
// No-op if not a HybridGaussianConditional.
if (marginalThreshold) conditional = conditional->restrict(fixed);

} else {
result.emplace_shared<DiscreteConditional>(pruned);
}

/* To prune, we visitWith every leaf in the HybridGaussianConditional.
* For each leaf, using the assignment we can check the discrete decision tree
* for 0.0 probability, then just set the leaf to a nullptr.
*
* We can later check the HybridGaussianConditional for just nullptrs.
*/

// Go through all the Gaussian conditionals in the Bayes Net and prune them as
// per pruned discrete joint.
for (auto &&conditional : *this) {
// Now decide on type what to do:
if (auto hgc = conditional->asHybrid()) {
// Prune the hybrid Gaussian conditional!
auto prunedHybridGaussianConditional = hgc->prune(pruned);

if (deadModeThreshold.has_value()) {
KeyVector deadKeys, conditionalDiscreteKeys;
for (const auto &kv : deadModesValues) {
deadKeys.push_back(kv.first);
}
for (auto dkey : prunedHybridGaussianConditional->discreteKeys()) {
conditionalDiscreteKeys.push_back(dkey.first);
}
// The discrete keys in the conditional are the same as the keys in the
// dead modes, then we just get the corresponding Gaussian conditional.
if (deadKeys == conditionalDiscreteKeys) {
result.push_back(
prunedHybridGaussianConditional->choose(deadModesValues));
} else {
// Add as-is
result.push_back(prunedHybridGaussianConditional);
}
} else {
// Type-erase and add to the pruned Bayes Net fragment.
result.push_back(prunedHybridGaussianConditional);
if (!prunedHybridGaussianConditional) {
throw std::runtime_error(
"A HybridGaussianConditional had all its conditionals pruned");
}

// Type-erase and add to the pruned Bayes Net fragment.
result.push_back(prunedHybridGaussianConditional);
} else if (auto gc = conditional->asGaussian()) {
// Add the non-HybridGaussianConditional conditional
result.push_back(gc);
}
// We ignore DiscreteConditional as they are already pruned and added.
} else
throw std::runtime_error(
"HybrdiBayesNet::prune: Unknown HybridConditional type.");
}

// Add the pruned discrete conditionals to the result.
for (const DiscreteConditional::shared_ptr &discrete : prunedBN)
result.push_back(discrete);

return result;
}

Expand Down
Loading

0 comments on commit 3cf1590

Please sign in to comment.