From d694dacd791af360146e4326551a6f5fd8de52f4 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Sun, 15 Oct 2023 17:17:18 +0200 Subject: [PATCH] cleanup and refactor *again* --- examples/llava/llava-cli.cpp | 39 ++++++------ ggml-metal.m | 1 + llava/clip.cpp | 24 ++++---- llava/clip.h | 8 +-- llava/llava-utils.h | 16 ++--- llava/llava.cpp | 111 +++++++++++++++++++++++++++++------ llava/llava.h | 17 ++++-- 7 files changed, 151 insertions(+), 65 deletions(-) diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 336b674a743e68..8ae5c41247f5e8 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -20,46 +20,47 @@ static void show_additional_info(int /*argc*/, char ** argv) { printf(" note: a lower temperature value like 0.1 is recommended for better quality.\n"); } -static bool load_image(llava_context * ctx_llava, gpt_params * params, float **image_embd, int * n_img_pos) { +static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_params * params) { + // load and preprocess the image - clip_image_u8 * img = make_clip_image_u8(); + llava_image_embed * embed = NULL; auto prompt = params->prompt; if (prompt_contains_image(prompt)) { if (!params->image.empty()) { printf("using base64 encoded image instead of command line image path\n"); } - if (!clip_image_load_from_prompt(prompt, img)) { + embed = llava_image_embed_make_with_prompt_base64(ctx_llava->ctx_clip, params->n_threads, prompt); + if (!embed) { fprintf(stderr, "%s: can't load image from prompt\n", __func__); - return false; + return NULL; } params->prompt = remove_image_from_prompt(prompt); } else { - if (!clip_image_load_from_file(params->image.c_str(), img)) { + embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads, params->image.c_str()); + if (!embed) { fprintf(stderr, "%s: is %s really an image file?\n", __func__, params->image.c_str()); - return false; + return NULL; } } - bool image_embed_result = llava_build_img_embed(ctx_llava->ctx_llama, ctx_llava->ctx_clip, params->n_threads, img, image_embd, n_img_pos); - if (!image_embed_result) { - clip_image_u8_free(img); - fprintf(stderr, "%s: coulnd't embed the image\n", __func__); - return false; - } - return true; + return embed; } -static void process_prompt(struct llava_context * ctx_llava, float * image_embd, int n_img_pos, gpt_params * params, const char * prompt) { +static void process_prompt(struct llava_context * ctx_llava, struct llava_image_embed * image_embed, gpt_params * params, const char * prompt) { int n_past = 0; const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict; // llava chat format is "USER: \n\nASSISTANT:" // GG: are we sure that the should be a trailing whitespace at the end of this string? + printf("evaluating system prompt\n"); eval_string(ctx_llava->ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params->n_batch, &n_past); - llava_eval_image_embd(ctx_llava->ctx_llama, image_embd, n_img_pos, params->n_batch, &n_past); + printf("evaluating image embed\n"); + llava_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past); + printf("evaluating prompt\n"); eval_string(ctx_llava->ctx_llama, prompt, params->n_batch, &n_past); eval_string(ctx_llava->ctx_llama, "\nASSISTANT:", params->n_batch, &n_past); + printf("awaiting response\n"); // generate the response @@ -153,16 +154,14 @@ int main(int argc, char ** argv) { return 1; } - float * image_embd; - int n_image_pos; - load_image(ctx_llava, ¶ms, &image_embd, &n_image_pos); + auto image_embed = load_image(ctx_llava, ¶ms); // process the prompt - process_prompt(ctx_llava, image_embd, n_image_pos, ¶ms, params.prompt.c_str()); + process_prompt(ctx_llava, image_embed, ¶ms, params.prompt.c_str()); llama_print_timings(ctx_llava->ctx_llama); - free(image_embd); + llava_image_embed_free(image_embed); llava_free(ctx_llava); return 0; } diff --git a/ggml-metal.m b/ggml-metal.m index 87fa172161405a..7090e79b59af1a 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -189,6 +189,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){ { NSBundle * bundle = nil; #ifdef SWIFT_PACKAGE + print("would use SWIFTPM_MODULE_BUNDLE"); bundle = SWIFTPM_MODULE_BUNDLE; #else bundle = [NSBundle bundleForClass:[GGMLMetalClass class]]; diff --git a/llava/clip.cpp b/llava/clip.cpp index a2531de73d4a5d..ffb38b81cc8ca1 100644 --- a/llava/clip.cpp +++ b/llava/clip.cpp @@ -678,7 +678,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { return new_clip; } -clip_image_u8 * make_clip_image_u8() { return new clip_image_u8(); } +clip_image_u8 * make_clip_image_u8() { + auto img = new clip_image_u8(); + return img; +} clip_image_f32 * make_clip_image_f32() { return new clip_image_f32(); } void clip_image_u8_free(clip_image_u8 * img) { if (img->data) { delete[] img->data; } delete img; } @@ -692,11 +695,11 @@ static void build_clip_img_from_data(const stbi_uc * data, int nx, int ny, clip_ memcpy(img->data, data, img->size); } -bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, clip_image_u8 * img) { +bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) { int nx, ny, nc; - auto data = stbi_load_from_memory(bytes, bytes_length, &nx, &ny, &nc, 3); + auto data = stbi_load(fname, &nx, &ny, &nc, 3); if (!data) { - fprintf(stderr, "%s: failed to decode image bytes\n", __func__); + fprintf(stderr, "%s: failed to load image '%s'\n", __func__, fname); return false; } build_clip_img_from_data(data, nx, ny, img); @@ -704,11 +707,11 @@ bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length return true; } -bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) { +bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img) { int nx, ny, nc; - auto data = stbi_load(fname, &nx, &ny, &nc, 3); + auto data = stbi_load_from_memory(bytes, bytes_length, &nx, &ny, &nc, 3); if (!data) { - fprintf(stderr, "%s: failed to load image '%s'\n", __func__, fname); + fprintf(stderr, "%s: failed to decode image bytes\n", __func__); return false; } build_clip_img_from_data(data, nx, ny, img); @@ -716,7 +719,6 @@ bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) { return true; } - // normalize: x = (x - mean) / std // TODO: implement bicubic interpolation instead of linear. bool clip_image_preprocess(const clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32 * res, const bool pad2square) { @@ -1065,16 +1067,16 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i return true; } -int clip_n_mmproj_embd(struct clip_ctx * ctx) { +int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->vision_model.mm_2_b->ne[0]; } -int clip_n_patches(struct clip_ctx * ctx) { +int clip_n_patches(const struct clip_ctx * ctx) { auto & params = ctx->vision_model.hparams; return (params.image_size / params.patch_size) * (params.image_size / params.patch_size); } -size_t clip_embd_nbytes(struct clip_ctx * ctx) { +size_t clip_embd_nbytes(const struct clip_ctx * ctx) { return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float); } diff --git a/llava/clip.h b/llava/clip.h index 2185f67b96ba9e..a8022c52453d24 100644 --- a/llava/clip.h +++ b/llava/clip.h @@ -25,9 +25,9 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity); void clip_free(struct clip_ctx * ctx); -size_t clip_embd_nbytes(struct clip_ctx * ctx); -int clip_n_patches(struct clip_ctx * ctx); -int clip_n_mmproj_embd(struct clip_ctx * ctx); +size_t clip_embd_nbytes(const struct clip_ctx * ctx); +int clip_n_patches(const struct clip_ctx * ctx); +int clip_n_mmproj_embd(const struct clip_ctx * ctx); // RGB uint8 image struct clip_image_u8 { @@ -62,7 +62,7 @@ LLAMA_API void clip_image_u8_free(clip_image_u8 * img); LLAMA_API void clip_image_f32_free(clip_image_f32 * img); LLAMA_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img); /** interpret bytes as an image file with length bytes_length, and use the result to populate img */ -LLAMA_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, clip_image_u8 * img); +LLAMA_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img); bool clip_image_preprocess(const struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res, const bool pad2square); bool clip_image_encode(const struct clip_ctx * ctx, const int n_threads, struct clip_image_f32 * img, float * vec); diff --git a/llava/llava-utils.h b/llava/llava-utils.h index 53beefd2614672..38bf5729642ed9 100644 --- a/llava/llava-utils.h +++ b/llava/llava-utils.h @@ -4,6 +4,7 @@ #include "common.h" #include "llama.h" +#include "llava.h" #include "base64.hpp" @@ -143,12 +144,12 @@ inline bool prompt_contains_image(const std::string& prompt) { } // replaces the base64 image tag in the prompt with `replacement` -inline bool clip_image_load_from_prompt(const std::string& prompt, clip_image_u8 * img) { +inline llava_image_embed * llava_image_embed_make_with_prompt_base64(struct clip_ctx * ctx_clip, int n_threads, const std::string& prompt) { size_t img_base64_str_start, img_base64_str_end; find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end); if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) { fprintf(stderr, "%s: invalid base64 image tag. must be %s%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END); - return false; + return NULL; } auto base64_bytes_start = img_base64_str_start + strlen(IMG_BASE64_TAG_BEGIN); @@ -157,16 +158,15 @@ inline bool clip_image_load_from_prompt(const std::string& prompt, clip_image_u8 auto required_bytes = base64::required_encode_size(base64_str.size()); auto img_bytes = std::vector(required_bytes); - auto img_bytes_end = base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin()); - size_t img_bytes_len = img_bytes_end - img_bytes.begin(); + base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin()); - auto img_loaded_ok = clip_image_load_from_bytes(img_bytes.data(), img_bytes_len, img); - if (!img_loaded_ok) { + auto embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, img_bytes.data(), img_bytes.size()); + if (!embed) { fprintf(stderr, "%s: could not load image from base64 string.\n", __func__); - return false; + return NULL; } - return true; + return embed; } inline std::string remove_image_from_prompt(const std::string& prompt, const char * replacement = "") { diff --git a/llava/llava.cpp b/llava/llava.cpp index 18cbc76aa496b4..6bafee5a0abcf0 100644 --- a/llava/llava.cpp +++ b/llava/llava.cpp @@ -10,7 +10,7 @@ #include "base64.hpp" -static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_image_embd, int * n_img_pos) { +static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) { clip_image_f32 * img_res = make_clip_image_f32(); if (!clip_image_preprocess(ctx_clip, img, img_res, /*pad2square =*/ true)) { fprintf(stderr, "%s: unable to preprocess image\n", __func__); @@ -19,7 +19,6 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli } *n_img_pos = clip_n_patches(ctx_clip); - *n_image_embd = clip_n_mmproj_embd(ctx_clip); const int64_t t_img_enc_start_us = ggml_time_us(); bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd); @@ -39,7 +38,18 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli return true; } -bool llava_build_img_embed(const llama_context * ctx_llama, clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) { +bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip) { + // make sure that the correct mmproj was used, i.e., compare apples to apples + int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama)); + auto n_image_embd = clip_n_mmproj_embd(ctx_clip); + if (n_image_embd != n_llama_embd) { + printf("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd); + return false; + } + return true; +} + +static bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) { float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)); if (!image_embd) { @@ -49,20 +59,11 @@ bool llava_build_img_embed(const llama_context * ctx_llama, clip_ctx * ctx_clip, } int n_img_pos; - int n_image_embd; - if (!encode_image_with_clip(ctx_clip, n_threads, img, image_embd, &n_image_embd, &n_img_pos)) { + if (!encode_image_with_clip(ctx_clip, n_threads, img, image_embd, &n_img_pos)) { fprintf(stderr, "%s: cannot encode image, aborting\n", __func__); free(image_embd); return false; } - // make sure that the correct mmproj was used, i.e., compare apples to apples - int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama)); - if (n_image_embd != n_llama_embd) { - printf("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd); - free(image_embd); - return false; - } - *image_embd_out = image_embd; *n_img_pos_out = n_img_pos; @@ -71,15 +72,15 @@ bool llava_build_img_embed(const llama_context * ctx_llama, clip_ctx * ctx_clip, -bool llava_eval_image_embd(llama_context * ctx_llama, float * image_embd, int n_image_pos, int n_batch, int * n_past) { +bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) { int n_embd = llama_n_embd(llama_get_model(ctx_llama)); - for (int i = 0; i < n_image_pos; i += n_batch) { - int n_eval = n_image_pos - i; + for (int i = 0; i < image_embed->n_image_pos; i += n_batch) { + int n_eval = image_embed->n_image_pos - i; if (n_eval > n_batch) { n_eval = n_batch; } - llama_batch batch = {int32_t(n_eval), nullptr, (image_embd+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, }; + llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, }; if (llama_decode(ctx_llama, batch)) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; @@ -88,3 +89,79 @@ bool llava_eval_image_embd(llama_context * ctx_llama, float * image_embd, int n_ } return true; } + + +LLAMA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length) +{ + clip_image_u8 * img = make_clip_image_u8(); + if (!clip_image_load_from_bytes(image_bytes, image_bytes_length, img)) { + clip_image_u8_free(img); + fprintf(stderr, "%s: can't load image from bytes, is it a valid image?", __func__); + return NULL; + } + + float* image_embed = NULL; + int n_image_pos = 0; + bool image_embed_result = llava_image_embed_make_with_clip_img(ctx_clip, n_threads, img, &image_embed, &n_image_pos); + if (!image_embed_result) { + clip_image_u8_free(img); + fprintf(stderr, "%s: coulnd't embed the image\n", __func__); + return NULL; + } + + clip_image_u8_free(img); + auto result = (llava_image_embed*)malloc(sizeof(llava_image_embed)); + result->embed = image_embed; + result->n_image_pos = n_image_pos; + return result; +} + +static bool load_file_to_bytes(const char* path, unsigned char** bytesOut, long *sizeOut) +{ + auto file = fopen(path, "rb"); + if (file == NULL) { + fprintf(stderr, "%s: can't read file %s\n", __func__, path); + return false; + } + + fseek(file, 0, SEEK_END); + auto fileSize = ftell(file); + fseek(file, 0, SEEK_SET); + + auto buffer = (unsigned char *)malloc(fileSize); // Allocate memory to hold the file data + if (buffer == NULL) { + fprintf(stderr, "%s: failed to alloc %ld bytes for file %s\n", __func__, fileSize, path); + perror("Memory allocation error"); + fclose(file); + return false; + } + fread(buffer, 1, fileSize, file); // Read the file into the buffer + fclose(file); // Close the file + + *bytesOut = buffer; + *sizeOut = fileSize; + return true; + +} + +LLAMA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path) +{ + unsigned char* image_bytes; + long image_bytes_length; + auto loaded = load_file_to_bytes(image_path, &image_bytes, &image_bytes_length); + if (!loaded) { + fprintf(stderr, "%s: failed to load %s\n", __func__, image_path); + return NULL; + } + + auto embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, image_bytes, image_bytes_length); + free(image_bytes); + + return embed; +} + + +LLAMA_API void llava_image_embed_free(struct llava_image_embed * embed) { + free(embed->embed); + free(embed); +} diff --git a/llava/llava.h b/llava/llava.h index de3875e039f19c..aa9ea1a4f15975 100644 --- a/llava/llava.h +++ b/llava/llava.h @@ -10,13 +10,20 @@ struct clip_ctx; extern "C" { #endif -/** using ctx_clip, build a llava image embedding from the passed-in image `img` (see clip.h for methods to load img). - * result is returned as image_embd_out, size n_image_pos_out */ -LLAMA_API bool llava_build_img_embed(const struct llama_context * ctx_llama, struct clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_image_pos_out); +struct llava_image_embed { + float * embed; + int n_image_pos; +}; -/** write the image represented by image_embd (size n_image_pos) into the llama context with batch size n_batch, +LLAMA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip); + +LLAMA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length); +LLAMA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path); +LLAMA_API void llava_image_embed_free(struct llava_image_embed * embed); + +/** write the image represented by embed into the llama context with batch size n_batch, * starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */ -LLAMA_API bool llava_eval_image_embd(struct llama_context * ctx_llama, float * image_embd, int n_image_pos, int n_batch, int * n_past); +LLAMA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past); #ifdef __cplusplus