Skip to content

Commit

Permalink
[native] Change PartitionAndSerialize to have pass in partition function
Browse files Browse the repository at this point in the history
  • Loading branch information
tanjialiang committed Dec 7, 2022
1 parent 73ad442 commit 51c51c4
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,9 @@ class PartitionAndSerializeOperator : public Operator {
planNode->outputType(),
operatorId,
planNode->id(),
"PartitionAndSerialize") {
auto inputType = planNode->sources()[0]->outputType();
auto keyChannels = toChannels(inputType, planNode->keys());

// Initialize the hive partition function.
const auto numPartitions = planNode->numPartitions();
std::vector<int> bucketToPartition(numPartitions);
std::iota(bucketToPartition.begin(), bucketToPartition.end(), 0);
partitionFunction_ =
std::make_unique<connector::hive::HivePartitionFunction>(
planNode->numPartitions(),
std::move(bucketToPartition),
keyChannels);
}
"PartitionAndSerialize"),
partitionFunction_(
planNode->partitionFunctionFactory()(planNode->numPartitions())) {}

bool needsInput() const override {
return !input_;
Expand Down Expand Up @@ -155,7 +144,7 @@ class PartitionAndSerializeOperator : public Operator {
}
}

std::unique_ptr<connector::hive::HivePartitionFunction> partitionFunction_;
std::unique_ptr<core::PartitionFunction> partitionFunction_;
std::vector<uint32_t> partitions_;
std::vector<size_t> rowSizes_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,24 @@ class PartitionAndSerializeNode : public velox::core::PlanNode {
std::vector<velox::core::TypedExprPtr> keys,
uint32_t numPartitions,
velox::RowTypePtr outputType,
velox::core::PlanNodePtr source)
velox::core::PlanNodePtr source,
velox::core::PartitionFunctionFactory partitionFunctionFactory)
: velox::core::PlanNode(id),
keys_{std::move(keys)},
numPartitions_{numPartitions},
keys_(std::move(keys)),
numPartitions_(numPartitions),
outputType_{std::move(outputType)},
sources_{std::move(source)} {
sources_({std::move(source)}),
partitionFunctionFactory_(std::move(partitionFunctionFactory)) {
// Only verify output types are correct. Note column names are not enforced
// in the following check.
VELOX_USER_CHECK(
velox::ROW(
{"partition", "data"}, {velox::INTEGER(), velox::VARBINARY()})
->equivalent(*outputType_));
VELOX_USER_CHECK(!keys_.empty(), "Empty keys for hive partition");
VELOX_USER_CHECK(!keys_.empty(), "Empty partition keys");
VELOX_USER_CHECK_NOT_NULL(
partitionFunctionFactory_,
"Partition function factory cannot be null.");
}

const velox::RowTypePtr& outputType() const override {
Expand All @@ -56,6 +63,11 @@ class PartitionAndSerializeNode : public velox::core::PlanNode {
return numPartitions_;
}

const velox::core::PartitionFunctionFactory& partitionFunctionFactory()
const {
return partitionFunctionFactory_;
}

std::string_view name() const override {
return "PartitionAndSerialize";
}
Expand All @@ -67,6 +79,7 @@ class PartitionAndSerializeNode : public velox::core::PlanNode {
const uint32_t numPartitions_;
const velox::RowTypePtr outputType_;
const std::vector<velox::core::PlanNodePtr> sources_;
const velox::core::PartitionFunctionFactory partitionFunctionFactory_;
};

class PartitionAndSerializeTranslator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "presto_cpp/main/operators/PartitionAndSerialize.h"
#include "presto_cpp/main/operators/ShuffleWrite.h"
#include "presto_cpp/main/operators/UnsafeRowExchangeSource.h"
#include "velox/connectors/hive/HivePartitionFunction.h"
#include "velox/exec/Exchange.h"
#include "velox/exec/tests/utils/OperatorTestBase.h"
#include "velox/exec/tests/utils/PlanBuilder.h"
Expand Down Expand Up @@ -147,14 +148,21 @@ auto addPartitionAndSerializeNode(uint32_t numPartitions) {
return [numPartitions](
core::PlanNodeId nodeId,
core::PlanNodePtr source) -> core::PlanNodePtr {
auto outputType = ROW({"p", "d"}, {INTEGER(), VARBINARY()});

std::vector<core::TypedExprPtr> keys;
keys.push_back(
std::make_shared<core::FieldAccessTypedExpr>(INTEGER(), "c0"));

auto outputType = source->outputType();
return std::make_shared<PartitionAndSerializeNode>(
nodeId, keys, numPartitions, outputType, std::move(source));
nodeId,
keys,
numPartitions,
ROW({"p", "d"}, {INTEGER(), VARBINARY()}),
std::move(source),
[outputType, keys](int numPartitions) {
auto keyChannels = exec::toChannels(outputType, keys);
return std::make_unique<connector::hive::HivePartitionFunction>(
numPartitions, std::vector<int>(numPartitions), keyChannels);
});
};
}

Expand Down

0 comments on commit 51c51c4

Please sign in to comment.