diff --git a/gtsam/discrete/DecisionTreeFactor.cpp b/gtsam/discrete/DecisionTreeFactor.cpp index e95b8fe374..acd4d4af2b 100644 --- a/gtsam/discrete/DecisionTreeFactor.cpp +++ b/gtsam/discrete/DecisionTreeFactor.cpp @@ -286,5 +286,39 @@ namespace gtsam { AlgebraicDecisionTree(keys, table), cardinalities_(keys.cardinalities()) {} + /* ************************************************************************ */ + DecisionTreeFactor DecisionTreeFactor::prune(size_t maxNrLeaves) const { + const size_t N = maxNrLeaves; + + // Get the probabilities in the decision tree so we can threshold. + std::vector probabilities; + this->visit([&](const double& prob) { probabilities.emplace_back(prob); }); + + // The number of probabilities can be lower than max_leaves + if (probabilities.size() <= N) { + return *this; + } + + std::sort(probabilities.begin(), probabilities.end(), + std::greater{}); + + double threshold = probabilities[N - 1]; + + // Now threshold the decision tree + size_t total = 0; + auto thresholdFunc = [threshold, &total, N](const double& value) { + if (value < threshold || total >= N) { + return 0.0; + } else { + total += 1; + return value; + } + }; + DecisionTree thresholded(*this, thresholdFunc); + + // Create pruned decision tree factor and return. + return DecisionTreeFactor(this->discreteKeys(), thresholded); + } + /* ************************************************************************ */ } // namespace gtsam diff --git a/gtsam/discrete/DecisionTreeFactor.h b/gtsam/discrete/DecisionTreeFactor.h index 91fa7c4849..1f3d692921 100644 --- a/gtsam/discrete/DecisionTreeFactor.h +++ b/gtsam/discrete/DecisionTreeFactor.h @@ -170,6 +170,18 @@ namespace gtsam { /// Return all the discrete keys associated with this factor. DiscreteKeys discreteKeys() const; + /** + * @brief Prune the decision tree of discrete variables. + * + * Pruning will set the leaves to be "pruned" to 0 indicating a 0 + * probability. + * A leaf is pruned if it is not in the top `maxNrLeaves` values. + * + * @param maxNrLeaves The maximum number of leaves to keep. + * @return DecisionTreeFactor + */ + DecisionTreeFactor prune(size_t maxNrLeaves) const; + /// @} /// @name Wrapper support /// @{ diff --git a/gtsam/discrete/tests/testDecisionTreeFactor.cpp b/gtsam/discrete/tests/testDecisionTreeFactor.cpp index 846653c383..83b586bbb2 100644 --- a/gtsam/discrete/tests/testDecisionTreeFactor.cpp +++ b/gtsam/discrete/tests/testDecisionTreeFactor.cpp @@ -106,6 +106,27 @@ TEST(DecisionTreeFactor, enumerate) { EXPECT(actual == expected); } +/* ************************************************************************* */ +// Check pruning of the decision tree works as expected. +TEST(DecisionTreeFactor, Prune) { + DiscreteKey A(1, 2), B(2, 2), C(3, 2); + DecisionTreeFactor f(A & B & C, "1 5 3 7 2 6 4 8"); + + // Only keep the leaves with the top 5 values. + size_t maxNrLeaves = 5; + auto pruned5 = f.prune(maxNrLeaves); + + // Pruned leaves should be 0 + DecisionTreeFactor expected(A & B & C, "0 5 0 7 0 6 4 8"); + EXPECT(assert_equal(expected, pruned5)); + + // Check for more extreme pruning where we only keep the top 2 leaves + maxNrLeaves = 2; + auto pruned2 = f.prune(maxNrLeaves); + DecisionTreeFactor expected2(A & B & C, "0 0 0 7 0 0 0 8"); + EXPECT(assert_equal(expected2, pruned2)); +} + /* ************************************************************************* */ TEST(DecisionTreeFactor, DotWithNames) { DiscreteKey A(12, 3), B(5, 2);