From c1ac1c6ac3aa92c1f108af9c7e9bdea0ca4b79e3 Mon Sep 17 00:00:00 2001 From: Irina Efode Date: Fri, 31 Jan 2025 11:39:57 +0400 Subject: [PATCH] [ CB ] Fix streaming in case of empty outputs (#1647) Ticket: * CVS-161111 --- src/cpp/src/continuous_batching_impl.cpp | 13 ++++-- src/cpp/src/lm_encoding.cpp | 13 ++++-- .../src/prompt_lookup/prompt_lookup_impl.cpp | 13 ++++-- src/cpp/src/sequence_group.hpp | 4 ++ .../speculative_decoding_impl.cpp | 17 ++++--- .../python_tests/test_continuous_batching.py | 46 ++++++++++++++++--- 6 files changed, 77 insertions(+), 29 deletions(-) diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index 095d7dc4e2..be1eba04f9 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -472,11 +472,14 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vectorcan_read()) { - std::unordered_map token = generation->read(); - for (const auto& gen_token : token.begin()->second.generated_ids) { - if (streamer_ptr->put(gen_token)) { - generation->drop(); - break; + std::unordered_map generation_outputs = generation->read(); + OPENVINO_ASSERT(generation_outputs.size() <= 1); + if (!generation_outputs.empty()) { + for (const auto& generated_token_id : generation_outputs.begin()->second.generated_ids) { + if (streamer_ptr->put(generated_token_id)) { + generation->drop(); + break; + } } } } diff --git a/src/cpp/src/lm_encoding.cpp b/src/cpp/src/lm_encoding.cpp index 0a51a4ed4d..e97700cabe 100644 --- a/src/cpp/src/lm_encoding.cpp +++ b/src/cpp/src/lm_encoding.cpp @@ -94,11 +94,14 @@ std::pair> get_lm_encoded_results( auto stream_generated_tokens = [&streamer_ptr, &generations, &active_sequence_groups]() { GenerationHandle& handle = generations.at(0); if (streamer_ptr && handle->can_read()) { - std::unordered_map token = handle->back(); - for (const auto& gen_token : token.begin()->second.generated_ids) { - if (streamer_ptr->put(gen_token)) { - handle->drop(); - break; + std::unordered_map generation_outputs = handle->read(); + OPENVINO_ASSERT(generation_outputs.size() <= 1); + if (!generation_outputs.empty()) { + for (const auto& generated_token_id : generation_outputs.begin()->second.generated_ids) { + if (streamer_ptr->put(generated_token_id)) { + handle->drop(); + break; + } } } } diff --git a/src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp b/src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp index 539680c819..6e58662b33 100644 --- a/src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp +++ b/src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp @@ -151,11 +151,14 @@ ContinuousBatchingPipeline::PromptLookupImpl::generate(const std::vectorcan_read()) { - std::unordered_map token = generation->read(); - for (const auto& gen_token : token.begin()->second.generated_ids) { - if (streamer_ptr->put(gen_token)) { - generation->drop(); - break; + std::unordered_map generation_outputs = generation->read(); + OPENVINO_ASSERT(generation_outputs.size() <= 1); + if (!generation_outputs.empty()) { + for (const auto& generated_token_id : generation_outputs.begin()->second.generated_ids) { + if (streamer_ptr->put(generated_token_id)) { + generation->drop(); + break; + } } } } diff --git a/src/cpp/src/sequence_group.hpp b/src/cpp/src/sequence_group.hpp index 19d29c92ac..72dbb64df8 100644 --- a/src/cpp/src/sequence_group.hpp +++ b/src/cpp/src/sequence_group.hpp @@ -649,7 +649,11 @@ class SequenceGroup : public std::enable_shared_from_this { if (has_finished()) { m_stream_window_size = 0; } + // push empty output in case we won't stream generation res if (generated_len <= (m_num_streamed_tokens + m_stream_window_size)) { + if (has_finished()) { + push_empty_outputs(); + } return; } // speculative decoding draft handling diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 51490945e7..6fb4e8ac53 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -261,11 +261,11 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< auto all_requests = get_awaiting_requests(); std::atomic has_active_requests = has_non_finished_requests(); - auto& generation = main_generations.at(0); + GenerationHandle& generation = main_generations.at(0); // create variables to make optimal thread-safe streaming std::mutex mutex; - std::unique_lock lock(mutex); + std::unique_lock lock(mutex); std::condition_variable cv; std::shared_ptr t_stream_ptr = nullptr; @@ -279,11 +279,14 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< }); if (generation->can_read()) { - std::unordered_map token = generation->read(); - for (const auto& gen_token : token.begin()->second.generated_ids) { - if (streamer_ptr->put(gen_token)) { - generation->drop(); - break; + std::unordered_map generation_outputs = generation->read(); + OPENVINO_ASSERT(generation_outputs.size() <= 1); + if (!generation_outputs.empty()) { + for (const auto& generated_token_id : generation_outputs.begin()->second.generated_ids) { + if (streamer_ptr->put(generated_token_id)) { + generation->drop(); + break; + } } } } diff --git a/tests/python_tests/test_continuous_batching.py b/tests/python_tests/test_continuous_batching.py index 7cf5bc5355..8afcc8061c 100644 --- a/tests/python_tests/test_continuous_batching.py +++ b/tests/python_tests/test_continuous_batching.py @@ -5,6 +5,7 @@ import pytest import math from typing import Dict +from functools import partial from pathlib import Path from openvino_genai import ContinuousBatchingPipeline, LLMPipeline, GenerationConfig, SchedulerConfig, Tokenizer, draft_model @@ -340,12 +341,10 @@ def test_preemption_with_multinomial_n_seq(tmp_path, dynamic_split_fuse): scheduler_config = get_scheduler_config({"num_kv_blocks": 8, "dynamic_split_fuse": dynamic_split_fuse, "max_num_batched_tokens": 256, "max_num_seqs": 256}) generate_and_compare_with_reference_text(models_path, multinomial_params_n_seq.prompts, multinomial_params_n_seq.ref_texts, multinomial_params_n_seq.generation_config, scheduler_config) -def get_data_by_pipeline_type(model_path: Path, pipeline_type: str): +def get_data_by_pipeline_type(model_path: Path, pipeline_type: str, generation_config: GenerationConfig): device = "CPU" - prompt = "Prompt example is" - generation_config = GenerationConfig() + prompt = "Prompt example is" generation_config.max_new_tokens = 10 - generation_config.do_sample = True pipe = None if pipeline_type == "continuous_batching": scheduler_config = SchedulerConfig() @@ -374,9 +373,42 @@ def test_pipelines_generate_with_streaming(tmp_path, pipeline_type): models_path : Path = tmp_path / "t_streaming" / model_id convert_models(opt_model, hf_tokenizer, models_path) - pipe, input, gen_config = get_data_by_pipeline_type(models_path, pipeline_type) - py_streamer = lambda x: False - _ = pipe.generate(input, generation_config=gen_config, streamer=py_streamer) + generation_config = GenerationConfig() + pipe, input, gen_config = get_data_by_pipeline_type(models_path, pipeline_type, generation_config) + + def py_streamer(py_str: str): + return False + + try: + _ = pipe.generate(input, generation_config=generation_config, streamer=py_streamer) + except Exception: + assert True + + del pipe + rmtree(models_path) + +@pytest.mark.parametrize("pipeline_type", ["continuous_batching", "speculative_decoding", "prompt_lookup_decoding", "llm_pipeline"]) +@pytest.mark.precommit +def test_pipelines_generate_with_streaming_empty_output(tmp_path, pipeline_type): + model_id : str = "facebook/opt-125m" + opt_model, hf_tokenizer = get_hugging_face_models(model_id) + + models_path : Path = tmp_path / "t_streaming" / model_id + convert_models(opt_model, hf_tokenizer, models_path) + + generation_config = GenerationConfig() + generation_config.stop_strings = {" the "} + generation_config.include_stop_str_in_output = False + + pipe, input, generation_config = get_data_by_pipeline_type(models_path, pipeline_type, generation_config) + + def py_streamer(py_str: str): + raise Exception("Streamer was called") + + try: + _ = pipe.generate(input, generation_config=generation_config, streamer=py_streamer) + except Exception: + assert False del pipe rmtree(models_path)