diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 962c45a65ae23..c210292865b29 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -148,9 +148,13 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: - if self.calculate_kv_scales and \ - attn_metadata.enable_kv_scales_calculation: - self.calc_kv_scales(key, value) + # NOTE: please avoid accessing `kv_cache` and `attn_metadata` arguments + # directly, use `self.kv_cache` and + # `get_forward_context().attn_metadata` instead. + if self.calculate_kv_scales: + ctx_attn_metadata = get_forward_context().attn_metadata + if ctx_attn_metadata.enable_kv_scales_calculation: + self.calc_kv_scales(key, value) if self.use_output: output = torch.empty_like(query) hidden_size = query.size(-1) @@ -164,15 +168,27 @@ def forward( if value is not None: value = value.view(-1, self.num_kv_heads, self.head_size) if self.use_direct_call: - unified_attention_with_output(query, key, value, output, - self.layer_name) + forward_context: ForwardContext = get_forward_context() + ctx_attn_metadata = forward_context.attn_metadata + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + self.impl.forward(self, + query, + key, + value, + self_kv_cache, + ctx_attn_metadata, + output=output) else: torch.ops.vllm.unified_attention_with_output( query, key, value, output, self.layer_name) return output.view(-1, hidden_size) else: if self.use_direct_call: - return unified_attention(query, key, value, self.layer_name) + forward_context = get_forward_context() + ctx_attn_metadata = forward_context.attn_metadata + self_kv_cache = self.kv_cache[forward_context.virtual_engine] + return self.impl.forward(self, query, key, value, + self_kv_cache, ctx_attn_metadata) else: return torch.ops.vllm.unified_attention( query, key, value, self.layer_name)