Skip to content

Commit

Permalink
llama : offload to RPC in addition to other backends
Browse files Browse the repository at this point in the history
  • Loading branch information
rgerganov committed May 30, 2024
1 parent 7846540 commit a0e5fa4
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 44 deletions.
31 changes: 27 additions & 4 deletions ggml-rpc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ struct ggml_backend_rpc_buffer_type_context {
};

struct ggml_backend_rpc_context {
int device;
std::string endpoint;
std::string name;
};
Expand All @@ -117,6 +118,9 @@ struct ggml_backend_rpc_buffer_context {
std::string name;
};

// device -> endpoint mapping
static std::unordered_map<int, std::string> endpoints;

// RPC helper functions

static std::shared_ptr<socket_t> make_socket(sockfd_t fd) {
Expand Down Expand Up @@ -573,7 +577,7 @@ GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {

GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
return ggml_backend_rpc_buffer_type(ctx->device);
}

GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
Expand Down Expand Up @@ -659,9 +663,13 @@ static ggml_backend_i ggml_backend_rpc_interface = {
/* .event_synchronize = */ NULL,
};

GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(int device) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
if (endpoints.find(device) == endpoints.end()) {
return nullptr;
}
auto endpoint = endpoints[device];
// NOTE: buffer types are allocated and never freed; this is by design
static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
auto it = buft_map.find(endpoint);
Expand Down Expand Up @@ -689,8 +697,17 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const
return buft;
}

GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
GGML_API GGML_CALL void ggml_backend_rpc_setdevice(const char * endpoint, int device) {
endpoints[device] = endpoint;
}

GGML_CALL ggml_backend_t ggml_backend_rpc_init(int device) {
if (endpoints.find(device) == endpoints.end()) {
return nullptr;
}
auto endpoint = endpoints[device];
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
/* .device = */ device,
/* .endpoint = */ endpoint,
/* .name = */ "RPC",
};
Expand Down Expand Up @@ -723,7 +740,13 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
*total = total_mem;
}

GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(int device, size_t * free, size_t * total) {
if (endpoints.find(device) == endpoints.end()) {
*free = 0;
*total = 0;
return;
}
auto endpoint = endpoints[device];
auto sock = get_socket(endpoint);
if (sock == nullptr) {
*free = 0;
Expand Down
8 changes: 5 additions & 3 deletions ggml-rpc.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@ extern "C" {
#define GGML_RPC_MAX_SERVERS 16

// backend API
GGML_API GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint);
GGML_API GGML_CALL void ggml_backend_rpc_setdevice(const char * endpoint, int device);

GGML_API GGML_CALL ggml_backend_t ggml_backend_rpc_init(int device);
GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend);

GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint);
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(int device);

GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(int device, size_t * free, size_t * total);

GGML_API GGML_CALL void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);

Expand Down
92 changes: 55 additions & 37 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2369,13 +2369,39 @@ struct llama_context {
struct llama_control_vector cvec;
};

static size_t llama_get_device_count(const llama_model & model) {
size_t count = 1;
#if defined(GGML_USE_CUDA)
count = ggml_backend_cuda_get_device_count();
#elif defined(GGML_USE_SYCL)
count = ggml_backend_sycl_get_device_count();
#elif defined(GGML_USE_VULKAN)
count = ggml_backend_vk_get_device_count();
#endif
#if defined(GGML_USE_RPC)
int rpc_count = (int)model.rpc_servers.size();
for (int i = 0; i < rpc_count; i++) {
int device = count + i;
const char * endpoint = model.rpc_servers[i].c_str();
ggml_backend_rpc_setdevice(endpoint, device);
}
count += rpc_count;
#endif
return count;
GGML_UNUSED(model);
}

static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_model & model, int gpu) {
ggml_backend_buffer_type_t buft = nullptr;

#ifdef GGML_USE_RPC
std::string endpoint = model.rpc_servers[gpu];
buft = ggml_backend_rpc_buffer_type(endpoint.c_str());
#elif defined(GGML_USE_METAL)
#if defined(GGML_USE_RPC)
int dev_count = (int)llama_get_device_count(model);
int rpc_count = (int)model.rpc_servers.size();
if (gpu >= dev_count - rpc_count) {
return ggml_backend_rpc_buffer_type(gpu);
}
#endif
#if defined(GGML_USE_METAL)
buft = ggml_backend_metal_buffer_type();
#elif defined(GGML_USE_CUDA)
buft = ggml_backend_cuda_buffer_type(gpu);
Expand Down Expand Up @@ -2423,29 +2449,18 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_mo
GGML_UNUSED(tensor_split);
}

static size_t llama_get_device_count(const llama_model & model) {
#if defined(GGML_USE_RPC)
return model.rpc_servers.size();
#elif defined(GGML_USE_CUDA)
return ggml_backend_cuda_get_device_count();
#elif defined(GGML_USE_SYCL)
return ggml_backend_sycl_get_device_count();
#elif defined(GGML_USE_VULKAN)
return ggml_backend_vk_get_device_count();
#else
return 1;
#endif
GGML_UNUSED(model);
}

static size_t llama_get_device_memory(const llama_model & model, int device) {
#if defined(GGML_USE_RPC)
size_t total;
size_t free;
std::string endpoint = model.rpc_servers[device];
ggml_backend_rpc_get_device_memory(endpoint.c_str(), &free, &total);
return free;
#elif defined(GGML_USE_CUDA)
int dev_count = (int)llama_get_device_count(model);
int rpc_count = (int)model.rpc_servers.size();
if (device >= dev_count - rpc_count) {
size_t total;
size_t free;
ggml_backend_rpc_get_device_memory(device, &free, &total);
return free;
}
#endif
#if defined(GGML_USE_CUDA)
size_t total;
size_t free;
ggml_backend_cuda_get_device_memory(device, &free, &total);
Expand Down Expand Up @@ -16146,7 +16161,7 @@ struct llama_model * llama_load_model_from_file(
return true;
};
}
if (params.rpc_servers != nullptr) {
if (params.rpc_servers != nullptr && params.rpc_servers[0] != '\0') {
// split the servers set them into model->rpc_servers
std::string servers(params.rpc_servers);
size_t pos = 0;
Expand Down Expand Up @@ -16304,17 +16319,7 @@ struct llama_context * llama_new_context_with_model(

if (!hparams.vocab_only) {
// initialize backends
#if defined(GGML_USE_RPC)
for (auto & server : model->rpc_servers) {
ggml_backend_t backend = ggml_backend_rpc_init(server.c_str());
if (backend == nullptr) {
LLAMA_LOG_ERROR("%s: failed to connect RPC backend to %s\n", __func__, server.c_str());
llama_free(ctx);
return nullptr;
}
ctx->backends.push_back(backend);
}
#elif defined(GGML_USE_METAL)
#if defined(GGML_USE_METAL)
if (model->n_gpu_layers > 0) {
ctx->backend_metal = ggml_backend_metal_init();
if (ctx->backend_metal == nullptr) {
Expand Down Expand Up @@ -16406,6 +16411,19 @@ struct llama_context * llama_new_context_with_model(
}
ctx->backends.push_back(backend);
}
#endif
#if defined(GGML_USE_RPC)
int dev_count = (int)llama_get_device_count(*model);
int rpc_count = (int)model->rpc_servers.size();
for (int i = dev_count - rpc_count; i < dev_count; i++) {
ggml_backend_t backend = ggml_backend_rpc_init(i);
if (backend == nullptr) {
LLAMA_LOG_ERROR("%s: failed to initialize RPC #%d\n", __func__, i);
llama_free(ctx);
return nullptr;
}
ctx->backends.push_back(backend);
}
#endif
ctx->backend_cpu = ggml_backend_cpu_init();
if (ctx->backend_cpu == nullptr) {
Expand Down

0 comments on commit a0e5fa4

Please sign in to comment.