Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply DecisionTree op with Assignment #1137

Merged
merged 1 commit into from
Mar 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions gtsam/discrete/DecisionTree-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,13 @@ namespace gtsam {
return f;
}

/// Apply unary operator with assignment
NodePtr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const override {
NodePtr f(new Leaf(op(choices, constant_)));
return f;
}

// Apply binary operator "h = f op g" on Leaf node
// Note op is not assumed commutative so we need to keep track of order
// Simply calls apply on argument to call correct virtual method:
Expand Down Expand Up @@ -322,12 +329,48 @@ namespace gtsam {
for (const NodePtr& branch : f.branches_) push_back(branch->apply(op));
}

/**
* @brief Constructor which accepts a UnaryAssignment op and the
* corresponding assignment.
*
* @param label The label for this node.
* @param f The original choice node to apply the op on.
* @param op Function to apply on the choice node. Takes Assignment and
* value as arguments.
* @param choices The Assignment that will go to op.
*/
Choice(const L& label, const Choice& f, const UnaryAssignment& op,
const Assignment<L>& choices)
: label_(label), allSame_(true) {
branches_.reserve(f.branches_.size()); // reserve space

Assignment<L> choices_ = choices;

for (size_t i = 0; i < f.branches_.size(); i++) {
choices_[label_] = i; // Set assignment for label to i

const NodePtr branch = f.branches_[i];
push_back(branch->apply(op, choices_));

// Remove the choice so we are backtracking
auto choice_it = choices_.find(label_);
choices_.erase(choice_it);
}
}

/** apply unary operator */
NodePtr apply(const Unary& op) const override {
auto r = boost::make_shared<Choice>(label_, *this, op);
return Unique(r);
}

/// Apply unary operator with assignment
NodePtr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const override {
auto r = boost::make_shared<Choice>(label_, *this, op, choices);
return Unique(r);
}

// Apply binary operator "h = f op g" on Choice node
// Note op is not assumed commutative so we need to keep track of order
// Simply calls apply on argument to call correct virtual method:
Expand Down Expand Up @@ -739,6 +782,20 @@ namespace gtsam {
return DecisionTree(root_->apply(op));
}

/// Apply unary operator with assignment
template <typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::apply(
const UnaryAssignment& op) const {
std::cout << "Calling the correct apply" << std::endl;
// It is unclear what should happen if tree is empty:
if (empty()) {
throw std::runtime_error(
"DecisionTree::apply(unary op) undefined for empty tree.");
}
Assignment<L> choices;
return DecisionTree(root_->apply(op, choices));
}

/****************************************************************************/
template<typename L, typename Y>
DecisionTree<L, Y> DecisionTree<L, Y>::apply(const DecisionTree& g,
Expand Down
20 changes: 20 additions & 0 deletions gtsam/discrete/DecisionTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ namespace gtsam {

/** Handy typedefs for unary and binary function types */
using Unary = std::function<Y(const Y&)>;
using UnaryAssignment = std::function<Y(const Assignment<L>&, const Y&)>;
using Binary = std::function<Y(const Y&, const Y&)>;

/** A label annotated with cardinality */
Expand Down Expand Up @@ -103,6 +104,8 @@ namespace gtsam {
&DefaultCompare) const = 0;
virtual const Y& operator()(const Assignment<L>& x) const = 0;
virtual Ptr apply(const Unary& op) const = 0;
virtual Ptr apply(const UnaryAssignment& op,
const Assignment<L>& choices) const = 0;
virtual Ptr apply_f_op_g(const Node&, const Binary&) const = 0;
virtual Ptr apply_g_op_fL(const Leaf&, const Binary&) const = 0;
virtual Ptr apply_g_op_fC(const Choice&, const Binary&) const = 0;
Expand Down Expand Up @@ -283,6 +286,16 @@ namespace gtsam {
/** apply Unary operation "op" to f */
DecisionTree apply(const Unary& op) const;

/**
* @brief Apply Unary operation "op" to f while also providing the
* corresponding assignment.
*
* @param op Function which takes Assignment<L> and Y as input and returns
* object of type Y.
* @return DecisionTree
*/
DecisionTree apply(const UnaryAssignment& op) const;

/** apply binary operation "op" to f and g */
DecisionTree apply(const DecisionTree& g, const Binary& op) const;

Expand Down Expand Up @@ -337,6 +350,13 @@ namespace gtsam {
return f.apply(op);
}

/// Apply unary operator `op` with Assignment to DecisionTree `f`.
template<typename L, typename Y>
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
const typename DecisionTree<L, Y>::UnaryAssignment& op) {
return f.apply(op);
}

/// Apply binary operator `op` to DecisionTree `f`.
template<typename L, typename Y>
DecisionTree<L, Y> apply(const DecisionTree<L, Y>& f,
Expand Down
28 changes: 28 additions & 0 deletions gtsam/discrete/tests/testDecisionTree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ struct DT : public DecisionTree<string, int> {
auto valueFormatter = [](const int& v) {
return (boost::format("%d") % v).str();
};
std::cout << s;
Base::print("", keyFormatter, valueFormatter);
}
/// Equality method customized to int node type
Expand Down Expand Up @@ -451,6 +452,33 @@ TEST(DecisionTree, threshold) {
EXPECT_LONGS_EQUAL(2, thresholded.fold(count, 0));
}

/* ************************************************************************** */
// Test apply with assignment.
TEST(DecisionTree, ApplyWithAssignment) {
// Create three level tree
vector<DT::LabelC> keys;
keys += DT::LabelC("C", 2), DT::LabelC("B", 2), DT::LabelC("A", 2);
DT tree(keys, "1 2 3 4 5 6 7 8");

DecisionTree<string, double> probTree(
keys, "0.01 0.02 0.03 0.04 0.05 0.06 0.07 0.08");
double threshold = 0.035;

// We test pruning one tree by indexing into another.
auto pruner = [&](const Assignment<string>& choices, const int& x) {
// Prune out all the leaves with even numbers
if (probTree(choices) < threshold) {
return 0;
} else {
return x;
}
};
DT prunedTree = tree.apply(pruner);

DT expectedTree(keys, "0 0 0 4 5 6 7 8");
EXPECT(assert_equal(expectedTree, prunedTree));
}

/* ************************************************************************* */
int main() {
TestResult tr;
Expand Down