Skip to content

Commit

Permalink
[native] Add shuffle interface registration framework
Browse files Browse the repository at this point in the history
  • Loading branch information
tanjialiang committed Dec 7, 2022
1 parent 3d2ac5a commit 0271efd
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 0 deletions.
17 changes: 17 additions & 0 deletions presto-native-execution/presto_cpp/main/PrestoServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#include "presto_cpp/main/common/Counters.h"
#include "presto_cpp/main/connectors/hive/storage_adapters/FileSystems.h"
#include "presto_cpp/main/http/HttpServer.h"
#include "presto_cpp/main/operators/LocalPersistentShuffle.h"
#include "presto_cpp/main/operators/ShuffleInterface.h"
#include "presto_cpp/presto_protocol/Connectors.h"
#include "presto_cpp/presto_protocol/WriteProtocol.h"
#include "presto_cpp/presto_protocol/presto_protocol.h"
Expand Down Expand Up @@ -146,6 +148,7 @@ void PrestoServer::run() {
registerPrestoCppCounters();
registerFileSystems();
registerOptionalHiveStorageAdapters();
registerShuffleInterfaceFactories();
protocol::registerHiveConnectors();
protocol::registerTpchConnector();
protocol::HiveNoCommitWriteProtocol::registerProtocol();
Expand Down Expand Up @@ -467,6 +470,20 @@ std::vector<std::string> PrestoServer::registerConnectors(
return catalogNames;
}

void PrestoServer::registerShuffleInterfaceFactories() {
operators::ShuffleInterface::registerFactory(
operators::LocalPersistentShuffle::kShuffleName.toString(),
[](const std::string& /* serializedShuffleInfo */,
operators::ShuffleInterface::Type /* type */,
velox::memory::MemoryPool* /* pool */) {
// TODO: Any impl of ShuffleInterface should have a constructor,
// accepting a serialized shuffle info. Then the way of picking a
// shuffle interface factory is through config.properties. We can create
// an entry called shuffle.name
return std::make_shared<operators::LocalPersistentShuffle>(1 << 15);
});
}

void PrestoServer::registerFileSystems() {
velox::filesystems::registerLocalFileSystem();
}
Expand Down
2 changes: 2 additions & 0 deletions presto-native-execution/presto_cpp/main/PrestoServer.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class PrestoServer {
virtual std::vector<std::string> registerConnectors(
const fs::path& configDirectoryPath);

virtual void registerShuffleInterfaceFactories();

virtual void registerFileSystems();

void initializeAsyncCache();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ namespace facebook::presto::operators {
/// this class (pointing to the same root path) to read and write shuffle data.
class LocalPersistentShuffle : public ShuffleInterface {
public:
static constexpr folly::StringPiece kShuffleName{"local"};

LocalPersistentShuffle(uint32_t maxBytesPerPartition)
: maxBytesPerPartition_(maxBytesPerPartition),
threadId_(std::this_thread::get_id()) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,27 @@
*/
#pragma once

#include <fmt/format.h>
#include "velox/exec/Operator.h"

namespace facebook::presto::operators {

class ShuffleInterface {
public:
/// Indicates the type of ShuffleInterface. This common interface could be
/// used for both READ and WRITE. This enum is used to differenciate the usage
/// of the ShuffleInterface
enum class Type {
kRead,
kWrite,
};

using ShuffleInterfaceFactory =
std::function<std::shared_ptr<ShuffleInterface>(
const std::string& serializedShuffleInfo,
Type type,
velox::memory::MemoryPool* pool)>;

/// Write to the shuffle one row at a time.
virtual void collect(int32_t partition, std::string_view data) = 0;

Expand All @@ -38,6 +53,35 @@ class ShuffleInterface {
/// to be read while noMoreData signals the shuffle service that there
/// is no more data to be writen.
virtual bool readyForRead() const = 0;

/// Register ShuffleInterfaceFactory to its registry. It returns true if the
/// registration is successful, false if a factory with the name already
/// exists.
/// This method is not thread safe.
static bool registerFactory(
const std::string& name,
ShuffleInterfaceFactory shuffleInterfaceFactory) {
std::unordered_map<std::string, ShuffleInterfaceFactory>& factoryMap =
factories();
return factoryMap.emplace(name, shuffleInterfaceFactory).second;
}

/// Get a ShuffleInterfaceFactory with provided name. Throws if not found.
/// This method is not thread safe.
static ShuffleInterfaceFactory& factory(const std::string& name) {
auto factoryIter = factories().find(name);
if (factoryIter == factories().end()) {
VELOX_FAIL(fmt::format(
"ShuffleInterface with name '{}' is not registered.", name));
}
return factoryIter->second;
}

private:
static std::unordered_map<std::string, ShuffleInterfaceFactory>& factories() {
static std::unordered_map<std::string, ShuffleInterfaceFactory> factories;
return factories;
}
};

} // namespace facebook::presto::operators
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,19 @@ TEST_F(UnsafeRowShuffleTest, partitionAndSerializeToString) {
" -- Values[1000 rows in 1 vectors] -> c0:INTEGER, c1:BIGINT\n");
ASSERT_EQ(plan->toString(false, false), "-- PartitionAndSerialize\n");
}

TEST_F(UnsafeRowShuffleTest, shuffleInterfaceRegistration) {
const std::string kShuffleName = "dummy-shuffle";
auto dummyShuffleInterface =
[](const std::string& /* serializedShuffleInfo */,
ShuffleInterface::Type /* type */,
velox::memory::MemoryPool* /* pool */) { return nullptr; };
EXPECT_TRUE(
ShuffleInterface::registerFactory(kShuffleName, dummyShuffleInterface));
EXPECT_NO_THROW(ShuffleInterface::factory(kShuffleName));
EXPECT_FALSE(
ShuffleInterface::registerFactory(kShuffleName, dummyShuffleInterface));
}
} // namespace facebook::presto::operators::test

int main(int argc, char** argv) {
Expand Down

0 comments on commit 0271efd

Please sign in to comment.