From 4631edc2b9bee2d7770dd6f0b7de265cab2c7b6c Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Thu, 17 Oct 2024 10:35:19 +0300 Subject: [PATCH] rpc : refactor backend Use structs for RPC request/response messages --- ggml/src/ggml-rpc.cpp | 58 ++++++++++++++++++++++--------------------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/ggml/src/ggml-rpc.cpp b/ggml/src/ggml-rpc.cpp index 13c7dd4364c331..ccaeda081edbed 100644 --- a/ggml/src/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc.cpp @@ -58,7 +58,7 @@ struct socket_t { }; // ggml_tensor is serialized into rpc_tensor -#pragma pack(push, 1) +#pragma pack(1) struct rpc_tensor { uint64_t id; uint32_t type; @@ -96,6 +96,17 @@ enum rpc_cmd { RPC_CMD_COUNT, }; +#pragma pack(1) +struct request_alloc_buffer { + uint64_t size; +}; + +#pragma pack(1) +struct response_alloc_buffer { + uint64_t remote_ptr; + uint64_t remote_size; +}; + // RPC data structures static ggml_guid_t ggml_backend_rpc_guid() { @@ -252,30 +263,31 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | // RPC response: | response_size (8 bytes) | response_data (response_size bytes) | -static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const std::vector & input, std::vector & output) { +static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) { uint8_t cmd_byte = cmd; if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) { return false; } - uint64_t input_size = input.size(); if (!send_data(sock->fd, &input_size, sizeof(input_size))) { return false; } - if (!send_data(sock->fd, input.data(), input.size())) { + if (!send_data(sock->fd, input, input_size)) { return false; } - uint64_t output_size; - if (!recv_data(sock->fd, &output_size, sizeof(output_size))) { + // TODO: currently the output_size is always known, do we need support for commands with variable output size? + // even if we do, we can skip sending output_size from the server for commands with known output size + uint64_t out_size; + if (!recv_data(sock->fd, &out_size, sizeof(out_size))) { return false; } - if (output_size == 0) { - output.clear(); - return true; - } - output.resize(output_size); - if (!recv_data(sock->fd, output.data(), output_size)) { + if (out_size != output_size) { return false; } + if (output_size > 0) { + if (!recv_data(sock->fd, output, output_size)) { + return false; + } + } return true; } @@ -484,25 +496,15 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; - // input serialization format: | size (8 bytes) | - int input_size = sizeof(uint64_t); - std::vector input(input_size, 0); - memcpy(input.data(), &size, sizeof(size)); - std::vector output; + request_alloc_buffer request = {size}; + response_alloc_buffer response; auto sock = get_socket(buft_ctx->endpoint); - bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, input, output); - GGML_ASSERT(status); - GGML_ASSERT(output.size() == 2*sizeof(uint64_t)); - // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) | - uint64_t remote_ptr; - memcpy(&remote_ptr, output.data(), sizeof(remote_ptr)); - size_t remote_size; - memcpy(&remote_size, output.data() + sizeof(uint64_t), sizeof(remote_size)); - if (remote_ptr != 0) { + bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response)); + if (response.remote_ptr != 0) { ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft, ggml_backend_rpc_buffer_interface, - new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"}, - remote_size); + new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"}, + response.remote_size); return buffer; } else { return nullptr;