From 2fcbb3fd76064a9245ee569bc37ce1dbb01b8c9d Mon Sep 17 00:00:00 2001 From: Xiaoxuan Meng Date: Fri, 20 Dec 2024 09:41:20 -0800 Subject: [PATCH] [native] Fix unsafe row exchange source with compression support --- .../operators/UnsafeRowExchangeSource.cpp | 22 ++++++++++++++----- .../main/operators/tests/BroadcastTest.cpp | 3 ++- .../presto_cpp/main/tests/TaskManagerTest.cpp | 3 ++- presto-native-execution/velox | 2 +- 4 files changed, 22 insertions(+), 8 deletions(-) diff --git a/presto-native-execution/presto_cpp/main/operators/UnsafeRowExchangeSource.cpp b/presto-native-execution/presto_cpp/main/operators/UnsafeRowExchangeSource.cpp index e952a62970e44..1c1453d393c29 100644 --- a/presto-native-execution/presto_cpp/main/operators/UnsafeRowExchangeSource.cpp +++ b/presto-native-execution/presto_cpp/main/operators/UnsafeRowExchangeSource.cpp @@ -16,6 +16,7 @@ #include "presto_cpp/main/common/Configs.h" #include "presto_cpp/main/operators/UnsafeRowExchangeSource.h" +#include "velox/serializers/RowSerializer.h" namespace facebook::presto::operators { @@ -36,7 +37,7 @@ UnsafeRowExchangeSource::request( return std::move(shuffle_->next()) .deferValue([this](velox::BufferPtr buffer) { std::vector promises; - int64_t totalBytes = 0; + int64_t totalBytes{0}; { std::lock_guard l(queue_->mutex()); @@ -45,14 +46,25 @@ UnsafeRowExchangeSource::request( queue_->enqueueLocked(nullptr, promises); } else { totalBytes = buffer->size(); + VELOX_CHECK_LE(totalBytes, std::numeric_limits::max()); ++numBatches_; - - auto ioBuf = - folly::IOBuf::wrapBuffer(buffer->as(), buffer->size()); + velox::serializer::detail::RowGroupHeader rowHeader{ + .uncompressedSize = static_cast(totalBytes), + .compressedSize = static_cast(totalBytes), + .compressed = false}; + auto headBuffer = std::make_shared( + velox::serializer::detail::RowGroupHeader::size(), '0'); + rowHeader.write(const_cast(headBuffer->data())); + + auto ioBuf = folly::IOBuf::wrapBuffer( + headBuffer->data(), headBuffer->size()); + ioBuf->appendToChain( + folly::IOBuf::wrapBuffer(buffer->as(), buffer->size())); queue_->enqueueLocked( std::make_unique( - std::move(ioBuf), [buffer](auto& /*unused*/) {}), + std::move(ioBuf), + [buffer, headBuffer](auto& /*unused*/) {}), promises); } } diff --git a/presto-native-execution/presto_cpp/main/operators/tests/BroadcastTest.cpp b/presto-native-execution/presto_cpp/main/operators/tests/BroadcastTest.cpp index 0973a8359c0b2..f206ac6e1c71b 100644 --- a/presto-native-execution/presto_cpp/main/operators/tests/BroadcastTest.cpp +++ b/presto-native-execution/presto_cpp/main/operators/tests/BroadcastTest.cpp @@ -216,7 +216,8 @@ class BroadcastTest : public exec::test::OperatorTestBase { pool(), dataType, velox::getNamedVectorSerde(velox::VectorSerde::Kind::kPresto), - &result); + &result, + nullptr); return result; } }; diff --git a/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp b/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp index 43190f2e5c039..8d92af63c56b0 100644 --- a/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp +++ b/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp @@ -165,7 +165,8 @@ class Cursor { std::vector vectors; while (!input->atEnd()) { RowVectorPtr vector; - VectorStreamGroup::read(input.get(), pool_, rowType_, serde, &vector); + VectorStreamGroup::read( + input.get(), pool_, rowType_, serde, &vector, nullptr); vectors.emplace_back(vector); } return vectors; diff --git a/presto-native-execution/velox b/presto-native-execution/velox index 12942c1eb76de..9265fbfd9a071 160000 --- a/presto-native-execution/velox +++ b/presto-native-execution/velox @@ -1 +1 @@ -Subproject commit 12942c1eb76de019b775f4b207bc4595d8ace5c0 +Subproject commit 9265fbfd9a0716136456ef8b6e455ff110a0f7da