Skip to content

Commit

Permalink
Merging PR vllm-project#12536
Browse files Browse the repository at this point in the history
Merged via CLI script
  • Loading branch information
heheda12345 authored and kerthcet committed Feb 21, 2025
1 parent 15236dc commit 64543d1
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,13 @@ def forward(
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
if self.calculate_kv_scales and \
attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(key, value)
# NOTE: please avoid accessing `kv_cache` and `attn_metadata` arguments
# directly, use `self.kv_cache` and
# `get_forward_context().attn_metadata` instead.
if self.calculate_kv_scales:
ctx_attn_metadata = get_forward_context().attn_metadata
if ctx_attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(key, value)
if self.use_output:
output = torch.empty_like(query)
hidden_size = query.size(-1)
Expand All @@ -172,15 +176,27 @@ def forward(
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
if self.use_direct_call:
unified_attention_with_output(query, key, value, output,
self.layer_name)
forward_context: ForwardContext = get_forward_context()
ctx_attn_metadata = forward_context.attn_metadata
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
self.impl.forward(self,
query,
key,
value,
self_kv_cache,
ctx_attn_metadata,
output=output)
else:
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name)
return output.view(-1, hidden_size)
else:
if self.use_direct_call:
return unified_attention(query, key, value, self.layer_name)
forward_context = get_forward_context()
ctx_attn_metadata = forward_context.attn_metadata
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
return self.impl.forward(self, query, key, value,
self_kv_cache, ctx_attn_metadata)
else:
return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name)
Expand Down

0 comments on commit 64543d1

Please sign in to comment.