diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 7896dd26e4ad2..acf79a89268fb 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -31,8 +31,6 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif -static const char * DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant"; - static llama_context ** g_ctx; static llama_model ** g_model; static common_sampler ** g_smpl; @@ -267,6 +265,7 @@ int main(int argc, char ** argv) { std::vector embd_inp; + bool waiting_for_first_input = params.conversation_mode && params.enable_chat_template && params.system_prompt.empty(); auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) { common_chat_msg new_msg; new_msg.role = role; @@ -278,11 +277,20 @@ int main(int argc, char ** argv) { }; { - auto prompt = (params.conversation_mode && params.enable_chat_template) - // format the system prompt in conversation mode (fallback to default if empty) - ? chat_add_and_format("system", params.system_prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.system_prompt) + std::string prompt; + + if (params.conversation_mode && params.enable_chat_template) { + // format the system prompt in conversation mode (will use template default if empty) + prompt = params.system_prompt; + + if (!prompt.empty()) { + prompt = chat_add_and_format("system", prompt); + } + } else { // otherwise use the prompt as is - : params.prompt; + prompt = params.prompt; + } + if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { LOG_DBG("tokenize the prompt\n"); embd_inp = common_tokenize(ctx, prompt, true, true); @@ -296,7 +304,7 @@ int main(int argc, char ** argv) { } // Should not run without any tokens - if (embd_inp.empty()) { + if (!params.conversation_mode && embd_inp.empty()) { if (add_bos) { embd_inp.push_back(llama_vocab_bos(vocab)); LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str()); @@ -777,7 +785,7 @@ int main(int argc, char ** argv) { } // deal with end of generation tokens in interactive mode - if (llama_vocab_is_eog(vocab, common_sampler_last(smpl))) { + if (!waiting_for_first_input && llama_vocab_is_eog(vocab, common_sampler_last(smpl))) { LOG_DBG("found an EOG token\n"); if (params.interactive) { @@ -797,12 +805,12 @@ int main(int argc, char ** argv) { } // if current token is not EOG, we add it to current assistant message - if (params.conversation_mode) { + if (params.conversation_mode && !waiting_for_first_input) { const auto id = common_sampler_last(smpl); assistant_ss << common_token_to_piece(ctx, id, false); } - if (n_past > 0 && is_interacting) { + if ((n_past > 0 || waiting_for_first_input) && is_interacting) { LOG_DBG("waiting for user input\n"); if (params.conversation_mode) { @@ -892,11 +900,12 @@ int main(int argc, char ** argv) { input_echo = false; // do not echo this again } - if (n_past > 0) { + if (n_past > 0 || waiting_for_first_input) { if (is_interacting) { common_sampler_reset(smpl); } is_interacting = false; + waiting_for_first_input = false; } }