Skip to content

Commit

Permalink
cleanup and refactor *again*
Browse files Browse the repository at this point in the history
  • Loading branch information
damian0815 committed Oct 15, 2023
1 parent e3261ff commit d64891b
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 65 deletions.
39 changes: 19 additions & 20 deletions examples/llava/llava-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,46 +20,47 @@ static void show_additional_info(int /*argc*/, char ** argv) {
printf(" note: a lower temperature value like 0.1 is recommended for better quality.\n");
}

static bool load_image(llava_context * ctx_llava, gpt_params * params, float **image_embd, int * n_img_pos) {
static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_params * params) {

// load and preprocess the image
clip_image_u8 * img = make_clip_image_u8();
llava_image_embed * embed = NULL;
auto prompt = params->prompt;
if (prompt_contains_image(prompt)) {
if (!params->image.empty()) {
printf("using base64 encoded image instead of command line image path\n");
}
if (!clip_image_load_from_prompt(prompt, img)) {
embed = llava_image_embed_make_with_prompt_base64(ctx_llava->ctx_clip, params->n_threads, prompt);
if (!embed) {
fprintf(stderr, "%s: can't load image from prompt\n", __func__);
return false;
return NULL;
}
params->prompt = remove_image_from_prompt(prompt);
} else {
if (!clip_image_load_from_file(params->image.c_str(), img)) {
embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads, params->image.c_str());
if (!embed) {
fprintf(stderr, "%s: is %s really an image file?\n", __func__, params->image.c_str());
return false;
return NULL;
}
}
bool image_embed_result = llava_build_img_embed(ctx_llava->ctx_llama, ctx_llava->ctx_clip, params->n_threads, img, image_embd, n_img_pos);
if (!image_embed_result) {
clip_image_u8_free(img);
fprintf(stderr, "%s: coulnd't embed the image\n", __func__);
return false;
}

return true;
return embed;
}

static void process_prompt(struct llava_context * ctx_llava, float * image_embd, int n_img_pos, gpt_params * params, const char * prompt) {
static void process_prompt(struct llava_context * ctx_llava, struct llava_image_embed * image_embed, gpt_params * params, const char * prompt) {
int n_past = 0;

const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;

// llava chat format is "<system_prompt>USER: <image_embeddings>\n<textual_prompt>\nASSISTANT:"
// GG: are we sure that the should be a trailing whitespace at the end of this string?
printf("evaluating system prompt\n");
eval_string(ctx_llava->ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params->n_batch, &n_past);
llava_eval_image_embd(ctx_llava->ctx_llama, image_embd, n_img_pos, params->n_batch, &n_past);
printf("evaluating image embed\n");
llava_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past);
printf("evaluating prompt\n");
eval_string(ctx_llava->ctx_llama, prompt, params->n_batch, &n_past);
eval_string(ctx_llava->ctx_llama, "\nASSISTANT:", params->n_batch, &n_past);
printf("awaiting response\n");

// generate the response

Expand Down Expand Up @@ -153,16 +154,14 @@ int main(int argc, char ** argv) {
return 1;
}

float * image_embd;
int n_image_pos;
load_image(ctx_llava, &params, &image_embd, &n_image_pos);
auto image_embed = load_image(ctx_llava, &params);

// process the prompt
process_prompt(ctx_llava, image_embd, n_image_pos, &params, params.prompt.c_str());
process_prompt(ctx_llava, image_embed, &params, params.prompt.c_str());

llama_print_timings(ctx_llava->ctx_llama);

free(image_embd);
llava_image_embed_free(image_embed);
llava_free(ctx_llava);
return 0;
}
24 changes: 13 additions & 11 deletions llava/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
return new_clip;
}

clip_image_u8 * make_clip_image_u8() { return new clip_image_u8(); }
clip_image_u8 * make_clip_image_u8() {
auto img = new clip_image_u8();
return img;
}
clip_image_f32 * make_clip_image_f32() { return new clip_image_f32(); }

void clip_image_u8_free(clip_image_u8 * img) { if (img->data) { delete[] img->data; } delete img; }
Expand All @@ -692,31 +695,30 @@ static void build_clip_img_from_data(const stbi_uc * data, int nx, int ny, clip_
memcpy(img->data, data, img->size);
}

bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, clip_image_u8 * img) {
bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) {
int nx, ny, nc;
auto data = stbi_load_from_memory(bytes, bytes_length, &nx, &ny, &nc, 3);
auto data = stbi_load(fname, &nx, &ny, &nc, 3);
if (!data) {
fprintf(stderr, "%s: failed to decode image bytes\n", __func__);
fprintf(stderr, "%s: failed to load image '%s'\n", __func__, fname);
return false;
}
build_clip_img_from_data(data, nx, ny, img);
stbi_image_free(data);
return true;
}

bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) {
bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img) {
int nx, ny, nc;
auto data = stbi_load(fname, &nx, &ny, &nc, 3);
auto data = stbi_load_from_memory(bytes, bytes_length, &nx, &ny, &nc, 3);
if (!data) {
fprintf(stderr, "%s: failed to load image '%s'\n", __func__, fname);
fprintf(stderr, "%s: failed to decode image bytes\n", __func__);
return false;
}
build_clip_img_from_data(data, nx, ny, img);
stbi_image_free(data);
return true;
}


// normalize: x = (x - mean) / std
// TODO: implement bicubic interpolation instead of linear.
bool clip_image_preprocess(const clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32 * res, const bool pad2square) {
Expand Down Expand Up @@ -1065,16 +1067,16 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
return true;
}

int clip_n_mmproj_embd(struct clip_ctx * ctx) {
int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
return ctx->vision_model.mm_2_b->ne[0];
}

int clip_n_patches(struct clip_ctx * ctx) {
int clip_n_patches(const struct clip_ctx * ctx) {
auto & params = ctx->vision_model.hparams;

return (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
}

size_t clip_embd_nbytes(struct clip_ctx * ctx) {
size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float);
}
8 changes: 4 additions & 4 deletions llava/clip.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity);

void clip_free(struct clip_ctx * ctx);

size_t clip_embd_nbytes(struct clip_ctx * ctx);
int clip_n_patches(struct clip_ctx * ctx);
int clip_n_mmproj_embd(struct clip_ctx * ctx);
size_t clip_embd_nbytes(const struct clip_ctx * ctx);
int clip_n_patches(const struct clip_ctx * ctx);
int clip_n_mmproj_embd(const struct clip_ctx * ctx);

// RGB uint8 image
struct clip_image_u8 {
Expand Down Expand Up @@ -62,7 +62,7 @@ LLAMA_API void clip_image_u8_free(clip_image_u8 * img);
LLAMA_API void clip_image_f32_free(clip_image_f32 * img);
LLAMA_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
/** interpret bytes as an image file with length bytes_length, and use the result to populate img */
LLAMA_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, clip_image_u8 * img);
LLAMA_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img);

bool clip_image_preprocess(const struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res, const bool pad2square);
bool clip_image_encode(const struct clip_ctx * ctx, const int n_threads, struct clip_image_f32 * img, float * vec);
Expand Down
16 changes: 8 additions & 8 deletions llava/llava-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "common.h"
#include "llama.h"
#include "llava.h"

#include "base64.hpp"

Expand Down Expand Up @@ -143,12 +144,12 @@ inline bool prompt_contains_image(const std::string& prompt) {
}

// replaces the base64 image tag in the prompt with `replacement`
inline bool clip_image_load_from_prompt(const std::string& prompt, clip_image_u8 * img) {
inline llava_image_embed * llava_image_embed_make_with_prompt_base64(struct clip_ctx * ctx_clip, int n_threads, const std::string& prompt) {
size_t img_base64_str_start, img_base64_str_end;
find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end);
if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) {
fprintf(stderr, "%s: invalid base64 image tag. must be %s<base64 byte string>%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END);
return false;
return NULL;
}

auto base64_bytes_start = img_base64_str_start + strlen(IMG_BASE64_TAG_BEGIN);
Expand All @@ -157,16 +158,15 @@ inline bool clip_image_load_from_prompt(const std::string& prompt, clip_image_u8

auto required_bytes = base64::required_encode_size(base64_str.size());
auto img_bytes = std::vector<unsigned char>(required_bytes);
auto img_bytes_end = base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin());
size_t img_bytes_len = img_bytes_end - img_bytes.begin();
base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin());

auto img_loaded_ok = clip_image_load_from_bytes(img_bytes.data(), img_bytes_len, img);
if (!img_loaded_ok) {
auto embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, img_bytes.data(), img_bytes.size());
if (!embed) {
fprintf(stderr, "%s: could not load image from base64 string.\n", __func__);
return false;
return NULL;
}

return true;
return embed;
}

inline std::string remove_image_from_prompt(const std::string& prompt, const char * replacement = "") {
Expand Down
111 changes: 94 additions & 17 deletions llava/llava.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

#include "base64.hpp"

static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_image_embd, int * n_img_pos) {
static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) {
clip_image_f32 * img_res = make_clip_image_f32();
if (!clip_image_preprocess(ctx_clip, img, img_res, /*pad2square =*/ true)) {
fprintf(stderr, "%s: unable to preprocess image\n", __func__);
Expand All @@ -19,7 +19,6 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
}

*n_img_pos = clip_n_patches(ctx_clip);
*n_image_embd = clip_n_mmproj_embd(ctx_clip);

const int64_t t_img_enc_start_us = ggml_time_us();
bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd);
Expand All @@ -39,7 +38,18 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
return true;
}

bool llava_build_img_embed(const llama_context * ctx_llama, clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) {
bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip) {
// make sure that the correct mmproj was used, i.e., compare apples to apples
int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama));
auto n_image_embd = clip_n_mmproj_embd(ctx_clip);
if (n_image_embd != n_llama_embd) {
printf("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd);
return false;
}
return true;
}

static bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) {

float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
if (!image_embd) {
Expand All @@ -49,20 +59,11 @@ bool llava_build_img_embed(const llama_context * ctx_llama, clip_ctx * ctx_clip,
}

int n_img_pos;
int n_image_embd;
if (!encode_image_with_clip(ctx_clip, n_threads, img, image_embd, &n_image_embd, &n_img_pos)) {
if (!encode_image_with_clip(ctx_clip, n_threads, img, image_embd, &n_img_pos)) {
fprintf(stderr, "%s: cannot encode image, aborting\n", __func__);
free(image_embd);
return false;
}
// make sure that the correct mmproj was used, i.e., compare apples to apples
int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama));
if (n_image_embd != n_llama_embd) {
printf("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd);
free(image_embd);
return false;
}

*image_embd_out = image_embd;
*n_img_pos_out = n_img_pos;

Expand All @@ -71,15 +72,15 @@ bool llava_build_img_embed(const llama_context * ctx_llama, clip_ctx * ctx_clip,



bool llava_eval_image_embd(llama_context * ctx_llama, float * image_embd, int n_image_pos, int n_batch, int * n_past) {
bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
int n_embd = llama_n_embd(llama_get_model(ctx_llama));

for (int i = 0; i < n_image_pos; i += n_batch) {
int n_eval = n_image_pos - i;
for (int i = 0; i < image_embed->n_image_pos; i += n_batch) {
int n_eval = image_embed->n_image_pos - i;
if (n_eval > n_batch) {
n_eval = n_batch;
}
llama_batch batch = {int32_t(n_eval), nullptr, (image_embd+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, };
llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, };
if (llama_decode(ctx_llama, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
Expand All @@ -88,3 +89,79 @@ bool llava_eval_image_embd(llama_context * ctx_llama, float * image_embd, int n_
}
return true;
}


LLAMA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length)
{
clip_image_u8 * img = make_clip_image_u8();
if (!clip_image_load_from_bytes(image_bytes, image_bytes_length, img)) {
clip_image_u8_free(img);
fprintf(stderr, "%s: can't load image from bytes, is it a valid image?", __func__);
return NULL;
}

float* image_embed = NULL;
int n_image_pos = 0;
bool image_embed_result = llava_image_embed_make_with_clip_img(ctx_clip, n_threads, img, &image_embed, &n_image_pos);
if (!image_embed_result) {
clip_image_u8_free(img);
fprintf(stderr, "%s: coulnd't embed the image\n", __func__);
return NULL;
}

clip_image_u8_free(img);
auto result = (llava_image_embed*)malloc(sizeof(llava_image_embed));
result->embed = image_embed;
result->n_image_pos = n_image_pos;
return result;
}

static bool load_file_to_bytes(const char* path, unsigned char** bytesOut, long *sizeOut)
{
auto file = fopen(path, "rb");
if (file == NULL) {
fprintf(stderr, "%s: can't read file %s\n", __func__, path);
return false;
}

fseek(file, 0, SEEK_END);
auto fileSize = ftell(file);
fseek(file, 0, SEEK_SET);

auto buffer = (unsigned char *)malloc(fileSize); // Allocate memory to hold the file data
if (buffer == NULL) {
fprintf(stderr, "%s: failed to alloc %ld bytes for file %s\n", __func__, fileSize, path);
perror("Memory allocation error");
fclose(file);
return false;
}
fread(buffer, 1, fileSize, file); // Read the file into the buffer
fclose(file); // Close the file

*bytesOut = buffer;
*sizeOut = fileSize;
return true;

}

LLAMA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path)
{
unsigned char* image_bytes;
long image_bytes_length;
auto loaded = load_file_to_bytes(image_path, &image_bytes, &image_bytes_length);
if (!loaded) {
fprintf(stderr, "%s: failed to load %s\n", __func__, image_path);
return NULL;
}

auto embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, image_bytes, image_bytes_length);
free(image_bytes);

return embed;
}


LLAMA_API void llava_image_embed_free(struct llava_image_embed * embed) {
free(embed->embed);
free(embed);
}
Loading

0 comments on commit d64891b

Please sign in to comment.