From f5d82569f75aa2456d8063e7b272aa3e52782542 Mon Sep 17 00:00:00 2001 From: Irina Efode Date: Thu, 30 Jan 2025 12:46:15 +0400 Subject: [PATCH] fix tests --- src/cpp/src/continuous_batching_impl.cpp | 17 +++++++------- src/cpp/src/lm_encoding.cpp | 17 +++++++------- .../src/prompt_lookup/prompt_lookup_impl.cpp | 17 +++++++------- .../speculative_decoding_impl.cpp | 23 +++++++++++-------- .../python_tests/test_continuous_batching.py | 3 ++- 5 files changed, 42 insertions(+), 35 deletions(-) diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index 32bfe02ada..cd6203a1b0 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -465,15 +465,16 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::generate(const std::vectorcan_read()) { - std::unordered_map token = generation->read(); - if (token.empty()) { - continue; - } - for (const auto& gen_token : token.begin()->second.generated_ids) { - if (streamer_ptr->put(gen_token)) { - generation->drop(); - break; + std::unordered_map generation_outputs = generation->read(); + OPENVINO_ASSERT(generation_outputs.size() <= 1); + for (const auto& generation_output : generation_outputs) { + for (const auto& generated_token_id : generation_output.second.generated_ids) { + if (streamer_ptr->put(generated_token_id)) { + generation->drop(); + break; + } } + break; } } }; diff --git a/src/cpp/src/lm_encoding.cpp b/src/cpp/src/lm_encoding.cpp index 80848a12fb..48d03eac8f 100644 --- a/src/cpp/src/lm_encoding.cpp +++ b/src/cpp/src/lm_encoding.cpp @@ -94,15 +94,16 @@ std::pair> get_lm_encoded_results( auto stream_generated_tokens = [&streamer_ptr, &generations, &active_sequence_groups]() { GenerationHandle& handle = generations.at(0); if (streamer_ptr && handle->can_read()) { - std::unordered_map token = handle->back(); - if (token.empty()) { - return; - } - for (const auto& gen_token : token.begin()->second.generated_ids) { - if (streamer_ptr->put(gen_token)) { - handle->drop(); - break; + std::unordered_map generation_outputs = handle->read(); + OPENVINO_ASSERT(generation_outputs.size() <= 1); + for (const auto& generation_output : generation_outputs) { + for (const auto& generated_token_id : generation_output.second.generated_ids) { + if (streamer_ptr->put(generated_token_id)) { + handle->drop(); + break; + } } + break; } } }; diff --git a/src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp b/src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp index 89d7edfd86..f09604de0d 100644 --- a/src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp +++ b/src/cpp/src/prompt_lookup/prompt_lookup_impl.cpp @@ -151,15 +151,16 @@ ContinuousBatchingPipeline::PromptLookupImpl::generate(const std::vectorcan_read()) { - std::unordered_map token = generation->read(); - if (token.empty()) { - continue; - } - for (const auto& gen_token : token.begin()->second.generated_ids) { - if (streamer_ptr->put(gen_token)) { - generation->drop(); - break; + std::unordered_map generation_outputs = generation->read(); + OPENVINO_ASSERT(generation_outputs.size() <= 1); + for (const auto& generation_output : generation_outputs) { + for (const auto& generated_token_id : generation_output.second.generated_ids) { + if (streamer_ptr->put(generated_token_id)) { + generation->drop(); + break; + } } + break; } } }; diff --git a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp index 8b04928452..2aad50f465 100644 --- a/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp +++ b/src/cpp/src/speculative_decoding/speculative_decoding_impl.cpp @@ -258,13 +258,15 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< auto all_requests = get_awaiting_requests(); std::atomic has_active_requests = has_non_finished_requests(); - auto& generation = main_generations.at(0); + GenerationHandle& generation = main_generations.at(0); // create variables to make optimal thread-safe streaming std::mutex mutex; - std::unique_lock lock(mutex); + std::unique_lock lock(mutex); std::condition_variable cv; + // auto t_stream_ptr = create_streaming_thread(streamer_ptr, lock, cv, main_generations.at(0), has_active_requests); + std::shared_ptr t_stream_ptr = nullptr; if (streamer_ptr) { // define stream token lambda to use in `t_stream_ptr` @@ -276,15 +278,16 @@ ContinuousBatchingPipeline::SpeculativeDecodingImpl::generate(const std::vector< }); if (generation->can_read()) { - std::unordered_map token = generation->read(); - if (token.empty()) { - continue; - } - for (const auto& gen_token : token.begin()->second.generated_ids) { - if (streamer_ptr->put(gen_token)) { - generation->drop(); - break; + std::unordered_map generation_outputs = generation->read(); + OPENVINO_ASSERT(generation_outputs.size() <= 1); + for (const auto& generation_output : generation_outputs) { + for (const auto& generated_token_id : generation_output.second.generated_ids) { + if (streamer_ptr->put(generated_token_id)) { + generation->drop(); + break; + } } + break; } } }; diff --git a/tests/python_tests/test_continuous_batching.py b/tests/python_tests/test_continuous_batching.py index ad38ea9819..823fd6260b 100644 --- a/tests/python_tests/test_continuous_batching.py +++ b/tests/python_tests/test_continuous_batching.py @@ -373,7 +373,8 @@ def test_pipelines_generate_with_streaming(tmp_path, pipeline_type): models_path : Path = tmp_path / "t_streaming" / model_id convert_models(opt_model, hf_tokenizer, models_path) - pipe, input, gen_config = get_data_by_pipeline_type(models_path, pipeline_type) + generation_config = GenerationConfig() + pipe, input, gen_config = get_data_by_pipeline_type(models_path, pipeline_type, generation_config) py_streamer = lambda x: False _ = pipe.generate(input, generation_config=gen_config, streamer=py_streamer)