Skip to content

Commit

Permalink
make test pass
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Jan 29, 2025
1 parent a736d0b commit ec6eee5
Show file tree
Hide file tree
Showing 11 changed files with 78 additions and 18 deletions.
3 changes: 3 additions & 0 deletions src/cpp/include/openvino/genai/generation_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ enum class StopCriteria { EARLY, HEURISTIC, NEVER };
* @param assistant_confidence_threshold the lower token probability of candidate to be validated by main model in case of dynamic strategy candidates number update.
* @param num_assistant_tokens the defined candidates number to be generated by draft model/prompt lookup in case of static strategy candidates number update.
* @param max_ngram_size is maximum ngram to use when looking for matches in the prompt.
*
* @param apply_chat_template whether or not to apply chat_template for non-chat scenarios
*/

class OPENVINO_GENAI_EXPORTS GenerationConfig {
Expand Down Expand Up @@ -128,6 +130,7 @@ class OPENVINO_GENAI_EXPORTS GenerationConfig {

std::optional<AdapterConfig> adapters;

// set to true if chat template should be applied for non-chat scenarios, set to false otherwise
bool apply_chat_template = true;

/** @brief sets eos_token_id to tokenizer_eos_token_id if eos_token_id is less than 0.
Expand Down
4 changes: 1 addition & 3 deletions src/cpp/include/openvino/genai/whisper_generation_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace genai {
*/
class OPENVINO_GENAI_EXPORTS WhisperGenerationConfig : public GenerationConfig {
public:
WhisperGenerationConfig() = default;
WhisperGenerationConfig();
explicit WhisperGenerationConfig(const std::filesystem::path& json_path);

// Corresponds to the ”<|startoftranscript|>” token.
Expand Down Expand Up @@ -97,8 +97,6 @@ class OPENVINO_GENAI_EXPORTS WhisperGenerationConfig : public GenerationConfig {
// A list containing the non-speech tokens that will be suppressed during generation.
std::vector<int64_t> suppress_tokens;

bool apply_chat_template = false;

void update_generation_config(const ov::AnyMap& config_map = {});

template <typename... Properties>
Expand Down
3 changes: 2 additions & 1 deletion src/cpp/src/icontinuous_batching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::generate(
encoded_inputs = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false)).input_ids;
} else {
// in case when chat_template was not found in tokenizer_config.json or set
encoded_inputs = m_tokenizer.encode(input_str).input_ids;
std::string input_str(prompt);
encoded_inputs = m_tokenizer.encode(input_str, ov::genai::add_special_tokens(true)).input_ids;
}
input_ids.push_back(encoded_inputs);
tokenization_durations.emplace_back(PerfMetrics::get_microsec(std::chrono::steady_clock::now() - encode_start));
Expand Down
26 changes: 23 additions & 3 deletions src/cpp/src/llm_pipeline_stateful.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,18 @@ DecodedResults StatefulLLMPipeline::generate(

if (auto input_vector = std::get_if<std::vector<std::string>>(&inputs)) {
OPENVINO_ASSERT(!is_chat_conversation, "Can't chat with multiple prompts");
encoded_input = m_tokenizer.encode(*input_vector);
if (config.apply_chat_template && !m_tokenizer.get_chat_template().empty()) {
std::vector<std::string> templated_input_vector;
for (auto& input : *input_vector) {
ChatHistory history({{{"role", "user"}, {"content", input}}});
constexpr bool add_generation_prompt = true;
auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt);
templated_input_vector.push_back(templated_prompt);
}
encoded_input = m_tokenizer.encode(templated_input_vector, ov::genai::add_special_tokens(false));
} else {
encoded_input = m_tokenizer.encode(*input_vector, ov::genai::add_special_tokens(true));
}
} else if (auto input_prompt = std::get_if<std::string>(&inputs)) {
std::string& prompt = *input_prompt;

Expand All @@ -104,7 +115,7 @@ DecodedResults StatefulLLMPipeline::generate(

m_history.push_back({{"role", "user"}, {"content", prompt}});
constexpr bool add_generation_prompt = true;
auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
// Do not add special tokens in chat scenario to be aligned with HF.
auto new_chat_tokens = m_tokenizer.encode(new_templated_chat_history, ov::genai::add_special_tokens(false));
auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(false));
Expand Down Expand Up @@ -157,7 +168,16 @@ DecodedResults StatefulLLMPipeline::generate(

// TODO: Forbid LoRA config change if we are in the chat mode, because it requires regenerating the history with LoRA applied
} else {
encoded_input = m_tokenizer.encode(prompt);
std::string& prompt = *input_prompt;
if (config.apply_chat_template && !m_tokenizer.get_chat_template().empty()) {
ChatHistory history({{{"role", "user"}, {"content", prompt}}});
constexpr bool add_generation_prompt = true;
auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt);
encoded_input = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false));
} else {
// in case when chat_template was not found in tokenizer_config.json or set
encoded_input = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(true));
}
}
}

Expand Down
20 changes: 18 additions & 2 deletions src/cpp/src/llm_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,15 @@ DecodedResults StatefulLLMPipeline::generate(
// for chat ov::genai::add_special_tokens(false) is aligned with stateful pipeline and HF
tokenized_input = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(false));
} else {
tokenized_input = m_tokenizer.encode(prompt);
if (config.apply_chat_template && !m_tokenizer.get_chat_template().empty()) {
ChatHistory history({{{"role", "user"}, {"content", prompt}}});
constexpr bool add_generation_prompt = true;
auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt);
tokenized_input = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false));
} else {
// in case when chat_template was not found in tokenizer_config.json or set
tokenized_input = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(true));
}
}

auto encode_stop_time = std::chrono::steady_clock::now();
Expand Down Expand Up @@ -1294,7 +1302,15 @@ DecodedResults StatelessLLMPipeline::generate(
// for chat ov::genai::add_special_tokens(false) is aligned with stateful pipeline and HF
tokenized_input = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(false));
} else {
tokenized_input = m_tokenizer.encode(prompt);
if (config.apply_chat_template && !m_tokenizer.get_chat_template().empty()) {
ChatHistory history({{{"role", "user"}, {"content", prompt}}});
constexpr bool add_generation_prompt = true;
auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt);
tokenized_input = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false));
} else {
// in case when chat_template was not found in tokenizer_config.json or set
tokenized_input = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(true));
}
}

auto encode_stop_time = std::chrono::steady_clock::now();
Expand Down
6 changes: 6 additions & 0 deletions src/cpp/src/whisper_generation_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
namespace ov {
namespace genai {

WhisperGenerationConfig::WhisperGenerationConfig() {
apply_chat_template = false;
}

WhisperGenerationConfig::WhisperGenerationConfig(const std::filesystem::path& json_path)
: GenerationConfig::GenerationConfig(json_path) {
using ov::genai::utils::read_json_param;
Expand All @@ -38,6 +42,8 @@ WhisperGenerationConfig::WhisperGenerationConfig(const std::filesystem::path& js
}

read_json_param(data, "lang_to_id", lang_to_id);

apply_chat_template = false;
}

void WhisperGenerationConfig::update_generation_config(const ov::AnyMap& config_map) {
Expand Down
3 changes: 3 additions & 0 deletions src/python/openvino_genai/py_openvino_genai.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ class GenerationConfig:
echo: if set to true, the model will echo the prompt in the output.
logprobs: number of top logprobs computed for each position, if set to 0, logprobs are not computed and value 0.0 is returned.
Currently only single top logprob can be returned, so any logprobs > 1 is treated as logprobs == 1. (default: 0).
apply_chat_template: whether to apply chat_template for non-chat scenarios
repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty.
presence_penalty: reduces absolute log prob if the token was generated at least once.
Expand Down Expand Up @@ -997,6 +998,7 @@ class LLMPipeline:
echo: if set to true, the model will echo the prompt in the output.
logprobs: number of top logprobs computed for each position, if set to 0, logprobs are not computed and value 0.0 is returned.
Currently only single top logprob can be returned, so any logprobs > 1 is treated as logprobs == 1. (default: 0).
apply_chat_template: whether to apply chat_template for non-chat scenarios
repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty.
presence_penalty: reduces absolute log prob if the token was generated at least once.
Expand Down Expand Up @@ -1082,6 +1084,7 @@ class LLMPipeline:
echo: if set to true, the model will echo the prompt in the output.
logprobs: number of top logprobs computed for each position, if set to 0, logprobs are not computed and value 0.0 is returned.
Currently only single top logprob can be returned, so any logprobs > 1 is treated as logprobs == 1. (default: 0).
apply_chat_template: whether to apply chat_template for non-chat scenarios
repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty.
presence_penalty: reduces absolute log prob if the token was generated at least once.
Expand Down
1 change: 1 addition & 0 deletions src/python/py_generation_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ char generation_config_docstring[] = R"(
echo: if set to true, the model will echo the prompt in the output.
logprobs: number of top logprobs computed for each position, if set to 0, logprobs are not computed and value 0.0 is returned.
Currently only single top logprob can be returned, so any logprobs > 1 is treated as logprobs == 1. (default: 0).
apply_chat_template: whether to apply chat_template for non-chat scenarios
repetition_penalty: the parameter for repetition penalty. 1.0 means no penalty.
presence_penalty: reduces absolute log prob if the token was generated at least once.
Expand Down
18 changes: 15 additions & 3 deletions tests/python_tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,12 @@ def run_hugging_face(
# process prompt by promp as we have multiple generation configs
for prompt, generation_config in zip(prompts, generation_configs):
hf_generation_config = convert_to_hf(opt_model.generation_config, generation_config)
inputs = hf_tokenizer(prompt, return_tensors="pt")
inputs = {}
if hf_tokenizer.chat_template and generation_config.apply_chat_template:
prompt = hf_tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True)
inputs = hf_tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
else:
inputs = hf_tokenizer(prompt, return_tensors="pt")
input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask']
prompt_len = 0 if generation_config.echo else input_ids.numel()

Expand All @@ -266,8 +271,15 @@ def run_hugging_face(
generation_result.m_scores = [score for score in generate_outputs.sequences_scores]
generation_results.append(generation_result)
else:
# process all prompts as a single batch as we have a single generation config for all prompts
inputs = hf_tokenizer(prompts, return_tensors='pt', padding=True, truncation=True, add_special_tokens=True, padding_side='left')
inputs = {}
if hf_tokenizer.chat_template and generation_configs.apply_chat_template:
processed_prompts = []
for prompt in prompts:
processed_prompts.append(hf_tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True))
# process all prompts as a single batch as we have a single generation config for all prompts
inputs = hf_tokenizer(processed_prompts, return_tensors='pt', padding=True, truncation=True, add_special_tokens=False, padding_side='left')
else:
inputs = hf_tokenizer(prompts, return_tensors='pt', padding=True, truncation=True, padding_side='left')
input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask']
hf_generation_config = convert_to_hf(opt_model.generation_config, generation_configs)
hf_encoded_outputs = opt_model.generate(input_ids, attention_mask=attention_mask, generation_config=hf_generation_config, tokenizer=hf_tokenizer)
Expand Down
8 changes: 4 additions & 4 deletions tests/python_tests/test_llm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

test_cases = [
(dict(max_new_tokens=20), '你好! 你好嗎?'),
(dict(max_new_tokens=30, num_beams=15, num_beam_groups=3, num_return_sequences=15, diversity_penalty=1.0), 'Alan Turing was a'),
(dict(max_new_tokens=30, num_beams=15, num_beam_groups=3, num_return_sequences=15, diversity_penalty=1.0), 'Why is the Sun yellow?'),
]
@pytest.mark.parametrize("generation_config_dict,prompt", test_cases)
@pytest.mark.parametrize("model_descr", get_models_list())
Expand Down Expand Up @@ -339,7 +339,7 @@ def test_unicode_pybind_decoding_one_string():
# Test that pybind will not fail.
model_id, path = 'katuni4ka/tiny-random-phi3', Path('tiny-random-phi3')
ov_pipe = read_model((model_id, path))[4]
res_str = ov_pipe.generate(',', max_new_tokens=4)
res_str = ov_pipe.generate(',', max_new_tokens=4, apply_chat_template=False)
assert '�' == res_str[-1]


Expand All @@ -350,7 +350,7 @@ def test_unicode_pybind_decoding_batched():
# Test that pybind will not fail.
model_id, path = 'katuni4ka/tiny-random-phi3', Path('tiny-random-phi3')
ov_pipe = read_model((model_id, path))[4]
res_str = ov_pipe.generate([","], max_new_tokens=4)
res_str = ov_pipe.generate([","], max_new_tokens=4, apply_chat_template=False)
assert '�' == res_str.texts[0][-1]


Expand All @@ -362,7 +362,7 @@ def test_unicode_pybind_decoding_one_string_streamer():
model_id, path = 'katuni4ka/tiny-random-phi3', Path('tiny-random-phi3')
ov_pipe = read_model((model_id, path))[4]
res_str = []
ov_pipe.generate(",", max_new_tokens=4, streamer=lambda x: res_str.append(x))
ov_pipe.generate(",", max_new_tokens=4, apply_chat_template=False, streamer=lambda x: res_str.append(x))
assert '�' == ''.join(res_str)[-1]

#
Expand Down
4 changes: 2 additions & 2 deletions tests/python_tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
(dict(max_new_tokens=30, min_new_tokens=30), '你好! 你好嗎?'),
(dict(max_new_tokens=30, ignore_eos=True), 'Alan Turing was a'),
# (dict(max_length=40), 'table is made of'),
(dict(stop_token_ids={28998}), 'The Sun is yellow because'), # since a test does not hang, it means stop token is met
(dict(stop_token_ids={28998}, apply_chat_template=False), 'The Sun is yellow because'), # since a test does not hang, it means stop token is met, skip chat template to generate long answer
# (dict(max_new_tokens=1, min_new_tokens=0, echo=True), 'What is OpenVINO?')
],
ids=["max_new_tokens",
Expand Down Expand Up @@ -59,7 +59,7 @@ def test_stop_strings(tmp_path, generation_config):
@pytest.mark.parametrize("generation_config",
[dict(max_new_tokens=30),
dict(max_new_tokens=30, repetition_penalty=2.0),
dict(max_new_tokens=300)],
dict(max_new_tokens=300, apply_chat_template=False)],
ids=["basic", "repetition_penalty", "long_max_new_tokens"])
@pytest.mark.parametrize("prompt", [
'What is OpenVINO?',
Expand Down

0 comments on commit ec6eee5

Please sign in to comment.