Skip to content

Commit

Permalink
[GPU] Improve kv cache memory allocation efficiency (openvinotoolkit#…
Browse files Browse the repository at this point in the history
…25580)

### Details:
 - Fixed two issues 
- 1) KV cache was allocating redundant memory when it requires new
memory
- 2) At a new inference, KV cache was setting a padding value as the one
used in the previous execution (last token for the previous generation),
which caused memory usage inefficiency.
- After fixing above issues, in some cases, memory is more frequently
allocated because
- 1) switching shape 1024 => 32 : happens reclaiming (previously due to
the wrong padding, it is not reclaimed.)
- 2) switching shape 32 => 1024 : new alloc needed at the first infer,
but shape history is not tracked yet. So during 3 iteration, it is
allocating new memory.
- Additional fix to resolve above issues: 
- 1) For initial allocation of kv cache, enforce prealloc with custom
prealloc count (known value of 128 + id%64) for sequence axis
- 2) For reclaiming kv cache : use prealloc size as the required memory
size

Memalloc count with PR

![image](https://github.com/user-attachments/assets/c65b3335-c849-46f8-b9fe-140c3a0fbccb)

### Tickets:
 - 146930
  • Loading branch information
yeonbok authored Jul 17, 2024
1 parent 4dbe733 commit e62d0fa
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 34 deletions.
36 changes: 21 additions & 15 deletions src/plugins/intel_gpu/include/intel_gpu/runtime/shape_predictor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,31 @@ struct ShapePredictor {
static_assert(_max_deque_size >= 2, "[GPU] Deque is supposed to contain at least 2 elements for prediction");
}


/// \brief Predicts the next possible shapes sizes based on history collected by previous
/// predict_preallocation_shape() calls.
/// This function works in two modes: by default it tries to predict shape for the next
/// `_next_iters_preallocation_count` iterations, in case if per-iteration buffer size is less than
/// `_max_per_iter_size` and difference between shapes is less than `_max_per_dim_diff`; the second
/// operation mode is percentage preallocation - this mode can be configured with
/// ov::intel_gpu::buffers_preallocation_ratio property, it increases buffer size by
/// `_buffers_preallocation_ratio` value unconditionally.
/// \param id Primitive id.
/// \param layout Primitive's layout on current iteration.
/// \param can_reuse_buffer Specifies if current memory buffer is enough to store data.
/// \return The result of shape size prediction as std::pair<bool, ov::Shape>, where the first element
/// says if shape is successfully predicted and can be preallocated, and the second element is ov::Shape itself.
/// \brief Predicts the next possible shapes sizes based on history collected by previous
/// predict_preallocation_shape() calls.
/// This function works in two modes: by default it tries to predict shape for the next
/// `_next_iters_preallocation_count` iterations, in case if per-iteration buffer size is less than
/// `_max_per_iter_size` and difference between shapes is less than `_max_per_dim_diff`; the second
/// operation mode is percentage preallocation - this mode can be configured with
/// ov::intel_gpu::buffers_preallocation_ratio property, it increases buffer size by
/// `_buffers_preallocation_ratio` value unconditionally.
/// \param id Primitive id.
/// \param layout Primitive's layout on current iteration.
/// \param can_reuse_buffer Specifies if current memory buffer is enough to store data.
/// \param out_idx output index of multiple outputs
/// \param custom_next_iters_prealloc_couunt If it is specified, enforce prealloc size as the specified value
/// \param custom_prealloc_dim If both custom_next_iters_prealloc_count and custom_prealloc_dim are specified,
/// increase custom_prealloc_dim with custom_next_iters_prealloc_count without checking shape history (e.g.,
/// used for first inference of kv cache)
/// \return The result of shape size prediction as std::pair<bool, ov::Shape>, where
/// the first element says if shape is successfully predicted and can be preallocated, and the second
/// element is ov::Shape itself.
std::pair<bool, ov::Shape> predict_preallocation_shape(const std::string& id,
const cldnn::layout& layout,
bool can_reuse_buffer,
const size_t out_idx = 0,
int32_t next_iters_prealloc_count = -1);
int32_t custom_next_iters_prealloc_count = -1,
int32_t custom_prealloc_dim = -1);

bool can_preallocate(size_t desired_buffer_size);

Expand Down
82 changes: 64 additions & 18 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -664,9 +664,32 @@ event::ptr primitive_inst::realloc_if_needed() {
updated_layouts[0] = layout(current_buf_shape, updated_layouts[0].data_type, updated_layouts[0].format);
}

int32_t tmp_prealloc_count = get_prealloc_iter_num();
GPU_DEBUG_IF(debug_config->mem_preallocation_params.is_initialized) {
// If debug config is set, repsect the config most
tmp_prealloc_count = -1;
}

// If we allocated too large memory, reclaim the memory.
for (size_t i = 0; i < updated_layouts.size(); ++i) {
if (updated_layouts[i].get_buffer_size().count() * 10 < _max_output_layout_count[i]) {
bool reclaim = 0;
size_t required_buffer_size = 0;
if (_node->is_type<kv_cache>() && i == 0) {
// Relax reclaiming condition for kv cache
const auto& desc = _node->as<kv_cache>().get_primitive();
auto prealloc_shape = updated_layouts[i].get_shape();
const auto shape_rank = prealloc_shape.size();
auto seq_axis =
static_cast<int32_t>(desc->concat_axis >= 0 ? desc->concat_axis : shape_rank + desc->concat_axis);
prealloc_shape[seq_axis] += tmp_prealloc_count;
required_buffer_size = std::accumulate(prealloc_shape.begin(), prealloc_shape.end(), size_t(1), std::multiplies<size_t>());
} else {
required_buffer_size = (updated_layouts[i].get_buffer_size().count());
}
if (required_buffer_size * 10 < _max_output_layout_count[i]) {
reclaim = true;
}
if (reclaim) {
GPU_DEBUG_TRACE_DETAIL << id() << ": Updated output[" << i << "] size " << updated_layouts[i].get_buffer_size().count()
<< " is much smaller than current memory size! " << _max_output_layout_count[i]
<< "Reset memory of output " << i << std::endl;
Expand All @@ -681,31 +704,51 @@ event::ptr primitive_inst::realloc_if_needed() {
return ev;
}

int32_t tmp_prealloc_count = get_prealloc_iter_num();
GPU_DEBUG_IF(debug_config->mem_preallocation_params.is_initialized) {
// If debug config is set, repsect the config most
tmp_prealloc_count = -1;
}

for (size_t i = 0; i < actual_layouts.size(); ++i) {
bool can_reuse_buffer = (_outputs[i] && updated_layouts[i].get_buffer_size().count() <= _max_output_layout_count[i]);
std::pair<bool, ov::Shape> prealloc_info =
sp.predict_preallocation_shape(id(), updated_layouts[i], can_reuse_buffer, i, tmp_prealloc_count);
std::pair<bool, ov::Shape> prealloc_info;
if (_node->is_type<kv_cache>() && i == 0) {
const auto& desc = _node->as<kv_cache>().get_primitive();
auto shape_rank = updated_layouts[i].get_shape().size();
auto seq_axis =
static_cast<int32_t>(desc->concat_axis >= 0 ? desc->concat_axis : shape_rank + desc->concat_axis);
prealloc_info = sp.predict_preallocation_shape(id(), updated_layouts[i], false, i, tmp_prealloc_count, seq_axis);
} else {
prealloc_info = sp.predict_preallocation_shape(id(), updated_layouts[i], can_reuse_buffer, i, tmp_prealloc_count);
}
if (prealloc_info.first && sp.can_preallocate(ov::shape_size(prealloc_info.second) * (dt_sizes_in_B[i]))) {
auto new_layout = updated_layouts[i];
new_layout.set_partial_shape(prealloc_info.second);
updated_params.output_layouts[i] = new_layout;
}

if (updated_params.output_layouts[i].get_buffer_size().count() < updated_layouts[i].get_buffer_size().count()) {
updated_params.output_layouts[i] = updated_layouts[i];
}

if (can_reuse_buffer) {
GPU_DEBUG_TRACE_DETAIL << id() << ": reuse previously allocated output buffer - "
GPU_DEBUG_TRACE_DETAIL << id() << ": reuse previously allocated output buffer[" << i << "] - "
<< actual_layouts[i].get_buffer_size().count() << "/" << _max_output_layout_count[i]
<< std::endl;
_outputs[i] = _network.get_engine().reinterpret_buffer(*_outputs[i], actual_layouts[i]);
if (_node->is_type<kv_cache>() && (i == 0)) {
// kv_cache has already assigned memory.
// No need to reinterpret output memory but need to update padding
const auto& desc = _node->as<kv_cache>().get_primitive();
auto& present_layout = _impl_params->output_layouts[i];
const auto present_layout_rank = present_layout.get_partial_shape().size();
const auto sequence_axis =
desc->concat_axis >= 0 ? desc->concat_axis : present_layout_rank + desc->concat_axis;

const auto sequence_axis_legacy = kv_cache_inst::get_sequence_axis_legacy(sequence_axis, present_layout_rank);
auto max_pad = kv_cache_inst::get_max_pad(present_layout,
_max_output_layout_count[i],
sequence_axis_legacy,
"present_layout");
kv_cache_inst::update_pad(present_layout, max_pad, sequence_axis_legacy);
GPU_DEBUG_TRACE_DETAIL << _impl_params->output_layouts[i].to_string() << std::endl;
set_shape_change();
} else {
_outputs[i] = _network.get_engine().reinterpret_buffer(*_outputs[i], actual_layouts[i]);
}
// TODO: check need_reset_output_memory per output
if (need_reset_output_memory() && !can_be_optimized()) {
GPU_DEBUG_TRACE_DETAIL << id() << " : Need reset output memory considering user" << std::endl;
Expand Down Expand Up @@ -740,7 +783,7 @@ event::ptr primitive_inst::realloc_if_needed() {

// Set variable memory same as output memory
if (_node->is_type<kv_cache>()) {
auto desc = _node->as<kv_cache>().get_primitive();
const auto& desc = _node->as<kv_cache>().get_primitive();
auto& variable = get_network().get_variable(desc->variable_info.variable_id);
auto present_layout = _impl_params->output_layouts[0];
auto present_layout_rank = present_layout.get_partial_shape().size();
Expand All @@ -760,7 +803,7 @@ event::ptr primitive_inst::realloc_if_needed() {
if (present_layout.data_padding.get_dynamic_pad_dims().sizes()[sequence_axis_legacy] == 1) {
// Apply padding of variable to make it be optimized in the next iteration
auto max_pad = kv_cache_inst::get_max_pad(present_layout,
updated_params.output_layouts[0].get_buffer_size().count(),
_max_output_layout_count[0],
sequence_axis_legacy,
"present_layout");
if (max_pad > 0) {
Expand All @@ -783,7 +826,7 @@ event::ptr primitive_inst::realloc_if_needed() {
GPU_DEBUG_TRACE_DETAIL << id() << ": Update variable " << variable.get_name()
<< "'s layout with allocated kv cache output: " << present_layout.to_short_string()
<< " (is_set = " << variable.is_set() << ") " << std::endl;
variable.set_layout(present_layout);
variable.set_memory(_outputs[0], present_layout);
}
} else {
GPU_DEBUG_TRACE_DETAIL << id() << ": Update variable " << variable.get_name()
Expand Down Expand Up @@ -1036,8 +1079,10 @@ void primitive_inst::update_paddings() {
auto reset_pad = [](kernel_impl_params& params, const program_node* node) {
params.output_layouts[0].data_padding = node->get_output_layout(0).data_padding;
};
if (_node->is_type<read_value>()) {
auto& variable = get_network().get_variable(_node->as<read_value>().get_primitive()->variable_id);
if (_node->is_type<read_value>() || _node->is_type<kv_cache>()) {
auto variable_id = _node->is_type<read_value>() ? (_node->as<read_value>().get_primitive()->variable_id)
: (_node->as<kv_cache>().get_primitive()->variable_info.variable_id);
auto& variable = get_network().get_variable(variable_id);
// Reset paddings for read_value and users with dynamic pad when variable is reset
// to avoid wrong pad used for some nodes due to pad propagation logic (which uses previous iter pad values)
if (!variable.is_set()) {
Expand All @@ -1054,6 +1099,7 @@ void primitive_inst::update_paddings() {
}
return;
}

if (_node->is_type<gather>() && _impl_params->output_layouts[0].data_padding.get_dynamic_pad_dims() != tensor(0)) {
if (can_be_optimized())
_impl_params->output_layouts[0] = _impl_params->input_layouts[0];
Expand Down Expand Up @@ -1141,7 +1187,7 @@ void primitive_inst::do_runtime_in_place_kv_cache() {
if (_impl_params->get_input_layout(0).count() == 0) {
return;
}
auto desc = _node->as<kv_cache>().get_primitive();
const auto& desc = _node->as<kv_cache>().get_primitive();
auto& past_layout = _impl_params->input_layouts[0];
auto& present_layout = _impl_params->output_layouts[0];
const auto& sequence_axis = desc->concat_axis;
Expand Down
10 changes: 9 additions & 1 deletion src/plugins/intel_gpu/src/runtime/shape_predictor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ std::pair<bool, ov::Shape> ShapePredictor::predict_preallocation_shape(const std
const cldnn::layout& layout,
bool can_reuse_buffer,
const size_t out_idx,
int32_t custom_next_iters_prealloc_count) {
int32_t custom_next_iters_prealloc_count,
int32_t custom_prealloc_dim) {
size_t next_iters_prealloc_count = custom_next_iters_prealloc_count > 0
? static_cast<size_t>(custom_next_iters_prealloc_count)
: _next_iters_preallocation_count;
Expand All @@ -79,6 +80,13 @@ std::pair<bool, ov::Shape> ShapePredictor::predict_preallocation_shape(const std
if (can_reuse_buffer)
return {false, {}};

// If both prealloc dim and prealloc count are specified, dont predict and just use the given info
if (custom_prealloc_dim >= 0 && custom_next_iters_prealloc_count > 0) {
auto new_shape = current_shape;
new_shape[custom_prealloc_dim] += custom_next_iters_prealloc_count;
return {true, new_shape};
}

// Check if there is enough data for prediction
const auto& shapes = _shapes_info[id_record];
const auto shapes_num = shapes.size();
Expand Down

0 comments on commit e62d0fa

Please sign in to comment.