From b3eaaaced09ff873f2693d844d353723535e6bab Mon Sep 17 00:00:00 2001 From: Jimmy Lu Date: Fri, 31 Jan 2025 08:34:32 -0800 Subject: [PATCH] fix: Support dictionary with nulls in RowVector::pushDictionaryToRowVectorLeaves (#12220) Summary: `RowVector::pushDictionaryToRowVectorLeaves` used to stop combining dictionaries if the dictionary introducing nulls over underlying `RowVector`. We add support for this case as well in this change. Differential Revision: D68919058 --- velox/vector/ComplexVector.cpp | 82 ++++++++++++++++++++++--------- velox/vector/tests/VectorTest.cpp | 13 ++--- 2 files changed, 63 insertions(+), 32 deletions(-) diff --git a/velox/vector/ComplexVector.cpp b/velox/vector/ComplexVector.cpp index bc2295325fdc..50ece9f00721 100644 --- a/velox/vector/ComplexVector.cpp +++ b/velox/vector/ComplexVector.cpp @@ -687,24 +687,17 @@ struct Wrapper { BufferPtr indices; }; -void combineWrappers( - std::vector& wrappers, +template +void forEachCombinedIndex( + const std::vector& wrappers, vector_size_t size, - memory::MemoryPool* pool) { - std::vector wrapInfos(wrappers.size()); + F&& f) { std::vector sourceIndices(wrappers.size()); - uint64_t* rawNulls = nullptr; for (int i = 0; i < wrappers.size(); ++i) { - wrapInfos[i] = wrappers[i].dictionary->wrapInfo(); - VELOX_CHECK_NOT_NULL(wrapInfos[i]); - sourceIndices[i] = wrapInfos[i]->as(); - if (!rawNulls && wrappers[i].dictionary->nulls()) { - wrappers.back().nulls = allocateNulls(size, pool); - rawNulls = wrappers.back().nulls->asMutable(); - } + auto& wrapInfo = wrappers[i].dictionary->wrapInfo(); + VELOX_CHECK_NOT_NULL(wrapInfo); + sourceIndices[i] = wrapInfo->as(); } - wrappers.back().indices = allocateIndices(size, pool); - auto* rawIndices = wrappers.back().indices->asMutable(); for (vector_size_t j = 0; j < size; ++j) { auto index = j; bool isNull = false; @@ -715,12 +708,55 @@ void combineWrappers( } index = sourceIndices[i][index]; } - if (isNull) { - bits::setNull(rawNulls, j); - } else { - rawIndices[j] = index; + f(j, index, isNull); + } +} + +void combineWrappers( + std::vector& wrappers, + vector_size_t size, + memory::MemoryPool* pool) { + uint64_t* rawNulls = nullptr; + for (int i = 0; i < wrappers.size(); ++i) { + if (!rawNulls && wrappers[i].dictionary->nulls()) { + wrappers.back().nulls = allocateNulls(size, pool); + rawNulls = wrappers.back().nulls->asMutable(); + break; } } + wrappers.back().indices = allocateIndices(size, pool); + auto* rawIndices = wrappers.back().indices->asMutable(); + forEachCombinedIndex( + wrappers, + size, + [&](vector_size_t outer, vector_size_t inner, bool isNull) { + if (isNull) { + bits::setNull(rawNulls, outer); + } else { + rawIndices[outer] = inner; + } + }); +} + +BufferPtr combineNulls( + const std::vector& wrappers, + vector_size_t size, + const uint64_t* valueNulls, + memory::MemoryPool* pool) { + if (wrappers.size() == 1 && !valueNulls) { + return wrappers[0].dictionary->nulls(); + } + auto nulls = allocateNulls(size, pool); + auto* rawNulls = nulls->asMutable(); + forEachCombinedIndex( + wrappers, + size, + [&](vector_size_t outer, vector_size_t inner, bool isNull) { + if (isNull || (valueNulls && bits::isBitNull(valueNulls, inner))) { + bits::setNull(rawNulls, outer); + } + }); + return nulls; } VectorPtr wrapInDictionary( @@ -765,9 +801,11 @@ VectorPtr pushDictionaryToRowVectorLeavesImpl( } case VectorEncoding::Simple::ROW: { VELOX_CHECK_EQ(values->typeKind(), TypeKind::ROW); + auto nulls = values->nulls(); for (auto& wrapper : wrappers) { if (wrapper.dictionary->nulls()) { - return wrapInDictionary(wrappers, size, values, pool); + nulls = combineNulls(wrappers, size, values->rawNulls(), pool); + break; } } auto children = values->asUnchecked()->children(); @@ -778,11 +816,7 @@ VectorPtr pushDictionaryToRowVectorLeavesImpl( } } return std::make_shared( - pool, - values->type(), - values->nulls(), - values->size(), - std::move(children)); + pool, values->type(), std::move(nulls), size, std::move(children)); } case VectorEncoding::Simple::DICTIONARY: { Wrapper wrapper{values, nullptr, nullptr}; diff --git a/velox/vector/tests/VectorTest.cpp b/velox/vector/tests/VectorTest.cpp index 29dd57485000..b826d2694c62 100644 --- a/velox/vector/tests/VectorTest.cpp +++ b/velox/vector/tests/VectorTest.cpp @@ -3867,14 +3867,11 @@ TEST_F(VectorTest, pushDictionaryToRowVectorLeaves) { auto& c4c2 = c4Row->childAt(2); ASSERT_EQ(c4c2->encoding(), VectorEncoding::Simple::DICTIONARY); ASSERT_EQ(c4c0->wrapInfo().get(), c4c2->wrapInfo().get()); - auto& c5 = outputRow->childAt(5); - ASSERT_EQ(c5->encoding(), VectorEncoding::Simple::DICTIONARY); - ASSERT_EQ(c5->valueVector()->encoding(), VectorEncoding::Simple::ROW); - auto* c5Row = c5->valueVector()->asUnchecked(); - auto& c5c0 = c5Row->childAt(0); - ASSERT_EQ(c5c0->encoding(), VectorEncoding::Simple::FLAT); - auto& c5c1 = c5Row->childAt(1); - ASSERT_EQ(c5c1->encoding(), VectorEncoding::Simple::FLAT); + auto* c5 = outputRow->childAt(5)->asChecked(); + auto& c5c0 = c5->childAt(0); + ASSERT_EQ(c5c0->encoding(), VectorEncoding::Simple::DICTIONARY); + auto& c5c1 = c5->childAt(1); + ASSERT_EQ(c5c1->encoding(), VectorEncoding::Simple::DICTIONARY); } }