Skip to content

Commit

Permalink
rpc : resource management rework
Browse files Browse the repository at this point in the history
  • Loading branch information
rgerganov committed May 27, 2024
1 parent 1d8fca7 commit 2ce50d1
Showing 1 changed file with 87 additions and 49 deletions.
136 changes: 87 additions & 49 deletions ggml-rpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,27 +96,37 @@ static ggml_guid_t ggml_backend_rpc_guid() {
return &guid;
}

struct ggml_backend_rpc_buffer_type_context {
struct rpc_backend {
int ref_count;
std::string endpoint;
std::shared_ptr<socket_t> sock;
ggml_backend_t backend;
};

using rpc_backend_ptr = std::shared_ptr<rpc_backend>;

struct ggml_backend_rpc_buffer_type_context {
std::shared_ptr<rpc_backend> back;
std::string name;
size_t alignment;
size_t max_size;
};

struct ggml_backend_rpc_context {
std::string endpoint;
std::string name;
std::shared_ptr<socket_t> sock;
std::shared_ptr<rpc_backend> back;
ggml_backend_buffer_type_t buft;
};

struct ggml_backend_rpc_buffer_context {
std::shared_ptr<socket_t> sock;
std::shared_ptr<rpc_backend> back;
std::unordered_map<ggml_backend_buffer_t, void *> base_cache;
uint64_t remote_ptr;
std::string name;
};

static std::unordered_map<std::string, rpc_backend_ptr> instances;

// RPC helper functions

static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
Expand Down Expand Up @@ -231,14 +241,13 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
return true;
}

static bool parse_endpoint(const char * endpoint, std::string & host, int & port) {
std::string str(endpoint);
size_t pos = str.find(':');
static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
size_t pos = endpoint.find(':');
if (pos == std::string::npos) {
return false;
}
host = str.substr(0, pos);
port = std::stoi(str.substr(pos + 1));
host = endpoint.substr(0, pos);
port = std::stoi(endpoint.substr(pos + 1));
return true;
}

Expand Down Expand Up @@ -273,6 +282,22 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm

// RPC client-side implementation

static void free_rpc_backend(rpc_backend_ptr rpc_back) {
ggml_backend_t backend = rpc_back->backend;
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
std::string endpoint = rpc_back->endpoint;
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context;
GGML_PRINT_DEBUG("[%s] closing connection to %s\n", __func__, endpoint.c_str());
delete buft_ctx;
delete rpc_ctx->buft;
delete rpc_ctx;
delete backend;
instances.erase(endpoint);
#ifdef _WIN32
WSACleanup();
#endif
}

GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
return ctx->name.c_str();
Expand All @@ -285,9 +310,13 @@ GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t
uint64_t remote_ptr = ctx->remote_ptr;
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, FREE_BUFFER, input, output);
bool status = send_rpc_cmd(ctx->back->sock, FREE_BUFFER, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.empty());
ctx->back->ref_count--;
if (ctx->back->ref_count == 0) {
free_rpc_backend(ctx->back);
}
delete ctx;
}

Expand All @@ -301,7 +330,7 @@ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t b
uint64_t remote_ptr = ctx->remote_ptr;
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, BUFFER_GET_BASE, input, output);
bool status = send_rpc_cmd(ctx->back->sock, BUFFER_GET_BASE, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == sizeof(uint64_t));
// output serialization format: | base_ptr (8 bytes) |
Expand Down Expand Up @@ -360,7 +389,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t b
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, SET_TENSOR, input, output);
bool status = send_rpc_cmd(ctx->back->sock, SET_TENSOR, input, output);
GGML_ASSERT(status);
}

Expand All @@ -374,7 +403,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t b
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, GET_TENSOR, input, output);
bool status = send_rpc_cmd(ctx->back->sock, GET_TENSOR, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == size);
// output serialization format: | data (size bytes) |
Expand All @@ -387,7 +416,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b
ggml_backend_rpc_buffer_context * src_ctx = (ggml_backend_rpc_buffer_context *)src_buffer->context;
ggml_backend_buffer_t dst_buffer = dst->buffer;
ggml_backend_rpc_buffer_context * dst_ctx = (ggml_backend_rpc_buffer_context *)dst_buffer->context;
if (src_ctx->sock != dst_ctx->sock) {
if (src_ctx->back != dst_ctx->back) {
return false;
}
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
Expand All @@ -399,7 +428,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b
memcpy(input.data(), &rpc_src, sizeof(rpc_src));
memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, COPY_TENSOR, input, output);
bool status = send_rpc_cmd(ctx->back->sock, COPY_TENSOR, input, output);
GGML_ASSERT(status);
// output serialization format: | result (1 byte) |
GGML_ASSERT(output.size() == 1);
Expand All @@ -414,7 +443,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer
memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(ctx->sock, BUFFER_CLEAR, input, output);
bool status = send_rpc_cmd(ctx->back->sock, BUFFER_CLEAR, input, output);
GGML_ASSERT(status);
}

Expand Down Expand Up @@ -442,7 +471,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
std::vector<uint8_t> input(input_size, 0);
memcpy(input.data(), &size, sizeof(size));
std::vector<uint8_t> output;
bool status = send_rpc_cmd(buft_ctx->sock, ALLOC_BUFFER, input, output);
bool status = send_rpc_cmd(buft_ctx->back->sock, ALLOC_BUFFER, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
Expand All @@ -453,8 +482,9 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
if (remote_ptr != 0) {
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
ggml_backend_rpc_buffer_interface,
new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"},
new ggml_backend_rpc_buffer_context{buft_ctx->back, {}, remote_ptr, "RPC"},
remote_size);
buft_ctx->back->ref_count++;
return buffer;
} else {
return nullptr;
Expand Down Expand Up @@ -508,7 +538,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend
}
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
return buft_ctx->sock == rpc_ctx->sock;
return buft_ctx->back == rpc_ctx->back;
}

static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
Expand All @@ -521,7 +551,6 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
/* .is_host = */ NULL,
};


GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;

Expand All @@ -530,11 +559,10 @@ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {

GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context;
delete buft_ctx;
delete rpc_ctx->buft;
delete rpc_ctx;
delete backend;
rpc_ctx->back->ref_count--;
if (rpc_ctx->back->ref_count == 0) {
free_rpc_backend(rpc_ctx->back);
}
}

GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
Expand Down Expand Up @@ -590,7 +618,7 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
std::vector<uint8_t> input;
serialize_graph(cgraph, input);
std::vector<uint8_t> output;
bool status = send_rpc_cmd(rpc_ctx->sock, GRAPH_COMPUTE, input, output);
bool status = send_rpc_cmd(rpc_ctx->back->sock, GRAPH_COMPUTE, input, output);
GGML_ASSERT(status);
GGML_ASSERT(output.size() == 1);
return (enum ggml_status)output[0];
Expand Down Expand Up @@ -624,17 +652,9 @@ static ggml_backend_i ggml_backend_rpc_interface = {
/* .event_synchronize = */ NULL,
};

static std::unordered_map<std::string, ggml_backend_t> instances;

GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr;
}

GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
std::string endpoint_str(endpoint);
if (instances.find(endpoint_str) != instances.end()) {
return instances[endpoint_str];
static rpc_backend_ptr create_rpc_backend(const std::string & endpoint) {
if (instances.find(endpoint) != instances.end()) {
return instances[endpoint];
}
#ifdef _WIN32
{
Expand All @@ -645,7 +665,7 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
}
}
#endif
fprintf(stderr, "Connecting to %s\n", endpoint);
fprintf(stderr, "Connecting to %s\n", endpoint.c_str());
std::string host;
int port;
if (!parse_endpoint(endpoint, host, port)) {
Expand All @@ -657,11 +677,12 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
}
size_t alignment = get_alignment(sock);
size_t max_size = get_max_size(sock);
auto rpc_back = std::make_shared<rpc_backend>();
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
/* .sock = */ sock,
/* .name = */ "RPC" + std::to_string(sock->fd),
/* .back = */ rpc_back,
/* .name = */ "RPC" + std::to_string(sock->fd),
/* .alignment = */ alignment,
/* .max_size = */ max_size
/* .max_size = */ max_size
};

ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
Expand All @@ -670,19 +691,37 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
};

ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
/* .endpoint = */ endpoint,
/* .name = */ "RPC" + std::to_string(sock->fd),
/* .sock = */ sock,
/* .back = */ rpc_back,
/* .buft = */ buft
};

instances[endpoint] = new ggml_backend {
ggml_backend_t backend = new ggml_backend {
/* .guid = */ ggml_backend_rpc_guid(),
/* .interface = */ ggml_backend_rpc_interface,
/* .context = */ ctx
};
rpc_back->sock = sock;
rpc_back->endpoint = endpoint;
rpc_back->backend = backend;
rpc_back->ref_count = 0;
instances[endpoint] = rpc_back;
return rpc_back;
}

return instances[endpoint];
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
auto rpc_back = create_rpc_backend(endpoint);
return rpc_back != nullptr ? ggml_backend_rpc_get_default_buffer_type(rpc_back->backend) : nullptr;
}

GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
std::string endpoint_str(endpoint);
auto rpc_back = create_rpc_backend(endpoint_str);
if (rpc_back == nullptr) {
return nullptr;
}
rpc_back->ref_count++;
return rpc_back->backend;
}

GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
Expand All @@ -706,14 +745,13 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
}

GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
if (backend == nullptr) {
auto rpc_back = create_rpc_backend(endpoint);
if (rpc_back == nullptr) {
*free = 0;
*total = 0;
return;
}
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
get_device_memory(ctx->sock, free, total);
get_device_memory(rpc_back->sock, free, total);
}

// RPC server-side implementation
Expand Down

0 comments on commit 2ce50d1

Please sign in to comment.