diff --git a/common/common.cpp b/common/common.cpp index ed09fc27df711..d0f9679a2d84a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -315,6 +315,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.model = argv[i]; + } else if (arg == "-md" || arg == "--model-draft") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.model_draft = argv[i]; } else if (arg == "-a" || arg == "--alias") { if (++i >= argc) { invalid_param = true; @@ -667,6 +673,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); fprintf(stdout, " -m FNAME, --model FNAME\n"); fprintf(stdout, " model path (default: %s)\n", params.model.c_str()); + fprintf(stdout, " -md FNAME, --model-draft FNAME\n"); + fprintf(stdout, " draft model for speculative sampling (default: %s)\n", params.model.c_str()); fprintf(stdout, " -ld LOGDIR, --logdir LOGDIR\n"); fprintf(stdout, " path under which to save YAML logs (no logging if unset)\n"); fprintf(stdout, "\n"); @@ -1060,6 +1068,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "mirostat_lr: %f # default: 0.1\n", params.mirostat_eta); fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false"); fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str()); + fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str()); fprintf(stream, "mtest: %s # default: false\n", params.mem_test ? "true" : "false"); fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false"); fprintf(stream, "n_gpu_layers: %d # default: 0\n", params.n_gpu_layers); diff --git a/common/common.h b/common/common.h index 5a379688ee529..a7e9d3b6d2a0d 100644 --- a/common/common.h +++ b/common/common.h @@ -63,6 +63,7 @@ struct gpt_params { float cfg_scale = 1.f; // How strong is guidance std::string model = "models/7B/ggml-model-f16.gguf"; // model path + std::string model_draft = ""; // draft model for speculative sampling std::string model_alias = "unknown"; // model alias std::string prompt = ""; std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index d1f547e0b337c..608a61ed3af2c 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -19,17 +19,18 @@ int main(int argc, char ** argv) { return 1; } + if (params.model_draft.empty()) { + fprintf(stderr, "%s: error: --model-draft is required\n", __func__); + return 1; + } + #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("speculative", "log")); LOG_TEE("Log start\n"); log_dump_cmdline(argc, argv); #endif // LOG_DISABLE_LOGS - // TODO: tmp hardcoded - const std::string fname_draft = "../models/codellama-7b/ggml-model-q4_1.gguf"; - - // init LLM - + // init llama.cpp llama_backend_init(params.numa); llama_model * model_tgt = NULL; @@ -43,11 +44,10 @@ int main(int argc, char ** argv) { std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params); // load the draft model - params.model = fname_draft; + params.model = params.model_draft; std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params); // tokenize the prompt - std::vector inp; inp = ::llama_tokenize(ctx_tgt, params.prompt, true); @@ -77,7 +77,8 @@ int main(int argc, char ** argv) { const int n_vocab = llama_n_vocab(ctx_tgt); //GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft)); - const int n_draft = 12; + // how many tokens to draft each time + const int n_draft = 8; int n_predict = 0; int n_drafted = 0; @@ -237,8 +238,8 @@ int main(int argc, char ** argv) { for (int i = 0; i < n_draft; ++i) { float * logits = llama_get_logits(ctx_dft); - int best_id = -1; - float best_logit = -1e30f; + int best_id = -1; + float best_logit = -1e30f; float best_logit2 = -1e30f; for (int j = 0; j < n_vocab; ++j) { if (logits[j] > best_logit) { @@ -275,7 +276,7 @@ int main(int argc, char ** argv) { LOG_TEE("generated %d tokens in %.3f seconds, speed: %.3f t/s\n", n_predict, (t_gen_end - t_gen_start) / 1e6f, n_predict / ((t_gen_end - t_gen_start) / 1e6f)); // TODO: make sure these numbers are computed correctly - LOG_TEE("\n\n"); + LOG_TEE("\n"); LOG_TEE("n_draft = %d\n", n_draft); LOG_TEE("n_predict = %d\n", n_predict); LOG_TEE("n_drafted = %d\n", n_drafted);