Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml backends interface v1 #547

Merged
merged 23 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ build-sanitize-thread/
build-cov/
build-ci-debug/
build-ci-release/
build-cublas/
out/
tmp/
models/
Expand All @@ -15,6 +16,7 @@ compile_commands.json
CMakeSettings.json
.vs/
.vscode/
.clangd

.exrc
.cache
Expand All @@ -32,4 +34,4 @@ zig-cache/

*.sw?

__pycache__/
__pycache__/
10 changes: 10 additions & 0 deletions examples/gpt-2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,13 @@ target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)
set(TEST_TARGET gpt-2-quantize)
add_executable(${TEST_TARGET} quantize.cpp)
target_link_libraries(${TEST_TARGET} PRIVATE ggml common common-ggml)

#
# For GPU offloading

if (GGML_CUBLAS)
add_compile_definitions(GGML_USE_CUBLAS)
endif()
if (GGML_CLBLAST)
add_compile_definitions(GGML_USE_CLBLAST)
endif()
171 changes: 123 additions & 48 deletions examples/gpt-2/main.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
#include "ggml/ggml.h"
#include "ggml/ggml-alloc.h"
#include "ggml/ggml-backend.h"

#ifdef GGML_USE_CUBLAS
#include "ggml-cuda.h"
#endif

#include "common.h"
#include "common-ggml.h"
Expand Down Expand Up @@ -70,11 +75,17 @@ struct gpt2_model {

//
struct ggml_context * ctx;

ggml_backend_t backend = NULL;

ggml_backend_buffer_t buffer_w;
ggml_backend_buffer_t buffer_kv;

std::map<std::string, struct ggml_tensor *> tensors;
};

// load the model's weights from a file
bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab) {
bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab & vocab, int n_gpu_layers) {
printf("%s: loading model from '%s'\n", __func__, fname.c_str());

auto fin = std::ifstream(fname, std::ios::binary);
Expand Down Expand Up @@ -155,7 +166,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &

auto & ctx = model.ctx;

size_t ctx_size = 0;
size_t buffer_size = 0;

{
const auto & hparams = model.hparams;
Expand All @@ -165,46 +176,44 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;

ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b

ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte
ctx_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
ctx_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head
buffer_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_g
buffer_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_b

ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b
buffer_size += n_vocab*n_embd*ggml_type_sizef(wtype); // wte
buffer_size += n_ctx*n_embd*ggml_type_sizef(GGML_TYPE_F32); // wpe
buffer_size += n_vocab*n_embd*ggml_type_sizef(wtype); // lm_head

ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b
buffer_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_g
buffer_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_1_b

ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w
ctx_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b
buffer_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_g
buffer_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ln_2_b

ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b
buffer_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_attn_w
buffer_size += n_layer*( 3*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_attn_b

ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
ctx_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b
buffer_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // c_attn_proj_w
buffer_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_attn_proj_b

ctx_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
ctx_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b
buffer_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_fc_w
buffer_size += n_layer*( 4*n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_fc_b

ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_k
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F32); // memory_v
buffer_size += n_layer*(4*n_embd*n_embd*ggml_type_sizef(wtype)); // c_mlp_proj_w
buffer_size += n_layer*( n_embd*ggml_type_sizef(GGML_TYPE_F32)); // c_mlp_proj_b

ctx_size += (6 + 12*n_layer)*512; // object overhead
buffer_size += (6 + 12*n_layer)*128; // alignment overhead

printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor));
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
printf("%s: ggml tensor size = %d bytes\n", __func__, (int) sizeof(ggml_tensor));
printf("%s: backend buffer size = %6.2f MB\n", __func__, buffer_size/(1024.0*1024.0));
}

// create the ggml context
{
size_t n_tensors = 2 + 6 + 12*model.hparams.n_layer;
struct ggml_init_params params = {
/*.mem_size =*/ ctx_size,
/*.mem_size =*/ ggml_tensor_overhead() * n_tensors,
/*.mem_buffer =*/ NULL,
/*.no_alloc =*/ false,
/*.no_alloc =*/ true,
};

model.ctx = ggml_init(params);
Expand All @@ -214,6 +223,31 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
}
}

// initialize the backend
#ifdef GGML_USE_CUBLAS
if (n_gpu_layers > 0) {
fprintf(stderr, "%s: using CUDA backend\n", __func__);
model.backend = ggml_backend_cuda_init();
if (!model.backend) {
fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
}
}
#endif

if (!model.backend) {
// fallback to CPU backend
fprintf(stderr, "%s: using CPU backend\n", __func__);
model.backend = ggml_backend_cpu_init();
}

if (!model.backend) {
fprintf(stderr, "%s: ggml_backend_cpu_init() failed\n", __func__);
return false;
}

// allocate weights buffer
model.buffer_w = ggml_backend_alloc_buffer(model.backend, buffer_size);

// prepare memory for the weights
{
const auto & hparams = model.hparams;
Expand Down Expand Up @@ -299,14 +333,34 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);

printf("%s: memory size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);

// create a backend buffer (can be in host or device memory)
model.buffer_kv = ggml_backend_alloc_buffer(model.backend, memory_size + 256);

// allocate the tensors into the backend buffer
// TODO: better API for this
ggerganov marked this conversation as resolved.
Show resolved Hide resolved
{
ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer_kv);

// this updates the pointers in the tensors to point to the correct location in the buffer
// this is necessary since the ggml_context is .no_alloc == true
ggerganov marked this conversation as resolved.
Show resolved Hide resolved
ggml_allocr_alloc(alloc, model.memory_k);
ggml_allocr_alloc(alloc, model.memory_v);

ggml_allocr_free(alloc);
}
}

// load weights
{
ggml_allocr * alloc = ggml_allocr_new_from_buffer(model.buffer_w);

size_t total_size = 0;

bool has_lm_head = false;

std::vector<char> read_buf;

while (true) {
int32_t n_dims;
int32_t length;
Expand Down Expand Up @@ -336,6 +390,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
}

auto tensor = model.tensors[name];
ggml_set_name(tensor, name.c_str());
if (ggml_nelements(tensor) != nelements) {
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.c_str());
return false;
Expand All @@ -360,11 +415,19 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
return false;
}

fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
// read into a temporary buffer first, then copy to the tensor
// TODO: read directly into the tensor if the backend is CPU
slaren marked this conversation as resolved.
Show resolved Hide resolved
read_buf.resize(ggml_nbytes(tensor));
fin.read(read_buf.data(), ggml_nbytes(tensor));

ggml_allocr_alloc(alloc, tensor);
ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor));

// GPT-2 models share the WTE tensor as the LM head
if (name == "model/wte" && has_lm_head == false) {
memcpy(model.lm_head->data, tensor->data, ggml_nbytes(tensor));
//ggml_allocr_alloc(alloc, model.lm_head);
//ggml_backend_tensor_copy(tensor, model.lm_head);
model.lm_head = tensor;
}

if (name == "model/lm_head") {
Expand All @@ -374,6 +437,7 @@ bool gpt2_model_load(const std::string & fname, gpt2_model & model, gpt_vocab &
total_size += ggml_nbytes(tensor);
}

ggml_allocr_free(alloc);
printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
}

Expand Down Expand Up @@ -416,21 +480,23 @@ struct ggml_cgraph * gpt2_graph(

// avoid writing to tensors if we are only measuring the memory usage
if (!ggml_allocr_is_measure(allocr)) {
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
ggml_backend_tensor_set(embd, embd_inp.data(), 0, N*ggml_element_size(embd));
}

struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
ggml_allocr_alloc(allocr, position);
if (!ggml_allocr_is_measure(allocr)) {
for (int i = 0; i < N; ++i) {
((int32_t *) position->data)[i] = n_past + i;
int32_t v = n_past + i;
ggml_backend_tensor_set(position, &v, i*sizeof(int32_t), sizeof(v));
}
}

struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
ggml_allocr_alloc(allocr, KQ_scale);
if (!ggml_allocr_is_measure(allocr)) {
ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
float s = 1.0f/sqrtf(float(n_embd)/n_head);
ggml_backend_tensor_set(KQ_scale, &s, 0, sizeof(s));
}

// wte + wpe
Expand All @@ -453,7 +519,8 @@ struct ggml_cgraph * gpt2_graph(
ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].ln_1_g, cur),
cur),
ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
//ggml_repeat(ctx0, model.layers[il].ln_1_b, cur));
model.layers[il].ln_1_b);
}

// attn
Expand Down Expand Up @@ -599,7 +666,8 @@ struct ggml_cgraph * gpt2_graph(
ggml_mul(ctx0,
ggml_repeat(ctx0, model.layers[il].ln_2_g, cur),
cur),
ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
//ggml_repeat(ctx0, model.layers[il].ln_2_b, cur));
model.layers[il].ln_2_b);
}

// fully connected
Expand Down Expand Up @@ -654,7 +722,8 @@ struct ggml_cgraph * gpt2_graph(
ggml_mul(ctx0,
ggml_repeat(ctx0, model.ln_f_g, inpL),
inpL),
ggml_repeat(ctx0, model.ln_f_b, inpL));
//ggml_repeat(ctx0, model.ln_f_b, inpL));
model.ln_f_b);
}

// inpL = WTE * inpL
Expand Down Expand Up @@ -703,11 +772,10 @@ bool gpt2_eval(
ggml_allocr_alloc_graph(allocr, gf);

// run the computation
struct ggml_cplan plan = ggml_graph_plan(gf, n_threads);
static std::vector<uint8_t> work_buffer;
work_buffer.resize(plan.work_size);
plan.work_data = work_buffer.data();
ggml_graph_compute(gf, &plan);
if (strcmp(ggml_backend_name(model.backend), "CPU") == 0) {
slaren marked this conversation as resolved.
Show resolved Hide resolved
slaren marked this conversation as resolved.
Show resolved Hide resolved
ggml_backend_cpu_set_n_threads(model.backend, n_threads);
}
ggml_backend_graph_compute(model.backend, gf);

//if (n_past%100 == 0) {
// ggml_graph_print (&gf);
Expand All @@ -718,11 +786,11 @@ bool gpt2_eval(
struct ggml_tensor * inpL = gf->nodes[gf->n_nodes - 1];

//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
//ggml_backend_tensor_get(inpL, embd_w.data(), 0, sizeof(float)*n_vocab*N);

// return result just for the last token
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
ggml_backend_tensor_get(inpL, embd_w.data(), (n_vocab*(N-1))*sizeof(float), sizeof(float)*n_vocab);

return true;
}
Expand Down Expand Up @@ -759,7 +827,7 @@ int main(int argc, char ** argv) {
{
const int64_t t_start_us = ggml_time_us();

if (!gpt2_model_load(params.model, model, vocab)) {
if (!gpt2_model_load(params.model, model, vocab, params.n_gpu_layers)) {
fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
return 1;
}
Expand All @@ -770,25 +838,27 @@ int main(int argc, char ** argv) {
}

// keep this buffer alive while evaluating the model
std::vector<uint8_t> compute_buffer;
ggml_backend_buffer_t buf_compute;

struct ggml_allocr * allocr = NULL;
// allocate the compute buffer
{
allocr = ggml_allocr_new_measure(GGML_MEM_ALIGN);
// alignment required by the backend
size_t align = ggml_backend_get_alignment(model.backend);
allocr = ggml_allocr_new_measure(align);

// create the worst case graph for memory usage estimation
int n_tokens = std::min(model.hparams.n_ctx, params.n_batch);
int n_past = model.hparams.n_ctx - n_tokens;
struct ggml_cgraph * gf = gpt2_graph(model, allocr, n_past, std::vector<gpt_vocab::id>(n_tokens, 0));

// compute the required memory
size_t mem_size = ggml_allocr_alloc_graph(allocr, gf) + GGML_MEM_ALIGN;
slaren marked this conversation as resolved.
Show resolved Hide resolved
size_t mem_size = ggml_allocr_alloc_graph(allocr, gf);

// recreate the allocator with the required memory
ggml_allocr_free(allocr);
compute_buffer.resize(mem_size);
allocr = ggml_allocr_new(compute_buffer.data(), mem_size, GGML_MEM_ALIGN);
buf_compute = ggml_backend_alloc_buffer(model.backend, mem_size);
allocr = ggml_allocr_new_from_buffer(buf_compute);

fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0/1024.0);
}
Expand Down Expand Up @@ -888,5 +958,10 @@ int main(int argc, char ** argv) {

ggml_free(model.ctx);

ggml_backend_buffer_free(model.buffer_w);
ggml_backend_buffer_free(model.buffer_kv);
ggml_backend_buffer_free(buf_compute);
ggml_backend_free(model.backend);

return 0;
}
Loading