Skip to content

Commit

Permalink
Optimise remote client
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe-Abraham committed Feb 11, 2025
1 parent e05d54a commit c31b29c
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 123 deletions.
2 changes: 0 additions & 2 deletions velox/functions/remote/client/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ velox_add_library(velox_functions_remote Remote.cpp)
velox_link_libraries(
velox_functions_remote
PUBLIC velox_expression
velox_memory
velox_exec
velox_vector
velox_presto_serializer
velox_functions_remote_thrift_client
velox_functions_remote_rest_client
Expand Down
227 changes: 106 additions & 121 deletions velox/functions/remote/client/Remote.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
#include "velox/type/fbhive/HiveTypeSerializer.h"
#include "velox/vector/VectorStream.h"

using namespace folly;
namespace facebook::velox::functions {
namespace {

Expand Down Expand Up @@ -64,7 +63,7 @@ class RemoteFunction : public exec::VectorFunction {
remoteInputType_ = ROW(std::move(types));
}

void operator()(const SocketAddress& address) {
void operator()(const folly::SocketAddress& address) {
thriftClient_ = getThriftClient(address, &eventBase_);
}

Expand All @@ -79,9 +78,11 @@ class RemoteFunction : public exec::VectorFunction {
exec::EvalCtx& context,
VectorPtr& result) const override {
try {
boost::apply_visitor(
ApplyRemote{rows, args, outputType, context, result, *this},
location_);
if ((metadata_.location.type() == typeid(folly::SocketAddress))) {
applyThriftRemote(rows, args, outputType, context, result);
} else if (metadata_.location.type() == typeid(std::string)) {
applyRestRemote(rows, args, outputType, context, result);
}
} catch (const VeloxRuntimeError&) {
throw;
} catch (const std::exception&) {
Expand All @@ -90,138 +91,122 @@ class RemoteFunction : public exec::VectorFunction {
}

private:
struct ApplyRemote : public boost::static_visitor<> {
const SelectivityVector& rows;
const std::vector<VectorPtr>& args;
const TypePtr& outputType;
exec::EvalCtx& context;
VectorPtr& result;
const RemoteFunction& parent;

ApplyRemote(
const SelectivityVector& r,
const std::vector<VectorPtr>& a,
const TypePtr& ot,
exec::EvalCtx& c,
VectorPtr& res,
const RemoteFunction& p)
: rows(r),
args(a),
outputType(ot),
context(c),
result(res),
parent(p) {}

void operator()(const std::string& url) const {
try {
serializer::presto::PrestoVectorSerde serde;
auto remoteRowVector = std::make_shared<RowVector>(
context.pool(),
parent.remoteInputType_,
BufferPtr{},
rows.end(),
std::move(args));

std::unique_ptr<IOBuf> requestBody =
std::make_unique<IOBuf>(rowVectorToIOBuf(
remoteRowVector, rows.end(), *context.pool(), &serde));

std::unique_ptr<IOBuf> responseBody =
parent.restClient_->invokeFunction(
boost::get<std::string>(parent.location_),
std::move(requestBody));

auto outputRowVector = IOBufToRowVector(
*responseBody, ROW({outputType}), *context.pool(), &serde);

result = outputRowVector->childAt(0);
} catch (const std::exception& e) {
VELOX_FAIL(
"Error while executing remote function '{}': {}",
parent.functionName_,
e.what());
}
}

void operator()(const SocketAddress& address) const {
// Create type and row vector for serialization.
void applyRestRemote(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& result) const {
try {
serializer::presto::PrestoVectorSerde serde;
auto remoteRowVector = std::make_shared<RowVector>(
context.pool(),
parent.remoteInputType_,
remoteInputType_,
BufferPtr{},
rows.end(),
std::move(args));

// Send to remote server.
remote::RemoteFunctionResponse remoteResponse;
remote::RemoteFunctionRequest request;
request.throwOnError_ref() = context.throwOnError();

auto functionHandle = request.remoteFunctionHandle_ref();
functionHandle->name_ref() = parent.functionName_;
functionHandle->returnType_ref() = serializeType(outputType);
functionHandle->argumentTypes_ref() = parent.serializedInputTypes_;

auto requestInputs = request.inputs_ref();
requestInputs->rowCount_ref() = remoteRowVector->size();
requestInputs->pageFormat_ref() = parent.serdeFormat_;

// TODO: serialize only active rows.
requestInputs->payload_ref() = rowVectorToIOBuf(
remoteRowVector, rows.end(), *context.pool(), parent.serde_.get());

try {
parent.thriftClient_->sync_invokeFunction(remoteResponse, request);
} catch (const std::exception& e) {
VELOX_FAIL(
"Error while executing remote function '{}' at '{}': {}",
parent.functionName_,
boost::get<SocketAddress>(parent.location_).describe(),
e.what());
}
std::unique_ptr<folly::IOBuf> requestBody =
std::make_unique<folly::IOBuf>(rowVectorToIOBuf(
remoteRowVector, rows.end(), *context.pool(), &serde));

std::unique_ptr<folly::IOBuf> responseBody = restClient_->invokeFunction(
boost::get<std::string>(location_), std::move(requestBody));

auto outputRowVector = IOBufToRowVector(
remoteResponse.get_result().get_payload(),
ROW({outputType}),
*context.pool(),
parent.serde_.get());
*responseBody, ROW({outputType}), *context.pool(), &serde);

result = outputRowVector->childAt(0);
} catch (const std::exception& e) {
VELOX_FAIL(
"Error while executing remote function '{}': {}",
functionName_,
e.what());
}
}

if (auto errorPayload = remoteResponse.get_result().errorPayload()) {
auto errorsRowVector = IOBufToRowVector(
*errorPayload,
ROW({VARCHAR()}),
*context.pool(),
parent.serde_.get());
auto errorsVector =
errorsRowVector->childAt(0)->asFlatVector<StringView>();
VELOX_CHECK(errorsVector, "Should be convertible to flat vector");

SelectivityVector selectedRows(errorsRowVector->size());
selectedRows.applyToSelected([&](vector_size_t i) {
if (errorsVector->isNullAt(i)) {
return;
}
try {
throw std::runtime_error(errorsVector->valueAt(i));
} catch (const std::exception& ex) {
context.setError(i, std::current_exception());
}
});
}
void applyThriftRemote(
const SelectivityVector& rows,
std::vector<VectorPtr>& args,
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& result) const {
// Create type and row vector for serialization.
auto remoteRowVector = std::make_shared<RowVector>(
context.pool(),
remoteInputType_,
BufferPtr{},
rows.end(),
std::move(args));

// Send to remote server.
remote::RemoteFunctionResponse remoteResponse;
remote::RemoteFunctionRequest request;
request.throwOnError_ref() = context.throwOnError();

auto functionHandle = request.remoteFunctionHandle_ref();
functionHandle->name_ref() = functionName_;
functionHandle->returnType_ref() = serializeType(outputType);
functionHandle->argumentTypes_ref() = serializedInputTypes_;

auto requestInputs = request.inputs_ref();
requestInputs->rowCount_ref() = remoteRowVector->size();
requestInputs->pageFormat_ref() = serdeFormat_;

// TODO: serialize only active rows.
requestInputs->payload_ref() = rowVectorToIOBuf(
remoteRowVector, rows.end(), *context.pool(), serde_.get());

try {
thriftClient_->sync_invokeFunction(remoteResponse, request);
} catch (const std::exception& e) {
VELOX_FAIL(
"Error while executing remote function '{}' at '{}': {}",
functionName_,
boost::get<folly::SocketAddress>(location_).describe(),
e.what());
}
};

auto outputRowVector = IOBufToRowVector(
remoteResponse.get_result().get_payload(),
ROW({outputType}),
*context.pool(),
serde_.get());
result = outputRowVector->childAt(0);

if (auto errorPayload = remoteResponse.get_result().errorPayload()) {
auto errorsRowVector = IOBufToRowVector(
*errorPayload, ROW({VARCHAR()}), *context.pool(), serde_.get());
auto errorsVector =
errorsRowVector->childAt(0)->asFlatVector<StringView>();
VELOX_CHECK(errorsVector, "Should be convertible to flat vector");

SelectivityVector selectedRows(errorsRowVector->size());
selectedRows.applyToSelected([&](vector_size_t i) {
if (errorsVector->isNullAt(i)) {
return;
}
try {
throw std::runtime_error(errorsVector->valueAt(i));
} catch (const std::exception& ex) {
context.setError(i, std::current_exception());
}
});
}
}

const std::string functionName_;
EventBase eventBase_;
const RemoteVectorFunctionMetadata metadata_;

remote::PageFormat serdeFormat_;
std::unique_ptr<VectorSerde> serde_;

boost::variant<SocketAddress, std::string> location_;
std::unique_ptr<RemoteFunctionClient> thriftClient_;
std::unique_ptr<HttpClient> restClient_;
boost::variant<folly::SocketAddress, std::string> location_;

// Depending on the location, one of these is initialized by the visitor.
std::unique_ptr<RemoteFunctionClient> thriftClient_{nullptr};
std::unique_ptr<HttpClient> restClient_{nullptr};

folly::EventBase eventBase_;

RowTypePtr remoteInputType_;
std::vector<std::string> serializedInputTypes_;
Expand All @@ -232,7 +217,7 @@ std::shared_ptr<exec::VectorFunction> createRemoteFunction(
const std::vector<exec::VectorFunctionArg>& inputArgs,
const core::QueryConfig& /*config*/,
const RemoteVectorFunctionMetadata& metadata) {
return std::make_unique<RemoteFunction>(name, inputArgs, metadata);
return std::make_shared<RemoteFunction>(name, inputArgs, metadata);
}

} // namespace
Expand Down

0 comments on commit c31b29c

Please sign in to comment.