From 39dcaa292bdd20f3da65bb90cc76741a39865991 Mon Sep 17 00:00:00 2001 From: Yaroslav Tarkan Date: Thu, 30 Jan 2025 18:05:33 +0300 Subject: [PATCH 1/2] Fix Qwen2VL generation without images (#1645) Ticket: CVS-156940, CVS-161487 --------- Co-authored-by: Vladimir Zlobin --- .../src/visual_language/inputs_embedder.cpp | 51 ++++++++----------- .../src/visual_language/vision_encoder.cpp | 2 +- tests/python_tests/common.py | 2 +- tests/python_tests/test_vlm_pipeline.py | 2 + 4 files changed, 26 insertions(+), 31 deletions(-) diff --git a/src/cpp/src/visual_language/inputs_embedder.cpp b/src/cpp/src/visual_language/inputs_embedder.cpp index e912570f20..d67c18817f 100644 --- a/src/cpp/src/visual_language/inputs_embedder.cpp +++ b/src/cpp/src/visual_language/inputs_embedder.cpp @@ -1428,15 +1428,15 @@ std::vector split_tokenize(const std::string& text, ov::genai::Token return tokenized; } -ov::Tensor insert_image_placeholders(const std::vector& chunks, size_t tokens_per_image) { +ov::Tensor insert_image_placeholders(const std::vector& chunks, const std::vector& tokens_per_images) { size_t merged_length = 0; for (const ov::Tensor& chunk : chunks) { merged_length += chunk.get_shape().at(1); } - merged_length += chunks.empty() ? 0 : (chunks.size() - 1) * tokens_per_image; + merged_length += std::accumulate(tokens_per_images.begin(), tokens_per_images.end(), 0); ov::Tensor merged{ov::element::i64, {1, merged_length}}; size_t offset = 0; - int64_t image_id = -1; + int64_t image_id = 0; for (const ov::Tensor& chunk : chunks) { size_t length = chunk.get_shape().at(1); std::copy_n( @@ -1448,11 +1448,11 @@ ov::Tensor insert_image_placeholders(const std::vector& chunks, size if (offset < merged_length) { std::fill_n( merged.data() + offset, - tokens_per_image, - image_id + tokens_per_images.at(image_id), + -image_id - 1 // It could be just -image_id. -1 is for consistency with the original implementation. ); - offset += tokens_per_image; - --image_id; + offset += tokens_per_images.at(image_id); + ++image_id; } } return merged; @@ -1481,9 +1481,7 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder { public: ov::InferRequest m_hd_feature_transformer; ov::InferRequest m_vision_projection; - // Used to insert <|image_i|>\n per image (not a slice). - size_t m_image_id = 1; - size_t m_tokens_per_image = 0; + std::vector m_tokens_per_images; InputsEmbedderPhi3V( const VLMConfig& vlm_config, @@ -1491,7 +1489,7 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder { const std::string& device, const ov::AnyMap device_config ): - IInputsEmbedder(vlm_config, model_dir, device, device_config), m_image_id{0}, + IInputsEmbedder(vlm_config, model_dir, device, device_config), m_hd_feature_transformer{phi3_v::create_hd_feature_transformer()}, m_vision_projection{utils::singleton_core().compile_model(model_dir / "openvino_vision_projection_model.xml", device, {}).create_infer_request()} {} @@ -1502,8 +1500,8 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder { for (const ov::Tensor& image : to_single_image_tensors(images)) { EncodedImage encoded_image = m_vision_encoder.encode(image); images_features_proj.push_back(phi3_v::hd_feature_transform(encoded_image, m_hd_feature_transformer, m_vlm_config.sub_GN, m_vlm_config.glb_GN, m_vision_projection)); - images_prompt << "<|image_" << m_image_id << "|>\n"; - ++m_image_id; + m_tokens_per_images.push_back(images_features_proj.back().get_shape().at(1)); + images_prompt << "<|image_" << m_tokens_per_images.size() << "|>\n"; } images_prompt << prompt; std::vector new_chat_tokens; @@ -1511,8 +1509,7 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder { if (m_is_chat_conversation) { m_history.push_back({{"role", "user"}, {"content", images_prompt.str()}}); constexpr bool add_generation_prompt = true; - std::string new_templated_chat_history; - new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); + std::string new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); auto start_tokenizer_time = std::chrono::steady_clock::now(); new_chat_tokens = phi3_v::split_tokenize(new_templated_chat_history, m_tokenizer); prev_chat_tokens = phi3_v::split_tokenize(m_templated_chat_history, m_tokenizer); @@ -1525,11 +1522,8 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder { auto end_tokenizer_time = std::chrono::steady_clock::now(); metrics.raw_metrics.tokenization_durations.emplace_back(PerfMetrics::get_microsec(end_tokenizer_time - start_tokenizer_time)); } - if (0 == m_tokens_per_image && !images_features_proj.empty()) { - m_tokens_per_image = images_features_proj.at(0).get_shape().at(1); - } - ov::Tensor new_merged_tokens = phi3_v::insert_image_placeholders(new_chat_tokens, m_tokens_per_image); - ov::Tensor prev_merged_tokens = phi3_v::insert_image_placeholders(prev_chat_tokens, m_tokens_per_image); + ov::Tensor new_merged_tokens = phi3_v::insert_image_placeholders(new_chat_tokens, m_tokens_per_images); + ov::Tensor prev_merged_tokens = phi3_v::insert_image_placeholders(prev_chat_tokens, m_tokens_per_images); ov::Tensor new_tokens = update_history(new_merged_tokens, prev_merged_tokens); std::vector tokens = phi3_v::drop_image_placeholders(new_tokens); OPENVINO_ASSERT(tokens.size() == images_features_proj.size() + 1); @@ -1537,7 +1531,6 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder { for (size_t im_id = 0; im_id < images_features_proj.size(); ++im_id) { size_t text_length = tokens.at(im_id).get_shape().at(1); size_t im_length = images_features_proj.at(im_id).get_shape().at(1); - OPENVINO_ASSERT(im_length == m_tokens_per_image); features_length += text_length + im_length; } features_length += tokens.back().get_shape().at(1); @@ -1570,7 +1563,7 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder { ); if (!m_is_chat_conversation) { - m_image_id = 0; + m_tokens_per_images.clear(); } return inputs_embeds; @@ -1578,12 +1571,12 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder { virtual void start_chat(const std::string& system_message) override { IInputsEmbedder::start_chat(system_message); - m_image_id = 0; + m_tokens_per_images.clear(); } virtual void finish_chat() override { IInputsEmbedder::finish_chat(); - m_image_id = 0; + m_tokens_per_images.clear(); } }; @@ -1662,10 +1655,6 @@ class InputsEmbedderQwen2VL : public InputsEmbedder::IInputsEmbedder { ov::Tensor input_ids = get_encoded_input_ids(formatted_prompt, metrics, chat_template_fallback); ov::Tensor text_embeds = m_embedding.infer(input_ids); - if (images.empty()) { - return text_embeds; - } - auto start_tokenizer_time = std::chrono::steady_clock::now(); ov::Tensor encoded_vision_start_token = m_tokenizer.encode(m_vlm_config.vision_start_token, ov::genai::add_special_tokens(false)).input_ids; ov::Tensor encoded_image_pad_token = m_tokenizer.encode(m_vlm_config.image_pad_token, ov::genai::add_special_tokens(false)).input_ids; @@ -1680,6 +1669,10 @@ class InputsEmbedderQwen2VL : public InputsEmbedder::IInputsEmbedder { int64_t position_ids_max_element = *std::max_element(m_position_ids.data(), m_position_ids.data() + m_position_ids.get_size()); m_rope_delta = position_ids_max_element + 1 - static_cast(input_ids.get_shape().at(1)); + if (images.empty()) { + return text_embeds; + } + return merge_text_and_image_embeddings_qwen2vl(input_ids, text_embeds, image_embeds, images_grid_thw, image_pad_token_id); } @@ -1874,7 +1867,7 @@ class InputsEmbedderQwen2VL : public InputsEmbedder::IInputsEmbedder { } // Calculate rotary embeddings for max_grid_size - const size_t dim = 1280 / 16 / 2; // config.vision_config.embed_dim / self.config.vision_config.num_heads / 2 + const size_t dim = m_vision_embeddings_merger.get_tensor("rotary_pos_emb").get_shape().at(1); const float theta = 10000.0f; std::vector inv_freq(dim / 2); diff --git a/src/cpp/src/visual_language/vision_encoder.cpp b/src/cpp/src/visual_language/vision_encoder.cpp index 04ddd63145..e8edd40890 100644 --- a/src/cpp/src/visual_language/vision_encoder.cpp +++ b/src/cpp/src/visual_language/vision_encoder.cpp @@ -843,7 +843,7 @@ std::tuple get_pixel_values_phi3_v(const ov::Tensor& imag ImageSize smart_resize_qwen2vl(size_t height, size_t width, size_t factor, size_t min_pixels, size_t max_pixels) { if (height < factor || width < factor) { - OPENVINO_THROW("Height or width must be larger than factor"); + OPENVINO_THROW("Height (" + std::to_string(height) + ") and width (" + std::to_string(width) + ") must be greater than factor (" + std::to_string(factor) + ")"); } if (std::max(height, width) / std::min(height, width) > 200) { OPENVINO_THROW("Absolute aspect ratio must be smaller than 200"); diff --git a/tests/python_tests/common.py b/tests/python_tests/common.py index 320f1e1a6a..88690e872a 100644 --- a/tests/python_tests/common.py +++ b/tests/python_tests/common.py @@ -535,7 +535,7 @@ def get_image_by_link(link): image = Image.open(requests.get(link, stream=True).raw) if image.mode != 'RGB': image = image.convert('RGB') - image_data = np.array((np.array(image.getdata()) - 128).astype(np.byte)).reshape(1, 3, image.size[1], image.size[0]) + image_data = np.array((np.array(image.getdata()) - 128).astype(np.byte)).reshape(1, image.size[1], image.size[0], 3) return Tensor(image_data) diff --git a/tests/python_tests/test_vlm_pipeline.py b/tests/python_tests/test_vlm_pipeline.py index 0f9358b961..3c188b26b2 100644 --- a/tests/python_tests/test_vlm_pipeline.py +++ b/tests/python_tests/test_vlm_pipeline.py @@ -47,6 +47,8 @@ def get_ov_model(model_id, cache): @pytest.mark.parametrize("model_id", [ "katuni4ka/tiny-random-minicpmv-2_6", "katuni4ka/tiny-random-phi3-vision", + "katuni4ka/tiny-random-llava", + "katuni4ka/tiny-random-qwen2vl", ]) def test_vlm_pipeline(model_id, cache): def streamer(word: str) -> bool: From 38ab05522fa661de3df32fd95296fe54f9252f4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mi=C5=82osz=20=C5=BBeglarski?= Date: Thu, 30 Jan 2025 18:03:35 +0100 Subject: [PATCH 2/2] Parallel sampling with threadpool (#1252) **This PR implements the same functionality as: https://github.com/openvinotoolkit/openvino.genai/pull/1233, but in a different manner. Only one of them should be merged.** Since pipeline logic is executed on a single thread, there are periods of low CPU usage while pipeline is not executing inference, but some other logic like sampling which can take quite large fraction of time. Currently after inference is done we sample from each sequence group in a loop on a single thread which becomes an issue with sampling parameters that significantly extend sampling time for a single sequence group. This PR extracts sampling logic for single sequence group into a separate method that can be executed independently from any other sequence group. In includes generic thread pool implementation that spawns certain amount of threads that are used to run sampling logic for different sequence groups in parallel. Performance measurements confirm improvement especially for non greedy sampling and with high concurrency (the more sequence groups are scheduled for inference the more benefit from parallel sampling). CVS-157230 --- .github/workflows/causal_lm_cpp.yml | 14 +- src/cpp/src/continuous_batching_adapter.hpp | 57 ++-- src/cpp/src/continuous_batching_impl.cpp | 19 +- src/cpp/src/sampler.cpp | 313 +++++++++++--------- src/cpp/src/sampler.hpp | 33 ++- src/cpp/src/threadpool.hpp | 70 +++++ src/cpp/src/visual_language/pipeline.cpp | 4 +- 7 files changed, 336 insertions(+), 174 deletions(-) create mode 100644 src/cpp/src/threadpool.hpp diff --git a/.github/workflows/causal_lm_cpp.yml b/.github/workflows/causal_lm_cpp.yml index 5dff0a58d3..a3e3e56312 100644 --- a/.github/workflows/causal_lm_cpp.yml +++ b/.github/workflows/causal_lm_cpp.yml @@ -466,9 +466,13 @@ jobs: - name: run and compare run: | source ./ov/setupvars.sh + echo Running speculative_decoding_lm C++ sample... ./build/samples/cpp/text_generation/speculative_decoding_lm ./dolly-v2-7b/ ./dolly-v2-3b/ "Alan Turing was a" > predictions_speculative.txt + echo Running greedy_causal_lm C++ sample... ./build/samples/cpp/text_generation/greedy_causal_lm ./dolly-v2-7b/ "Alan Turing was a" > predictions_greedy.txt + echo Running speculative_decoding_lm Python sample... python ./samples/python/text_generation/speculative_decoding_lm.py ./dolly-v2-7b/ ./dolly-v2-3b/ "Alan Turing was a" > predictions_py.txt + echo All samples executed, checking result correctness... python -c " with open('predictions_greedy.txt', 'r') as f: predicted_greedy = f.readline() @@ -476,6 +480,8 @@ jobs: predicted_speculative = f.readline() with open('predictions_py.txt', 'r') as f: predicted_py = f.readline() + print(f'Predicted greedy: {predicted_greedy}') + print(f'Predicted speculative: {predicted_speculative}') assert predicted_greedy == predicted_speculative assert predicted_greedy == predicted_py assert predicted_speculative == predicted_py @@ -523,10 +529,13 @@ jobs: ``` Question: Can you please add 2 and 3 A:' > ./prompt.txt - + echo Running prompt_lookup_decoding_lm C++ sample... ./build/samples/cpp/text_generation/prompt_lookup_decoding_lm ./TinyLlama-1.1B-Chat-v1.0/ "$( predictions_prompt_lookup.txt + echo Running greedy_causal_lm C++ sample... ./build/samples/cpp/text_generation/greedy_causal_lm ./TinyLlama-1.1B-Chat-v1.0/ "$( predictions_greedy.txt + echo Running prompt_lookup_decoding_lm Python sample... python ./samples/python/text_generation/prompt_lookup_decoding_lm.py ./TinyLlama-1.1B-Chat-v1.0/ "$( predictions_py.txt + echo All samples executed, checking result correctness... python -c " with open('predictions_greedy.txt', 'r') as f: predicted_greedy = f.readline() @@ -534,6 +543,9 @@ jobs: predicted_prompt_lookup = f.readline() with open('predictions_py.txt', 'r') as f: predicted_prompt_lookup_py = f.readline() + + print(f'Predicted greedy: {predicted_greedy}') + print(f'Predicted prompt lookup: {predicted_prompt_lookup}') assert predicted_greedy == predicted_prompt_lookup assert predicted_greedy == predicted_prompt_lookup_py assert predicted_prompt_lookup == predicted_prompt_lookup_py diff --git a/src/cpp/src/continuous_batching_adapter.hpp b/src/cpp/src/continuous_batching_adapter.hpp index 00928b342d..c1ab881371 100644 --- a/src/cpp/src/continuous_batching_adapter.hpp +++ b/src/cpp/src/continuous_batching_adapter.hpp @@ -1,10 +1,10 @@ - // Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 #include "llm_pipeline_base.hpp" #include "openvino/genai/continuous_batching_pipeline.hpp" +#include namespace ov::genai { @@ -17,29 +17,27 @@ template struct overloaded : Ts... {using Ts::operator()...;}; template overloaded(Ts...) -> overloaded; class ContinuousBatchingAdapter final : public LLMPipelineImplBase { - ContinuousBatchingPipeline m_impl; + std::unique_ptr m_impl; public: ContinuousBatchingAdapter( const ov::InferRequest& request, const Tokenizer& tokenizer, OptionalGenerationConfig generation_config - ): LLMPipelineImplBase{dont_construct(), GenerationConfig{}}, - m_impl{std::filesystem::path{}, SchedulerConfig{}, std::string{}} { } - + ): LLMPipelineImplBase{dont_construct(), GenerationConfig{}}, + m_impl{std::make_unique(std::filesystem::path{}, SchedulerConfig{}, std::string{})} { } + ContinuousBatchingAdapter( const std::filesystem::path& models_path, const Tokenizer& tokenizer, const SchedulerConfig& scheduler_config, const std::string& device, const ov::AnyMap& plugin_config - ): LLMPipelineImplBase{tokenizer, GenerationConfig()}, m_impl{ - models_path, - tokenizer, - scheduler_config, - device, - plugin_config} { - m_generation_config = m_impl.get_config(); - } + ): LLMPipelineImplBase{tokenizer, GenerationConfig()} { + auto mutable_plugin_config = plugin_config; + mutable_plugin_config["sampler_num_threads"] = 1; + m_impl = std::make_unique(models_path, tokenizer, scheduler_config, device, mutable_plugin_config); + m_generation_config = m_impl->get_config(); + } ContinuousBatchingAdapter( const std::string& model_str, @@ -49,27 +47,22 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase { const std::string& device, const ov::AnyMap& plugin_config, const ov::genai::GenerationConfig& generation_config - ): LLMPipelineImplBase{tokenizer, GenerationConfig()}, m_impl{ - model_str, - weights_tensor, - tokenizer, - scheduler_config, - device, - plugin_config, - generation_config} {} + ): LLMPipelineImplBase{tokenizer, GenerationConfig()} { + auto mutable_plugin_config = plugin_config; + mutable_plugin_config["sampler_num_threads"] = 1; + m_impl = std::make_unique(model_str, weights_tensor, tokenizer, scheduler_config, device, mutable_plugin_config, generation_config); + } ContinuousBatchingAdapter( const std::filesystem::path& models_path, const SchedulerConfig& scheduler_config, const std::string& device, const ov::AnyMap& plugin_config - ): LLMPipelineImplBase{Tokenizer(models_path), GenerationConfig()}, m_impl{ - models_path, - m_tokenizer, - scheduler_config, - device, - plugin_config} { - m_generation_config = m_impl.get_config(); + ): LLMPipelineImplBase{Tokenizer(models_path), GenerationConfig()} { + auto mutable_plugin_config = plugin_config; + mutable_plugin_config["sampler_num_threads"] = 1; + m_impl = std::make_unique(models_path, m_tokenizer, scheduler_config, device, mutable_plugin_config); + m_generation_config = m_impl->get_config(); } DecodedResults generate( @@ -90,7 +83,7 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase { }, inputs); const GenerationConfig& config = generation_config.has_value() ? *generation_config : m_generation_config; // -1 == config.eos_token_id and config.validate() are handled in m_impl. - std::vector generated = m_impl.generate(prompts, + std::vector generated = m_impl->generate(prompts, std::vector{prompts.size(), config}, streamer ); @@ -181,7 +174,7 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase { const GenerationConfig& config = generation_config.has_value() ? *generation_config : m_generation_config; // -1 == config.eos_token_id and config.validate() are handled in m_impl. - std::vector generated = m_impl.generate(input_ids, + std::vector generated = m_impl->generate(input_ids, std::vector{input_ids.size(), config}, streamer ); @@ -210,11 +203,11 @@ class ContinuousBatchingAdapter final : public LLMPipelineImplBase { } void start_chat(const std::string& system_message) override { - m_impl.start_chat(); + m_impl->start_chat(); }; void finish_chat() override { - m_impl.finish_chat(); + m_impl->finish_chat(); }; }; diff --git a/src/cpp/src/continuous_batching_impl.cpp b/src/cpp/src/continuous_batching_impl.cpp index f95cd3b9c6..095d7dc4e2 100644 --- a/src/cpp/src/continuous_batching_impl.cpp +++ b/src/cpp/src/continuous_batching_impl.cpp @@ -142,17 +142,25 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::initialize_pipeline( const std::vector& kv_cache_config) { ov::Core core = utils::singleton_core(); ov::CompiledModel compiled_model; + ov::AnyMap mutable_properties = properties; + // Extract sampler_num_threads property if exists and remove it from properties + size_t sampler_num_threads = std::thread::hardware_concurrency(); + auto sampler_num_threads_it = mutable_properties.find("sampler_num_threads"); + if (sampler_num_threads_it != mutable_properties.end()) { + sampler_num_threads = sampler_num_threads_it->second.as(); + mutable_properties.erase(sampler_num_threads_it); + } // TODO: remove once plugin automatically set KV cache precisions - apply_kv_cache_precision(model, device, properties); + apply_kv_cache_precision(model, device, mutable_properties); // apply LoRA - if (auto filtered_properties = extract_adapters_from_properties(properties, &m_generation_config.adapters)) { + if (auto filtered_properties = extract_adapters_from_properties(mutable_properties, &m_generation_config.adapters)) { m_generation_config.adapters->set_tensor_name_prefix("base_model.model.model."); m_adapter_controller = AdapterController(model, *m_generation_config.adapters, device); // TODO: Make the prefix name configurable compiled_model = core.compile_model(model, device, *filtered_properties); } else { - compiled_model = core.compile_model(model, device, properties); + compiled_model = core.compile_model(model, device, mutable_properties); } ov::genai::utils::print_compiled_model_properties(compiled_model, "LLM with Paged Attention"); @@ -227,7 +235,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::initialize_pipeline( std::make_shared(infer_request, m_block_size, m_num_decoder_layers); } - m_sampler = std::make_shared(m_tokenizer); + m_sampler = std::make_shared(m_tokenizer, sampler_num_threads); m_sampler->set_seed(m_generation_config.rng_seed); // If eos_token_id was not provided, take value @@ -282,8 +290,8 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() { _pull_awaiting_requests(); - Scheduler::Output scheduler_output; + { static ManualTimer scheduling_timer("scheduling"); scheduling_timer.start(); @@ -318,6 +326,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::step() { return; } ov::Tensor logits; + { static ManualTimer timer("forward"); timer.start(); diff --git a/src/cpp/src/sampler.cpp b/src/cpp/src/sampler.cpp index 7a1e079746..fe3cc8239a 100644 --- a/src/cpp/src/sampler.cpp +++ b/src/cpp/src/sampler.cpp @@ -1,6 +1,7 @@ // Copyright (C) 2023-2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 +#include #include "sampler.hpp" namespace ov::genai { @@ -744,6 +745,144 @@ process_stop_strings(const std::set& stop_strings, Tokenizer& token return result; } +SequenceGroupSamplingInfo Sampler::sample_from_sequence_group(SequenceGroup::Ptr sequence_group, ov::Tensor sequence_group_logits, + LogitProcessor& logit_processor, const std::pair>& stop_strings, + bool is_validation_mode_enabled) { + SequenceGroupSamplingInfo sg_sampling_info; + // Assistant pipeline info is relevant for speculative and prompt lookup decoding + AssistingPipelineInfo& assisting_pipeline_info = sg_sampling_info.get_assisting_pipeline_info(); + const ov::genai::GenerationConfig& sampling_params = sequence_group->get_sampling_parameters(); + const size_t output_seq_len = sequence_group->get_output_seq_len(); + // get number of tokens to be validated + size_t num_tokens_to_process = sequence_group->get_num_tokens_to_validate(); + + if (num_tokens_to_process > output_seq_len - 1) { + auto delta = num_tokens_to_process - (output_seq_len - 1); + assisting_pipeline_info.updated_validation_len = std::max(assisting_pipeline_info.updated_validation_len, delta); + num_tokens_to_process -= delta; + } + + if (sampling_params.is_greedy_decoding() || sampling_params.is_multinomial()) { + std::vector running_sequences = sequence_group->get_running_sequences(); + size_t num_running_sequences = sequence_group->num_running_seqs(); + if (sampling_params.is_greedy_decoding()) { + OPENVINO_ASSERT(num_running_sequences == 1); + } + for (size_t running_sequence_id = 0; running_sequence_id < num_running_sequences; ++running_sequence_id) { + auto& running_sequence = running_sequences[running_sequence_id]; + bool is_validation_passed = true; + // make `num_tokens_to_process` iteration to validate a candidate generated by `draft_model` + 1 iteration to generate one more token by `main_model` + for (size_t i = 0; i <= num_tokens_to_process; ++i) { + sg_sampling_info.sampler_output.num_generated_tokens++; + // calculate token offset from the end of logit + size_t token_offset = num_tokens_to_process - i; + // max counter of needed to be sampled tokens + OPENVINO_ASSERT(running_sequence->get_generated_len() >= token_offset); + size_t generated_and_verified_len = running_sequence->get_generated_len() - token_offset; + OPENVINO_ASSERT(sequence_group->get_max_new_tokens() >= generated_and_verified_len); + size_t max_num_sampled_token = sequence_group->get_max_new_tokens() - generated_and_verified_len; + if (max_num_sampled_token == 0) { + stop_sample_tokens(running_sequence, token_offset, max_num_sampled_token, assisting_pipeline_info.max_removed_tokens_per_request); + break; + } + + // do sampling only for token validation/generation. + // continue in case of extending draft model sequences by main model generated tokens which + // should be taken to KV cache without validation + if (!is_validation_mode_enabled && token_offset > 0) { + continue; + } + + auto logit_vector = _get_logit_vector(sequence_group_logits, running_sequence_id, token_offset); + logit_processor.apply(logit_vector); + + Token sampled_token; + bool is_generate_n_tokens = false; + if (sampling_params.is_greedy_decoding()) { + sampled_token = { _greedy_sample(logit_vector, sampling_params.logprobs) }; + } else { + // is_multinomial() + is_generate_n_tokens = sequence_group->num_total_seqs() == 1; + const size_t num_tokens_per_sequence = is_generate_n_tokens ? sampling_params.num_return_sequences : 1; + is_generate_n_tokens &= (num_tokens_per_sequence > 1); + auto sampled_token_ids = _multinomial_sample(logit_vector, num_tokens_per_sequence); + OPENVINO_ASSERT(sampled_token_ids.size(), num_tokens_per_sequence); + // to create n sequence just in case of `sequence_group->num_total_seqs() == 1` and `sampling_params.num_return_sequences > 1` + if (is_generate_n_tokens) { + const auto forked_seq_ids = create_n_forked_sequences(sequence_group, logit_processor, sampled_token_ids); + sg_sampling_info.sampler_output.m_forked_sequences.insert({running_sequences[0]->get_id(), forked_seq_ids}); + } + sampled_token = sampled_token_ids.front(); + // make `_speculative_sampling` in case of previous token was not accepted in speculative decoding + if (!is_validation_passed) { + float p_prime = get_p_prime(running_sequence, sampled_token, token_offset + 1); + assisting_pipeline_info.max_removed_tokens_per_request = std::max(assisting_pipeline_info.max_removed_tokens_per_request, token_offset); + // update prob only in case candidate prob > sampled token prob + if (p_prime > 0.f) { + auto prob = std::exp(sampled_token.m_log_prob); + prob /= p_prime; + sampled_token.m_log_prob = std::log(prob); + } + } + } + // flag to add sampled token to generated sequence or extend logit processors only + bool is_extend_sequence = token_offset == 0 || is_generate_n_tokens || !is_validation_passed; + if (is_validation_mode_enabled && !is_extend_sequence) { + is_validation_passed = validate_candidate(running_sequences[running_sequence_id], token_offset, sampled_token, + is_extend_sequence, assisting_pipeline_info.max_removed_tokens_per_request, sampling_params.do_sample); + // doing resample in case of non accepted tokens in specualtive sampling + if (!is_validation_passed && sampling_params.do_sample) { + continue; + } + // update log prob just while validation process + if (!is_extend_sequence) { + OPENVINO_ASSERT(generated_and_verified_len < running_sequences[running_sequence_id]->get_generated_len()); + running_sequence->update_generated_log_prob(generated_and_verified_len, sampled_token.m_log_prob); + } + } + register_new_token(sampled_token, running_sequences[running_sequence_id], logit_processor, is_extend_sequence, is_validation_mode_enabled); + // to exit from sampling in case of failed token validation + if (!is_validation_passed) { + break; + } + } + assisting_pipeline_info.min_generated_len = std::min(assisting_pipeline_info.min_generated_len, running_sequence->get_generated_len()); + } + align_all_sequence_len(sequence_group, assisting_pipeline_info.min_generated_len, logit_processor); + for (const auto& dropped_seq_id : _try_finish_generation(sequence_group)) { + sg_sampling_info.sampler_output.m_dropped_sequences.push_back(dropped_seq_id); + } + } else if (sampling_params.is_beam_search()) { + uint64_t request_id = sequence_group->get_request_id(); + + // create beam search info if we are on the first generate + GroupBeamSearcher* beam_searcher; + { + std::lock_guard lock(m_beam_search_info_mutex); + if (m_beam_search_info.find(request_id) == m_beam_search_info.end()) { + m_beam_search_info.emplace(request_id, GroupBeamSearcher(sequence_group, m_tokenizer)); + } + beam_searcher = &m_beam_search_info.at(request_id); + } + + // current algorithm already adds new tokens to running sequences and + beam_searcher->select_next_tokens(sequence_group_logits, sg_sampling_info.sampler_output, stop_strings); + + // check max length stop criteria + std::vector running_sequences = sequence_group->get_running_sequences(); + if (!sequence_group->has_finished() && + running_sequences[0]->get_generated_len() == sequence_group->get_max_new_tokens()) { + // stop sequence by max_new_tokens + beam_searcher->finalize(sg_sampling_info.sampler_output); + } + } + // Notify handle after sampling is done. + // For non-streaming this is effective only when the generation is finished. + OPENVINO_ASSERT(num_tokens_to_process >= assisting_pipeline_info.max_removed_tokens_per_request); + sequence_group->notify_handle(); + return sg_sampling_info; +} + SamplerOutput Sampler::sample(const std::vector & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled) { @@ -753,13 +892,14 @@ SamplerOutput Sampler::sample(const std::vector & sequence_g size_t vocab_size = logits_shape[2]; SamplerOutput sampler_output; + std::unordered_map> sg_sampling_future_map; for (size_t sequence_group_id = 0, currently_processed_tokens = 0; sequence_group_id < sequence_groups.size(); ++sequence_group_id) { SequenceGroup::Ptr sequence_group = sequence_groups[sequence_group_id]; if (!sequence_group->is_scheduled()) continue; - size_t num_running_sequences = sequence_group->num_running_seqs(); - size_t output_seq_len = sequence_group->get_output_seq_len(); + const size_t num_running_sequences = sequence_group->num_running_seqs(); + const size_t output_seq_len = sequence_group->get_output_seq_len(); const ov::genai::GenerationConfig& sampling_params = sequence_group->get_sampling_parameters(); const auto request_id = sequence_group->get_request_id(); @@ -771,153 +911,62 @@ SamplerOutput Sampler::sample(const std::vector & sequence_g m_stop_strings.insert({request_id, processed_stop_string}); sequence_group->set_stream_window_size(processed_stop_string.first); } - auto& stop_strings = m_stop_strings.at(request_id); + const auto& stop_strings = m_stop_strings.at(request_id); auto& logit_processor = m_logit_processors.at(request_id); const void * sequence_group_logits_data = logits_data + vocab_size * currently_processed_tokens; ov::Tensor sequence_group_logits(ov::element::f32, ov::Shape{num_running_sequences, output_seq_len, vocab_size}, (void *)sequence_group_logits_data); - size_t max_removed_tokens_per_request = 0, min_generated_len = std::numeric_limits::max(), updated_validation_len = 0; if (sequence_group->requires_sampling()) { - // get number of token to be validated - auto num_tokens_to_process = sequence_group->get_num_tokens_to_validate(); - if (num_tokens_to_process > output_seq_len - 1) { - auto delta = num_tokens_to_process - (output_seq_len - 1); - updated_validation_len = std::max(updated_validation_len, delta); - num_tokens_to_process -= delta; - } - if (sampling_params.is_greedy_decoding() || sampling_params.is_multinomial()) { - std::vector running_sequences = sequence_group->get_running_sequences(); - if (sampling_params.is_greedy_decoding()) { - OPENVINO_ASSERT(num_running_sequences == 1); - } - for (size_t running_sequence_id = 0; running_sequence_id < num_running_sequences; ++running_sequence_id) { - auto& running_sequence = running_sequences[running_sequence_id]; - bool is_validation_passed = true; - // make `num_tokens_to_process` iteration to validate a candidate generated by `draft_model` + 1 iteration to generate one more token by `main_model` - for (size_t i = 0; i <= num_tokens_to_process; ++i) { - sampler_output.num_generated_tokens++; - // calculate token offset from the end of logit - size_t token_offset = num_tokens_to_process - i; - // max counter of needed to be sampled tokens - OPENVINO_ASSERT(running_sequence->get_generated_len() >= token_offset); - size_t generated_and_verified_len = running_sequence->get_generated_len() - token_offset; - OPENVINO_ASSERT(sequence_group->get_max_new_tokens() >= generated_and_verified_len); - size_t max_num_sampled_token = sequence_group->get_max_new_tokens() - generated_and_verified_len; - if (max_num_sampled_token == 0) { - stop_sample_tokens(running_sequence, token_offset, max_num_sampled_token, max_removed_tokens_per_request); - break; - } - - // do sampling only for token validation/generation. - // continue in case of extending draft model sequences by main model generated tokens which - // should be taken to KV cache without validation - if (!is_validation_mode_enabled && token_offset > 0) { - continue; - } - - auto logit_vector = _get_logit_vector(sequence_group_logits, running_sequence_id, token_offset); - logit_processor.apply(logit_vector); - - Token sampled_token; - bool is_generate_n_tokens = false; - if (sampling_params.is_greedy_decoding()) { - sampled_token = { _greedy_sample(logit_vector, sampling_params.logprobs) }; - } else { - // is_multinomial() - is_generate_n_tokens = sequence_group->num_total_seqs() == 1; - const size_t num_tokens_per_sequence = is_generate_n_tokens ? sampling_params.num_return_sequences : 1; - is_generate_n_tokens &= (num_tokens_per_sequence > 1); - auto sampled_token_ids = _multinomial_sample(logit_vector, num_tokens_per_sequence); - OPENVINO_ASSERT(sampled_token_ids.size(), num_tokens_per_sequence); - // to create n sequence just in case of `sequence_group->num_total_seqs() == 1` and `sampling_params.num_return_sequences > 1` - if (is_generate_n_tokens) { - const auto forked_seq_ids = create_n_forked_sequences(sequence_group, logit_processor, sampled_token_ids); - sampler_output.m_forked_sequences.insert({running_sequences[0]->get_id(), forked_seq_ids}); - } - sampled_token = sampled_token_ids.front(); - // make `_speculative_sampling` in case of previous token was not accepted in speculative decoding - if (!is_validation_passed) { - float p_prime = get_p_prime(running_sequence, sampled_token, token_offset + 1); - max_removed_tokens_per_request = std::max(max_removed_tokens_per_request, token_offset); - // update prob only in case candidate prob > sampled token prob - if (p_prime > 0.f) { - auto prob = std::exp(sampled_token.m_log_prob); - prob /= p_prime; - sampled_token.m_log_prob = std::log(prob); - } - } - } - // flag to add sampled token to generated sequence or extend logit processors only - bool is_extend_sequence = token_offset == 0 || is_generate_n_tokens || !is_validation_passed; - if (is_validation_mode_enabled && !is_extend_sequence) { - is_validation_passed = validate_candidate(running_sequences[running_sequence_id], token_offset, sampled_token, - is_extend_sequence, max_removed_tokens_per_request, sampling_params.do_sample); - // doing resample in case of non accepted tokens in specualtive sampling - if (!is_validation_passed && sampling_params.do_sample) { - continue; - } - // update log prob just while validation process - if (!is_extend_sequence) { - OPENVINO_ASSERT(generated_and_verified_len < running_sequences[running_sequence_id]->get_generated_len()); - running_sequence->update_generated_log_prob(generated_and_verified_len, sampled_token.m_log_prob); - } - } - register_new_token(sampled_token, running_sequences[running_sequence_id], logit_processor, is_extend_sequence, is_validation_mode_enabled); - // to exit from sampling in case of failed token validation - if (!is_validation_passed) { - break; - } - } - min_generated_len = std::min(min_generated_len, running_sequence->get_generated_len()); - } - align_all_sequence_len(sequence_group, min_generated_len, logit_processor); - for (const auto& dropped_seq_id : _try_finish_generation(sequence_group)) { - sampler_output.m_dropped_sequences.push_back(dropped_seq_id); - } - } else if (sampling_params.is_beam_search()) { - uint64_t request_id = sequence_group->get_request_id(); - - // create beam search info if we are on the first generate - if (m_beam_search_info.find(request_id) == m_beam_search_info.end()) { - m_beam_search_info.emplace(request_id, GroupBeamSearcher(sequence_group, m_tokenizer)); - } - - // current algorithm already adds new tokens to running sequences and - m_beam_search_info.at(request_id).select_next_tokens(sequence_group_logits, sampler_output, stop_strings); - - // check max length stop criteria - std::vector running_sequences = sequence_group->get_running_sequences(); - if (!sequence_group->has_finished() && - running_sequences[0]->get_generated_len() == sequence_group->get_max_new_tokens()) { - // stop sequence by max_new_tokens - m_beam_search_info.at(request_id).finalize(sampler_output); - } - } - // Notify handle after sampling is done. - // For non-streaming this is effective only when the generation is finished. - OPENVINO_ASSERT(num_tokens_to_process >= max_removed_tokens_per_request); - sequence_group->notify_handle(); + // Call sample_from_sequence_group asynchronously + sg_sampling_future_map[request_id] = m_thread_pool.submit(&Sampler::sample_from_sequence_group, this, sequence_group, sequence_group_logits, + logit_processor, stop_strings, is_validation_mode_enabled); } else { // we are in prompt processing phase when prompt is split into chunks and processed step by step } + // accumulate a number of processed tokens + currently_processed_tokens += output_seq_len * num_running_sequences; + } + // Update sequence groups internal states after sampling is done + for (auto& sequence_group : sequence_groups) { + if (!sequence_group->is_scheduled()) + continue; + SequenceGroupSamplingInfo sg_sampling_info; + const auto request_id = sequence_group->get_request_id(); + if (sg_sampling_future_map.find(request_id) != sg_sampling_future_map.end()) { + // If there is a future assigned to a sequence group we read it's result (blocking if results not available yet) + sg_sampling_info = sg_sampling_future_map[request_id].get(); + sampler_output.num_generated_tokens += sg_sampling_info.sampler_output.num_generated_tokens; + + // Merge sampler output from sequence group to the main one + sampler_output.m_dropped_sequences.insert( + sampler_output.m_dropped_sequences.end(), + sg_sampling_info.sampler_output.m_dropped_sequences.begin(), + sg_sampling_info.sampler_output.m_dropped_sequences.end() + ); + + for (const auto& forked_seq : sg_sampling_info.sampler_output.m_forked_sequences) { + sampler_output.m_forked_sequences[forked_seq.first].insert( + sampler_output.m_forked_sequences[forked_seq.first].end(), + forked_seq.second.begin(), + forked_seq.second.end() + ); + } + } // NOTE: it should be before 'get_num_scheduled_tokens' is used // update internal state of sequence group to reset scheduler tokens and update currently processed ones - auto min_validated_tokens = sequence_group->get_num_tokens_to_validate() - max_removed_tokens_per_request; + const AssistingPipelineInfo& assisting_pipeline_info = std::as_const(sg_sampling_info.get_assisting_pipeline_info()); sequence_group->finish_iteration(); // decrease sequence_group context in case of candidates generated by draft_model were not accepted by main_model - if (max_removed_tokens_per_request) { - auto min_processed_tokens = sequence_group->get_prompt_len() + min_generated_len - 1; + if (assisting_pipeline_info.max_removed_tokens_per_request) { + auto min_processed_tokens = sequence_group->get_prompt_len() + assisting_pipeline_info.min_generated_len - 1; sequence_group->update_processed_tokens_num(min_processed_tokens); + auto& logit_processor = get_logit_processor(sequence_group->get_request_id()); logit_processor.update_generated_len(min_processed_tokens); } - if (updated_validation_len) { - sequence_group->set_num_validated_tokens(updated_validation_len); + if (assisting_pipeline_info.updated_validation_len) { + sequence_group->set_num_validated_tokens(assisting_pipeline_info.updated_validation_len); } - - // accumulate a number of processed tokens - currently_processed_tokens += output_seq_len * num_running_sequences; } - return sampler_output; } diff --git a/src/cpp/src/sampler.hpp b/src/cpp/src/sampler.hpp index 9768e0a7af..c53676d23c 100644 --- a/src/cpp/src/sampler.hpp +++ b/src/cpp/src/sampler.hpp @@ -19,6 +19,7 @@ #include "logit_processor.hpp" #include "scheduler.hpp" #include "sequence_group.hpp" +#include "threadpool.hpp" namespace ov::genai { // Handle stop_token_ids @@ -42,6 +43,21 @@ struct SamplerOutput { size_t num_generated_tokens = 0; }; +struct AssistingPipelineInfo { + size_t max_removed_tokens_per_request = 0; + size_t min_generated_len = std::numeric_limits::max(); + size_t updated_validation_len = 0; +}; + +struct SequenceGroupSamplingInfo { + SamplerOutput sampler_output; + AssistingPipelineInfo assisting_pipeline_info; + + AssistingPipelineInfo& get_assisting_pipeline_info() { + return assisting_pipeline_info; + } +}; + class Sampler { class GroupBeamSearcher; @@ -53,8 +69,13 @@ class Sampler { bool validate_candidate(Sequence::Ptr running_sequence, size_t& token_idx, Token& sampled_token, bool& is_extend_sequence, size_t& max_removed_tokens, bool do_sample); + SequenceGroupSamplingInfo sample_from_sequence_group(SequenceGroup::Ptr sequence_group, ov::Tensor sequence_group_logits, + LogitProcessor& logit_processor, const std::pair>& stop_strings, + bool is_validation_mode_enabled); + // request ID => beam search tracking information std::map m_beam_search_info; + std::mutex m_beam_search_info_mutex; std::mt19937 rng_engine; size_t seed = rng_engine.default_seed; @@ -65,9 +86,13 @@ class Sampler { Tokenizer m_tokenizer; + ThreadPool m_thread_pool; + public: - Sampler() = default; - Sampler(Tokenizer & tokenizer) : m_tokenizer(tokenizer) {}; + Sampler(const Sampler& rhs) = delete; + Sampler(Sampler&& rhs) = delete; + Sampler(size_t num_threads = 1): m_thread_pool(num_threads) {}; + explicit Sampler(const Tokenizer & tokenizer, size_t num_threads = 1) : m_tokenizer(tokenizer), m_thread_pool(num_threads) {}; SamplerOutput sample(const std::vector & sequence_groups, ov::Tensor logits, bool is_validation_mode_enabled = false); void set_seed(size_t new_seed) { @@ -76,6 +101,10 @@ class Sampler { } size_t get_seed() { return seed; } + void set_tokenizer(const Tokenizer& tokenizer) { + m_tokenizer = tokenizer; + } + void clear_request_info(uint64_t request_id); LogitProcessor& get_logit_processor(uint64_t request_id); diff --git a/src/cpp/src/threadpool.hpp b/src/cpp/src/threadpool.hpp new file mode 100644 index 0000000000..a5576c2d6c --- /dev/null +++ b/src/cpp/src/threadpool.hpp @@ -0,0 +1,70 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { + +private: + std::vector threads; + std::queue> tasks; + std::mutex queue_mutex; + std::condition_variable cv; + std::atomic stop{false}; + +public: + ThreadPool(const ThreadPool& rhs) = delete; + ThreadPool(ThreadPool&& rhs) = delete; + ThreadPool(size_t num_threads = std::thread::hardware_concurrency()) + { + for (size_t i = 0; i < num_threads; ++i) { + threads.emplace_back([this] { + while (true) { + std::function task; + { + std::unique_lock lock(queue_mutex); + cv.wait(lock, [this] { + return !tasks.empty() || stop; + }); + if (stop && tasks.empty()) { + return; + } + task = move(tasks.front()); + tasks.pop(); + } + task(); + } + }); + } + } + + ~ThreadPool() + { + stop = true; + cv.notify_all(); + for (auto& thread : threads) { + thread.join(); + } + } + + template + auto submit(F&& f, Args&&... args) -> std::future> + { + using return_type = std::invoke_result_t; + auto task = std::make_shared>( + std::bind(std::forward(f), std::forward(args)...) + ); + std::future result = task->get_future(); + { + std::unique_lock lock(queue_mutex); + tasks.emplace([task]() { (*task)(); }); + } + cv.notify_one(); + return result; + } +}; diff --git a/src/cpp/src/visual_language/pipeline.cpp b/src/cpp/src/visual_language/pipeline.cpp index a3f9859384..aefb765874 100644 --- a/src/cpp/src/visual_language/pipeline.cpp +++ b/src/cpp/src/visual_language/pipeline.cpp @@ -108,7 +108,7 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id()); } - m_sampler = Sampler(m_tokenizer); + m_sampler.set_tokenizer(m_tokenizer); m_sampler.set_seed(m_generation_config.rng_seed); } @@ -146,7 +146,7 @@ class ov::genai::VLMPipeline::VLMPipelineImpl { m_generation_config.set_eos_token_id(m_tokenizer.get_eos_token_id()); } - m_sampler = Sampler(m_tokenizer); + m_sampler.set_tokenizer(m_tokenizer); m_sampler.set_seed(m_generation_config.rng_seed); }