Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization of PushRowPage for high number of cpu cores #11182

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 92 additions & 22 deletions src/common/quantile.h
Original file line number Diff line number Diff line change
Expand Up @@ -840,47 +840,117 @@ class SketchContainerImpl {
template <typename Batch, typename IsValid>
void PushRowPageImpl(Batch const &batch, size_t base_rowid, OptionalWeights weights, size_t nnz,
size_t n_features, bool is_dense, IsValid is_valid) {
auto thread_columns_ptr = LoadBalance(batch, nnz, n_features, n_threads_, is_valid);

dmlc::OMPException exc;
#pragma omp parallel num_threads(n_threads_)
{
exc.Run([&]() {
auto tid = static_cast<uint32_t>(omp_get_thread_num());
auto const begin = thread_columns_ptr[tid];
auto const end = thread_columns_ptr[tid + 1];

// do not iterate if no columns are assigned to the thread
if (begin < end && end <= n_features) {
for (size_t ridx = 0; ridx < batch.Size(); ++ridx) {
size_t ridx_block_size = batch.Size() / n_threads_ + (batch.Size() % n_threads_ > 0);
size_t min_ridx_block_size = 1024;
if ((n_features < static_cast<size_t>(n_threads_)) &&
(ridx_block_size > min_ridx_block_size)) {
/* Row-wise parallelisation.
*/
std::vector<std::set<float>> categories_buff(n_threads_ * n_features);
std::vector<WQSketch> sketches_buff(n_threads_ * n_features);

#pragma omp parallel num_threads(n_threads_)
{
exc.Run([&]() {
auto tid = static_cast<uint32_t>(omp_get_thread_num());
WQSketch* sketches_th = sketches_buff.data() + tid * n_features;
std::set<float>* categories_th = categories_buff.data() + tid * n_features;

for (size_t ii = 0; ii < n_features; ii++) {
auto n_bins = std::min(static_cast<bst_idx_t>(max_bins_), columns_size_[ii]);
auto eps = 1.0 / (static_cast<float>(n_bins) * WQSketch::kFactor);
sketches_th[ii].Init(columns_size_[ii], eps);
}

size_t ridx_begin = tid * ridx_block_size;
size_t ridx_end = std::min(ridx_begin + ridx_block_size, batch.Size());
for (size_t ridx = ridx_begin; ridx < ridx_end; ++ridx) {
auto const &line = batch.GetLine(ridx);
auto w = weights[ridx + base_rowid];
if (is_dense) {
for (size_t ii = begin; ii < end; ii++) {
for (size_t ii = 0; ii < n_features; ii++) {
auto elem = line.GetElement(ii);
if (is_valid(elem)) {
if (IsCat(feature_types_, ii)) {
categories_[ii].emplace(elem.value);
categories_th[ii].emplace(elem.value);
} else {
sketches_[ii].Push(elem.value, w);
sketches_th[ii].Push(elem.value, w);
}
}
}
} else {
for (size_t i = 0; i < line.Size(); ++i) {
auto const &elem = line.GetElement(i);
if (is_valid(elem) && elem.column_idx >= begin && elem.column_idx < end) {
for (size_t ii = 0; ii < line.Size(); ++ii) {
auto elem = line.GetElement(ii);
if (is_valid(elem)) {
if (IsCat(feature_types_, elem.column_idx)) {
categories_[elem.column_idx].emplace(elem.value);
categories_th[elem.column_idx].emplace(elem.value);
} else {
sketches_[elem.column_idx].Push(elem.value, w);
sketches_th[elem.column_idx].Push(elem.value, w);
}
}
}
}
}
}
});
#pragma omp barrier

size_t fidx_block_size = n_features / n_threads_ + (n_features % n_threads_ > 0);
size_t fidx_begin = tid * fidx_block_size;
size_t fidx_end = std::min(fidx_begin + fidx_block_size, n_features);
for (size_t ii = fidx_begin; ii < fidx_end; ++ii) {
for (int th = 0; th < n_threads_; ++th) {
if (IsCat(feature_types_, ii)) {
categories_[ii].merge(categories_buff[th * n_features + ii]);
} else {
typename WQSketch::SummaryContainer summary;
sketches_buff[th * n_features + ii].GetSummary(&summary);
sketches_[ii].PushSummary(summary);
}
}
}
});
}
} else {
auto thread_columns_ptr = LoadBalance(batch, nnz, n_features, n_threads_, is_valid);
#pragma omp parallel num_threads(n_threads_)
{
exc.Run([&]() {
auto tid = static_cast<uint32_t>(omp_get_thread_num());
auto const begin = thread_columns_ptr[tid];
auto const end = thread_columns_ptr[tid + 1];

// do not iterate if no columns are assigned to the thread
if (begin < end && end <= n_features) {
for (size_t ridx = 0; ridx < batch.Size(); ++ridx) {
auto const &line = batch.GetLine(ridx);
auto w = weights[ridx + base_rowid];
if (is_dense) {
for (size_t ii = begin; ii < end; ii++) {
auto elem = line.GetElement(ii);
if (is_valid(elem)) {
if (IsCat(feature_types_, ii)) {
categories_[ii].emplace(elem.value);
} else {
sketches_[ii].Push(elem.value, w);
}
}
}
} else {
for (size_t i = 0; i < line.Size(); ++i) {
auto const &elem = line.GetElement(i);
if (is_valid(elem) && elem.column_idx >= begin && elem.column_idx < end) {
if (IsCat(feature_types_, elem.column_idx)) {
categories_[elem.column_idx].emplace(elem.value);
} else {
sketches_[elem.column_idx].Push(elem.value, w);
}
}
}
}
}
}
});
}
}
exc.Rethrow();
}
Expand Down
Loading