Skip to content

Commit

Permalink
make test pass
Browse files Browse the repository at this point in the history
  • Loading branch information
sbalandi committed Jan 29, 2025
1 parent a736d0b commit faf1043
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 18 deletions.
1 change: 1 addition & 0 deletions src/cpp/include/openvino/genai/generation_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class OPENVINO_GENAI_EXPORTS GenerationConfig {

std::optional<AdapterConfig> adapters;

// set to true if chat template should be applied for non-chat scenarios, set to false otherwise
bool apply_chat_template = true;

/** @brief sets eos_token_id to tokenizer_eos_token_id if eos_token_id is less than 0.
Expand Down
4 changes: 1 addition & 3 deletions src/cpp/include/openvino/genai/whisper_generation_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace genai {
*/
class OPENVINO_GENAI_EXPORTS WhisperGenerationConfig : public GenerationConfig {
public:
WhisperGenerationConfig() = default;
WhisperGenerationConfig();
explicit WhisperGenerationConfig(const std::filesystem::path& json_path);

// Corresponds to the ”<|startoftranscript|>” token.
Expand Down Expand Up @@ -97,8 +97,6 @@ class OPENVINO_GENAI_EXPORTS WhisperGenerationConfig : public GenerationConfig {
// A list containing the non-speech tokens that will be suppressed during generation.
std::vector<int64_t> suppress_tokens;

bool apply_chat_template = false;

void update_generation_config(const ov::AnyMap& config_map = {});

template <typename... Properties>
Expand Down
3 changes: 2 additions & 1 deletion src/cpp/src/icontinuous_batching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ ContinuousBatchingPipeline::IContinuousBatchingPipeline::generate(
encoded_inputs = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false)).input_ids;
} else {
// in case when chat_template was not found in tokenizer_config.json or set
encoded_inputs = m_tokenizer.encode(input_str).input_ids;
std::string input_str(prompt);
encoded_inputs = m_tokenizer.encode(input_str, ov::genai::add_special_tokens(true)).input_ids;
}
input_ids.push_back(encoded_inputs);
tokenization_durations.emplace_back(PerfMetrics::get_microsec(std::chrono::steady_clock::now() - encode_start));
Expand Down
26 changes: 23 additions & 3 deletions src/cpp/src/llm_pipeline_stateful.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,18 @@ DecodedResults StatefulLLMPipeline::generate(

if (auto input_vector = std::get_if<std::vector<std::string>>(&inputs)) {
OPENVINO_ASSERT(!is_chat_conversation, "Can't chat with multiple prompts");
encoded_input = m_tokenizer.encode(*input_vector);
if (config.apply_chat_template && !m_tokenizer.get_chat_template().empty()) {
std::vector<std::string> templated_input_vector;
for (auto& input : *input_vector) {
ChatHistory history({{{"role", "user"}, {"content", input}}});
constexpr bool add_generation_prompt = true;
auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt);
templated_input_vector.push_back(templated_prompt);
}
encoded_input = m_tokenizer.encode(templated_input_vector, ov::genai::add_special_tokens(false));
} else {
encoded_input = m_tokenizer.encode(*input_vector, ov::genai::add_special_tokens(true));
}
} else if (auto input_prompt = std::get_if<std::string>(&inputs)) {
std::string& prompt = *input_prompt;

Expand All @@ -104,7 +115,7 @@ DecodedResults StatefulLLMPipeline::generate(

m_history.push_back({{"role", "user"}, {"content", prompt}});
constexpr bool add_generation_prompt = true;
auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
auto new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
// Do not add special tokens in chat scenario to be aligned with HF.
auto new_chat_tokens = m_tokenizer.encode(new_templated_chat_history, ov::genai::add_special_tokens(false));
auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(false));
Expand Down Expand Up @@ -157,7 +168,16 @@ DecodedResults StatefulLLMPipeline::generate(

// TODO: Forbid LoRA config change if we are in the chat mode, because it requires regenerating the history with LoRA applied
} else {
encoded_input = m_tokenizer.encode(prompt);
std::string& prompt = *input_prompt;
if (config.apply_chat_template && !m_tokenizer.get_chat_template().empty()) {
ChatHistory history({{{"role", "user"}, {"content", prompt}}});
constexpr bool add_generation_prompt = true;
auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt);
encoded_input = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false));
} else {
// in case when chat_template was not found in tokenizer_config.json or set
encoded_input = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(true));
}
}
}

Expand Down
20 changes: 18 additions & 2 deletions src/cpp/src/llm_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,15 @@ DecodedResults StatefulLLMPipeline::generate(
// for chat ov::genai::add_special_tokens(false) is aligned with stateful pipeline and HF
tokenized_input = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(false));
} else {
tokenized_input = m_tokenizer.encode(prompt);
if (config.apply_chat_template && !m_tokenizer.get_chat_template().empty()) {
ChatHistory history({{{"role", "user"}, {"content", prompt}}});
constexpr bool add_generation_prompt = true;
auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt);
tokenized_input = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false));
} else {
// in case when chat_template was not found in tokenizer_config.json or set
tokenized_input = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(true));
}
}

auto encode_stop_time = std::chrono::steady_clock::now();
Expand Down Expand Up @@ -1294,7 +1302,15 @@ DecodedResults StatelessLLMPipeline::generate(
// for chat ov::genai::add_special_tokens(false) is aligned with stateful pipeline and HF
tokenized_input = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(false));
} else {
tokenized_input = m_tokenizer.encode(prompt);
if (config.apply_chat_template && !m_tokenizer.get_chat_template().empty()) {
ChatHistory history({{{"role", "user"}, {"content", prompt}}});
constexpr bool add_generation_prompt = true;
auto templated_prompt = m_tokenizer.apply_chat_template(history, add_generation_prompt);
tokenized_input = m_tokenizer.encode(templated_prompt, ov::genai::add_special_tokens(false));
} else {
// in case when chat_template was not found in tokenizer_config.json or set
tokenized_input = m_tokenizer.encode(prompt, ov::genai::add_special_tokens(true));
}
}

auto encode_stop_time = std::chrono::steady_clock::now();
Expand Down
6 changes: 6 additions & 0 deletions src/cpp/src/whisper_generation_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
namespace ov {
namespace genai {

WhisperGenerationConfig::WhisperGenerationConfig() {
apply_chat_template = false;
}

WhisperGenerationConfig::WhisperGenerationConfig(const std::filesystem::path& json_path)
: GenerationConfig::GenerationConfig(json_path) {
using ov::genai::utils::read_json_param;
Expand All @@ -38,6 +42,8 @@ WhisperGenerationConfig::WhisperGenerationConfig(const std::filesystem::path& js
}

read_json_param(data, "lang_to_id", lang_to_id);

apply_chat_template = false;
}

void WhisperGenerationConfig::update_generation_config(const ov::AnyMap& config_map) {
Expand Down
18 changes: 15 additions & 3 deletions tests/python_tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,12 @@ def run_hugging_face(
# process prompt by promp as we have multiple generation configs
for prompt, generation_config in zip(prompts, generation_configs):
hf_generation_config = convert_to_hf(opt_model.generation_config, generation_config)
inputs = hf_tokenizer(prompt, return_tensors="pt")
inputs = {}
if hf_tokenizer.chat_template and generation_config.apply_chat_template:
prompt = hf_tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True)
inputs = hf_tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
else:
inputs = hf_tokenizer(prompt, return_tensors="pt")
input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask']
prompt_len = 0 if generation_config.echo else input_ids.numel()

Expand All @@ -266,8 +271,15 @@ def run_hugging_face(
generation_result.m_scores = [score for score in generate_outputs.sequences_scores]
generation_results.append(generation_result)
else:
# process all prompts as a single batch as we have a single generation config for all prompts
inputs = hf_tokenizer(prompts, return_tensors='pt', padding=True, truncation=True, add_special_tokens=True, padding_side='left')
inputs = {}
if hf_tokenizer.chat_template and generation_configs.apply_chat_template:
processed_prompts = []
for prompt in prompts:
processed_prompts.append(hf_tokenizer.apply_chat_template([{'role': 'user', 'content': prompt}], tokenize=False, add_generation_prompt=True))
# process all prompts as a single batch as we have a single generation config for all prompts
inputs = hf_tokenizer(processed_prompts, return_tensors='pt', padding=True, truncation=True, add_special_tokens=False, padding_side='left')
else:
inputs = hf_tokenizer(prompts, return_tensors='pt', padding=True, truncation=True, padding_side='left')
input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask']
hf_generation_config = convert_to_hf(opt_model.generation_config, generation_configs)
hf_encoded_outputs = opt_model.generate(input_ids, attention_mask=attention_mask, generation_config=hf_generation_config, tokenizer=hf_tokenizer)
Expand Down
8 changes: 4 additions & 4 deletions tests/python_tests/test_llm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

test_cases = [
(dict(max_new_tokens=20), '你好! 你好嗎?'),
(dict(max_new_tokens=30, num_beams=15, num_beam_groups=3, num_return_sequences=15, diversity_penalty=1.0), 'Alan Turing was a'),
(dict(max_new_tokens=30, num_beams=15, num_beam_groups=3, num_return_sequences=15, diversity_penalty=1.0), 'Why is the Sun yellow?'),
]
@pytest.mark.parametrize("generation_config_dict,prompt", test_cases)
@pytest.mark.parametrize("model_descr", get_models_list())
Expand Down Expand Up @@ -339,7 +339,7 @@ def test_unicode_pybind_decoding_one_string():
# Test that pybind will not fail.
model_id, path = 'katuni4ka/tiny-random-phi3', Path('tiny-random-phi3')
ov_pipe = read_model((model_id, path))[4]
res_str = ov_pipe.generate(',', max_new_tokens=4)
res_str = ov_pipe.generate(',', max_new_tokens=4, apply_chat_template=False)
assert '�' == res_str[-1]


Expand All @@ -350,7 +350,7 @@ def test_unicode_pybind_decoding_batched():
# Test that pybind will not fail.
model_id, path = 'katuni4ka/tiny-random-phi3', Path('tiny-random-phi3')
ov_pipe = read_model((model_id, path))[4]
res_str = ov_pipe.generate([","], max_new_tokens=4)
res_str = ov_pipe.generate([","], max_new_tokens=4, apply_chat_template=False)
assert '�' == res_str.texts[0][-1]


Expand All @@ -362,7 +362,7 @@ def test_unicode_pybind_decoding_one_string_streamer():
model_id, path = 'katuni4ka/tiny-random-phi3', Path('tiny-random-phi3')
ov_pipe = read_model((model_id, path))[4]
res_str = []
ov_pipe.generate(",", max_new_tokens=4, streamer=lambda x: res_str.append(x))
ov_pipe.generate(",", max_new_tokens=4, apply_chat_template=False, streamer=lambda x: res_str.append(x))
assert '�' == ''.join(res_str)[-1]

#
Expand Down
4 changes: 2 additions & 2 deletions tests/python_tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
(dict(max_new_tokens=30, min_new_tokens=30), '你好! 你好嗎?'),
(dict(max_new_tokens=30, ignore_eos=True), 'Alan Turing was a'),
# (dict(max_length=40), 'table is made of'),
(dict(stop_token_ids={28998}), 'The Sun is yellow because'), # since a test does not hang, it means stop token is met
(dict(stop_token_ids={28998}, apply_chat_template=False), 'The Sun is yellow because'), # since a test does not hang, it means stop token is met, skip chat template to generate long answer
# (dict(max_new_tokens=1, min_new_tokens=0, echo=True), 'What is OpenVINO?')
],
ids=["max_new_tokens",
Expand Down Expand Up @@ -59,7 +59,7 @@ def test_stop_strings(tmp_path, generation_config):
@pytest.mark.parametrize("generation_config",
[dict(max_new_tokens=30),
dict(max_new_tokens=30, repetition_penalty=2.0),
dict(max_new_tokens=300)],
dict(max_new_tokens=300, apply_chat_template=False)],
ids=["basic", "repetition_penalty", "long_max_new_tokens"])
@pytest.mark.parametrize("prompt", [
'What is OpenVINO?',
Expand Down

0 comments on commit faf1043

Please sign in to comment.