Skip to content

Commit

Permalink
Drop DeviceConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
ilya-lavrenov committed Jan 28, 2025
1 parent 007c29c commit d59bd18
Show file tree
Hide file tree
Showing 13 changed files with 129 additions and 206 deletions.
1 change: 0 additions & 1 deletion .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@
- 'src/cpp/src/continuous_batching_impl.cpp'
- 'src/cpp/src/continuous_batching_pipeline.cpp'
- 'src/cpp/src/debug_utils.hpp'
- 'src/cpp/src/device_config.hpp'
- 'src/cpp/src/generation_handle.cpp'
- 'src/cpp/src/generation_stream.hpp'
- 'src/cpp/src/model_runner.hpp'
Expand Down
69 changes: 56 additions & 13 deletions src/cpp/src/cache_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

#include <vector>
#include <list>

#include "openvino/runtime/tensor.hpp"
#include "device_config.hpp"
#include "paged_attention_transformations.hpp"

#ifndef _WIN32
#include <sys/mman.h>
#include "openvino/core/shape.hpp"


class TensorMmapAllocator {
Expand Down Expand Up @@ -49,6 +49,7 @@ namespace ov::genai {
class CacheManager {
size_t m_num_decoder_layers = 0;
std::string m_device;
size_t m_block_size = 0; // block size is per inference device
std::vector<ov::element::Type> m_key_precisions, m_value_precisions;
std::vector<ov::PartialShape> m_key_shapes, m_value_shapes;
std::vector<ov::Tensor> m_key_cache, m_value_cache;
Expand All @@ -65,29 +66,63 @@ class CacheManager {
m_request.set_tensor(std::string("value_cache.") + std::to_string(decoder_layer_id), m_value_cache[decoder_layer_id]);
}

ov::PartialShape patch_shape(ov::PartialShape pshape, ov::element::Type cache_type) {
ov::PartialShape to_shape(const KVHeadConfig& config, ov::element::Type cache_type, bool key_param) {
OPENVINO_ASSERT(!m_device.empty(), "Internal error: device is not set");
ov::PartialShape pshape;

if (m_device.find("CPU") != std::string::npos) {
if (key_param) {
pshape = ov::PartialShape{ov::Dimension::dynamic(),
ov::Dimension(config.num_k_heads),
ov::Dimension(m_block_size),
ov::Dimension(config.k_head_size)};
} else {
pshape = ov::PartialShape{ov::Dimension::dynamic(),
ov::Dimension(config.num_v_heads),
ov::Dimension(m_block_size),
ov::Dimension(config.v_head_size)};
}

if (m_device.find("CPU") != std::string::npos && cache_type == ov::element::u8) {
// Scale, zero point and quantized data will be stored together.
// The layout for per token per head:
// |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)|
// so, we have to extend head_size by 8, which is sizeof(float)
// for scale and sizeof(float) for zeropoint
pshape[3] += 2 * sizeof(float);
if (cache_type == ov::element::u8) {
// Scale, zero point and quantized data will be stored together.
// The layout for per token per head:
// |scale(f32)|zeropoint(f32)|quantized data(u8,idx_1)|quantized data(u8,idx_2)|...|quantized data(u8,idx_head_size)|
// so, we have to extend head_size by 8, which is sizeof(float)
// for scale and sizeof(float) for zeropoint
pshape[3] += 2 * sizeof(float);
}
} else if (m_device.find("GPU") != std::string::npos) {
if (key_param) {
pshape = ov::PartialShape{ov::Dimension::dynamic(),
ov::Dimension(config.num_k_heads),
ov::Dimension(config.k_head_size),
ov::Dimension(m_block_size)};
} else {
pshape = ov::PartialShape{ov::Dimension::dynamic(),
ov::Dimension(config.num_v_heads),
ov::Dimension(m_block_size),
ov::Dimension(config.v_head_size)};
}
} else {
OPENVINO_THROW("Internal error: unsupported device ", m_device);
}

return pshape;
}

public:
CacheManager(ov::InferRequest request, const DeviceConfig& device_config) :
CacheManager(ov::InferRequest request, const std::vector<KVHeadConfig>& kv_cache_config) :
m_request(request) {
// extract information about inference device
ov::CompiledModel compiled_model = request.get_compiled_model();
std::vector<std::string> execution_devices = compiled_model.get_property(ov::execution_devices);
OPENVINO_ASSERT(execution_devices.size() == 1, "Contituous batching: execution device is expected to be CPU or GPU, but got ", execution_devices.size(), " devices");
m_device = execution_devices[0];

// set block_size depending on device
const size_t cpu_block_size = 32, gpu_block_size = 16;
const bool is_gpu = m_device.find("GPU") != std::string::npos;
m_block_size = is_gpu ? gpu_block_size : cpu_block_size;

// extract information about KV cache precisions and shapes
size_t kv_input_index = 0;
Expand All @@ -96,13 +131,13 @@ class CacheManager {
auto cache_precision = input.get_element_type();

if (name.find("key_cache.") == 0) {
auto pshape = patch_shape(device_config.get_key_cache_shape(kv_input_index), cache_precision);
auto pshape = to_shape(kv_cache_config[kv_input_index], cache_precision, true);
m_key_shapes.push_back(pshape);
m_key_precisions.push_back(cache_precision);
m_block_size_in_bytes += pshape[1].get_length() * pshape[2].get_length() * pshape[3].get_length() * cache_precision.size();
break;
} else if (name.find("value_cache.") == 0) {
auto pshape = patch_shape(device_config.get_value_cache_shape(kv_input_index), cache_precision);
auto pshape = to_shape(kv_cache_config[kv_input_index], cache_precision, false);
m_value_shapes.push_back(pshape);
m_value_precisions.push_back(cache_precision);
m_block_size_in_bytes += pshape[1].get_length() * pshape[2].get_length() * pshape[3].get_length() * cache_precision.size();
Expand All @@ -124,6 +159,10 @@ class CacheManager {
return m_device;
}

size_t get_block_size() const {
return m_block_size;
}

ov::element::Type get_key_cache_precision(size_t decoder_layer_id) const {
OPENVINO_ASSERT(decoder_layer_id < m_key_precisions.size());
return m_key_precisions[decoder_layer_id];
Expand Down Expand Up @@ -251,6 +290,10 @@ class CacheManager {
return m_value_cache[decoder_layer_id];
}

size_t get_k_head_size(size_t layer_id) const {
return m_key_shapes[layer_id][1].get_length();
}

void copy_blocks(const std::map<size_t, std::list<size_t>>& block_copy_map) {
for (const auto & blocks_pair : block_copy_map) {
size_t src_block_id = blocks_pair.first;
Expand Down
46 changes: 22 additions & 24 deletions src/cpp/src/continuous_batching_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,12 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::ContinuousBatchingImpl(
m_generation_config = generation_config;
m_is_validation_mode_enabled = is_validation_mode_enabled;

DeviceConfig device_config(device);

bool is_need_per_layer_cache_control = scheduler_config.use_cache_eviction;
bool allow_cache_rotation = scheduler_config.cache_eviction_config.apply_rotation;
utils::apply_paged_attention_transformations(model, device_config, is_need_per_layer_cache_control, allow_cache_rotation);
auto kv_cache_config = utils::apply_paged_attention_transformations(model, is_need_per_layer_cache_control, allow_cache_rotation);
utils::apply_gather_before_matmul_transformation(model);

initialize_pipeline(model, scheduler_config, properties, device_config);
initialize_pipeline(model, scheduler_config, device, properties, kv_cache_config);
}

ContinuousBatchingPipeline::ContinuousBatchingImpl::~ContinuousBatchingImpl() {
Expand All @@ -139,29 +137,31 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_pull_awaiting_requests
void ContinuousBatchingPipeline::ContinuousBatchingImpl::initialize_pipeline(
std::shared_ptr<ov::Model> model,
const SchedulerConfig& scheduler_config,
const std::string& device,
const ov::AnyMap& properties,
const DeviceConfig& device_config) {
const std::vector<KVHeadConfig>& kv_cache_config) {
ov::Core core = utils::singleton_core();
ov::CompiledModel compiled_model;

// TODO: remove once plugin automatically set KV cache precisions
apply_kv_cache_precision(model, device_config.get_device(), properties);
apply_kv_cache_precision(model, device, properties);

// apply LoRA
if (auto filtered_properties = extract_adapters_from_properties(properties, &m_generation_config.adapters)) {
m_generation_config.adapters->set_tensor_name_prefix("base_model.model.model.");
m_adapter_controller = AdapterController(model, *m_generation_config.adapters, device_config.get_device()); // TODO: Make the prefix name configurable
compiled_model = core.compile_model(model, device_config.get_device(), *filtered_properties);
m_adapter_controller = AdapterController(model, *m_generation_config.adapters, device); // TODO: Make the prefix name configurable
compiled_model = core.compile_model(model, device, *filtered_properties);
} else {
compiled_model = core.compile_model(model, device_config.get_device(), properties);
compiled_model = core.compile_model(model, device, properties);
}

ov::genai::utils::print_compiled_model_properties(compiled_model, "LLM with Paged Attention");
ov::InferRequest infer_request = compiled_model.create_infer_request();

// Cache manager
std::shared_ptr<CacheManager> cache_manager = std::make_shared<CacheManager>(infer_request, device_config);
std::shared_ptr<CacheManager> cache_manager = std::make_shared<CacheManager>(infer_request, kv_cache_config);
m_num_decoder_layers = cache_manager->get_num_decoder_layers();
m_block_size = cache_manager->get_block_size();

// Scheduler
SchedulerConfig normalized_config = scheduler_config;
Expand All @@ -171,21 +171,21 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::initialize_pipeline(
}

bool can_use_partial_preemption = true;
if (device_config.get_device().find("GPU") != std::string::npos && !normalized_config.dynamic_split_fuse) {
if (device.find("GPU") != std::string::npos && !normalized_config.dynamic_split_fuse) {
// in case of executing a `vLLM-like` pipeline, it's better not to use partial eviction on the GPU,
// as it may lead to performance slowdown
can_use_partial_preemption = false;
}

m_scheduler = std::make_shared<Scheduler>(device_config.get_block_size(), cache_manager, normalized_config, m_num_decoder_layers, can_use_partial_preemption);
m_scheduler = std::make_shared<Scheduler>(cache_manager->get_block_size(), cache_manager, normalized_config, m_num_decoder_layers, can_use_partial_preemption);

// Model Runner
bool is_use_cache_eviction = m_scheduler->get_config().use_cache_eviction;
if (is_use_cache_eviction) {
const auto& eviction_config = m_scheduler->get_config().cache_eviction_config;
bool is_apply_rotation = eviction_config.apply_rotation;
m_model_runner = std::make_shared<ModelRunner>(infer_request,
m_scheduler->get_block_size(),
m_block_size,
m_num_decoder_layers,
/* collect_attention_scores = */ true,
/* is_use_per_layer_cache_control = */ true,
Expand All @@ -199,10 +199,10 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::initialize_pipeline(
m_rotation_deltas_stores.push_back(store);
}

size_t max_sequence_cache_occupation_length_in_blocks = normalized_config.max_num_batched_tokens / m_scheduler->get_block_size() + 1;
size_t embedding_size = device_config.get_k_head_size(0);
size_t max_sequence_cache_occupation_length_in_blocks = normalized_config.max_num_batched_tokens / m_block_size + 1;
size_t embedding_size = cache_manager->get_k_head_size(0);
m_cache_rotation_calculator = std::make_shared<CacheRotationCalculator>(
m_scheduler->get_block_size(),
m_block_size,
max_sequence_cache_occupation_length_in_blocks,
embedding_size);
auto rotation_trig_lut = ov::Tensor(ov::element::f32, ov::Shape{max_sequence_cache_occupation_length_in_blocks, embedding_size});
Expand All @@ -224,7 +224,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::initialize_pipeline(
}
} else {
m_model_runner =
std::make_shared<ModelRunner>(infer_request, m_scheduler->get_block_size(), m_num_decoder_layers);
std::make_shared<ModelRunner>(infer_request, m_block_size, m_num_decoder_layers);
}

m_sampler = std::make_shared<Sampler>(m_tokenizer);
Expand All @@ -245,9 +245,7 @@ ContinuousBatchingPipeline::ContinuousBatchingImpl::add_request(uint64_t request
sampling_params.set_eos_token_id(m_generation_config.eos_token_id);
sampling_params.validate();

SequenceGroup::Ptr sequence_group = std::make_shared<SequenceGroup>(request_id, input_ids,
sampling_params,
m_scheduler->get_block_size());
SequenceGroup::Ptr sequence_group = std::make_shared<SequenceGroup>(request_id, input_ids, sampling_params, m_block_size);

if (m_scheduler->get_config().enable_prefix_caching) {
m_scheduler->restore_cached_blocks(sequence_group);
Expand Down Expand Up @@ -662,8 +660,8 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_compute_cache_rotation
size_t block_offset = num_blocks_to_rotate_for_each_layer[layer_idx];
auto rotation_deltas_tensor_data =
m_rotation_deltas_stores[layer_idx].data<int32_t>() + block_offset;
for (size_t tok_idx = 0; tok_idx < m_scheduler->get_block_size(); tok_idx++) {
rotation_deltas_tensor_data[tok_idx] = block_rotation_data.rotation_delta / m_scheduler->get_block_size();
for (size_t tok_idx = 0; tok_idx < m_block_size; tok_idx++) {
rotation_deltas_tensor_data[tok_idx] = block_rotation_data.rotation_delta / m_block_size;
}
num_blocks_to_rotate_for_each_layer[layer_idx] += 1;
}
Expand Down Expand Up @@ -693,7 +691,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_maybe_evict_cache_bloc
auto seq_id = seq_id_and_attention_scores.first;
const auto& attention_scores_for_all_decoder_layers = seq_id_and_attention_scores.second;
if (m_seq_group_id_to_cache_eviction_algo_map.find(seq_id) == m_seq_group_id_to_cache_eviction_algo_map.end()) {
m_seq_group_id_to_cache_eviction_algo_map[seq_id] = CacheEvictionAlgorithm(sched_config.cache_eviction_config, m_scheduler->get_block_size(), num_decoder_layers);
m_seq_group_id_to_cache_eviction_algo_map[seq_id] = CacheEvictionAlgorithm(sched_config.cache_eviction_config, m_block_size, num_decoder_layers);
}
auto& cache_eviction_algo = m_seq_group_id_to_cache_eviction_algo_map[seq_id];
cache_eviction_algo.register_new_token_scores(attention_scores_for_all_decoder_layers);
Expand Down Expand Up @@ -728,7 +726,7 @@ void ContinuousBatchingPipeline::ContinuousBatchingImpl::_maybe_evict_cache_bloc
// Assuming that the evicted blocks are always full (since they by design are only selected from intermediate-age blocks)
auto seq_group_ptr = seq_group_ptr_and_num_blocks_evicted.first;
auto num_blocks_evicted = seq_group_ptr_and_num_blocks_evicted.second;
seq_group_ptr->register_token_eviction(num_blocks_evicted * m_scheduler->get_block_size());
seq_group_ptr->register_token_eviction(num_blocks_evicted * m_block_size);
}
}

Expand Down
4 changes: 3 additions & 1 deletion src/cpp/src/continuous_batching_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc
bool m_is_validation_mode_enabled = false;

size_t m_num_decoder_layers = 0;
size_t m_block_size = 0;

// Pre-allocated per-layer storages for the per-token cache re-rotation deltas used in cache eviction case
std::vector<ov::Tensor> m_rotation_deltas_stores;
Expand All @@ -58,8 +59,9 @@ class ContinuousBatchingPipeline::ContinuousBatchingImpl : public ContinuousBatc

void initialize_pipeline(std::shared_ptr<ov::Model> model,
const SchedulerConfig& scheduler_config,
const std::string& device,
const ov::AnyMap& plugin_config,
const DeviceConfig& device_config);
const std::vector<KVHeadConfig>& kv_cache_config);

/**
* Pulls requests from awaiting queue to running queue
Expand Down
Loading

0 comments on commit d59bd18

Please sign in to comment.