diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml
index 57776be64b..e444443ea7 100644
--- a/.github/workflows/mac.yml
+++ b/.github/workflows/mac.yml
@@ -17,7 +17,7 @@ concurrency:
env:
PYTHON_VERSION: '3.10'
- OV_BRANCH: 'master'
+ OV_BRANCH: 7f56fcd4658c6a427111ac835e809ddd87f0cad2
OV_TARBALL: ''
jobs:
diff --git a/SUPPORTED_MODELS.md b/SUPPORTED_MODELS.md
index f79234489d..3064fb58c1 100644
--- a/SUPPORTED_MODELS.md
+++ b/SUPPORTED_MODELS.md
@@ -312,6 +312,7 @@ In addition to image generation models, `InpaintingPipeline` supports specialize
Models |
LoRA support |
Example HuggingFace Models |
+ Notes |
InternVL2 |
@@ -329,6 +330,7 @@ In addition to image generation models, `InpaintingPipeline` supports specialize
OpenGVLab/InternVL2_5-8B
+ |
LLaVA |
@@ -339,6 +341,7 @@ In addition to image generation models, `InpaintingPipeline` supports specialize
llava-hf/llava-1.5-7b-hf
+ |
LLaVA-NeXT |
@@ -351,6 +354,7 @@ In addition to image generation models, `InpaintingPipeline` supports specialize
llava-hf/llama3-llava-next-8b-hf
+ |
MiniCPMV |
@@ -361,6 +365,22 @@ In addition to image generation models, `InpaintingPipeline` supports specialize
openbmb/MiniCPM-V-2_6
+ |
+
+
+ Phi3VForCausalLM |
+ phi3_v |
+ Not supported |
+
+
+ |
+
+ GPU isn't supported
+ These models' configs aren't consistent. It's required to override the default eos_token_id with the one from a tokenizer: generation_config.set_eos_token_id(pipe.get_tokenizer().get_eos_token_id()) .
+ |
Qwen2-VL |
@@ -372,6 +392,7 @@ In addition to image generation models, `InpaintingPipeline` supports specialize
Qwen/Qwen2-VL-7B-Instruct
+ |
diff --git a/src/cpp/src/visual_language/clip.cpp b/src/cpp/src/visual_language/clip.cpp
index 30a6dff5ae..9347f63074 100644
--- a/src/cpp/src/visual_language/clip.cpp
+++ b/src/cpp/src/visual_language/clip.cpp
@@ -12,7 +12,7 @@ static float clip_lerp(float s, float e, float t) {
}
// Bilinear resize function
-static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) {
+void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) {
dst.nx = target_width;
dst.ny = target_height;
dst.buf.resize(3 * target_width * target_height);
diff --git a/src/cpp/src/visual_language/clip.hpp b/src/cpp/src/visual_language/clip.hpp
index 4bdb4542d0..e00ac2fc40 100644
--- a/src/cpp/src/visual_language/clip.hpp
+++ b/src/cpp/src/visual_language/clip.hpp
@@ -31,6 +31,7 @@ struct clip_image_f32 {
};
void bicubic_resize(const clip_image_u8& img, clip_image_u8& dst, int target_width, int target_height);
+void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height);
/** preprocess img and store the result in res_imgs, pad_to_square may be overridden to false depending on model configuration */
clip_image_f32 clip_image_preprocess(struct clip_ctx& ctx, const clip_image_u8& img);
diff --git a/src/cpp/src/visual_language/inputs_embedder.cpp b/src/cpp/src/visual_language/inputs_embedder.cpp
index 4f3812862c..66b17e5804 100644
--- a/src/cpp/src/visual_language/inputs_embedder.cpp
+++ b/src/cpp/src/visual_language/inputs_embedder.cpp
@@ -7,15 +7,10 @@
#include "visual_language/clip.hpp"
#include "visual_language/vision_encoder.hpp"
#include "visual_language/embedding_model.hpp"
+#include "openvino/opsets/opset13.hpp"
#include "utils.hpp"
-
-
-namespace {
-
-constexpr size_t BATCH_SIZE = 1;
-
-} // namespace
+#include
namespace ov::genai {
@@ -155,17 +150,8 @@ class InputsEmbedder::IInputsEmbedder {
),
m_tokenizer(tokenizer) { }
- ov::Tensor get_encoded_input_ids(const std::string& prompt, ov::genai::VLMPerfMetrics& metrics, const std::string& chat_template_fallback = {}) {
- ov::Tensor encoded_input_ids;
+ std::pair apply_chat_template_tokenize(const std::string& prompt, ov::genai::VLMPerfMetrics& metrics, const std::string& chat_template_fallback = {}) {
if (m_is_chat_conversation) {
- // KV cache in model already contains prompts and answers from previous iterations.
- // So only new prompt wrapped into chat template to be sent into model. Tokenizer always returns
- // token_ids = {, ...}. So if tokenizer applies only to the new prompt,
- // will be inserted on every iteration.
- // So actual pipeline calculates input_ids for whole chat history + for whole chat history without the new prompt
- // and takes only the difference between them.
- // The chat history cannot be saved as already encoded tokens because generate call doesn't return token, but
- // KV cache contains it. So we have to add it manually or get it by tokenization all chat history.
m_history.push_back({{"role", "user"}, {"content", prompt}});
constexpr bool add_generation_prompt = true;
std::string new_templated_chat_history;
@@ -177,9 +163,31 @@ class InputsEmbedder::IInputsEmbedder {
}
auto start_tokenizer_time = std::chrono::steady_clock::now();
ov::Tensor new_chat_tokens = m_tokenizer.encode(new_templated_chat_history, ov::genai::add_special_tokens(false)).input_ids;
- TokenizedInputs prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(false));
+ ov::Tensor prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(false)).input_ids;
+ auto end_tokenizer_time = std::chrono::steady_clock::now();
+ metrics.raw_metrics.tokenization_durations.emplace_back(PerfMetrics::get_microsec(end_tokenizer_time - start_tokenizer_time));
+ m_templated_chat_history = std::move(new_templated_chat_history);
+ return {new_chat_tokens, prev_chat_tokens};
+ } else {
+ auto start_tokenizer_time = std::chrono::steady_clock::now();
+ ov::Tensor encoded_input_ids = m_tokenizer.encode(prompt).input_ids;
auto end_tokenizer_time = std::chrono::steady_clock::now();
metrics.raw_metrics.tokenization_durations.emplace_back(PerfMetrics::get_microsec(end_tokenizer_time - start_tokenizer_time));
+ return {encoded_input_ids, ov::Tensor()};
+ }
+ }
+
+ ov::Tensor update_history(const ov::Tensor& new_chat_tokens, const ov::Tensor& prev_chat_tokens) {
+ if (m_is_chat_conversation) {
+ ov::Tensor encoded_input_ids;
+ // KV cache in model already contains prompts and answers from previous iterations.
+ // So only new prompt wrapped into chat template to be sent into model. Tokenizer always returns
+ // token_ids = {, ...}. So if tokenizer applies only to the new prompt,
+ // will be inserted on every iteration.
+ // So actual pipeline calculates input_ids for whole chat history + for whole chat history without the new prompt
+ // and takes only the difference between them.
+ // The chat history cannot be saved as already encoded tokens because generate call doesn't return token, but
+ // KV cache contains it. So we have to add it manually or get it by tokenization all chat history.
// some symbols combinations can be encoded by the tokenizer in different ways
// if we met sequence with such combination of symbols, we cannot correctly subtract the new history from the old history
@@ -187,7 +195,7 @@ class InputsEmbedder::IInputsEmbedder {
size_t trusted_history_length = 0;
if (!m_tokenized_history.empty()) {
std::set stop_tokens = {m_tokenizer.get_eos_token_id()};
- trusted_history_length = ov::genai::utils::get_first_history_difference(prev_chat_tokens.input_ids, m_tokenized_history, stop_tokens);
+ trusted_history_length = ov::genai::utils::get_first_history_difference(prev_chat_tokens, m_tokenized_history, stop_tokens);
}
if (m_tokenized_history.empty()) {
@@ -213,25 +221,25 @@ class InputsEmbedder::IInputsEmbedder {
new_tensor.copy_to(encoded_input_ids);
} else {
encoded_input_ids = utils::subtract_chat_tokenized_inputs(
- {new_chat_tokens}, prev_chat_tokens
+ {new_chat_tokens}, {prev_chat_tokens}
).input_ids;
if (m_last_disappeared_token.has_value())
encoded_input_ids = ov::genai::utils::push_front_inputs(encoded_input_ids, *m_last_disappeared_token);
}
- m_templated_chat_history = std::move(new_templated_chat_history);
m_tokenized_history.clear();
std::copy_n(new_chat_tokens.data(), new_chat_tokens.get_size(), std::back_inserter(m_tokenized_history));
+ return encoded_input_ids;
} else {
- auto start_tokenizer_time = std::chrono::steady_clock::now();
- encoded_input_ids = m_tokenizer.encode(prompt).input_ids;
- auto end_tokenizer_time = std::chrono::steady_clock::now();
- metrics.raw_metrics.tokenization_durations.emplace_back(PerfMetrics::get_microsec(end_tokenizer_time - start_tokenizer_time));
m_tokenized_history.clear();
- std::copy_n(encoded_input_ids.data(), encoded_input_ids.get_size(), std::back_inserter(m_tokenized_history));
+ std::copy_n(new_chat_tokens.data(), new_chat_tokens.get_size(), std::back_inserter(m_tokenized_history));
+ return new_chat_tokens;
}
+ }
- return encoded_input_ids;
+ ov::Tensor get_encoded_input_ids(const std::string& prompt, ov::genai::VLMPerfMetrics& metrics, const std::string& chat_template_fallback = "") {
+ const auto [new_chat_tokens, prev_chat_tokens] = apply_chat_template_tokenize(prompt, metrics, chat_template_fallback);
+ return update_history(new_chat_tokens, prev_chat_tokens);
}
/**
@@ -687,6 +695,7 @@ class InputsEmbedderLLaVA : public InputsEmbedder::IInputsEmbedder {
}
size_t merged_seq_length = text_embeds_seq_length + total_image_seq_length - num_image_tokens;
+ constexpr size_t BATCH_SIZE = 1;
ov::Tensor merged_embeds(text_embeds.get_element_type(), {BATCH_SIZE, merged_seq_length, hidden_size});
float* merged_data = merged_embeds.data();
@@ -1163,6 +1172,400 @@ class InputsEmbedderInternVLChat : public InputsEmbedder::IInputsEmbedder {
}
};
+namespace {
+namespace phi3_v {
+// Reimplementation of python
+// N, L, C = image_features.shape
+// assert L == 24 * 24 and C == 1024 and N % (h_crop * w_crop) == 0
+// num_images = N // (h_crop * w_crop)
+// H = int(L**0.5)
+// print(L, H)
+// image_features_hd = (
+// image_features.reshape(N, H, H, C) # N, 24, 24, 1024
+// .reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024
+// .permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024
+// .reshape(N, -1, 4 * C) # N, 144, 4096
+// .reshape(num_images, h_crop, w_crop, H // 2, H // 2, -1) # n_img, h_crop, w_crop, 12, 12, 4096
+// .permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096
+// .reshape(num_images, h_crop * H // 2, w_crop * H // 2, 4 * C) # n_img, h_crop*12, w_crop*12, 4096
+// )
+// Obtained in the following way
+// import torch
+// import openvino as ov
+// import numpy as np
+// class Model(torch.nn.Module):
+// def forward(self, image_features, h_crop, w_crop):
+// """
+// image_features: (num_images*num_crops, 24*24, 1024)
+// output: (num_images, h_crop*12, w_crop*12, 4096), h_crop*w_crop == num_crops
+// """
+// N, L, C = image_features.shape
+// num_images = N // (h_crop * w_crop)
+// H = (torch.tensor(L, dtype=torch.float32)**0.5).int()
+// image_features_hd = (
+// image_features.reshape(N, H, H, C) # N, 24, 24, 1024
+// .reshape(N, H // 2, 2, H // 2, 2, C) # N, 12, 2, 12, 2, 1024
+// .permute(0, 1, 3, 2, 4, 5) # N, 12, 12, 2, 2, 1024
+// .reshape(N, -1, 4 * C) # N, 144, 4096
+// .reshape(num_images, h_crop, w_crop, H // 2, H // 2, -1) # n_img, h_crop, w_crop, 12, 12, 4096
+// .permute(0, 1, 3, 2, 4, 5) # n_img, h_crop, 12, w_crop, 12, 4096
+// .reshape(num_images, h_crop * H // 2, w_crop * H // 2, 4 * C) # n_img, h_crop*12, w_crop*12, 4096
+// return {"o": image_features_hd}
+// model = Model()
+// example_input = {"image_features": torch.rand((4, 576, 1024), dtype=torch.float32), "h_crop": torch.tensor(2, dtype=torch.int32), "w_crop": torch.tensor(2, dtype=torch.int32)}
+// ov_model = ov.convert_model(model, example_input=example_input, input=ov.PartialShape([-1, 576, 1024]))
+// # ov_model.outputs[0].get_tensor().set_names({"out"})
+// ov.save_model(ov_model, "reshape_hd_patches_2x2merge.xml")
+// inp = np.arange(4 * 576 * 1024).reshape([4, 576, 1024])
+// test = ov.Core().compile_model(ov_model, "CPU")
+// print(ov_model)
+// print(test([inp, 2, 2])["o"].flatten())
+// 2. Run https://github.com/slyalin/openvino_devtools/blob/bcd4a51b1354b24b2316ac3e1c77b2f87ae7a497/openvino_devtools/ov2py.py with the IR.
+// 3. Translate the printed Python implementation to C++.
+ov::InferRequest create_hd_feature_transformer() {
+ using namespace ov;
+ using namespace element;
+ using namespace opset13;
+ using namespace std;
+ auto t0 = make_shared(f32, PartialShape{-1, 576, 1024});
+ auto t1 = make_shared(i32, PartialShape{});
+ auto t2 = make_shared(i32, PartialShape{});
+ auto t3 = make_shared(t0);
+ auto t4 = make_shared(i64, Shape{}, vector{0});
+ auto t5 = make_shared(i64, Shape{}, vector{0});
+ auto t6 = make_shared(t3, t4, t5);
+ auto t7 = make_shared(i64, Shape{1}, vector{1});
+ auto t8 = make_shared(t6, t7, false);
+ auto t9 = make_shared(i64, Shape{}, vector{1});
+ auto t10 = make_shared(i64, Shape{}, vector{0});
+ auto t11 = make_shared(t3, t9, t10);
+ auto t12 = make_shared(t11, element::f32);
+ auto t13 = make_shared(f32, Shape{}, vector{0.5});
+ auto t14 = make_shared(t12, t13, "numpy");
+ auto t15 = make_shared(t14, element::i32);
+ auto t16 = make_shared(t15, element::i64);
+ auto t17 = make_shared(i32, Shape{}, vector{0});
+ auto t18 = make_shared(t16, t17);
+ auto t19 = make_shared(i64, Shape{1}, vector{2});
+ auto t20 = make_shared(i64, Shape{}, vector{0});
+ auto t21 = make_shared(t3, t19, t20);
+ auto t22 = make_shared(NodeVector{t8, t18, t18, t21}, 0);
+ auto t23 = make_shared(t0, t22, false);
+ auto t24 = make_shared(i64, Shape{}, vector{2});
+ auto t25 = make_shared(t16, t24, "numpy");
+ auto t26 = make_shared(t25);
+ auto t27 = make_shared(i32, Shape{}, vector{0});
+ auto t28 = make_shared(t26, t27);
+ auto t29 = make_shared(i64, Shape{1}, vector{2});
+ auto t30 = make_shared(i64, Shape{1}, vector{2});
+ auto t31 = make_shared(NodeVector{t8, t28, t29, t28, t30, t21}, 0);
+ auto t32 = make_shared(t23, t31, false);
+ auto t33 = make_shared(i64, Shape{6}, vector{0, 1, 3, 2, 4, 5});
+ auto t34 = make_shared(t32, t33);
+ auto t35 = make_shared(i64, Shape{1}, vector{-1});
+ auto t36 = make_shared(i64, Shape{1}, vector{4});
+ auto t37 = make_shared(t21, t36, "numpy");
+ auto t38 = make_shared(NodeVector{t8, t35, t37}, 0);
+ auto t39 = make_shared(t34, t38, false);
+ auto t40 = make_shared(t1, t2, "numpy");
+ auto t41 = make_shared(t40, element::i64);
+ auto t42 = make_shared(t6, t41, "numpy");
+ auto t43 = make_shared(t42);
+ auto t44 = make_shared(i64, Shape{}, vector{0});
+ auto t45 = make_shared(t43, t44);
+ auto t46 = make_shared(t1, element::i64);
+ auto t47 = make_shared(t46, t44);
+ auto t48 = make_shared(t2, element::i64);
+ auto t49 = make_shared(t48, t44);
+ auto t50 = make_shared(i64, Shape{1}, vector{-1});
+ auto t51 = make_shared(NodeVector{t45, t47, t49, t28, t28, t50}, 0);
+ auto t52 = make_shared(t39, t51, false);
+ auto t53 = make_shared(i64, Shape{6}, vector{0, 1, 3, 2, 4, 5});
+ auto t54 = make_shared(t52, t53);
+ auto t55 = make_shared(t1, t15, "numpy");
+ auto t56 = make_shared(t55, element::i64);
+ auto t57 = make_shared(i64, Shape{}, vector{2});
+ auto t58 = make_shared(t56, t57, "numpy");
+ auto t59 = make_shared(t58);
+ auto t60 = make_shared(i32, Shape{}, vector{0});
+ auto t61 = make_shared(t59, t60);
+ auto t62 = make_shared(t2, t15, "numpy");
+ auto t63 = make_shared(t62, element::i64);
+ auto t64 = make_shared(i64, Shape{}, vector{2});
+ auto t65 = make_shared(t63, t64, "numpy");
+ auto t66 = make_shared(t65);
+ auto t67 = make_shared(t66, t60);
+ auto t68 = make_shared(NodeVector{t45, t61, t67, t37}, 0);
+ auto t69 = make_shared(t54, t68, false);
+ shared_ptr model = make_shared(make_shared(t69), ParameterVector{t0, t1, t2});
+ return utils::singleton_core().compile_model(
+ model, "CPU"
+ ).create_infer_request();
+}
+
+ov::Tensor reshape_hd_patches_2x2merge(const ov::Tensor& image_features, size_t h_crop, size_t w_crop, InferRequest& hd_feature_transformer) {
+ ov::Shape shape = image_features.get_shape();
+ OPENVINO_ASSERT(3 == shape.size());
+ OPENVINO_ASSERT(24 * 24 == shape.at(1));
+ OPENVINO_ASSERT(1024 == shape.at(2));
+ hd_feature_transformer.set_input_tensor(0, image_features);
+ ov::Tensor height{ov::element::i32, {}, &h_crop};
+ hd_feature_transformer.set_input_tensor(1, height);
+ ov::Tensor width{ov::element::i32, {}, &w_crop};
+ hd_feature_transformer.set_input_tensor(2, width);
+ hd_feature_transformer.infer();
+ return hd_feature_transformer.get_output_tensor();
+}
+
+// image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
+// output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
+ov::Tensor add_image_newline(const ov::Tensor& image_features_hd, const std::vector& sub_GN) {
+ const ov::Shape& nhwc = image_features_hd.get_shape(); // [N, 12*h_crop, 12*w_crop, 4096]
+ const float* in = image_features_hd.data();
+ ov::Tensor image_features_hd_new_line{ov::element::f32, {nhwc.at(0), nhwc.at(1) * (nhwc.at(2) + 1), nhwc.at(3)}};
+ float* out = image_features_hd_new_line.data();
+ for (size_t batch_id = 0; batch_id < nhwc.at(0); ++batch_id) {
+ for (size_t row_id = 0; row_id < nhwc.at(1); ++row_id) {
+ for (size_t col_id = 0; col_id < nhwc.at(2); ++col_id) {
+ std::copy_n(
+ in + batch_id * nhwc.at(1) * nhwc.at(2) * nhwc.at(3) + row_id * nhwc.at(2) * nhwc.at(3) + col_id * nhwc.at(3),
+ nhwc.at(3),
+ out + batch_id * nhwc.at(1) * (nhwc.at(2) + 1) * nhwc.at(3) + row_id * (nhwc.at(2) + 1) * nhwc.at(3) + col_id * nhwc.at(3)
+ );
+ }
+ std::copy(
+ sub_GN.begin(),
+ sub_GN.end(),
+ out + batch_id * nhwc.at(1) * (nhwc.at(2) + 1) * nhwc.at(3) + row_id * (nhwc.at(2) + 1) * nhwc.at(3) + nhwc.at(2) * nhwc.at(3)
+ );
+ }
+ }
+ return image_features_hd_new_line;
+}
+
+ov::Tensor concatenate_2d(const ov::Tensor& first_1lf, const std::vector& second_f, const ov::Tensor& third_1lf) {
+ size_t first_l = first_1lf.get_shape().at(1);
+ constexpr size_t second_l = 1;
+ size_t third_l = third_1lf.get_shape().at(1);
+ size_t features = first_1lf.get_shape().at(2);
+ OPENVINO_ASSERT(second_f.size() == features);
+ ov::Tensor out_1lf{ov::element::f32, {1, first_l + second_l + third_l, features}};
+ float* out = out_1lf.data();
+ std::copy_n(first_1lf.data(), first_l * features, out);
+ std::copy(second_f.begin(), second_f.end(), out + first_l * features);
+ std::copy_n(third_1lf.data(), third_l * features, out + (first_l + second_l) * features);
+ return out_1lf;
+}
+
+// image_features.resized_source: (num_crops+1, 24*24, 1024)
+ov::Tensor hd_feature_transform(const EncodedImage& image_features, InferRequest& hd_feature_transformer, const std::vector& sub_GN, const std::vector& glb_GN, ov::InferRequest& vision_projection) {
+ const ov::Shape& image_features_shape = image_features.resized_source.get_shape();
+ ov::Tensor global_image_features{ov::element::f32, {1, image_features_shape.at(1), image_features_shape.at(2)}, image_features.resized_source.data()};
+ // global feature can be viewed as a special HD case with num_crops 1x1
+ ov::Tensor global_image_features_hd = reshape_hd_patches_2x2merge(global_image_features, 1, 1, hd_feature_transformer);
+ ov::Tensor global_image_features_hd_newline = add_image_newline(global_image_features_hd, sub_GN); // [1,12*(12+1),4096]
+ constexpr size_t INPUT_IMAGE_SIZE = 336;
+ size_t h_crop = image_features.resized_source_size.height / INPUT_IMAGE_SIZE;
+ size_t w_crop = image_features.resized_source_size.width / INPUT_IMAGE_SIZE;
+ size_t num_crops = h_crop * w_crop;
+
+ // NOTE: real num_crops is padded
+ // (num_crops, 24*24, 1024)
+ ov::Tensor sub_image_features{ov::element::f32, {
+ num_crops,
+ image_features_shape.at(1),
+ image_features_shape.at(2)
+ }, image_features.resized_source.data() + image_features_shape.at(1) * image_features_shape.at(2)};
+ ov::Tensor sub_image_features_hd = reshape_hd_patches_2x2merge(sub_image_features, h_crop, w_crop, hd_feature_transformer); // [1, 24, 24, 4096]
+ ov::Tensor sub_image_features_hd_newline = add_image_newline(sub_image_features_hd, sub_GN); // [1,h_crop*12*(w_crop*12+1), 4096]
+ ov::Tensor image_embeddings = concatenate_2d(sub_image_features_hd_newline, glb_GN, global_image_features_hd_newline); // [1,l,4096]
+ vision_projection.set_input_tensor(image_embeddings);
+ vision_projection.infer();
+ ov::Tensor out = vision_projection.get_output_tensor();
+ ov::Tensor res{out.get_element_type(), out.get_shape()};
+ out.copy_to(res);
+ return res;
+}
+
+std::vector split_tokenize(const std::string& text, ov::genai::Tokenizer& tokenizer) {
+ constexpr int make_suffix_iterator = -1;
+ std::regex rgx{R"(<\|image_\d+\|>)"};
+ std::sregex_token_iterator iter{
+ text.begin(),
+ text.end(),
+ rgx,
+ make_suffix_iterator
+ };
+ std::vector tokenized;
+ for ( ; iter != std::sregex_token_iterator{}; ++iter) {
+ if (iter->str().empty()) {
+ continue;
+ }
+ std::string substr = *iter;
+ tokenized.push_back(tokenizer.encode(substr, ov::genai::add_special_tokens(true)).input_ids);
+ }
+ return tokenized;
+}
+
+ov::Tensor insert_image_placeholders(const std::vector& chunks, size_t tokens_per_image) {
+ size_t merged_length = 0;
+ for (const ov::Tensor& chunk : chunks) {
+ merged_length += chunk.get_shape().at(1);
+ }
+ merged_length += chunks.empty() ? 0 : (chunks.size() - 1) * tokens_per_image;
+ ov::Tensor merged{ov::element::i64, {1, merged_length}};
+ size_t offset = 0;
+ int64_t image_id = -1;
+ for (const ov::Tensor& chunk : chunks) {
+ size_t length = chunk.get_shape().at(1);
+ std::copy_n(
+ chunk.data(),
+ length,
+ merged.data() + offset
+ );
+ offset += length;
+ if (offset < merged_length) {
+ std::fill_n(
+ merged.data() + offset,
+ tokens_per_image,
+ image_id
+ );
+ offset += tokens_per_image;
+ --image_id;
+ }
+ }
+ return merged;
+}
+
+std::vector drop_image_placeholders(const ov::Tensor& tokens) {
+ std::vector chunks;
+ size_t offset = 0;
+ while (offset < tokens.get_shape().at(1)) {
+ size_t length = 0;
+ while (offset + length < tokens.get_shape().at(1) && tokens.data()[offset + length] >= 0) {
+ ++length;
+ }
+ chunks.emplace_back(ov::element::i64, ov::Shape{1, length}, tokens.data() + offset);
+ offset += length;
+ while (offset < tokens.get_shape().at(1) && tokens.data()[offset] < 0) {
+ ++offset;
+ }
+ }
+ return chunks;
+}
+} // namespace phi3_v
+} // anonymous namespace
+
+class InputsEmbedderPhi3V : public InputsEmbedder::IInputsEmbedder {
+public:
+ ov::InferRequest m_hd_feature_transformer;
+ ov::InferRequest m_vision_projection;
+ // Used to insert <|image_i|>\n per image (not a slice).
+ size_t m_image_id = 1;
+ size_t m_tokens_per_image = 0;
+
+ InputsEmbedderPhi3V(
+ const VLMConfig& vlm_config,
+ const std::filesystem::path& model_dir,
+ const std::string& device,
+ const ov::AnyMap device_config
+ ):
+ IInputsEmbedder(vlm_config, model_dir, device, device_config), m_image_id{0},
+ m_hd_feature_transformer{phi3_v::create_hd_feature_transformer()},
+ m_vision_projection{utils::singleton_core().compile_model(model_dir / "openvino_vision_projection_model.xml", device, {}).create_infer_request()} {}
+
+ ov::Tensor get_inputs_embeds(const std::string& prompt, const std::vector& images, ov::genai::VLMPerfMetrics& metrics) override {
+ OPENVINO_ASSERT(images.empty() || m_history.empty(), "Images can only be provided for initial prompt");
+ std::vector images_features_proj;
+ std::stringstream images_prompt;
+ for (const ov::Tensor& image : to_single_image_tensors(images)) {
+ EncodedImage encoded_image = m_vision_encoder.encode(image);
+ images_features_proj.push_back(phi3_v::hd_feature_transform(encoded_image, m_hd_feature_transformer, m_vlm_config.sub_GN, m_vlm_config.glb_GN, m_vision_projection));
+ images_prompt << "<|image_" << m_image_id << "|>\n";
+ ++m_image_id;
+ }
+ images_prompt << prompt;
+ std::vector new_chat_tokens;
+ std::vector prev_chat_tokens;
+ if (m_is_chat_conversation) {
+ m_history.push_back({{"role", "user"}, {"content", images_prompt.str()}});
+ constexpr bool add_generation_prompt = true;
+ std::string new_templated_chat_history;
+ new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt);
+ auto start_tokenizer_time = std::chrono::steady_clock::now();
+ new_chat_tokens = phi3_v::split_tokenize(new_templated_chat_history, m_tokenizer);
+ prev_chat_tokens = phi3_v::split_tokenize(m_templated_chat_history, m_tokenizer);
+ auto end_tokenizer_time = std::chrono::steady_clock::now();
+ metrics.raw_metrics.tokenization_durations.emplace_back(PerfMetrics::get_microsec(end_tokenizer_time - start_tokenizer_time));
+ m_templated_chat_history = std::move(new_templated_chat_history);
+ } else {
+ auto start_tokenizer_time = std::chrono::steady_clock::now();
+ new_chat_tokens = phi3_v::split_tokenize(images_prompt.str(), m_tokenizer);
+ auto end_tokenizer_time = std::chrono::steady_clock::now();
+ metrics.raw_metrics.tokenization_durations.emplace_back(PerfMetrics::get_microsec(end_tokenizer_time - start_tokenizer_time));
+ }
+ if (0 == m_tokens_per_image && !images_features_proj.empty()) {
+ m_tokens_per_image = images_features_proj.at(0).get_shape().at(1);
+ }
+ ov::Tensor new_merged_tokens = phi3_v::insert_image_placeholders(new_chat_tokens, m_tokens_per_image);
+ ov::Tensor prev_merged_tokens = phi3_v::insert_image_placeholders(prev_chat_tokens, m_tokens_per_image);
+ ov::Tensor new_tokens = update_history(new_merged_tokens, prev_merged_tokens);
+ std::vector tokens = phi3_v::drop_image_placeholders(new_tokens);
+ OPENVINO_ASSERT(tokens.size() == images_features_proj.size() + 1);
+ size_t features_length = 0;
+ for (size_t im_id = 0; im_id < images_features_proj.size(); ++im_id) {
+ size_t text_length = tokens.at(im_id).get_shape().at(1);
+ size_t im_length = images_features_proj.at(im_id).get_shape().at(1);
+ OPENVINO_ASSERT(im_length == m_tokens_per_image);
+ features_length += text_length + im_length;
+ }
+ features_length += tokens.back().get_shape().at(1);
+ ov::Tensor inputs_embeds{ov::element::f32, {1, features_length, m_vlm_config.hidden_size}};
+ size_t offset = 0;
+ for (size_t im_id = 0; im_id < images_features_proj.size(); ++im_id) {
+ const ov::Tensor& text_embeds = m_embedding.infer(tokens.at(im_id));
+ const ov::Tensor& image_embeds = images_features_proj.at(im_id);
+ size_t text_length = text_embeds.get_shape().at(1);
+ size_t im_length = image_embeds.get_shape().at(1);
+ std::copy_n(
+ text_embeds.data(),
+ text_embeds.get_size(),
+ inputs_embeds.data() + offset * m_vlm_config.hidden_size
+ );
+ offset += text_length;
+ std::copy_n(
+ image_embeds.data(),
+ image_embeds.get_size(),
+ inputs_embeds.data() + offset * m_vlm_config.hidden_size
+ );
+ offset += im_length;
+ }
+ const ov::Tensor& text_embeds = m_embedding.infer(tokens.back());
+ size_t text_length = text_embeds.get_shape().at(1);
+ std::copy_n(
+ text_embeds.data(),
+ text_embeds.get_size(),
+ inputs_embeds.data() + offset * m_vlm_config.hidden_size
+ );
+
+ if (!m_is_chat_conversation) {
+ m_image_id = 0;
+ }
+
+ return inputs_embeds;
+ }
+
+ virtual void start_chat(const std::string& system_message) override {
+ IInputsEmbedder::start_chat(system_message);
+ m_image_id = 0;
+ }
+
+ virtual void finish_chat() override {
+ IInputsEmbedder::finish_chat();
+ m_image_id = 0;
+ }
+};
+
class InputsEmbedderQwen2VL : public InputsEmbedder::IInputsEmbedder {
// A model for merging image embeddings (hidden states), rotary_pos_emb and attension_mask.
// Inputs:
@@ -1577,6 +1980,8 @@ InputsEmbedder::InputsEmbedder(const VLMConfig& vlm_config,
m_impl = std::make_shared(vlm_config, model_dir, device, device_config);
} else if (vlm_config.model_type == VLMModelType::INTERNVL_CHAT) {
m_impl = std::make_shared(vlm_config, model_dir, device, device_config);
+ } else if (vlm_config.model_type == VLMModelType::PHI3_V) {
+ m_impl = std::make_shared(vlm_config, model_dir, device, device_config);
} else if (vlm_config.model_type == VLMModelType::QWEN2_VL) {
m_impl = std::make_shared(vlm_config, model_dir, device, device_config);
} else {
diff --git a/src/cpp/src/visual_language/inputs_embedder.hpp b/src/cpp/src/visual_language/inputs_embedder.hpp
index 223d090b22..4462c58185 100644
--- a/src/cpp/src/visual_language/inputs_embedder.hpp
+++ b/src/cpp/src/visual_language/inputs_embedder.hpp
@@ -68,6 +68,7 @@ class InputsEmbedder {
friend class InputsEmbedderLLaVA;
friend class InputsEmbedderLLaVANext;
friend class InputsEmbedderInternVLChat;
+ friend class InputsEmbedderPhi3V;
friend class InputsEmbedderQwen2VL;
};
diff --git a/src/cpp/src/visual_language/processor_config.cpp b/src/cpp/src/visual_language/processor_config.cpp
index f790c58912..527557061e 100644
--- a/src/cpp/src/visual_language/processor_config.cpp
+++ b/src/cpp/src/visual_language/processor_config.cpp
@@ -41,6 +41,10 @@ ov::genai::ProcessorConfig::ProcessorConfig(const std::filesystem::path& json_pa
if (parsed.contains("image_grid_pinpoints")) {
image_grid_pinpoints = parsed.at("image_grid_pinpoints").get>>();
}
+ read_json_param(parsed, "num_crops", phi3_v.num_crops);
+ if (parsed.contains("img_processor")) {
+ phi3_v.num_img_tokens = parsed.at("img_processor").at("num_img_tokens");
+ }
// Setting qwen2vl config params
read_json_param(parsed, "min_pixels", min_pixels);
diff --git a/src/cpp/src/visual_language/processor_config.hpp b/src/cpp/src/visual_language/processor_config.hpp
index 1d40e091a9..1c4db59fd9 100644
--- a/src/cpp/src/visual_language/processor_config.hpp
+++ b/src/cpp/src/visual_language/processor_config.hpp
@@ -35,9 +35,10 @@ class ProcessorConfig {
/// llava calls it image_std.
std::array norm_std{1.0f, 1.0f, 1.0f};
- // llava specific config params
+ // A renamed version of norm_mean.
std::array image_mean{0.0f, 0.0f, 0.0f};
std::array image_std{1.0f, 1.0f, 1.0f};
+ // llava specific config params
size_t crop_size_height = 336;
size_t crop_size_width = 336;
size_t size_shortest_edge = 336;
@@ -45,6 +46,10 @@ class ProcessorConfig {
// llava-next specific config params
std::vector> image_grid_pinpoints{{336, 672}, {672, 336}, {672, 672}, {1008, 336}, {336, 1008}};
+ struct {
+ size_t num_crops = 4;
+ size_t num_img_tokens = 144;
+ } phi3_v;
// qwen2vl specific params
size_t min_pixels = 3136;
size_t max_pixels = 12845056;
diff --git a/src/cpp/src/visual_language/vision_encoder.cpp b/src/cpp/src/visual_language/vision_encoder.cpp
index 4a5179fdd0..04ddd63145 100644
--- a/src/cpp/src/visual_language/vision_encoder.cpp
+++ b/src/cpp/src/visual_language/vision_encoder.cpp
@@ -645,6 +645,202 @@ ov::Tensor get_pixel_values_internvl(const ov::Tensor& image, const ProcessorCon
return output_tensor;
}
+namespace phi3_v {
+constexpr size_t INPUT_IMAGE_SIZE = 336;
+
+ov::Tensor padding_336(const ov::Tensor& unpadded) {
+ ov::Shape _1ss3 = unpadded.get_shape();
+ size_t s1 = _1ss3.at(1), s2 = _1ss3.at(2);
+ if (s1 < s2) {
+ size_t tar = size_t(std::ceil(float(s1) / INPUT_IMAGE_SIZE) * INPUT_IMAGE_SIZE);
+ size_t top_padding = (tar - s1) / 2;
+ ov::Tensor padded{ov::element::u8, {1, tar, s2, 3}};
+ uint8_t* padded_data = padded.data();
+ std::fill_n(padded_data, padded.get_size(), 255);
+ std::copy_n(unpadded.data(), unpadded.get_size(), padded_data + top_padding * s2 * 3);
+ return padded;
+ }
+ size_t tar = size_t(std::ceil(float(s2) / INPUT_IMAGE_SIZE) * INPUT_IMAGE_SIZE);
+ size_t left_padding = (tar - s2) / 2;
+ ov::Tensor padded{ov::element::u8, {1, s1, tar, 3}};
+ uint8_t* padded_data = padded.data();
+ std::fill_n(padded_data, padded.get_size(), 255);
+ uint8_t* unpadded_data = unpadded.data();
+ for (size_t row = 0; row < s1; ++row) {
+ std::copy_n(unpadded_data + row * s2 * 3, s2 * 3, padded_data + row * tar * 3 + left_padding * 3);
+ }
+ return padded;
+}
+
+ov::Tensor HD_transform(const ov::Tensor& uint8, size_t num_crops) {
+ ov::Shape _1hwc = uint8.get_shape();
+ size_t height = _1hwc.at(1), width = _1hwc.at(2);
+ bool trans = false;
+ if (width < height) {
+ std::swap(height, width);
+ trans = true;
+ }
+ float ratio = float(width) / height;
+ unsigned scale = 1;
+ while (scale * std::ceil(scale / ratio) <= num_crops) {
+ ++scale;
+ }
+ --scale;
+ size_t new_w = scale * INPUT_IMAGE_SIZE;
+ size_t new_h = new_w / ratio;
+ clip_image_u8 src{}, dst{};
+ uint8_t* uint8_data = uint8.data();
+ if (trans) {
+ src = clip_image_u8{int(height), int(width), {uint8_data, uint8_data + uint8.get_size()}};
+ bilinear_resize(src, dst, new_h, new_w);
+ return padding_336(ov::Tensor{ov::element::u8, {1, new_w, new_h, 3}, dst.buf.data()});
+ }
+ src = clip_image_u8{int(width), int(height), {uint8_data, uint8_data + uint8.get_size()}};
+ bilinear_resize(src, dst, new_w, new_h);
+ return padding_336(ov::Tensor{ov::element::u8, {1, new_h, new_w, 3}, dst.buf.data()});
+}
+
+ov::Tensor mean_scale(const ov::Tensor& uint8, const ProcessorConfig& config) {
+ uint8_t* uint_8_data = uint8.data();
+ ov::Tensor float_normalized{ov::element::f32, uint8.get_shape()};
+ float* float_data = float_normalized.data();
+ OPENVINO_ASSERT(0 == uint8.get_size() % 3, "RGB");
+ for (size_t idx = 0; idx < uint8.get_size(); idx += 3) {
+ float_data[idx] = (float(uint_8_data[idx]) / 255.0f - config.image_mean[0]) / config.image_std[0];
+ float_data[idx + 1] = (float(uint_8_data[idx + 1]) / 255.0f - config.image_mean[1]) / config.image_std[1];
+ float_data[idx + 2] = (float(uint_8_data[idx + 2]) / 255.0f - config.image_mean[2]) / config.image_std[2];
+ }
+ return float_normalized;
+}
+
+ov::Tensor channels_first(const ov::Tensor& _1hw3) {
+ ov::Shape shape = _1hw3.get_shape();
+ ov::Tensor _13hw = ov::Tensor{ov::element::f32, {1, 3, shape.at(1), shape.at(2)}};
+ float* _1hw3_data = _1hw3.data();
+ float* _13hw_data = _13hw.data();
+ for (size_t plane = 0; plane < 3; ++plane) {
+ for (size_t row = 0; row < shape.at(1); ++row) {
+ for (size_t col = 0; col < shape.at(2); ++col) {
+ _13hw_data[plane * shape.at(1) * shape.at(2) + row * shape.at(2) + col] = _1hw3_data[row * shape.at(2) * 3 + col * 3 + plane];
+ }
+ }
+ }
+ return _13hw;
+}
+
+// Reimplementation of Python im.reshape(1, 3, h//336, 336, w//336, 336).permute(0,2,4,1,3,5).reshape(-1, 3, 336, 336)
+ov::Tensor slice_image(const ov::Tensor& image) {
+ ov::Shape shape = image.get_shape();
+ size_t N = shape[0];
+ size_t C = shape[1];
+ size_t H = shape[2];
+ size_t W = shape[3];
+
+ size_t num_h_slices = H / INPUT_IMAGE_SIZE;
+ size_t num_w_slices = W / INPUT_IMAGE_SIZE;
+
+ // Step 1: Define and populate the reshaped tensor in the correct shape order
+ ov::Tensor reshaped{ov::element::f32, {N, num_h_slices, num_w_slices, C, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE}};
+ float* reshaped_data = reshaped.data();
+ float* image_data = image.data();
+
+ // Populate the reshaped tensor
+ for (size_t n = 0; n < N; ++n) {
+ for (size_t h = 0; h < num_h_slices; ++h) {
+ for (size_t w = 0; w < num_w_slices; ++w) {
+ for (size_t c = 0; c < C; ++c) {
+ for (size_t i = 0; i < INPUT_IMAGE_SIZE; ++i) {
+ for (size_t j = 0; j < INPUT_IMAGE_SIZE; ++j) {
+ size_t src_idx = n * C * H * W + c * H * W + (h * INPUT_IMAGE_SIZE + i) * W + (w * INPUT_IMAGE_SIZE + j);
+ size_t dst_idx = n * num_h_slices * num_w_slices * C * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE +
+ h * num_w_slices * C * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE +
+ w * C * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE +
+ c * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE +
+ i * INPUT_IMAGE_SIZE + j;
+ reshaped_data[dst_idx] = image_data[src_idx];
+ }
+ }
+ }
+ }
+ }
+ }
+
+ // Step 2: Define the permuted tensor in the final shape
+ ov::Tensor permuted{ov::element::f32, {N * num_h_slices * num_w_slices, C, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE}};
+ float* permuted_data = permuted.data();
+
+ // Perform permutation by flattening N, num_h_slices, and num_w_slices
+ for (size_t n = 0; n < N; ++n) {
+ for (size_t h = 0; h < num_h_slices; ++h) {
+ for (size_t w = 0; w < num_w_slices; ++w) {
+ for (size_t c = 0; c < C; ++c) {
+ for (size_t i = 0; i < INPUT_IMAGE_SIZE; ++i) {
+ for (size_t j = 0; j < INPUT_IMAGE_SIZE; ++j) {
+ size_t src_idx = n * num_h_slices * num_w_slices * C * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE +
+ h * num_w_slices * C * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE +
+ w * C * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE +
+ c * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE +
+ i * INPUT_IMAGE_SIZE + j;
+ size_t dst_idx = (n * num_h_slices * num_w_slices + h * num_w_slices + w) * C * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE +
+ c * INPUT_IMAGE_SIZE * INPUT_IMAGE_SIZE +
+ i * INPUT_IMAGE_SIZE + j;
+ permuted_data[dst_idx] = reshaped_data[src_idx];
+ }
+ }
+ }
+ }
+ }
+ }
+
+ return permuted;
+}
+
+ov::Tensor concatenate_batch(const ov::Tensor& float_first, const ov::Tensor& float_second) {
+ ov::Shape shape_first = float_first.get_shape();
+ ov::Shape shape_second = float_second.get_shape();
+ OPENVINO_ASSERT(shape_first.at(1) == shape_second.at(1), "Channels must be the same");
+ OPENVINO_ASSERT(shape_first.at(2) == shape_second.at(2), "Height must be the same");
+ OPENVINO_ASSERT(shape_first.at(3) == shape_second.at(3), "Width must be the same");
+ ov::Tensor concatenated{ov::element::f32, {shape_first.at(0) + shape_second.at(0), shape_first.at(1), shape_first.at(2), shape_first.at(3)}};
+ float* concatenated_data = concatenated.data();
+ float* first_data = float_first.data();
+ float* second_data = float_second.data();
+ std::copy(first_data, first_data + float_first.get_size(), concatenated_data);
+ std::copy(second_data, second_data + float_second.get_size(), concatenated_data + float_first.get_size());
+ return concatenated;
+}
+
+ov::Tensor pad_to_max_num_crops_tensor(const ov::Tensor& nchw, size_t max_crops) {
+ ov::Shape shape = nchw.get_shape();
+ size_t num_crops = shape[0];
+ if (num_crops >= max_crops) {
+ return nchw;
+ }
+ ov::Tensor padded{ov::element::f32, {max_crops, shape[1], shape[2], shape[3]}};
+ float* padded_data = padded.data();
+ float* nchw_data = nchw.data();
+ std::copy_n(nchw_data, nchw.get_size(), padded_data);
+ return padded;
+}
+
+std::tuple get_pixel_values_phi3_v(const ov::Tensor& image, const ProcessorConfig& config) {
+ ov::Tensor hd_image = HD_transform(image, config.phi3_v.num_crops);
+ ImageSize image_size{hd_image.get_shape().at(2), hd_image.get_shape().at(1)};
+ clip_image_u8 img{int(hd_image.get_shape().at(2)), int(hd_image.get_shape().at(1)), {hd_image.data(), hd_image.data() + hd_image.get_size()}};
+ clip_image_u8 dst;
+ bicubic_resize(img, dst, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE);
+ ov::Tensor global_image{ov::element::u8, {1, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE, 3}, dst.buf.data()};
+ global_image = mean_scale(global_image, config);
+ hd_image = mean_scale(hd_image, config);
+ global_image = channels_first(global_image);
+ hd_image = channels_first(hd_image);
+ ov::Tensor slices = slice_image(hd_image);
+ ov::Tensor concatenated = concatenate_batch(global_image, slices);
+ ov::Tensor pixel_values = pad_to_max_num_crops_tensor(concatenated, config.phi3_v.num_crops);
+ return {std::move(pixel_values), image_size};
+}
+} // namespace phi3_v
+
ImageSize smart_resize_qwen2vl(size_t height, size_t width, size_t factor, size_t min_pixels, size_t max_pixels) {
if (height < factor || width < factor) {
OPENVINO_THROW("Height or width must be larger than factor");
@@ -832,6 +1028,8 @@ EncodedImage VisionEncoder::encode(const ov::Tensor& image, const ProcessorConfi
return encode_llava_next(image, config);
} else if (model_type == VLMModelType::INTERNVL_CHAT) {
return encode_internvl(image, config);
+ } else if (model_type == VLMModelType::PHI3_V) {
+ return encode_phi3_v(image, config);
} else if (model_type == VLMModelType::QWEN2_VL) {
return encode_qwen2vl(image, config);
} else {
@@ -908,6 +1106,13 @@ EncodedImage VisionEncoder::encode_internvl(const ov::Tensor& image, const Proce
return {std::move(image_features), resized_source_size};
}
+EncodedImage VisionEncoder::encode_phi3_v(const ov::Tensor& image, const ProcessorConfig& config) {
+ const auto& [pixel_values, image_size] = phi3_v::get_pixel_values_phi3_v(image, config);
+ m_vision_encoder.set_input_tensor(pixel_values);
+ m_vision_encoder.infer();
+ return {m_vision_encoder.get_output_tensor(), image_size};
+}
+
EncodedImage VisionEncoder::encode_qwen2vl(const ov::Tensor& image, const ProcessorConfig& config) {
ov::Shape image_shape = image.get_shape();
auto original_height = image_shape.at(1);
diff --git a/src/cpp/src/visual_language/vision_encoder.hpp b/src/cpp/src/visual_language/vision_encoder.hpp
index e725c06bf4..8bec971894 100644
--- a/src/cpp/src/visual_language/vision_encoder.hpp
+++ b/src/cpp/src/visual_language/vision_encoder.hpp
@@ -159,6 +159,10 @@ class VisionEncoder {
const ov::Tensor& image, const ProcessorConfig& config
);
+ EncodedImage encode_phi3_v(
+ const ov::Tensor& image, const ProcessorConfig& config
+ );
+
EncodedImage encode_qwen2vl(
const ov::Tensor& image, const ProcessorConfig& config
);
diff --git a/src/cpp/src/visual_language/vlm_config.cpp b/src/cpp/src/visual_language/vlm_config.cpp
index 6eab781fc0..5609c886c4 100644
--- a/src/cpp/src/visual_language/vlm_config.cpp
+++ b/src/cpp/src/visual_language/vlm_config.cpp
@@ -19,4 +19,13 @@ ov::genai::VLMConfig::VLMConfig(const std::filesystem::path& json_path) {
// Setting llava_next specific config params
read_json_param(parsed, "image_newline", image_newline);
+ // phi3_v
+ if (parsed.contains("sub_GN")) {
+ sub_GN = parsed.at("sub_GN").get>>>>().at(0).at(0).at(0);
+ }
+ OPENVINO_ASSERT(sub_GN.size() == 4096);
+ if (parsed.contains("glb_GN")) {
+ glb_GN = parsed.at("glb_GN").get>>>().at(0).at(0);
+ }
+ OPENVINO_ASSERT(glb_GN.size() == 4096);
}
diff --git a/src/cpp/src/visual_language/vlm_config.hpp b/src/cpp/src/visual_language/vlm_config.hpp
index c70c757707..7a052b8537 100644
--- a/src/cpp/src/visual_language/vlm_config.hpp
+++ b/src/cpp/src/visual_language/vlm_config.hpp
@@ -54,6 +54,9 @@ class VLMConfig {
std::string image_context_token = "";
/// @brief A string token denoting end of image embeddings for InternVL2 model.
std::string image_end_token = "";
+ /// @brief phi3_v new line token embedding to separate images.
+ std::vector sub_GN = std::vector(4096, 0.0f);
+ std::vector glb_GN = std::vector(4096, 0.0f);
/// @brief A string token denoting start of vision embeddings for Qwen2VL model.
std::string vision_start_token = "<|vision_start|>";
diff --git a/src/cpp/src/visual_language/vlm_model_type.hpp b/src/cpp/src/visual_language/vlm_model_type.hpp
index 6f554fbf98..93387cacbc 100644
--- a/src/cpp/src/visual_language/vlm_model_type.hpp
+++ b/src/cpp/src/visual_language/vlm_model_type.hpp
@@ -16,6 +16,7 @@ enum class VLMModelType {
LLAVA,
LLAVA_NEXT,
INTERNVL_CHAT,
+ PHI3_V,
QWEN2_VL,
};
@@ -25,6 +26,7 @@ inline VLMModelType to_vlm_model_type(const std::string& value) {
{"llava", VLMModelType::LLAVA},
{"llava_next", VLMModelType::LLAVA_NEXT},
{"internvl_chat", VLMModelType::INTERNVL_CHAT},
+ {"phi3_v", VLMModelType::PHI3_V},
{"qwen2_vl", VLMModelType::QWEN2_VL}
};
diff --git a/tests/python_tests/test_vlm_pipeline.py b/tests/python_tests/test_vlm_pipeline.py
index b413b6cf1d..0f9358b961 100644
--- a/tests/python_tests/test_vlm_pipeline.py
+++ b/tests/python_tests/test_vlm_pipeline.py
@@ -9,17 +9,17 @@
from openvino_genai import VLMPipeline, GenerationConfig
from common import get_image_by_link, get_beam_search, get_multinomial_all_parameters, get_default_properties
-def get_ov_model(cache):
- model_dir = cache.mkdir("tiny-random-minicpmv-2_6")
+def get_ov_model(model_id, cache):
+ model_dir = cache.mkdir(model_id.split('/')[-1])
if (model_dir / "openvino_language_model.xml").exists():
return model_dir
- model_id = "katuni4ka/tiny-random-minicpmv-2_6"
processor = transformers.AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
processor.tokenizer.save_pretrained(model_dir)
ov_tokenizer, ov_detokenizer = openvino_tokenizers.convert_tokenizer(processor.tokenizer, with_detokenizer=True)
openvino.save_model(ov_tokenizer, model_dir / "openvino_tokenizer.xml")
openvino.save_model(ov_detokenizer, model_dir / "openvino_detokenizer.xml")
model = OVModelForVisualCausalLM.from_pretrained(model_id, compile=False, device="CPU", export=True, load_in_8bit=False, trust_remote_code=True, ov_config=get_default_properties())
+ processor.chat_template = processor.tokenizer.chat_template # It seems that tiny-random-phi3-vision is saved incorrectly. That line works this around.
processor.save_pretrained(model_dir)
model.save_pretrained(model_dir)
return model_dir
@@ -44,13 +44,17 @@ def get_ov_model(cache):
@pytest.mark.precommit
@pytest.mark.nightly
-def test_vlm_pipeline(cache):
+@pytest.mark.parametrize("model_id", [
+ "katuni4ka/tiny-random-minicpmv-2_6",
+ "katuni4ka/tiny-random-phi3-vision",
+])
+def test_vlm_pipeline(model_id, cache):
def streamer(word: str) -> bool:
nonlocal result_from_streamer
result_from_streamer.append(word)
return False
- models_path = get_ov_model(cache)
+ models_path = get_ov_model(model_id, cache)
generation_config = GenerationConfig(max_new_tokens=30)
for links in image_links_for_testing:
@@ -76,7 +80,7 @@ def streamer(word: str) -> bool:
@pytest.mark.precommit
@pytest.mark.nightly
def test_vlm_get_tokenizer(cache):
- models_path = get_ov_model(cache)
+ models_path = get_ov_model("katuni4ka/tiny-random-minicpmv-2_6", cache)
pipe = VLMPipeline(models_path, "CPU")
tokenizer = pipe.get_tokenizer()
tokenizer.encode("")
@@ -89,15 +93,16 @@ def test_vlm_get_tokenizer(cache):
get_multinomial_all_parameters(),
])
def test_sampling(config, cache):
- models_path = get_ov_model(cache)
+ models_path = get_ov_model("katuni4ka/tiny-random-minicpmv-2_6", cache)
image = get_image_by_link(image_links[0])
pipe = VLMPipeline(models_path, "CPU")
pipe.generate(prompts[0], image=image, generation_config=config)
@pytest.mark.precommit
+@pytest.mark.nightly
def test_perf_metrics(cache):
import numpy as np
- models_path = get_ov_model(cache)
+ models_path = get_ov_model("katuni4ka/tiny-random-minicpmv-2_6", cache)
images = [get_image_by_link(image_links[0])]