Skip to content

Commit

Permalink
[MT] Add device storage to multi-target tree.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Feb 22, 2025
1 parent 853e3d5 commit 1e46ff0
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 87 deletions.
79 changes: 50 additions & 29 deletions include/xgboost/multi_target_tree_model.h
Original file line number Diff line number Diff line change
@@ -1,72 +1,93 @@
/**
* Copyright 2023 by XGBoost contributors
* Copyright 2023-2025, XGBoost contributors
*
* \brief Core data structure for multi-target trees.
* @brief Core data structure for multi-target trees.
*/
#ifndef XGBOOST_MULTI_TARGET_TREE_MODEL_H_
#define XGBOOST_MULTI_TARGET_TREE_MODEL_H_
#include <xgboost/base.h> // for bst_node_t, bst_target_t, bst_feature_t
#include <xgboost/context.h> // for Context
#include <xgboost/linalg.h> // for VectorView
#include <xgboost/model.h> // for Model
#include <xgboost/span.h> // for Span

#include <cinttypes> // for uint8_t
#include <cstddef> // for size_t
#include <vector> // for vector
#include <xgboost/base.h> // for bst_node_t, bst_target_t, bst_feature_t
#include <xgboost/context.h> // for Context
#include <xgboost/host_device_vector.h> // for HostDeviceVector
#include <xgboost/linalg.h> // for VectorView
#include <xgboost/model.h> // for Model
#include <xgboost/span.h> // for Span

#include <cstddef> // for size_t
#include <cstdint> // for uint8_t
#include <vector> // for vector

namespace xgboost {
struct TreeParam;
/**
* \brief Tree structure for multi-target model.
* @brief Tree structure for multi-target model.
*/
class MultiTargetTree : public Model {
public:
static bst_node_t constexpr InvalidNodeId() { return -1; }

private:
TreeParam const* param_;
std::vector<bst_node_t> left_;
std::vector<bst_node_t> right_;
std::vector<bst_node_t> parent_;
std::vector<bst_feature_t> split_index_;
std::vector<std::uint8_t> default_left_;
std::vector<float> split_conds_;
std::vector<float> weights_;
HostDeviceVector<bst_node_t> left_;
HostDeviceVector<bst_node_t> right_;
HostDeviceVector<bst_node_t> parent_;
HostDeviceVector<bst_feature_t> split_index_;
HostDeviceVector<std::uint8_t> default_left_;
HostDeviceVector<float> split_conds_;
HostDeviceVector<float> weights_;

[[nodiscard]] linalg::VectorView<float const> NodeWeight(bst_node_t nidx) const {
auto beg = nidx * this->NumTarget();
auto v = common::Span<float const>{weights_}.subspan(beg, this->NumTarget());
auto v = this->weights_.ConstHostSpan().subspan(beg, this->NumTarget());
return linalg::MakeTensorView(DeviceOrd::CPU(), v, v.size());
}
[[nodiscard]] linalg::VectorView<float> NodeWeight(bst_node_t nidx) {
auto beg = nidx * this->NumTarget();
auto v = common::Span<float>{weights_}.subspan(beg, this->NumTarget());
auto v = this->weights_.HostSpan().subspan(beg, this->NumTarget());
return linalg::MakeTensorView(DeviceOrd::CPU(), v, v.size());
}

public:
explicit MultiTargetTree(TreeParam const* param);
MultiTargetTree(MultiTargetTree const& that);
MultiTargetTree& operator=(MultiTargetTree const& that) = delete;
MultiTargetTree(MultiTargetTree&& that) = default;
MultiTargetTree& operator=(MultiTargetTree&& that) = default;

/**
* \brief Set the weight for a leaf.
* @brief Set the weight for a leaf.
*/
void SetLeaf(bst_node_t nidx, linalg::VectorView<float const> weight);
/**
* \brief Expand a leaf into split node.
* @brief Expand a leaf into split node.
*/
void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left,
linalg::VectorView<float const> base_weight,
linalg::VectorView<float const> left_weight,
linalg::VectorView<float const> right_weight);

[[nodiscard]] bool IsLeaf(bst_node_t nidx) const { return left_[nidx] == InvalidNodeId(); }
[[nodiscard]] bst_node_t Parent(bst_node_t nidx) const { return parent_.at(nidx); }
[[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const { return left_.at(nidx); }
[[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const { return right_.at(nidx); }
[[nodiscard]] bool IsLeaf(bst_node_t nidx) const {
return left_.ConstHostVector()[nidx] == InvalidNodeId();
}
[[nodiscard]] bst_node_t Parent(bst_node_t nidx) const {
return parent_.ConstHostVector().at(nidx);
}
[[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const {
return left_.ConstHostVector().at(nidx);
}
[[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const {
return right_.ConstHostVector().at(nidx);
}

[[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const { return split_index_[nidx]; }
[[nodiscard]] float SplitCond(bst_node_t nidx) const { return split_conds_[nidx]; }
[[nodiscard]] bool DefaultLeft(bst_node_t nidx) const { return default_left_[nidx]; }
[[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const {
return split_index_.ConstHostVector()[nidx];
}
[[nodiscard]] float SplitCond(bst_node_t nidx) const {
return split_conds_.ConstHostVector()[nidx];
}
[[nodiscard]] bool DefaultLeft(bst_node_t nidx) const {
return default_left_.ConstHostVector()[nidx];
}
[[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const {
return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx);
}
Expand Down
143 changes: 85 additions & 58 deletions src/tree/multi_target_tree_model.cc
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
/**
* Copyright 2023 by XGBoost Contributors
* Copyright 2023-2025, XGBoost Contributors
*/
#include "xgboost/multi_target_tree_model.h"

#include <algorithm> // for copy_n
#include <cstddef> // for size_t
#include <cstdint> // for int32_t, uint8_t
#include <limits> // for numeric_limits
#include <string_view> // for string_view
#include <utility> // for move
#include <vector> // for vector

#include "io_utils.h" // for I32ArrayT, FloatArrayT, GetElem, ...
#include "xgboost/base.h" // for bst_node_t, bst_feature_t, bst_target_t
#include "xgboost/json.h" // for Json, get, Object, Number, Integer, ...
#include <algorithm> // for copy_n
#include <cstddef> // for size_t
#include <cstdint> // for int32_t, uint8_t
#include <limits> // for numeric_limits
#include <string_view> // for string_view
#include <utility> // for move
#include <vector> // for vector

#include "io_utils.h" // for I32ArrayT, FloatArrayT, GetElem, ...
#include "xgboost/base.h" // for bst_node_t, bst_feature_t, bst_target_t
#include "xgboost/json.h" // for Json, get, Object, Number, Integer, ...
#include "xgboost/logging.h"
#include "xgboost/tree_model.h" // for TreeParam

Expand All @@ -30,27 +30,47 @@ MultiTargetTree::MultiTargetTree(TreeParam const* param)
CHECK_GT(param_->size_leaf_vector, 1);
}

MultiTargetTree::MultiTargetTree(MultiTargetTree const& that)
: param_{that.param_},
left_(that.left_.Size(), 0, that.left_.Device()),
right_(that.right_.Size(), 0, that.right_.Device()),
parent_(that.parent_.Size(), 0, that.parent_.Device()),
split_index_(that.split_index_.Size(), 0, that.split_index_.Device()),
default_left_(that.default_left_.Size(), 0, that.default_left_.Device()),
split_conds_(that.split_conds_.Size(), 0, that.split_conds_.Device()),
weights_(that.weights_.Size(), 0, that.weights_.Device()) {
this->left_.Copy(that.left_);
this->right_.Copy(that.right_);
this->parent_.Copy(that.parent_);
this->split_index_.Copy(that.split_index_);
this->default_left_.Copy(that.default_left_);
this->split_conds_.Copy(that.split_conds_);
this->weights_.Copy(that.weights_);
}

template <bool typed, bool feature_is_64>
void LoadModelImpl(Json const& in, std::vector<float>* p_weights, std::vector<bst_node_t>* p_lefts,
std::vector<bst_node_t>* p_rights, std::vector<bst_node_t>* p_parents,
std::vector<float>* p_conds, std::vector<bst_feature_t>* p_fidx,
std::vector<std::uint8_t>* p_dft_left) {
void LoadModelImpl(Json const& in, HostDeviceVector<float>* p_weights,
HostDeviceVector<bst_node_t>* p_lefts, HostDeviceVector<bst_node_t>* p_rights,
HostDeviceVector<bst_node_t>* p_parents, HostDeviceVector<float>* p_conds,
HostDeviceVector<bst_feature_t>* p_fidx,
HostDeviceVector<std::uint8_t>* p_dft_left) {
namespace tf = tree_field;

auto get_float = [&](std::string_view name, std::vector<float>* p_out) {
auto get_float = [&](std::string_view name, HostDeviceVector<float>* p_out) {
auto& values = get<FloatArrayT<typed>>(get<Object const>(in).find(name)->second);
auto& out = *p_out;
out.resize(values.size());
out.Resize(values.size());
auto& h_out = out.HostVector();
for (std::size_t i = 0; i < values.size(); ++i) {
out[i] = GetElem<Number>(values, i);
h_out[i] = GetElem<Number>(values, i);
}
};
get_float(tf::kBaseWeight, p_weights);
get_float(tf::kSplitCond, p_conds);

auto get_nidx = [&](std::string_view name, std::vector<bst_node_t>* p_nidx) {
auto get_nidx = [&](std::string_view name, HostDeviceVector<bst_node_t>* p_nidx) {
auto& nidx = get<I32ArrayT<typed>>(get<Object const>(in).find(name)->second);
auto& out_nidx = *p_nidx;
auto& out_nidx = p_nidx->HostVector();
out_nidx.resize(nidx.size());
for (std::size_t i = 0; i < nidx.size(); ++i) {
out_nidx[i] = GetElem<Integer>(nidx, i);
Expand All @@ -61,15 +81,15 @@ void LoadModelImpl(Json const& in, std::vector<float>* p_weights, std::vector<bs
get_nidx(tf::kParent, p_parents);

auto const& splits = get<IndexArrayT<typed, feature_is_64> const>(in[tf::kSplitIdx]);
p_fidx->resize(splits.size());
auto& out_fidx = *p_fidx;
p_fidx->Resize(splits.size());
auto& out_fidx = p_fidx->HostVector();
for (std::size_t i = 0; i < splits.size(); ++i) {
out_fidx[i] = GetElem<Integer>(splits, i);
}

auto const& dft_left = get<U8ArrayT<typed> const>(in[tf::kDftLeft]);
auto& out_dft_l = *p_dft_left;
out_dft_l.resize(dft_left.size());
p_dft_left->Resize(dft_left.size());
auto& out_dft_l = p_dft_left->HostVector();
for (std::size_t i = 0; i < dft_left.size(); ++i) {
out_dft_l[i] = GetElem<Boolean>(dft_left, i);
}
Expand Down Expand Up @@ -109,19 +129,25 @@ void MultiTargetTree::SaveModel(Json* p_out) const {
U8Array default_left(n_nodes);
F32Array weights(n_nodes * this->NumTarget());

auto const& h_left = this->left_.ConstHostVector();
auto const& h_right = this->right_.ConstHostVector();
auto const& h_parent = this->parent_.ConstHostVector();
auto const& h_split_index = this->split_index_.ConstHostVector();
auto const& h_split_conds = this->split_conds_.ConstHostVector();
auto const& h_default_left = this->default_left_.ConstHostVector();
auto save_tree = [&](auto* p_indices_array) {
auto& indices_array = *p_indices_array;
for (bst_node_t nidx = 0; nidx < n_nodes; ++nidx) {
CHECK_LT(nidx, left_.size());
lefts.Set(nidx, left_[nidx]);
CHECK_LT(nidx, right_.size());
rights.Set(nidx, right_[nidx]);
CHECK_LT(nidx, parent_.size());
parents.Set(nidx, parent_[nidx]);
CHECK_LT(nidx, split_index_.size());
indices_array.Set(nidx, split_index_[nidx]);
conds.Set(nidx, split_conds_[nidx]);
default_left.Set(nidx, default_left_[nidx]);
CHECK_LT(nidx, left_.Size());
lefts.Set(nidx, h_left[nidx]);
CHECK_LT(nidx, right_.Size());
rights.Set(nidx, h_right[nidx]);
CHECK_LT(nidx, parent_.Size());
parents.Set(nidx, h_parent[nidx]);
CHECK_LT(nidx, split_index_.Size());
indices_array.Set(nidx, h_split_index[nidx]);
conds.Set(nidx, h_split_conds[nidx]);
default_left.Set(nidx, h_default_left[nidx]);

auto in_weight = this->NodeWeight(nidx);
auto weight_out = common::Span<float>(weights.GetArray())
Expand Down Expand Up @@ -157,8 +183,8 @@ void MultiTargetTree::SetLeaf(bst_node_t nidx, linalg::VectorView<float const> w
CHECK(this->IsLeaf(nidx)) << "Collapsing a split node to leaf " << MTNotImplemented();
auto const next_nidx = nidx + 1;
CHECK_EQ(weight.Size(), this->NumTarget());
CHECK_GE(weights_.size(), next_nidx * weight.Size());
auto out_weight = common::Span<float>(weights_).subspan(nidx * weight.Size(), weight.Size());
CHECK_GE(weights_.Size(), next_nidx * weight.Size());
auto out_weight = weights_.HostSpan().subspan(nidx * weight.Size(), weight.Size());
for (std::size_t i = 0; i < weight.Size(); ++i) {
out_weight[i] = weight(i);
}
Expand All @@ -169,39 +195,40 @@ void MultiTargetTree::Expand(bst_node_t nidx, bst_feature_t split_idx, float spl
linalg::VectorView<float const> left_weight,
linalg::VectorView<float const> right_weight) {
CHECK(this->IsLeaf(nidx));
CHECK_GE(parent_.size(), 1);
CHECK_EQ(parent_.size(), left_.size());
CHECK_EQ(left_.size(), right_.size());
CHECK_GE(parent_.Size(), 1);
CHECK_EQ(parent_.Size(), left_.Size());
CHECK_EQ(left_.Size(), right_.Size());

std::size_t n = param_->num_nodes + 2;
CHECK_LT(split_idx, this->param_->num_feature);
left_.resize(n, InvalidNodeId());
right_.resize(n, InvalidNodeId());
parent_.resize(n, InvalidNodeId());
left_.Resize(n, InvalidNodeId());
right_.Resize(n, InvalidNodeId());
parent_.Resize(n, InvalidNodeId());

auto left_child = parent_.size() - 2;
auto right_child = parent_.size() - 1;
auto left_child = parent_.Size() - 2;
auto right_child = parent_.Size() - 1;

left_[nidx] = left_child;
right_[nidx] = right_child;
left_.HostVector()[nidx] = left_child;
right_.HostVector()[nidx] = right_child;

auto& h_parent = parent_.HostVector();
if (nidx != 0) {
CHECK_NE(parent_[nidx], InvalidNodeId());
CHECK_NE(h_parent[nidx], InvalidNodeId());
}

parent_[left_child] = nidx;
parent_[right_child] = nidx;
h_parent[left_child] = nidx;
h_parent[right_child] = nidx;

split_index_.resize(n);
split_index_[nidx] = split_idx;
split_index_.Resize(n);
split_index_.HostVector()[nidx] = split_idx;

split_conds_.resize(n, std::numeric_limits<float>::quiet_NaN());
split_conds_[nidx] = split_cond;
split_conds_.Resize(n, std::numeric_limits<float>::quiet_NaN());
split_conds_.HostVector()[nidx] = split_cond;

default_left_.resize(n);
default_left_[nidx] = static_cast<std::uint8_t>(default_left);
default_left_.Resize(n);
default_left_.HostVector()[nidx] = static_cast<std::uint8_t>(default_left);

weights_.resize(n * this->NumTarget());
weights_.Resize(n * this->NumTarget());
auto p_weight = this->NodeWeight(nidx);
CHECK_EQ(p_weight.Size(), base_weight.Size());
auto l_weight = this->NodeWeight(left_child);
Expand All @@ -217,5 +244,5 @@ void MultiTargetTree::Expand(bst_node_t nidx, bst_feature_t split_idx, float spl
}

bst_target_t MultiTargetTree::NumTarget() const { return param_->size_leaf_vector; }
std::size_t MultiTargetTree::Size() const { return parent_.size(); }
std::size_t MultiTargetTree::Size() const { return parent_.Size(); }
} // namespace xgboost

0 comments on commit 1e46ff0

Please sign in to comment.