Skip to content

Commit

Permalink
[TRANSFROMATIONS] Add support for 'inputs_embeds' input in SDPAToPA (o…
Browse files Browse the repository at this point in the history
…penvinotoolkit#27158)

[TRANSFROMATIONS] Add support for 'inputs_embeds' input in SDPAToPA

Add support for 'input_embeds' input in SDPAToPA transformation.
The input is used in VLM instead of 'input_ids' in text-only models.

The changes enable support of the SDPAToPA transformation for the
following models:
 * llava-hf/llava-1.5-7b-hf
 * llava-hf/llava-v1.6-mistral-7b-hf
 * llava-hf/llava-v1.6-vicuna-7b-hf
 * llava-hf/llama3-llava-next-8b-hf
 * openbmb/MiniCPM-V-2_6

- Ticket:
[CVS-156956](https://jira.devtools.intel.com/browse/CVS-156956)

Signed-off-by: Andrii Staikov <andrii.staikov@intel.com>

---------

Co-authored-by: Pawel Raasz <pawel.raasz@intel.com>
  • Loading branch information
CuriousPanCake and praasz authored Nov 18, 2024
1 parent 58672c3 commit b87f635
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 64 deletions.
82 changes: 44 additions & 38 deletions src/core/src/pass/sdpa_to_paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ static std::shared_ptr<v0::Parameter> setName(std::shared_ptr<v0::Parameter> nod
// Set name for both node and output tensor (should be only one tensor, and any other names will be overriden by a
// given single name)
node->set_friendly_name(name);
OPENVINO_ASSERT(node->get_output_size() == 1); // Should I use assert here?
OPENVINO_ASSERT(node->get_output_size() == 1);
node->get_output_tensor(0).set_names({name});
return node;
}
Expand All @@ -53,8 +53,33 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode

auto sliding_window = v0::Constant::create(element::i32, Shape{}, {0}); // sliding_window

std::shared_ptr<v0::Parameter> input_ids_node =
std::dynamic_pointer_cast<v0::Parameter>(model->input("input_ids").get_node_shared_ptr());
auto get_parameter = [=](const std::shared_ptr<ov::Model>& model,
const std::string& name) -> std::shared_ptr<v0::Parameter> {
for (const auto& param : model->inputs()) {
const auto& names = param.get_names();
if (names.count(name)) {
if (auto casted_param = ov::as_type_ptr<v0::Parameter>(param.get_node_shared_ptr())) {
return casted_param;
} else {
OPENVINO_THROW("The model is in the inconsistent state. Found input '",
name,
"', but couldn't cast it to v0::Parameter.");
}
}
}

return nullptr;
};

std::shared_ptr<v0::Parameter> input_ids_node;
for (const auto& name : {"input_ids", "inputs_embeds"}) {
if ((input_ids_node = get_parameter(model, name))) {
break;
}
}

OPENVINO_ASSERT(input_ids_node, "The model doesn't contain input_ids or input_embeds input. Aborting.");

input_ids_node->set_partial_shape(PartialShape{-1});
auto unsqueezed_input_ids =
std::make_shared<v0::Unsqueeze>(input_ids_node, v0::Constant::create(element::i32, Shape{}, {1}));
Expand All @@ -66,25 +91,14 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
auto prev_max_seq_len =
std::make_shared<v1::Subtract>(max_context_len, std::make_shared<v0::Convert>(cur_seq_len, element::i32));

auto has_parameter = [=](const std::shared_ptr<ov::Model>& model, const std::string& name) -> bool {
for (auto& t : model->inputs()) {
const auto& names = t.get_names();
if (names.find(name) != names.end()) {
return true;
}
}

return false;
};

ParameterVector kv_parameters;
ParameterVector parameters_to_remove;
ResultVector results_to_remove; // # used, but cannot really track all Results in stateless model
ParameterVector block_indices_inputs;
ResultVector score_results;

std::shared_ptr<v0::Parameter> position_ids;
if (!has_parameter(model, "position_ids")) {
if (!get_parameter(model, "position_ids")) {
position_ids = setName(std::make_shared<v0::Parameter>(element::i64, PartialShape{-1}), "position_ids");
model->add_parameters({position_ids});
} else {
Expand Down Expand Up @@ -136,30 +150,22 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptr<ov::Mode
}

for (auto& param_name : {"beam_idx", "attention_mask"}) {
if (has_parameter(model, param_name)) {
if (const auto& param =
std::dynamic_pointer_cast<v0::Parameter>(model->input(param_name).get_node_shared_ptr())) {
model->remove_parameter(param);

if (param->output(0).get_target_inputs().size() == 0) {
std::stringstream consumers;
consumers << std::endl;
for (auto& input : param->output(0).get_target_inputs()) {
consumers << *input.get_node() << std::endl;
}
OPENVINO_ASSERT(param->output(0).get_target_inputs().size() == 0,
"PagedAttention transformation failed: couldn't remove ",
param->output(0).get_target_inputs().size(),
" inputs of ",
param_name,
" input: ",
consumers.str());
if (auto param = get_parameter(model, param_name)) {
model->remove_parameter(param);

if (param->output(0).get_target_inputs().size() == 0) {
std::stringstream consumers;
consumers << std::endl;
for (auto& input : param->output(0).get_target_inputs()) {
consumers << *input.get_node() << std::endl;
}
} else {
OPENVINO_THROW("The model is in the inconsistent state. Found input '",
param_name,
"', but couldn't cast it to v0::Parameter.");
return false;
OPENVINO_ASSERT(param->output(0).get_target_inputs().size() == 0,
"PagedAttention transformation failed: couldn't remove ",
param->output(0).get_target_inputs().size(),
" inputs of ",
param_name,
" input: ",
consumers.str());
}
}
}
Expand Down
46 changes: 39 additions & 7 deletions tests/model_hub_tests/transformation_tests/generate_ref_diffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,36 @@
from openvino._offline_transformations import paged_attention_transformation
from openvino._pyopenvino.op import _PagedAttentionExtension, Parameter, Result
from optimum.intel import OVModelForCausalLM
from optimum.intel.openvino import OVModelForVisualCausalLM
from typing import Type, Union

nodes_to_compare = ("ScaledDotProductAttention", "PagedAttentionExtension", "Parameter", "ReadValue", "Assign")

def get_models_list_type(file_name: str, cls: Union[Type[OVModelForCausalLM], Type[OVModelForVisualCausalLM]]):
models = []
for line_items in utils.parse_list_file(file_name):
if len(line_items) == 2:
model_name, model_link = line_items
models.append((model_name, model_link, None, None, cls))
elif len(line_items) == 4:
model_name, model_link, mark, reason = line_items
models.append((model_name, model_link, mark, reason))
elif len(line_items) > 4:
model_name, model_link, mark, reason, *other = line_items
if not mark:
mark = None
if not reason:
reason = None
other = line_items[4:]
transformations = [item[8:] for item in other if item.startswith('ts_name:')]
layers = [item[6:] for item in other if item.startswith('layer:')]
models.append((model_name, model_link, mark, reason, transformations, layers))
else:
items = ','.join(line_items)
assert False, \
f'Incorrect model info fields {items}. It must contain either 2 or 4 or more than 4 fields.'
return models

def main():
use_cache_eviction = False
if len(sys.argv) >= 2:
Expand All @@ -55,32 +82,37 @@ def main():

if OUTPUT_FILE.exists() and OUTPUT_FILE.is_file():
OUTPUT_FILE.unlink()

with open(OUTPUT_FILE, 'w') as file:
model_list = utils.get_models_list(os.path.join(os.path.dirname(__file__), "models", "hf-tiny-random-models-precommit"))
model_list = get_models_list_type(os.path.join(os.path.dirname(__file__), "models", "hf-tiny-random-models-precommit"), OVModelForCausalLM)
model_list.extend(get_models_list_type(os.path.join(os.path.dirname(__file__), "models", "hf-tiny-random-vl-models-precommit"), OVModelForVisualCausalLM))
print(OUTPUT_FILE)
print('ref_diff_map_cache_eviction = {' if use_cache_eviction else 'ref_diff_map = {', file=file)

for model_id, _, _, _ in model_list:
for model_id, _, _, _, cls in model_list:
# wrapping in try/catch block to continue printing models even if one has failed
try:
model = OVModelForCausalLM.from_pretrained(model_id, export=True, trust_remote_code=True)
model = cls.from_pretrained(model_id, export=True, trust_remote_code=True)
except:
print(f"Couldn't read {model_id}.")
continue

ov_model = model.model if cls is OVModelForCausalLM else model.lm_model

before_map = {}
for op in model.model.get_ordered_ops():
for op in ov_model.get_ordered_ops():
if op.get_type_name() in nodes_to_compare:
before_map[op.get_type_name()] = before_map.get(op.get_type_name(), 0) + 1

# wrapping in try/catch block to continue printing models even if one has failed
try:
paged_attention_transformation(model.model, use_cache_eviction, use_cache_eviction)
paged_attention_transformation(ov_model, use_cache_eviction, use_cache_eviction)
except:
print(f"Couldn't run SDPAToPA transformation on {model_id} and generate diffs.")
continue

after_map = {}
for op in model.model.get_ordered_ops():
for op in ov_model.get_ordered_ops():
if op.get_type_name() in nodes_to_compare:
after_map[op.get_type_name()] = after_map.get(op.get_type_name(), 0) + 1

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
katuni4ka/tiny-random-llava-next,https://huggingface.co/katuni4ka/tiny-random-llava-next
katuni4ka/tiny-random-minicpmv-2_6,https://huggingface.co/katuni4ka/tiny-random-minicpmv-2_6
katuni4ka/tiny-random-llava,https://huggingface.co/katuni4ka/tiny-random-llava
katuni4ka/tiny-random-nanollava,https://huggingface.co/katuni4ka/tiny-random-nanollava,xfail,CVS-157416
72 changes: 65 additions & 7 deletions tests/model_hub_tests/transformation_tests/sdpa2pa_ref_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,35 @@
"ReadValue" : -12,
"Assign" : -12,
},
"katuni4ka/tiny-random-llava-next" : {
"PagedAttentionExtension" : 2,
"Parameter" : 7,
"ReadValue" : -4,
"ScaledDotProductAttention" : -2,
"Assign" : -4,
},
"katuni4ka/tiny-random-minicpmv-2_6" : {
"PagedAttentionExtension" : 2,
"Parameter" : 7,
"ReadValue" : -4,
"ScaledDotProductAttention" : -2,
"Assign" : -4,
},
"katuni4ka/tiny-random-llava" : {
"Assign" : -4,
"Parameter" : 7,
"ReadValue" : -4,
"ScaledDotProductAttention" : -2,
"PagedAttentionExtension" : 2,
},

# "katuni4ka/tiny-random-nanollava" : {
# "Assign" : -4,
# "Parameter" : 7,
# "ReadValue" : -4,
# "ScaledDotProductAttention" : -2,
# "PagedAttentionExtension" : 2,
# },
}

ref_diff_map_cache_eviction = {
Expand Down Expand Up @@ -532,13 +561,13 @@
"Parameter" : 14,
"Assign" : -8,
},
"katuni4ka/tiny-random-minicpm" : {
"ScaledDotProductAttention" : -4,
"Parameter" : 14,
"PagedAttentionExtension" : 4,
"ReadValue" : -8,
"Assign" : -8,
},
"katuni4ka/tiny-random-minicpm" : {
"ScaledDotProductAttention" : -4,
"Parameter" : 14,
"PagedAttentionExtension" : 4,
"ReadValue" : -8,
"Assign" : -8,
},
"katuni4ka/tiny-random-falcon-40b" : {
"ScaledDotProductAttention" : -2,
"ReadValue" : -4,
Expand Down Expand Up @@ -609,4 +638,33 @@
"Parameter" : 20,
"Assign" : -12,
},
"katuni4ka/tiny-random-llava-next" : {
"Parameter" : 8,
"Assign" : -4,
"ReadValue" : -4,
"PagedAttentionExtension" : 2,
"ScaledDotProductAttention" : -2,
},
"katuni4ka/tiny-random-minicpmv-2_6" : {
"Parameter" : 8,
"Assign" : -4,
"ReadValue" : -4,
"PagedAttentionExtension" : 2,
"ScaledDotProductAttention" : -2,
},
"katuni4ka/tiny-random-llava" : {
"ReadValue" : -4,
"Parameter" : 8,
"ScaledDotProductAttention" : -2,
"PagedAttentionExtension" : 2,
"Assign" : -4,
},

# "katuni4ka/tiny-random-nanollava" : {
# "ReadValue" : -4,
# "Parameter" : 8,
# "ScaledDotProductAttention" : -2,
# "PagedAttentionExtension" : 2,
# "Assign" : -4,
# },
}
Loading

0 comments on commit b87f635

Please sign in to comment.