-
Notifications
You must be signed in to change notification settings - Fork 793
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1270 from borglab/hybrid/hybrid-optimize
Linear HybridBayesNet optimization
- Loading branch information
Showing
11 changed files
with
776 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
/* ---------------------------------------------------------------------------- | ||
* GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||
* Atlanta, Georgia 30332-0415 | ||
* All Rights Reserved | ||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||
* See LICENSE for the license information | ||
* -------------------------------------------------------------------------- */ | ||
|
||
/** | ||
* @file DiscreteLookupDAG.cpp | ||
* @date Aug, 2022 | ||
* @author Shangjie Xue | ||
*/ | ||
|
||
#include <gtsam/discrete/DiscreteBayesNet.h> | ||
#include <gtsam/discrete/DiscreteLookupDAG.h> | ||
#include <gtsam/discrete/DiscreteValues.h> | ||
#include <gtsam/hybrid/HybridBayesNet.h> | ||
#include <gtsam/hybrid/HybridConditional.h> | ||
#include <gtsam/hybrid/HybridLookupDAG.h> | ||
#include <gtsam/hybrid/HybridValues.h> | ||
#include <gtsam/linear/VectorValues.h> | ||
|
||
#include <string> | ||
#include <utility> | ||
|
||
using std::pair; | ||
using std::vector; | ||
|
||
namespace gtsam { | ||
|
||
/* ************************************************************************** */ | ||
void HybridLookupTable::argmaxInPlace(HybridValues* values) const { | ||
// For discrete conditional, uses argmaxInPlace() method in | ||
// DiscreteLookupTable. | ||
if (isDiscrete()) { | ||
boost::static_pointer_cast<DiscreteLookupTable>(inner_)->argmaxInPlace( | ||
&(values->discrete)); | ||
} else if (isContinuous()) { | ||
// For Gaussian conditional, uses solve() method in GaussianConditional. | ||
values->continuous.insert( | ||
boost::static_pointer_cast<GaussianConditional>(inner_)->solve( | ||
values->continuous)); | ||
} else if (isHybrid()) { | ||
// For hybrid conditional, since children should not contain discrete | ||
// variable, we can condition on the discrete variable in the parents and | ||
// solve the resulting GaussianConditional. | ||
auto conditional = | ||
boost::static_pointer_cast<GaussianMixture>(inner_)->conditionals()( | ||
values->discrete); | ||
values->continuous.insert(conditional->solve(values->continuous)); | ||
} | ||
} | ||
|
||
/* ************************************************************************** */ | ||
HybridLookupDAG HybridLookupDAG::FromBayesNet(const HybridBayesNet& bayesNet) { | ||
HybridLookupDAG dag; | ||
for (auto&& conditional : bayesNet) { | ||
HybridLookupTable hlt(*conditional); | ||
dag.push_back(hlt); | ||
} | ||
return dag; | ||
} | ||
|
||
/* ************************************************************************** */ | ||
HybridValues HybridLookupDAG::argmax(HybridValues result) const { | ||
// Argmax each node in turn in topological sort order (parents first). | ||
for (auto lookupTable : boost::adaptors::reverse(*this)) | ||
lookupTable->argmaxInPlace(&result); | ||
return result; | ||
} | ||
|
||
} // namespace gtsam |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
/* ---------------------------------------------------------------------------- | ||
* GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||
* Atlanta, Georgia 30332-0415 | ||
* All Rights Reserved | ||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||
* See LICENSE for the license information | ||
* -------------------------------------------------------------------------- */ | ||
|
||
/** | ||
* @file HybridLookupDAG.h | ||
* @date Aug, 2022 | ||
* @author Shangjie Xue | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <gtsam/discrete/DiscreteDistribution.h> | ||
#include <gtsam/discrete/DiscreteLookupDAG.h> | ||
#include <gtsam/hybrid/HybridConditional.h> | ||
#include <gtsam/hybrid/HybridValues.h> | ||
#include <gtsam/inference/BayesNet.h> | ||
#include <gtsam/inference/FactorGraph.h> | ||
|
||
#include <boost/shared_ptr.hpp> | ||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
namespace gtsam { | ||
|
||
/** | ||
* @brief HybridLookupTable table for max-product | ||
* | ||
* Similar to DiscreteLookupTable, inherits from hybrid conditional for | ||
* convenience. Is used in the max-product algorithm. | ||
*/ | ||
class GTSAM_EXPORT HybridLookupTable : public HybridConditional { | ||
public: | ||
using Base = HybridConditional; | ||
using This = HybridLookupTable; | ||
using shared_ptr = boost::shared_ptr<This>; | ||
using BaseConditional = Conditional<DecisionTreeFactor, This>; | ||
|
||
/** | ||
* @brief Construct a new Hybrid Lookup Table object form a HybridConditional. | ||
* | ||
* @param conditional input hybrid conditional | ||
*/ | ||
HybridLookupTable(HybridConditional& conditional) : Base(conditional){}; | ||
|
||
/** | ||
* @brief Calculate assignment for frontal variables that maximizes value. | ||
* @param (in/out) parentsValues Known assignments for the parents. | ||
*/ | ||
void argmaxInPlace(HybridValues* parentsValues) const; | ||
}; | ||
|
||
/** A DAG made from hybrid lookup tables, as defined above. Similar to | ||
* DiscreteLookupDAG */ | ||
class GTSAM_EXPORT HybridLookupDAG : public BayesNet<HybridLookupTable> { | ||
public: | ||
using Base = BayesNet<HybridLookupTable>; | ||
using This = HybridLookupDAG; | ||
using shared_ptr = boost::shared_ptr<This>; | ||
|
||
/// @name Standard Constructors | ||
/// @{ | ||
|
||
/// Construct empty DAG. | ||
HybridLookupDAG() {} | ||
|
||
/// Create from BayesNet with LookupTables | ||
static HybridLookupDAG FromBayesNet(const HybridBayesNet& bayesNet); | ||
|
||
/// Destructor | ||
virtual ~HybridLookupDAG() {} | ||
|
||
/// @} | ||
|
||
/// @name Standard Interface | ||
/// @{ | ||
|
||
/** Add a DiscreteLookupTable */ | ||
template <typename... Args> | ||
void add(Args&&... args) { | ||
emplace_shared<HybridLookupTable>(std::forward<Args>(args)...); | ||
} | ||
|
||
/** | ||
* @brief argmax by back-substitution, optionally given certain variables. | ||
* | ||
* Assumes the DAG is reverse topologically sorted, i.e. last | ||
* conditional will be optimized first *and* that the | ||
* DAG does not contain any conditionals for the given variables. If the DAG | ||
* resulted from eliminating a factor graph, this is true for the elimination | ||
* ordering. | ||
* | ||
* @return given assignment extended w. optimal assignment for all variables. | ||
*/ | ||
HybridValues argmax(HybridValues given = HybridValues()) const; | ||
/// @} | ||
|
||
private: | ||
/** Serialization function */ | ||
friend class boost::serialization::access; | ||
template <class ARCHIVE> | ||
void serialize(ARCHIVE& ar, const unsigned int /*version*/) { | ||
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(Base); | ||
} | ||
}; | ||
|
||
// traits | ||
template <> | ||
struct traits<HybridLookupDAG> : public Testable<HybridLookupDAG> {}; | ||
|
||
} // namespace gtsam |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
/* ---------------------------------------------------------------------------- | ||
* GTSAM Copyright 2010, Georgia Tech Research Corporation, | ||
* Atlanta, Georgia 30332-0415 | ||
* All Rights Reserved | ||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list) | ||
* See LICENSE for the license information | ||
* -------------------------------------------------------------------------- */ | ||
|
||
/** | ||
* @file HybridValues.h | ||
* @date Jul 28, 2022 | ||
* @author Shangjie Xue | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <gtsam/discrete/Assignment.h> | ||
#include <gtsam/discrete/DiscreteKey.h> | ||
#include <gtsam/discrete/DiscreteValues.h> | ||
#include <gtsam/inference/Key.h> | ||
#include <gtsam/linear/VectorValues.h> | ||
#include <gtsam/nonlinear/Values.h> | ||
|
||
#include <map> | ||
#include <string> | ||
#include <vector> | ||
|
||
namespace gtsam { | ||
|
||
/** | ||
* HybridValues represents a collection of DiscreteValues and VectorValues. It | ||
* is typically used to store the variables of a HybridGaussianFactorGraph. | ||
* Optimizing a HybridGaussianBayesNet returns this class. | ||
*/ | ||
class GTSAM_EXPORT HybridValues { | ||
public: | ||
// DiscreteValue stored the discrete components of the HybridValues. | ||
DiscreteValues discrete; | ||
|
||
// VectorValue stored the continuous components of the HybridValues. | ||
VectorValues continuous; | ||
|
||
// Default constructor creates an empty HybridValues. | ||
HybridValues() : discrete(), continuous(){}; | ||
|
||
// Construct from DiscreteValues and VectorValues. | ||
HybridValues(const DiscreteValues& dv, const VectorValues& cv) | ||
: discrete(dv), continuous(cv){}; | ||
|
||
// print required by Testable for unit testing | ||
void print(const std::string& s = "HybridValues", | ||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { | ||
std::cout << s << ": \n"; | ||
discrete.print(" Discrete", keyFormatter); // print discrete components | ||
continuous.print(" Continuous", | ||
keyFormatter); // print continuous components | ||
}; | ||
|
||
// equals required by Testable for unit testing | ||
bool equals(const HybridValues& other, double tol = 1e-9) const { | ||
return discrete.equals(other.discrete, tol) && | ||
continuous.equals(other.continuous, tol); | ||
} | ||
|
||
// Check whether a variable with key \c j exists in DiscreteValue. | ||
bool existsDiscrete(Key j) { return (discrete.find(j) != discrete.end()); }; | ||
|
||
// Check whether a variable with key \c j exists in VectorValue. | ||
bool existsVector(Key j) { return continuous.exists(j); }; | ||
|
||
// Check whether a variable with key \c j exists. | ||
bool exists(Key j) { return existsDiscrete(j) || existsVector(j); }; | ||
|
||
/** Insert a discrete \c value with key \c j. Replaces the existing value if | ||
* the key \c j is already used. | ||
* @param value The vector to be inserted. | ||
* @param j The index with which the value will be associated. */ | ||
void insert(Key j, int value) { discrete[j] = value; }; | ||
|
||
/** Insert a vector \c value with key \c j. Throws an invalid_argument | ||
* exception if the key \c j is already used. | ||
* @param value The vector to be inserted. | ||
* @param j The index with which the value will be associated. */ | ||
void insert(Key j, const Vector& value) { continuous.insert(j, value); } | ||
|
||
// TODO(Shangjie)- update() and insert_or_assign() , similar to Values.h | ||
|
||
/** | ||
* Read/write access to the discrete value with key \c j, throws | ||
* std::out_of_range if \c j does not exist. | ||
*/ | ||
size_t& atDiscrete(Key j) { return discrete.at(j); }; | ||
|
||
/** | ||
* Read/write access to the vector value with key \c j, throws | ||
* std::out_of_range if \c j does not exist. | ||
*/ | ||
Vector& at(Key j) { return continuous.at(j); }; | ||
|
||
/// @name Wrapper support | ||
/// @{ | ||
|
||
/** | ||
* @brief Output as a html table. | ||
* | ||
* @param keyFormatter function that formats keys. | ||
* @return string html output. | ||
*/ | ||
std::string html( | ||
const KeyFormatter& keyFormatter = DefaultKeyFormatter) const { | ||
std::stringstream ss; | ||
ss << this->discrete.html(keyFormatter); | ||
ss << this->continuous.html(keyFormatter); | ||
return ss.str(); | ||
}; | ||
|
||
/// @} | ||
}; | ||
|
||
// traits | ||
template <> | ||
struct traits<HybridValues> : public Testable<HybridValues> {}; | ||
|
||
} // namespace gtsam |
Oops, something went wrong.