diff --git a/include/xgboost/multi_target_tree_model.h b/include/xgboost/multi_target_tree_model.h index 676c43196263..430c5455f1e9 100644 --- a/include/xgboost/multi_target_tree_model.h +++ b/include/xgboost/multi_target_tree_model.h @@ -1,24 +1,26 @@ /** - * Copyright 2023 by XGBoost contributors + * Copyright 2023-2025, XGBoost contributors * - * \brief Core data structure for multi-target trees. + * @brief Core data structure for multi-target trees. */ #ifndef XGBOOST_MULTI_TARGET_TREE_MODEL_H_ #define XGBOOST_MULTI_TARGET_TREE_MODEL_H_ -#include // for bst_node_t, bst_target_t, bst_feature_t -#include // for Context -#include // for VectorView -#include // for Model -#include // for Span -#include // for uint8_t -#include // for size_t -#include // for vector +#include // for bst_node_t, bst_target_t, bst_feature_t +#include // for Context +#include // for HostDeviceVector +#include // for VectorView +#include // for Model +#include // for Span + +#include // for size_t +#include // for uint8_t +#include // for vector namespace xgboost { struct TreeParam; /** - * \brief Tree structure for multi-target model. + * @brief Tree structure for multi-target model. */ class MultiTargetTree : public Model { public: @@ -26,47 +28,66 @@ class MultiTargetTree : public Model { private: TreeParam const* param_; - std::vector left_; - std::vector right_; - std::vector parent_; - std::vector split_index_; - std::vector default_left_; - std::vector split_conds_; - std::vector weights_; + HostDeviceVector left_; + HostDeviceVector right_; + HostDeviceVector parent_; + HostDeviceVector split_index_; + HostDeviceVector default_left_; + HostDeviceVector split_conds_; + HostDeviceVector weights_; [[nodiscard]] linalg::VectorView NodeWeight(bst_node_t nidx) const { auto beg = nidx * this->NumTarget(); - auto v = common::Span{weights_}.subspan(beg, this->NumTarget()); + auto v = this->weights_.ConstHostSpan().subspan(beg, this->NumTarget()); return linalg::MakeTensorView(DeviceOrd::CPU(), v, v.size()); } [[nodiscard]] linalg::VectorView NodeWeight(bst_node_t nidx) { auto beg = nidx * this->NumTarget(); - auto v = common::Span{weights_}.subspan(beg, this->NumTarget()); + auto v = this->weights_.HostSpan().subspan(beg, this->NumTarget()); return linalg::MakeTensorView(DeviceOrd::CPU(), v, v.size()); } public: explicit MultiTargetTree(TreeParam const* param); + MultiTargetTree(MultiTargetTree const& that); + MultiTargetTree& operator=(MultiTargetTree const& that) = delete; + MultiTargetTree(MultiTargetTree&& that) = default; + MultiTargetTree& operator=(MultiTargetTree&& that) = default; + /** - * \brief Set the weight for a leaf. + * @brief Set the weight for a leaf. */ void SetLeaf(bst_node_t nidx, linalg::VectorView weight); /** - * \brief Expand a leaf into split node. + * @brief Expand a leaf into split node. */ void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left, linalg::VectorView base_weight, linalg::VectorView left_weight, linalg::VectorView right_weight); - [[nodiscard]] bool IsLeaf(bst_node_t nidx) const { return left_[nidx] == InvalidNodeId(); } - [[nodiscard]] bst_node_t Parent(bst_node_t nidx) const { return parent_.at(nidx); } - [[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const { return left_.at(nidx); } - [[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const { return right_.at(nidx); } + [[nodiscard]] bool IsLeaf(bst_node_t nidx) const { + return left_.ConstHostVector()[nidx] == InvalidNodeId(); + } + [[nodiscard]] bst_node_t Parent(bst_node_t nidx) const { + return parent_.ConstHostVector().at(nidx); + } + [[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const { + return left_.ConstHostVector().at(nidx); + } + [[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const { + return right_.ConstHostVector().at(nidx); + } - [[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const { return split_index_[nidx]; } - [[nodiscard]] float SplitCond(bst_node_t nidx) const { return split_conds_[nidx]; } - [[nodiscard]] bool DefaultLeft(bst_node_t nidx) const { return default_left_[nidx]; } + [[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const { + return split_index_.ConstHostVector()[nidx]; + } + [[nodiscard]] float SplitCond(bst_node_t nidx) const { + return split_conds_.ConstHostVector()[nidx]; + } + [[nodiscard]] bool DefaultLeft(bst_node_t nidx) const { + return default_left_.ConstHostVector()[nidx]; + } [[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const { return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx); } diff --git a/src/tree/multi_target_tree_model.cc b/src/tree/multi_target_tree_model.cc index 11ee1f6dd0c4..7f3087fd6f06 100644 --- a/src/tree/multi_target_tree_model.cc +++ b/src/tree/multi_target_tree_model.cc @@ -1,19 +1,19 @@ /** - * Copyright 2023 by XGBoost Contributors + * Copyright 2023-2025, XGBoost Contributors */ #include "xgboost/multi_target_tree_model.h" -#include // for copy_n -#include // for size_t -#include // for int32_t, uint8_t -#include // for numeric_limits -#include // for string_view -#include // for move -#include // for vector - -#include "io_utils.h" // for I32ArrayT, FloatArrayT, GetElem, ... -#include "xgboost/base.h" // for bst_node_t, bst_feature_t, bst_target_t -#include "xgboost/json.h" // for Json, get, Object, Number, Integer, ... +#include // for copy_n +#include // for size_t +#include // for int32_t, uint8_t +#include // for numeric_limits +#include // for string_view +#include // for move +#include // for vector + +#include "io_utils.h" // for I32ArrayT, FloatArrayT, GetElem, ... +#include "xgboost/base.h" // for bst_node_t, bst_feature_t, bst_target_t +#include "xgboost/json.h" // for Json, get, Object, Number, Integer, ... #include "xgboost/logging.h" #include "xgboost/tree_model.h" // for TreeParam @@ -30,27 +30,47 @@ MultiTargetTree::MultiTargetTree(TreeParam const* param) CHECK_GT(param_->size_leaf_vector, 1); } +MultiTargetTree::MultiTargetTree(MultiTargetTree const& that) + : param_{that.param_}, + left_(that.left_.Size(), 0, that.left_.Device()), + right_(that.right_.Size(), 0, that.right_.Device()), + parent_(that.parent_.Size(), 0, that.parent_.Device()), + split_index_(that.split_index_.Size(), 0, that.split_index_.Device()), + default_left_(that.default_left_.Size(), 0, that.default_left_.Device()), + split_conds_(that.split_conds_.Size(), 0, that.split_conds_.Device()), + weights_(that.weights_.Size(), 0, that.weights_.Device()) { + this->left_.Copy(that.left_); + this->right_.Copy(that.right_); + this->parent_.Copy(that.parent_); + this->split_index_.Copy(that.split_index_); + this->default_left_.Copy(that.default_left_); + this->split_conds_.Copy(that.split_conds_); + this->weights_.Copy(that.weights_); +} + template -void LoadModelImpl(Json const& in, std::vector* p_weights, std::vector* p_lefts, - std::vector* p_rights, std::vector* p_parents, - std::vector* p_conds, std::vector* p_fidx, - std::vector* p_dft_left) { +void LoadModelImpl(Json const& in, HostDeviceVector* p_weights, + HostDeviceVector* p_lefts, HostDeviceVector* p_rights, + HostDeviceVector* p_parents, HostDeviceVector* p_conds, + HostDeviceVector* p_fidx, + HostDeviceVector* p_dft_left) { namespace tf = tree_field; - auto get_float = [&](std::string_view name, std::vector* p_out) { + auto get_float = [&](std::string_view name, HostDeviceVector* p_out) { auto& values = get>(get(in).find(name)->second); auto& out = *p_out; - out.resize(values.size()); + out.Resize(values.size()); + auto& h_out = out.HostVector(); for (std::size_t i = 0; i < values.size(); ++i) { - out[i] = GetElem(values, i); + h_out[i] = GetElem(values, i); } }; get_float(tf::kBaseWeight, p_weights); get_float(tf::kSplitCond, p_conds); - auto get_nidx = [&](std::string_view name, std::vector* p_nidx) { + auto get_nidx = [&](std::string_view name, HostDeviceVector* p_nidx) { auto& nidx = get>(get(in).find(name)->second); - auto& out_nidx = *p_nidx; + auto& out_nidx = p_nidx->HostVector(); out_nidx.resize(nidx.size()); for (std::size_t i = 0; i < nidx.size(); ++i) { out_nidx[i] = GetElem(nidx, i); @@ -61,15 +81,15 @@ void LoadModelImpl(Json const& in, std::vector* p_weights, std::vector const>(in[tf::kSplitIdx]); - p_fidx->resize(splits.size()); - auto& out_fidx = *p_fidx; + p_fidx->Resize(splits.size()); + auto& out_fidx = p_fidx->HostVector(); for (std::size_t i = 0; i < splits.size(); ++i) { out_fidx[i] = GetElem(splits, i); } auto const& dft_left = get const>(in[tf::kDftLeft]); - auto& out_dft_l = *p_dft_left; - out_dft_l.resize(dft_left.size()); + p_dft_left->Resize(dft_left.size()); + auto& out_dft_l = p_dft_left->HostVector(); for (std::size_t i = 0; i < dft_left.size(); ++i) { out_dft_l[i] = GetElem(dft_left, i); } @@ -109,19 +129,25 @@ void MultiTargetTree::SaveModel(Json* p_out) const { U8Array default_left(n_nodes); F32Array weights(n_nodes * this->NumTarget()); + auto const& h_left = this->left_.ConstHostVector(); + auto const& h_right = this->right_.ConstHostVector(); + auto const& h_parent = this->parent_.ConstHostVector(); + auto const& h_split_index = this->split_index_.ConstHostVector(); + auto const& h_split_conds = this->split_conds_.ConstHostVector(); + auto const& h_default_left = this->default_left_.ConstHostVector(); auto save_tree = [&](auto* p_indices_array) { auto& indices_array = *p_indices_array; for (bst_node_t nidx = 0; nidx < n_nodes; ++nidx) { - CHECK_LT(nidx, left_.size()); - lefts.Set(nidx, left_[nidx]); - CHECK_LT(nidx, right_.size()); - rights.Set(nidx, right_[nidx]); - CHECK_LT(nidx, parent_.size()); - parents.Set(nidx, parent_[nidx]); - CHECK_LT(nidx, split_index_.size()); - indices_array.Set(nidx, split_index_[nidx]); - conds.Set(nidx, split_conds_[nidx]); - default_left.Set(nidx, default_left_[nidx]); + CHECK_LT(nidx, left_.Size()); + lefts.Set(nidx, h_left[nidx]); + CHECK_LT(nidx, right_.Size()); + rights.Set(nidx, h_right[nidx]); + CHECK_LT(nidx, parent_.Size()); + parents.Set(nidx, h_parent[nidx]); + CHECK_LT(nidx, split_index_.Size()); + indices_array.Set(nidx, h_split_index[nidx]); + conds.Set(nidx, h_split_conds[nidx]); + default_left.Set(nidx, h_default_left[nidx]); auto in_weight = this->NodeWeight(nidx); auto weight_out = common::Span(weights.GetArray()) @@ -157,8 +183,8 @@ void MultiTargetTree::SetLeaf(bst_node_t nidx, linalg::VectorView w CHECK(this->IsLeaf(nidx)) << "Collapsing a split node to leaf " << MTNotImplemented(); auto const next_nidx = nidx + 1; CHECK_EQ(weight.Size(), this->NumTarget()); - CHECK_GE(weights_.size(), next_nidx * weight.Size()); - auto out_weight = common::Span(weights_).subspan(nidx * weight.Size(), weight.Size()); + CHECK_GE(weights_.Size(), next_nidx * weight.Size()); + auto out_weight = weights_.HostSpan().subspan(nidx * weight.Size(), weight.Size()); for (std::size_t i = 0; i < weight.Size(); ++i) { out_weight[i] = weight(i); } @@ -169,39 +195,40 @@ void MultiTargetTree::Expand(bst_node_t nidx, bst_feature_t split_idx, float spl linalg::VectorView left_weight, linalg::VectorView right_weight) { CHECK(this->IsLeaf(nidx)); - CHECK_GE(parent_.size(), 1); - CHECK_EQ(parent_.size(), left_.size()); - CHECK_EQ(left_.size(), right_.size()); + CHECK_GE(parent_.Size(), 1); + CHECK_EQ(parent_.Size(), left_.Size()); + CHECK_EQ(left_.Size(), right_.Size()); std::size_t n = param_->num_nodes + 2; CHECK_LT(split_idx, this->param_->num_feature); - left_.resize(n, InvalidNodeId()); - right_.resize(n, InvalidNodeId()); - parent_.resize(n, InvalidNodeId()); + left_.Resize(n, InvalidNodeId()); + right_.Resize(n, InvalidNodeId()); + parent_.Resize(n, InvalidNodeId()); - auto left_child = parent_.size() - 2; - auto right_child = parent_.size() - 1; + auto left_child = parent_.Size() - 2; + auto right_child = parent_.Size() - 1; - left_[nidx] = left_child; - right_[nidx] = right_child; + left_.HostVector()[nidx] = left_child; + right_.HostVector()[nidx] = right_child; + auto& h_parent = parent_.HostVector(); if (nidx != 0) { - CHECK_NE(parent_[nidx], InvalidNodeId()); + CHECK_NE(h_parent[nidx], InvalidNodeId()); } - parent_[left_child] = nidx; - parent_[right_child] = nidx; + h_parent[left_child] = nidx; + h_parent[right_child] = nidx; - split_index_.resize(n); - split_index_[nidx] = split_idx; + split_index_.Resize(n); + split_index_.HostVector()[nidx] = split_idx; - split_conds_.resize(n, std::numeric_limits::quiet_NaN()); - split_conds_[nidx] = split_cond; + split_conds_.Resize(n, std::numeric_limits::quiet_NaN()); + split_conds_.HostVector()[nidx] = split_cond; - default_left_.resize(n); - default_left_[nidx] = static_cast(default_left); + default_left_.Resize(n); + default_left_.HostVector()[nidx] = static_cast(default_left); - weights_.resize(n * this->NumTarget()); + weights_.Resize(n * this->NumTarget()); auto p_weight = this->NodeWeight(nidx); CHECK_EQ(p_weight.Size(), base_weight.Size()); auto l_weight = this->NodeWeight(left_child); @@ -217,5 +244,5 @@ void MultiTargetTree::Expand(bst_node_t nidx, bst_feature_t split_idx, float spl } bst_target_t MultiTargetTree::NumTarget() const { return param_->size_leaf_vector; } -std::size_t MultiTargetTree::Size() const { return parent_.size(); } +std::size_t MultiTargetTree::Size() const { return parent_.Size(); } } // namespace xgboost