Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of Set Membership in TreeEnsemble #21222

Closed
wants to merge 11 commits into from
Closed
6 changes: 5 additions & 1 deletion onnxruntime/core/providers/cpu/ml/ml_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ enum NODE_MODE : uint8_t {
BRANCH_GTE = 6,
BRANCH_GT = 8,
BRANCH_EQ = 10,
BRANCH_NEQ = 12
BRANCH_NEQ = 12,
BRANCH_MEMBER = 14
};

static inline NODE_MODE MakeTreeNodeMode(const std::string& input) {
Expand All @@ -49,6 +50,9 @@ static inline NODE_MODE MakeTreeNodeMode(const std::string& input) {
if (input == "BRANCH_EQ") {
return NODE_MODE::BRANCH_EQ;
}
if (input == "BRANCH_MEMBER") {
return NODE_MODE::BRANCH_MEMBER;
}
return NODE_MODE::BRANCH_NEQ;
}

Expand Down
176 changes: 154 additions & 22 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,48 @@
void ComputeAgg(concurrency::ThreadPool* ttp, const Tensor* X, Tensor* Y, Tensor* label, const AGG& agg) const;

private:
bool CheckIfSubtreesAreEqual(const size_t left_id, const size_t right_id, const int64_t tree_id, const InlinedVector<NODE_MODE>& cmodes,

Check warning on line 90 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h:90: Lines should be <= 120 characters long [whitespace/line_length] [2]
const InlinedVector<size_t>& truenode_ids, const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,

Check warning on line 91 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h:91: Lines should be <= 120 characters long [whitespace/line_length] [2]
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,

Check warning on line 92 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h:92: Lines should be <= 120 characters long [whitespace/line_length] [2]
const std::vector<float>& target_class_weights, const std::vector<ThresholdType>& target_class_weights_as_tensor,

Check warning on line 93 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h:93: Lines should be <= 120 characters long [whitespace/line_length] [2]
const InlinedVector<TreeNodeElementId>& node_tree_ids, InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices);

Check warning on line 94 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h:94: Lines should be <= 120 characters long [whitespace/line_length] [2]
size_t AddNodes(const size_t i, const InlinedVector<NODE_MODE>& cmodes, const InlinedVector<size_t>& truenode_ids,
const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,
const std::vector<int64_t>& nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping,
int64_t tree_id, const InlinedVector<TreeNodeElementId>& node_tree_ids);
int64_t tree_id, const InlinedVector<TreeNodeElementId>& node_tree_ids, const std::vector<float>& target_class_weights,

Check warning on line 99 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h:99: Lines should be <= 120 characters long [whitespace/line_length] [2]
const std::vector<ThresholdType>& target_class_weights_as_tensor, InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices);

Check warning on line 100 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h:100: Lines should be <= 120 characters long [whitespace/line_length] [2]
};

// Below is simple implementation of `bit_cast` as it is supported from c++20 and the current supported version is c++17
// Remove it when that is not the case
template <class To, class From>
std::enable_if_t<
sizeof(To) == sizeof(From) &&
std::is_trivially_copyable_v<From> &&
std::is_trivially_copyable_v<To>,
To>
// constexpr support needs compiler magic
static bit_cast(const From& src) noexcept {
static_assert(std::is_trivially_constructible_v<To>,
"This implementation additionally requires "
"destination type to be trivially constructible");

To dst;
std::memcpy(&dst, &src, sizeof(To));
return dst;
}

template <typename T>
std::conditional_t<sizeof(T) == sizeof(uint32_t), uint32_t, uint64_t> bit_cast_int(T val) {
if constexpr (sizeof(T) == sizeof(uint32_t)) {
return bit_cast<uint32_t>(val);
} else if constexpr (sizeof(T) == sizeof(uint64_t)) {
return bit_cast<uint64_t>(val);
}
static_assert(sizeof(T) == sizeof(uint32_t) || sizeof(T) == sizeof(uint64_t));
}

template <typename InputType, typename ThresholdType, typename OutputType>
Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(const OpKernelInfo& info) {
std::vector<ThresholdType> base_values_as_tensor, nodes_hitrates_as_tensor,
Expand Down Expand Up @@ -270,6 +305,16 @@
}
}

// Sort targets
InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices;
indices.reserve(target_class_nodeids.size());
for (i = 0, limit = target_class_nodeids.size(); i < limit; i++) {
indices.emplace_back(
TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]}, i);
}

std::sort(indices.begin(), indices.end());

// Let's construct nodes_ such that the false branch is always the next element in nodes_.
// updated_mapping will translates the old position of each node to the new node position in nodes_.
std::vector<size_t> updated_mapping(nodes_treeids.size(), 0);
Expand All @@ -280,26 +325,13 @@
int64_t tree_id = node_tree_ids[i].tree_id;
size_t root_position =
AddNodes(i, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, nodes_values,
nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids,
target_class_weights, target_class_weights_as_tensor, indices);
roots_.push_back(&nodes_[root_position]);
previous_tree_id = tree_id;
}
}

n_trees_ = roots_.size();
if (((int64_t)nodes_.size()) != n_nodes_) {
ORT_THROW("Number of nodes in nodes_ (", nodes_.size(), ") is different from n_nodes (", n_nodes_, ").");
}

// Sort targets
InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices;
indices.reserve(target_class_nodeids.size());
for (i = 0, limit = target_class_nodeids.size(); i < limit; i++) {
indices.emplace_back(
std::pair<TreeNodeElementId, uint32_t>(TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]}, i));
}

std::sort(indices.begin(), indices.end());

TreeNodeElementId ind;
SparseValue<ThresholdType> w;
Expand Down Expand Up @@ -341,13 +373,56 @@
return Status::OK();
}

template <typename InputType, typename ThresholdType, typename OutputType>
bool TreeEnsembleCommon<InputType, ThresholdType, OutputType>::CheckIfSubtreesAreEqual(
const size_t left_id, const size_t right_id, const int64_t tree_id, const InlinedVector<NODE_MODE>& cmodes,
const InlinedVector<size_t>& truenode_ids, const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,

Check warning on line 379 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h:379: Lines should be <= 120 characters long [whitespace/line_length] [2]
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,
const std::vector<float>& target_class_weights, const std::vector<ThresholdType>& target_class_weights_as_tensor,
const InlinedVector<TreeNodeElementId>& node_tree_ids, InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices) {

Check warning on line 382 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h:382: Lines should be <= 120 characters long [whitespace/line_length] [2]
// Leaves have values set at 0
if (cmodes[left_id] != cmodes[right_id] || nodes_featureids[left_id] != nodes_featureids[right_id] || (!nodes_values_as_tensor.empty() && nodes_values_as_tensor[left_id] != nodes_values_as_tensor[right_id]) || (nodes_values_as_tensor.empty() && node_values[left_id] != node_values[right_id])) {

Check warning on line 384 in onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h:384: Lines should be <= 120 characters long [whitespace/line_length] [2]
return false;
}

if (cmodes[left_id] == NODE_MODE::LEAF) {
const auto left_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(node_tree_ids[left_id], uint32_t(0)))->second;
const auto right_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(node_tree_ids[right_id], uint32_t(0)))->second;

if (target_class_weights_as_tensor.empty()) {
return target_class_weights[left_target_node] == target_class_weights[right_target_node];
} else {
return target_class_weights_as_tensor[left_target_node] == target_class_weights_as_tensor[right_target_node];
}
}

return CheckIfSubtreesAreEqual(falsenode_ids[left_id], falsenode_ids[right_id], tree_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids,
nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices) &&
CheckIfSubtreesAreEqual(truenode_ids[left_id], truenode_ids[right_id], tree_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids,
nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices);
}

inline void UpdateThreshold(double val, double& mask) {
uint64_t new_mask = bit_cast<uint64_t>(mask) | (1ll << (static_cast<uint32_t>(val) - 1));
mask = bit_cast<double>(new_mask);
}

inline void UpdateThreshold(float val, float& mask) {
uint32_t new_mask = bit_cast<uint32_t>(mask) | (1 << (static_cast<uint32_t>(val) - 1));
mask = bit_cast<float>(new_mask);
}

#define BITCOUNT(T) int64_t(sizeof(T) * 8)
#define CANMASK(v, T) (v >= 1 && v <= BITCOUNT(T)) && v == std::floor(v)

template <typename InputType, typename ThresholdType, typename OutputType>
size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
const size_t i, const InlinedVector<NODE_MODE>& cmodes, const InlinedVector<size_t>& truenode_ids,
const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,
const std::vector<int64_t>& nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping, int64_t tree_id,
const InlinedVector<TreeNodeElementId>& node_tree_ids) {
const InlinedVector<TreeNodeElementId>& node_tree_ids, const std::vector<float>& target_class_weights,
const std::vector<ThresholdType>& target_class_weights_as_tensor, InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices) {
// Validate this index maps to the same tree_id as the one we should be building.
if (node_tree_ids[i].tree_id != tree_id) {
ORT_THROW("Tree id mismatch. Expected ", tree_id, " but got ", node_tree_ids[i].tree_id, " at position ", i);
Expand All @@ -369,23 +444,54 @@
if (node.feature_id > max_feature_id_) {
max_feature_id_ = node.feature_id;
}
node.value_or_unique_weight =
nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[i]) : nodes_values_as_tensor[i];

node.value_or_unique_weight = 0;
const ThresholdType node_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[i]) : nodes_values_as_tensor[i];
if (node.flags == NODE_MODE::BRANCH_EQ && CANMASK(node_threshold, ThresholdType)) {
UpdateThreshold(node_threshold, node.value_or_unique_weight);
node.flags = NODE_MODE::BRANCH_MEMBER;
} else {
node.value_or_unique_weight = node_threshold;
}

if (i < static_cast<size_t>(nodes_missing_value_tracks_true.size()) && nodes_missing_value_tracks_true[i] == 1) {
node.flags |= static_cast<uint8_t>(MissingTrack::kTrue);
}
nodes_.push_back(std::move(node));
if (nodes_[node_pos].is_not_leaf()) {
size_t falsenode_id = falsenode_ids[i];

// Categoricals are represented as a chain of `EQ` nodes where the subtree for the true child is identical for all nodes in the chain
// Below we are folding together these nodes into one of mode `BRANCH_MEMBER`
// The threshold of this node should be interpreted as a bitmask showing which categoricals values were found in the chain
// Afterwards, when looking whether a feature is included we can do an `and` with the mask of the node
// and the one of the feature (the mask has only one bit set on the place for its value)
// Beware that if a category is bigger than the threshold type, the node stays as `EQ` and no combination is done
if (nodes_[node_pos].flags == NODE_MODE::BRANCH_MEMBER) {
ThresholdType falsenode_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[falsenode_id]) : nodes_values_as_tensor[falsenode_id];

while (cmodes[falsenode_id] == NODE_MODE::BRANCH_EQ && nodes_[node_pos].feature_id == nodes_featureids[falsenode_id] &&
CANMASK(falsenode_threshold, ThresholdType) &&
CheckIfSubtreesAreEqual(truenode_ids[i], truenode_ids[falsenode_id], tree_id, cmodes, truenode_ids, falsenode_ids,
nodes_featureids, nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices)) {
UpdateThreshold(falsenode_threshold, nodes_[node_pos].value_or_unique_weight);
falsenode_id = falsenode_ids[falsenode_id];
falsenode_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[falsenode_id]) : nodes_values_as_tensor[falsenode_id];
}
}

size_t false_branch =
AddNodes(falsenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor,
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
AddNodes(falsenode_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor,
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids,
target_class_weights, target_class_weights_as_tensor, indices);
if (false_branch != node_pos + 1) {
ORT_THROW("False node must always be the next node, but it isn't at index ", node_pos, " with flags ",
static_cast<int>(nodes_[node_pos].flags));
}
size_t true_branch =
AddNodes(truenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor,
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids,
target_class_weights, target_class_weights_as_tensor, indices);
// We don't need to store the false branch pointer since we know it is always in the immediate next entry in nodes_.
// nodes_[node_pos].falsenode_inc_or_n_weights.ptr = &nodes_[false_branch];
nodes_[node_pos].truenode_or_weight.ptr = &nodes_[true_branch];
Expand Down Expand Up @@ -684,6 +790,13 @@
} \
}

// Check whether the feature value is set true in the mask
template <typename T1, typename T2>
inline bool SetMembershipCheck(T1 val, T2 mask) {
const int64_t val_as_int = static_cast<int64_t>(val);
return CANMASK(val, T2) && (((1ll << (val_as_int - 1)) & bit_cast_int(mask)) != 0);
}

inline bool _isnan_(float x) { return std::isnan(x); }
inline bool _isnan_(double x) { return std::isnan(x); }
inline bool _isnan_(int64_t) { return false; }
Expand Down Expand Up @@ -726,6 +839,20 @@
case NODE_MODE::BRANCH_NEQ:
TREE_FIND_VALUE(!=)
break;
case NODE_MODE::BRANCH_MEMBER:
if (has_missing_tracks_) {
while (root->is_not_leaf()) {
val = x_data[root->feature_id];
root = (SetMembershipCheck(val, root->value_or_unique_weight) || (root->is_missing_track_true() && _isnan_(val)))
? root->truenode_or_weight.ptr
: root + 1;
}
} else {
while (root->is_not_leaf()) {
val = x_data[root->feature_id];
root = SetMembershipCheck(val, root->value_or_unique_weight) ? root->truenode_or_weight.ptr : root + 1;
}
}
case NODE_MODE::LEAF:
break;
}
Expand Down Expand Up @@ -759,6 +886,11 @@
root = val != threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE::BRANCH_MEMBER:
root = (SetMembershipCheck(val, root->value_or_unique_weight) || (root->is_missing_track_true() && _isnan_(val)))
? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE::LEAF:
return root;
}
Expand Down
84 changes: 84 additions & 0 deletions onnxruntime/test/providers/cpu/ml/treeregressor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,90 @@ TEST(MLOpTest, TreeRegressorSingleTargetSum_as_tensor_precision) {
GenTreeAndRunTest1_as_tensor_precision(3);
}

TEST(MLOpTest, TreeRegressorCategoricals) {
OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain);

// tree
int64_t n_targets = 1;
std::vector<int64_t> nodes_featureids = {0, 0, 0, 0, 1, 0, 0};
std::vector<std::string> nodes_modes = {"BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "LEAF", "BRANCH_LEQ", "LEAF", "LEAF"};
std::vector<float> nodes_values = {1, 3, 4, 0, 5.5, 0, 0};

std::vector<int64_t> nodes_treeids = {0, 0, 0, 0, 0, 0, 0};
std::vector<int64_t> nodes_nodeids = {0, 1, 2, 3, 4, 5, 6};
std::vector<int64_t> nodes_falsenodeids = {1, 2, 3, 0, 5, 0, 0};
std::vector<int64_t> nodes_truenodeids = {4, 4, 4, 0, 6, 0, 0};

std::string post_transform = "NONE";
std::vector<int64_t> target_ids = {0, 0, 0};
std::vector<int64_t> target_nodeids = {3, 5, 6};
std::vector<int64_t> target_treeids = {0, 0, 0};
std::vector<float> target_weights = {-4.699999809265137, 17.700000762939453, 11.100000381469727};

// add attributes
test.AddAttribute("nodes_truenodeids", nodes_truenodeids);
test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids);
test.AddAttribute("nodes_treeids", nodes_treeids);
test.AddAttribute("nodes_nodeids", nodes_nodeids);
test.AddAttribute("nodes_featureids", nodes_featureids);
test.AddAttribute("nodes_values", nodes_values);
test.AddAttribute("nodes_modes", nodes_modes);
test.AddAttribute("target_treeids", target_treeids);
test.AddAttribute("target_nodeids", target_nodeids);
test.AddAttribute("target_ids", target_ids);
test.AddAttribute("target_weights", target_weights);
test.AddAttribute("n_targets", n_targets);

// fill input data
std::vector<float> X = {3.0f, 6.6f, 1.0f, 5.0f, 5.0f, 5.5f};
std::vector<float> Y = {17.700000762939453, 11.100000381469727, -4.699999809265137};
test.AddInput<float>("X", {3, 2}, X);
test.AddOutput<float>("Y", {3, 1}, Y);
test.Run();
}

TEST(MLOpTest, TreeRegressorCategoricalsFolding) {
OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain);

// tree
int64_t n_targets = 1;
std::vector<int64_t> nodes_featureids = {0, 0, 1, 1, 0, 0, 0};
std::vector<std::string> nodes_modes = {"BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "LEAF", "LEAF", "LEAF"};
std::vector<float> nodes_values = {1, 3, 2, 3, 0, 0, 0};

std::vector<int64_t> nodes_treeids = {0, 0, 0, 0, 0, 0, 0};
std::vector<int64_t> nodes_nodeids = {0, 1, 2, 3, 4, 5, 6};
std::vector<int64_t> nodes_falsenodeids = {1, 2, 3, 4, 0, 0, 0};
std::vector<int64_t> nodes_truenodeids = {5, 5, 6, 6, 0, 0, 0};

std::string post_transform = "NONE";
std::vector<int64_t> target_ids = {0, 0, 0};
std::vector<int64_t> target_nodeids = {4, 5, 6};
std::vector<int64_t> target_treeids = {0, 0, 0};
std::vector<float> target_weights = {17.700000762939453, 11.100000381469727, -4.699999809265137};

// add attributes
test.AddAttribute("nodes_truenodeids", nodes_truenodeids);
test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids);
test.AddAttribute("nodes_treeids", nodes_treeids);
test.AddAttribute("nodes_nodeids", nodes_nodeids);
test.AddAttribute("nodes_featureids", nodes_featureids);
test.AddAttribute("nodes_values", nodes_values);
test.AddAttribute("nodes_modes", nodes_modes);
test.AddAttribute("target_treeids", target_treeids);
test.AddAttribute("target_nodeids", target_nodeids);
test.AddAttribute("target_ids", target_ids);
test.AddAttribute("target_weights", target_weights);
test.AddAttribute("n_targets", n_targets);

// fill input data
std::vector<float> X = {1.0f, 2.0f, 3.0f, 1.0f, 2.0f, 3.0f, 2.0f, 1.0f};
std::vector<float> Y = {11.100000381469727, 11.100000381469727, -4.699999809265137, 17.700000762939453};
test.AddInput<float>("X", {4, 2}, X);
test.AddOutput<float>("Y", {4, 1}, Y);
test.Run();
}

TEST(MLOpTest, TreeRegressorTrueNodeBeforeNode) {
OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain);

Expand Down
Loading