Skip to content

Commit

Permalink
gguf : deduplicate (#2629)
Browse files Browse the repository at this point in the history
* gguf : better type names

* dedup : CPU + Metal is working

* ggml : fix warnings about unused results

* llama.cpp : fix line feed and compiler warning

* llama : fix strncpy warning + note token_to_str does not write null

* llama : restore the original load/save session implementation

Will migrate this to GGUF in the future

* convert-llama-h5-to-gguf.py : support alt ctx param name

* ggml : assert when using ggml_mul with non-F32 src1

* examples : dedup simple

---------

Co-authored-by: klosax <131523366+klosax@users.noreply.github.com>
  • Loading branch information
ggerganov and klosax authored Aug 16, 2023
1 parent 758ff1b commit 88b5769
Show file tree
Hide file tree
Showing 21 changed files with 1,776 additions and 7,544 deletions.
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,6 @@ endif()
add_library(llama
llama.cpp
llama.h
llama-util.h
)

target_include_directories(llama PUBLIC .)
Expand Down
12 changes: 3 additions & 9 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Define the default target now so that it is always the first target
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple server embd-input-test gguf gguf-llama-simple gptneox-main
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple server embd-input-test gguf gptneox-main

# Binaries only useful for tests
TEST_TARGETS = tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0
Expand Down Expand Up @@ -329,10 +329,7 @@ ggml-alloc.o: ggml-alloc.c ggml.h ggml-alloc.h

OBJS += ggml-alloc.o

llama.o: llama.cpp ggml.h ggml-alloc.h ggml-cuda.h ggml-metal.h llama.h llama-util.h
$(CXX) $(CXXFLAGS) -c $< -o $@

gguf-llama.o: gguf-llama.cpp ggml.h ggml-alloc.h ggml-cuda.h ggml-metal.h gguf-llama.h
llama.o: llama.cpp ggml.h ggml-alloc.h ggml-cuda.h ggml-metal.h llama.h
$(CXX) $(CXXFLAGS) -c $< -o $@

common.o: examples/common.cpp examples/common.h
Expand Down Expand Up @@ -388,10 +385,7 @@ $(LIB_PRE)embdinput$(DSO_EXT): examples/embd-input/embd-input.h examples/embd-in
embd-input-test: $(LIB_PRE)embdinput$(DSO_EXT) examples/embd-input/embd-input-test.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %$(DSO_EXT),$(filter-out %.h,$(filter-out %.hpp,$^))) -o $@ $(LDFLAGS) -L. -lembdinput

gguf: examples/gguf/gguf.cpp build-info.h ggml.o gguf-llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

gguf-llama-simple: examples/gguf/gguf-llama-simple.cpp build-info.h ggml.o gguf-llama.o common.o $(OBJS)
gguf: examples/gguf/gguf.cpp build-info.h ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)

gptneox-main: gptneox-main.cpp ggml.o $(OBJS)
Expand Down
7 changes: 3 additions & 4 deletions convert-llama-7b-pth-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def count_model_parts(dir_model: str) -> int:
toktype = 1 # defualt to normal token type
if tokenizer.is_unknown(i): toktype = 2
if tokenizer.is_control(i): toktype = 3

# TODO: How to determinate if a token is user defined?
# ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto
# if tokenizer.is_user_defined(i): toktype = 4
Expand Down Expand Up @@ -223,7 +223,7 @@ def count_model_parts(dir_model: str) -> int:
sys.exit()

n_dims = len(data.shape)
data_dtype = data.dtype
data_dtype = data.dtype

# if f32 desired, convert any float16 to float32
if ftype == 0 and data.dtype == np.float16:
Expand Down Expand Up @@ -261,7 +261,6 @@ def count_model_parts(dir_model: str) -> int:
for name in model_part.keys():
data = model_part[name]


old_dtype = data.dtype

# we don't need these
Expand All @@ -284,7 +283,7 @@ def count_model_parts(dir_model: str) -> int:
sys.exit()

n_dims = len(data.shape)
data_dtype = data.dtype
data_dtype = data.dtype

# if f32 desired, convert any float16 to float32
if ftype == 0 and data.dtype == np.float16:
Expand Down
13 changes: 11 additions & 2 deletions convert-llama-h5-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,21 @@ def count_model_parts(dir_model: str) -> int:
else:
hf_repo=""

if "max_sequence_length" in hparams:
ctx_length = hparams["max_sequence_length"]
elif "max_position_embeddings" in hparams:
ctx_length = hparams["max_position_embeddings"]
else:
print("gguf: can not find ctx length parameter.")
sys.exit()


gguf_writer.add_architecture(llm_arch)
gguf_writer.add_name(last_dir)
gguf_writer.add_file_type("All tensors F32" if ftype == 0 else "Most tensors F16, some F32")
gguf_writer.add_source_hf_repo(hf_repo)
gguf_writer.add_tensor_data_layout(llm_arch, "Meta AI original pth")
gguf_writer.add_context_length(llm_arch, hparams["max_position_embeddings"])
gguf_writer.add_context_length(llm_arch, ctx_length)
gguf_writer.add_embedding_length(llm_arch, hparams["hidden_size"])
gguf_writer.add_block_count(llm_arch, block_count)
gguf_writer.add_feed_forward_length(llm_arch, hparams["intermediate_size"])
Expand Down Expand Up @@ -318,7 +327,7 @@ def count_model_parts(dir_model: str) -> int:
if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
data = data.astype(np.float16)

print( name + ", shape " + str(len(data.shape)) + ", " + str(old_dtype) + " --> " + str(data.dtype))
print(name + ", shape " + str(len(data.shape)) + ", " + str(old_dtype) + " --> " + str(data.dtype))

gguf_writer.write_tensor_to_file(data)

Expand Down
16 changes: 0 additions & 16 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,18 +170,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.n_ctx = std::stoi(argv[i]);
} else if (arg == "-gqa" || arg == "--gqa") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.n_gqa = std::stoi(argv[i]);
} else if (arg == "-eps" || arg == "--rms-norm-eps") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.rms_norm_eps = std::stof(argv[i]);
} else if (arg == "--rope-freq-base") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -546,8 +534,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stdout, " -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict);
fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa);
fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps);
fprintf(stdout, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
fprintf(stdout, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
fprintf(stdout, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
Expand Down Expand Up @@ -638,8 +624,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param

lparams.n_ctx = params.n_ctx;
lparams.n_batch = params.n_batch;
lparams.n_gqa = params.n_gqa;
lparams.rms_norm_eps = params.rms_norm_eps;
lparams.n_gpu_layers = params.n_gpu_layers;
lparams.main_gpu = params.main_gpu;
lparams.tensor_split = params.tensor_split;
Expand Down
2 changes: 0 additions & 2 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,12 @@ struct gpt_params {
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 512; // context size
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
int32_t n_gqa = 1; // grouped-query attention factor (TODO: move to hparams)
int32_t n_keep = 0; // number of tokens to keep from initial prompt
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
float rms_norm_eps = LLAMA_DEFAULT_RMS_EPS; // rms norm epsilon
float rope_freq_base = 10000.0f; // RoPE base frequency
float rope_freq_scale = 1.0f; // RoPE frequency scaling factor

Expand Down
144 changes: 75 additions & 69 deletions examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "ggml.h"
#include "llama.h"

#include <unordered_map>
#include <vector>
#include <cassert>
Expand Down Expand Up @@ -502,7 +503,7 @@ bool is_ggml_file(const char *filename) {
return false;
}
uint32_t magic = file.read_u32();
return magic == LLAMA_FILE_MAGIC;
return magic == GGUF_MAGIC;
}

void load_vocab(const char *filename, Config *config, struct llama_vocab *vocab) {
Expand Down Expand Up @@ -590,75 +591,80 @@ void save_as_llama_model(struct llama_vocab * vocab, struct my_llama_model * mod
if (file.fp == NULL) {
return;
}
// write_magic
file.write_u32(LLAMA_FILE_MAGIC); // magic
file.write_u32(LLAMA_FILE_VERSION); // version
// write_hparams
file.write_u32(model->hparams.n_vocab);
file.write_u32(model->hparams.n_embd);
file.write_u32(model->hparams.n_mult);
file.write_u32(model->hparams.n_head);
file.write_u32(model->hparams.n_layer);
file.write_u32(model->hparams.n_rot);
file.write_u32(LLAMA_FTYPE_ALL_F32);

// write_vocab - for now we are just writing the existing BPE voc. assuming karpathy's vocabulary is the same. idk.
uint32_t n_vocab = model->hparams.n_vocab;
for (uint32_t i = 0; i < n_vocab; i++) {
const auto & token_score = vocab->id_to_token.at(i);
file.write_u32((uint32_t) token_score.tok.size());
file.write_raw(token_score.tok.data(), token_score.tok.size());
file.write_raw(&token_score.score, sizeof(token_score.score));
}

// stuff AK weights into GG weights one by one.
// w->token_embedding_table -> model->tok_embeddings
// float* -> struct ggml_tensor
stuff_karpathy_weights_into_gg(model->tok_embeddings, w->token_embedding_table);
stuff_karpathy_weights_into_gg(model->output, w->token_embedding_table);

stuff_karpathy_weights_into_gg(model->norm, w->rms_final_weight);
//print_row(model->norm, 0);

// for rms-att-weight
int row_length = model->hparams.n_embd;
const auto & hparams = model->hparams;
//int n_ff = model->hparams.n_embd;
int n_ff = get_n_ff(&hparams);

for (uint32_t i = 0; i < model->hparams.n_layer; ++i){
auto & layer = model->layers[i];
// 1d
stuff_karpathy_weights_into_gg(layer.attention_norm, &w->rms_att_weight[i*row_length]);
stuff_karpathy_weights_into_gg(layer.ffn_norm , &w->rms_ffn_weight[i*row_length]);

// from 3d matrix layer x dim x dim to 2d matrix dim x dim
stuff_karpathy_weights_into_gg(layer.wq , &w->wq[i*row_length*row_length]);
stuff_karpathy_weights_into_gg(layer.wk , &w->wk[i*row_length*row_length]);
stuff_karpathy_weights_into_gg(layer.wv , &w->wv[i*row_length*row_length]);
stuff_karpathy_weights_into_gg(layer.wo , &w->wo[i*row_length*row_length]);

stuff_karpathy_weights_into_gg(layer.w1 , &w->w1[i*row_length*n_ff]);
stuff_karpathy_weights_into_gg(layer.w2 , &w->w2[i*n_ff*row_length]);
stuff_karpathy_weights_into_gg(layer.w3 , &w->w3[i*row_length*n_ff]);
}
// write tensors
write_tensor(&file, model->tok_embeddings);
write_tensor(&file, model->norm);
write_tensor(&file, model->output); // ?
for (uint32_t i = 0; i < model->hparams.n_layer; ++i) {
auto & layer = model->layers[i];

write_tensor(&file, layer.attention_norm);
write_tensor(&file, layer.wq);
write_tensor(&file, layer.wk);
write_tensor(&file, layer.wv);
write_tensor(&file, layer.wo);
write_tensor(&file, layer.ffn_norm);
write_tensor(&file, layer.w1);
write_tensor(&file, layer.w2);
write_tensor(&file, layer.w3);
}
#pragma message("TODO: implement file saving using gguf")
(void) vocab;
(void) model;
(void) w;
// // write_magic
// file.write_u32(LLAMA_FILE_MAGIC); // magic
// file.write_u32(LLAMA_FILE_VERSION); // version
// // write_hparams
// file.write_u32(model->hparams.n_vocab);
// file.write_u32(model->hparams.n_embd);
// file.write_u32(model->hparams.n_mult);
// file.write_u32(model->hparams.n_head);
// file.write_u32(model->hparams.n_layer);
// file.write_u32(model->hparams.n_rot);
// file.write_u32(LLAMA_FTYPE_ALL_F32);
//
// // write_vocab - for now we are just writing the existing BPE voc. assuming karpathy's vocabulary is the same. idk.
// uint32_t n_vocab = model->hparams.n_vocab;
// for (uint32_t i = 0; i < n_vocab; i++) {
// const auto & token_score = vocab->id_to_token.at(i);
// file.write_u32((uint32_t) token_score.tok.size());
// file.write_raw(token_score.tok.data(), token_score.tok.size());
// file.write_raw(&token_score.score, sizeof(token_score.score));
// }
//
// // stuff AK weights into GG weights one by one.
// // w->token_embedding_table -> model->tok_embeddings
// // float* -> struct ggml_tensor
// stuff_karpathy_weights_into_gg(model->tok_embeddings, w->token_embedding_table);
// stuff_karpathy_weights_into_gg(model->output, w->token_embedding_table);
//
// stuff_karpathy_weights_into_gg(model->norm, w->rms_final_weight);
// //print_row(model->norm, 0);
//
// // for rms-att-weight
// int row_length = model->hparams.n_embd;
// const auto & hparams = model->hparams;
// //int n_ff = model->hparams.n_embd;
// int n_ff = get_n_ff(&hparams);
//
// for (uint32_t i = 0; i < model->hparams.n_layer; ++i){
// auto & layer = model->layers[i];
// // 1d
// stuff_karpathy_weights_into_gg(layer.attention_norm, &w->rms_att_weight[i*row_length]);
// stuff_karpathy_weights_into_gg(layer.ffn_norm , &w->rms_ffn_weight[i*row_length]);
//
// // from 3d matrix layer x dim x dim to 2d matrix dim x dim
// stuff_karpathy_weights_into_gg(layer.wq , &w->wq[i*row_length*row_length]);
// stuff_karpathy_weights_into_gg(layer.wk , &w->wk[i*row_length*row_length]);
// stuff_karpathy_weights_into_gg(layer.wv , &w->wv[i*row_length*row_length]);
// stuff_karpathy_weights_into_gg(layer.wo , &w->wo[i*row_length*row_length]);
//
// stuff_karpathy_weights_into_gg(layer.w1 , &w->w1[i*row_length*n_ff]);
// stuff_karpathy_weights_into_gg(layer.w2 , &w->w2[i*n_ff*row_length]);
// stuff_karpathy_weights_into_gg(layer.w3 , &w->w3[i*row_length*n_ff]);
// }
// // write tensors
// write_tensor(&file, model->tok_embeddings);
// write_tensor(&file, model->norm);
// write_tensor(&file, model->output); // ?
// for (uint32_t i = 0; i < model->hparams.n_layer; ++i) {
// auto & layer = model->layers[i];
//
// write_tensor(&file, layer.attention_norm);
// write_tensor(&file, layer.wq);
// write_tensor(&file, layer.wk);
// write_tensor(&file, layer.wv);
// write_tensor(&file, layer.wo);
// write_tensor(&file, layer.ffn_norm);
// write_tensor(&file, layer.w1);
// write_tensor(&file, layer.w2);
// write_tensor(&file, layer.w3);
// }
}

struct train_params get_default_train_params() {
Expand Down
Loading

0 comments on commit 88b5769

Please sign in to comment.