diff --git a/src/common/quantile.h b/src/common/quantile.h index e189b259b159..97eeb593b869 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -748,37 +748,79 @@ std::vector CalcColumnSize(Batch const &batch, bst_feature_t const n_ return entries_per_columns; } +struct WLBalance { + explicit WLBalance(size_t n_columns) : is_column_splited(n_columns) {} + + struct ThreadWorkLoad { + std::vector columns; + size_t split_idx = 0; + size_t n_splits = 1; + + ThreadWorkLoad() : columns() {} + }; + + std::vector baskets; + std::vector is_column_splited; + bool has_splitted = false; +}; + + template -std::vector LoadBalance(Batch const &batch, size_t nnz, bst_feature_t n_columns, - size_t const nthreads, IsValid&& is_valid) { - /* Some sparse datasets have their mass concentrating on small number of features. To - * avoid waiting for a few threads running forever, we here distribute different number - * of columns to different threads according to number of entries. +WLBalance LoadBalance(Batch const &batch, size_t nnz, bst_feature_t n_columns, + size_t const nthreads, IsValid&& is_valid) { + /* Some datasets have long columns. It is beneficial to split such columns between threads and + * than collect the result if number of threads is high enourth. In this case, each thread being + * involved in processing of splitted columns works only with a single column. + * + * Columns that are too small for splitting are distributed between threads. In this case each thread + * can process multiple columns. The range of columns indexes for all the rthreads in this case don't + * overlap with each other. */ + WLBalance wl_balance(n_columns); + if (nnz == 0) return wl_balance; + auto& wl_baskets = wl_balance.baskets; + size_t const total_entries = nnz; size_t const entries_per_thread = DivRoundUp(total_entries, nthreads); // Need to calculate the size for each batch. std::vector entries_per_columns = CalcColumnSize(batch, n_columns, nthreads, is_valid); - std::vector cols_ptr(nthreads + 1, 0); - size_t count{0}; - size_t current_thread{1}; - for (auto col : entries_per_columns) { - cols_ptr.at(current_thread)++; // add one column to thread - count += col; - CHECK_LE(count, total_entries); - if (count > entries_per_thread) { - current_thread++; - count = 0; - cols_ptr.at(current_thread) = cols_ptr[current_thread - 1]; + size_t count = 0; + for (size_t column_idx = 0; column_idx < n_columns; ++column_idx) { + size_t n_entries = entries_per_columns[column_idx]; + + if (n_entries > 0) { + size_t n_splits = std::min(nthreads * n_entries / total_entries, n_entries); + constexpr size_t kMinBlockSize = (1u << 16); + if ((n_splits > 1) && (kMinBlockSize * n_splits < n_entries)) { + // Split column between threads + wl_balance.has_splitted = true; + wl_balance.is_column_splited[column_idx] = true; + for (size_t split_idx = 0; split_idx < n_splits; split_idx++) { + wl_baskets.emplace_back(); + + auto& wl = wl_baskets.back(); + wl.columns.push_back(column_idx); + wl.split_idx = split_idx; + wl.n_splits = n_splits; + } + } else { + if (wl_baskets.empty() || count > entries_per_thread) { + wl_baskets.emplace_back(); + count = 0; + } + count += n_entries; + + auto& wl = wl_baskets.back(); + wl.columns.push_back(column_idx); + wl_balance.is_column_splited[column_idx] = false; + } } } - // Idle threads. - for (; current_thread < cols_ptr.size() - 1; ++current_thread) { - cols_ptr[current_thread + 1] = cols_ptr[current_thread]; - } - return cols_ptr; + + CHECK_LE(wl_baskets.size(), nthreads); + return wl_balance; } /*! @@ -840,46 +882,126 @@ class SketchContainerImpl { template void PushRowPageImpl(Batch const &batch, size_t base_rowid, OptionalWeights weights, size_t nnz, size_t n_features, bool is_dense, IsValid is_valid) { - auto thread_columns_ptr = LoadBalance(batch, nnz, n_features, n_threads_, is_valid); + auto threads_wl = LoadBalance(batch, nnz, n_features, n_threads_, is_valid); + if (threads_wl.baskets.empty()) return; + + std::vector> categories_buff; + std::vector sketches_buff; + std::vector buff_was_used; + + if (threads_wl.has_splitted) { + sketches_buff.resize(threads_wl.baskets.size()); + categories_buff.resize(threads_wl.baskets.size()); + buff_was_used.resize(threads_wl.baskets.size(), 0); + } dmlc::OMPException exc; -#pragma omp parallel num_threads(n_threads_) + #pragma omp parallel num_threads(threads_wl.baskets.size()) { exc.Run([&]() { auto tid = static_cast(omp_get_thread_num()); - auto const begin = thread_columns_ptr[tid]; - auto const end = thread_columns_ptr[tid + 1]; + const auto& wl = threads_wl.baskets[tid]; + if (wl.n_splits > 1) { + // We process only a single column in this case + size_t column = wl.columns.front(); + + auto n_bins = std::min(static_cast(max_bins_), columns_size_[column]); + auto eps = 1.0 / (static_cast(n_bins) * WQSketch::kFactor); + sketches_buff[tid].Init(columns_size_[column], eps); - // do not iterate if no columns are assigned to the thread - if (begin < end && end <= n_features) { + size_t split_size = DivRoundUp(batch.Size(), wl.n_splits); + size_t begin = wl.split_idx * split_size; + size_t end = std::min(begin + split_size, batch.Size()); + + for (size_t ridx = begin; ridx < end; ++ridx) { + auto const &line = batch.GetLine(ridx); + auto w = weights[ridx + base_rowid]; + if (is_dense) { + auto const &elem = line.GetElement(column); + /* elem.column_idx == column */ + if (is_valid(elem)) { + buff_was_used[tid] = 1; + PushElement(elem, &categories_buff[tid], &sketches_buff[tid], w); + } + } else { + size_t n_columns_with_high_idx = n_features - column; + size_t col_begin = line.Size() < n_columns_with_high_idx ? 0 + : line.Size() - n_columns_with_high_idx; + size_t col_end = std::min(column + 1, line.Size()); + for (size_t i = col_begin; i < col_end; ++i) { + auto const &elem = line.GetElement(i); + if (is_valid(elem) && (elem.column_idx == column)) { + buff_was_used[tid] = 1; + PushElement(elem, &categories_buff[tid], &sketches_buff[tid], w); + } + } + } + } + } else { for (size_t ridx = 0; ridx < batch.Size(); ++ridx) { auto const &line = batch.GetLine(ridx); auto w = weights[ridx + base_rowid]; if (is_dense) { - for (size_t ii = begin; ii < end; ii++) { - auto elem = line.GetElement(ii); - if (is_valid(elem)) { - if (IsCat(feature_types_, ii)) { - categories_[ii].emplace(elem.value); - } else { - sketches_[ii].Push(elem.value, w); + for (size_t ii = wl.columns.front(); ii <= wl.columns.back(); ++ii) { + if (!threads_wl.is_column_splited[ii]) { + auto const &elem = line.GetElement(ii); + /* elem.column_idx == ii */ + if (is_valid(elem)) { + PushElement(elem, &categories_[ii], &sketches_[ii], w); } } } } else { - for (size_t i = 0; i < line.Size(); ++i) { + // number of columns with idx >= wl.columns.front() + size_t n_columns_with_high_idx = n_features - wl.columns.front(); + size_t col_begin = line.Size() < n_columns_with_high_idx + ? 0 : line.Size() - n_columns_with_high_idx; + size_t col_end = std::min(wl.columns.back() + 1, line.Size()); + for (size_t i = col_begin; i < col_end; ++i) { auto const &elem = line.GetElement(i); - if (is_valid(elem) && elem.column_idx >= begin && elem.column_idx < end) { - if (IsCat(feature_types_, elem.column_idx)) { - categories_[elem.column_idx].emplace(elem.value); - } else { - sketches_[elem.column_idx].Push(elem.value, w); + if (is_valid(elem)) { + if (!threads_wl.is_column_splited[elem.column_idx] && + (elem.column_idx >= wl.columns.front()) && + (elem.column_idx <= wl.columns.back())) { + PushElement(elem, &categories_[elem.column_idx], + &sketches_[elem.column_idx], w); } } } } } } + + #pragma omp barrier + if (wl.n_splits > 1 && wl.split_idx == 0) { + /* The thread being responsible for the first block in split + * collect info from the other ones. + */ + size_t column_idx = wl.columns.front(); + + typename WQSketch::SummaryContainer main_summary; + main_summary.Reserve(sketches_[column_idx].limit_size); + typename WQSketch::SummaryContainer split_summary; + split_summary.Reserve(2 * sketches_[column_idx].limit_size); + typename WQSketch::SummaryContainer comb_summary; + comb_summary.Reserve(3 * sketches_[column_idx].limit_size); + + for (size_t th = tid + 0; th < tid + wl.n_splits; ++th) { + CHECK_LT(th, threads_wl.baskets.size()); + // Make shure some work was done by thread + if (buff_was_used[th] > 0) { + if (IsCat(feature_types_, column_idx)) { + categories_[column_idx].merge(categories_buff[th]); + } else { + sketches_buff[th].GetSummary(&split_summary); + + comb_summary.SetCombine(main_summary, split_summary); + main_summary.SetPrune(comb_summary, sketches_[column_idx].limit_size); + } + } + } + sketches_[column_idx].PushSummary(main_summary); + } }); } exc.Rethrow(); @@ -893,6 +1015,18 @@ class SketchContainerImpl { private: // Merge all categories from other workers. void AllreduceCategories(Context const* ctx, MetaInfo const& info); + + template + void PushElement(const ElemType& elem, + std::set* categorie, + WQSketch* sketch, + float w) { + if (IsCat(feature_types_, elem.column_idx)) { + categorie->emplace(elem.value); + } else { + sketch->Push(elem.value, w); + } + } }; class HostSketchContainer : public SketchContainerImpl> { diff --git a/tests/cpp/common/test_quantile.cc b/tests/cpp/common/test_quantile.cc index 1ef6572599fc..b54c2a96391b 100644 --- a/tests/cpp/common/test_quantile.cc +++ b/tests/cpp/common/test_quantile.cc @@ -17,15 +17,18 @@ namespace xgboost::common { TEST(Quantile, LoadBalance) { size_t constexpr kRows = 1000, kCols = 100; auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(); - std::vector cols_ptr; + WLBalance threads_wl(kCols); Context ctx; for (auto const& page : m->GetBatches(&ctx)) { data::SparsePageAdapterBatch adapter{page.GetView()}; - cols_ptr = LoadBalance(adapter, page.data.Size(), kCols, 13, [](auto) { return true; }); + threads_wl = LoadBalance(adapter, page.data.Size(), kCols, 13, [](auto) { return true; }); } size_t n_cols = 0; - for (size_t i = 1; i < cols_ptr.size(); ++i) { - n_cols += cols_ptr[i] - cols_ptr[i - 1]; + for (const auto& basket : threads_wl.baskets) { + n_cols += basket.columns.size(); + for (size_t column : basket.columns) { + CHECK_LT(column, kCols); + } } CHECK_EQ(n_cols, kCols); } @@ -160,6 +163,12 @@ TEST(Quantile, DistributedBasic) { TestDistributedQuantile(kRows, kCols); } +TEST(Quantile, DistributedRowWise) { + size_t kRows = (1u << 16) * common::OmpGetNumThreads(0); + size_t kCols = 2; + TestDistributedQuantile(kRows, kCols); +} + TEST(Quantile, Distributed) { constexpr size_t kRows = 4000, kCols = 200; TestDistributedQuantile(kRows, kCols); @@ -288,6 +297,12 @@ TEST(Quantile, ColumnSplitBasic) { TestColSplitQuantile(kRows, kCols); } +TEST(Quantile, ColumnSplitRowWise) { + size_t kRows = (1u << 16) * common::OmpGetNumThreads(0); + size_t kCols = 2; + TestColSplitQuantile(kRows, kCols); +} + TEST(Quantile, ColumnSplit) { constexpr size_t kRows = 4000, kCols = 200; TestColSplitQuantile(kRows, kCols);