Skip to content

Commit

Permalink
CB: preparation for relying on KV cache precisions from plugins (#1634)
Browse files Browse the repository at this point in the history
- Currently we have logic to detect KV cache precision and this logic
become more and more complex
- The idea is to rely on plugin's logic and compiled PA model with
`ov::element::dynamic` precisions for KV cache inputs.
- Later, take `ov::CompiledModel` and extract precisions from its
`inputs()`
- Then create tensors based on computed `num_kv_blocks` which depends on
KV cache precisions.

Currently, logic to mimic plugin's logic for KV cache precisions is
still here, but will be dropped once plugin will support
`ov::element::dynamic`
  • Loading branch information
ilya-lavrenov authored Jan 29, 2025
1 parent 4fb48de commit 5cbadd1
Show file tree
Hide file tree
Showing 18 changed files with 352 additions and 340 deletions.
4 changes: 2 additions & 2 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@
- 'src/cpp/src/generation_handle.cpp'
- 'src/cpp/src/generation_stream.hpp'
- 'src/cpp/src/model_runner.hpp'
- 'src/cpp/src/utils/paged_attention_transformations.cpp'
- 'src/cpp/src/utils/paged_attention_transformations.hpp'
- 'src/cpp/src/paged_attention_transformations.cpp'
- 'src/cpp/src/paged_attention_transformations.hpp'
- 'src/cpp/src/scheduler.hpp'
- 'src/cpp/src/sequence_group.cpp'
- 'src/cpp/src/sequence_group.hpp'
Expand Down
169 changes: 113 additions & 56 deletions src/cpp/src/cache_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,61 +45,126 @@ class TensorMmapAllocator {
#endif

namespace ov::genai {

class CacheManager {
DeviceConfig m_device_config;
std::vector<ov::Tensor> m_key_cache;
std::vector<ov::Tensor> m_value_cache;
size_t m_num_allocated_kv_blocks = 0;
size_t m_num_decoder_layers = 0;
std::string m_device;
std::vector<ov::element::Type> m_key_precisions, m_value_precisions;
std::vector<ov::PartialShape> m_key_shapes, m_value_shapes;
std::vector<ov::Tensor> m_key_cache, m_value_cache;
size_t m_num_allocated_kv_blocks = 0, m_block_size_in_bytes = 0;
ov::InferRequest m_request;
ov::Core m_core;

ov::Shape set_first_dim_and_make_static(const ov::PartialShape& shape, size_t dim) {
ov::PartialShape res_shape = shape;
res_shape[0] = dim;
OPENVINO_ASSERT(res_shape.is_static());
return res_shape.to_shape();
static ov::Shape set_kv_blocks(ov::PartialShape pshape, size_t num_kv_blocks) {
pshape[0] = num_kv_blocks;
return pshape.get_shape();
}

void update_request_tensor(size_t decoder_layer_id) {
m_request.set_tensor(std::string("key_cache.") + std::to_string(decoder_layer_id), m_key_cache[decoder_layer_id]);
m_request.set_tensor(std::string("value_cache.") + std::to_string(decoder_layer_id), m_value_cache[decoder_layer_id]);
}

ov::PartialShape patch_shape(ov::PartialShape pshape, ov::element::Type cache_type) {
OPENVINO_ASSERT(!m_device.empty(), "Internal error: device is not set");

if (m_device.find("CPU") != std::string::npos && cache_type == ov::element::u8) {
// Scale, zero point and quantized data will be stored together.
// The layout for per token per head:
// |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)|
// so, we have to extend head_size by 8, which is sizeof(float)
// for scale and sizeof(float) for zeropoint
pshape[3] += 2 * sizeof(float);
}

return pshape;
}

public:
explicit CacheManager(const DeviceConfig &device_config, ov::InferRequest request, ov::Core core) :
m_device_config(device_config),
m_request(request),
m_core(core) {
m_key_cache.reserve(m_device_config.get_num_layers());
m_value_cache.reserve(m_device_config.get_num_layers());
CacheManager(ov::InferRequest request, const DeviceConfig& device_config) :
m_request(request) {
// extract information about inference device
ov::CompiledModel compiled_model = request.get_compiled_model();
std::vector<std::string> execution_devices = compiled_model.get_property(ov::execution_devices);
OPENVINO_ASSERT(execution_devices.size() == 1, "Contituous batching: execution device is expected to be CPU or GPU, but got ", execution_devices.size(), " devices");
m_device = execution_devices[0];

// extract information about KV cache precisions and shapes
size_t kv_input_index = 0;
for (const auto& input : compiled_model.inputs()) {
for (auto & name : input.get_names()) {
auto cache_precision = input.get_element_type();

if (name.find("key_cache.") == 0) {
auto pshape = patch_shape(device_config.get_key_cache_shape(kv_input_index), cache_precision);
m_key_shapes.push_back(pshape);
m_key_precisions.push_back(cache_precision);
m_block_size_in_bytes += pshape[1].get_length() * pshape[2].get_length() * pshape[3].get_length() * cache_precision.size();
break;
} else if (name.find("value_cache.") == 0) {
auto pshape = patch_shape(device_config.get_value_cache_shape(kv_input_index), cache_precision);
m_value_shapes.push_back(pshape);
m_value_precisions.push_back(cache_precision);
m_block_size_in_bytes += pshape[1].get_length() * pshape[2].get_length() * pshape[3].get_length() * cache_precision.size();
++kv_input_index;
break;
}
}
}

m_num_decoder_layers = m_value_precisions.size();
OPENVINO_ASSERT(m_num_decoder_layers == m_key_precisions.size(), "Invalid case: a different number of K and V caches in a LLM model");
}

size_t get_num_decoder_layers() const {
return m_num_decoder_layers;
}

std::string get_device() const {
return m_device;
}

ov::element::Type get_key_cache_precision(size_t decoder_layer_id) const {
OPENVINO_ASSERT(decoder_layer_id < m_key_precisions.size());
return m_key_precisions[decoder_layer_id];
}

ov::element::Type get_value_cache_precision(size_t decoder_layer_id) const {
OPENVINO_ASSERT(decoder_layer_id < m_value_precisions.size());
return m_value_precisions[decoder_layer_id];
}

size_t get_block_size_in_bytes() const {
return m_block_size_in_bytes;
}

void allocate_cache_if_needed(size_t num_kv_blocks) {
if (m_num_allocated_kv_blocks >= num_kv_blocks) {
return;
}
OPENVINO_ASSERT(m_key_cache.size() == m_value_cache.size());
m_num_allocated_kv_blocks = num_kv_blocks;

const std::string device_name = m_device_config.get_device();
m_num_allocated_kv_blocks = num_kv_blocks;

ov::Coordinate start_key{0,0,0,0};
ov::Coordinate start_value{0,0,0,0};

if (device_name.find("GPU") == std::string::npos) {// Allocate KV caches
for (size_t decoder_layer_id = 0; decoder_layer_id < m_device_config.get_num_layers(); ++decoder_layer_id) {
ov::Shape value_cache_shape = set_first_dim_and_make_static(m_device_config.get_value_cache_shape(decoder_layer_id), num_kv_blocks);
ov::Shape key_cache_shape = set_first_dim_and_make_static(m_device_config.get_key_cache_shape(decoder_layer_id), num_kv_blocks);
if (m_device.find("GPU") == std::string::npos) {// Allocate KV caches
for (size_t decoder_layer_id = 0; decoder_layer_id < m_num_decoder_layers; ++decoder_layer_id) {
ov::Shape value_cache_shape = set_kv_blocks(m_value_shapes[decoder_layer_id], num_kv_blocks);
ov::Shape key_cache_shape = set_kv_blocks(m_key_shapes[decoder_layer_id], num_kv_blocks);

ov::element::Type key_precision = get_key_cache_precision(decoder_layer_id);
ov::element::Type value_precision = get_value_cache_precision(decoder_layer_id);

#ifdef _WIN32
ov::Tensor key_cache(m_device_config.get_cache_precision(), key_cache_shape);
ov::Tensor value_cache(m_device_config.get_cache_precision(), value_cache_shape);
ov::Tensor key_cache(key_precision, key_cache_shape);
ov::Tensor value_cache(value_precision, value_cache_shape);
#else
auto key_size = ov::shape_size(key_cache_shape) * m_device_config.get_cache_precision().size();
auto value_size = ov::shape_size(value_cache_shape) * m_device_config.get_cache_precision().size();

ov::Tensor key_cache = ov::Tensor(m_device_config.get_cache_precision(), key_cache_shape, TensorMmapAllocator(key_size));
ov::Tensor value_cache = ov::Tensor(m_device_config.get_cache_precision(), value_cache_shape, TensorMmapAllocator(value_size));
auto key_size = ov::shape_size(key_cache_shape) * key_precision.size();
auto value_size = ov::shape_size(value_cache_shape) * value_precision.size();

ov::Tensor key_cache(key_precision, key_cache_shape, TensorMmapAllocator(key_size));
ov::Tensor value_cache(value_precision, value_cache_shape, TensorMmapAllocator(value_size));
#endif

auto key_cache_roi_end = static_cast<unsigned char*>(key_cache.data());
Expand Down Expand Up @@ -137,24 +202,23 @@ class CacheManager {
if (m_key_cache.size() > decoder_layer_id) {
m_key_cache[decoder_layer_id] = key_cache;
m_value_cache[decoder_layer_id] = value_cache;
}
else {
} else {
m_key_cache.emplace_back(key_cache);
m_value_cache.emplace_back(value_cache);
}

update_request_tensor(decoder_layer_id);
}
} else {
auto remote_context = m_core.get_default_context(device_name);
for (size_t decoder_layer_id = 0; decoder_layer_id < m_device_config.get_num_layers(); ++decoder_layer_id) {
ov::Shape value_cache_shape = set_first_dim_and_make_static(m_device_config.get_value_cache_shape(decoder_layer_id), num_kv_blocks);
ov::Shape key_cache_shape = set_first_dim_and_make_static(m_device_config.get_key_cache_shape(decoder_layer_id), num_kv_blocks);
ov::Tensor key_cache = remote_context.create_tensor(m_device_config.get_cache_precision(),
key_cache_shape);
ov::Tensor value_cache = remote_context.create_tensor(m_device_config.get_cache_precision(),
value_cache_shape);
auto remote_context = m_request.get_compiled_model().get_context();

for (size_t decoder_layer_id = 0; decoder_layer_id < m_num_decoder_layers; ++decoder_layer_id) {
ov::Shape value_cache_shape = set_kv_blocks(m_value_shapes[decoder_layer_id], num_kv_blocks);
ov::Shape key_cache_shape = set_kv_blocks(m_key_shapes[decoder_layer_id], num_kv_blocks);

ov::Tensor key_cache = remote_context.create_tensor(get_key_cache_precision(decoder_layer_id), key_cache_shape);
ov::Tensor value_cache = remote_context.create_tensor(get_value_cache_precision(decoder_layer_id), value_cache_shape);

if (m_key_cache.size() > decoder_layer_id) {
ov::Coordinate end_key = m_key_cache[decoder_layer_id].get_shape();
ov::Coordinate end_value = m_value_cache[decoder_layer_id].get_shape();
Expand All @@ -167,23 +231,23 @@ class CacheManager {

m_key_cache[decoder_layer_id] = key_cache;
m_value_cache[decoder_layer_id] = value_cache;
}
else {
} else {
m_key_cache.emplace_back(key_cache);
m_value_cache.emplace_back(value_cache);
}

update_request_tensor(decoder_layer_id);
}
}
}

ov::Tensor get_key_cache(size_t decoder_layer_id) const {
OPENVINO_ASSERT(decoder_layer_id < m_key_cache.size());
OPENVINO_ASSERT(decoder_layer_id < m_key_cache.size(), "decoder_layer_id = ", decoder_layer_id, ", num_layers = ", m_key_cache.size());
return m_key_cache[decoder_layer_id];
}

ov::Tensor get_value_cache(size_t decoder_layer_id) const {
OPENVINO_ASSERT(decoder_layer_id < m_value_cache.size());
OPENVINO_ASSERT(decoder_layer_id < m_value_cache.size(), "decoder_layer_id = ", decoder_layer_id, ", num_layers = ", m_value_cache.size());
return m_value_cache[decoder_layer_id];
}

Expand All @@ -192,9 +256,9 @@ class CacheManager {
size_t src_block_id = blocks_pair.first;
const std::list<size_t>& dst_block_ids = blocks_pair.second;
for (size_t dst_block_id : dst_block_ids) {
for (size_t decoder_layer_id = 0; decoder_layer_id < m_device_config.get_num_layers(); ++decoder_layer_id) {
ov::Shape key_shape = set_first_dim_and_make_static(m_device_config.get_key_cache_shape(decoder_layer_id), m_num_allocated_kv_blocks);
ov::Shape value_shape = set_first_dim_and_make_static(m_device_config.get_value_cache_shape(decoder_layer_id), m_num_allocated_kv_blocks);
for (size_t decoder_layer_id = 0; decoder_layer_id < m_num_decoder_layers; ++decoder_layer_id) {
ov::Shape key_shape = set_kv_blocks(m_key_shapes[decoder_layer_id], m_num_allocated_kv_blocks);
ov::Shape value_shape = set_kv_blocks(m_value_shapes[decoder_layer_id], m_num_allocated_kv_blocks);
ov::Coordinate key_src_start_roi(key_shape.size(), 0);
ov::Coordinate key_src_end_roi = key_shape;
ov::Coordinate key_dst_start_roi(key_shape.size(), 0);
Expand All @@ -221,13 +285,6 @@ class CacheManager {
}
}
}

std::shared_ptr<Core> get_core() {
return std::make_shared<Core>(m_core);
}

std::shared_ptr<DeviceConfig> get_device_config() {
return std::make_shared<DeviceConfig>(m_device_config);
}
};

}
Loading

0 comments on commit 5cbadd1

Please sign in to comment.