Skip to content

Commit

Permalink
Move hash table overlap bits check earlier (#10472)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #10472

Original commit changeset: f0f267dce460

Original Phabricator Diff: D59640854

Reviewed By: Yuhta

Differential Revision: D59782277

fbshipit-source-id: dfd4c9acef2b4c15380d47610ba924383172b26a
  • Loading branch information
Jialiang Tan authored and facebook-github-bot committed Jul 16, 2024
1 parent 5eac2f4 commit 3598307
Show file tree
Hide file tree
Showing 11 changed files with 180 additions and 128 deletions.
8 changes: 4 additions & 4 deletions velox/exec/GroupingSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,15 +228,14 @@ void GroupingSet::addInputForActiveRows(
*lookup_,
input,
activeRows_,
ignoreNullKeys_,
BaseHashTable::kNoSpillInputStartPartitionBit);
if (lookup_->rows.empty()) {
// No rows to probe. Can happen when ignoreNullKeys_ is true and all rows
// have null keys.
return;
}

table_->groupProbe(*lookup_);
table_->groupProbe(*lookup_, BaseHashTable::kNoSpillInputStartPartitionBit);
masks_.addInput(input, activeRows_);

auto* groups = lookup_->hits.data();
Expand Down Expand Up @@ -400,7 +399,7 @@ void GroupingSet::createHashTable() {

lookup_ = std::make_unique<HashLookup>(table_->hashers());
if (!isAdaptive_ && table_->hashMode() != BaseHashTable::HashMode::kHash) {
table_->forceGenericHashMode();
table_->forceGenericHashMode(BaseHashTable::kNoSpillInputStartPartitionBit);
}
}

Expand Down Expand Up @@ -812,7 +811,8 @@ bool GroupingSet::isPartialFull(int64_t maxBytes) {
// per 32 buckets.
if (stats.capacity * sizeof(void*) > maxBytes / 16 &&
stats.numDistinct < stats.capacity / 32) {
table_->decideHashMode(0, true);
table_->decideHashMode(
0, BaseHashTable::kNoSpillInputStartPartitionBit, true);
}
return allocatedBytes() > maxBytes;
}
Expand Down
2 changes: 1 addition & 1 deletion velox/exec/GroupingSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ class GroupingSet {
// groups.
void extractSpillResult(const RowVectorPtr& result);

// Return a list of accumulators for 'aggregates_', plus one more accumulator
// Returns a list of accumulators for 'aggregates_', plus one more accumulator
// for 'sortedAggregations_', and one for each 'distinctAggregations_'. When
// 'excludeToIntermediate' is true, skip the functions that support
// 'toIntermediate'.
Expand Down
6 changes: 3 additions & 3 deletions velox/exec/HashBuild.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -752,10 +752,10 @@ bool HashBuild::finishHashBuild() {
CpuWallTimer cpuWallTimer{timing};
table_->prepareJoinTable(
std::move(otherTables),
allowParallelJoinBuild ? operatorCtx_->task()->queryCtx()->executor()
: nullptr,
isInputFromSpill() ? spillConfig()->startPartitionBit
: BaseHashTable::kNoSpillInputStartPartitionBit);
: BaseHashTable::kNoSpillInputStartPartitionBit,
allowParallelJoinBuild ? operatorCtx_->task()->queryCtx()->executor()
: nullptr);
}
stats_.wlock()->addRuntimeStat(
BaseHashTable::kBuildWallNanos,
Expand Down
90 changes: 55 additions & 35 deletions velox/exec/HashTable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,9 @@ void populateNormalizedKeys(HashLookup& lookup, int8_t sizeBits) {
} // namespace

template <bool ignoreNullKeys>
void HashTable<ignoreNullKeys>::groupProbe(HashLookup& lookup) {
void HashTable<ignoreNullKeys>::groupProbe(
HashLookup& lookup,
int8_t spillInputStartPartitionBit) {
incrementProbes(lookup.rows.size());

if (hashMode_ == HashMode::kArray) {
Expand All @@ -459,7 +461,7 @@ void HashTable<ignoreNullKeys>::groupProbe(HashLookup& lookup) {
}
// Do size-based rehash before mixing hashes from normalized keys
// because the size of the table affects the mixing.
checkSize(lookup.rows.size(), false);
checkSize(lookup.rows.size(), false, spillInputStartPartitionBit);
if (hashMode_ == HashMode::kNormalizedKey) {
populateNormalizedKeys(lookup, sizeBits_);
groupNormalizedKeyProbe(lookup);
Expand Down Expand Up @@ -706,7 +708,9 @@ void HashTable<ignoreNullKeys>::joinNormalizedKeyProbe(HashLookup& lookup) {
}

template <bool ignoreNullKeys>
void HashTable<ignoreNullKeys>::allocateTables(uint64_t size) {
void HashTable<ignoreNullKeys>::allocateTables(
uint64_t size,
int8_t spillInputStartPartitionBit) {
VELOX_CHECK(bits::isPowerOfTwo(size), "Size is not a power of two: {}", size);
VELOX_CHECK_GT(size, 0);
capacity_ = size;
Expand All @@ -716,6 +720,7 @@ void HashTable<ignoreNullKeys>::allocateTables(uint64_t size) {
sizeMask_ = byteSize - 1;
numBuckets_ = byteSize / kBucketSize;
sizeBits_ = __builtin_popcountll(sizeMask_);
checkHashBitsOverlap(spillInputStartPartitionBit);
bucketOffsetMask_ = sizeMask_ & ~(kBucketSize - 1);
// The total size is 8 bytes per slot, in groups of 16 slots with 16 bytes of
// tags and 16 * 6 bytes of pointers and a padding of 16 bytes to round up the
Expand Down Expand Up @@ -754,7 +759,8 @@ void HashTable<ignoreNullKeys>::clear(bool freeTable) {
template <bool ignoreNullKeys>
void HashTable<ignoreNullKeys>::checkSize(
int32_t numNew,
bool initNormalizedKeys) {
bool initNormalizedKeys,
int8_t spillInputStartPartitionBit) {
// NOTE: the way we decide the table size and trigger rehash, guarantees the
// table should always have free slots after the insertion.
VELOX_CHECK(
Expand All @@ -768,9 +774,9 @@ void HashTable<ignoreNullKeys>::checkSize(
const int64_t newNumDistincts = numNew + numDistinct_;
if (table_ == nullptr || capacity_ == 0) {
const auto newSize = newHashTableEntries(numDistinct_, numNew);
allocateTables(newSize);
allocateTables(newSize, spillInputStartPartitionBit);
if (numDistinct_ > 0) {
rehash(initNormalizedKeys);
rehash(initNormalizedKeys, spillInputStartPartitionBit);
}
// We are not always able to reuse a tombstone slot as a free one for hash
// collision handling purpose. For example, if all the table slots are
Expand All @@ -782,8 +788,8 @@ void HashTable<ignoreNullKeys>::checkSize(
// NOTE: we need to plus one here as number itself could be power of two.
const auto newCapacity = bits::nextPowerOfTwo(
std::max(newNumDistincts, capacity_ - numTombstones_) + 1);
allocateTables(newCapacity);
rehash(initNormalizedKeys);
allocateTables(newCapacity, spillInputStartPartitionBit);
rehash(initNormalizedKeys, spillInputStartPartitionBit);
}
}

Expand Down Expand Up @@ -1276,7 +1282,9 @@ void HashTable<ignoreNullKeys>::insertForJoin(
}

template <bool ignoreNullKeys>
void HashTable<ignoreNullKeys>::rehash(bool initNormalizedKeys) {
void HashTable<ignoreNullKeys>::rehash(
bool initNormalizedKeys,
int8_t spillInputStartPartitionBit) {
++numRehashes_;
constexpr int32_t kHashBatchSize = 1024;
if (canApplyParallelJoinBuild()) {
Expand All @@ -1299,15 +1307,18 @@ void HashTable<ignoreNullKeys>::rehash(bool initNormalizedKeys) {
if (!insertBatch(
groups, numGroups, hashes, initNormalizedKeys || i != 0)) {
VELOX_CHECK_NE(hashMode_, HashMode::kHash);
setHashMode(HashMode::kHash, 0);
setHashMode(HashMode::kHash, 0, spillInputStartPartitionBit);
return;
}
} while (numGroups > 0);
}
}

template <bool ignoreNullKeys>
void HashTable<ignoreNullKeys>::setHashMode(HashMode mode, int32_t numNew) {
void HashTable<ignoreNullKeys>::setHashMode(
HashMode mode,
int32_t numNew,
int8_t spillInputStartPartitionBit) {
VELOX_CHECK_NE(hashMode_, HashMode::kHash);
TestValue::adjust("facebook::velox::exec::HashTable::setHashMode", &mode);
if (mode == HashMode::kArray) {
Expand All @@ -1317,7 +1328,7 @@ void HashTable<ignoreNullKeys>::setHashMode(HashMode mode, int32_t numNew) {
table_ = tableAllocation_.data<char*>();
memset(table_, 0, bytes);
hashMode_ = HashMode::kArray;
rehash(true);
rehash(true, spillInputStartPartitionBit);
} else if (mode == HashMode::kHash) {
hashMode_ = HashMode::kHash;
for (auto& hasher : hashers_) {
Expand All @@ -1326,12 +1337,12 @@ void HashTable<ignoreNullKeys>::setHashMode(HashMode mode, int32_t numNew) {
rows_->disableNormalizedKeys();
capacity_ = 0;
// Makes tables of the right size and rehashes.
checkSize(numNew, true);
checkSize(numNew, true, spillInputStartPartitionBit);
} else if (mode == HashMode::kNormalizedKey) {
hashMode_ = HashMode::kNormalizedKey;
capacity_ = 0;
// Makes tables of the right size and rehashes.
checkSize(numNew, true);
checkSize(numNew, true, spillInputStartPartitionBit);
}
}

Expand Down Expand Up @@ -1459,6 +1470,7 @@ void HashTable<ignoreNullKeys>::clearUseRange(std::vector<bool>& useRange) {
template <bool ignoreNullKeys>
void HashTable<ignoreNullKeys>::decideHashMode(
int32_t numNew,
int8_t spillInputStartPartitionBit,
bool disableRangeArrayHash) {
std::vector<uint64_t> rangeSizes(hashers_.size());
std::vector<uint64_t> distinctSizes(hashers_.size());
Expand All @@ -1475,7 +1487,7 @@ void HashTable<ignoreNullKeys>::decideHashMode(
disableRangeArrayHash_ |= disableRangeArrayHash;
if (numDistinct_ && !isJoinBuild_) {
if (!analyze()) {
setHashMode(HashMode::kHash, numNew);
setHashMode(HashMode::kHash, numNew, spillInputStartPartitionBit);
return;
}
}
Expand All @@ -1500,38 +1512,38 @@ void HashTable<ignoreNullKeys>::decideHashMode(
if (rangesWithReserve < kArrayHashMaxSize && !disableRangeArrayHash_) {
std::fill(useRange.begin(), useRange.end(), true);
capacity_ = setHasherMode(hashers_, useRange, rangeSizes, distinctSizes);
setHashMode(HashMode::kArray, numNew);
setHashMode(HashMode::kArray, numNew, spillInputStartPartitionBit);
return;
}

if (bestWithReserve < kArrayHashMaxSize ||
(disableRangeArrayHash_ && bestWithReserve < numDistinct_ * 2)) {
capacity_ = setHasherMode(hashers_, useRange, rangeSizes, distinctSizes);
setHashMode(HashMode::kArray, numNew);
setHashMode(HashMode::kArray, numNew, spillInputStartPartitionBit);
return;
}
if (rangesWithReserve != VectorHasher::kRangeTooLarge) {
std::fill(useRange.begin(), useRange.end(), true);
setHasherMode(hashers_, useRange, rangeSizes, distinctSizes);
setHashMode(HashMode::kNormalizedKey, numNew);
setHashMode(HashMode::kNormalizedKey, numNew, spillInputStartPartitionBit);
return;
}
if (hashers_.size() == 1 && distinctsWithReserve > 10000) {
// A single part group by that does not go by range or become an array
// does not make sense as a normalized key unless it is very small.
setHashMode(HashMode::kHash, numNew);
setHashMode(HashMode::kHash, numNew, spillInputStartPartitionBit);
return;
}

if (distinctsWithReserve < kArrayHashMaxSize) {
clearUseRange(useRange);
capacity_ = setHasherMode(hashers_, useRange, rangeSizes, distinctSizes);
setHashMode(HashMode::kArray, numNew);
setHashMode(HashMode::kArray, numNew, spillInputStartPartitionBit);
return;
}
if (distinctsWithReserve == VectorHasher::kRangeTooLarge &&
rangesWithReserve == VectorHasher::kRangeTooLarge) {
setHashMode(HashMode::kHash, numNew);
setHashMode(HashMode::kHash, numNew, spillInputStartPartitionBit);
return;
}
// The key concatenation fits in 64 bits.
Expand All @@ -1541,7 +1553,7 @@ void HashTable<ignoreNullKeys>::decideHashMode(
clearUseRange(useRange);
}
setHasherMode(hashers_, useRange, rangeSizes, distinctSizes);
setHashMode(HashMode::kNormalizedKey, numNew);
setHashMode(HashMode::kNormalizedKey, numNew, spillInputStartPartitionBit);
}

template <bool ignoreNullKeys>
Expand All @@ -1555,6 +1567,15 @@ std::vector<RowContainer*> HashTable<ignoreNullKeys>::allRows() const {
return rowContainers;
}

template <bool ignoreNullKeys>
void HashTable<ignoreNullKeys>::checkHashBitsOverlap(
int8_t spillInputStartPartitionBit) {
if (spillInputStartPartitionBit != kNoSpillInputStartPartitionBit &&
hashMode() != HashMode::kArray) {
VELOX_CHECK_LT(sizeBits_ - 1, spillInputStartPartitionBit);
}
}

template <bool ignoreNullKeys>
std::string HashTable<ignoreNullKeys>::toString() {
std::stringstream out;
Expand Down Expand Up @@ -1682,8 +1703,8 @@ bool mayUseValueIds(const BaseHashTable& table) {
template <bool ignoreNullKeys>
void HashTable<ignoreNullKeys>::prepareJoinTable(
std::vector<std::unique_ptr<BaseHashTable>> tables,
folly::Executor* executor,
int8_t spillInputStartPartitionBit) {
int8_t spillInputStartPartitionBit,
folly::Executor* executor) {
buildExecutor_ = executor;
otherTables_.reserve(tables.size());
for (auto& table : tables) {
Expand Down Expand Up @@ -1719,14 +1740,13 @@ void HashTable<ignoreNullKeys>::prepareJoinTable(
}
if (!useValueIds) {
if (hashMode_ != HashMode::kHash) {
setHashMode(HashMode::kHash, 0);
setHashMode(HashMode::kHash, 0, spillInputStartPartitionBit);
} else {
checkSize(0, true);
checkSize(0, true, spillInputStartPartitionBit);
}
} else {
decideHashMode(0);
decideHashMode(0, spillInputStartPartitionBit);
}
checkHashBitsOverlap(spillInputStartPartitionBit);
}

template <bool ignoreNullKeys>
Expand Down Expand Up @@ -2081,11 +2101,11 @@ std::string BaseHashTable::RowsIterator::toString() const {
rowContainerIterator_.toString());
}

void BaseHashTable::prepareForGroupProbe(
template <bool ignoreNullKeys>
void HashTable<ignoreNullKeys>::prepareForGroupProbe(
HashLookup& lookup,
const RowVectorPtr& input,
SelectivityVector& rows,
bool ignoreNullKeys,
int8_t spillInputStartPartitionBit) {
checkHashBitsOverlap(spillInputStartPartitionBit);
auto& hashers = lookup.hashers;
Expand All @@ -2095,7 +2115,7 @@ void BaseHashTable::prepareForGroupProbe(
hasher->decode(*key, rows);
}

if (ignoreNullKeys) {
if constexpr (ignoreNullKeys) {
// A null in any of the keys disables the row.
deselectRowsWithNulls(hashers, rows);
}
Expand All @@ -2117,19 +2137,19 @@ void BaseHashTable::prepareForGroupProbe(

if (rehash || capacity() == 0) {
if (mode != BaseHashTable::HashMode::kHash) {
decideHashMode(input->size());
decideHashMode(input->size(), spillInputStartPartitionBit);
// Do not forward 'ignoreNullKeys' to avoid redundant evaluation of
// deselectRowsWithNulls.
prepareForGroupProbe(
lookup, input, rows, false, spillInputStartPartitionBit);
prepareForGroupProbe(lookup, input, rows, spillInputStartPartitionBit);
return;
}
}

populateLookupRows(rows, lookup.rows);
}

void BaseHashTable::prepareForJoinProbe(
template <bool ignoreNullKeys>
void HashTable<ignoreNullKeys>::prepareForJoinProbe(
HashLookup& lookup,
const RowVectorPtr& input,
SelectivityVector& rows,
Expand Down
Loading

0 comments on commit 3598307

Please sign in to comment.