Skip to content

Commit

Permalink
[native] Shuffle related entities should have shared ownership to Shu…
Browse files Browse the repository at this point in the history
…ffleInterface
  • Loading branch information
tanjialiang committed Dec 7, 2022
1 parent 51c51c4 commit 670ae5b
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class ShuffleWriteOperator : public Operator {
}

private:
ShuffleInterface* const FOLLY_NONNULL shuffle_;
const std::shared_ptr<ShuffleInterface> shuffle_;
};
} // namespace

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class ShuffleWriteNode : public velox::core::PlanNode {
public:
ShuffleWriteNode(
const velox::core::PlanNodeId& id,
ShuffleInterface* shuffle,
const std::shared_ptr<ShuffleInterface>& shuffle,
velox::core::PlanNodePtr source)
: velox::core::PlanNode(id),
shuffle_{shuffle},
Expand All @@ -37,7 +37,7 @@ class ShuffleWriteNode : public velox::core::PlanNode {
return sources_;
}

ShuffleInterface* shuffle() const {
const std::shared_ptr<ShuffleInterface>& shuffle() const {
return shuffle_;
}

Expand All @@ -48,7 +48,7 @@ class ShuffleWriteNode : public velox::core::PlanNode {
private:
void addDetails(std::stringstream& stream) const override {}

ShuffleInterface* shuffle_;
const std::shared_ptr<ShuffleInterface> shuffle_;
const std::vector<velox::core::PlanNodePtr> sources_;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class UnsafeRowExchangeSource : public velox::exec::ExchangeSource {
const std::string& taskId,
int destination,
std::shared_ptr<velox::exec::ExchangeQueue> queue,
ShuffleInterface* shuffle,
const std::shared_ptr<ShuffleInterface>& shuffle,
velox::memory::MemoryPool* pool)
: ExchangeSource(taskId, destination, queue, pool), shuffle_(shuffle) {}

Expand All @@ -39,6 +39,6 @@ class UnsafeRowExchangeSource : public velox::exec::ExchangeSource {
void close() override {}

private:
ShuffleInterface* shuffle_;
const std::shared_ptr<ShuffleInterface> shuffle_;
};
} // namespace facebook::presto::operators
} // namespace facebook::presto::operators
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class TestShuffle : public ShuffleInterface {
std::vector<std::vector<BufferPtr>> readyPartitions_;
};

void registerExchangeSource(ShuffleInterface* shuffle) {
void registerExchangeSource(const std::shared_ptr<ShuffleInterface>& shuffle) {
exec::ExchangeSource::registerFactory(
[shuffle](
const std::string& taskId,
Expand Down Expand Up @@ -166,7 +166,7 @@ auto addPartitionAndSerializeNode(uint32_t numPartitions) {
};
}

auto addShuffleWriteNode(ShuffleInterface* shuffle) {
auto addShuffleWriteNode(const std::shared_ptr<ShuffleInterface>& shuffle) {
return [shuffle](
core::PlanNodeId nodeId,
core::PlanNodePtr source) -> core::PlanNodePtr {
Expand Down Expand Up @@ -267,7 +267,7 @@ class UnsafeRowShuffleTest : public exec::test::OperatorTestBase {
}

void runShuffleTest(
ShuffleInterface* shuffle,
const std::shared_ptr<ShuffleInterface>& shuffle,
size_t numPartitions,
size_t numMapDrivers,
const std::vector<RowVectorPtr>& data) {
Expand Down Expand Up @@ -348,7 +348,7 @@ TEST_F(UnsafeRowShuffleTest, operators) {
std::make_unique<PartitionAndSerializeTranslator>());
exec::Operator::registerOperator(std::make_unique<ShuffleWriteTranslator>());

TestShuffle shuffle(pool(), 4, 1 << 20 /* 1MB */);
auto shuffle = std::make_shared<TestShuffle>(pool(), 4, 1 << 20 /* 1MB */);

auto data = makeRowVector({
makeFlatVector<int32_t>({1, 2, 3, 4}),
Expand All @@ -359,7 +359,7 @@ TEST_F(UnsafeRowShuffleTest, operators) {
.values({data}, true)
.addNode(addPartitionAndSerializeNode(4))
.localPartition({})
.addNode(addShuffleWriteNode(&shuffle))
.addNode(addShuffleWriteNode(shuffle))
.planNode();

exec::test::CursorParameters params;
Expand All @@ -375,12 +375,13 @@ TEST_F(UnsafeRowShuffleTest, endToEnd) {
size_t numPartitions = 5;
size_t numMapDrivers = 2;

TestShuffle shuffle(pool(), numPartitions, 1 << 20 /* 1MB */);
auto shuffle =
std::make_shared<TestShuffle>(pool(), numPartitions, 1 << 20 /* 1MB */);
auto data = vectorMaker_.rowVector({
makeFlatVector<int32_t>({1, 2, 3, 4, 5, 6}),
makeFlatVector<int64_t>({10, 20, 30, 40, 50, 60}),
});
runShuffleTest(&shuffle, numPartitions, numMapDrivers, {data});
runShuffleTest(shuffle, numPartitions, numMapDrivers, {data});
}

TEST_F(UnsafeRowShuffleTest, persistentShuffleDeser) {
Expand Down Expand Up @@ -432,14 +433,14 @@ TEST_F(UnsafeRowShuffleTest, persistentShuffle) {
auto rootPath = rootDirectory->path;

// Initialize persistent shuffle.
LocalPersistentShuffle shuffle(
auto shuffle = std::make_shared<LocalPersistentShuffle>(
rootPath, pool(), numPartitions, 1 << 20 /* 1MB */);

auto data = vectorMaker_.rowVector({
makeFlatVector<int32_t>({1, 2, 3, 4, 5, 6}),
makeFlatVector<int64_t>({10, 20, 30, 40, 50, 60}),
});
runShuffleTest(&shuffle, numPartitions, numMapDrivers, {data});
runShuffleTest(shuffle, numPartitions, numMapDrivers, {data});
}

TEST_F(UnsafeRowShuffleTest, persistentShuffleFuzz) {
Expand Down Expand Up @@ -492,7 +493,7 @@ TEST_F(UnsafeRowShuffleTest, persistentShuffleFuzz) {
velox::filesystems::registerLocalFileSystem();
auto rootDirectory = velox::exec::test::TempDirectoryPath::create();
auto rootPath = rootDirectory->path;
auto shuffle = std::make_unique<LocalPersistentShuffle>(
auto shuffle = std::make_shared<LocalPersistentShuffle>(
rootPath, pool(), numPartitions, 1 << 15);
for (int it = 0; it < numIterations; it++) {
shuffle->reset(pool(), numPartitions, rootPath);
Expand All @@ -505,7 +506,7 @@ TEST_F(UnsafeRowShuffleTest, persistentShuffleFuzz) {
auto input = fuzzer.fuzzRow(rowType);
inputVectors.push_back(input);
}
runShuffleTest(shuffle.get(), numPartitions, numMapDrivers, inputVectors);
runShuffleTest(shuffle, numPartitions, numMapDrivers, inputVectors);
}
}

Expand Down

0 comments on commit 670ae5b

Please sign in to comment.