Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

changing dimensions of batch size, kv cache and num_input_heads #793

Merged
Changes from 5 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
006c942
chanign dimensions of batch size, kv cache and num_input_heads
lamiayous Aug 23, 2024
2f88d30
Merge branch 'master' of https://github.com/openvinotoolkit/openvino.…
lamiayous Aug 26, 2024
7d5cb93
Support for Qwen and isolating changes to functions
lamiayous Aug 27, 2024
0969051
indentation fix
lamiayous Aug 27, 2024
a5abd8f
Merge branch 'master' into ly/handling_tensors_access_in_genai
lamiayous Aug 27, 2024
73181e3
Merge branch 'master' of https://github.com/openvinotoolkit/openvino.…
lamiayous Aug 28, 2024
d223867
Merge branch 'master' of https://github.com/openvinotoolkit/openvino.…
lamiayous Aug 28, 2024
e929a48
Merge branch 'ly/handling_tensors_access_in_genai' of https://github.…
lamiayous Aug 28, 2024
323e1ce
Merge branch 'master' of https://github.com/openvinotoolkit/openvino.…
lamiayous Aug 28, 2024
bc84a24
fix
lamiayous Aug 28, 2024
c208f1f
passing KVAxesPosition to reshape_to_static
lamiayous Sep 3, 2024
be7a41c
typo fix
lamiayous Sep 3, 2024
a119bef
remove debug print
lamiayous Sep 3, 2024
311c358
changed to strict model_type comparison
lamiayous Sep 3, 2024
2896476
fix typo
lamiayous Sep 3, 2024
7019379
Merge branch 'master' into ly/handling_tensors_access_in_genai
lamiayous Sep 3, 2024
9bcaebd
Update src/cpp/src/llm_pipeline_static.cpp
lamiayous Sep 3, 2024
6359a73
Merge branch 'master' into ly/handling_tensors_access_in_genai
ilya-lavrenov Sep 3, 2024
db7c28b
fix typo
lamiayous Sep 4, 2024
88e6b0f
Merge branch 'ly/handling_tensors_access_in_genai' of https://github.…
lamiayous Sep 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 48 additions & 11 deletions src/cpp/src/llm_pipeline_static.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "utils.hpp"

#include <openvino/pass/stateful_to_stateless.hpp>
#include <jinja2cpp/user_callable.h>
#include <fstream>

namespace {

Expand Down Expand Up @@ -83,9 +85,35 @@ std::shared_ptr<ov::Model> add_slices_to_kvcache_inputs(const std::shared_ptr<ov
return std::make_shared<ov::Model>(model->get_results(), ov::SinkVector{}, new_params);
}

struct KVAxesPosition {
uint32_t batch;
uint32_t seq_len;
};

KVAxesPosition get_kv_axes(const std::string& model_type) {
if (model_type == "chatglm")
return {1, 0};
else if (model_type == "Qwen")
return {0, 1};
else
return {0, 2};
}

std::string get_model_type(const std::filesystem::path& filepath) {
std::ifstream file(filepath);
if (!file.is_open()) {
throw std::runtime_error("Could not open file: " + filepath.string());
}
nlohmann::json config_data = nlohmann::json::parse(file);
std::string model_type = config_data["model_type"].get<std::string>();
return model_type;
}

void reshape_to_static(std::shared_ptr<ov::Model> model,
const uint32_t input_size,
const uint32_t kvcache_size) {
const uint32_t kvcache_size,
const std::filesystem::path& path) {
auto config_file_path = path / "config.json";
std::map<std::string, ov::PartialShape> new_shapes;
for (auto input : model->inputs()) {
const auto& input_name = input.get_any_name();
Expand All @@ -98,10 +126,11 @@ void reshape_to_static(std::shared_ptr<ov::Model> model,
new_shape = ov::PartialShape({1, input_size});
} else {
const auto& partial_shape = input.get_partial_shape();
new_shape = ov::PartialShape({1,
partial_shape[1].get_length(),
kvcache_size-input_size,
partial_shape[3].get_length()});
new_shape = partial_shape;
std::string model_type = get_model_type(config_file_path.string());
KVAxesPosition kv_axes_position = get_kv_axes(model_type);
new_shape[kv_axes_position.batch] = 1;
new_shape[kv_axes_position.seq_len] = kvcache_size - input_size;
}
new_shapes.emplace(input_name, new_shape);
}
Expand Down Expand Up @@ -219,13 +248,21 @@ StaticLLMPipeline::StaticLLMPipeline(
// (5) Clone the model - this will be prefill
m_prefill_model = m_kvcache_model->clone();
m_prefill_model->set_friendly_name(m_kvcache_model->get_friendly_name() + "_prefill");
std::string model_type = get_model_type(path / "config.json");
uint32_t kv_dims;
if (model_type == "chatglm")
kv_dims = 0u;
else if (model_type == "Qwen")
kv_dims = 1u;
else
kv_dims = 2u;
m_kvcache_desc = KVCacheDesc { 1024u, 0u, kv_dims };
// (6) Reshape both models to static shape
const auto kMaxPromptLen = pop_or_default(pipeline_config, "MAX_PROMPT_LEN", 1024u);
const auto kMinResponseLen = pop_or_default(pipeline_config, "MIN_RESPONSE_LEN", 150u);
// FIXME For some models KV-cache dim != 2u
m_kvcache_desc = KVCacheDesc { kMaxPromptLen, kMaxPromptLen + kMinResponseLen, 0u, 2u };
reshape_to_static(m_prefill_model, m_kvcache_desc.max_prompt_size, m_kvcache_desc.max_prompt_size);
reshape_to_static(m_kvcache_model, 1u, m_kvcache_desc.total_size);
const uint32_t max_prompt_size = m_kvcache_desc.total_size;
const uint32_t max_kvcache_size = m_kvcache_desc.total_size;
reshape_to_static(m_prefill_model, max_prompt_size, max_kvcache_size, path);
reshape_to_static(m_kvcache_model, 1u, max_kvcache_size, path);

// (7) Compile both model
auto prefill_config = pop_or_default(pipeline_config, "PREFILL_CONFIG", get_default_prefill_config());
auto generate_config = pop_or_default(pipeline_config, "GENERATE_CONFIG", get_default_generate_config());
Expand Down
Loading