Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization of PushRowPage for high number of cpu cores #11182

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 174 additions & 40 deletions src/common/quantile.h
Original file line number Diff line number Diff line change
Expand Up @@ -748,37 +748,79 @@ std::vector<bst_idx_t> CalcColumnSize(Batch const &batch, bst_feature_t const n_
return entries_per_columns;
}

struct WLBalance {
explicit WLBalance(size_t n_columns) : is_column_splited(n_columns) {}

struct ThreadWorkLoad {
std::vector<size_t> columns;
size_t split_idx = 0;
size_t n_splits = 1;

ThreadWorkLoad() : columns() {}
};

std::vector<ThreadWorkLoad> baskets;
std::vector<bool> is_column_splited;
bool has_splitted = false;
};


template <typename Batch, typename IsValid>
std::vector<bst_feature_t> LoadBalance(Batch const &batch, size_t nnz, bst_feature_t n_columns,
size_t const nthreads, IsValid&& is_valid) {
/* Some sparse datasets have their mass concentrating on small number of features. To
* avoid waiting for a few threads running forever, we here distribute different number
* of columns to different threads according to number of entries.
WLBalance LoadBalance(Batch const &batch, size_t nnz, bst_feature_t n_columns,
size_t const nthreads, IsValid&& is_valid) {
/* Some datasets have long columns. It is beneficial to split such columns between threads and
* than collect the result if number of threads is high enourth. In this case, each thread being
* involved in processing of splitted columns works only with a single column.
*
* Columns that are too small for splitting are distributed between threads. In this case each thread
* can process multiple columns. The range of columns indexes for all the rthreads in this case don't
* overlap with each other.
*/
WLBalance wl_balance(n_columns);
if (nnz == 0) return wl_balance;
auto& wl_baskets = wl_balance.baskets;

size_t const total_entries = nnz;
size_t const entries_per_thread = DivRoundUp(total_entries, nthreads);

// Need to calculate the size for each batch.
std::vector<bst_idx_t> entries_per_columns = CalcColumnSize(batch, n_columns, nthreads, is_valid);
std::vector<bst_feature_t> cols_ptr(nthreads + 1, 0);
size_t count{0};
size_t current_thread{1};

for (auto col : entries_per_columns) {
cols_ptr.at(current_thread)++; // add one column to thread
count += col;
CHECK_LE(count, total_entries);
if (count > entries_per_thread) {
current_thread++;
count = 0;
cols_ptr.at(current_thread) = cols_ptr[current_thread - 1];
size_t count = 0;
for (size_t column_idx = 0; column_idx < n_columns; ++column_idx) {
size_t n_entries = entries_per_columns[column_idx];

if (n_entries > 0) {
size_t n_splits = std::min(nthreads * n_entries / total_entries, n_entries);
constexpr size_t kMinBlockSize = (1u << 16);
if ((n_splits > 1) && (kMinBlockSize * n_splits < n_entries)) {
// Split column between threads
wl_balance.has_splitted = true;
wl_balance.is_column_splited[column_idx] = true;
for (size_t split_idx = 0; split_idx < n_splits; split_idx++) {
wl_baskets.emplace_back();

auto& wl = wl_baskets.back();
wl.columns.push_back(column_idx);
wl.split_idx = split_idx;
wl.n_splits = n_splits;
}
} else {
if (wl_baskets.empty() || count > entries_per_thread) {
wl_baskets.emplace_back();
count = 0;
}
count += n_entries;

auto& wl = wl_baskets.back();
wl.columns.push_back(column_idx);
wl_balance.is_column_splited[column_idx] = false;
}
}
}
// Idle threads.
for (; current_thread < cols_ptr.size() - 1; ++current_thread) {
cols_ptr[current_thread + 1] = cols_ptr[current_thread];
}
return cols_ptr;

CHECK_LE(wl_baskets.size(), nthreads);
return wl_balance;
}

/*!
Expand Down Expand Up @@ -840,46 +882,126 @@ class SketchContainerImpl {
template <typename Batch, typename IsValid>
void PushRowPageImpl(Batch const &batch, size_t base_rowid, OptionalWeights weights, size_t nnz,
size_t n_features, bool is_dense, IsValid is_valid) {
auto thread_columns_ptr = LoadBalance(batch, nnz, n_features, n_threads_, is_valid);
auto threads_wl = LoadBalance(batch, nnz, n_features, n_threads_, is_valid);
if (threads_wl.baskets.empty()) return;

std::vector<std::set<float>> categories_buff;
std::vector<WQSketch> sketches_buff;
std::vector<int> buff_was_used;

if (threads_wl.has_splitted) {
sketches_buff.resize(threads_wl.baskets.size());
categories_buff.resize(threads_wl.baskets.size());
buff_was_used.resize(threads_wl.baskets.size(), 0);
}

dmlc::OMPException exc;
#pragma omp parallel num_threads(n_threads_)
#pragma omp parallel num_threads(threads_wl.baskets.size())
{
exc.Run([&]() {
auto tid = static_cast<uint32_t>(omp_get_thread_num());
auto const begin = thread_columns_ptr[tid];
auto const end = thread_columns_ptr[tid + 1];
const auto& wl = threads_wl.baskets[tid];
if (wl.n_splits > 1) {
// We process only a single column in this case
size_t column = wl.columns.front();

auto n_bins = std::min(static_cast<bst_idx_t>(max_bins_), columns_size_[column]);
auto eps = 1.0 / (static_cast<float>(n_bins) * WQSketch::kFactor);
sketches_buff[tid].Init(columns_size_[column], eps);

// do not iterate if no columns are assigned to the thread
if (begin < end && end <= n_features) {
size_t split_size = DivRoundUp(batch.Size(), wl.n_splits);
size_t begin = wl.split_idx * split_size;
size_t end = std::min(begin + split_size, batch.Size());

for (size_t ridx = begin; ridx < end; ++ridx) {
auto const &line = batch.GetLine(ridx);
auto w = weights[ridx + base_rowid];
if (is_dense) {
auto const &elem = line.GetElement(column);
/* elem.column_idx == column */
if (is_valid(elem)) {
buff_was_used[tid] = 1;
PushElement(elem, &categories_buff[tid], &sketches_buff[tid], w);
}
} else {
size_t n_columns_with_high_idx = n_features - column;
size_t col_begin = line.Size() < n_columns_with_high_idx ? 0
: line.Size() - n_columns_with_high_idx;
size_t col_end = std::min(column + 1, line.Size());
for (size_t i = col_begin; i < col_end; ++i) {
auto const &elem = line.GetElement(i);
if (is_valid(elem) && (elem.column_idx == column)) {
buff_was_used[tid] = 1;
PushElement(elem, &categories_buff[tid], &sketches_buff[tid], w);
}
}
}
}
} else {
for (size_t ridx = 0; ridx < batch.Size(); ++ridx) {
auto const &line = batch.GetLine(ridx);
auto w = weights[ridx + base_rowid];
if (is_dense) {
for (size_t ii = begin; ii < end; ii++) {
auto elem = line.GetElement(ii);
if (is_valid(elem)) {
if (IsCat(feature_types_, ii)) {
categories_[ii].emplace(elem.value);
} else {
sketches_[ii].Push(elem.value, w);
for (size_t ii = wl.columns.front(); ii <= wl.columns.back(); ++ii) {
if (!threads_wl.is_column_splited[ii]) {
auto const &elem = line.GetElement(ii);
/* elem.column_idx == ii */
if (is_valid(elem)) {
PushElement(elem, &categories_[ii], &sketches_[ii], w);
}
}
}
} else {
for (size_t i = 0; i < line.Size(); ++i) {
// number of columns with idx >= wl.columns.front()
size_t n_columns_with_high_idx = n_features - wl.columns.front();
size_t col_begin = line.Size() < n_columns_with_high_idx
? 0 : line.Size() - n_columns_with_high_idx;
size_t col_end = std::min(wl.columns.back() + 1, line.Size());
for (size_t i = col_begin; i < col_end; ++i) {
auto const &elem = line.GetElement(i);
if (is_valid(elem) && elem.column_idx >= begin && elem.column_idx < end) {
if (IsCat(feature_types_, elem.column_idx)) {
categories_[elem.column_idx].emplace(elem.value);
} else {
sketches_[elem.column_idx].Push(elem.value, w);
if (is_valid(elem)) {
if (!threads_wl.is_column_splited[elem.column_idx] &&
(elem.column_idx >= wl.columns.front()) &&
(elem.column_idx <= wl.columns.back())) {
PushElement(elem, &categories_[elem.column_idx],
&sketches_[elem.column_idx], w);
}
}
}
}
}
}

#pragma omp barrier
if (wl.n_splits > 1 && wl.split_idx == 0) {
/* The thread being responsible for the first block in split
* collect info from the other ones.
*/
size_t column_idx = wl.columns.front();

typename WQSketch::SummaryContainer main_summary;
main_summary.Reserve(sketches_[column_idx].limit_size);
typename WQSketch::SummaryContainer split_summary;
split_summary.Reserve(2 * sketches_[column_idx].limit_size);
typename WQSketch::SummaryContainer comb_summary;
comb_summary.Reserve(3 * sketches_[column_idx].limit_size);

for (size_t th = tid + 0; th < tid + wl.n_splits; ++th) {
CHECK_LT(th, threads_wl.baskets.size());
// Make shure some work was done by thread
if (buff_was_used[th] > 0) {
if (IsCat(feature_types_, column_idx)) {
categories_[column_idx].merge(categories_buff[th]);
} else {
sketches_buff[th].GetSummary(&split_summary);

comb_summary.SetCombine(main_summary, split_summary);
main_summary.SetPrune(comb_summary, sketches_[column_idx].limit_size);
}
}
}
sketches_[column_idx].PushSummary(main_summary);
}
});
}
exc.Rethrow();
Expand All @@ -893,6 +1015,18 @@ class SketchContainerImpl {
private:
// Merge all categories from other workers.
void AllreduceCategories(Context const* ctx, MetaInfo const& info);

template <class ElemType>
void PushElement(const ElemType& elem,
std::set<float>* categorie,
WQSketch* sketch,
float w) {
if (IsCat(feature_types_, elem.column_idx)) {
categorie->emplace(elem.value);
} else {
sketch->Push(elem.value, w);
}
}
};

class HostSketchContainer : public SketchContainerImpl<WQuantileSketch<float, float>> {
Expand Down
23 changes: 19 additions & 4 deletions tests/cpp/common/test_quantile.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@ namespace xgboost::common {
TEST(Quantile, LoadBalance) {
size_t constexpr kRows = 1000, kCols = 100;
auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix();
std::vector<bst_feature_t> cols_ptr;
WLBalance threads_wl(kCols);
Context ctx;
for (auto const& page : m->GetBatches<SparsePage>(&ctx)) {
data::SparsePageAdapterBatch adapter{page.GetView()};
cols_ptr = LoadBalance(adapter, page.data.Size(), kCols, 13, [](auto) { return true; });
threads_wl = LoadBalance(adapter, page.data.Size(), kCols, 13, [](auto) { return true; });
}
size_t n_cols = 0;
for (size_t i = 1; i < cols_ptr.size(); ++i) {
n_cols += cols_ptr[i] - cols_ptr[i - 1];
for (const auto& basket : threads_wl.baskets) {
n_cols += basket.columns.size();
for (size_t column : basket.columns) {
CHECK_LT(column, kCols);
}
}
CHECK_EQ(n_cols, kCols);
}
Expand Down Expand Up @@ -160,6 +163,12 @@ TEST(Quantile, DistributedBasic) {
TestDistributedQuantile<false>(kRows, kCols);
}

TEST(Quantile, DistributedRowWise) {
size_t kRows = (1u << 16) * common::OmpGetNumThreads(0);
size_t kCols = 2;
TestDistributedQuantile<false>(kRows, kCols);
}

TEST(Quantile, Distributed) {
constexpr size_t kRows = 4000, kCols = 200;
TestDistributedQuantile<false>(kRows, kCols);
Expand Down Expand Up @@ -288,6 +297,12 @@ TEST(Quantile, ColumnSplitBasic) {
TestColSplitQuantile<false>(kRows, kCols);
}

TEST(Quantile, ColumnSplitRowWise) {
size_t kRows = (1u << 16) * common::OmpGetNumThreads(0);
size_t kCols = 2;
TestColSplitQuantile<false>(kRows, kCols);
}

TEST(Quantile, ColumnSplit) {
constexpr size_t kRows = 4000, kCols = 200;
TestColSplitQuantile<false>(kRows, kCols);
Expand Down
Loading