Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

changing dimensions of batch size, kv cache and num_input_heads #793

Merged
Changes from 16 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
006c942
chanign dimensions of batch size, kv cache and num_input_heads
lamiayous Aug 23, 2024
2f88d30
Merge branch 'master' of https://github.com/openvinotoolkit/openvino.…
lamiayous Aug 26, 2024
7d5cb93
Support for Qwen and isolating changes to functions
lamiayous Aug 27, 2024
0969051
indentation fix
lamiayous Aug 27, 2024
a5abd8f
Merge branch 'master' into ly/handling_tensors_access_in_genai
lamiayous Aug 27, 2024
73181e3
Merge branch 'master' of https://github.com/openvinotoolkit/openvino.…
lamiayous Aug 28, 2024
d223867
Merge branch 'master' of https://github.com/openvinotoolkit/openvino.…
lamiayous Aug 28, 2024
e929a48
Merge branch 'ly/handling_tensors_access_in_genai' of https://github.…
lamiayous Aug 28, 2024
323e1ce
Merge branch 'master' of https://github.com/openvinotoolkit/openvino.…
lamiayous Aug 28, 2024
bc84a24
fix
lamiayous Aug 28, 2024
c208f1f
passing KVAxesPosition to reshape_to_static
lamiayous Sep 3, 2024
be7a41c
typo fix
lamiayous Sep 3, 2024
a119bef
remove debug print
lamiayous Sep 3, 2024
311c358
changed to strict model_type comparison
lamiayous Sep 3, 2024
2896476
fix typo
lamiayous Sep 3, 2024
7019379
Merge branch 'master' into ly/handling_tensors_access_in_genai
lamiayous Sep 3, 2024
9bcaebd
Update src/cpp/src/llm_pipeline_static.cpp
lamiayous Sep 3, 2024
6359a73
Merge branch 'master' into ly/handling_tensors_access_in_genai
ilya-lavrenov Sep 3, 2024
db7c28b
fix typo
lamiayous Sep 4, 2024
88e6b0f
Merge branch 'ly/handling_tensors_access_in_genai' of https://github.…
lamiayous Sep 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions src/cpp/src/llm_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "utils.hpp"

#include <openvino/pass/stateful_to_stateless.hpp>
#include <jinja2cpp/user_callable.h>
#include <fstream>

namespace {

Expand Down Expand Up @@ -83,9 +85,41 @@ std::shared_ptr<ov::Model> add_slices_to_kvcache_inputs(const std::shared_ptr<ov
return std::make_shared<ov::Model>(model->get_results(), ov::SinkVector{}, new_params);
}

struct KVAxesPosition {
uint32_t batch;
uint32_t seq_len;
};

KVAxesPosition get_kv_axes(const std::string& model_type) {
KVAxesPosition axes;
if (model_type == "chatglm") {
axes.batch = 1u;
axes.seq_len = 0u;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ilya-lavrenov Yes, you're right. The only drawback is that this function cannot be used for .blob's case as they will be stateless + static.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is it supposed to pass compiled blob via current LLMPipeline API ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's on review so far: #811

Copy link
Collaborator

@TolyaTalamanov TolyaTalamanov Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current solution aim to handle qwen / chatglm cases that are crucial for now.

But in general, I'd prefer using your approach but somewhere inside StatefulToStateless transformation, so that it could save necessary metadata that will be available from both xml / blob formats.

} else if (model_type == "qwen") {
// Note, qwen2 does not fall into this category and conforms to default layout
axes.batch = 0u;
axes.seq_len = 1u;
} else {
axes.batch = 0u;
axes.seq_len = 2u;
}
return axes;
}

std::string get_model_type_from_json(const std::filesystem::path& filepath) {
std::ifstream file(filepath);
if (!file.is_open()) {
throw std::runtime_error("Could not open file: " + filepath.string());
}
nlohmann::json config_data = nlohmann::json::parse(file);
std::string model_type = config_data["model_type"].get<std::string>();
return model_type;
}

void reshape_to_static(std::shared_ptr<ov::Model> model,
const uint32_t input_size,
const uint32_t kvcache_size) {
const uint32_t kvcache_size,
const KVAxesPosition& kv_axes_position) {
std::map<std::string, ov::PartialShape> new_shapes;
for (auto input : model->inputs()) {
const auto& input_name = input.get_any_name();
Expand All @@ -98,10 +132,9 @@ void reshape_to_static(std::shared_ptr<ov::Model> model,
new_shape = ov::PartialShape({1, input_size});
} else {
const auto& partial_shape = input.get_partial_shape();
new_shape = ov::PartialShape({1,
partial_shape[1].get_length(),
kvcache_size-input_size,
partial_shape[3].get_length()});
new_shape = partial_shape;
new_shape[kv_axes_position.batch] = 1;
new_shape[kv_axes_position.seq_len] = kvcache_size - input_size;
}
new_shapes.emplace(input_name, new_shape);
}
Expand Down Expand Up @@ -222,10 +255,10 @@ StaticLLMPipeline::StaticLLMPipeline(
// (6) Reshape both models to static shape
const auto kMaxPromptLen = pop_or_default(pipeline_config, "MAX_PROMPT_LEN", 1024u);
const auto kMinResponseLen = pop_or_default(pipeline_config, "MIN_RESPONSE_LEN", 150u);
// FIXME For some models KV-cache dim != 2u
m_kvcache_desc = KVCacheDesc { kMaxPromptLen, kMaxPromptLen + kMinResponseLen, 0u, 2u };
reshape_to_static(m_prefill_model, m_kvcache_desc.max_prompt_size, m_kvcache_desc.max_prompt_size);
reshape_to_static(m_kvcache_model, 1u, m_kvcache_desc.total_size);
KVAxesPosition axes = get_kv_axes(get_model_type_from_json(path / "config.json"));
m_kvcache_desc = KVCacheDesc { kMaxPromptLen, kMaxPromptLen + kMinResponseLen, 0u, axes.seq_len };
reshape_to_static(m_prefill_model, m_kvcache_desc.max_prompt_size, m_kvcache_desc.max_prompt_size, axes);
reshape_to_static(m_kvcache_model, 1u, m_kvcache_desc.total_size, axes);
// (7) Compile both model
auto prefill_config = pop_or_default(pipeline_config, "PREFILL_CONFIG", get_default_prefill_config());
auto generate_config = pop_or_default(pipeline_config, "GENERATE_CONFIG", get_default_generate_config());
Expand Down
Loading