diff --git a/ggml-rpc.cpp b/ggml-rpc.cpp index cc1d3ace1ddac0..f8a49b1c6df248 100644 --- a/ggml-rpc.cpp +++ b/ggml-rpc.cpp @@ -96,27 +96,37 @@ static ggml_guid_t ggml_backend_rpc_guid() { return &guid; } -struct ggml_backend_rpc_buffer_type_context { +struct rpc_backend { + int ref_count; + std::string endpoint; std::shared_ptr sock; + ggml_backend_t backend; +}; + +using rpc_backend_ptr = std::shared_ptr; + +struct ggml_backend_rpc_buffer_type_context { + std::shared_ptr back; std::string name; size_t alignment; size_t max_size; }; struct ggml_backend_rpc_context { - std::string endpoint; std::string name; - std::shared_ptr sock; + std::shared_ptr back; ggml_backend_buffer_type_t buft; }; struct ggml_backend_rpc_buffer_context { - std::shared_ptr sock; + std::shared_ptr back; std::unordered_map base_cache; uint64_t remote_ptr; std::string name; }; +static std::unordered_map instances; + // RPC helper functions static std::shared_ptr make_socket(sockfd_t fd) { @@ -231,14 +241,13 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) { return true; } -static bool parse_endpoint(const char * endpoint, std::string & host, int & port) { - std::string str(endpoint); - size_t pos = str.find(':'); +static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) { + size_t pos = endpoint.find(':'); if (pos == std::string::npos) { return false; } - host = str.substr(0, pos); - port = std::stoi(str.substr(pos + 1)); + host = endpoint.substr(0, pos); + port = std::stoi(endpoint.substr(pos + 1)); return true; } @@ -273,6 +282,22 @@ static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cm // RPC client-side implementation +static void free_rpc_backend(rpc_backend_ptr rpc_back) { + ggml_backend_t backend = rpc_back->backend; + ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; + std::string endpoint = rpc_back->endpoint; + ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context; + GGML_PRINT_DEBUG("[%s] closing connection to %s\n", __func__, endpoint.c_str()); + delete buft_ctx; + delete rpc_ctx->buft; + delete rpc_ctx; + delete backend; + instances.erase(endpoint); +#ifdef _WIN32 + WSACleanup(); +#endif +} + GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) { ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; return ctx->name.c_str(); @@ -285,9 +310,13 @@ GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t uint64_t remote_ptr = ctx->remote_ptr; memcpy(input.data(), &remote_ptr, sizeof(remote_ptr)); std::vector output; - bool status = send_rpc_cmd(ctx->sock, FREE_BUFFER, input, output); + bool status = send_rpc_cmd(ctx->back->sock, FREE_BUFFER, input, output); GGML_ASSERT(status); GGML_ASSERT(output.empty()); + ctx->back->ref_count--; + if (ctx->back->ref_count == 0) { + free_rpc_backend(ctx->back); + } delete ctx; } @@ -301,7 +330,7 @@ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t b uint64_t remote_ptr = ctx->remote_ptr; memcpy(input.data(), &remote_ptr, sizeof(remote_ptr)); std::vector output; - bool status = send_rpc_cmd(ctx->sock, BUFFER_GET_BASE, input, output); + bool status = send_rpc_cmd(ctx->back->sock, BUFFER_GET_BASE, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == sizeof(uint64_t)); // output serialization format: | base_ptr (8 bytes) | @@ -360,7 +389,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t b memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size); std::vector output; - bool status = send_rpc_cmd(ctx->sock, SET_TENSOR, input, output); + bool status = send_rpc_cmd(ctx->back->sock, SET_TENSOR, input, output); GGML_ASSERT(status); } @@ -374,7 +403,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t b memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset)); memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size)); std::vector output; - bool status = send_rpc_cmd(ctx->sock, GET_TENSOR, input, output); + bool status = send_rpc_cmd(ctx->back->sock, GET_TENSOR, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == size); // output serialization format: | data (size bytes) | @@ -387,7 +416,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context; ggml_backend_buffer_t dst_buffer = dst->buffer; ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context; - if (src_ctx->sock != dst_ctx->sock) { + if (src_ctx->back != dst_ctx->back) { return false; } ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; @@ -399,7 +428,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b memcpy(input.data(), &rpc_src, sizeof(rpc_src)); memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst)); std::vector output; - bool status = send_rpc_cmd(ctx->sock, COPY_TENSOR, input, output); + bool status = send_rpc_cmd(ctx->back->sock, COPY_TENSOR, input, output); GGML_ASSERT(status); // output serialization format: | result (1 byte) | GGML_ASSERT(output.size() == 1); @@ -414,7 +443,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr)); memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value)); std::vector output; - bool status = send_rpc_cmd(ctx->sock, BUFFER_CLEAR, input, output); + bool status = send_rpc_cmd(ctx->back->sock, BUFFER_CLEAR, input, output); GGML_ASSERT(status); } @@ -442,7 +471,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer std::vector input(input_size, 0); memcpy(input.data(), &size, sizeof(size)); std::vector output; - bool status = send_rpc_cmd(buft_ctx->sock, ALLOC_BUFFER, input, output); + bool status = send_rpc_cmd(buft_ctx->back->sock, ALLOC_BUFFER, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == 2*sizeof(uint64_t)); // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) | @@ -453,8 +482,9 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer if (remote_ptr != 0) { ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft, ggml_backend_rpc_buffer_interface, - new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"}, + new ggml_backend_rpc_buffer_context{buft_ctx->back, {}, remote_ptr, "RPC"}, remote_size); + buft_ctx->back->ref_count++; return buffer; } else { return nullptr; @@ -508,7 +538,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend } ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; - return buft_ctx->sock == rpc_ctx->sock; + return buft_ctx->back == rpc_ctx->back; } static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = { @@ -521,7 +551,6 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = { /* .is_host = */ NULL, }; - GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; @@ -530,11 +559,10 @@ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) { GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; - ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context; - delete buft_ctx; - delete rpc_ctx->buft; - delete rpc_ctx; - delete backend; + rpc_ctx->back->ref_count--; + if (rpc_ctx->back->ref_count == 0) { + free_rpc_backend(rpc_ctx->back); + } } GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) { @@ -590,7 +618,7 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t std::vector input; serialize_graph(cgraph, input); std::vector output; - bool status = send_rpc_cmd(rpc_ctx->sock, GRAPH_COMPUTE, input, output); + bool status = send_rpc_cmd(rpc_ctx->back->sock, GRAPH_COMPUTE, input, output); GGML_ASSERT(status); GGML_ASSERT(output.size() == 1); return (enum ggml_status)output[0]; @@ -624,17 +652,9 @@ static ggml_backend_i ggml_backend_rpc_interface = { /* .event_synchronize = */ NULL, }; -static std::unordered_map instances; - -GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { - ggml_backend_t backend = ggml_backend_rpc_init(endpoint); - return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr; -} - -GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { - std::string endpoint_str(endpoint); - if (instances.find(endpoint_str) != instances.end()) { - return instances[endpoint_str]; +static rpc_backend_ptr create_rpc_backend(const std::string & endpoint) { + if (instances.find(endpoint) != instances.end()) { + return instances[endpoint]; } #ifdef _WIN32 { @@ -645,7 +665,7 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { } } #endif - fprintf(stderr, "Connecting to %s\n", endpoint); + fprintf(stderr, "Connecting to %s\n", endpoint.c_str()); std::string host; int port; if (!parse_endpoint(endpoint, host, port)) { @@ -657,11 +677,12 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { } size_t alignment = get_alignment(sock); size_t max_size = get_max_size(sock); + auto rpc_back = std::make_shared(); ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context { - /* .sock = */ sock, - /* .name = */ "RPC" + std::to_string(sock->fd), + /* .back = */ rpc_back, + /* .name = */ "RPC" + std::to_string(sock->fd), /* .alignment = */ alignment, - /* .max_size = */ max_size + /* .max_size = */ max_size }; ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type { @@ -670,19 +691,37 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { }; ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { - /* .endpoint = */ endpoint, /* .name = */ "RPC" + std::to_string(sock->fd), - /* .sock = */ sock, + /* .back = */ rpc_back, /* .buft = */ buft }; - instances[endpoint] = new ggml_backend { + ggml_backend_t backend = new ggml_backend { /* .guid = */ ggml_backend_rpc_guid(), /* .interface = */ ggml_backend_rpc_interface, /* .context = */ ctx }; + rpc_back->sock = sock; + rpc_back->endpoint = endpoint; + rpc_back->backend = backend; + rpc_back->ref_count = 0; + instances[endpoint] = rpc_back; + return rpc_back; +} - return instances[endpoint]; +GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { + auto rpc_back = create_rpc_backend(endpoint); + return rpc_back != nullptr ? ggml_backend_rpc_get_default_buffer_type(rpc_back->backend) : nullptr; +} + +GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) { + std::string endpoint_str(endpoint); + auto rpc_back = create_rpc_backend(endpoint_str); + if (rpc_back == nullptr) { + return nullptr; + } + rpc_back->ref_count++; + return rpc_back->backend; } GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) { @@ -706,14 +745,13 @@ static void get_device_memory(const std::shared_ptr & sock, size_t * f } GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) { - ggml_backend_t backend = ggml_backend_rpc_init(endpoint); - if (backend == nullptr) { + auto rpc_back = create_rpc_backend(endpoint); + if (rpc_back == nullptr) { *free = 0; *total = 0; return; } - ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context; - get_device_memory(ctx->sock, free, total); + get_device_memory(rpc_back->sock, free, total); } // RPC server-side implementation