Skip to content

Commit

Permalink
speculative : add --model-draft CLI arg
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Aug 31, 2023
1 parent b7fa7e7 commit ec5086f
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 11 deletions.
9 changes: 9 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.model = argv[i];
} else if (arg == "-md" || arg == "--model-draft") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.model_draft = argv[i];
} else if (arg == "-a" || arg == "--alias") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -667,6 +673,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stdout, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
fprintf(stdout, " -m FNAME, --model FNAME\n");
fprintf(stdout, " model path (default: %s)\n", params.model.c_str());
fprintf(stdout, " -md FNAME, --model-draft FNAME\n");
fprintf(stdout, " draft model for speculative sampling (default: %s)\n", params.model.c_str());
fprintf(stdout, " -ld LOGDIR, --logdir LOGDIR\n");
fprintf(stdout, " path under which to save YAML logs (no logging if unset)\n");
fprintf(stdout, "\n");
Expand Down Expand Up @@ -1060,6 +1068,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", params.mirostat_eta);
fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false");
fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str());
fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str());
fprintf(stream, "mtest: %s # default: false\n", params.mem_test ? "true" : "false");
fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false");
fprintf(stream, "n_gpu_layers: %d # default: 0\n", params.n_gpu_layers);
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ struct gpt_params {
float cfg_scale = 1.f; // How strong is guidance

std::string model = "models/7B/ggml-model-f16.gguf"; // model path
std::string model_draft = ""; // draft model for speculative sampling
std::string model_alias = "unknown"; // model alias
std::string prompt = "";
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
Expand Down
23 changes: 12 additions & 11 deletions examples/speculative/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,18 @@ int main(int argc, char ** argv) {
return 1;
}

if (params.model_draft.empty()) {
fprintf(stderr, "%s: error: --model-draft is required\n", __func__);
return 1;
}

#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("speculative", "log"));
LOG_TEE("Log start\n");
log_dump_cmdline(argc, argv);
#endif // LOG_DISABLE_LOGS

// TODO: tmp hardcoded
const std::string fname_draft = "../models/codellama-7b/ggml-model-q4_1.gguf";

// init LLM

// init llama.cpp
llama_backend_init(params.numa);

llama_model * model_tgt = NULL;
Expand All @@ -43,11 +44,10 @@ int main(int argc, char ** argv) {
std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);

// load the draft model
params.model = fname_draft;
params.model = params.model_draft;
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);

// tokenize the prompt

std::vector<llama_token> inp;
inp = ::llama_tokenize(ctx_tgt, params.prompt, true);

Expand Down Expand Up @@ -77,7 +77,8 @@ int main(int argc, char ** argv) {
const int n_vocab = llama_n_vocab(ctx_tgt);
//GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft));

const int n_draft = 12;
// how many tokens to draft each time
const int n_draft = 8;

int n_predict = 0;
int n_drafted = 0;
Expand Down Expand Up @@ -237,8 +238,8 @@ int main(int argc, char ** argv) {
for (int i = 0; i < n_draft; ++i) {
float * logits = llama_get_logits(ctx_dft);

int best_id = -1;
float best_logit = -1e30f;
int best_id = -1;
float best_logit = -1e30f;
float best_logit2 = -1e30f;
for (int j = 0; j < n_vocab; ++j) {
if (logits[j] > best_logit) {
Expand Down Expand Up @@ -275,7 +276,7 @@ int main(int argc, char ** argv) {
LOG_TEE("generated %d tokens in %.3f seconds, speed: %.3f t/s\n", n_predict, (t_gen_end - t_gen_start) / 1e6f, n_predict / ((t_gen_end - t_gen_start) / 1e6f));

// TODO: make sure these numbers are computed correctly
LOG_TEE("\n\n");
LOG_TEE("\n");
LOG_TEE("n_draft = %d\n", n_draft);
LOG_TEE("n_predict = %d\n", n_predict);
LOG_TEE("n_drafted = %d\n", n_drafted);
Expand Down

0 comments on commit ec5086f

Please sign in to comment.