diff --git a/src/cpp/src/visual_language/inputs_embedder.cpp b/src/cpp/src/visual_language/inputs_embedder.cpp index e912570f20..d67c18817f 100644 --- a/src/cpp/src/visual_language/inputs_embedder.cpp +++ b/src/cpp/src/visual_language/inputs_embedder.cpp @@ -1428,15 +1428,15 @@ std::vector split_tokenize(const std::string& text, ov::genai::Token return tokenized; } -ov::Tensor insert_image_placeholders(const std::vector& chunks, size_t tokens_per_image) { +ov::Tensor insert_image_placeholders(const std::vector& chunks, const std::vector& tokens_per_images) { size_t merged_length = 0; for (const ov::Tensor& chunk : chunks) { merged_length += chunk.get_shape().at(1); } - merged_length += chunks.empty() ? 0 : (chunks.size() - 1) * tokens_per_image; + merged_length += std::accumulate(tokens_per_images.begin(), tokens_per_images.end(), 0); ov::Tensor merged{ov::element::i64, {1, merged_length}}; size_t offset = 0; - int64_t image_id = -1; + int64_t image_id = 0; for (const ov::Tensor& chunk : chunks) { size_t length = chunk.get_shape().at(1); std::copy_n( @@ -1448,11 +1448,11 @@ ov::Tensor insert_image_placeholders(const std::vector& chunks, size if (offset < merged_length) { std::fill_n( merged.data() + offset, - tokens_per_image, - image_id + tokens_per_images.at(image_id), + -image_id - 1 // It could be just -image_id. -1 is for consistency with the original implementation. ); - offset += tokens_per_image; - --image_id; + offset += tokens_per_images.at(image_id); + ++image_id; } } return merged; @@ -1481,9 +1481,7 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder { public: ov::InferRequest m_hd_feature_transformer; ov::InferRequest m_vision_projection; - // Used to insert <|image_i|>\n per image (not a slice). - size_t m_image_id = 1; - size_t m_tokens_per_image = 0; + std::vector m_tokens_per_images; InputsEmbedderPhi3V( const VLMConfig& vlm_config, @@ -1491,7 +1489,7 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder { const std::string& device, const ov::AnyMap device_config ): - IInputsEmbedder(vlm_config, model_dir, device, device_config), m_image_id{0}, + IInputsEmbedder(vlm_config, model_dir, device, device_config), m_hd_feature_transformer{phi3_v::create_hd_feature_transformer()}, m_vision_projection{utils::singleton_core().compile_model(model_dir / "openvino_vision_projection_model.xml", device, {}).create_infer_request()} {} @@ -1502,8 +1500,8 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder { for (const ov::Tensor& image : to_single_image_tensors(images)) { EncodedImage encoded_image = m_vision_encoder.encode(image); images_features_proj.push_back(phi3_v::hd_feature_transform(encoded_image, m_hd_feature_transformer, m_vlm_config.sub_GN, m_vlm_config.glb_GN, m_vision_projection)); - images_prompt << "<|image_" << m_image_id << "|>\n"; - ++m_image_id; + m_tokens_per_images.push_back(images_features_proj.back().get_shape().at(1)); + images_prompt << "<|image_" << m_tokens_per_images.size() << "|>\n"; } images_prompt << prompt; std::vector new_chat_tokens; @@ -1511,8 +1509,7 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder { if (m_is_chat_conversation) { m_history.push_back({{"role", "user"}, {"content", images_prompt.str()}}); constexpr bool add_generation_prompt = true; - std::string new_templated_chat_history; - new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); + std::string new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); auto start_tokenizer_time = std::chrono::steady_clock::now(); new_chat_tokens = phi3_v::split_tokenize(new_templated_chat_history, m_tokenizer); prev_chat_tokens = phi3_v::split_tokenize(m_templated_chat_history, m_tokenizer); @@ -1525,11 +1522,8 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder { auto end_tokenizer_time = std::chrono::steady_clock::now(); metrics.raw_metrics.tokenization_durations.emplace_back(PerfMetrics::get_microsec(end_tokenizer_time - start_tokenizer_time)); } - if (0 == m_tokens_per_image && !images_features_proj.empty()) { - m_tokens_per_image = images_features_proj.at(0).get_shape().at(1); - } - ov::Tensor new_merged_tokens = phi3_v::insert_image_placeholders(new_chat_tokens, m_tokens_per_image); - ov::Tensor prev_merged_tokens = phi3_v::insert_image_placeholders(prev_chat_tokens, m_tokens_per_image); + ov::Tensor new_merged_tokens = phi3_v::insert_image_placeholders(new_chat_tokens, m_tokens_per_images); + ov::Tensor prev_merged_tokens = phi3_v::insert_image_placeholders(prev_chat_tokens, m_tokens_per_images); ov::Tensor new_tokens = update_history(new_merged_tokens, prev_merged_tokens); std::vector tokens = phi3_v::drop_image_placeholders(new_tokens); OPENVINO_ASSERT(tokens.size() == images_features_proj.size() + 1); @@ -1537,7 +1531,6 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder { for (size_t im_id = 0; im_id < images_features_proj.size(); ++im_id) { size_t text_length = tokens.at(im_id).get_shape().at(1); size_t im_length = images_features_proj.at(im_id).get_shape().at(1); - OPENVINO_ASSERT(im_length == m_tokens_per_image); features_length += text_length + im_length; } features_length += tokens.back().get_shape().at(1); @@ -1570,7 +1563,7 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder { ); if (!m_is_chat_conversation) { - m_image_id = 0; + m_tokens_per_images.clear(); } return inputs_embeds; @@ -1578,12 +1571,12 @@ class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder { virtual void start_chat(const std::string& system_message) override { IInputsEmbedder::start_chat(system_message); - m_image_id = 0; + m_tokens_per_images.clear(); } virtual void finish_chat() override { IInputsEmbedder::finish_chat(); - m_image_id = 0; + m_tokens_per_images.clear(); } }; @@ -1662,10 +1655,6 @@ class InputsEmbedderQwen2VL : public InputsEmbedder::IInputsEmbedder { ov::Tensor input_ids = get_encoded_input_ids(formatted_prompt, metrics, chat_template_fallback); ov::Tensor text_embeds = m_embedding.infer(input_ids); - if (images.empty()) { - return text_embeds; - } - auto start_tokenizer_time = std::chrono::steady_clock::now(); ov::Tensor encoded_vision_start_token = m_tokenizer.encode(m_vlm_config.vision_start_token, ov::genai::add_special_tokens(false)).input_ids; ov::Tensor encoded_image_pad_token = m_tokenizer.encode(m_vlm_config.image_pad_token, ov::genai::add_special_tokens(false)).input_ids; @@ -1680,6 +1669,10 @@ class InputsEmbedderQwen2VL : public InputsEmbedder::IInputsEmbedder { int64_t position_ids_max_element = *std::max_element(m_position_ids.data(), m_position_ids.data() + m_position_ids.get_size()); m_rope_delta = position_ids_max_element + 1 - static_cast(input_ids.get_shape().at(1)); + if (images.empty()) { + return text_embeds; + } + return merge_text_and_image_embeddings_qwen2vl(input_ids, text_embeds, image_embeds, images_grid_thw, image_pad_token_id); } @@ -1874,7 +1867,7 @@ class InputsEmbedderQwen2VL : public InputsEmbedder::IInputsEmbedder { } // Calculate rotary embeddings for max_grid_size - const size_t dim = 1280 / 16 / 2; // config.vision_config.embed_dim / self.config.vision_config.num_heads / 2 + const size_t dim = m_vision_embeddings_merger.get_tensor("rotary_pos_emb").get_shape().at(1); const float theta = 10000.0f; std::vector inv_freq(dim / 2); diff --git a/src/cpp/src/visual_language/vision_encoder.cpp b/src/cpp/src/visual_language/vision_encoder.cpp index 04ddd63145..e8edd40890 100644 --- a/src/cpp/src/visual_language/vision_encoder.cpp +++ b/src/cpp/src/visual_language/vision_encoder.cpp @@ -843,7 +843,7 @@ std::tuple get_pixel_values_phi3_v(const ov::Tensor& imag ImageSize smart_resize_qwen2vl(size_t height, size_t width, size_t factor, size_t min_pixels, size_t max_pixels) { if (height < factor || width < factor) { - OPENVINO_THROW("Height or width must be larger than factor"); + OPENVINO_THROW("Height (" + std::to_string(height) + ") and width (" + std::to_string(width) + ") must be greater than factor (" + std::to_string(factor) + ")"); } if (std::max(height, width) / std::min(height, width) > 200) { OPENVINO_THROW("Absolute aspect ratio must be smaller than 200"); diff --git a/tests/python_tests/common.py b/tests/python_tests/common.py index 320f1e1a6a..88690e872a 100644 --- a/tests/python_tests/common.py +++ b/tests/python_tests/common.py @@ -535,7 +535,7 @@ def get_image_by_link(link): image = Image.open(requests.get(link, stream=True).raw) if image.mode != 'RGB': image = image.convert('RGB') - image_data = np.array((np.array(image.getdata()) - 128).astype(np.byte)).reshape(1, 3, image.size[1], image.size[0]) + image_data = np.array((np.array(image.getdata()) - 128).astype(np.byte)).reshape(1, image.size[1], image.size[0], 3) return Tensor(image_data) diff --git a/tests/python_tests/test_vlm_pipeline.py b/tests/python_tests/test_vlm_pipeline.py index 0f9358b961..3c188b26b2 100644 --- a/tests/python_tests/test_vlm_pipeline.py +++ b/tests/python_tests/test_vlm_pipeline.py @@ -47,6 +47,8 @@ def get_ov_model(model_id, cache): @pytest.mark.parametrize("model_id", [ "katuni4ka/tiny-random-minicpmv-2_6", "katuni4ka/tiny-random-phi3-vision", + "katuni4ka/tiny-random-llava", + "katuni4ka/tiny-random-qwen2vl", ]) def test_vlm_pipeline(model_id, cache): def streamer(word: str) -> bool: