Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize speculative decoding PVC memory usage #10329

Merged
merged 7 commits into from
Mar 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 60 additions & 2 deletions python/llm/src/bigdl/llm/transformers/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
import os
import copy
import logging
import warnings
import inspect
import transformers
from packaging import version
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from transformers import top_k_top_p_filtering, GenerationConfig, \
LogitsProcessorList, StoppingCriteriaList
Expand Down Expand Up @@ -367,6 +367,55 @@ def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_s
delta_past_value.to(torch.float32)


def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_len=256,
model_type="llama"):
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
extend_kv_cache
enough_kv_room = True
if model_type not in ["chatglm", "qwen", "baichuan", "llama", "mistral",
"gptj", "opt"]:
return past_key_values, False
cache_k = past_key_values[0][0]
if model_type == "chatglm":
cache_k = cache_k.permute(1, 2, 0, 3)
elif model_type == "qwen":
cache_k = cache_k.transpose(1, 2)

enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value=(cache_k, None),
seq_len=max_step_draft)
bsz, num_heads, current_seq_len, head_dim = cache_k.shape
device = past_key_values[0][0].device
if not enough_kv_room:
past_key_values = list(past_key_values)
for i in range(len(past_key_values)):
cache_k = past_key_values[i][0]
cache_v = past_key_values[i][1]
if model_type == "chatglm":
cache_k = cache_k.permute(1, 2, 0, 3)
cache_v = cache_v.permute(1, 2, 0, 3)
elif model_type == "qwen":
cache_k = cache_k.transpose(1, 2)
cache_v = cache_v.transpose(1, 2)
new_cache_k, new_cache_v = extend_kv_cache(
bsz,
num_heads, # Support GQA
head_dim,
cache_k.size(2),
current_seq_len + max_step_draft + kv_alloc_block_len,
dtype=cache_v.dtype,
device=device)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
if model_type == "chatglm":
past_key_values[i] = (new_cache_k.permute(2, 0, 1, 3),
new_cache_v.permute(2, 0, 1, 3))
elif model_type == "qwen":
past_key_values[i] = (new_cache_k.transpose(1, 2), new_cache_v.transpose(1, 2))
else:
past_key_values[i] = (new_cache_k, new_cache_v)
return past_key_values, not enough_kv_room


@torch.no_grad()
def speculative_generate(self,
inputs: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -504,6 +553,9 @@ def speculative_generate(self,

self.clear_benchmarks()

if self.device.type == 'xpu':
torch.xpu.empty_cache()

# Example:
# Target model forward for the first token
# Step 1. target_model(prompt) -> a
Expand Down Expand Up @@ -562,6 +614,10 @@ def speculative_generate(self,
past_key_values_storage, _enable_ipex)
original_draft_past_key_values = draft_past_key_values
else:
past_key_values, extend_kv = _check_and_extend_kv_cache(past_key_values,
max_step_draft,
max_new_tokens - step + 40,
self.config.model_type)
draft_past_key_values = past_key_values
draft_generate_ids[:, 0] = current_input_ids
draft_prob_list = []
Expand Down Expand Up @@ -742,6 +798,8 @@ def speculative_generate(self,
output_ids = greedy(logits)
if self.device.type == 'xpu':
torch.xpu.synchronize()
if extend_kv:
torch.xpu.empty_cache()
toc = time.time()
self.verify_time.append(toc - tic)
self.generate_time.append(self.draft_time[-1] + self.verify_time[-1])
Expand Down
Loading