Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid Bayes Net/Tree Optimize #1280

Merged
merged 7 commits into from
Aug 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions gtsam/hybrid/HybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
*/

#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/hybrid/HybridLookupDAG.h>
#include <gtsam/hybrid/HybridValues.h>

namespace gtsam {

Expand Down Expand Up @@ -112,22 +112,27 @@ HybridBayesNet HybridBayesNet::prune(

/* ************************************************************************* */
GaussianMixture::shared_ptr HybridBayesNet::atGaussian(size_t i) const {
return boost::dynamic_pointer_cast<GaussianMixture>(factors_.at(i)->inner());
return factors_.at(i)->asMixture();
}

/* ************************************************************************* */
DiscreteConditional::shared_ptr HybridBayesNet::atDiscrete(size_t i) const {
return boost::dynamic_pointer_cast<DiscreteConditional>(
factors_.at(i)->inner());
return factors_.at(i)->asDiscreteConditional();
}

/* ************************************************************************* */
GaussianBayesNet HybridBayesNet::choose(
const DiscreteValues &assignment) const {
GaussianBayesNet gbn;
for (size_t idx = 0; idx < size(); idx++) {
GaussianMixture gm = *this->atGaussian(idx);
gbn.push_back(gm(assignment));
try {
GaussianMixture gm = *this->atGaussian(idx);
gbn.push_back(gm(assignment));

} catch (std::exception &exc) {
// if factor at `idx` is discrete-only, just continue.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert that idx is indeed discrete only

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made the assertion in #1282

continue;
}
}
return gbn;
}
Expand All @@ -138,4 +143,10 @@ HybridValues HybridBayesNet::optimize() const {
return dag.argmax();
}

/* *******************************************************************************/
VectorValues HybridBayesNet::optimize(const DiscreteValues &assignment) const {
GaussianBayesNet gbn = this->choose(assignment);
return gbn.optimize();
}

} // namespace gtsam
9 changes: 9 additions & 0 deletions gtsam/hybrid/HybridBayesNet.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
/// TODO(Shangjie) do we need to create a HybridGaussianBayesNet class, and
/// put this method there?
HybridValues optimize() const;

/**
* @brief Given the discrete assignment, return the optimized estimate for the
* selected Gaussian BayesNet.
*
* @param assignment An assignment of discrete values.
* @return Values
*/
VectorValues optimize(const DiscreteValues &assignment) const;
};

} // namespace gtsam
44 changes: 44 additions & 0 deletions gtsam/hybrid/HybridBayesTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,48 @@ bool HybridBayesTree::equals(const This& other, double tol) const {
return Base::equals(other, tol);
}

/* ************************************************************************* */
VectorValues HybridBayesTree::optimize(const DiscreteValues& assignment) const {
GaussianBayesNet gbn;

KeyVector added_keys;

// Iterate over all the nodes in the BayesTree
for (auto&& node : nodes()) {
// Check if conditional being added is already in the Bayes net.
if (std::find(added_keys.begin(), added_keys.end(), node.first) ==
added_keys.end()) {
// Access the clique and get the underlying hybrid conditional
HybridBayesTreeClique::shared_ptr clique = node.second;
HybridConditional::shared_ptr conditional = clique->conditional();

KeyVector frontals(conditional->frontals().begin(),
conditional->frontals().end());

// Record the key being added
added_keys.insert(added_keys.end(), frontals.begin(), frontals.end());

// If conditional is hybrid (and not discrete-only), we get the Gaussian
// Conditional corresponding to the assignment and add it to the Gaussian
// Bayes Net.
if (conditional->isHybrid()) {
auto gm = conditional->asMixture();
GaussianConditional::shared_ptr gaussian_conditional =
(*gm)(assignment);

gbn.push_back(gaussian_conditional);
}
}
}
// If TBB is enabled, the bayes net order gets reversed,
// so we pre-reverse it
#ifdef GTSAM_USE_TBB
auto reversed = boost::adaptors::reverse(gbn);
gbn = GaussianBayesNet(reversed.begin(), reversed.end());
#endif

// Return the optimized bayes net.
return gbn.optimize();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean that gbn.optimize does not work when the bayes net has non sequential order?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The GBN needs to be topologically sorted so that it optimizes parents before children, and for some reason, the GBN is reverse topologically sorted only when TBB is enabled.

}

} // namespace gtsam
9 changes: 9 additions & 0 deletions gtsam/hybrid/HybridBayesTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ class GTSAM_EXPORT HybridBayesTree : public BayesTree<HybridBayesTreeClique> {
/** Check equality */
bool equals(const This& other, double tol = 1e-9) const;

/**
* @brief Recursively optimize the BayesTree to produce a vector solution.
*
* @param assignment The discrete values assignment to select the Gaussian
* mixtures.
* @return VectorValues
*/
VectorValues optimize(const DiscreteValues& assignment) const;

/// @}
};

Expand Down
9 changes: 4 additions & 5 deletions gtsam/hybrid/HybridConditional.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class GTSAM_EXPORT HybridConditional
BaseConditional; ///< Typedef to our conditional base class

protected:
// Type-erased pointer to the inner type
/// Type-erased pointer to the inner type
boost::shared_ptr<Factor> inner_;

public:
Expand Down Expand Up @@ -127,8 +127,7 @@ class GTSAM_EXPORT HybridConditional
* @param gaussianMixture Gaussian Mixture Conditional used to create the
* HybridConditional.
*/
HybridConditional(
boost::shared_ptr<GaussianMixture> gaussianMixture);
HybridConditional(boost::shared_ptr<GaussianMixture> gaussianMixture);

/**
* @brief Return HybridConditional as a GaussianMixture
Expand Down Expand Up @@ -168,10 +167,10 @@ class GTSAM_EXPORT HybridConditional
/// Get the type-erased pointer to the inner type
boost::shared_ptr<Factor> inner() { return inner_; }

}; // DiscreteConditional
}; // HybridConditional

// traits
template <>
struct traits<HybridConditional> : public Testable<DiscreteConditional> {};
struct traits<HybridConditional> : public Testable<HybridConditional> {};

} // namespace gtsam
22 changes: 19 additions & 3 deletions gtsam/hybrid/HybridGaussianFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ continuousElimination(const HybridGaussianFactorGraph &factors,
for (auto &fp : factors) {
if (auto ptr = boost::dynamic_pointer_cast<HybridGaussianFactor>(fp)) {
gfg.push_back(ptr->inner());
} else if (auto p =
boost::static_pointer_cast<HybridConditional>(fp)->inner()) {
gfg.push_back(boost::static_pointer_cast<GaussianConditional>(p));
} else if (auto ptr = boost::static_pointer_cast<HybridConditional>(fp)) {
gfg.push_back(
boost::static_pointer_cast<GaussianConditional>(ptr->inner()));
} else {
// It is an orphan wrapped conditional
}
Expand Down Expand Up @@ -401,4 +401,20 @@ void HybridGaussianFactorGraph::add(DecisionTreeFactor::shared_ptr factor) {
FactorGraph::add(boost::make_shared<HybridDiscreteFactor>(factor));
}

/* ************************************************************************ */
const Ordering HybridGaussianFactorGraph::getHybridOrdering(
OptionalOrderingType orderingType) const {
KeySet discrete_keys;
for (auto &factor : factors_) {
for (const DiscreteKey &k : factor->discreteKeys()) {
discrete_keys.insert(k.first);
}
}

const VariableIndex index(factors_);
Ordering ordering = Ordering::ColamdConstrainedLast(
index, KeyVector(discrete_keys.begin(), discrete_keys.end()), true);
return ordering;
}

} // namespace gtsam
9 changes: 9 additions & 0 deletions gtsam/hybrid/HybridGaussianFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,15 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
Base::push_back(sharedFactor);
}
}

/**
* @brief
*
* @param orderingType
* @return const Ordering
*/
const Ordering getHybridOrdering(
OptionalOrderingType orderingType = boost::none) const;
};

} // namespace gtsam
4 changes: 2 additions & 2 deletions gtsam/hybrid/tests/Switching.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ struct Switching {
// Add "motion models".
for (size_t k = 1; k < K; k++) {
KeyVector keys = {X(k), X(k + 1)};
auto motion_models = motionModels(k);
auto motion_models = motionModels(k, between_sigma);
std::vector<NonlinearFactor::shared_ptr> components;
for (auto &&f : motion_models) {
components.push_back(boost::dynamic_pointer_cast<NonlinearFactor>(f));
Expand All @@ -155,7 +155,7 @@ struct Switching {
}

// Add measurement factors
auto measurement_noise = noiseModel::Isotropic::Sigma(1, 0.1);
auto measurement_noise = noiseModel::Isotropic::Sigma(1, prior_sigma);
for (size_t k = 2; k <= K; k++) {
nonlinearFactorGraph.emplace_nonlinear<PriorFactor<double>>(
X(k), 1.0 * (k - 1), measurement_noise);
Expand Down
71 changes: 71 additions & 0 deletions gtsam/hybrid/tests/testHybridBayesNet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
*/

#include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/nonlinear/NonlinearFactorGraph.h>

#include "Switching.h"

Expand Down Expand Up @@ -85,6 +86,76 @@ TEST(HybridBayesNet, Choose) {
*gbn.at(3)));
}

/* ****************************************************************************/
// Test bayes net optimize
TEST(HybridBayesNet, OptimizeAssignment) {
Switching s(4);

Ordering ordering;
for (auto&& kvp : s.linearizationPoint) {
ordering += kvp.key;
}

HybridBayesNet::shared_ptr hybridBayesNet;
HybridGaussianFactorGraph::shared_ptr remainingFactorGraph;
std::tie(hybridBayesNet, remainingFactorGraph) =
s.linearizedFactorGraph.eliminatePartialSequential(ordering);

DiscreteValues assignment;
assignment[M(1)] = 1;
assignment[M(2)] = 1;
assignment[M(3)] = 1;

VectorValues delta = hybridBayesNet->optimize(assignment);

// The linearization point has the same value as the key index,
// e.g. X(1) = 1, X(2) = 2,
// but the factors specify X(k) = k-1, so delta should be -1.
VectorValues expected_delta;
expected_delta.insert(make_pair(X(1), -Vector1::Ones()));
expected_delta.insert(make_pair(X(2), -Vector1::Ones()));
expected_delta.insert(make_pair(X(3), -Vector1::Ones()));
expected_delta.insert(make_pair(X(4), -Vector1::Ones()));

EXPECT(assert_equal(expected_delta, delta));
}

/* ****************************************************************************/
// Test bayes net optimize
TEST(HybridBayesNet, Optimize) {
Switching s(4);

Ordering ordering;
for (auto&& kvp : s.linearizationPoint) {
ordering += kvp.key;
}

Ordering hybridOrdering = s.linearizedFactorGraph.getHybridOrdering();
HybridBayesNet::shared_ptr hybridBayesNet =
s.linearizedFactorGraph.eliminateSequential(hybridOrdering);

HybridValues delta = hybridBayesNet->optimize();

delta.print();
VectorValues correct;
correct.insert(X(1), 0 * Vector1::Ones());
correct.insert(X(2), 1 * Vector1::Ones());
correct.insert(X(3), 2 * Vector1::Ones());
correct.insert(X(4), 3 * Vector1::Ones());

DiscreteValues assignment111;
assignment111[M(1)] = 1;
assignment111[M(2)] = 1;
assignment111[M(3)] = 1;
std::cout << hybridBayesNet->choose(assignment111).error(correct) << std::endl;

DiscreteValues assignment101;
assignment101[M(1)] = 1;
assignment101[M(2)] = 0;
assignment101[M(3)] = 1;
std::cout << hybridBayesNet->choose(assignment101).error(correct) << std::endl;
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down
Loading