diff --git a/components/udf/mocks.h b/components/udf/mocks.h index 52a5d916..e57c7a6a 100644 --- a/components/udf/mocks.h +++ b/components/udf/mocks.h @@ -41,6 +41,12 @@ class MockUdfClient : public UdfClient { const google::protobuf::RepeatedPtrField&, ExecutionMetadata& execution_metadata), (const, override)); + MOCK_METHOD((absl::StatusOr>), + BatchExecuteCode, + (const RequestContextFactory& request_context_factory, + (absl::flat_hash_map & udf_input_map), + ExecutionMetadata& metadata), + (const, override)); MOCK_METHOD((absl::Status), Stop, (), (override)); MOCK_METHOD((absl::Status), SetCodeObject, (CodeConfig, privacy_sandbox::server_common::log::PSLogContext&), diff --git a/components/udf/noop_udf_client.cc b/components/udf/noop_udf_client.cc index af45a4f1..480dd6c2 100644 --- a/components/udf/noop_udf_client.cc +++ b/components/udf/noop_udf_client.cc @@ -45,6 +45,17 @@ class NoopUdfClientImpl : public UdfClient { return ""; } + absl::StatusOr> BatchExecuteCode( + const RequestContextFactory& request_context_factory, + absl::flat_hash_map& udf_input_map, + ExecutionMetadata& metadata) const { + absl::flat_hash_map response; + for (auto&& [k, v] : udf_input_map) { + response[k] = ""; + } + return response; + } + absl::Status Stop() { return absl::OkStatus(); } absl::Status SetCodeObject( diff --git a/components/udf/udf_client.cc b/components/udf/udf_client.cc index d989173d..a0345ea6 100644 --- a/components/udf/udf_client.cc +++ b/components/udf/udf_client.cc @@ -15,6 +15,7 @@ #include "components/udf/udf_client.h" #include +#include #include #include #include @@ -159,6 +160,55 @@ class UdfClientImpl : public UdfClient { absl::ToInt64Milliseconds(latency_recorder.GetLatency()); return *result; } + + absl::StatusOr> BatchExecuteCode( + const RequestContextFactory& request_context_factory, + absl::flat_hash_map& udf_input_map, + ExecutionMetadata& metadata) const { + absl::flat_hash_map results; + if (udf_input_map.empty()) { + PS_VLOG(5, request_context_factory.Get().GetPSLogContext()) + << "UDF input map is empty. Not executing any UDFs."; + return results; + } + + absl::flat_hash_map>> + responses; + metadata.custom_code_total_execution_time_micros = 0; + for (auto&& [id, udf_input] : udf_input_map) { + responses[id] = std::async( + std::launch::async, + [this, &request_context_factory, &metadata](UDFInput&& udf_input) { + ExecutionMetadata single_run_metadata; + auto result = + this->ExecuteCode(request_context_factory, + std::move(udf_input.execution_metadata), + udf_input.arguments, single_run_metadata); + // Record the longest UDF execution time across all parallel + // executions + metadata.custom_code_total_execution_time_micros = std::max( + metadata.custom_code_total_execution_time_micros, + single_run_metadata.custom_code_total_execution_time_micros); + return result; + }, + std::move(udf_input)); + } + + // Process responses + for (auto&& [id, response] : responses) { + auto result = response.get(); + + if (result.ok()) { + results[id] = std::move(result.value()); + } else { + PS_LOG(ERROR, request_context_factory.Get().GetPSLogContext()) + << "UDF Execution failed for partition id " << id << ": " + << result.status(); + } + } + return results; + } + absl::Status Init() { return roma_service_.Init(); } absl::Status Stop() { return roma_service_.Stop(); } @@ -256,13 +306,13 @@ class UdfClientImpl : public UdfClient { const absl::Duration udf_timeout_; const absl::Duration udf_update_timeout_; int udf_min_log_level_; - // Per b/299667930, RomaService has been extended to support metadata storage - // as a side effect of RomaService::Execute(), making it no longer const. - // However, UDFClient::ExecuteCode() remains logically const, so RomaService - // is marked as mutable to allow usage within UDFClient::ExecuteCode(). For - // concerns about mutable or go/totw/174, RomaService is thread-safe, so - // losing the thread-safety of usage within a const function is a lesser - // concern. + // Per b/299667930, RomaService has been extended to support metadata + // storage as a side effect of RomaService::Execute(), making it no longer + // const. However, UDFClient::ExecuteCode() remains logically const, so + // RomaService is marked as mutable to allow usage within + // UDFClient::ExecuteCode(). For concerns about mutable or go/totw/174, + // RomaService is thread-safe, so losing the thread-safety of usage within a + // const function is a lesser concern. mutable RomaService> roma_service_; }; diff --git a/components/udf/udf_client.h b/components/udf/udf_client.h index cb4bbe85..d86dd9ca 100644 --- a/components/udf/udf_client.h +++ b/components/udf/udf_client.h @@ -34,7 +34,13 @@ namespace kv_server { +struct UDFInput { + UDFExecutionMetadata execution_metadata; + google::protobuf::RepeatedPtrField arguments; +}; + struct ExecutionMetadata { + // Total time for all custom code to execute std::optional custom_code_total_execution_time_micros; }; @@ -60,6 +66,13 @@ class UdfClient { const google::protobuf::RepeatedPtrField& arguments, ExecutionMetadata& metadata) const = 0; + // Executes multiple UDFs in parallel. Code object must be set before making + // this call. + virtual absl::StatusOr> + BatchExecuteCode(const RequestContextFactory& request_context_factory, + absl::flat_hash_map& udf_input_map, + ExecutionMetadata& metadata) const = 0; + virtual absl::Status Stop() = 0; // Sets the code object that will be used for UDF execution diff --git a/components/udf/udf_client_test.cc b/components/udf/udf_client_test.cc index 5e1c8883..9f447ae6 100644 --- a/components/udf/udf_client_test.cc +++ b/components/udf/udf_client_test.cc @@ -51,6 +51,18 @@ using testing::Return; namespace kv_server { namespace { + +constexpr std::string_view kEmptyMetadata = R"( +request_metadata { + fields { + key: "hostname" + value { + string_value: "" + } + } +} + )"; + absl::StatusOr> CreateUdfClient() { Config> config; config.number_of_workers = 1; @@ -1104,5 +1116,123 @@ TEST_F(UdfClientTest, JsCallsLogCustomMetricFailedToLogError) { EXPECT_THAT(metrics_logging_outcome, ContainsRegex("Failed to log metrics")); } +TEST_F(UdfClientTest, BatchExecuteCodeSuccess) { + auto udf_client = CreateUdfClient(); + EXPECT_TRUE(udf_client.ok()); + + absl::Status code_obj_status = udf_client.value()->SetCodeObject(CodeConfig{ + .js = "hello = (metadata, data) => 'Hello world! ' + " + "JSON.stringify(metadata) + JSON.stringify(data);", + .udf_handler_name = "hello", + .logical_commit_time = 1, + .version = 1, + }); + EXPECT_TRUE(code_obj_status.ok()); + + UDFExecutionMetadata udf_metadata; + TextFormat::ParseFromString(kEmptyMetadata, &udf_metadata); + + google::protobuf::RepeatedPtrField args1; + args1.Add([] { + UDFArgument arg; + arg.mutable_tags()->add_values()->set_string_value("tag1"); + arg.mutable_data()->set_string_value("key1"); + return arg; + }()); + google::protobuf::RepeatedPtrField args2; + args2.Add([] { + UDFArgument arg; + arg.mutable_tags()->add_values()->set_string_value("tag2"); + arg.mutable_data()->set_string_value("key2"); + return arg; + }()); + + absl::flat_hash_map input; + input[1] = {.arguments = args1}; + input[2] = {.execution_metadata = udf_metadata, .arguments = args2}; + auto result = udf_client.value()->BatchExecuteCode( + *request_context_factory_, input, execution_metadata_); + ASSERT_TRUE(result.ok()); + auto udf_outputs = std::move(result.value()); + EXPECT_EQ(udf_outputs.size(), 2); + EXPECT_EQ( + udf_outputs[1], + R"("Hello world! {\"udfInterfaceVersion\":1}{\"tags\":[\"tag1\"],\"data\":\"key1\"}")"); + EXPECT_EQ( + udf_outputs[2], + R"("Hello world! {\"udfInterfaceVersion\":1,\"requestMetadata\":{\"hostname\":\"\"}}{\"tags\":[\"tag2\"],\"data\":\"key2\"}")"); + + absl::Status stop = udf_client.value()->Stop(); + EXPECT_TRUE(stop.ok()); +} + +TEST_F(UdfClientTest, BatchExecuteCodeIgnoresFailedPartition) { + auto udf_client = CreateUdfClient(); + EXPECT_TRUE(udf_client.ok()); + + absl::Status code_obj_status = udf_client.value()->SetCodeObject(CodeConfig{ + .js = + R"js(function hello(metadata, data) { + if(data.data == "valid_key") {return 'Hello world!';} + throw new Error('Oh no!'); + })js", + .udf_handler_name = "hello", + .logical_commit_time = 1, + .version = 1, + }); + EXPECT_TRUE(code_obj_status.ok()); + + UDFExecutionMetadata udf_metadata; + TextFormat::ParseFromString(kEmptyMetadata, &udf_metadata); + + google::protobuf::RepeatedPtrField args1; + args1.Add([] { + UDFArgument arg; + arg.mutable_tags()->add_values()->set_string_value("some_tag"); + arg.mutable_data()->set_string_value("valid_key"); + return arg; + }()); + google::protobuf::RepeatedPtrField args2; + args2.Add([] { + UDFArgument arg; + arg.mutable_data()->set_string_value("invalid key"); + return arg; + }()); + + absl::flat_hash_map input; + input[1] = {.arguments = args1}; + input[2] = {.arguments = args2}; + auto result = udf_client.value()->BatchExecuteCode( + *request_context_factory_, input, execution_metadata_); + ASSERT_TRUE(result.ok()); + auto udf_outputs = std::move(result.value()); + EXPECT_EQ(udf_outputs.size(), 1); + EXPECT_EQ(udf_outputs[1], R"("Hello world!")"); + + absl::Status stop = udf_client.value()->Stop(); + EXPECT_TRUE(stop.ok()); +} + +TEST_F(UdfClientTest, BatchExecuteCodeEmptyReturnsSuccess) { + auto udf_client = CreateUdfClient(); + EXPECT_TRUE(udf_client.ok()); + + absl::Status code_obj_status = udf_client.value()->SetCodeObject(CodeConfig{ + .js = "hello = (metadata, data) => 'Hello world! ' + " + "JSON.stringify(metadata) + JSON.stringify(data);", + .udf_handler_name = "hello", + .logical_commit_time = 1, + .version = 1, + }); + EXPECT_TRUE(code_obj_status.ok()); + + absl::flat_hash_map input; + auto result = udf_client.value()->BatchExecuteCode( + *request_context_factory_, input, execution_metadata_); + ASSERT_TRUE(result.ok()); + EXPECT_EQ(result->size(), 0); + absl::Status stop = udf_client.value()->Stop(); + EXPECT_TRUE(stop.ok()); +} } // namespace } // namespace kv_server