Skip to content
This repository has been archived by the owner on Jul 24, 2024. It is now read-only.

Commit

Permalink
Merge remote-tracking branch 'origin/main' into geon-dev
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-geon-park committed Mar 5, 2024
2 parents 6caa148 + 68d93b5 commit 7f2a7d8
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 77 deletions.
219 changes: 146 additions & 73 deletions vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512

from timber.models.timber_attention.attention1_block_gpu import paged_timber_attention

from timber.models.timber_attention.attention1_block_gpu import (
paged_timber_attention,
timber_attention
)
from vllm.transformers_utils import config as vllm_transformers_config
from timber.utils import get_bench
BENCHMARK_ITERATION = 0

class PagedAttention(nn.Module):
Expand Down Expand Up @@ -109,83 +113,152 @@ def forward(
input_metadata.slot_mapping.flatten(),
input_metadata.kv_cache_dtype,
)

hip_k = int(os.environ.get('HIP_K', '1024'))

if input_metadata.is_prompt:
# Prompt run.
if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv, query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
key.shape[-1])
value = value[:, :, None, :].expand(value.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
# normal attention
if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0):
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if input_metadata.attn_bias is None:
BENCHMARK_PROMPT_ATTENTION = os.environ.get('BENCHMARK_PAGED_ATTENTION', '0') == '1'
backend = os.environ.get('PROMPT_ATTENTION_BACKEND', 'vllm')
is_normal_attention = (key_cache is None) or (value_cache is None) or (input_metadata.block_tables.numel() == 0)
if backend == 'vllm':
if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query = query.view(
query.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
query.shape[-1],
)
key = key[:, :, None, :]\
.expand(
key.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
key.shape[-1]
)
value = value[:, :, None, :]\
.expand(
value.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1]
)
# normal attention
if is_normal_attention:
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if input_metadata.attn_bias is None:
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
[seq_len] * batch_size)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
input_metadata.attn_bias = attn_bias
else:
input_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, batch_size,
seq_len, query.dtype)

# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
[seq_len] * batch_size)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
input_metadata.attn_bias = attn_bias
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
else:
input_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, batch_size,
seq_len, query.dtype)
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))

# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if self.alibi_slopes is None:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
if BENCHMARK_PROMPT_ATTENTION:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()

out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=input_metadata.attn_bias,
p=0.0,
scale=self.scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
)
output = out.view_as(query)

if BENCHMARK_PROMPT_ATTENTION:
end.record()
torch.cuda.synchronize()
print(backend, start.elapsed_time(end), output.shape, end='\n')
else:
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))

out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=input_metadata.attn_bias,
p=0.0,
scale=self.scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
# prefix-enabled attention
output = torch.empty_like(query)
context_attention_fwd(
query,
key,
value,
output,
key_cache,
value_cache,
input_metadata.block_tables, # [BS, max_block_per_request]
input_metadata.start_loc,
input_metadata.prompt_lens,
input_metadata.context_lens,
input_metadata.max_seq_len,
getattr(self, "alibi_slopes", None),
)
elif backend == 'timber':
# timber support MQA/GQA
warnings.warn('prompt attention backend is timber')

TDST, H, HID = query.shape
TSRC, H_KV, _HID = key.shape
assert key.shape[:-1] == value.shape[:-1]
assert HID == _HID

query = query.permute(1, 0, 2)
key = key.permute(1, 0, 2)
value = value.permute(1, 0, 2)

if BENCHMARK_PROMPT_ATTENTION:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()

assert input_metadata.attn_bias is None
assert self.alibi_slopes is None

output, _ = timber_attention(
q=query * self.scale,
k=key,
v=value,
attention_mask=None,
mask_k=hip_k,
block_size_q=32,
block_size_k=2,
)
output = out.view_as(query)

output = output.permute(1, 0, 2)
output = output.view(
1,
TDST,
H,
HID,
).contiguous()

if BENCHMARK_PROMPT_ATTENTION:
end.record()
torch.cuda.synchronize()
print(backend, start.elapsed_time(end), output.shape, end='\n')
else:
# prefix-enabled attention
output = torch.empty_like(query)
context_attention_fwd(
query,
key,
value,
output,
key_cache,
value_cache,
input_metadata.block_tables, # [BS, max_block_per_request]
input_metadata.start_loc,
input_metadata.prompt_lens,
input_metadata.context_lens,
input_metadata.max_seq_len,
getattr(self, "alibi_slopes", None),
)

raise Exception(backend)
else:
# Decoding run.
BENCHMARK_PAGED_ATTENTION = os.environ.get('BENCHMARK_PAGED_ATTENTION', '0') == '1'
Expand All @@ -211,7 +284,7 @@ def forward(
self.alibi_slopes,
)
elif backend == 'timber':
warnings.warn('backend is timber')
warnings.warn('paged attention backend is timber')

output, _ = paged_timber_attention(
q=query,
Expand All @@ -222,9 +295,9 @@ def forward(
context_lens=input_metadata.context_lens,
max_context_len=input_metadata.max_context_len,
attention_mask=None,
mask_k=1024,
mask_k=hip_k,
block_size_q=32,
block_size_k=2,
block_size_q=16
)

N_H, _, HID = output.shape
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,9 @@ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2,
dtype=torch.float)) * self.extrapolation_factor
dtype=torch.float,
device=pos_freqs.device
)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
Expand Down
42 changes: 39 additions & 3 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import time
from typing import Dict, List, Optional, Tuple, Set, Union

Expand Down Expand Up @@ -534,13 +535,35 @@ def execute_model(
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> Optional[SamplerOutput]:
(input_tokens, input_positions, input_metadata, sampling_metadata,
lora_requests,
lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list)
BENCHMARK_RUNNER = os.environ.get('BENCHMARK_RUNNER', '0') == '1'

if BENCHMARK_RUNNER:
t_start = time.time()

start_prepare = torch.cuda.Event(enable_timing=True)
end_prepare = torch.cuda.Event(enable_timing=True)

start_model = torch.cuda.Event(enable_timing=True)
end_model = torch.cuda.Event(enable_timing=True)

start_sample = torch.cuda.Event(enable_timing=True)
end_sample = torch.cuda.Event(enable_timing=True)

if BENCHMARK_RUNNER: start_prepare.record()
(
input_tokens,
input_positions,
input_metadata,
sampling_metadata,
lora_requests,
lora_mapping
) = self.prepare_input_tensors(seq_group_metadata_list)
if BENCHMARK_RUNNER: end_prepare.record()

if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping)

if BENCHMARK_RUNNER: start_model.record()
# Execute the model.
if input_metadata.use_cuda_graph:
graph_batch_size = input_tokens.shape[0]
Expand All @@ -553,12 +576,25 @@ def execute_model(
kv_caches=kv_caches,
input_metadata=input_metadata,
)
if BENCHMARK_RUNNER: end_model.record()

# Sample the next token.
if BENCHMARK_RUNNER: start_sample.record()
output = self.model.sample(
hidden_states=hidden_states,
sampling_metadata=sampling_metadata,
)
if BENCHMARK_RUNNER: end_sample.record()

if BENCHMARK_RUNNER:
torch.cuda.synchronize()
elapsed_prepare = start_prepare.elapsed_time(end_prepare)
elapsed_model = start_model.elapsed_time(end_model)
elapsed_sample = start_sample.elapsed_time(end_sample)
elapsed_total = (time.time() - t_start) * 1000

print(f'[{time.time() * 1000:.3f}] prepare: {elapsed_prepare:.3f}, model: {elapsed_model:.3f}, sample: {elapsed_sample:.3f}, total: {elapsed_total:.3f}')

return output

@torch.inference_mode()
Expand Down

0 comments on commit 7f2a7d8

Please sign in to comment.