From 683bf3fd3ac06a6865ecfe77b9dbacb77393dd91 Mon Sep 17 00:00:00 2001 From: Joe Abraham Date: Fri, 7 Jun 2024 14:23:39 -0700 Subject: [PATCH] feat(functions): Add support for REST based remote functions Co-authored-by: Wills Feng --- CMakeLists.txt | 3 + velox/functions/remote/client/CMakeLists.txt | 6 + velox/functions/remote/client/Remote.cpp | 78 +++++- velox/functions/remote/client/Remote.h | 25 +- velox/functions/remote/client/RestClient.cpp | 61 +++++ velox/functions/remote/client/RestClient.h | 68 +++++ .../remote/client/tests/CMakeLists.txt | 17 ++ .../client/tests/RemoteFunctionRestTest.cpp | 225 +++++++++++++++ velox/functions/remote/server/CMakeLists.txt | 17 +- .../server/RemoteFunctionBaseService.cpp | 102 +++++++ .../remote/server/RemoteFunctionBaseService.h | 58 ++++ .../server/RemoteFunctionRestService.cpp | 257 ++++++++++++++++++ .../remote/server/RemoteFunctionRestService.h | 90 ++++++ .../server/RemoteFunctionRestServiceMain.cpp | 66 +++++ .../remote/server/RemoteFunctionService.cpp | 94 +------ .../remote/server/RemoteFunctionService.h | 15 +- 16 files changed, 1072 insertions(+), 110 deletions(-) create mode 100644 velox/functions/remote/client/RestClient.cpp create mode 100644 velox/functions/remote/client/RestClient.h create mode 100644 velox/functions/remote/client/tests/RemoteFunctionRestTest.cpp create mode 100644 velox/functions/remote/server/RemoteFunctionBaseService.cpp create mode 100644 velox/functions/remote/server/RemoteFunctionBaseService.h create mode 100644 velox/functions/remote/server/RemoteFunctionRestService.cpp create mode 100644 velox/functions/remote/server/RemoteFunctionRestService.h create mode 100644 velox/functions/remote/server/RemoteFunctionRestServiceMain.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 339d1ed0dda2..486fbe20e791 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -523,6 +523,9 @@ if(VELOX_ENABLE_REMOTE_FUNCTIONS) find_package(fizz CONFIG REQUIRED) find_package(wangle CONFIG REQUIRED) find_package(FBThrift CONFIG REQUIRED) + set(cpr_SOURCE BUNDLED) + velox_resolve_dependency(cpr) + FetchContent_MakeAvailable(cpr) endif() if(VELOX_ENABLE_GCS) diff --git a/velox/functions/remote/client/CMakeLists.txt b/velox/functions/remote/client/CMakeLists.txt index 56663a29d04b..0c15a3d783e7 100644 --- a/velox/functions/remote/client/CMakeLists.txt +++ b/velox/functions/remote/client/CMakeLists.txt @@ -16,11 +16,17 @@ velox_add_library(velox_functions_remote_thrift_client ThriftClient.cpp) velox_link_libraries(velox_functions_remote_thrift_client PUBLIC remote_function_thrift FBThrift::thriftcpp2) +velox_add_library(velox_functions_remote_rest_client RestClient.cpp) +velox_link_libraries(velox_functions_remote_rest_client Folly::folly cpr::cpr) + velox_add_library(velox_functions_remote Remote.cpp) velox_link_libraries( velox_functions_remote PUBLIC velox_expression + velox_exec + velox_presto_serializer velox_functions_remote_thrift_client + velox_functions_remote_rest_client velox_functions_remote_get_serde velox_type_fbhive Folly::folly) diff --git a/velox/functions/remote/client/Remote.cpp b/velox/functions/remote/client/Remote.cpp index 8458b84baaef..40c89debd5de 100644 --- a/velox/functions/remote/client/Remote.cpp +++ b/velox/functions/remote/client/Remote.cpp @@ -16,12 +16,19 @@ #include "velox/functions/remote/client/Remote.h" +#include #include +#include +#include + +#include "velox/common/memory/ByteStream.h" #include "velox/expression/Expr.h" #include "velox/expression/VectorFunction.h" +#include "velox/functions/remote/client/RestClient.h" #include "velox/functions/remote/client/ThriftClient.h" #include "velox/functions/remote/if/GetSerde.h" #include "velox/functions/remote/if/gen-cpp2/RemoteFunctionServiceAsyncClient.h" +#include "velox/serializers/PrestoSerializer.h" #include "velox/type/fbhive/HiveTypeSerializer.h" #include "velox/vector/VectorStream.h" @@ -29,7 +36,6 @@ namespace facebook::velox::functions { namespace { std::string serializeType(const TypePtr& type) { - // Use hive type serializer. return type::fbhive::HiveTypeSerializer::serialize(type); } @@ -40,10 +46,11 @@ class RemoteFunction : public exec::VectorFunction { const std::vector& inputArgs, const RemoteVectorFunctionMetadata& metadata) : functionName_(functionName), - location_(metadata.location), - thriftClient_(getThriftClient(location_, &eventBase_)), + metadata_(metadata), serdeFormat_(metadata.serdeFormat), serde_(getSerde(serdeFormat_)) { + boost::apply_visitor(*this, metadata_.location); + std::vector types; types.reserve(inputArgs.size()); serializedInputTypes_.reserve(inputArgs.size()); @@ -55,6 +62,14 @@ class RemoteFunction : public exec::VectorFunction { remoteInputType_ = ROW(std::move(types)); } + void operator()(const folly::SocketAddress& address) { + thriftClient_ = getThriftClient(address, &eventBase_); + } + + void operator()(const std::string& url) { + restClient_ = getRestClient(); + } + void apply( const SelectivityVector& rows, std::vector& args, @@ -62,7 +77,11 @@ class RemoteFunction : public exec::VectorFunction { exec::EvalCtx& context, VectorPtr& result) const override { try { - applyRemote(rows, args, outputType, context, result); + if ((metadata_.location.type() == typeid(folly::SocketAddress))) { + applyThriftRemote(rows, args, outputType, context, result); + } else if (metadata_.location.type() == typeid(std::string)) { + applyRestRemote(rows, args, outputType, context, result); + } } catch (const VeloxRuntimeError&) { throw; } catch (const std::exception&) { @@ -71,7 +90,41 @@ class RemoteFunction : public exec::VectorFunction { } private: - void applyRemote( + void applyRestRemote( + const SelectivityVector& rows, + std::vector& args, + const TypePtr& outputType, + exec::EvalCtx& context, + VectorPtr& result) const { + try { + serializer::presto::PrestoVectorSerde serde; + auto remoteRowVector = std::make_shared( + context.pool(), + remoteInputType_, + BufferPtr{}, + rows.end(), + std::move(args)); + + std::unique_ptr requestBody = + std::make_unique(rowVectorToIOBuf( + remoteRowVector, rows.end(), *context.pool(), &serde)); + + std::unique_ptr responseBody = restClient_->invokeFunction( + boost::get(metadata_.location), std::move(requestBody)); + + auto outputRowVector = IOBufToRowVector( + *responseBody, ROW({outputType}), *context.pool(), &serde); + + result = outputRowVector->childAt(0); + } catch (const std::exception& e) { + VELOX_FAIL( + "Error while executing remote function '{}': {}", + functionName_, + e.what()); + } + } + + void applyThriftRemote( const SelectivityVector& rows, std::vector& args, const TypePtr& outputType, @@ -109,7 +162,7 @@ class RemoteFunction : public exec::VectorFunction { VELOX_FAIL( "Error while executing remote function '{}' at '{}': {}", functionName_, - location_.describe(), + boost::get(metadata_.location).describe(), e.what()); } @@ -142,10 +195,13 @@ class RemoteFunction : public exec::VectorFunction { } const std::string functionName_; - folly::SocketAddress location_; - + const RemoteVectorFunctionMetadata metadata_; folly::EventBase eventBase_; - std::unique_ptr thriftClient_; + + // Depending on the location, one of these is initialized by the visitor. + std::unique_ptr thriftClient_{nullptr}; + std::unique_ptr restClient_{nullptr}; + remote::PageFormat serdeFormat_; std::unique_ptr serde_; @@ -159,7 +215,7 @@ std::shared_ptr createRemoteFunction( const std::vector& inputArgs, const core::QueryConfig& /*config*/, const RemoteVectorFunctionMetadata& metadata) { - return std::make_unique(name, inputArgs, metadata); + return std::make_shared(name, inputArgs, metadata); } } // namespace @@ -169,7 +225,7 @@ void registerRemoteFunction( std::vector signatures, const RemoteVectorFunctionMetadata& metadata, bool overwrite) { - exec::registerStatefulVectorFunction( + registerStatefulVectorFunction( name, signatures, std::bind( diff --git a/velox/functions/remote/client/Remote.h b/velox/functions/remote/client/Remote.h index a6a1e773dc81..16fa1db37ae9 100644 --- a/velox/functions/remote/client/Remote.h +++ b/velox/functions/remote/client/Remote.h @@ -16,6 +16,7 @@ #pragma once +#include #include #include "velox/expression/VectorFunction.h" #include "velox/functions/remote/if/gen-cpp2/RemoteFunction_types.h" @@ -23,13 +24,29 @@ namespace facebook::velox::functions { struct RemoteVectorFunctionMetadata : public exec::VectorFunctionMetadata { - /// Network address of the servr to communicate with. Note that this can hold - /// a network location (ip/port pair) or a unix domain socket path (see + /// URL of the HTTP/REST server for remote function. + /// Or Network address of the server to communicate with. Note that this can + /// hold a network location (ip/port pair) or a unix domain socket path (see /// SocketAddress::makeFromPath()). - folly::SocketAddress location; + boost::variant location; - /// The serialization format to be used + /// The serialization format to be used when sending data to the remote. remote::PageFormat serdeFormat{remote::PageFormat::PRESTO_PAGE}; + + /// Optional schema defining the structure of the data or input/output types + /// involved in the remote function. This may include details such as column + /// names and data types. + std::optional schema; + + /// Optional identifier for the specific remote function to be invoked. + /// This can be useful when the same server hosts multiple functions, + /// and the client needs to specify which function to call. + std::optional functionId; + + /// Optional version information to be used when calling the remote function. + /// This can help in ensuring compatibility with a particular version of the + /// function if multiple versions are available on the server. + std::optional version; }; /// Registers a new remote function. It will use the meatadata defined in diff --git a/velox/functions/remote/client/RestClient.cpp b/velox/functions/remote/client/RestClient.cpp new file mode 100644 index 000000000000..e05b287bd9ca --- /dev/null +++ b/velox/functions/remote/client/RestClient.cpp @@ -0,0 +1,61 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/remote/client/RestClient.h" + +#include +#include + +#include "velox/common/base/Exceptions.h" + +using namespace folly; +namespace facebook::velox::functions { + +std::unique_ptr RestClient::invokeFunction( + const std::string& fullUrl, + std::unique_ptr requestPayload) { + IOBufQueue inputBufQueue(IOBufQueue::cacheChainLength()); + inputBufQueue.append(std::move(requestPayload)); + + std::string requestBody; + for (auto range : *inputBufQueue.front()) { + requestBody.append( + reinterpret_cast(range.data()), range.size()); + } + + cpr::Response response = cpr::Post( + cpr::Url{fullUrl}, + cpr::Header{ + {"Content-Type", "application/X-presto-pages"}, + {"Accept", "application/X-presto-pages"}}, + cpr::Body{requestBody}); + + if (response.error) { + VELOX_FAIL(fmt::format( + "Error communicating with server: {} URL: {}", + response.error.message, + fullUrl)); + } + + auto outputBuf = IOBuf::copyBuffer(response.text); + return outputBuf; +} + +std::unique_ptr getRestClient() { + return std::make_unique(); +} + +} // namespace facebook::velox::functions diff --git a/velox/functions/remote/client/RestClient.h b/velox/functions/remote/client/RestClient.h new file mode 100644 index 000000000000..b439cece53b8 --- /dev/null +++ b/velox/functions/remote/client/RestClient.h @@ -0,0 +1,68 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace facebook::velox::functions { + +/// @brief Abstract interface for an HTTP client. +/// Provides a method to invoke a function by sending an HTTP request +/// and receiving a response, both in Presto's serialized wire format. +class HttpClient { + public: + virtual ~HttpClient() = default; + + /// @brief Invokes a function over HTTP. + /// @param url The endpoint URL to send the request to. + /// @param requestPayload The request payload in Presto's serialized wire + /// format. + /// @return A unique pointer to the response payload in Presto's serialized + /// wire format. + virtual std::unique_ptr invokeFunction( + const std::string& url, + std::unique_ptr requestPayload) = 0; +}; + +/// @brief Concrete implementation of HttpClient using REST. +/// Handles HTTP communication by sending requests and receiving responses +/// using RESTful APIs with payloads in Presto's serialized wire format. +class RestClient : public HttpClient { + public: + /// @brief Invokes a function over HTTP using cpr. + /// Sends an HTTP POST request to the specified URL with the request payload + /// and receives the response payload. Both payloads are in Presto's + /// serialized wire format. + /// @param url The endpoint URL to send the request to. + /// @param requestPayload The request payload in Presto's serialized wire + /// format. + /// @return A unique pointer to the response payload in Presto's serialized + /// wire format. + /// @throws VeloxException if there is an error initializing cpr or during + /// the request. + std::unique_ptr invokeFunction( + const std::string& url, + std::unique_ptr requestPayload) override; +}; + +/// @brief Factory function to create an instance of RestClient. +/// @return A unique pointer to an HttpClient implementation. +std::unique_ptr getRestClient(); + +} // namespace facebook::velox::functions diff --git a/velox/functions/remote/client/tests/CMakeLists.txt b/velox/functions/remote/client/tests/CMakeLists.txt index 15c8c6e00ebd..a7a70ff0402b 100644 --- a/velox/functions/remote/client/tests/CMakeLists.txt +++ b/velox/functions/remote/client/tests/CMakeLists.txt @@ -28,3 +28,20 @@ target_link_libraries( GTest::gmock GTest::gtest GTest::gtest_main) + +add_executable(velox_functions_remote_client_rest_test + RemoteFunctionRestTest.cpp) + +add_test(velox_functions_remote_client_rest_test + velox_functions_remote_client_rest_test) + +target_link_libraries( + velox_functions_remote_client_rest_test + velox_functions_remote_rest_client + velox_functions_remote_server_rest + velox_functions_remote + velox_functions_test_lib + velox_exec_test_lib + GTest::gmock + GTest::gtest + GTest::gtest_main) diff --git a/velox/functions/remote/client/tests/RemoteFunctionRestTest.cpp b/velox/functions/remote/client/tests/RemoteFunctionRestTest.cpp new file mode 100644 index 000000000000..f69e5cda03c1 --- /dev/null +++ b/velox/functions/remote/client/tests/RemoteFunctionRestTest.cpp @@ -0,0 +1,225 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include "velox/common/base/Exceptions.h" +#include "velox/common/base/tests/GTestUtils.h" +#include "velox/exec/tests/utils/PortUtil.h" +#include "velox/functions/Registerer.h" +#include "velox/functions/lib/CheckedArithmetic.h" +#include "velox/functions/prestosql/Arithmetic.h" +#include "velox/functions/prestosql/StringFunctions.h" +#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" +#include "velox/functions/remote/client/Remote.h" +#include "velox/functions/remote/server/RemoteFunctionRestService.h" + +using ::facebook::velox::test::assertEqualVectors; + +namespace facebook::velox::functions { +namespace { + +class RemoteFunctionRestTest + : public test::FunctionBaseTest, + public testing::WithParamInterface { + public: + void SetUp() override { + auto servicePort = facebook::velox::exec::test::getFreePort(); + location_ = fmt::format("http://127.0.0.1:{}", servicePort); + initializeServer(servicePort); + registerRemoteFunctions(); + + auto wrongServicePort = facebook::velox::exec::test::getFreePort(); + wrongLocation_ = fmt::format("http://127.0.0.1:{}", wrongServicePort); + } + + // Registers a few remote functions to be used in this test. + void registerRemoteFunctions() const { + RemoteVectorFunctionMetadata metadata; + metadata.serdeFormat = remote::PageFormat::PRESTO_PAGE; + + auto functionName = "remote_abs"; + metadata.location = location_ + '/' + functionName; + auto absSignature = {exec::FunctionSignatureBuilder() + .returnType("integer") + .argumentType("integer") + .build()}; + registerRemoteFunction(functionName, absSignature, metadata); + + functionName = "remote_plus"; + metadata.location = location_ + '/' + functionName; + auto plusSignatures = {exec::FunctionSignatureBuilder() + .returnType("bigint") + .argumentType("bigint") + .argumentType("bigint") + .build()}; + registerRemoteFunction("remote_plus", plusSignatures, metadata); + + functionName = "remote_wrong_port"; + RemoteVectorFunctionMetadata wrongMetadata = metadata; + wrongMetadata.serdeFormat = remote::PageFormat::PRESTO_PAGE; + wrongMetadata.location = wrongLocation_ + functionName; + registerRemoteFunction(functionName, plusSignatures, wrongMetadata); + + functionName = "remote_divide"; + metadata.location = location_ + '/' + functionName; + auto divSignatures = {exec::FunctionSignatureBuilder() + .returnType("double") + .argumentType("double") + .argumentType("double") + .build()}; + registerRemoteFunction(functionName, divSignatures, metadata); + + functionName = "remote_substr"; + metadata.location = location_ + '/' + functionName; + auto substrSignatures = {exec::FunctionSignatureBuilder() + .returnType("varchar") + .argumentType("varchar") + .argumentType("integer") + .build()}; + registerRemoteFunction(functionName, substrSignatures, metadata); + + // Registers the actual function under a different prefix. This is only + // needed for tests since the HTTP service runs in the same process. + registerFunction( + {remotePrefix_ + ".remote_abs"}); + registerFunction( + {remotePrefix_ + ".remote_plus"}); + registerFunction( + {remotePrefix_ + ".remote_divide"}); + registerFunction( + {remotePrefix_ + ".remote_substr"}); + } + + void initializeServer(uint16_t servicePort) { + // Adjusted for Boost.Beast server; the server is started in the main + // thread. + + // Start the server in a separate thread + serverThread_ = std::make_unique([this, servicePort]() { + std::string serviceHost = "127.0.0.1"; + std::string functionPrefix = remotePrefix_; + std::make_shared( + ioc_, + boost::asio::ip::tcp::endpoint( + boost::asio::ip::make_address(serviceHost), servicePort), + functionPrefix) + ->run(); + + ioc_.run(); + }); + + VELOX_CHECK( + waitForRunning(servicePort), "Unable to initialize HTTP server."); + } + + ~RemoteFunctionRestTest() override { + if (serverThread_ && serverThread_->joinable()) { + ioc_.stop(); + serverThread_->join(); + } + } + + private: + bool waitForRunning(uint16_t servicePort) const { + for (size_t i = 0; i < 100; ++i) { + using boost::asio::ip::tcp; + boost::asio::io_context io_context; + + tcp::socket socket(io_context); + tcp::resolver resolver(io_context); + + try { + boost::asio::connect( + socket, resolver.resolve("127.0.0.1", std::to_string(servicePort))); + return true; + } catch (std::exception& e) { + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + } + } + return false; + } + + std::unique_ptr serverThread_; + boost::asio::io_context ioc_{1}; + + std::string location_; + std::string wrongLocation_; + + const std::string remotePrefix_{"remote"}; +}; + +TEST_F(RemoteFunctionRestTest, absolute) { + auto inputVector = makeFlatVector({-10, -20}); + auto results = evaluate>( + "remote_abs(c0)", makeRowVector({inputVector})); + + auto expected = makeFlatVector({10, 20}); + assertEqualVectors(expected, results); +} + +TEST_F(RemoteFunctionRestTest, simple) { + auto inputVector = makeFlatVector({1, 2, 3, 4, 5}); + auto results = evaluate>( + "remote_plus(c0, c0)", makeRowVector({inputVector})); + + auto expected = makeFlatVector({2, 4, 6, 8, 10}); + assertEqualVectors(expected, results); +} + +TEST_F(RemoteFunctionRestTest, string) { + auto inputVector = + makeFlatVector({"hello", "my", "remote", "world"}); + auto inputVector1 = makeFlatVector({2, 1, 3, 5}); + auto results = evaluate>( + "remote_substr(c0, c1)", makeRowVector({inputVector, inputVector1})); + + auto expected = makeFlatVector({"ello", "my", "mote", "d"}); + assertEqualVectors(expected, results); +} + +TEST_F(RemoteFunctionRestTest, connectionError) { + auto inputVector = makeFlatVector({1, 2, 3, 4, 5}); + auto func = [&]() { + evaluate>( + "remote_wrong_port(c0, c0)", makeRowVector({inputVector})); + }; + + // Check it throws and that the exception has the "connection refused" + // substring. + EXPECT_THROW(func(), VeloxRuntimeError); + try { + func(); + } catch (const VeloxRuntimeError& e) { + EXPECT_THAT( + e.message(), + testing::HasSubstr("Reason: Error communicating with server: ")); + } +} + +} // namespace +} // namespace facebook::velox::functions + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + folly::Init init{&argc, &argv, false}; + return RUN_ALL_TESTS(); +} diff --git a/velox/functions/remote/server/CMakeLists.txt b/velox/functions/remote/server/CMakeLists.txt index ff2afa0fed6a..f096f766937e 100644 --- a/velox/functions/remote/server/CMakeLists.txt +++ b/velox/functions/remote/server/CMakeLists.txt @@ -12,15 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. +add_library(velox_functions_remote_server_base RemoteFunctionBaseService.cpp) +target_link_libraries( + velox_functions_remote_server_base velox_expression velox_type_fbhive) + add_library(velox_functions_remote_server RemoteFunctionService.cpp) target_link_libraries( velox_functions_remote_server PUBLIC remote_function_thrift velox_functions_remote_get_serde - velox_type_fbhive velox_memory) + velox_functions_remote_server_base) add_executable(velox_functions_remote_server_main RemoteFunctionServiceMain.cpp) target_link_libraries( velox_functions_remote_server_main velox_functions_remote_server velox_functions_prestosql) + +add_library(velox_functions_remote_server_rest RemoteFunctionRestService.cpp) +target_link_libraries( + velox_functions_remote_server_rest velox_presto_serializer + velox_functions_remote_server_base) + +add_executable(velox_functions_remote_server_rest_main + RemoteFunctionRestServiceMain.cpp) +target_link_libraries( + velox_functions_remote_server_rest_main velox_functions_remote_server_rest + velox_functions_prestosql velox_exec_test_lib) diff --git a/velox/functions/remote/server/RemoteFunctionBaseService.cpp b/velox/functions/remote/server/RemoteFunctionBaseService.cpp new file mode 100644 index 000000000000..69fb5e4a9bbd --- /dev/null +++ b/velox/functions/remote/server/RemoteFunctionBaseService.cpp @@ -0,0 +1,102 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/remote/server/RemoteFunctionBaseService.h" + +#include "velox/type/fbhive/HiveTypeParser.h" + +namespace facebook::velox::functions { +namespace { + +inline std::string getFunctionName( + const std::string& prefix, + const std::string& functionName) { + return prefix.empty() ? functionName + : fmt::format("{}.{}", prefix, functionName); +} + +inline TypePtr deserializeType(const std::string& input) { + // Use hive type parser/serializer. + return type::fbhive::HiveTypeParser().parse(input); +} + +inline RowTypePtr deserializeArgTypes( + const std::vector& argTypes) { + const size_t argCount = argTypes.size(); + + std::vector argumentTypes; + std::vector typeNames; + argumentTypes.reserve(argCount); + typeNames.reserve(argCount); + + for (size_t i = 0; i < argCount; ++i) { + argumentTypes.emplace_back(deserializeType(argTypes[i])); + typeNames.emplace_back(fmt::format("c{}", i)); + } + return ROW(std::move(typeNames), std::move(argumentTypes)); +} + +inline std::vector getExpressions( + const RowTypePtr& inputType, + const TypePtr& returnType, + const std::string& functionName) { + std::vector inputs; + for (size_t i = 0; i < inputType->size(); ++i) { + inputs.push_back(std::make_shared( + inputType->childAt(i), inputType->nameOf(i))); + } + + return {std::make_shared( + returnType, std::move(inputs), functionName)}; +} + +} // namespace + +RowVectorPtr RemoteFunctionBaseService::invokeFunctionInternal( + const folly::IOBuf& payload, + const std::vector& argTypeNames, + const std::string& returnTypeName, + const std::string& functionName, + bool throwOnError, + VectorSerde* serde) { + auto inputType = deserializeArgTypes(argTypeNames); + auto outputType = deserializeType(returnTypeName); + + auto inputVector = IOBufToRowVector(payload, inputType, *pool_, serde); + + const vector_size_t numRows = inputVector->size(); + SelectivityVector rows{numRows}; + + queryCtx_ = core::QueryCtx::create(); + execCtx_ = std::make_unique(pool_.get(), queryCtx_.get()); + exec::ExprSet exprSet{ + getExpressions( + inputType, + outputType, + getFunctionName(functionPrefix_, functionName)), + execCtx_.get()}; + evalCtx_ = std::make_unique( + execCtx_.get(), &exprSet, inputVector.get()); + *evalCtx_->mutableThrowOnError() = throwOnError; + + std::vector expressionResult; + exprSet.eval(rows, *evalCtx_, expressionResult); + + return std::make_shared( + pool_.get(), ROW({outputType}), BufferPtr(), numRows, expressionResult); +} + +} // namespace facebook::velox::functions diff --git a/velox/functions/remote/server/RemoteFunctionBaseService.h b/velox/functions/remote/server/RemoteFunctionBaseService.h new file mode 100644 index 000000000000..a4a126ebd3e8 --- /dev/null +++ b/velox/functions/remote/server/RemoteFunctionBaseService.h @@ -0,0 +1,58 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include "velox/common/memory/Memory.h" +#include "velox/expression/Expr.h" +#include "velox/vector/VectorStream.h" + +namespace facebook::velox::functions { + +class RemoteFunctionBaseService { + public: + virtual ~RemoteFunctionBaseService() = default; + + protected: + RemoteFunctionBaseService( + const std::string& functionPrefix, + std::shared_ptr pool) + : functionPrefix_(functionPrefix), pool_(std::move(pool)) { + if (!pool_) { + pool_ = memory::memoryManager()->addLeafPool(); + } + } + + RowVectorPtr invokeFunctionInternal( + const folly::IOBuf& payload, + const std::vector& argTypeNames, + const std::string& returnTypeName, + const std::string& functionName, + bool throwOnError, + VectorSerde* serde); + + exec::EvalErrors* getEvalErrors_() { + return evalCtx_ ? evalCtx_->errors() : nullptr; + } + + std::string functionPrefix_; + std::shared_ptr pool_; + std::shared_ptr queryCtx_; + std::unique_ptr execCtx_; + std::unique_ptr evalCtx_; +}; + +} // namespace facebook::velox::functions diff --git a/velox/functions/remote/server/RemoteFunctionRestService.cpp b/velox/functions/remote/server/RemoteFunctionRestService.cpp new file mode 100644 index 000000000000..e7f4c392d790 --- /dev/null +++ b/velox/functions/remote/server/RemoteFunctionRestService.cpp @@ -0,0 +1,257 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "velox/functions/remote/server/RemoteFunctionRestService.h" + +#include +#include "velox/serializers/PrestoSerializer.h" + +namespace facebook::velox::functions { + +namespace { + +struct InternalFunctionSignature { + std::vector argumentTypes; + std::string returnType; +}; + +std::map internalFunctionSignatureMap = + { + {"remote_abs", {{"integer"}, "integer"}}, + {"remote_plus", {{"bigint", "bigint"}, "bigint"}}, + {"remote_divide", {{"double", "double"}, "double"}}, + {"remote_substr", {{"varchar", "integer"}, "varchar"}}, + // Add more functions here as needed, registerRemoteFunction should be + // called to use the functions mentioned in this map +}; + +} // namespace + +RestSession::RestSession( + boost::asio::ip::tcp::socket socket, + std::string functionPrefix) + : RemoteFunctionBaseService(std::move(functionPrefix), nullptr), + socket_(std::move(socket)) {} + +void RestSession::run() { + doRead(); +} + +void RestSession::doRead() { + auto self = shared_from_this(); + boost::beast::http::async_read( + socket_, + buffer_, + req_, + [self](boost::beast::error_code ec, std::size_t bytes_transferred) { + self->onRead(ec, bytes_transferred); + }); +} + +void RestSession::onRead( + boost::beast::error_code ec, + std::size_t bytes_transferred) { + boost::ignore_unused(bytes_transferred); + + if (ec == boost::beast::http::error::end_of_stream) { + return doClose(); + } + + if (ec) { + LOG(ERROR) << "Read error: " << ec.message(); + return; + } + + handleRequest(std::move(req_)); +} + +void RestSession::handleRequest( + boost::beast::http::request req) { + res_.version(req.version()); + res_.set(boost::beast::http::field::server, BOOST_BEAST_VERSION_STRING); + + if (req.method() != boost::beast::http::verb::post) { + res_.result(boost::beast::http::status::method_not_allowed); + res_.set(boost::beast::http::field::content_type, "text/plain"); + res_.body() = "Only POST method is allowed"; + res_.prepare_payload(); + + auto self = shared_from_this(); + boost::beast::http::async_write( + socket_, + res_, + [self](boost::beast::error_code ec, std::size_t bytes_transferred) { + self->onWrite(true, ec, bytes_transferred); + }); + return; + } + + std::string path = req.target(); + + // Expected path format: + // /{functionName} + std::vector pathComponents; + folly::split('/', path, pathComponents); + + std::string functionName; + if (pathComponents.size() <= 2) { + functionName = pathComponents[1]; + } else { + res_.result(boost::beast::http::status::bad_request); + res_.set(boost::beast::http::field::content_type, "text/plain"); + res_.body() = "Invalid request path"; + res_.prepare_payload(); + + auto self = shared_from_this(); + boost::beast::http::async_write( + socket_, + res_, + [self](boost::beast::error_code ec, std::size_t bytes_transferred) { + self->onWrite(true, ec, bytes_transferred); + }); + return; + } + + try { + const auto& functionSignature = + internalFunctionSignatureMap.at(functionName); + + serializer::presto::PrestoVectorSerde serde; + auto inputBuffer = folly::IOBuf::copyBuffer(req.body()); + + auto outputRowVector = invokeFunctionInternal( + *inputBuffer, + functionSignature.argumentTypes, + functionSignature.returnType, + functionName, + true, + &serde); + + auto payload = rowVectorToIOBuf( + outputRowVector, outputRowVector->size(), *pool_, &serde); + + res_.result(boost::beast::http::status::ok); + res_.set( + boost::beast::http::field::content_type, "application/octet-stream"); + res_.body() = payload.moveToFbString().toStdString(); + res_.prepare_payload(); + + auto self = shared_from_this(); + boost::beast::http::async_write( + socket_, + res_, + [self](boost::beast::error_code ec, std::size_t bytes_transferred) { + self->onWrite(false, ec, bytes_transferred); + }); + + } catch (const std::exception& ex) { + LOG(ERROR) << ex.what(); + res_.result(boost::beast::http::status::internal_server_error); + res_.set(boost::beast::http::field::content_type, "text/plain"); + res_.body() = ex.what(); + res_.prepare_payload(); + + auto self = shared_from_this(); + boost::beast::http::async_write( + socket_, + res_, + [self](boost::beast::error_code ec, std::size_t bytes_transferred) { + self->onWrite(true, ec, bytes_transferred); + }); + } +} + +void RestSession::onWrite( + bool close, + boost::beast::error_code ec, + std::size_t bytes_transferred) { + boost::ignore_unused(bytes_transferred); + + if (ec) { + LOG(ERROR) << "Write error: " << ec.message(); + return; + } + + if (close) { + return doClose(); + } + + req_ = {}; + + doRead(); +} + +void RestSession::doClose() { + boost::beast::error_code ec; + socket_.shutdown(boost::asio::ip::tcp::socket::shutdown_send, ec); +} + +RestListener::RestListener( + boost::asio::io_context& ioc, + boost::asio::ip::tcp::endpoint endpoint, + std::string functionPrefix) + : ioc_(ioc), acceptor_(ioc), functionPrefix_(std::move(functionPrefix)) { + boost::beast::error_code ec; + + acceptor_.open(endpoint.protocol(), ec); + if (ec) { + LOG(ERROR) << "Open error: " << ec.message(); + return; + } + + acceptor_.set_option(boost::asio::socket_base::reuse_address(true), ec); + if (ec) { + LOG(ERROR) << "Set_option error: " << ec.message(); + return; + } + + acceptor_.bind(endpoint, ec); + if (ec) { + LOG(ERROR) << "Bind error: " << ec.message(); + return; + } + + acceptor_.listen(boost::asio::socket_base::max_listen_connections, ec); + if (ec) { + LOG(ERROR) << "Listen error: " << ec.message(); + return; + } +} + +void RestListener::run() { + doAccept(); +} + +void RestListener::doAccept() { + acceptor_.async_accept( + [self = shared_from_this()]( + boost::beast::error_code ec, boost::asio::ip::tcp::socket socket) { + self->onAccept(ec, std::move(socket)); + }); +} + +void RestListener::onAccept( + boost::beast::error_code ec, + boost::asio::ip::tcp::socket socket) { + if (ec) { + LOG(ERROR) << "Accept error: " << ec.message(); + } else { + std::make_shared(std::move(socket), functionPrefix_)->run(); + } + doAccept(); +} + +} // namespace facebook::velox::functions diff --git a/velox/functions/remote/server/RemoteFunctionRestService.h b/velox/functions/remote/server/RemoteFunctionRestService.h new file mode 100644 index 000000000000..14a344b1ff72 --- /dev/null +++ b/velox/functions/remote/server/RemoteFunctionRestService.h @@ -0,0 +1,90 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include "velox/functions/remote/server/RemoteFunctionBaseService.h" + +namespace facebook::velox::functions { + +/// @brief Manages an individual HTTP session. +/// Handles reading HTTP requests, processing them, and sending responses. +/// This class re-hosts Velox functions and allows testing their functionality. +class RestSession : public std::enable_shared_from_this, + public RemoteFunctionBaseService { + public: + RestSession(boost::asio::ip::tcp::socket socket, std::string functionPrefix); + + /// Starts the session by initiating a read operation. + void run(); + + private: + // Initiates an asynchronous read operation. + void doRead(); + + // Called when a read operation completes. + void onRead(boost::beast::error_code ec, std::size_t bytes_transferred); + + // Processes the HTTP request and prepares a response. + void handleRequest( + boost::beast::http::request req); + + // Called when a write operation completes. + void onWrite( + bool close, + boost::beast::error_code ec, + std::size_t bytes_transferred); + + // Closes the socket connection. + void doClose(); + + boost::asio::ip::tcp::socket socket_; + boost::beast::flat_buffer buffer_; + boost::beast::http::request req_; + boost::beast::http::response res_; +}; + +/// @brief Listens for incoming TCP connections and creates sessions. +/// Sets up a TCP acceptor to listen for client connections, +/// creating a new session for each accepted connection. +class RestListener : public std::enable_shared_from_this { + public: + RestListener( + boost::asio::io_context& ioc, + boost::asio::ip::tcp::endpoint endpoint, + std::string functionPrefix); + + /// Starts accepting incoming connections. + void run(); + + private: + // Initiates an asynchronous accept operation. + void doAccept(); + + // Called when an accept operation completes. + void onAccept( + boost::beast::error_code ec, + boost::asio::ip::tcp::socket socket); + + boost::asio::io_context& ioc_; + boost::asio::ip::tcp::acceptor acceptor_; + std::string functionPrefix_; +}; + +} // namespace facebook::velox::functions diff --git a/velox/functions/remote/server/RemoteFunctionRestServiceMain.cpp b/velox/functions/remote/server/RemoteFunctionRestServiceMain.cpp new file mode 100644 index 000000000000..03a623c717f7 --- /dev/null +++ b/velox/functions/remote/server/RemoteFunctionRestServiceMain.cpp @@ -0,0 +1,66 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "RemoteFunctionRestService.h" +#include "velox/common/memory/Memory.h" +#include "velox/exec/tests/utils/PortUtil.h" +#include "velox/functions/Registerer.h" +#include "velox/functions/prestosql/registration/RegistrationFunctions.h" + +/// This executable is meant for testing. It instantiates a lightweight +/// server that can handle remote function requests, hosting all Presto scalar +/// functions. Clients can connect to this server to invoke these functions +/// remotely for testing and validation purposes. +/// +/// The server binds to a TCP endpoint specified by the --service_host and +/// --service_port flags, and each function is registered with a prefix defined +/// by the --function_prefix flag. +/// +/// NOTE: This server runs on a single-threaded boost::asio::io_context for +/// simplicity; it is not optimized for high throughput or production use. + +DEFINE_string(service_host, "127.0.0.1", "Host to bind the service to"); +DEFINE_string( + function_prefix, + "remote.schema", + "Prefix to be added to the functions being registered"); + +using namespace ::facebook::velox; + +int main(int argc, char* argv[]) { + folly::Init init(&argc, &argv); + FLAGS_logtostderr = true; + memory::initializeMemoryManager({}); + + LOG(INFO) << "Registering Presto functions"; + functions::prestosql::registerAllScalarFunctions(FLAGS_function_prefix); + boost::asio::io_context ioc{1}; + + auto servicePort = facebook::velox::exec::test::getFreePort(); + LOG(INFO) << "Initializing rest server at 127.0.0.1:" << servicePort; + std::make_shared( + ioc, + boost::asio::ip::tcp::endpoint( + boost::asio::ip::make_address(FLAGS_service_host), servicePort), + FLAGS_function_prefix) + ->run(); + + ioc.run(); + + return 0; +} diff --git a/velox/functions/remote/server/RemoteFunctionService.cpp b/velox/functions/remote/server/RemoteFunctionService.cpp index e378f5815ef4..ba5d3cebdbe7 100644 --- a/velox/functions/remote/server/RemoteFunctionService.cpp +++ b/velox/functions/remote/server/RemoteFunctionService.cpp @@ -16,56 +16,9 @@ #include "velox/functions/remote/server/RemoteFunctionService.h" #include "velox/common/base/Exceptions.h" -#include "velox/expression/Expr.h" #include "velox/functions/remote/if/GetSerde.h" -#include "velox/type/fbhive/HiveTypeParser.h" -#include "velox/vector/VectorStream.h" namespace facebook::velox::functions { -namespace { - -std::string getFunctionName( - const std::string& prefix, - const std::string& functionName) { - return prefix.empty() ? functionName - : fmt::format("{}.{}", prefix, functionName); -} - -TypePtr deserializeType(const std::string& input) { - // Use hive type parser/serializer. - return type::fbhive::HiveTypeParser().parse(input); -} - -RowTypePtr deserializeArgTypes(const std::vector& argTypes) { - const size_t argCount = argTypes.size(); - - std::vector argumentTypes; - std::vector typeNames; - argumentTypes.reserve(argCount); - typeNames.reserve(argCount); - - for (size_t i = 0; i < argCount; ++i) { - argumentTypes.emplace_back(deserializeType(argTypes[i])); - typeNames.emplace_back(fmt::format("c{}", i)); - } - return ROW(std::move(typeNames), std::move(argumentTypes)); -} - -} // namespace - -std::vector getExpressions( - const RowTypePtr& inputType, - const TypePtr& returnType, - const std::string& functionName) { - std::vector inputs; - for (size_t i = 0; i < inputType->size(); ++i) { - inputs.push_back(std::make_shared( - inputType->childAt(i), inputType->nameOf(i))); - } - - return {std::make_shared( - returnType, std::move(inputs), functionName)}; -} void RemoteFunctionServiceHandler::handleErrors( apache::thrift::field_ref result, @@ -112,50 +65,25 @@ void RemoteFunctionServiceHandler::invokeFunction( const auto& functionHandle = request->get_remoteFunctionHandle(); const auto& inputs = request->get_inputs(); - // Deserialize types and data. - auto inputType = deserializeArgTypes(functionHandle.get_argumentTypes()); - auto outputType = deserializeType(functionHandle.get_returnType()); - auto serdeFormat = inputs.get_pageFormat(); auto serde = getSerde(serdeFormat); - auto inputVector = - IOBufToRowVector(inputs.get_payload(), inputType, *pool_, serde.get()); - - // Execute the expression. - const vector_size_t numRows = inputVector->size(); - SelectivityVector rows{numRows}; - - // Expression boilerplate. - auto queryCtx = core::QueryCtx::create(); - core::ExecCtx execCtx{pool_.get(), queryCtx.get()}; - exec::ExprSet exprSet{ - getExpressions( - inputType, - outputType, - getFunctionName(functionPrefix_, functionHandle.get_name())), - &execCtx}; - - exec::EvalCtx evalCtx(&execCtx, &exprSet, inputVector.get()); - if (!request->get_throwOnError()) { - *evalCtx.mutableThrowOnError() = false; - } - - std::vector expressionResult; - exprSet.eval(rows, evalCtx, expressionResult); - - // Create output vector. - auto outputRowVector = std::make_shared( - pool_.get(), ROW({outputType}), BufferPtr(), numRows, expressionResult); + auto outputRowVector = invokeFunctionInternal( + inputs.get_payload(), + functionHandle.get_argumentTypes(), + functionHandle.get_returnType(), + functionHandle.get_name(), + request->get_throwOnError(), + serde.get()); auto result = response.result_ref(); result->rowCount_ref() = outputRowVector->size(); result->pageFormat_ref() = serdeFormat; - result->payload_ref() = - rowVectorToIOBuf(outputRowVector, rows.end(), *pool_, serde.get()); + result->payload_ref() = rowVectorToIOBuf( + outputRowVector, outputRowVector->size(), *pool_, serde.get()); - auto evalErrors = evalCtx.errors(); - if (evalErrors != nullptr && evalErrors->hasError()) { + auto evalErrors = getEvalErrors_(); + if (evalErrors && evalErrors->hasError()) { handleErrors(result, evalErrors, serde); } } diff --git a/velox/functions/remote/server/RemoteFunctionService.h b/velox/functions/remote/server/RemoteFunctionService.h index 3004f1576916..937c45287333 100644 --- a/velox/functions/remote/server/RemoteFunctionService.h +++ b/velox/functions/remote/server/RemoteFunctionService.h @@ -17,9 +17,8 @@ #pragma once #include -#include "velox/common/memory/Memory.h" #include "velox/functions/remote/if/gen-cpp2/RemoteFunctionService.h" -#include "velox/vector/VectorStream.h" +#include "velox/functions/remote/server/RemoteFunctionBaseService.h" namespace facebook::velox::exec { class EvalErrors; @@ -29,17 +28,14 @@ namespace facebook::velox::functions { // Simple implementation of the thrift server handler. class RemoteFunctionServiceHandler - : virtual public apache::thrift::ServiceHandler< + : public RemoteFunctionBaseService, + virtual public apache::thrift::ServiceHandler< remote::RemoteFunctionService> { public: RemoteFunctionServiceHandler( const std::string& functionPrefix = "", std::shared_ptr pool = nullptr) - : functionPrefix_(functionPrefix), pool_(std::move(pool)) { - if (pool_ == nullptr) { - pool_ = memory::memoryManager()->addLeafPool(); - } - } + : RemoteFunctionBaseService(functionPrefix, std::move(pool)) {} void invokeFunction( remote::RemoteFunctionResponse& response, @@ -52,9 +48,6 @@ class RemoteFunctionServiceHandler apache::thrift::field_ref result, exec::EvalErrors* evalErrors, const std::unique_ptr& serde) const; - - const std::string functionPrefix_; - std::shared_ptr pool_; }; } // namespace facebook::velox::functions