Skip to content

Commit

Permalink
Merge pull request #1270 from borglab/hybrid/hybrid-optimize
Browse files Browse the repository at this point in the history
Linear HybridBayesNet optimization
  • Loading branch information
xsj01 authored Aug 22, 2022
2 parents f124ccc + c4184e1 commit ef066a0
Show file tree
Hide file tree
Showing 11 changed files with 776 additions and 0 deletions.
9 changes: 9 additions & 0 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
* @brief A bayes net of Gaussian Conditionals indexed by discrete keys.
* @author Fan Jiang
* @author Varun Agrawal
* @author Shangjie Xue
* @date January 2022
*/

#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/hybrid/HybridLookupDAG.h>

namespace gtsam {

Expand All @@ -40,4 +43,10 @@ GaussianBayesNet HybridBayesNet::choose(
return gbn;
}

/* *******************************************************************************/
HybridValues HybridBayesNet::optimize() const {
auto dag = HybridLookupDAG::FromBayesNet(*this);
return dag.argmax();
}

} // namespace gtsam
6 changes: 6 additions & 0 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#pragma once

#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/linear/GaussianBayesNet.h>

Expand Down Expand Up @@ -61,6 +62,11 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* @return GaussianBayesNet
*/
GaussianBayesNet choose(const DiscreteValues &assignment) const;

/// Solve the HybridBayesNet by back-substitution.
/// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and
/// put this method there?
HybridValues optimize() const;
};

} // namespace gtsam
76 changes: 76 additions & 0 deletions gtsam/hybrid/HybridLookupDAG.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */

/**
* @file DiscreteLookupDAG.cpp
* @date Aug, 2022
* @author Shangjie Xue
*/

#include <gtsam/discrete/DiscreteBayesNet.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridLookupDAG.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/linear/VectorValues.h>

#include <string>
#include <utility>

using std::pair;
using std::vector;

namespace gtsam {

/* ************************************************************************** */
void HybridLookupTable::argmaxInPlace(HybridValues* values) const {
// For discrete conditional, uses argmaxInPlace() method in
// DiscreteLookupTable.
if (isDiscrete()) {
boost::static_pointer_cast<DiscreteLookupTable>(inner_)->argmaxInPlace(
&(values->discrete));
} else if (isContinuous()) {
// For Gaussian conditional, uses solve() method in GaussianConditional.
values->continuous.insert(
boost::static_pointer_cast<GaussianConditional>(inner_)->solve(
values->continuous));
} else if (isHybrid()) {
// For hybrid conditional, since children should not contain discrete
// variable, we can condition on the discrete variable in the parents and
// solve the resulting GaussianConditional.
auto conditional =
boost::static_pointer_cast<GaussianMixture>(inner_)->conditionals()(
values->discrete);
values->continuous.insert(conditional->solve(values->continuous));
}
}

/* ************************************************************************** */
HybridLookupDAG HybridLookupDAG::FromBayesNet(const HybridBayesNet& bayesNet) {
HybridLookupDAG dag;
for (auto&& conditional : bayesNet) {
HybridLookupTable hlt(*conditional);
dag.push_back(hlt);
}
return dag;
}

/* ************************************************************************** */
HybridValues HybridLookupDAG::argmax(HybridValues result) const {
// Argmax each node in turn in topological sort order (parents first).
for (auto lookupTable : boost::adaptors::reverse(*this))
lookupTable->argmaxInPlace(&result);
return result;
}

} // namespace gtsam
119 changes: 119 additions & 0 deletions gtsam/hybrid/HybridLookupDAG.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */

/**
* @file HybridLookupDAG.h
* @date Aug, 2022
* @author Shangjie Xue
*/

#pragma once

#include <gtsam/discrete/DiscreteDistribution.h>
#include <gtsam/discrete/DiscreteLookupDAG.h>
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/BayesNet.h>
#include <gtsam/inference/FactorGraph.h>

#include <boost/shared_ptr.hpp>
#include <string>
#include <utility>
#include <vector>

namespace gtsam {

/**
* @brief HybridLookupTable table for max-product
*
* Similar to DiscreteLookupTable, inherits from hybrid conditional for
* convenience. Is used in the max-product algorithm.
*/
class GTSAM_EXPORT HybridLookupTable : public HybridConditional {
public:
using Base = HybridConditional;
using This = HybridLookupTable;
using shared_ptr = boost::shared_ptr<This>;
using BaseConditional = Conditional<DecisionTreeFactor, This>;

/**
* @brief Construct a new Hybrid Lookup Table object form a HybridConditional.
*
* @param conditional input hybrid conditional
*/
HybridLookupTable(HybridConditional& conditional) : Base(conditional){};

/**
* @brief Calculate assignment for frontal variables that maximizes value.
* @param (in/out) parentsValues Known assignments for the parents.
*/
void argmaxInPlace(HybridValues* parentsValues) const;
};

/** A DAG made from hybrid lookup tables, as defined above. Similar to
* DiscreteLookupDAG */
class GTSAM_EXPORT HybridLookupDAG : public BayesNet<HybridLookupTable> {
public:
using Base = BayesNet<HybridLookupTable>;
using This = HybridLookupDAG;
using shared_ptr = boost::shared_ptr<This>;

/// @name Standard Constructors
/// @{

/// Construct empty DAG.
HybridLookupDAG() {}

/// Create from BayesNet with LookupTables
static HybridLookupDAG FromBayesNet(const HybridBayesNet& bayesNet);

/// Destructor
virtual ~HybridLookupDAG() {}

/// @}

/// @name Standard Interface
/// @{

/** Add a DiscreteLookupTable */
template <typename... Args>
void add(Args&&... args) {
emplace_shared<HybridLookupTable>(std::forward<Args>(args)...);
}

/**
* @brief argmax by back-substitution, optionally given certain variables.
*
* Assumes the DAG is reverse topologically sorted, i.e. last
* conditional will be optimized first *and* that the
* DAG does not contain any conditionals for the given variables. If the DAG
* resulted from eliminating a factor graph, this is true for the elimination
* ordering.
*
* @return given assignment extended w. optimal assignment for all variables.
*/
HybridValues argmax(HybridValues given = HybridValues()) const;
/// @}

private:
/** Serialization function */
friend class boost::serialization::access;
template <class ARCHIVE>
void serialize(ARCHIVE& ar, const unsigned int /*version*/) {
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base);
}
};

// traits
template <>
struct traits<HybridLookupDAG> : public Testable<HybridLookupDAG> {};

} // namespace gtsam
127 changes: 127 additions & 0 deletions gtsam/hybrid/HybridValues.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */

/**
* @file HybridValues.h
* @date Jul 28, 2022
* @author Shangjie Xue
*/

#pragma once

#include <gtsam/discrete/Assignment.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/inference/Key.h>
#include <gtsam/linear/VectorValues.h>
#include <gtsam/nonlinear/Values.h>

#include <map>
#include <string>
#include <vector>

namespace gtsam {

/**
* HybridValues represents a collection of DiscreteValues and VectorValues. It
* is typically used to store the variables of a HybridGaussianFactorGraph.
* Optimizing a HybridGaussianBayesNet returns this class.
*/
class GTSAM_EXPORT HybridValues {
public:
// DiscreteValue stored the discrete components of the HybridValues.
DiscreteValues discrete;

// VectorValue stored the continuous components of the HybridValues.
VectorValues continuous;

// Default constructor creates an empty HybridValues.
HybridValues() : discrete(), continuous(){};

// Construct from DiscreteValues and VectorValues.
HybridValues(const DiscreteValues& dv, const VectorValues& cv)
: discrete(dv), continuous(cv){};

// print required by Testable for unit testing
void print(const std::string& s = "HybridValues",
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
std::cout << s << ": \n";
discrete.print(" Discrete", keyFormatter); // print discrete components
continuous.print(" Continuous",
keyFormatter); // print continuous components
};

// equals required by Testable for unit testing
bool equals(const HybridValues& other, double tol = 1e-9) const {
return discrete.equals(other.discrete, tol) &&
continuous.equals(other.continuous, tol);
}

// Check whether a variable with key \c j exists in DiscreteValue.
bool existsDiscrete(Key j) { return (discrete.find(j) != discrete.end()); };

// Check whether a variable with key \c j exists in VectorValue.
bool existsVector(Key j) { return continuous.exists(j); };

// Check whether a variable with key \c j exists.
bool exists(Key j) { return existsDiscrete(j) || existsVector(j); };

/** Insert a discrete \c value with key \c j. Replaces the existing value if
* the key \c j is already used.
* @param value The vector to be inserted.
* @param j The index with which the value will be associated. */
void insert(Key j, int value) { discrete[j] = value; };

/** Insert a vector \c value with key \c j. Throws an invalid_argument
* exception if the key \c j is already used.
* @param value The vector to be inserted.
* @param j The index with which the value will be associated. */
void insert(Key j, const Vector& value) { continuous.insert(j, value); }

// TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h

/**
* Read/write access to the discrete value with key \c j, throws
* std::out_of_range if \c j does not exist.
*/
size_t& atDiscrete(Key j) { return discrete.at(j); };

/**
* Read/write access to the vector value with key \c j, throws
* std::out_of_range if \c j does not exist.
*/
Vector& at(Key j) { return continuous.at(j); };

/// @name Wrapper support
/// @{

/**
* @brief Output as a html table.
*
* @param keyFormatter function that formats keys.
* @return string html output.
*/
std::string html(
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const {
std::stringstream ss;
ss << this->discrete.html(keyFormatter);
ss << this->continuous.html(keyFormatter);
return ss.str();
};

/// @}
};

// traits
template <>
struct traits<HybridValues> : public Testable<HybridValues> {};

} // namespace gtsam
Loading

0 comments on commit ef066a0

Please sign in to comment.