From 8d8ebb17985022714ac216e0ebc233d986265e5d Mon Sep 17 00:00:00 2001 From: caoshiyi Date: Wed, 15 Nov 2023 05:36:07 +0000 Subject: [PATCH 01/37] add the prefix prefill triton kernel --- .../layers/triton_kernel/__init__.py | 0 .../layers/triton_kernel/prefix_prefill.py | 466 ++++++++++++++++++ 2 files changed, 466 insertions(+) create mode 100644 vllm/model_executor/layers/triton_kernel/__init__.py create mode 100644 vllm/model_executor/layers/triton_kernel/prefix_prefill.py diff --git a/vllm/model_executor/layers/triton_kernel/__init__.py b/vllm/model_executor/layers/triton_kernel/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py new file mode 100644 index 0000000000000..7e8cdc003c5fe --- /dev/null +++ b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py @@ -0,0 +1,466 @@ +import torch +import time +import itertools +import triton +import triton.language as tl +import math +import torch.nn.functional as F +import matplotlib.pyplot as plt + +from benchmark_utils import bench, gc_torch + +if triton.__version__ >= "2.1.0": + @triton.jit + def _fwd_kernel( + Q, K, V, K_cache, V_cache, B_Loc, sm_scale, B_Start_Loc, B_Seqlen, B_Ctxlen, block_size, x, + Out, + stride_b_loc_b, stride_b_loc_s, + stride_qbs, stride_qh, stride_qd, + stride_kbs, stride_kh, stride_kd, + stride_vbs, stride_vh, stride_vd, + stride_obs, stride_oh, stride_od, + stride_k_cache_bs, stride_k_cache_h, stride_k_cache_d, stride_k_cache_bl, stride_k_cache_x, + stride_v_cache_bs, stride_v_cache_h, stride_v_cache_d, stride_v_cache_bl, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd + + q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + ((start_n + offs_n) // block_size) * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_ctx_len, other=0) + off_k = bn[None, :] * stride_k_cache_bs + cur_head * stride_k_cache_h + (offs_d[:, None] // x) * stride_k_cache_d + ((start_n + offs_n[None,:]) % block_size) * stride_k_cache_bl + (offs_d[:, None] % x) * stride_k_cache_x + off_v = bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h + offs_d[None,:] * stride_v_cache_d + (start_n + offs_n[:,None]) % block_size * stride_v_cache_bl + k = tl.load(K_cache + off_k, mask=(start_n + offs_n[None,:]) < cur_batch_ctx_len, other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf")) + qk *= sm_scale + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd + off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where(block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len - cur_batch_ctx_len, other=0.0) + # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len - cur_batch_ctx_len, other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + + @torch.inference_mode() + def context_attention_fwd(q, k, v, o, k_cache, v_cache, b_loc, b_start_loc, b_seq_len, b_ctx_len, max_input_len): + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + + sm_scale = 1.0 / (Lq**0.5) # 计算scale系数 + batch, head = b_seq_len.shape[0], q.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, + + num_warps = 8 if Lk <= 64 else 8 + _fwd_kernel[grid]( + q, k, v, k_cache, v_cache, b_loc, sm_scale, b_start_loc, b_seq_len, b_ctx_len, v_cache.shape[3], 8, + o, + b_loc.stride(0), b_loc.stride(1), + q.stride(0), q.stride(1), q.stride(2), + k.stride(0), k.stride(1), k.stride(2), + v.stride(0), v.stride(1), v.stride(2), + o.stride(0), o.stride(1), o.stride(2), + k_cache.stride(0), k_cache.stride(1), k_cache.stride(2), k_cache.stride(3), k_cache.stride(4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), v_cache.stride(1), v_cache.stride(2), v_cache.stride(3), #[num_blocks, num_kv_heads, head_size, block_size] + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + +@torch.inference_mode() +def test_contexted_kv_attention( + num_heads: int, + head_size: int, + dtype: torch.dtype, +) -> None: + import random + random.seed(0) + torch.manual_seed(0) + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask + MAX_SEQ_LEN = 1024 + MAX_CTX_LEN = 1024 + BS = 10 + cache_size = 640 + block_size = 32 + max_block_per_request = 64 + subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] + ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] + seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)] + + num_tokens = sum(subquery_lens) + query = torch.empty(num_tokens, + num_heads, + head_size, + dtype=dtype, + device='cuda') + query.uniform_(-1e-3, 1e-3) + output = torch.empty(num_tokens, + num_heads, + head_size, + dtype=dtype, + device='cuda') + + + kv = torch.empty(sum(seq_lens), + 2, + num_heads, + head_size, + dtype=dtype, + device='cuda') + kv.uniform_(-1e-3, 1e-3) + key,value = kv.unbind(dim=1) + + k_cache = torch.zeros(cache_size, block_size, num_heads, head_size, dtype=dtype, device='cuda') + v_cache = torch.zeros(cache_size, block_size, num_heads, head_size, dtype=dtype, device='cuda') + k = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype, device='cuda') + v = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype, device='cuda') + values = torch.arange(0, cache_size, dtype=torch.long, device='cuda') + values = values[torch.randperm(cache_size)] + block_table = values[:BS * max_block_per_request].view(BS, max_block_per_request) + b_loc = torch.zeros(BS, MAX_CTX_LEN, dtype=torch.long, device='cuda') + b_seq_len = torch.tensor(seq_lens, dtype=torch.long, device='cuda') + b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long, device='cuda') + b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1], dtype=torch.long, device='cuda'), dim=0) + max_input_len = MAX_SEQ_LEN + # copy kv to cache + b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], dtype=torch.long, device='cuda'), dim=0) + for i in range(BS): + for j in range(subquery_lens[i]): + k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j]) + v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j]) + cur_ctx = 0 + block_id = 0 + while cur_ctx < b_ctx_len[i]: + start_loc = b_seq_start_loc[i] + cur_ctx + if cur_ctx + block_size > b_ctx_len[i]: + end_loc = b_seq_start_loc[i] + b_ctx_len[i] + else: + end_loc = start_loc + block_size + start_slot = block_table[i,block_id] * block_size + end_slot = start_slot + end_loc - start_loc + k_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_(key[start_loc:end_loc]) + v_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_(value[start_loc:end_loc]) + cur_ctx += block_size + block_id += 1 + # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] + k_cache = k_cache.view(-1, block_size, num_heads, head_size//8, 8).permute(0, 2, 3, 1, 4).contiguous() + # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] to V_cache[num_blocks, num_kv_heads, head_size, block_size] + v_cache = v_cache.view(-1, block_size, num_heads, head_size).permute(0, 2, 3, 1).contiguous() + + + context_attention_fwd(query, k, v, output, + k_cache, v_cache, block_table, + b_start_loc, b_seq_len, + b_ctx_len, max_input_len) + torch.cuda.synchronize() + start_time = time.time() + context_attention_fwd(query, k, v, output, + k_cache, v_cache, block_table, + b_start_loc, b_seq_len, + b_ctx_len, max_input_len) + torch.cuda.synchronize() + end_time = time.time() + print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") + + scale = float(1.0 / (head_size**0.5)) + + attn_op = xops.fmha.cutlass.FwOp() + + attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(subquery_lens, seq_lens) + output_ref = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=scale, + op=attn_op, + ) + torch.cuda.synchronize() + start_time = time.time() + output_ref = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=scale, + op=attn_op, + ) + torch.cuda.synchronize() + end_time = time.time() + print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") + output_ref = output_ref.squeeze(0) + print(output_ref.shape) + print("max ", torch.max(torch.abs(output_ref - output))) + print("mean ", torch.mean(torch.abs(output_ref - output))) + print(output[0,0,:10]) + print(output_ref[0,0,:10]) + assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) + +@torch.inference_mode() +def bench_contexted_kv_attention( + num_heads: int, + head_size: int, + dtype: torch.dtype, +) -> None: + import random + random.seed(0) + torch.manual_seed(0) + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask + + # seq_len = [16, 64, 128, 256, 512, 1024] + # ctx_len = [16, 64, 128, 256, 512, 1024, 2048] + seq_len = [16, 64, 128, 256, 512, 1024] + ctx_len = [256, 512, 1024, 2048] + BS = 20 + timings_triton = {} + timings_xformer = {} + for MAX_SEQ_LEN, MAX_CTX_LEN in itertools.product(seq_len, ctx_len): + gc_torch() + # MAX_SEQ_LEN = 1024 + # MAX_CTX_LEN = 2048 + outputs = [ + f"seq_len={MAX_SEQ_LEN}", + f"ctx_len={MAX_CTX_LEN}", + f"bs={BS}" + ] + cache_size = 40960 + block_size = 1 + max_block_per_request = 2048 + subquery_lens = [random.randint(MAX_SEQ_LEN, MAX_SEQ_LEN) for _ in range(BS)] + ctx_lens = [random.randint(MAX_CTX_LEN, MAX_CTX_LEN) for _ in range(BS)] + seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)] + + num_tokens = sum(subquery_lens) + query = torch.empty(num_tokens, + num_heads, + head_size, + dtype=dtype, + device='cuda') + query.uniform_(-1e-3, 1e-3) + output = torch.empty(num_tokens, + num_heads, + head_size, + dtype=dtype, + device='cuda') + + + kv = torch.empty(sum(seq_lens), + 2, + num_heads, + head_size, + dtype=dtype, + device='cuda') + kv.uniform_(-1e-3, 1e-3) + key,value = kv.unbind(dim=1) + + k_cache = torch.zeros(cache_size, block_size, num_heads, head_size, dtype=dtype, device='cuda') + v_cache = torch.zeros(cache_size, block_size, num_heads, head_size, dtype=dtype, device='cuda') + k = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype, device='cuda') + v = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype, device='cuda') + values = torch.arange(0, cache_size, dtype=torch.long, device='cuda') + values = values[torch.randperm(cache_size)] + block_table = values[:BS * max_block_per_request].view(BS, max_block_per_request) + b_loc = torch.zeros(BS, MAX_CTX_LEN, dtype=torch.long, device='cuda') + b_seq_len = torch.tensor(seq_lens, dtype=torch.long, device='cuda') + b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long, device='cuda') + b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1], dtype=torch.long, device='cuda'), dim=0) + max_input_len = MAX_SEQ_LEN + # copy kv to cache + b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], dtype=torch.long, device='cuda'), dim=0) + for i in range(BS): + for j in range(subquery_lens[i]): + k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j]) + v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j]) + cur_ctx = 0 + block_id = 0 + while cur_ctx < b_ctx_len[i]: + start_loc = b_seq_start_loc[i] + cur_ctx + if cur_ctx + block_size > b_ctx_len[i]: + end_loc = b_seq_start_loc[i] + b_ctx_len[i] + else: + end_loc = start_loc + block_size + start_slot = block_table[i,block_id] * block_size + end_slot = start_slot + end_loc - start_loc + k_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_(key[start_loc:end_loc]) + v_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_(value[start_loc:end_loc]) + cur_ctx += block_size + block_id += 1 + # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] + k_cache = k_cache.view(-1, block_size, num_heads, head_size//8, 8).permute(0, 2, 3, 1, 4).contiguous() + # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] to V_cache[num_blocks, num_kv_heads, head_size, block_size] + v_cache = v_cache.view(-1, block_size, num_heads, head_size).permute(0, 2, 3, 1).contiguous() + + + context_attention_fwd(query, k, v, output, + k_cache, v_cache, block_table, + b_start_loc, b_seq_len, + b_ctx_len, max_input_len) + torch.cuda.synchronize() + start_time = time.time() + context_attention_fwd(query, k, v, output, + k_cache, v_cache, block_table, + b_start_loc, b_seq_len, + b_ctx_len, max_input_len) + torch.cuda.synchronize() + end_time = time.time() + print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") + + scale = float(1.0 / (head_size**0.5)) + + attn_op = xops.fmha.cutlass.FwOp() + + attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(subquery_lens, seq_lens) + output_ref = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=scale, + op=attn_op, + ) + torch.cuda.synchronize() + start_time = time.time() + output_ref = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=scale, + op=attn_op, + ) + torch.cuda.synchronize() + end_time = time.time() + print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") + output_ref = output_ref.squeeze(0) + # print(output_ref.shape) + # print("max ", torch.max(torch.abs(output_ref - output))) + # print("mean ", torch.mean(torch.abs(output_ref - output))) + # print(output[0,0,:10]) + # print(output_ref[0,0,:10]) + assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) + + result = bench(lambda: xops.memory_efficient_attention_forward(query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=scale, + op=attn_op)) + outputs.append(f"\n xformer: {result.avg()*1e6:3.0f}us±{result.std()*1e6:3.0f}us") + timings_xformer[(MAX_SEQ_LEN, MAX_CTX_LEN)] = result.avg()*1e6 + result2 = bench(lambda: context_attention_fwd(query, k, v, output, + k_cache, v_cache, block_table, + b_start_loc, b_seq_len, + b_ctx_len, max_input_len)) + outputs.append(f"\n triton: {result2.avg()*1e6:3.0f}us±{result2.std()*1e6:3.0f}us") + timings_triton[(MAX_SEQ_LEN, MAX_CTX_LEN)] = result2.avg()*1e6 + print(" | ".join(outputs)) + + +test_contexted_kv_attention(12, 128, torch.float16) \ No newline at end of file From 49aaf421cdece7fbfc42f46ea2b1fd09cd252cc7 Mon Sep 17 00:00:00 2001 From: caoshiyi Date: Wed, 15 Nov 2023 05:41:41 +0000 Subject: [PATCH 02/37] add the prefix class --- vllm/engine/llm_engine.py | 16 ++++++- vllm/prefix.py | 89 +++++++++++++++++++++++++++++++++++++++ vllm/sequence.py | 8 ++++ 3 files changed, 112 insertions(+), 1 deletion(-) create mode 100644 vllm/prefix.py diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c3752b11f5660..b1c7814f58741 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -241,6 +241,7 @@ def add_request( sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, + prefix_pos: Optional[int] = None, ) -> None: """Add a request to the engine's request pool. @@ -269,9 +270,22 @@ def add_request( seq_id = next(self.seq_counter) seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) + # check prefix + if prefix_pos is not None: + # a temp workaround + prefix_pos = prefix_pos // block_size + truncated_prefix_token_ids = prompt_token_ids[:prefix_pos] + prefix = self.scheduler.prefix_pool.fixed_search(hash(truncated_prefix_token_ids)) + if prefix is not None: + seq.prefix = prefix + # prefix.update_freq(1.0) + else: + # create a new prefix + seq.prefix = self.scheduler.prefix_pool.add_prefix(truncated_prefix_token_ids) + # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, - arrival_time) + arrival_time, seq.prefix) # Add the sequence group to the scheduler. self.scheduler.add_seq_group(seq_group) diff --git a/vllm/prefix.py b/vllm/prefix.py new file mode 100644 index 0000000000000..ee775ba8b2826 --- /dev/null +++ b/vllm/prefix.py @@ -0,0 +1,89 @@ +from typing import Dict, List, Optional, Union + +# Define the prefix class, which is a collection of prefix (a sequence of tokens). +# The class contains the following main methods: +# 1. A match method that checks if a prefix matches a given sequence of tokens. +# 2. A swapping method that can load or offload the prefix to or from GPU +# 3. An update_frequency method that updates the frequency of the prefix. +# 4. A get_status method that tells if the prefix is on GPU or not. + + +class Prefix: + def __init__(self, prefix_id, token_ids, block_size): + self.prefix_id = prefix_id + self.token_ids = token_ids + self.length = len(token_ids) + assert self.length % block_size == 0 + self.on_gpu = False + self.on_cpu = False + self.block_table = None + # a lock to prevent multiple sequence from calculating the same prefix + self.swap_to_gpu = False + + # freq-related + self.freq = 1 + self.alpha = 0.8 + self.beta = 0.5 + + def get_block_table_num(self) -> List[int]: + return [block.block_number for block in self.block_table] + + def match(self, tokens): + return tokens[:self.length] == self.token_ids + + # should be called if the prefix is hit for this iteration + def update_freq(self, new_hit_rate): + self.freq = self.alpha * self.freq + (1 - self.alpha) * new_hit_rate + self.alpha = 0.8 + + # should be called if the prefix is not hit for this iteration + def punish_freq(self): + self.alpha = self.beta * self.alpha if self.alpha > 0.1 else 0.1 + + # whether the prefix is on GPU or not + def get_status(self): + return self.on_gpu + + def get_length(self): + return self.length + + +# Define the prefix pool class, which is a collection of prefixes. +# The class contains the following main methods: +# 1. add a prefix to the pool, with a computed hash +# 2. TODO: create subprefix, if one is a prefix of the other: they can share some memory blocks +# 3. efficient_search: given a sequence of tokens, find the longest prefix in the pool that matches the sequence +# 4. fixed_search: given the prefix's hash, find the prefix in the pool +# 5. TODO: approximate_search: given a sequence of tokens, find the similar prefixes in the pool + + +class PrefixPool: + def __init__(self, block_size): + self.prefixes = [] + self.prefixes_hash = {} + self.block_size = block_size + + def add_prefix(self, token_ids: List[int]): + # generate prefix_id + prefix_id = len(self.prefixes) + # create a new prefix + prefix = Prefix(prefix_id, token_ids, self.block_size) + self.prefixes.append(prefix) + # @TODO: compute the hash of the prefix + prefix_hash = hash(prefix.token_ids) + self.prefixes_hash[prefix.prefix_id] = prefix_hash + return prefix + + # @TODO: this one should also come with a method to identify the prefix + def efficient_search(self, token_ids: List[int]): + # improve this search + for prefix in self.prefixes: + if prefix.match(token_ids): + return prefix + return None + + # use this first, if we already know from the application which part of the tokens are prefix. + def fixed_search(self, prefix_hash): + prefix_id = self.prefixes_hash[prefix_hash] + return self.prefixes[prefix_id] + diff --git a/vllm/sequence.py b/vllm/sequence.py index ecfaee6e8c3d6..be60ca7f71828 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Union from vllm.block import LogicalTokenBlock +from vllm.prefix import Prefix from vllm.sampling_params import SamplingParams PromptLogprobs = List[Optional[Dict[int, float]]] @@ -113,9 +114,11 @@ def __init__( prompt: str, prompt_token_ids: List[int], block_size: int, + prefix: Optional[Prefix] = None, ) -> None: self.seq_id = seq_id self.prompt = prompt + self.prefix = prefix self.block_size = block_size self.data = SequenceData(prompt_token_ids) @@ -236,12 +239,14 @@ def __init__( seqs: List[Sequence], sampling_params: SamplingParams, arrival_time: float, + prefix: Optional[Prefix] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.sampling_params = sampling_params self.arrival_time = arrival_time self.prompt_logprobs: Optional[PromptLogprobs] = None + self.prefix = prefix @property def prompt(self) -> str: @@ -335,6 +340,7 @@ class SequenceGroupMetadata: sampling_params: The sampling parameters used to generate the outputs. block_tables: The block tables. (Seq id -> list of physical block numbers) + prefix: The prefix of the prompt of the sequence group. """ def __init__( @@ -344,12 +350,14 @@ def __init__( seq_data: Dict[int, SequenceData], sampling_params: SamplingParams, block_tables: Dict[int, List[int]], + prefix: Optional[Prefix] = None, ) -> None: self.request_id = request_id self.is_prompt = is_prompt self.seq_data = seq_data self.sampling_params = sampling_params self.block_tables = block_tables + self.prefix = prefix class SequenceOutputs: From ac96e7f32e550b4ddecb254b78da09d1fbebb58a Mon Sep 17 00:00:00 2001 From: caoshiyi Date: Wed, 15 Nov 2023 06:06:18 +0000 Subject: [PATCH 03/37] modify block manager & scheduler --- vllm/core/block_manager.py | 100 ++++++++++++++++++++++++++++++++++++- vllm/core/scheduler.py | 40 +++++++++++++++ 2 files changed, 139 insertions(+), 1 deletion(-) diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 57349e7fe7f92..4a62a0fe84f06 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -4,6 +4,7 @@ from vllm.block import PhysicalTokenBlock from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device +from vllm.prefix import PrefixPool, Prefix class BlockAllocator: @@ -91,6 +92,10 @@ def can_allocate(self, seq_group: SequenceGroup) -> bool: # the same prompt. This may not be true for preempted sequences. seq = seq_group.get_seqs()[0] num_required_blocks = len(seq.logical_token_blocks) + + if seq_group.prefix is not None and seq_group.prefix.on_gpu: + num_required_blocks -= seq_group.prefix.get_length() // self.block_size + if self.block_sliding_window is not None: num_required_blocks = min(num_required_blocks, self.block_sliding_window) @@ -105,8 +110,26 @@ def allocate(self, seq_group: SequenceGroup) -> None: seq = seq_group.get_seqs()[0] # Allocate new physical token blocks that will store the prompt tokens. + num_prompt_blocks = len(seq.logical_token_blocks) + block_table: BlockTable = [] - for logical_idx in range(len(seq.logical_token_blocks)): + prefix_block_table: BlockTable = [] + num_prefix_blocks = 0 + if seq_group.prefix is not None: + # prefix is already on gpu or will be swapped in before the actual computation + if seq_group.prefix.on_gpu: + num_prompt_blocks -= seq_group.prefix.get_length() // self.block_size + for block in seq_group.prefix.block_table: + block.ref_count += seq_group.num_seqs() + block_table.append(block) + # TODO: will need to perform the copy-on-write if prefix length is not a multiple of block size + + # allocate blocks for the prefix, we need to calculate the prefix's kv in this run + elif not seq_group.prefix.swap_to_gpu: + num_prefix_blocks = seq_group.prefix.get_length() // self.block_size + seq_group.prefix.swap_to_gpu = True + + for logical_idx in range(num_prompt_blocks): if (self.block_sliding_window is not None and logical_idx >= self.block_sliding_window): block = block_table[logical_idx % self.block_sliding_window] @@ -115,10 +138,16 @@ def allocate(self, seq_group: SequenceGroup) -> None: # Set the reference counts of the token blocks. block.ref_count = seq_group.num_seqs() block_table.append(block) + if logical_idx < num_prefix_blocks: + block.ref_count += 1 + prefix_block_table.append(block) # Assign the block table for each sequence. for seq in seq_group.get_seqs(): self.block_tables[seq.seq_id] = block_table.copy() + + if num_prefix_blocks > 0: + seq_group.prefix.block_table = prefix_block_table.copy() def can_append_slot(self, seq_group: SequenceGroup) -> bool: # Simple heuristic: If there is at least one free block @@ -188,12 +217,29 @@ def can_swap_in(self, seq_group: SequenceGroup) -> bool: num_required_blocks = len(blocks) + num_swapped_seqs return num_free_blocks - num_required_blocks >= self.watermark_blocks + def can_swap_in_prefix(self, prefix: Prefix) -> bool: + blocks = prefix.block_table + num_free_blocks = self.gpu_allocator.get_num_free_blocks() + # NOTE: Conservatively, we assume that every sequence will allocate + # at least one free block right after the swap-in. + # NOTE: This should match the logic in can_append_slot(). + num_required_blocks = len(blocks) + return num_free_blocks - num_required_blocks >= self.watermark_blocks + def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: # CPU block -> GPU block. + if seq_group.prefix is not None: + # make sure to swap in the prefix first + assert seq_group.prefix.on_gpu == True + mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): new_block_table: BlockTable = [] block_table = self.block_tables[seq.seq_id] + if seq_group.prefix is not None: + for block in seq_group.prefix.block_table: + new_block_table.append(block) + block.ref_count += 1 for cpu_block in block_table: if cpu_block in mapping: @@ -213,10 +259,35 @@ def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: } return block_number_mapping + def swap_in_prefix(self, prefix: Prefix) -> Dict[int, int]: + # CPU block -> GPU block. + mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} + new_block_table = [] + block_table = prefix.block_table + + for cpu_block in enumerate(block_table): + # ref_count = 1 + gpu_block = self.gpu_allocator.allocate() + mapping[cpu_block] = gpu_block + new_block_table.append(gpu_block) + # Free the CPU block swapped in to GPU. + self.cpu_allocator.free(cpu_block) + prefix.block_table = new_block_table + + block_number_mapping = { + cpu_block.block_number: gpu_block.block_number + for cpu_block, gpu_block in mapping.items() + } + return block_number_mapping + def can_swap_out(self, seq_group: SequenceGroup) -> bool: blocks = self._get_physical_blocks(seq_group) return len(blocks) <= self.cpu_allocator.get_num_free_blocks() + def can_swap_out_prefix(self, prefix: Prefix) -> bool: + blocks = prefix.block_table + return len(blocks) <= self.cpu_allocator.get_num_free_blocks() + def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: # GPU block -> CPU block. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} @@ -225,6 +296,11 @@ def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: block_table = self.block_tables[seq.seq_id] for gpu_block in block_table: + # do not swap out the prefix + if seq_group.prefix is not None and gpu_block in seq_group.prefix.block_table: + self.gpu_allocator.free(gpu_block) + continue + if gpu_block in mapping: cpu_block = mapping[gpu_block] cpu_block.ref_count += 1 @@ -241,6 +317,28 @@ def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: for gpu_block, cpu_block in mapping.items() } return block_number_mapping + + def swap_out_prefix(self, prefix: Prefix) -> Dict[int, int]: + # GPU block -> CPU block. + # make sure all the reference seq are finished or swapped out before swapping out the prefix + mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} + new_block_table = [] + block_table = prefix.block_table + + for gpu_block in block_table: + cpu_block = self.cpu_allocator.allocate() + mapping[gpu_block] = cpu_block + new_block_table.append(cpu_block) + # Free the GPU block swapped out to CPU. + assert gpu_block.ref_count == 1 + self.gpu_allocator.free(gpu_block) + prefix.block_table = new_block_table + + block_number_mapping = { + gpu_block.block_number: cpu_block.block_number + for gpu_block, cpu_block in mapping.items() + } + return block_number_mapping def _free_block_table(self, block_table: BlockTable) -> None: for block in set(block_table): diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 0c98c063c8694..0b2693f1042af 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -8,6 +8,7 @@ from vllm.logger import init_logger from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) +from vllm.prefix import Prefix, PrefixPool logger = init_logger(__name__) @@ -74,6 +75,8 @@ def __init__( num_gpu_blocks=self.cache_config.num_gpu_blocks, num_cpu_blocks=self.cache_config.num_cpu_blocks, sliding_window=self.cache_config.sliding_window) + + self.prefix_pool = PrefixPool(self.cache_config.block_size) # TODO(zhuohan): Use deque instead of list for better performance. # Sequence groups in the WAITING state. @@ -177,11 +180,22 @@ def _schedule(self) -> SchedulerOutputs: seq_lens = new_seq_lens seq_group = self.waiting.pop(0) + # swap in the prefix if it is on CPU + if seq_group.prefix is not None and seq_group.prefix.on_cpu: + # prefix.on_gpu will be set inside this function + self._swap_in_prefix(seq_group.prefix, blocks_to_swap_in) + # if the prefix hasn't been compuated, allocate blocks for it and set prefix.swap_to_gpu to True self._allocate(seq_group) self.running.append(seq_group) num_curr_seqs += num_new_seqs scheduled.append(seq_group) + # set the prefix state + for seq_group in scheduled: + if seq_group.prefix is not None: + seq_group.prefix.on_gpu = True + seq_group.prefix.swap_to_gpu = False + if scheduled or ignored_seq_groups: scheduler_outputs = SchedulerOutputs( scheduled_seq_groups=scheduled, @@ -288,6 +302,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: seq_data=seq_data, sampling_params=seq_group.sampling_params, block_tables=block_tables, + prefix=seq_group.prefix, ) seq_group_metadata_list.append(seq_group_metadata) return seq_group_metadata_list, scheduler_outputs @@ -398,3 +413,28 @@ def _swap_out( blocks_to_swap_out.update(mapping) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq.status = SequenceStatus.SWAPPED + + def _swap_in_prefix( + self, + prefix: Prefix, + blocks_to_swap_in: Dict[int, int], + ) -> None: + mapping = self.block_manager.swap_in_prefix(prefix) + blocks_to_swap_in.update(mapping) + prefix.on_gpu = True + + def _swap_out_prefix( + self, + prefix: Prefix, + blocks_to_swap_out: Dict[int, int], + ) -> None: + if not self.block_manager.can_swap_out_prefix(prefix): + # FIXME(woosuk): Abort the sequence group instead of aborting the + # entire engine. + raise RuntimeError( + "Aborted due to the lack of CPU swap space. Please increase " + "the swap space to avoid this error.") + mapping = self.block_manager.swap_out_prefix(prefix) + blocks_to_swap_out.update(mapping) + prefix.on_cpu = True + prefix.on_gpu = False From bf3faa26387b2d76a7e5181913946b4fd9784e3e Mon Sep 17 00:00:00 2001 From: caoshiyi Date: Wed, 15 Nov 2023 06:50:23 +0000 Subject: [PATCH 04/37] prepare inputs to the kernel --- vllm/model_executor/input_metadata.py | 9 ++- vllm/model_executor/layers/attention.py | 73 +++++++++++++++++++++---- vllm/worker/worker.py | 26 +++++++-- 3 files changed, 92 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index b3b5852e48769..68f7d5a485648 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -25,7 +25,9 @@ def __init__( seq_groups: List[Tuple[List[int], SamplingParams]], seq_data: Dict[int, SequenceData], prompt_lens: List[int], + max_seq_len: int, slot_mapping: torch.Tensor, + start_loc:torch.Tensor, context_lens: torch.Tensor, max_context_len: int, block_tables: torch.Tensor, @@ -36,7 +38,12 @@ def __init__( self.seq_groups = seq_groups self.seq_data = seq_data self.prompt_lens = prompt_lens + self.prompt_lens_tensor = torch.tensor(prompt_lens, + dtype=torch.int, + device=slot_mapping.device) + self.max_seq_len = max_seq_len self.slot_mapping = slot_mapping + self.start_loc = start_loc self.context_lens = context_lens self.max_context_len = max_context_len self.block_tables = block_tables @@ -69,7 +76,7 @@ def __init__( self.max_num_blocks_per_seq = block_tables.shape[1] else: self.max_num_blocks_per_seq = 0 - assert block_tables.shape[0] == self.num_generation_tokens + # assert block_tables.shape[0] == self.num_generation_tokens # Set during the execution of the first attention op. self.attn_bias: Optional[AttentionBias] = None diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index b94c82e132583..db466ce3b7d25 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -11,6 +11,7 @@ from vllm import cache_ops from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.triton_kernel.prefix_prefill import context_attention_fwd _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. @@ -114,6 +115,48 @@ def multi_query_kv_attention( output.copy_(out.squeeze(0)) return output + def multi_query_cached_kv_attention( + self, + output: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + """Normal attention for the prompt tokens. + + Args: + output: shape = [num_prompt_tokens, num_heads, head_size] + query: shape = [num_prompt_tokens, num_heads, head_size] + key: shape = [num_prompt_tokens, num_kv_heads, head_size] + value: shape = [num_prompt_tokens, num_kv_heads, head_size] + input_metadata: metadata for prefix-enabled prefill attention. + """ + + if self.num_kv_heads != self.num_heads: + # Project the key and value tensors to the desired number of heads. + key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1) + value = torch.repeat_interleave(value, self.num_queries_per_kv,dim=1) + + context_attention_fwd( + query, + key, + value, + output, + key_cache, + value_cache, + input_metadata.block_tables, # [BS, max_block_per_request] + input_metadata.start_loc, + input_metadata.prompt_lens_tensor, + input_metadata.context_lens, + input_metadata.max_seq_len + ) + + + return output + def get_alibi_slopes(self) -> Optional[torch.Tensor]: """Returns the slopes for the alibi attention bias. @@ -241,23 +284,33 @@ def forward( # Pre-allocate the output tensor. output = torch.empty_like(query) + # Wait until the cache op is done. + if cache_event is not None: + cache_event.wait() + # Compute the attention op for prompts. num_prompt_tokens = input_metadata.num_prompt_tokens if num_prompt_tokens > 0: # Prompt run. assert input_metadata.num_generation_tokens == 0 - self.set_attn_bias(input_metadata, dtype=query.dtype) - self.multi_query_kv_attention( - output, - query, - key, - value, + # self.set_attn_bias(input_metadata, dtype=query.dtype) + # self.multi_query_kv_attention( + # output, + # query, + # key, + # value, + # input_metadata, + # ) + self.multi_query_cached_kv_attention( + output[:num_prompt_tokens], + query[:num_prompt_tokens], + key[:num_prompt_tokens], + value[:num_prompt_tokens], + key_cache, + value_cache, input_metadata, ) - - # Wait until the cache op is done. - if cache_event is not None: - cache_event.wait() + # TODO(shiyi): perform multi_query_cached_kv_attention after the cache op for better kernel performance # Reshape the keys and values and store them in the cache. # When key_cache and value_cache are not provided, the new key diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index bbbc2e7f45a6e..0f597bc457c8d 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -161,6 +161,9 @@ def _prepare_inputs( # Add prompt tokens. prompt_lens: List[int] = [] + context_lens: List[int] = [] + subquery_lens: List[int] = [] + prefix_block_tables: List[List[int]] = [] for seq_group_metadata in seq_group_metadata_list: if not seq_group_metadata.is_prompt: continue @@ -176,6 +179,15 @@ def _prepare_inputs( prompt_tokens = seq_data.get_token_ids() prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) + prefix_len = 0 + if seq_group_metadata.prefix is not None and seq_group_metadata.prefix.on_gpu: + prefix_len = seq_group_metadata.prefix.get_length() + assert prefix_len % self.block_size == 0 + prompt_tokens = prompt_tokens[prefix_len:] + prefix_block_tables.append(seq_group_metadata.prefix.get_block_table_num()) + # actual prompt lens + context_lens.append(prefix_len) + subquery_lens.append(prompt_len-prefix_len) if sampling_params.prompt_logprobs is not None: # NOTE: prompt token positions do not need sample, skip @@ -188,7 +200,7 @@ def _prepare_inputs( input_tokens.append(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.append(list(range(prompt_len))) + input_positions.extend(range(prefix_len,prefix_len+len(prompt_tokens))) if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized @@ -199,7 +211,7 @@ def _prepare_inputs( # Compute the slot mapping. slot_mapping.append([]) block_table = seq_group_metadata.block_tables[seq_id] - for i in range(prompt_len): + for i in range(prefix_len, prompt_len): block_number = block_table[i // self.block_size] block_offset = i % self.block_size slot = block_number * self.block_size + block_offset @@ -208,9 +220,8 @@ def _prepare_inputs( # Add generation tokens. max_context_len = 0 max_num_blocks_per_seq = 0 - context_lens: List[int] = [] generation_block_tables: List[List[int]] = [] - max_seq_len = max(prompt_lens) if prompt_lens else 1 + max_seq_len = max(subquery_lens) if subquery_lens else 1 for i, seq_group_metadata in enumerate(seq_group_metadata_list): if seq_group_metadata.is_prompt: # We need to do this in this loop as we need to know max_seq_len @@ -283,9 +294,10 @@ def _prepare_inputs( _pad_to_max(mapping, max_seq_len, pad=-1) for mapping in slot_mapping ] + block_tables = generation_block_tables if prefix_block_tables == [] else prefix_block_tables padded_block_tables = [ _pad_to_max(block_table, max_num_blocks_per_seq, pad=0) - for block_table in generation_block_tables + for block_table in block_tables ] # Convert to tensors. @@ -315,12 +327,16 @@ def _prepare_inputs( seq_data: Dict[int, SequenceData] = {} for seq_group_metadata in seq_group_metadata_list: seq_data.update(seq_group_metadata.seq_data) + + start_loc_tensor = torch.arange(0, len(prompt_lens)*max_seq_len, max_seq_len, dtype=torch.long, device='cuda') input_metadata = InputMetadata( seq_groups=seq_groups, seq_data=seq_data, prompt_lens=prompt_lens, + max_seq_len=max_seq_len, slot_mapping=slot_mapping_tensor, + start_loc=start_loc_tensor, context_lens=context_lens_tensor, max_context_len=max_context_len, block_tables=block_tables_tensor, From 074a90f4642134a387f883f692df706593faa969 Mon Sep 17 00:00:00 2001 From: caoshiyi Date: Thu, 16 Nov 2023 07:53:34 +0000 Subject: [PATCH 05/37] fix --- vllm/model_executor/input_metadata.py | 2 +- vllm/model_executor/layers/attention.py | 37 ++++--- .../layers/triton_kernel/benchmark_utils.py | 102 ++++++++++++++++++ .../layers/triton_kernel/prefix_prefill.py | 4 +- vllm/worker/worker.py | 2 +- 5 files changed, 126 insertions(+), 21 deletions(-) create mode 100644 vllm/model_executor/layers/triton_kernel/benchmark_utils.py diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index 68f7d5a485648..24b49eb3db104 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -71,7 +71,7 @@ def __init__( self.num_prompts = len(prompt_lens) self.num_prompt_tokens = self.num_prompts * self.max_prompt_len - self.num_generation_tokens = context_lens.shape[0] + self.num_generation_tokens = context_lens.shape[0] if not prompt_lens else 0 if block_tables.numel() > 0: self.max_num_blocks_per_seq = block_tables.shape[1] else: diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index db466ce3b7d25..12fc52b1874d4 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -293,23 +293,26 @@ def forward( if num_prompt_tokens > 0: # Prompt run. assert input_metadata.num_generation_tokens == 0 - # self.set_attn_bias(input_metadata, dtype=query.dtype) - # self.multi_query_kv_attention( - # output, - # query, - # key, - # value, - # input_metadata, - # ) - self.multi_query_cached_kv_attention( - output[:num_prompt_tokens], - query[:num_prompt_tokens], - key[:num_prompt_tokens], - value[:num_prompt_tokens], - key_cache, - value_cache, - input_metadata, - ) + if key_cache is None or value_cache is None: + # No cache provided. Perform normal attention. + self.set_attn_bias(input_metadata, dtype=query.dtype) + self.multi_query_kv_attention( + output, + query, + key, + value, + input_metadata, + ) + else: + self.multi_query_cached_kv_attention( + output[:num_prompt_tokens], + query[:num_prompt_tokens], + key[:num_prompt_tokens], + value[:num_prompt_tokens], + key_cache, + value_cache, + input_metadata, + ) # TODO(shiyi): perform multi_query_cached_kv_attention after the cache op for better kernel performance # Reshape the keys and values and store them in the cache. diff --git a/vllm/model_executor/layers/triton_kernel/benchmark_utils.py b/vllm/model_executor/layers/triton_kernel/benchmark_utils.py new file mode 100644 index 0000000000000..f22df2a89769d --- /dev/null +++ b/vllm/model_executor/layers/triton_kernel/benchmark_utils.py @@ -0,0 +1,102 @@ +import abc +import dataclasses +import gc +import itertools +import time +from typing import Callable + +import numpy as np +import torch + + +class Benchmark(abc.ABC): + + def setup(self): + pass + + def before_run(self): + pass + + @abc.abstractmethod + def run(self): + pass + + def after_run(self): + pass + + def teardown(self): + pass + + +class wrap_benchmark(Benchmark): + + def __init__(self, fn_run: Callable[[], None]): + self.fn_run = fn_run + + def run(self): + self.fn_run() + + +@dataclasses.dataclass +class BenchResult: + warmup: int + repeat: int + latency: np.ndarray + + def avg(self) -> np.ndarray: + return np.mean(self.latency) + + def std(self) -> np.ndarray: + return np.std(self.latency) + + def avg_std(self) -> np.ndarray: + return self.avg(), self.std() + + +def bench( + f: Callable[[], None], + warmup: int = 100, + repeat: int = 500, +) -> BenchResult: + if isinstance(f, Benchmark): + b = f + else: + b = wrap_benchmark(f) + + cache = torch.empty(256 * 2**20, dtype=torch.int8, device="cuda:0") + b.setup() + + latency = np.zeros(repeat, dtype=np.float64) + for i in range(-warmup, repeat): + b.before_run() + cache.zero_() + + torch.cuda.synchronize() + t0 = time.perf_counter() + b.run() + torch.cuda.synchronize() + t1 = time.perf_counter() + + b.after_run() + + if i >= 0: + latency[i] = t1 - t0 + + b.teardown() + del cache + return BenchResult(warmup, repeat, latency) + + +def gc_torch(): + gc.collect() + torch.cuda.empty_cache() + + +def batched(iterable, n): + "Batch data into tuples of length n. The last batch may be shorter." + # batched('ABCDEFG', 3) --> ABC DEF G + if n < 1: + raise ValueError('n must be at least one') + it = iter(iterable) + while batch := list(itertools.islice(it, n)): + yield batch diff --git a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py index 7e8cdc003c5fe..9c3de3e2b6a1d 100644 --- a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py +++ b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py @@ -7,7 +7,7 @@ import torch.nn.functional as F import matplotlib.pyplot as plt -from benchmark_utils import bench, gc_torch +from .benchmark_utils import bench, gc_torch if triton.__version__ >= "2.1.0": @triton.jit @@ -463,4 +463,4 @@ def bench_contexted_kv_attention( print(" | ".join(outputs)) -test_contexted_kv_attention(12, 128, torch.float16) \ No newline at end of file +# test_contexted_kv_attention(12, 128, torch.float16) \ No newline at end of file diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 0f597bc457c8d..92e387122216e 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -200,7 +200,7 @@ def _prepare_inputs( input_tokens.append(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.extend(range(prefix_len,prefix_len+len(prompt_tokens))) + input_positions.append(list(range(prefix_len,prefix_len+len(prompt_tokens)))) if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized From de69ca4a49a5df8687882da20760893b046d91fe Mon Sep 17 00:00:00 2001 From: caoshiyi Date: Thu, 16 Nov 2023 22:06:15 +0000 Subject: [PATCH 06/37] add prefix_pos --- examples/api_client.py | 6 ++++-- vllm/engine/async_llm_engine.py | 9 +++++++-- vllm/engine/llm_engine.py | 4 ++-- vllm/entrypoints/api_server.py | 3 ++- vllm/prefix.py | 6 +++++- 5 files changed, 20 insertions(+), 8 deletions(-) diff --git a/examples/api_client.py b/examples/api_client.py index 70ec8c5492124..5fecd739af4f1 100644 --- a/examples/api_client.py +++ b/examples/api_client.py @@ -15,12 +15,14 @@ def clear_line(n: int = 1) -> None: def post_http_request(prompt: str, + prefix_pos: int, api_url: str, n: int = 1, stream: bool = False) -> requests.Response: headers = {"User-Agent": "Test Client"} pload = { "prompt": prompt, + "prefix_pos": prefix_pos, "n": n, "use_beam_search": True, "temperature": 0.0, @@ -52,7 +54,7 @@ def get_response(response: requests.Response) -> List[str]: parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=8000) parser.add_argument("--n", type=int, default=4) - parser.add_argument("--prompt", type=str, default="San Francisco is a") + parser.add_argument("--prompt", type=str, default="San Francisco is a "*32) parser.add_argument("--stream", action="store_true") args = parser.parse_args() prompt = args.prompt @@ -61,7 +63,7 @@ def get_response(response: requests.Response) -> List[str]: stream = args.stream print(f"Prompt: {prompt!r}\n", flush=True) - response = post_http_request(prompt, api_url, n, stream) + response = post_http_request(prompt, 32, api_url, n, stream) if stream: num_printed_lines = 0 diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index bc36b64f7df0f..f17f6d88943d2 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -357,6 +357,7 @@ async def add_request( sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, + prefix_pos: Optional[int] = None, ) -> AsyncStream: if self.log_requests: shortened_prompt = prompt @@ -369,6 +370,7 @@ async def add_request( max_log_len] logger.info(f"Received request {request_id}: " f"prompt: {shortened_prompt!r}, " + f"prefix_pos: {prefix_pos}," f"sampling params: {sampling_params}, " f"prompt token ids: {shortened_token_ids}.") @@ -387,13 +389,15 @@ async def add_request( prompt=prompt, sampling_params=sampling_params, prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time) + arrival_time=arrival_time, + prefix_pos=prefix_pos) return stream async def generate( self, prompt: Optional[str], + prefix_pos: Optional[int], sampling_params: SamplingParams, request_id: str, prompt_token_ids: Optional[List[int]] = None) -> RequestOutput: @@ -424,7 +428,8 @@ async def generate( prompt, sampling_params, prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time) + arrival_time=arrival_time, + prefix_pos=prefix_pos) async for request_output in stream: yield request_output diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index b1c7814f58741..0482815bb1883 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -273,9 +273,9 @@ def add_request( # check prefix if prefix_pos is not None: # a temp workaround - prefix_pos = prefix_pos // block_size + prefix_pos = (prefix_pos // block_size) * block_size truncated_prefix_token_ids = prompt_token_ids[:prefix_pos] - prefix = self.scheduler.prefix_pool.fixed_search(hash(truncated_prefix_token_ids)) + prefix = self.scheduler.prefix_pool.fixed_search(hash(tuple(truncated_prefix_token_ids))) if prefix is not None: seq.prefix = prefix # prefix.update_freq(1.0) diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index fb29837da8cf0..8ff2367409efd 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -34,11 +34,12 @@ async def generate(request: Request) -> Response: """ request_dict = await request.json() prompt = request_dict.pop("prompt") + prefix_pos = request_dict.pop("prefix_pos", None) stream = request_dict.pop("stream", False) sampling_params = SamplingParams(**request_dict) request_id = random_uuid() - results_generator = engine.generate(prompt, sampling_params, request_id) + results_generator = engine.generate(prompt, prefix_pos, sampling_params, request_id) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: diff --git a/vllm/prefix.py b/vllm/prefix.py index ee775ba8b2826..06c0a95936acd 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -13,6 +13,8 @@ def __init__(self, prefix_id, token_ids, block_size): self.prefix_id = prefix_id self.token_ids = token_ids self.length = len(token_ids) + print("prefix length: ", self.length) + print("block size: ", block_size) assert self.length % block_size == 0 self.on_gpu = False self.on_cpu = False @@ -70,7 +72,7 @@ def add_prefix(self, token_ids: List[int]): prefix = Prefix(prefix_id, token_ids, self.block_size) self.prefixes.append(prefix) # @TODO: compute the hash of the prefix - prefix_hash = hash(prefix.token_ids) + prefix_hash = hash(tuple(prefix.token_ids)) self.prefixes_hash[prefix.prefix_id] = prefix_hash return prefix @@ -84,6 +86,8 @@ def efficient_search(self, token_ids: List[int]): # use this first, if we already know from the application which part of the tokens are prefix. def fixed_search(self, prefix_hash): + if prefix_hash not in self.prefixes_hash: + return None prefix_id = self.prefixes_hash[prefix_hash] return self.prefixes[prefix_id] From f5bf25ab43dd78c49328742d626db696d02e5d2d Mon Sep 17 00:00:00 2001 From: caoshiyi Date: Fri, 17 Nov 2023 01:52:34 +0000 Subject: [PATCH 07/37] fix prefix state transition --- vllm/core/scheduler.py | 6 ------ vllm/worker/worker.py | 8 +++++++- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 0b2693f1042af..91662688bf824 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -190,12 +190,6 @@ def _schedule(self) -> SchedulerOutputs: num_curr_seqs += num_new_seqs scheduled.append(seq_group) - # set the prefix state - for seq_group in scheduled: - if seq_group.prefix is not None: - seq_group.prefix.on_gpu = True - seq_group.prefix.swap_to_gpu = False - if scheduled or ignored_seq_groups: scheduler_outputs = SchedulerOutputs( scheduled_seq_groups=scheduled, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 92e387122216e..101278c290add 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -229,7 +229,8 @@ def _prepare_inputs( seq_ids) == 1, "Prompt input should have only one seq." sampling_params = seq_group_metadata.sampling_params assert len(prompt_lens) == len(seq_group_metadata_list) - prompt_len = prompt_lens[i] + # here is the subprompt len + prompt_len = subquery_lens[i] if sampling_params.prompt_logprobs is not None: selected_token_indices.extend( range(selected_token_start_idx, @@ -237,6 +238,11 @@ def _prepare_inputs( selected_token_indices.append(selected_token_start_idx + prompt_len - 1) selected_token_start_idx += max_seq_len + + # set the prefix state + if seq_group_metadata.prefix is not None and seq_group_metadata.prefix.swap_to_gpu: + seq_group_metadata.prefix.on_gpu = True + seq_group_metadata.prefix.swap_to_gpu = False continue seq_ids = list(seq_group_metadata.seq_data.keys()) From 49b19bba3fb066887a4df15d7f250d83cead003f Mon Sep 17 00:00:00 2001 From: caoshiyi Date: Fri, 17 Nov 2023 07:37:39 +0000 Subject: [PATCH 08/37] tested on single request --- vllm/engine/llm_engine.py | 1 + vllm/model_executor/layers/attention.py | 5 ++++- vllm/prefix.py | 4 +++- vllm/worker/worker.py | 1 + 4 files changed, 9 insertions(+), 2 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0482815bb1883..b01286aac3424 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -278,6 +278,7 @@ def add_request( prefix = self.scheduler.prefix_pool.fixed_search(hash(tuple(truncated_prefix_token_ids))) if prefix is not None: seq.prefix = prefix + print("prefix status: ", "on gpu" if prefix.get_status() else "on cpu") # prefix.update_freq(1.0) else: # create a new prefix diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 12fc52b1874d4..9d978a064507d 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -293,7 +293,7 @@ def forward( if num_prompt_tokens > 0: # Prompt run. assert input_metadata.num_generation_tokens == 0 - if key_cache is None or value_cache is None: + if key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0: # No cache provided. Perform normal attention. self.set_attn_bias(input_metadata, dtype=query.dtype) self.multi_query_kv_attention( @@ -304,6 +304,8 @@ def forward( input_metadata, ) else: + print("Using prefix-enabled prefill attention") + print("num_prompt_tokens: ", num_prompt_tokens) self.multi_query_cached_kv_attention( output[:num_prompt_tokens], query[:num_prompt_tokens], @@ -334,6 +336,7 @@ def forward( value_cache, slot_mapping, ) + torch.cuda.synchronize() if input_metadata.num_generation_tokens > 0: # Decoding run. diff --git a/vllm/prefix.py b/vllm/prefix.py index 06c0a95936acd..508ee2fdf4da9 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -73,7 +73,8 @@ def add_prefix(self, token_ids: List[int]): self.prefixes.append(prefix) # @TODO: compute the hash of the prefix prefix_hash = hash(tuple(prefix.token_ids)) - self.prefixes_hash[prefix.prefix_id] = prefix_hash + # self.prefixes_hash[prefix.prefix_id] = prefix_hash + self.prefixes_hash[prefix_hash] = prefix.prefix_id return prefix # @TODO: this one should also come with a method to identify the prefix @@ -88,6 +89,7 @@ def efficient_search(self, token_ids: List[int]): def fixed_search(self, prefix_hash): if prefix_hash not in self.prefixes_hash: return None + print("Found prefix in the pool.") prefix_id = self.prefixes_hash[prefix_hash] return self.prefixes[prefix_id] diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 101278c290add..3d7a31bda8973 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -301,6 +301,7 @@ def _prepare_inputs( for mapping in slot_mapping ] block_tables = generation_block_tables if prefix_block_tables == [] else prefix_block_tables + print("block_tables", block_tables) padded_block_tables = [ _pad_to_max(block_table, max_num_blocks_per_seq, pad=0) for block_table in block_tables From 1619d336e5a8c64058d0598c1054544052fb44f1 Mon Sep 17 00:00:00 2001 From: caoshiyi Date: Sat, 18 Nov 2023 07:21:02 +0000 Subject: [PATCH 09/37] add prefix_pos for offline inference --- vllm/engine/llm_engine.py | 19 ++++++++++--------- vllm/entrypoints/llm.py | 7 +++++-- vllm/model_executor/layers/attention.py | 13 ++++++------- vllm/worker/worker.py | 4 +++- 4 files changed, 24 insertions(+), 19 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index b01286aac3424..a788adb306712 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -274,15 +274,16 @@ def add_request( if prefix_pos is not None: # a temp workaround prefix_pos = (prefix_pos // block_size) * block_size - truncated_prefix_token_ids = prompt_token_ids[:prefix_pos] - prefix = self.scheduler.prefix_pool.fixed_search(hash(tuple(truncated_prefix_token_ids))) - if prefix is not None: - seq.prefix = prefix - print("prefix status: ", "on gpu" if prefix.get_status() else "on cpu") - # prefix.update_freq(1.0) - else: - # create a new prefix - seq.prefix = self.scheduler.prefix_pool.add_prefix(truncated_prefix_token_ids) + if prefix_pos > 0: + truncated_prefix_token_ids = prompt_token_ids[:prefix_pos] + prefix = self.scheduler.prefix_pool.fixed_search(hash(tuple(truncated_prefix_token_ids))) + if prefix is not None: + seq.prefix = prefix + # print("prefix status: ", "on gpu" if prefix.get_status() else "on cpu") + # prefix.update_freq(1.0) + else: + # create a new prefix + seq.prefix = self.scheduler.prefix_pool.add_prefix(truncated_prefix_token_ids) # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 9dddfc1acd9cc..4accc8fd28c95 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -108,6 +108,7 @@ def generate( prompts: Optional[Union[str, List[str]]] = None, sampling_params: Optional[SamplingParams] = None, prompt_token_ids: Optional[List[List[int]]] = None, + prefix_pos: Optional[Union[int, List[int]]] = None, use_tqdm: bool = True, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -149,11 +150,12 @@ def generate( num_requests = len(prompt_token_ids) for i in range(num_requests): prompt = prompts[i] if prompts is not None else None + prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None if prompt_token_ids is None: token_ids = None else: token_ids = prompt_token_ids[i] - self._add_request(prompt, sampling_params, token_ids) + self._add_request(prompt, sampling_params, token_ids, prefix_pos_i) return self._run_engine(use_tqdm) def _add_request( @@ -161,10 +163,11 @@ def _add_request( prompt: Optional[str], sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]], + prefix_pos: Optional[int], ) -> None: request_id = str(next(self.request_counter)) self.llm_engine.add_request(request_id, prompt, sampling_params, - prompt_token_ids) + prompt_token_ids, prefix_pos=prefix_pos) def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # Initialize tqdm. diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 9d978a064507d..2ea0d73b1a4d6 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -304,13 +304,13 @@ def forward( input_metadata, ) else: - print("Using prefix-enabled prefill attention") - print("num_prompt_tokens: ", num_prompt_tokens) + # print("Using prefix-enabled prefill attention") + # print("block tables: ", input_metadata.block_tables) self.multi_query_cached_kv_attention( - output[:num_prompt_tokens], - query[:num_prompt_tokens], - key[:num_prompt_tokens], - value[:num_prompt_tokens], + output, + query, + key, + value, key_cache, value_cache, input_metadata, @@ -336,7 +336,6 @@ def forward( value_cache, slot_mapping, ) - torch.cuda.synchronize() if input_metadata.num_generation_tokens > 0: # Decoding run. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 3d7a31bda8973..f22a67844d070 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -185,6 +185,8 @@ def _prepare_inputs( assert prefix_len % self.block_size == 0 prompt_tokens = prompt_tokens[prefix_len:] prefix_block_tables.append(seq_group_metadata.prefix.get_block_table_num()) + else: + prefix_block_tables.append([]) # actual prompt lens context_lens.append(prefix_len) subquery_lens.append(prompt_len-prefix_len) @@ -301,7 +303,7 @@ def _prepare_inputs( for mapping in slot_mapping ] block_tables = generation_block_tables if prefix_block_tables == [] else prefix_block_tables - print("block_tables", block_tables) + # print("block_tables", block_tables) padded_block_tables = [ _pad_to_max(block_table, max_num_blocks_per_seq, pad=0) for block_table in block_tables From 2309330ffe03ddbc515f0ad4876d6481b572dcbd Mon Sep 17 00:00:00 2001 From: caoshiyi Date: Sat, 18 Nov 2023 07:21:33 +0000 Subject: [PATCH 10/37] minor --- vllm/prefix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/prefix.py b/vllm/prefix.py index 508ee2fdf4da9..e80acdad0b8e1 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -89,7 +89,7 @@ def efficient_search(self, token_ids: List[int]): def fixed_search(self, prefix_hash): if prefix_hash not in self.prefixes_hash: return None - print("Found prefix in the pool.") + # print("Found prefix in the pool.") prefix_id = self.prefixes_hash[prefix_hash] return self.prefixes[prefix_id] From d8b2809788590ce62c835d49afcc3f11a8181e16 Mon Sep 17 00:00:00 2001 From: caoshiyi Date: Tue, 21 Nov 2023 09:28:14 +0000 Subject: [PATCH 11/37] fix blocktable padding --- vllm/worker/worker.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index f22a67844d070..56d0f415424de 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -164,6 +164,7 @@ def _prepare_inputs( context_lens: List[int] = [] subquery_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] + max_num_blocks_per_seq_prompt = 0 for seq_group_metadata in seq_group_metadata_list: if not seq_group_metadata.is_prompt: continue @@ -184,7 +185,9 @@ def _prepare_inputs( prefix_len = seq_group_metadata.prefix.get_length() assert prefix_len % self.block_size == 0 prompt_tokens = prompt_tokens[prefix_len:] - prefix_block_tables.append(seq_group_metadata.prefix.get_block_table_num()) + prefix_block_table = seq_group_metadata.prefix.get_block_table_num() + prefix_block_tables.append(prefix_block_table) + max_num_blocks_per_seq_prompt = max(max_num_blocks_per_seq_prompt, len(prefix_block_table)) else: prefix_block_tables.append([]) # actual prompt lens @@ -303,7 +306,7 @@ def _prepare_inputs( for mapping in slot_mapping ] block_tables = generation_block_tables if prefix_block_tables == [] else prefix_block_tables - # print("block_tables", block_tables) + max_num_blocks_per_seq = max_num_blocks_per_seq if prefix_block_tables == [] else max_num_blocks_per_seq_prompt padded_block_tables = [ _pad_to_max(block_table, max_num_blocks_per_seq, pad=0) for block_table in block_tables From 50e74972ae534639b305e4f47efbce033ee17bbf Mon Sep 17 00:00:00 2001 From: caoshiyi Date: Mon, 1 Jan 2024 09:04:59 +0000 Subject: [PATCH 12/37] fix multi-gpu state transition --- vllm/engine/llm_engine.py | 6 ++++++ vllm/worker/worker.py | 6 +----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a788adb306712..c889354c77657 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -555,6 +555,12 @@ def _process_model_outputs( request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) + # update prefix state + for seq_group in scheduled_seq_groups: + if seq_group.prefix is not None and seq_group.prefix.swap_to_gpu: + seq_group.prefix.on_gpu = True + seq_group.prefix.swap_to_gpu = False + if self.log_stats: # Log the system stats. self._log_system_stats(scheduler_outputs.prompt_run, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 56d0f415424de..d25ac149b1f5c 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -243,11 +243,7 @@ def _prepare_inputs( selected_token_indices.append(selected_token_start_idx + prompt_len - 1) selected_token_start_idx += max_seq_len - - # set the prefix state - if seq_group_metadata.prefix is not None and seq_group_metadata.prefix.swap_to_gpu: - seq_group_metadata.prefix.on_gpu = True - seq_group_metadata.prefix.swap_to_gpu = False + continue seq_ids = list(seq_group_metadata.seq_data.keys()) From 33bfcffb741dba817a723325debeecc8817ab8f3 Mon Sep 17 00:00:00 2001 From: caoshiyi Date: Tue, 2 Jan 2024 02:25:39 +0000 Subject: [PATCH 13/37] clean --- vllm/core/block_manager.py | 3 +- vllm/core/scheduler.py | 2 +- .../layers/triton_kernel/benchmark_utils.py | 102 ----------- .../layers/triton_kernel/prefix_prefill.py | 170 ------------------ vllm/prefix.py | 84 ++++----- 5 files changed, 42 insertions(+), 319 deletions(-) delete mode 100644 vllm/model_executor/layers/triton_kernel/benchmark_utils.py diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 633c84dc1e068..3a410aadadb58 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -5,7 +5,7 @@ from vllm.block import PhysicalTokenBlock from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device -from vllm.prefix import PrefixPool, Prefix +from vllm.prefix import Prefix # Mapping: logical block number -> physical block. BlockTable = List[PhysicalTokenBlock] @@ -279,6 +279,7 @@ def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: } return block_number_mapping + # currently not used def swap_in_prefix(self, prefix: Prefix) -> Dict[int, int]: # CPU block -> GPU block. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 843601fd7dadd..12b9740a71682 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -196,7 +196,7 @@ def _schedule(self) -> SchedulerOutputs: if seq_group.prefix is not None and seq_group.prefix.on_cpu: # prefix.on_gpu will be set inside this function self._swap_in_prefix(seq_group.prefix, blocks_to_swap_in) - # if the prefix hasn't been compuated, allocate blocks for it and set prefix.swap_to_gpu to True + self._allocate(seq_group) self.running.append(seq_group) num_curr_seqs += num_new_seqs diff --git a/vllm/model_executor/layers/triton_kernel/benchmark_utils.py b/vllm/model_executor/layers/triton_kernel/benchmark_utils.py deleted file mode 100644 index f22df2a89769d..0000000000000 --- a/vllm/model_executor/layers/triton_kernel/benchmark_utils.py +++ /dev/null @@ -1,102 +0,0 @@ -import abc -import dataclasses -import gc -import itertools -import time -from typing import Callable - -import numpy as np -import torch - - -class Benchmark(abc.ABC): - - def setup(self): - pass - - def before_run(self): - pass - - @abc.abstractmethod - def run(self): - pass - - def after_run(self): - pass - - def teardown(self): - pass - - -class wrap_benchmark(Benchmark): - - def __init__(self, fn_run: Callable[[], None]): - self.fn_run = fn_run - - def run(self): - self.fn_run() - - -@dataclasses.dataclass -class BenchResult: - warmup: int - repeat: int - latency: np.ndarray - - def avg(self) -> np.ndarray: - return np.mean(self.latency) - - def std(self) -> np.ndarray: - return np.std(self.latency) - - def avg_std(self) -> np.ndarray: - return self.avg(), self.std() - - -def bench( - f: Callable[[], None], - warmup: int = 100, - repeat: int = 500, -) -> BenchResult: - if isinstance(f, Benchmark): - b = f - else: - b = wrap_benchmark(f) - - cache = torch.empty(256 * 2**20, dtype=torch.int8, device="cuda:0") - b.setup() - - latency = np.zeros(repeat, dtype=np.float64) - for i in range(-warmup, repeat): - b.before_run() - cache.zero_() - - torch.cuda.synchronize() - t0 = time.perf_counter() - b.run() - torch.cuda.synchronize() - t1 = time.perf_counter() - - b.after_run() - - if i >= 0: - latency[i] = t1 - t0 - - b.teardown() - del cache - return BenchResult(warmup, repeat, latency) - - -def gc_torch(): - gc.collect() - torch.cuda.empty_cache() - - -def batched(iterable, n): - "Batch data into tuples of length n. The last batch may be shorter." - # batched('ABCDEFG', 3) --> ABC DEF G - if n < 1: - raise ValueError('n must be at least one') - it = iter(iterable) - while batch := list(itertools.islice(it, n)): - yield batch diff --git a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py index 6e2ea15623332..d4586a96f936e 100644 --- a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py +++ b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py @@ -1,13 +1,7 @@ import torch import time -import itertools import triton import triton.language as tl -import math -import torch.nn.functional as F -import matplotlib.pyplot as plt - -from .benchmark_utils import bench, gc_torch if triton.__version__ >= "2.1.0": @triton.jit @@ -299,168 +293,4 @@ def test_contexted_kv_attention( print(output_ref[0,0,:10]) assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) -@torch.inference_mode() -def bench_contexted_kv_attention( - num_heads: int, - head_size: int, - dtype: torch.dtype, -) -> None: - import random - random.seed(0) - torch.manual_seed(0) - from xformers import ops as xops - from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask - - # seq_len = [16, 64, 128, 256, 512, 1024] - # ctx_len = [16, 64, 128, 256, 512, 1024, 2048] - seq_len = [16, 64, 128, 256, 512, 1024] - ctx_len = [256, 512, 1024, 2048] - BS = 20 - timings_triton = {} - timings_xformer = {} - for MAX_SEQ_LEN, MAX_CTX_LEN in itertools.product(seq_len, ctx_len): - gc_torch() - # MAX_SEQ_LEN = 1024 - # MAX_CTX_LEN = 2048 - outputs = [ - f"seq_len={MAX_SEQ_LEN}", - f"ctx_len={MAX_CTX_LEN}", - f"bs={BS}" - ] - cache_size = 40960 - block_size = 1 - max_block_per_request = 2048 - subquery_lens = [random.randint(MAX_SEQ_LEN, MAX_SEQ_LEN) for _ in range(BS)] - ctx_lens = [random.randint(MAX_CTX_LEN, MAX_CTX_LEN) for _ in range(BS)] - seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)] - - num_tokens = sum(subquery_lens) - query = torch.empty(num_tokens, - num_heads, - head_size, - dtype=dtype, - device='cuda') - query.uniform_(-1e-3, 1e-3) - output = torch.empty(num_tokens, - num_heads, - head_size, - dtype=dtype, - device='cuda') - - - kv = torch.empty(sum(seq_lens), - 2, - num_heads, - head_size, - dtype=dtype, - device='cuda') - kv.uniform_(-1e-3, 1e-3) - key,value = kv.unbind(dim=1) - - k_cache = torch.zeros(cache_size, block_size, num_heads, head_size, dtype=dtype, device='cuda') - v_cache = torch.zeros(cache_size, block_size, num_heads, head_size, dtype=dtype, device='cuda') - k = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype, device='cuda') - v = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype, device='cuda') - values = torch.arange(0, cache_size, dtype=torch.long, device='cuda') - values = values[torch.randperm(cache_size)] - block_table = values[:BS * max_block_per_request].view(BS, max_block_per_request) - b_loc = torch.zeros(BS, MAX_CTX_LEN, dtype=torch.long, device='cuda') - b_seq_len = torch.tensor(seq_lens, dtype=torch.long, device='cuda') - b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long, device='cuda') - b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1], dtype=torch.long, device='cuda'), dim=0) - max_input_len = MAX_SEQ_LEN - # copy kv to cache - b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], dtype=torch.long, device='cuda'), dim=0) - for i in range(BS): - for j in range(subquery_lens[i]): - k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j]) - v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j]) - cur_ctx = 0 - block_id = 0 - while cur_ctx < b_ctx_len[i]: - start_loc = b_seq_start_loc[i] + cur_ctx - if cur_ctx + block_size > b_ctx_len[i]: - end_loc = b_seq_start_loc[i] + b_ctx_len[i] - else: - end_loc = start_loc + block_size - start_slot = block_table[i,block_id] * block_size - end_slot = start_slot + end_loc - start_loc - k_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_(key[start_loc:end_loc]) - v_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_(value[start_loc:end_loc]) - cur_ctx += block_size - block_id += 1 - # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] - k_cache = k_cache.view(-1, block_size, num_heads, head_size//8, 8).permute(0, 2, 3, 1, 4).contiguous() - # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] to V_cache[num_blocks, num_kv_heads, head_size, block_size] - v_cache = v_cache.view(-1, block_size, num_heads, head_size).permute(0, 2, 3, 1).contiguous() - - - context_attention_fwd(query, k, v, output, - k_cache, v_cache, block_table, - b_start_loc, b_seq_len, - b_ctx_len, max_input_len) - torch.cuda.synchronize() - start_time = time.time() - context_attention_fwd(query, k, v, output, - k_cache, v_cache, block_table, - b_start_loc, b_seq_len, - b_ctx_len, max_input_len) - torch.cuda.synchronize() - end_time = time.time() - print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") - - scale = float(1.0 / (head_size**0.5)) - - attn_op = xops.fmha.cutlass.FwOp() - - attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(subquery_lens, seq_lens) - output_ref = xops.memory_efficient_attention_forward( - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), - attn_bias=attn_bias, - p=0.0, - scale=scale, - op=attn_op, - ) - torch.cuda.synchronize() - start_time = time.time() - output_ref = xops.memory_efficient_attention_forward( - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), - attn_bias=attn_bias, - p=0.0, - scale=scale, - op=attn_op, - ) - torch.cuda.synchronize() - end_time = time.time() - print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") - output_ref = output_ref.squeeze(0) - # print(output_ref.shape) - # print("max ", torch.max(torch.abs(output_ref - output))) - # print("mean ", torch.mean(torch.abs(output_ref - output))) - # print(output[0,0,:10]) - # print(output_ref[0,0,:10]) - assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) - - result = bench(lambda: xops.memory_efficient_attention_forward(query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), - attn_bias=attn_bias, - p=0.0, - scale=scale, - op=attn_op)) - outputs.append(f"\n xformer: {result.avg()*1e6:3.0f}us±{result.std()*1e6:3.0f}us") - timings_xformer[(MAX_SEQ_LEN, MAX_CTX_LEN)] = result.avg()*1e6 - result2 = bench(lambda: context_attention_fwd(query, k, v, output, - k_cache, v_cache, block_table, - b_start_loc, b_seq_len, - b_ctx_len, max_input_len)) - outputs.append(f"\n triton: {result2.avg()*1e6:3.0f}us±{result2.std()*1e6:3.0f}us") - timings_triton[(MAX_SEQ_LEN, MAX_CTX_LEN)] = result2.avg()*1e6 - print(" | ".join(outputs)) - - # test_contexted_kv_attention(12, 128, torch.float16) \ No newline at end of file diff --git a/vllm/prefix.py b/vllm/prefix.py index e80acdad0b8e1..c07d1580bffe9 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -1,84 +1,78 @@ -from typing import Dict, List, Optional, Union - -# Define the prefix class, which is a collection of prefix (a sequence of tokens). -# The class contains the following main methods: -# 1. A match method that checks if a prefix matches a given sequence of tokens. -# 2. A swapping method that can load or offload the prefix to or from GPU -# 3. An update_frequency method that updates the frequency of the prefix. -# 4. A get_status method that tells if the prefix is on GPU or not. - +from typing import List, Optional class Prefix: - def __init__(self, prefix_id, token_ids, block_size): + """Data and states associated with a prefix of prompt tokens for multiple sequence groups. + + Args: + prefix_id: The id of the prefix in the prefix pool. + token_ids: The token ids of the prefix. + block_size: The block size of the executed model. + + Attributes: + on_gpu: True if the prefix will be on GPU before the execution of the model. + on_cpu: True if the prefix is on CPU. + swap_to_gpu: True when the prefix will be computed during the execution of the model. + """ + def __init__( + self, + prefix_id: int, + token_ids: List[int], + block_size: int, + ) -> None: self.prefix_id = prefix_id self.token_ids = token_ids self.length = len(token_ids) - print("prefix length: ", self.length) - print("block size: ", block_size) assert self.length % block_size == 0 self.on_gpu = False self.on_cpu = False self.block_table = None # a lock to prevent multiple sequence from calculating the same prefix self.swap_to_gpu = False - - # freq-related - self.freq = 1 - self.alpha = 0.8 - self.beta = 0.5 def get_block_table_num(self) -> List[int]: return [block.block_number for block in self.block_table] - def match(self, tokens): + def match(self, tokens: List[int]) -> bool: return tokens[:self.length] == self.token_ids - - # should be called if the prefix is hit for this iteration - def update_freq(self, new_hit_rate): - self.freq = self.alpha * self.freq + (1 - self.alpha) * new_hit_rate - self.alpha = 0.8 - - # should be called if the prefix is not hit for this iteration - def punish_freq(self): - self.alpha = self.beta * self.alpha if self.alpha > 0.1 else 0.1 # whether the prefix is on GPU or not - def get_status(self): + def get_status(self) -> bool: return self.on_gpu - def get_length(self): + def get_length(self) -> int: return self.length - -# Define the prefix pool class, which is a collection of prefixes. -# The class contains the following main methods: -# 1. add a prefix to the pool, with a computed hash -# 2. TODO: create subprefix, if one is a prefix of the other: they can share some memory blocks -# 3. efficient_search: given a sequence of tokens, find the longest prefix in the pool that matches the sequence -# 4. fixed_search: given the prefix's hash, find the prefix in the pool -# 5. TODO: approximate_search: given a sequence of tokens, find the similar prefixes in the pool - - class PrefixPool: - def __init__(self, block_size): + """Manages all the prompt prefixes. + + Args: + block_size: The block size of the executed model. + + Attributes: + prefixes: A list of all the prefixes. + prefixes_hash: Mapping from the hash of the prefix to the prefix id. + block_size: The block size of the executed model. + """ + def __init__( + self, + block_size: int, + ) -> None: self.prefixes = [] self.prefixes_hash = {} self.block_size = block_size - def add_prefix(self, token_ids: List[int]): + def add_prefix(self, token_ids: List[int]) -> Prefix: # generate prefix_id prefix_id = len(self.prefixes) # create a new prefix prefix = Prefix(prefix_id, token_ids, self.block_size) self.prefixes.append(prefix) - # @TODO: compute the hash of the prefix prefix_hash = hash(tuple(prefix.token_ids)) - # self.prefixes_hash[prefix.prefix_id] = prefix_hash self.prefixes_hash[prefix_hash] = prefix.prefix_id return prefix # @TODO: this one should also come with a method to identify the prefix - def efficient_search(self, token_ids: List[int]): + def efficient_search(self, token_ids: List[int]) -> Optional[Prefix]: # improve this search for prefix in self.prefixes: if prefix.match(token_ids): @@ -86,7 +80,7 @@ def efficient_search(self, token_ids: List[int]): return None # use this first, if we already know from the application which part of the tokens are prefix. - def fixed_search(self, prefix_hash): + def fixed_search(self, prefix_hash : int) -> Optional[Prefix]: if prefix_hash not in self.prefixes_hash: return None # print("Found prefix in the pool.") From 1cef69fe6b89366ca525f20f3506efc2e2bbdd36 Mon Sep 17 00:00:00 2001 From: caoshiyi Date: Tue, 2 Jan 2024 02:32:58 +0000 Subject: [PATCH 14/37] format --- examples/api_client.py | 4 +- vllm/core/block_manager.py | 17 +- vllm/core/scheduler.py | 6 +- vllm/engine/async_llm_engine.py | 2 +- vllm/engine/llm_engine.py | 6 +- vllm/entrypoints/api_server.py | 3 +- vllm/entrypoints/llm.py | 7 +- vllm/model_executor/layers/attention.py | 8 +- .../layers/triton_kernel/prefix_prefill.py | 289 +++++++++++++----- vllm/prefix.py | 23 +- vllm/worker/model_runner.py | 21 +- 11 files changed, 265 insertions(+), 121 deletions(-) diff --git a/examples/api_client.py b/examples/api_client.py index 5fecd739af4f1..f554448ddbd1e 100644 --- a/examples/api_client.py +++ b/examples/api_client.py @@ -54,7 +54,9 @@ def get_response(response: requests.Response) -> List[str]: parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=8000) parser.add_argument("--n", type=int, default=4) - parser.add_argument("--prompt", type=str, default="San Francisco is a "*32) + parser.add_argument("--prompt", + type=str, + default="San Francisco is a " * 32) parser.add_argument("--stream", action="store_true") args = parser.parse_args() prompt = args.prompt diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 3a410aadadb58..69f03140b0e4f 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -108,7 +108,8 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: num_required_blocks = len(seq.logical_token_blocks) if seq_group.prefix is not None and seq_group.prefix.on_gpu: - num_required_blocks -= seq_group.prefix.get_length() // self.block_size + num_required_blocks -= seq_group.prefix.get_length( + ) // self.block_size if self.block_sliding_window is not None: num_required_blocks = min(num_required_blocks, @@ -138,15 +139,17 @@ def allocate(self, seq_group: SequenceGroup) -> None: if seq_group.prefix is not None: # prefix is already on gpu or will be swapped in before the actual computation if seq_group.prefix.on_gpu: - num_prompt_blocks -= seq_group.prefix.get_length() // self.block_size + num_prompt_blocks -= seq_group.prefix.get_length( + ) // self.block_size for block in seq_group.prefix.block_table: block.ref_count += seq_group.num_seqs() block_table.append(block) # TODO: will need to perform the copy-on-write if prefix length is not a multiple of block size - + # allocate blocks for the prefix, we need to calculate the prefix's kv in this run elif not seq_group.prefix.swap_to_gpu: - num_prefix_blocks = seq_group.prefix.get_length() // self.block_size + num_prefix_blocks = seq_group.prefix.get_length( + ) // self.block_size seq_group.prefix.swap_to_gpu = True for logical_idx in range(num_prompt_blocks): @@ -165,7 +168,7 @@ def allocate(self, seq_group: SequenceGroup) -> None: # Assign the block table for each sequence. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): self.block_tables[seq.seq_id] = block_table.copy() - + if num_prefix_blocks > 0: seq_group.prefix.block_table = prefix_block_table.copy() @@ -250,7 +253,7 @@ def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: # CPU block -> GPU block. if seq_group.prefix is not None: # make sure to swap in the prefix first - assert seq_group.prefix.on_gpu == True + assert seq_group.prefix.on_gpu is True mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): @@ -338,7 +341,7 @@ def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: for gpu_block, cpu_block in mapping.items() } return block_number_mapping - + def swap_out_prefix(self, prefix: Prefix) -> Dict[int, int]: # GPU block -> CPU block. # make sure all the reference seq are finished or swapped out before swapping out the prefix diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 12b9740a71682..b794f529dd039 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -75,7 +75,7 @@ def __init__( num_gpu_blocks=self.cache_config.num_gpu_blocks, num_cpu_blocks=self.cache_config.num_cpu_blocks, sliding_window=self.cache_config.sliding_window) - + self.prefix_pool = PrefixPool(self.cache_config.block_size) # TODO(zhuohan): Use deque instead of list for better performance. @@ -196,7 +196,7 @@ def _schedule(self) -> SchedulerOutputs: if seq_group.prefix is not None and seq_group.prefix.on_cpu: # prefix.on_gpu will be set inside this function self._swap_in_prefix(seq_group.prefix, blocks_to_swap_in) - + self._allocate(seq_group) self.running.append(seq_group) num_curr_seqs += num_new_seqs @@ -420,7 +420,7 @@ def _swap_out( blocks_to_swap_out.update(mapping) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq.status = SequenceStatus.SWAPPED - + def _swap_in_prefix( self, prefix: Prefix, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index caa13490fdf9e..94e7cd4898341 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -404,7 +404,7 @@ async def add_request( async def generate( self, prompt: Optional[str], - prefix_pos: Optional[int], + prefix_pos: Optional[int], sampling_params: SamplingParams, request_id: str, prompt_token_ids: Optional[List[int]] = None diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4cf7c28ed0fb3..5a6955b359b50 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -307,14 +307,16 @@ def add_request( prefix_pos = (prefix_pos // block_size) * block_size if prefix_pos > 0: truncated_prefix_token_ids = prompt_token_ids[:prefix_pos] - prefix = self.scheduler.prefix_pool.fixed_search(hash(tuple(truncated_prefix_token_ids))) + prefix = self.scheduler.prefix_pool.fixed_search( + hash(tuple(truncated_prefix_token_ids))) if prefix is not None: seq.prefix = prefix # print("prefix status: ", "on gpu" if prefix.get_status() else "on cpu") # prefix.update_freq(1.0) else: # create a new prefix - seq.prefix = self.scheduler.prefix_pool.add_prefix(truncated_prefix_token_ids) + seq.prefix = self.scheduler.prefix_pool.add_prefix( + truncated_prefix_token_ids) # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 35068bb10d6a8..0d7679968cf3a 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -39,7 +39,8 @@ async def generate(request: Request) -> Response: sampling_params = SamplingParams(**request_dict) request_id = random_uuid() - results_generator = engine.generate(prompt, prefix_pos, sampling_params, request_id) + results_generator = engine.generate(prompt, prefix_pos, sampling_params, + request_id) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 09d638eb4c2b0..9f5fbbd876123 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -174,8 +174,11 @@ def _add_request( prefix_pos: Optional[int], ) -> None: request_id = str(next(self.request_counter)) - self.llm_engine.add_request(request_id, prompt, sampling_params, - prompt_token_ids, prefix_pos=prefix_pos) + self.llm_engine.add_request(request_id, + prompt, + sampling_params, + prompt_token_ids, + prefix_pos=prefix_pos) def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # Initialize tqdm. diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index b81bef01128a5..d8e46e3834547 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -117,7 +117,8 @@ def forward( self.num_queries_per_kv, value.shape[-1]) # normal attention - if key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0: + if key_cache is None or value_cache is None or input_metadata.block_tables.numel( + ) == 0: # Set attention bias if not provided. This typically happens at the # very attention layer of every iteration. # FIXME(woosuk): This is a hack. @@ -166,12 +167,11 @@ def forward( output, key_cache, value_cache, - input_metadata.block_tables, # [BS, max_block_per_request] + input_metadata.block_tables, # [BS, max_block_per_request] input_metadata.start_loc, input_metadata.prompt_lens_tensor, input_metadata.context_lens, - input_metadata.max_seq_len - ) + input_metadata.max_seq_len) # TODO: add support for Alibi bias else: diff --git a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py index d4586a96f936e..34afd3a4ff448 100644 --- a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py +++ b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py @@ -4,18 +4,47 @@ import triton.language as tl if triton.__version__ >= "2.1.0": + @triton.jit def _fwd_kernel( - Q, K, V, K_cache, V_cache, B_Loc, sm_scale, B_Start_Loc, B_Seqlen, B_Ctxlen, block_size, x, + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, Out, - stride_b_loc_b, stride_b_loc_s, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - stride_vbs, stride_vh, stride_vd, - stride_obs, stride_oh, stride_od, - stride_k_cache_bs, stride_k_cache_h, stride_k_cache_d, stride_k_cache_bl, stride_k_cache_x, - stride_v_cache_bs, stride_v_cache_h, stride_v_cache_d, stride_v_cache_bl, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) @@ -32,9 +61,14 @@ def _fwd_kernel( offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd + off_q = (cur_batch_in_all_start_index + + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[ + None, :] * stride_qd - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, other=0.0) + q = tl.load( + Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) # # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") @@ -44,14 +78,29 @@ def _fwd_kernel( for start_n in range(0, cur_batch_ctx_len, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + ((start_n + offs_n) // block_size) * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_ctx_len, other=0) - off_k = bn[None, :] * stride_k_cache_bs + cur_head * stride_k_cache_h + (offs_d[:, None] // x) * stride_k_cache_d + ((start_n + offs_n[None,:]) % block_size) * stride_k_cache_bl + (offs_d[:, None] % x) * stride_k_cache_x - off_v = bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h + offs_d[None,:] * stride_v_cache_d + (start_n + offs_n[:,None]) % block_size * stride_v_cache_bl - k = tl.load(K_cache + off_k, mask=(start_n + offs_n[None,:]) < cur_batch_ctx_len, other=0.0) + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = bn[ + None, :] * stride_k_cache_bs + cur_head * stride_k_cache_h + ( + offs_d[:, None] // x) * stride_k_cache_d + ( + (start_n + offs_n[None, :]) % + block_size) * stride_k_cache_bl + ( + offs_d[:, None] % x) * stride_k_cache_x + off_v = bn[:, + None] * stride_v_cache_bs + cur_head * stride_v_cache_h + offs_d[ + None, :] * stride_v_cache_d + ( + start_n + offs_n[:, None] + ) % block_size * stride_v_cache_bl + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) - qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf")) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) qk *= sm_scale # -- compute m_ij, p, l_ij @@ -71,32 +120,42 @@ def _fwd_kernel( acc_scale = l_i / l_i_new * alpha acc = acc * acc_scale[:, None] # update acc - v = tl.load(V_cache + off_v, mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, other=0.0) + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0) p = p.to(v.dtype) acc += tl.dot(p, v) # # update m_i and l_i l_i = l_i_new m_i = m_i_new - - off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd + + off_k = offs_n[ + None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, + None] * stride_kd + off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[ + None, :] * stride_vd k_ptrs = K + off_k v_ptrs = V + off_v - block_mask = tl.where(block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len - cur_batch_ctx_len, other=0.0) + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) # -- compute m_ij, p, l_ij m_ij = tl.max(qk, 1) @@ -115,8 +174,11 @@ def _fwd_kernel( acc_scale = l_i / l_i_new * alpha acc = acc * acc_scale[:, None] # update acc - v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len - cur_batch_ctx_len, other=0.0) + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) p = p.to(v.dtype) acc += tl.dot(p, v) @@ -124,13 +186,18 @@ def _fwd_kernel( l_i = l_i_new m_i = m_i_new # initialize pointers to output - off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + off_o = (cur_batch_in_all_start_index + + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[ + None, :] * stride_od out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) return @torch.inference_mode() - def context_attention_fwd(q, k, v, o, k_cache, v_cache, b_loc, b_start_loc, b_seq_len, b_ctx_len, max_input_len): + def context_attention_fwd(q, k, v, o, k_cache, v_cache, b_loc, b_start_loc, + b_seq_len, b_ctx_len, max_input_len): BLOCK = 128 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] @@ -144,15 +211,44 @@ def context_attention_fwd(q, k, v, o, k_cache, v_cache, b_loc, b_start_loc, b_se num_warps = 8 if Lk <= 64 else 8 _fwd_kernel[grid]( - q, k, v, k_cache, v_cache, b_loc, sm_scale, b_start_loc, b_seq_len, b_ctx_len, v_cache.shape[3], 8, + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + b_start_loc, + b_seq_len, + b_ctx_len, + v_cache.shape[3], + 8, o, - b_loc.stride(0), b_loc.stride(1), - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - o.stride(0), o.stride(1), o.stride(2), - k_cache.stride(0), k_cache.stride(1), k_cache.stride(2), k_cache.stride(3), k_cache.stride(4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] - v_cache.stride(0), v_cache.stride(1), v_cache.stride(2), v_cache.stride(3), #[num_blocks, num_kv_heads, head_size, block_size] + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride( + 3), #[num_blocks, num_kv_heads, head_size, block_size] BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK, @@ -161,6 +257,7 @@ def context_attention_fwd(q, k, v, o, k_cache, v_cache, b_loc, b_start_loc, b_se ) return + @torch.inference_mode() def test_contexted_kv_attention( num_heads: int, @@ -181,48 +278,73 @@ def test_contexted_kv_attention( subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)] - + num_tokens = sum(subquery_lens) query = torch.empty(num_tokens, - num_heads, - head_size, - dtype=dtype, - device='cuda') + num_heads, + head_size, + dtype=dtype, + device='cuda') query.uniform_(-1e-3, 1e-3) output = torch.empty(num_tokens, + num_heads, + head_size, + dtype=dtype, + device='cuda') + + kv = torch.empty(sum(seq_lens), + 2, + num_heads, + head_size, + dtype=dtype, + device='cuda') + kv.uniform_(-1e-3, 1e-3) + key, value = kv.unbind(dim=1) + + k_cache = torch.zeros(cache_size, + block_size, + num_heads, + head_size, + dtype=dtype, + device='cuda') + v_cache = torch.zeros(cache_size, + block_size, + num_heads, + head_size, + dtype=dtype, + device='cuda') + k = torch.zeros(sum(subquery_lens), + num_heads, + head_size, + dtype=dtype, + device='cuda') + v = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype, device='cuda') - - - kv = torch.empty(sum(seq_lens), - 2, - num_heads, - head_size, - dtype=dtype, - device='cuda') - kv.uniform_(-1e-3, 1e-3) - key,value = kv.unbind(dim=1) - - k_cache = torch.zeros(cache_size, block_size, num_heads, head_size, dtype=dtype, device='cuda') - v_cache = torch.zeros(cache_size, block_size, num_heads, head_size, dtype=dtype, device='cuda') - k = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype, device='cuda') - v = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype, device='cuda') values = torch.arange(0, cache_size, dtype=torch.long, device='cuda') values = values[torch.randperm(cache_size)] - block_table = values[:BS * max_block_per_request].view(BS, max_block_per_request) - b_loc = torch.zeros(BS, MAX_CTX_LEN, dtype=torch.long, device='cuda') + block_table = values[:BS * max_block_per_request].view( + BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.long, device='cuda') b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long, device='cuda') - b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1], dtype=torch.long, device='cuda'), dim=0) + b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1], + dtype=torch.long, + device='cuda'), + dim=0) max_input_len = MAX_SEQ_LEN # copy kv to cache - b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], dtype=torch.long, device='cuda'), dim=0) + b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], + dtype=torch.long, + device='cuda'), + dim=0) for i in range(BS): for j in range(subquery_lens[i]): - k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + j]) - v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + b_ctx_len[i] + j]) + k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + + j]) + v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + + b_ctx_len[i] + j]) cur_ctx = 0 block_id = 0 while cur_ctx < b_ctx_len[i]: @@ -231,28 +353,27 @@ def test_contexted_kv_attention( end_loc = b_seq_start_loc[i] + b_ctx_len[i] else: end_loc = start_loc + block_size - start_slot = block_table[i,block_id] * block_size + start_slot = block_table[i, block_id] * block_size end_slot = start_slot + end_loc - start_loc - k_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_(key[start_loc:end_loc]) - v_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_(value[start_loc:end_loc]) + k_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_( + key[start_loc:end_loc]) + v_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_( + value[start_loc:end_loc]) cur_ctx += block_size block_id += 1 # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] - k_cache = k_cache.view(-1, block_size, num_heads, head_size//8, 8).permute(0, 2, 3, 1, 4).contiguous() + k_cache = k_cache.view(-1, block_size, num_heads, head_size // 8, + 8).permute(0, 2, 3, 1, 4).contiguous() # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] to V_cache[num_blocks, num_kv_heads, head_size, block_size] - v_cache = v_cache.view(-1, block_size, num_heads, head_size).permute(0, 2, 3, 1).contiguous() - - - context_attention_fwd(query, k, v, output, - k_cache, v_cache, block_table, - b_start_loc, b_seq_len, - b_ctx_len, max_input_len) + v_cache = v_cache.view(-1, block_size, num_heads, + head_size).permute(0, 2, 3, 1).contiguous() + + context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table, + b_start_loc, b_seq_len, b_ctx_len, max_input_len) torch.cuda.synchronize() start_time = time.time() - context_attention_fwd(query, k, v, output, - k_cache, v_cache, block_table, - b_start_loc, b_seq_len, - b_ctx_len, max_input_len) + context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table, + b_start_loc, b_seq_len, b_ctx_len, max_input_len) torch.cuda.synchronize() end_time = time.time() print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") @@ -260,8 +381,9 @@ def test_contexted_kv_attention( scale = float(1.0 / (head_size**0.5)) attn_op = xops.fmha.cutlass.FwOp() - - attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens(subquery_lens, seq_lens) + + attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( + subquery_lens, seq_lens) output_ref = xops.memory_efficient_attention_forward( query.unsqueeze(0), key.unsqueeze(0), @@ -289,8 +411,9 @@ def test_contexted_kv_attention( print(output_ref.shape) print("max ", torch.max(torch.abs(output_ref - output))) print("mean ", torch.mean(torch.abs(output_ref - output))) - print(output[0,0,:10]) - print(output_ref[0,0,:10]) + print(output[0, 0, :10]) + print(output_ref[0, 0, :10]) assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) -# test_contexted_kv_attention(12, 128, torch.float16) \ No newline at end of file + +# test_contexted_kv_attention(12, 128, torch.float16) diff --git a/vllm/prefix.py b/vllm/prefix.py index c07d1580bffe9..a16426a9d561f 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -1,5 +1,6 @@ from typing import List, Optional + class Prefix: """Data and states associated with a prefix of prompt tokens for multiple sequence groups. @@ -13,6 +14,7 @@ class Prefix: on_cpu: True if the prefix is on CPU. swap_to_gpu: True when the prefix will be computed during the execution of the model. """ + def __init__( self, prefix_id: int, @@ -28,20 +30,21 @@ def __init__( self.block_table = None # a lock to prevent multiple sequence from calculating the same prefix self.swap_to_gpu = False - + def get_block_table_num(self) -> List[int]: return [block.block_number for block in self.block_table] - + def match(self, tokens: List[int]) -> bool: return tokens[:self.length] == self.token_ids - + # whether the prefix is on GPU or not def get_status(self) -> bool: return self.on_gpu - + def get_length(self) -> int: return self.length + class PrefixPool: """Manages all the prompt prefixes. @@ -53,14 +56,15 @@ class PrefixPool: prefixes_hash: Mapping from the hash of the prefix to the prefix id. block_size: The block size of the executed model. """ + def __init__( - self, + self, block_size: int, ) -> None: self.prefixes = [] self.prefixes_hash = {} self.block_size = block_size - + def add_prefix(self, token_ids: List[int]) -> Prefix: # generate prefix_id prefix_id = len(self.prefixes) @@ -70,7 +74,7 @@ def add_prefix(self, token_ids: List[int]) -> Prefix: prefix_hash = hash(tuple(prefix.token_ids)) self.prefixes_hash[prefix_hash] = prefix.prefix_id return prefix - + # @TODO: this one should also come with a method to identify the prefix def efficient_search(self, token_ids: List[int]) -> Optional[Prefix]: # improve this search @@ -78,12 +82,11 @@ def efficient_search(self, token_ids: List[int]) -> Optional[Prefix]: if prefix.match(token_ids): return prefix return None - + # use this first, if we already know from the application which part of the tokens are prefix. - def fixed_search(self, prefix_hash : int) -> Optional[Prefix]: + def fixed_search(self, prefix_hash: int) -> Optional[Prefix]: if prefix_hash not in self.prefixes_hash: return None # print("Found prefix in the pool.") prefix_id = self.prefixes_hash[prefix_hash] return self.prefixes[prefix_id] - diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 11ba929dcc8d6..c2397b47f01d2 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -96,19 +96,22 @@ def _prepare_prompt( prefix_len = seq_group_metadata.prefix.get_length() assert prefix_len % self.block_size == 0 prompt_tokens = prompt_tokens[prefix_len:] - prefix_block_table = seq_group_metadata.prefix.get_block_table_num() + prefix_block_table = seq_group_metadata.prefix.get_block_table_num( + ) prefix_block_tables.append(prefix_block_table) - max_num_blocks_per_seq_prompt = max(max_num_blocks_per_seq_prompt, len(prefix_block_table)) + max_num_blocks_per_seq_prompt = max( + max_num_blocks_per_seq_prompt, len(prefix_block_table)) else: prefix_block_tables.append([]) # actual prompt lens context_lens.append(prefix_len) - subquery_lens.append(prompt_len-prefix_len) + subquery_lens.append(prompt_len - prefix_len) input_tokens.append(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. - input_positions.append(list(range(prefix_len,prefix_len+len(prompt_tokens)))) + input_positions.append( + list(range(prefix_len, prefix_len + len(prompt_tokens)))) if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized @@ -152,8 +155,8 @@ def _prepare_prompt( pad=_PAD_SLOT_ID, dtype=torch.long) context_lens_tensor = torch.tensor(context_lens, - dtype=torch.int, - device='cuda') + dtype=torch.int, + device='cuda') # prefix block tables block_tables = _make_tensor_with_pad( @@ -163,7 +166,11 @@ def _prepare_prompt( dtype=torch.int, ) - start_loc_tensor = torch.arange(0, len(prompt_lens)*max_prompt_len, max_prompt_len, dtype=torch.long, device='cuda') + start_loc_tensor = torch.arange(0, + len(prompt_lens) * max_prompt_len, + max_prompt_len, + dtype=torch.long, + device='cuda') input_metadata = InputMetadata( prompt_lens=prompt_lens, From 98ee509178149118957e8141e09829463ed381df Mon Sep 17 00:00:00 2001 From: caoshiyi Date: Tue, 2 Jan 2024 03:18:40 +0000 Subject: [PATCH 15/37] add support for alibi bias Co-authored-by: DouHappy <2278958187@qq.com> --- vllm/model_executor/layers/attention.py | 5 +- .../layers/triton_kernel/prefix_prefill.py | 495 +++++++++++++++++- 2 files changed, 496 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index d8e46e3834547..d86645bac8c9b 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -171,8 +171,9 @@ def forward( input_metadata.start_loc, input_metadata.prompt_lens_tensor, input_metadata.context_lens, - input_metadata.max_seq_len) - # TODO: add support for Alibi bias + input_metadata.max_seq_len, + getattr(self, "alibi_slopes", None), + ) else: # Decoding run. diff --git a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py index 34afd3a4ff448..b369b53817827 100644 --- a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py +++ b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py @@ -195,9 +195,450 @@ def _fwd_kernel( mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) return + @triton.jit + def _fwd_kernel_flash_attn_v2( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = (cur_batch_in_all_start_index + + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[ + None, :] * stride_qd + + q = tl.load( + Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = bn[ + None, :] * stride_k_cache_bs + cur_head * stride_k_cache_h + ( + offs_d[:, None] // x) * stride_k_cache_d + ( + (start_n + offs_n[None, :]) % + block_size) * stride_k_cache_bl + ( + offs_d[:, None] % x) * stride_k_cache_x + off_v = bn[:, + None] * stride_v_cache_bs + cur_head * stride_v_cache_h + offs_d[ + None, :] * stride_v_cache_d + ( + start_n + offs_n[:, None] + ) % block_size * stride_v_cache_bl + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = offs_n[ + None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, + None] * stride_kd + off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[ + None, :] * stride_vd + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + # acc /= l_i[:, None] + # initialize pointers to output + off_o = (cur_batch_in_all_start_index + + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[ + None, :] * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + + @triton.jit + def _fwd_kernel_alibi( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + Alibi_slopes, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + # debuger, + # stride_db_head, stride_db_q, stride_db_k, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + # attn_bias[] + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + # cur_batch_seq_len: the length of prompts + # cur_batch_ctx_len: the length of prefix + # cur_batch_in_all_start_index: the start id of the dim=0 + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = (cur_batch_in_all_start_index + + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[ + None, :] * stride_qd + + q = tl.load( + Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange( + 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = 0 + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = bn[ + None, :] * stride_k_cache_bs + cur_head * stride_k_cache_h + ( + offs_d[:, None] // x) * stride_k_cache_d + ( + (start_n + offs_n[None, :]) % + block_size) * stride_k_cache_bl + ( + offs_d[:, None] % x) * stride_k_cache_x + off_v = bn[:, + None] * stride_v_cache_bs + cur_head * stride_v_cache_h + offs_d[ + None, :] * stride_v_cache_d + ( + start_n + offs_n[:, None] + ) % block_size * stride_v_cache_bl + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), + alibi, float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # # debuger alibi + # offset_db = stride_db_head * cur_head + offset_db_q[:, None] * stride_db_q + offset_db_k[None, :] * stride_db_k + # mask_db = (offset_db_q < cur_batch_seq_len - cur_batch_ctx_len)[:, None] & (offset_db_k < cur_batch_seq_len)[None, :] + # tl.store(debuger + offset_db, alibi, mask=mask_db) + # offset_db_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v, allow_tf32=False) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = offs_n[ + None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, + None] * stride_kd + off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[ + None, :] * stride_vd + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + # init alibi + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange( + 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = cur_batch_ctx_len + # # init debuger + # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc + # offset_db_k = tl.arange(0, BLOCK_N) + # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, allow_tf32=False) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), + alibi, float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # # debuger alibi + # offset_db = stride_db_head * cur_head + offset_db_q[:, None] * stride_db_q + offset_db_k[None, :] * stride_db_k + # mask_db = (offset_db_q < cur_batch_seq_len - cur_batch_ctx_len)[:, None] & (offset_db_k < cur_batch_seq_len)[None, :] + # tl.store(debuger + offset_db, alibi, mask=mask_db) + # offset_db_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v, allow_tf32=False) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + acc = acc / l_i[:, None] + + # initialize pointers to output + off_o = (cur_batch_in_all_start_index + + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[ + None, :] * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + @torch.inference_mode() - def context_attention_fwd(q, k, v, o, k_cache, v_cache, b_loc, b_start_loc, - b_seq_len, b_ctx_len, max_input_len): + def context_attention_fwd(q, + k, + v, + o, + k_cache, + v_cache, + b_loc, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + alibi_slopes=None): BLOCK = 128 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] @@ -210,6 +651,56 @@ def context_attention_fwd(q, k, v, o, k_cache, v_cache, b_loc, b_start_loc, grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, num_warps = 8 if Lk <= 64 else 8 + if alibi_slopes is not None: + _fwd_kernel_alibi[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + b_start_loc, + b_seq_len, + b_ctx_len, + alibi_slopes, + v_cache.shape[3], + 8, + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4 + ), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride( + 3), #[num_blocks, num_kv_heads, head_size, block_size] + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + _fwd_kernel[grid]( q, k, From a948cd35d7580ec4c4b47228f8082192b11d8847 Mon Sep 17 00:00:00 2001 From: caoshiyi Date: Fri, 5 Jan 2024 08:10:06 +0000 Subject: [PATCH 16/37] format --- vllm/worker/model_runner.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 02d6c146bbf3c..2365abb19b482 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -176,8 +176,8 @@ def _prepare_prompt( dtype=torch.long, device='cuda') prompt_lens_tensor = torch.tensor(prompt_lens, - dtype=torch.long, - device='cuda') + dtype=torch.long, + device='cuda') input_metadata = InputMetadata( is_prompt=True, @@ -381,14 +381,15 @@ def prepare_input_tensors( is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, input_metadata, - prompt_lens, subquery_lens) = self._prepare_prompt(seq_group_metadata_list) + (input_tokens, input_positions, input_metadata, prompt_lens, + subquery_lens) = self._prepare_prompt(seq_group_metadata_list) else: - (input_tokens, input_positions, input_metadata, subquery_lens - ) = self._prepare_decode(seq_group_metadata_list) + (input_tokens, input_positions, input_metadata, + subquery_lens) = self._prepare_decode(seq_group_metadata_list) prompt_lens = [] sampling_metadata = self._prepare_sample(seq_group_metadata_list, - prompt_lens, subquery_lens) + prompt_lens, + subquery_lens) def get_size_or_none(x: Optional[torch.Tensor]): return x.size() if x is not None else None @@ -458,15 +459,15 @@ def get_size_or_none(x: Optional[torch.Tensor]): slot_mapping = None if py_data["prompt_lens_size"] is not None: prompt_lens = torch.empty(*py_data["prompt_lens_size"], - dtype=torch.long, - device="cuda") + dtype=torch.long, + device="cuda") broadcast(prompt_lens, src=0) else: prompt_lens = None if py_data["start_loc_size"] is not None: start_loc = torch.empty(*py_data["start_loc_size"], - dtype=torch.long, - device="cuda") + dtype=torch.long, + device="cuda") broadcast(start_loc, src=0) else: start_loc = None From ead42a2adb847fddc761c5a87195c919c79daeca Mon Sep 17 00:00:00 2001 From: caoshiyi Date: Sat, 6 Jan 2024 01:57:49 +0000 Subject: [PATCH 17/37] clean --- examples/api_client.py | 8 +-- vllm/core/block_manager.py | 55 ++++--------------- vllm/core/scheduler.py | 1 + vllm/engine/async_llm_engine.py | 4 +- vllm/model_executor/layers/attention.py | 4 +- .../layers/triton_kernel/prefix_prefill.py | 4 ++ vllm/prefix.py | 17 ++---- vllm/worker/model_runner.py | 5 +- 8 files changed, 32 insertions(+), 66 deletions(-) diff --git a/examples/api_client.py b/examples/api_client.py index f554448ddbd1e..70ec8c5492124 100644 --- a/examples/api_client.py +++ b/examples/api_client.py @@ -15,14 +15,12 @@ def clear_line(n: int = 1) -> None: def post_http_request(prompt: str, - prefix_pos: int, api_url: str, n: int = 1, stream: bool = False) -> requests.Response: headers = {"User-Agent": "Test Client"} pload = { "prompt": prompt, - "prefix_pos": prefix_pos, "n": n, "use_beam_search": True, "temperature": 0.0, @@ -54,9 +52,7 @@ def get_response(response: requests.Response) -> List[str]: parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=8000) parser.add_argument("--n", type=int, default=4) - parser.add_argument("--prompt", - type=str, - default="San Francisco is a " * 32) + parser.add_argument("--prompt", type=str, default="San Francisco is a") parser.add_argument("--stream", action="store_true") args = parser.parse_args() prompt = args.prompt @@ -65,7 +61,7 @@ def get_response(response: requests.Response) -> List[str]: stream = args.stream print(f"Prompt: {prompt!r}\n", flush=True) - response = post_http_request(prompt, 32, api_url, n, stream) + response = post_http_request(prompt, api_url, n, stream) if stream: num_printed_lines = 0 diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 69f03140b0e4f..74630081d4866 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -3,9 +3,9 @@ from typing import Dict, List, Optional, Set, Tuple from vllm.block import PhysicalTokenBlock +from vllm.prefix import Prefix from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device -from vllm.prefix import Prefix # Mapping: logical block number -> physical block. BlockTable = List[PhysicalTokenBlock] @@ -137,19 +137,18 @@ def allocate(self, seq_group: SequenceGroup) -> None: prefix_block_table: BlockTable = [] num_prefix_blocks = 0 if seq_group.prefix is not None: - # prefix is already on gpu or will be swapped in before the actual computation + # prefix is already on gpu or + # will be swapped in before the actual computation if seq_group.prefix.on_gpu: - num_prompt_blocks -= seq_group.prefix.get_length( - ) // self.block_size + num_prompt_blocks -= seq_group.prefix.get_num_blocks() for block in seq_group.prefix.block_table: block.ref_count += seq_group.num_seqs() block_table.append(block) - # TODO: will need to perform the copy-on-write if prefix length is not a multiple of block size - # allocate blocks for the prefix, we need to calculate the prefix's kv in this run + # allocate blocks for the prefix, + # we need to calculate the prefix's kv in this run elif not seq_group.prefix.swap_to_gpu: - num_prefix_blocks = seq_group.prefix.get_length( - ) // self.block_size + num_prefix_blocks = seq_group.prefix.get_num_blocks() seq_group.prefix.swap_to_gpu = True for logical_idx in range(num_prompt_blocks): @@ -161,6 +160,10 @@ def allocate(self, seq_group: SequenceGroup) -> None: # Set the reference counts of the token blocks. block.ref_count = seq_group.num_seqs() block_table.append(block) + # Store the blocks computed by + # the first seq group using this prefix + # the other seq groups in the same batch will also compute the prefix + # but those blocks won't be stored if logical_idx < num_prefix_blocks: block.ref_count += 1 prefix_block_table.append(block) @@ -169,6 +172,7 @@ def allocate(self, seq_group: SequenceGroup) -> None: for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): self.block_tables[seq.seq_id] = block_table.copy() + # Record the prefix block table for the prefix if num_prefix_blocks > 0: seq_group.prefix.block_table = prefix_block_table.copy() @@ -240,15 +244,6 @@ def can_swap_in(self, seq_group: SequenceGroup) -> bool: num_required_blocks = len(blocks) + num_swapped_seqs return num_free_blocks - num_required_blocks >= self.watermark_blocks - def can_swap_in_prefix(self, prefix: Prefix) -> bool: - blocks = prefix.block_table - num_free_blocks = self.gpu_allocator.get_num_free_blocks() - # NOTE: Conservatively, we assume that every sequence will allocate - # at least one free block right after the swap-in. - # NOTE: This should match the logic in can_append_slot(). - num_required_blocks = len(blocks) - return num_free_blocks - num_required_blocks >= self.watermark_blocks - def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: # CPU block -> GPU block. if seq_group.prefix is not None: @@ -282,36 +277,10 @@ def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: } return block_number_mapping - # currently not used - def swap_in_prefix(self, prefix: Prefix) -> Dict[int, int]: - # CPU block -> GPU block. - mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} - new_block_table = [] - block_table = prefix.block_table - - for cpu_block in enumerate(block_table): - # ref_count = 1 - gpu_block = self.gpu_allocator.allocate() - mapping[cpu_block] = gpu_block - new_block_table.append(gpu_block) - # Free the CPU block swapped in to GPU. - self.cpu_allocator.free(cpu_block) - prefix.block_table = new_block_table - - block_number_mapping = { - cpu_block.block_number: gpu_block.block_number - for cpu_block, gpu_block in mapping.items() - } - return block_number_mapping - def can_swap_out(self, seq_group: SequenceGroup) -> bool: blocks = self._get_physical_blocks(seq_group) return len(blocks) <= self.cpu_allocator.get_num_free_blocks() - def can_swap_out_prefix(self, prefix: Prefix) -> bool: - blocks = prefix.block_table - return len(blocks) <= self.cpu_allocator.get_num_free_blocks() - def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: # GPU block -> CPU block. mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index b794f529dd039..5af64d3249df1 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -76,6 +76,7 @@ def __init__( num_cpu_blocks=self.cache_config.num_cpu_blocks, sliding_window=self.cache_config.sliding_window) + # Create the prefix pool to cache the prefixes. self.prefix_pool = PrefixPool(self.cache_config.block_size) # TODO(zhuohan): Use deque instead of list for better performance. diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 1e01c539b46d8..f339ddb63c0c1 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -410,10 +410,10 @@ async def add_request( async def generate( self, prompt: Optional[str], - prefix_pos: Optional[int], sampling_params: SamplingParams, request_id: str, - prompt_token_ids: Optional[List[int]] = None + prompt_token_ids: Optional[List[int]] = None, + prefix_pos: Optional[int] = None, ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index be10f9d9b58d5..799d1f60851a5 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -117,8 +117,8 @@ def forward( self.num_queries_per_kv, value.shape[-1]) # normal attention - if key_cache is None or value_cache is None or input_metadata.block_tables.numel( - ) == 0: + if (key_cache is None or value_cache is None + or input_metadata.block_tables.numel() == 0): # Set attention bias if not provided. This typically happens at the # very attention layer of every iteration. # FIXME(woosuk): This is a hack. diff --git a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py index b369b53817827..a1d2b7fd4cc5a 100644 --- a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py +++ b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py @@ -1,3 +1,6 @@ +# The kernels in this file are adapted from LightLLM's context_attention_fwd: +# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py + import torch import time import triton @@ -749,6 +752,7 @@ def context_attention_fwd(q, return +# TODO move to a test file @torch.inference_mode() def test_contexted_kv_attention( num_heads: int, diff --git a/vllm/prefix.py b/vllm/prefix.py index a16426a9d561f..18217f12a6897 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -23,15 +23,19 @@ def __init__( ) -> None: self.prefix_id = prefix_id self.token_ids = token_ids + self.block_size = block_size self.length = len(token_ids) assert self.length % block_size == 0 self.on_gpu = False self.on_cpu = False - self.block_table = None + self.block_table: Optional[List[int]] = None # a lock to prevent multiple sequence from calculating the same prefix self.swap_to_gpu = False - def get_block_table_num(self) -> List[int]: + def get_num_blocks(self) -> int: + return self.length // self.block_size + + def get_block_numbers(self) -> List[int]: return [block.block_number for block in self.block_table] def match(self, tokens: List[int]) -> bool: @@ -75,18 +79,9 @@ def add_prefix(self, token_ids: List[int]) -> Prefix: self.prefixes_hash[prefix_hash] = prefix.prefix_id return prefix - # @TODO: this one should also come with a method to identify the prefix - def efficient_search(self, token_ids: List[int]) -> Optional[Prefix]: - # improve this search - for prefix in self.prefixes: - if prefix.match(token_ids): - return prefix - return None - # use this first, if we already know from the application which part of the tokens are prefix. def fixed_search(self, prefix_hash: int) -> Optional[Prefix]: if prefix_hash not in self.prefixes_hash: return None - # print("Found prefix in the pool.") prefix_id = self.prefixes_hash[prefix_hash] return self.prefixes[prefix_id] diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2365abb19b482..7692cfbfb4d9a 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -96,11 +96,12 @@ def _prepare_prompt( prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) prefix_len = 0 - if seq_group_metadata.prefix is not None and seq_group_metadata.prefix.on_gpu: + if (seq_group_metadata.prefix is not None + and seq_group_metadata.prefix.on_gpu): prefix_len = seq_group_metadata.prefix.get_length() assert prefix_len % self.block_size == 0 prompt_tokens = prompt_tokens[prefix_len:] - prefix_block_table = seq_group_metadata.prefix.get_block_table_num( + prefix_block_table = seq_group_metadata.prefix.get_block_numbers( ) prefix_block_tables.append(prefix_block_table) max_num_blocks_per_seq_prompt = max( From 7bcb509c520395bec51c4d7ef0e310f8db36a25e Mon Sep 17 00:00:00 2001 From: caoshiyi Date: Sat, 6 Jan 2024 03:07:46 +0000 Subject: [PATCH 18/37] clean --- vllm/engine/llm_engine.py | 2 -- vllm/entrypoints/api_server.py | 4 ++-- .../layers/triton_kernel/prefix_prefill.py | 17 ----------------- vllm/prefix.py | 3 ++- 4 files changed, 4 insertions(+), 22 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 834e553d212d2..bd38bd5db72f1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -357,8 +357,6 @@ def add_request( hash(tuple(truncated_prefix_token_ids))) if prefix is not None: seq.prefix = prefix - # print("prefix status: ", "on gpu" if prefix.get_status() else "on cpu") - # prefix.update_freq(1.0) else: # create a new prefix seq.prefix = self.scheduler.prefix_pool.add_prefix( diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 665b36fcbeb5b..0f6a50ebbe229 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -38,8 +38,8 @@ async def generate(request: Request) -> Response: sampling_params = SamplingParams(**request_dict) request_id = random_uuid() - results_generator = engine.generate(prompt, prefix_pos, sampling_params, - request_id) + results_generator = engine.generate(prompt, sampling_params, request_id, + prefix_pos) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: diff --git a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py index a1d2b7fd4cc5a..d01bc92b15fe8 100644 --- a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py +++ b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py @@ -152,7 +152,6 @@ def _fwd_kernel( mask=(start_n + offs_n[None, :]) < cur_batch_seq_len - cur_batch_ctx_len, other=0.0) - # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) @@ -341,7 +340,6 @@ def _fwd_kernel_flash_attn_v2( mask=(start_n + offs_n[None, :]) < cur_batch_seq_len - cur_batch_ctx_len, other=0.0) - # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) @@ -427,8 +425,6 @@ def _fwd_kernel_alibi( stride_v_cache_h, stride_v_cache_d, stride_v_cache_bl, - # debuger, - # stride_db_head, stride_db_q, stride_db_k, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, @@ -506,12 +502,6 @@ def _fwd_kernel_alibi( qk += alibi alibi_start_k += BLOCK_N - # # debuger alibi - # offset_db = stride_db_head * cur_head + offset_db_q[:, None] * stride_db_q + offset_db_k[None, :] * stride_db_k - # mask_db = (offset_db_q < cur_batch_seq_len - cur_batch_ctx_len)[:, None] & (offset_db_k < cur_batch_seq_len)[None, :] - # tl.store(debuger + offset_db, alibi, mask=mask_db) - # offset_db_k += BLOCK_N - # -- compute m_ij, p, l_ij m_ij = tl.max(qk, 1) m_i_new = tl.maximum(m_i, m_ij) @@ -566,7 +556,6 @@ def _fwd_kernel_alibi( mask=(start_n + offs_n[None, :]) < cur_batch_seq_len - cur_batch_ctx_len, other=0.0) - # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k, allow_tf32=False) @@ -583,12 +572,6 @@ def _fwd_kernel_alibi( qk += alibi alibi_start_k += BLOCK_N - # # debuger alibi - # offset_db = stride_db_head * cur_head + offset_db_q[:, None] * stride_db_q + offset_db_k[None, :] * stride_db_k - # mask_db = (offset_db_q < cur_batch_seq_len - cur_batch_ctx_len)[:, None] & (offset_db_k < cur_batch_seq_len)[None, :] - # tl.store(debuger + offset_db, alibi, mask=mask_db) - # offset_db_k += BLOCK_N - # -- compute m_ij, p, l_ij m_ij = tl.max(qk, 1) m_i_new = tl.maximum(m_i, m_ij) diff --git a/vllm/prefix.py b/vllm/prefix.py index 18217f12a6897..bc539e03f1e40 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -70,12 +70,13 @@ def __init__( self.block_size = block_size def add_prefix(self, token_ids: List[int]) -> Prefix: + prefix_hash = hash(tuple(token_ids)) + assert prefix_hash not in self.prefixes_hash # generate prefix_id prefix_id = len(self.prefixes) # create a new prefix prefix = Prefix(prefix_id, token_ids, self.block_size) self.prefixes.append(prefix) - prefix_hash = hash(tuple(prefix.token_ids)) self.prefixes_hash[prefix_hash] = prefix.prefix_id return prefix From 8bc52ca7b9f0bb81b7a03b606fb5a684c67f214c Mon Sep 17 00:00:00 2001 From: caoshiyi Date: Sun, 7 Jan 2024 21:14:22 +0000 Subject: [PATCH 19/37] clean & minor --- vllm/core/block_manager.py | 22 ---------------------- vllm/core/scheduler.py | 31 +------------------------------ vllm/worker/model_runner.py | 1 + 3 files changed, 2 insertions(+), 52 deletions(-) diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 74630081d4866..5e1389ca1f771 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -311,28 +311,6 @@ def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: } return block_number_mapping - def swap_out_prefix(self, prefix: Prefix) -> Dict[int, int]: - # GPU block -> CPU block. - # make sure all the reference seq are finished or swapped out before swapping out the prefix - mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} - new_block_table = [] - block_table = prefix.block_table - - for gpu_block in block_table: - cpu_block = self.cpu_allocator.allocate() - mapping[gpu_block] = cpu_block - new_block_table.append(cpu_block) - # Free the GPU block swapped out to CPU. - assert gpu_block.ref_count == 1 - self.gpu_allocator.free(gpu_block) - prefix.block_table = new_block_table - - block_number_mapping = { - gpu_block.block_number: cpu_block.block_number - for gpu_block, cpu_block in mapping.items() - } - return block_number_mapping - def _free_block_table(self, block_table: BlockTable) -> None: for block in set(block_table): if block.device == Device.GPU: diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 5af64d3249df1..6b402ab2c9b38 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -8,7 +8,7 @@ from vllm.logger import init_logger from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) -from vllm.prefix import Prefix, PrefixPool +from vllm.prefix import PrefixPool logger = init_logger(__name__) @@ -193,10 +193,6 @@ def _schedule(self) -> SchedulerOutputs: seq_lens = new_seq_lens seq_group = self.waiting.pop(0) - # swap in the prefix if it is on CPU - if seq_group.prefix is not None and seq_group.prefix.on_cpu: - # prefix.on_gpu will be set inside this function - self._swap_in_prefix(seq_group.prefix, blocks_to_swap_in) self._allocate(seq_group) self.running.append(seq_group) @@ -421,28 +417,3 @@ def _swap_out( blocks_to_swap_out.update(mapping) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq.status = SequenceStatus.SWAPPED - - def _swap_in_prefix( - self, - prefix: Prefix, - blocks_to_swap_in: Dict[int, int], - ) -> None: - mapping = self.block_manager.swap_in_prefix(prefix) - blocks_to_swap_in.update(mapping) - prefix.on_gpu = True - - def _swap_out_prefix( - self, - prefix: Prefix, - blocks_to_swap_out: Dict[int, int], - ) -> None: - if not self.block_manager.can_swap_out_prefix(prefix): - # FIXME(woosuk): Abort the sequence group instead of aborting the - # entire engine. - raise RuntimeError( - "Aborted due to the lack of CPU swap space. Please increase " - "the swap space to avoid this error.") - mapping = self.block_manager.swap_out_prefix(prefix) - blocks_to_swap_out.update(mapping) - prefix.on_cpu = True - prefix.on_gpu = False diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 7692cfbfb4d9a..cba4e3c1488b2 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -601,6 +601,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: input_metadata = InputMetadata( is_prompt=False, slot_mapping=slot_mapping[:batch_size], + prompt_lens=None, max_seq_len=None, start_loc=None, max_context_len=self.max_context_len_to_capture, From abb843b051a37dbb8273b30e1ba33d140df22560 Mon Sep 17 00:00:00 2001 From: caoshiyi Date: Sun, 7 Jan 2024 21:17:08 +0000 Subject: [PATCH 20/37] format --- vllm/core/block_manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 5e1389ca1f771..3eeffa5ece7fd 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -3,7 +3,6 @@ from typing import Dict, List, Optional, Set, Tuple from vllm.block import PhysicalTokenBlock -from vllm.prefix import Prefix from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device From dc08e14a23ba8ef88ec5a431822f73606a4fdca0 Mon Sep 17 00:00:00 2001 From: caoshiyi Date: Sun, 7 Jan 2024 21:24:49 +0000 Subject: [PATCH 21/37] move prefix prefill kernel test to a separate file --- tests/kernels/test_prefix_prefill.py | 164 ++++++++++++++++++ .../layers/triton_kernel/prefix_prefill.py | 163 ----------------- 2 files changed, 164 insertions(+), 163 deletions(-) create mode 100644 tests/kernels/test_prefix_prefill.py diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py new file mode 100644 index 0000000000000..2bf5dbd6a9bbe --- /dev/null +++ b/tests/kernels/test_prefix_prefill.py @@ -0,0 +1,164 @@ +import time +import torch + +from vllm.model_executor.layers.triton_kernel.prefix_prefill import context_attention_fwd + +@torch.inference_mode() +def test_contexted_kv_attention( + num_heads: int, + head_size: int, + dtype: torch.dtype, +) -> None: + import random + random.seed(0) + torch.manual_seed(0) + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask + MAX_SEQ_LEN = 1024 + MAX_CTX_LEN = 1024 + BS = 10 + cache_size = 640 + block_size = 32 + max_block_per_request = 64 + subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] + ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] + seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)] + + num_tokens = sum(subquery_lens) + query = torch.empty(num_tokens, + num_heads, + head_size, + dtype=dtype, + device='cuda') + query.uniform_(-1e-3, 1e-3) + output = torch.empty(num_tokens, + num_heads, + head_size, + dtype=dtype, + device='cuda') + + kv = torch.empty(sum(seq_lens), + 2, + num_heads, + head_size, + dtype=dtype, + device='cuda') + kv.uniform_(-1e-3, 1e-3) + key, value = kv.unbind(dim=1) + + k_cache = torch.zeros(cache_size, + block_size, + num_heads, + head_size, + dtype=dtype, + device='cuda') + v_cache = torch.zeros(cache_size, + block_size, + num_heads, + head_size, + dtype=dtype, + device='cuda') + k = torch.zeros(sum(subquery_lens), + num_heads, + head_size, + dtype=dtype, + device='cuda') + v = torch.zeros(sum(subquery_lens), + num_heads, + head_size, + dtype=dtype, + device='cuda') + values = torch.arange(0, cache_size, dtype=torch.long, device='cuda') + values = values[torch.randperm(cache_size)] + block_table = values[:BS * max_block_per_request].view( + BS, max_block_per_request) + b_seq_len = torch.tensor(seq_lens, dtype=torch.long, device='cuda') + b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long, device='cuda') + b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1], + dtype=torch.long, + device='cuda'), + dim=0) + max_input_len = MAX_SEQ_LEN + # copy kv to cache + b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], + dtype=torch.long, + device='cuda'), + dim=0) + for i in range(BS): + for j in range(subquery_lens[i]): + k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + + j]) + v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + + b_ctx_len[i] + j]) + cur_ctx = 0 + block_id = 0 + while cur_ctx < b_ctx_len[i]: + start_loc = b_seq_start_loc[i] + cur_ctx + if cur_ctx + block_size > b_ctx_len[i]: + end_loc = b_seq_start_loc[i] + b_ctx_len[i] + else: + end_loc = start_loc + block_size + start_slot = block_table[i, block_id] * block_size + end_slot = start_slot + end_loc - start_loc + k_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_( + key[start_loc:end_loc]) + v_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_( + value[start_loc:end_loc]) + cur_ctx += block_size + block_id += 1 + # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] + k_cache = k_cache.view(-1, block_size, num_heads, head_size // 8, + 8).permute(0, 2, 3, 1, 4).contiguous() + # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] to V_cache[num_blocks, num_kv_heads, head_size, block_size] + v_cache = v_cache.view(-1, block_size, num_heads, + head_size).permute(0, 2, 3, 1).contiguous() + + context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table, + b_start_loc, b_seq_len, b_ctx_len, max_input_len) + torch.cuda.synchronize() + start_time = time.time() + context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table, + b_start_loc, b_seq_len, b_ctx_len, max_input_len) + torch.cuda.synchronize() + end_time = time.time() + print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") + + scale = float(1.0 / (head_size**0.5)) + + attn_op = xops.fmha.cutlass.FwOp() + + attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( + subquery_lens, seq_lens) + output_ref = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=scale, + op=attn_op, + ) + torch.cuda.synchronize() + start_time = time.time() + output_ref = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=scale, + op=attn_op, + ) + torch.cuda.synchronize() + end_time = time.time() + print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") + output_ref = output_ref.squeeze(0) + print(output_ref.shape) + print("max ", torch.max(torch.abs(output_ref - output))) + print("mean ", torch.mean(torch.abs(output_ref - output))) + print(output[0, 0, :10]) + print(output_ref[0, 0, :10]) + assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) + + +test_contexted_kv_attention(12, 128, torch.float16) \ No newline at end of file diff --git a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py index d01bc92b15fe8..f0686cc923ca7 100644 --- a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py +++ b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py @@ -2,7 +2,6 @@ # https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py import torch -import time import triton import triton.language as tl @@ -733,165 +732,3 @@ def context_attention_fwd(q, num_stages=1, ) return - - -# TODO move to a test file -@torch.inference_mode() -def test_contexted_kv_attention( - num_heads: int, - head_size: int, - dtype: torch.dtype, -) -> None: - import random - random.seed(0) - torch.manual_seed(0) - from xformers import ops as xops - from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask - MAX_SEQ_LEN = 1024 - MAX_CTX_LEN = 1024 - BS = 10 - cache_size = 640 - block_size = 32 - max_block_per_request = 64 - subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] - ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] - seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)] - - num_tokens = sum(subquery_lens) - query = torch.empty(num_tokens, - num_heads, - head_size, - dtype=dtype, - device='cuda') - query.uniform_(-1e-3, 1e-3) - output = torch.empty(num_tokens, - num_heads, - head_size, - dtype=dtype, - device='cuda') - - kv = torch.empty(sum(seq_lens), - 2, - num_heads, - head_size, - dtype=dtype, - device='cuda') - kv.uniform_(-1e-3, 1e-3) - key, value = kv.unbind(dim=1) - - k_cache = torch.zeros(cache_size, - block_size, - num_heads, - head_size, - dtype=dtype, - device='cuda') - v_cache = torch.zeros(cache_size, - block_size, - num_heads, - head_size, - dtype=dtype, - device='cuda') - k = torch.zeros(sum(subquery_lens), - num_heads, - head_size, - dtype=dtype, - device='cuda') - v = torch.zeros(sum(subquery_lens), - num_heads, - head_size, - dtype=dtype, - device='cuda') - values = torch.arange(0, cache_size, dtype=torch.long, device='cuda') - values = values[torch.randperm(cache_size)] - block_table = values[:BS * max_block_per_request].view( - BS, max_block_per_request) - b_seq_len = torch.tensor(seq_lens, dtype=torch.long, device='cuda') - b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long, device='cuda') - b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1], - dtype=torch.long, - device='cuda'), - dim=0) - max_input_len = MAX_SEQ_LEN - # copy kv to cache - b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], - dtype=torch.long, - device='cuda'), - dim=0) - for i in range(BS): - for j in range(subquery_lens[i]): - k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + - j]) - v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + - b_ctx_len[i] + j]) - cur_ctx = 0 - block_id = 0 - while cur_ctx < b_ctx_len[i]: - start_loc = b_seq_start_loc[i] + cur_ctx - if cur_ctx + block_size > b_ctx_len[i]: - end_loc = b_seq_start_loc[i] + b_ctx_len[i] - else: - end_loc = start_loc + block_size - start_slot = block_table[i, block_id] * block_size - end_slot = start_slot + end_loc - start_loc - k_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_( - key[start_loc:end_loc]) - v_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_( - value[start_loc:end_loc]) - cur_ctx += block_size - block_id += 1 - # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] - k_cache = k_cache.view(-1, block_size, num_heads, head_size // 8, - 8).permute(0, 2, 3, 1, 4).contiguous() - # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] to V_cache[num_blocks, num_kv_heads, head_size, block_size] - v_cache = v_cache.view(-1, block_size, num_heads, - head_size).permute(0, 2, 3, 1).contiguous() - - context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table, - b_start_loc, b_seq_len, b_ctx_len, max_input_len) - torch.cuda.synchronize() - start_time = time.time() - context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table, - b_start_loc, b_seq_len, b_ctx_len, max_input_len) - torch.cuda.synchronize() - end_time = time.time() - print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") - - scale = float(1.0 / (head_size**0.5)) - - attn_op = xops.fmha.cutlass.FwOp() - - attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( - subquery_lens, seq_lens) - output_ref = xops.memory_efficient_attention_forward( - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), - attn_bias=attn_bias, - p=0.0, - scale=scale, - op=attn_op, - ) - torch.cuda.synchronize() - start_time = time.time() - output_ref = xops.memory_efficient_attention_forward( - query.unsqueeze(0), - key.unsqueeze(0), - value.unsqueeze(0), - attn_bias=attn_bias, - p=0.0, - scale=scale, - op=attn_op, - ) - torch.cuda.synchronize() - end_time = time.time() - print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") - output_ref = output_ref.squeeze(0) - print(output_ref.shape) - print("max ", torch.max(torch.abs(output_ref - output))) - print("mean ", torch.mean(torch.abs(output_ref - output))) - print(output[0, 0, :10]) - print(output_ref[0, 0, :10]) - assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) - - -# test_contexted_kv_attention(12, 128, torch.float16) From f0f8f669245f63be664d4384ffb23a5c771bb1f6 Mon Sep 17 00:00:00 2001 From: caoshiyi Date: Sun, 7 Jan 2024 21:29:09 +0000 Subject: [PATCH 22/37] format --- tests/kernels/test_prefix_prefill.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 2bf5dbd6a9bbe..4b663eb1c887a 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -3,6 +3,7 @@ from vllm.model_executor.layers.triton_kernel.prefix_prefill import context_attention_fwd + @torch.inference_mode() def test_contexted_kv_attention( num_heads: int, @@ -161,4 +162,4 @@ def test_contexted_kv_attention( assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) -test_contexted_kv_attention(12, 128, torch.float16) \ No newline at end of file +test_contexted_kv_attention(12, 128, torch.float16) From 3678af69aa5dc560a212d0e7b09527d48f655e07 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 13 Jan 2024 01:00:11 +0000 Subject: [PATCH 23/37] fix test --- tests/kernels/test_prefix_prefill.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 4b663eb1c887a..553227b975e26 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -1,20 +1,27 @@ +import random +import pytest import time -import torch -from vllm.model_executor.layers.triton_kernel.prefix_prefill import context_attention_fwd +import torch +from vllm.model_executor.layers.triton_kernel.prefix_prefill import (context_attention_fwd) +from xformers import ops as xops +from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask +NUM_HEADS = [12] +HEAD_SIZES = [128] +DTYPES = [torch.float16] +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) @torch.inference_mode() def test_contexted_kv_attention( num_heads: int, head_size: int, dtype: torch.dtype, ) -> None: - import random random.seed(0) torch.manual_seed(0) - from xformers import ops as xops - from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask MAX_SEQ_LEN = 1024 MAX_CTX_LEN = 1024 BS = 10 @@ -107,10 +114,12 @@ def test_contexted_kv_attention( value[start_loc:end_loc]) cur_ctx += block_size block_id += 1 - # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] + # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] + # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] k_cache = k_cache.view(-1, block_size, num_heads, head_size // 8, 8).permute(0, 2, 3, 1, 4).contiguous() - # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] to V_cache[num_blocks, num_kv_heads, head_size, block_size] + # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] + # to V_cache[num_blocks, num_kv_heads, head_size, block_size] v_cache = v_cache.view(-1, block_size, num_heads, head_size).permute(0, 2, 3, 1).contiguous() @@ -154,11 +163,6 @@ def test_contexted_kv_attention( end_time = time.time() print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") output_ref = output_ref.squeeze(0) - print(output_ref.shape) - print("max ", torch.max(torch.abs(output_ref - output))) - print("mean ", torch.mean(torch.abs(output_ref - output))) - print(output[0, 0, :10]) - print(output_ref[0, 0, :10]) assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) From 037950b055d5c2b985d56f3dbed0467a5a49fb7b Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 13 Jan 2024 01:12:01 +0000 Subject: [PATCH 24/37] format kernel --- tests/kernels/test_prefix_prefill.py | 4 +- .../layers/triton_kernel/prefix_prefill.py | 126 +++++++++--------- 2 files changed, 63 insertions(+), 67 deletions(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 553227b975e26..8a0847d14adc1 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -3,7 +3,8 @@ import time import torch -from vllm.model_executor.layers.triton_kernel.prefix_prefill import (context_attention_fwd) +from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( + context_attention_fwd) from xformers import ops as xops from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask @@ -11,6 +12,7 @@ HEAD_SIZES = [128] DTYPES = [torch.float16] + @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", DTYPES) diff --git a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py index f0686cc923ca7..8fa70054f02ca 100644 --- a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py +++ b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py @@ -63,9 +63,9 @@ def _fwd_kernel( offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = (cur_batch_in_all_start_index + - offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[ - None, :] * stride_qd + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) q = tl.load( Q + off_q, @@ -84,17 +84,16 @@ def _fwd_kernel( ((start_n + offs_n) // block_size) * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_ctx_len, other=0) - off_k = bn[ - None, :] * stride_k_cache_bs + cur_head * stride_k_cache_h + ( - offs_d[:, None] // x) * stride_k_cache_d + ( - (start_n + offs_n[None, :]) % - block_size) * stride_k_cache_bl + ( - offs_d[:, None] % x) * stride_k_cache_x - off_v = bn[:, - None] * stride_v_cache_bs + cur_head * stride_v_cache_h + offs_d[ - None, :] * stride_v_cache_d + ( - start_n + offs_n[:, None] - ) % block_size * stride_v_cache_bl + off_k = (bn[None, :] * stride_k_cache_bs + + cur_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = ( + bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) k = tl.load(K_cache + off_k, mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, other=0.0) @@ -132,11 +131,10 @@ def _fwd_kernel( l_i = l_i_new m_i = m_i_new - off_k = offs_n[ - None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, - None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[ - None, :] * stride_vd + off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh + + offs_d[None, :] * stride_vd) k_ptrs = K + off_k v_ptrs = V + off_v @@ -187,9 +185,9 @@ def _fwd_kernel( l_i = l_i_new m_i = m_i_new # initialize pointers to output - off_o = (cur_batch_in_all_start_index + - offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[ - None, :] * stride_od + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) out_ptrs = Out + off_o tl.store(out_ptrs, acc, @@ -252,9 +250,9 @@ def _fwd_kernel_flash_attn_v2( offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = (cur_batch_in_all_start_index + - offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[ - None, :] * stride_qd + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) q = tl.load( Q + off_q, @@ -273,17 +271,16 @@ def _fwd_kernel_flash_attn_v2( ((start_n + offs_n) // block_size) * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_ctx_len, other=0) - off_k = bn[ - None, :] * stride_k_cache_bs + cur_head * stride_k_cache_h + ( - offs_d[:, None] // x) * stride_k_cache_d + ( - (start_n + offs_n[None, :]) % - block_size) * stride_k_cache_bl + ( - offs_d[:, None] % x) * stride_k_cache_x - off_v = bn[:, - None] * stride_v_cache_bs + cur_head * stride_v_cache_h + offs_d[ - None, :] * stride_v_cache_d + ( - start_n + offs_n[:, None] - ) % block_size * stride_v_cache_bl + off_k = (bn[None, :] * stride_k_cache_bs + + cur_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = ( + bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) k = tl.load(K_cache + off_k, mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, other=0.0) @@ -320,11 +317,10 @@ def _fwd_kernel_flash_attn_v2( l_i = l_i_new m_i = m_i_new - off_k = offs_n[ - None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, - None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[ - None, :] * stride_vd + off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh + + offs_d[None, :] * stride_vd) k_ptrs = K + off_k v_ptrs = V + off_v @@ -376,9 +372,9 @@ def _fwd_kernel_flash_attn_v2( # acc /= l_i[:, None] # initialize pointers to output - off_o = (cur_batch_in_all_start_index + - offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[ - None, :] * stride_od + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) out_ptrs = Out + off_o tl.store(out_ptrs, acc, @@ -446,9 +442,9 @@ def _fwd_kernel_alibi( offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = (cur_batch_in_all_start_index + - offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[ - None, :] * stride_qd + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) q = tl.load( Q + off_q, @@ -471,17 +467,16 @@ def _fwd_kernel_alibi( ((start_n + offs_n) // block_size) * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_ctx_len, other=0) - off_k = bn[ - None, :] * stride_k_cache_bs + cur_head * stride_k_cache_h + ( - offs_d[:, None] // x) * stride_k_cache_d + ( - (start_n + offs_n[None, :]) % - block_size) * stride_k_cache_bl + ( - offs_d[:, None] % x) * stride_k_cache_x - off_v = bn[:, - None] * stride_v_cache_bs + cur_head * stride_v_cache_h + offs_d[ - None, :] * stride_v_cache_d + ( - start_n + offs_n[:, None] - ) % block_size * stride_v_cache_bl + off_k = (bn[None, :] * stride_k_cache_bs + + cur_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = ( + bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) k = tl.load(K_cache + off_k, mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, other=0.0) @@ -527,11 +522,10 @@ def _fwd_kernel_alibi( l_i = l_i_new m_i = m_i_new - off_k = offs_n[ - None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, - None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[ - None, :] * stride_vd + off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh + + offs_d[None, :] * stride_vd) k_ptrs = K + off_k v_ptrs = V + off_v @@ -602,9 +596,9 @@ def _fwd_kernel_alibi( acc = acc / l_i[:, None] # initialize pointers to output - off_o = (cur_batch_in_all_start_index + - offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[ - None, :] * stride_od + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) out_ptrs = Out + off_o tl.store(out_ptrs, acc, From b414c770d7e1b23cbe02c6e1a70de8cc654a4e37 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 13 Jan 2024 01:22:36 +0000 Subject: [PATCH 25/37] fix format --- vllm/model_executor/layers/attention.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index dafc2897e9a79..8b5c6ab30d7b7 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -10,7 +10,8 @@ from vllm._C import ops from vllm._C import cache_ops from vllm.model_executor.input_metadata import InputMetadata -from vllm.model_executor.layers.triton_kernel.prefix_prefill import context_attention_fwd +from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( + context_attention_fwd) from vllm.utils import is_hip _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] @@ -119,8 +120,8 @@ def forward( # normal attention if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): - # Set attention bias if not provided. This typically happens at the - # very attention layer of every iteration. + # Set attention bias if not provided. This typically happens at + # the very attention layer of every iteration. # FIXME(woosuk): This is a hack. if input_metadata.attn_bias is None: if self.alibi_slopes is None: @@ -135,8 +136,8 @@ def forward( self.alibi_slopes, self.num_kv_heads, batch_size, seq_len, query.dtype) - # TODO(woosuk): Too many view operations. Let's try to reduce them - # in the future for code readability. + # TODO(woosuk): Too many view operations. Let's try to reduce + # them in the future for code readability. if self.alibi_slopes is None: query = query.unsqueeze(0) key = key.unsqueeze(0) From bb4ca734af897a8cd45bbe3d540b12b53b4ca38a Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 15 Jan 2024 07:03:41 +0000 Subject: [PATCH 26/37] [WIP] Refactor --- vllm/block.py | 4 ++ vllm/core/block_manager.py | 48 +++++++++------------ vllm/engine/llm_engine.py | 25 ++--------- vllm/entrypoints/api_server.py | 6 ++- vllm/prefix.py | 77 +++++++++++++++------------------- vllm/sequence.py | 5 +-- vllm/worker/model_runner.py | 47 ++++++++++----------- 7 files changed, 88 insertions(+), 124 deletions(-) diff --git a/vllm/block.py b/vllm/block.py index 435aa50ca22ea..5fe39ed47b2ff 100644 --- a/vllm/block.py +++ b/vllm/block.py @@ -66,3 +66,7 @@ def __repr__(self) -> str: return (f'PhysicalTokenBlock(device={self.device}, ' f'block_number={self.block_number}, ' f'ref_count={self.ref_count})') + + +# Mapping: logical block number -> physical block. +BlockTable = List[PhysicalTokenBlock] diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 3eeffa5ece7fd..16060f9145d8c 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -2,13 +2,10 @@ import enum from typing import Dict, List, Optional, Set, Tuple -from vllm.block import PhysicalTokenBlock +from vllm.block import BlockTable, PhysicalTokenBlock from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device -# Mapping: logical block number -> physical block. -BlockTable = List[PhysicalTokenBlock] - class BlockAllocator: """Manages free physical token blocks for a device. @@ -135,20 +132,14 @@ def allocate(self, seq_group: SequenceGroup) -> None: block_table: BlockTable = [] prefix_block_table: BlockTable = [] num_prefix_blocks = 0 - if seq_group.prefix is not None: - # prefix is already on gpu or - # will be swapped in before the actual computation - if seq_group.prefix.on_gpu: - num_prompt_blocks -= seq_group.prefix.get_num_blocks() - for block in seq_group.prefix.block_table: - block.ref_count += seq_group.num_seqs() - block_table.append(block) - # allocate blocks for the prefix, - # we need to calculate the prefix's kv in this run - elif not seq_group.prefix.swap_to_gpu: - num_prefix_blocks = seq_group.prefix.get_num_blocks() - seq_group.prefix.swap_to_gpu = True + prefix = seq_group.prefix + if prefix is not None and prefix.allocated: + # Prefix has already been allocated. Use the existing block table. + num_prompt_blocks -= prefix.get_num_blocks() + for block in prefix.block_table: + block.ref_count += seq_group.num_seqs() + block_table.append(block) for logical_idx in range(num_prompt_blocks): if (self.block_sliding_window is not None @@ -159,22 +150,20 @@ def allocate(self, seq_group: SequenceGroup) -> None: # Set the reference counts of the token blocks. block.ref_count = seq_group.num_seqs() block_table.append(block) - # Store the blocks computed by - # the first seq group using this prefix - # the other seq groups in the same batch will also compute the prefix - # but those blocks won't be stored - if logical_idx < num_prefix_blocks: + + if prefix is not None and not prefix.allocated: + # Allocate blocks for the prefix, we will compute the prefix's + # KV cache in this run. + num_prefix_blocks = prefix.get_num_blocks() + prefix_block_table = block_table[:num_prefix_blocks] + for block in prefix_block_table: block.ref_count += 1 - prefix_block_table.append(block) + prefix.set_block_table(prefix_block_table) # Assign the block table for each sequence. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): self.block_tables[seq.seq_id] = block_table.copy() - # Record the prefix block table for the prefix - if num_prefix_blocks > 0: - seq_group.prefix.block_table = prefix_block_table.copy() - def can_append_slot(self, seq_group: SequenceGroup) -> bool: # Simple heuristic: If there is at least one free block # for each sequence, we can append. @@ -288,8 +277,9 @@ def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: block_table = self.block_tables[seq.seq_id] for gpu_block in block_table: - # do not swap out the prefix - if seq_group.prefix is not None and gpu_block in seq_group.prefix.block_table: + if (seq_group.prefix is not None + and gpu_block in seq_group.prefix.block_table): + # We do not swap out the prefix blocks. self.gpu_allocator.free(gpu_block) continue diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index db32abbd2c22a..a6a1e3014cbab 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -390,24 +390,13 @@ def add_request( seq_id = next(self.seq_counter) seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) - # check prefix - if prefix_pos is not None: - # a temp workaround - prefix_pos = (prefix_pos // block_size) * block_size - if prefix_pos > 0: - truncated_prefix_token_ids = prompt_token_ids[:prefix_pos] - prefix = self.scheduler.prefix_pool.fixed_search( - hash(tuple(truncated_prefix_token_ids))) - if prefix is not None: - seq.prefix = prefix - else: - # create a new prefix - seq.prefix = self.scheduler.prefix_pool.add_prefix( - truncated_prefix_token_ids) + # Check whether the input specifies prefix + prefix = self.scheduler.prefix_pool.add_or_get_prefix( + prompt_token_ids[:prefix_pos]) if prefix_pos is not None else None # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, - arrival_time, seq.prefix) + arrival_time, prefix) # Add the sequence group to the scheduler. self.scheduler.add_seq_group(seq_group) @@ -678,12 +667,6 @@ def _process_model_outputs( request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) - # update prefix state - for seq_group in scheduled_seq_groups: - if seq_group.prefix is not None and seq_group.prefix.swap_to_gpu: - seq_group.prefix.on_gpu = True - seq_group.prefix.swap_to_gpu = False - if self.log_stats: # Log the system stats. self._log_system_stats(scheduler_outputs.prompt_run, diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 67605df68cfa8..f7b8d258fae4c 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -38,8 +38,10 @@ async def generate(request: Request) -> Response: sampling_params = SamplingParams(**request_dict) request_id = random_uuid() - results_generator = engine.generate(prompt, sampling_params, request_id, - prefix_pos) + results_generator = engine.generate(prompt, + sampling_params, + request_id, + prefix_pos=prefix_pos) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: diff --git a/vllm/prefix.py b/vllm/prefix.py index bc539e03f1e40..5270e248cd8cc 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -1,36 +1,33 @@ -from typing import List, Optional +from typing import Dict, List, Sequence, Tuple, Optional + +from vllm.block import BlockTable class Prefix: - """Data and states associated with a prefix of prompt tokens for multiple sequence groups. + """Data and states associated with a prefix of prompt tokens for multiple + sequence groups. Args: prefix_id: The id of the prefix in the prefix pool. token_ids: The token ids of the prefix. block_size: The block size of the executed model. - - Attributes: - on_gpu: True if the prefix will be on GPU before the execution of the model. - on_cpu: True if the prefix is on CPU. - swap_to_gpu: True when the prefix will be computed during the execution of the model. """ def __init__( self, - prefix_id: int, - token_ids: List[int], + token_ids: Sequence[int], block_size: int, ) -> None: - self.prefix_id = prefix_id - self.token_ids = token_ids + self.token_ids = tuple(token_ids) self.block_size = block_size self.length = len(token_ids) + self.hash = hash(token_ids) assert self.length % block_size == 0 - self.on_gpu = False - self.on_cpu = False - self.block_table: Optional[List[int]] = None - # a lock to prevent multiple sequence from calculating the same prefix - self.swap_to_gpu = False + self.block_table: Optional[BlockTable] = None + + @property + def allocated(self) -> bool: + return self.block_table is not None def get_num_blocks(self) -> int: return self.length // self.block_size @@ -38,26 +35,24 @@ def get_num_blocks(self) -> int: def get_block_numbers(self) -> List[int]: return [block.block_number for block in self.block_table] - def match(self, tokens: List[int]) -> bool: - return tokens[:self.length] == self.token_ids - - # whether the prefix is on GPU or not - def get_status(self) -> bool: - return self.on_gpu - def get_length(self) -> int: return self.length + def get_hash(self) -> int: + return self.hash + + def set_block_table(self, block_table: BlockTable) -> None: + self.block_table = block_table.copy() + class PrefixPool: """Manages all the prompt prefixes. Args: block_size: The block size of the executed model. - + Attributes: prefixes: A list of all the prefixes. - prefixes_hash: Mapping from the hash of the prefix to the prefix id. block_size: The block size of the executed model. """ @@ -65,24 +60,20 @@ def __init__( self, block_size: int, ) -> None: - self.prefixes = [] - self.prefixes_hash = {} + self.prefixes: Dict[int, Prefix] = {} self.block_size = block_size - def add_prefix(self, token_ids: List[int]) -> Prefix: - prefix_hash = hash(tuple(token_ids)) - assert prefix_hash not in self.prefixes_hash - # generate prefix_id - prefix_id = len(self.prefixes) - # create a new prefix - prefix = Prefix(prefix_id, token_ids, self.block_size) - self.prefixes.append(prefix) - self.prefixes_hash[prefix_hash] = prefix.prefix_id - return prefix - - # use this first, if we already know from the application which part of the tokens are prefix. - def fixed_search(self, prefix_hash: int) -> Optional[Prefix]: - if prefix_hash not in self.prefixes_hash: + def _truncate_token_ids(self, token_ids: Sequence[int]) -> Tuple[int]: + new_length = len(token_ids) // self.block_size * self.block_size + return tuple(token_ids[:new_length]) + + def add_or_get_prefix(self, token_ids: Sequence[int]) -> Optional[Prefix]: + token_ids = self._truncate_token_ids(token_ids) + if len(token_ids) == 0: + # Prefix is empty. return None - prefix_id = self.prefixes_hash[prefix_hash] - return self.prefixes[prefix_id] + prefix = Prefix(token_ids, self.block_size) + prefix_hash = prefix.get_hash() + if prefix_hash not in self.prefixes_hash: + self.prefixes[prefix_hash] = prefix + return self.prefixes[prefix_hash] diff --git a/vllm/sequence.py b/vllm/sequence.py index d8fe6fdf453dd..00fa72dc56956 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -114,11 +114,9 @@ def __init__( prompt: str, prompt_token_ids: List[int], block_size: int, - prefix: Optional[Prefix] = None, ) -> None: self.seq_id = seq_id self.prompt = prompt - self.prefix = prefix self.block_size = block_size self.data = SequenceData(prompt_token_ids) @@ -245,8 +243,8 @@ def __init__( self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.sampling_params = sampling_params self.arrival_time = arrival_time + self.prefix: Optional[Prefix] = None self.prompt_logprobs: Optional[PromptLogprobs] = None - self.prefix = prefix @property def prompt(self) -> str: @@ -332,7 +330,6 @@ def __repr__(self) -> str: class SequenceGroupMetadata: """Metadata for a sequence group. Used to create `InputMetadata`. - Args: request_id: The ID of the request. is_prompt: Whether the request is at prompt stage. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0a76c65b100f8..27dce5e5a9376 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -84,7 +84,6 @@ def _prepare_prompt( context_lens: List[int] = [] subquery_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] - max_num_blocks_per_seq_prompt = 0 for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -96,16 +95,11 @@ def _prepare_prompt( prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) prefix_len = 0 - if (seq_group_metadata.prefix is not None - and seq_group_metadata.prefix.on_gpu): - prefix_len = seq_group_metadata.prefix.get_length() - assert prefix_len % self.block_size == 0 + prefix = seq_group_metadata.prefix + if prefix is not None and prefix.allocated: + prefix_len = prefix.get_length() prompt_tokens = prompt_tokens[prefix_len:] - prefix_block_table = seq_group_metadata.prefix.get_block_numbers( - ) - prefix_block_tables.append(prefix_block_table) - max_num_blocks_per_seq_prompt = max( - max_num_blocks_per_seq_prompt, len(prefix_block_table)) + prefix_block_tables.append(prefix.get_block_numbers()) else: prefix_block_tables.append([]) # actual prompt lens @@ -134,7 +128,9 @@ def _prepare_prompt( # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. start_idx = 0 if self.sliding_window is not None: - assert prefix_len == 0, "prefix caching is currently not supported when using sliding window" + assert prefix_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention") start_idx = max(0, prompt_len - self.sliding_window) for i in range(prefix_len, prompt_len): if i < start_idx: @@ -162,15 +158,14 @@ def _prepare_prompt( context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device='cuda') - - # prefix block tables + # Prepare prefix block tables + max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) block_tables = _make_tensor_with_pad( prefix_block_tables, - max_len=max_num_blocks_per_seq_prompt, + max_len=max_prompt_block_table_len, pad=0, dtype=torch.int, ) - start_loc_tensor = torch.arange(0, len(prompt_lens) * max_prompt_len, max_prompt_len, @@ -191,7 +186,8 @@ def _prepare_prompt( block_tables=block_tables, use_cuda_graph=False, ) - return input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens + return (input_tokens, input_positions, input_metadata, prompt_lens, + subquery_lens) def _prepare_decode( self, @@ -301,7 +297,7 @@ def _prepare_decode( block_tables=block_tables, use_cuda_graph=use_captured_graph, ) - return input_tokens, input_positions, input_metadata, None + return input_tokens, input_positions, input_metadata def _prepare_sample( self, @@ -315,7 +311,7 @@ def _prepare_sample( categorized_sample_indices = {t: [] for t in SamplingType} categorized_sample_indices_start_idx = 0 - max_prompt_len = max(subquery_lens) if subquery_lens else 1 + max_subquery_len = max(subquery_lens) if subquery_lens else 1 for i, seq_group_metadata in enumerate(seq_group_metadata_list): seq_ids = list(seq_group_metadata.seq_data.keys()) sampling_params = seq_group_metadata.sampling_params @@ -323,10 +319,10 @@ def _prepare_sample( if seq_group_metadata.is_prompt: assert len(seq_ids) == 1 - prompt_len = subquery_lens[i] + subquery_len = subquery_lens[i] if sampling_params.prompt_logprobs is not None: # NOTE: prompt token positions do not need sample, skip - categorized_sample_indices_start_idx += prompt_len - 1 + categorized_sample_indices_start_idx += subquery_len - 1 categorized_sample_indices[ sampling_params.sampling_type].append( @@ -336,10 +332,10 @@ def _prepare_sample( if sampling_params.prompt_logprobs is not None: selected_token_indices.extend( range(selected_token_start_idx, - selected_token_start_idx + prompt_len - 1)) + selected_token_start_idx + subquery_len - 1)) selected_token_indices.append(selected_token_start_idx + - prompt_len - 1) - selected_token_start_idx += max_prompt_len + subquery_len - 1) + selected_token_start_idx += max_subquery_len else: num_seqs = len(seq_ids) selected_token_indices.extend( @@ -387,8 +383,9 @@ def prepare_input_tensors( (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens) = self._prepare_prompt(seq_group_metadata_list) else: - (input_tokens, input_positions, input_metadata, - subquery_lens) = self._prepare_decode(seq_group_metadata_list) + (input_tokens, input_positions, input_metadata + ) = self._prepare_decode(seq_group_metadata_list) + subquery_lens = None prompt_lens = [] sampling_metadata = self._prepare_sample(seq_group_metadata_list, prompt_lens, From 37cd3fc1960eed04e4c933e25610c4f27138a0fc Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Mon, 15 Jan 2024 07:03:41 +0000 Subject: [PATCH 27/37] fix comments --- vllm/block.py | 4 ++ vllm/core/block_manager.py | 48 +++++++++------------ vllm/engine/llm_engine.py | 25 ++--------- vllm/entrypoints/api_server.py | 6 ++- vllm/prefix.py | 77 +++++++++++++++------------------- vllm/sequence.py | 5 +-- vllm/worker/model_runner.py | 47 ++++++++++----------- 7 files changed, 88 insertions(+), 124 deletions(-) diff --git a/vllm/block.py b/vllm/block.py index 435aa50ca22ea..5fe39ed47b2ff 100644 --- a/vllm/block.py +++ b/vllm/block.py @@ -66,3 +66,7 @@ def __repr__(self) -> str: return (f'PhysicalTokenBlock(device={self.device}, ' f'block_number={self.block_number}, ' f'ref_count={self.ref_count})') + + +# Mapping: logical block number -> physical block. +BlockTable = List[PhysicalTokenBlock] diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 3eeffa5ece7fd..16060f9145d8c 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -2,13 +2,10 @@ import enum from typing import Dict, List, Optional, Set, Tuple -from vllm.block import PhysicalTokenBlock +from vllm.block import BlockTable, PhysicalTokenBlock from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.utils import Device -# Mapping: logical block number -> physical block. -BlockTable = List[PhysicalTokenBlock] - class BlockAllocator: """Manages free physical token blocks for a device. @@ -135,20 +132,14 @@ def allocate(self, seq_group: SequenceGroup) -> None: block_table: BlockTable = [] prefix_block_table: BlockTable = [] num_prefix_blocks = 0 - if seq_group.prefix is not None: - # prefix is already on gpu or - # will be swapped in before the actual computation - if seq_group.prefix.on_gpu: - num_prompt_blocks -= seq_group.prefix.get_num_blocks() - for block in seq_group.prefix.block_table: - block.ref_count += seq_group.num_seqs() - block_table.append(block) - # allocate blocks for the prefix, - # we need to calculate the prefix's kv in this run - elif not seq_group.prefix.swap_to_gpu: - num_prefix_blocks = seq_group.prefix.get_num_blocks() - seq_group.prefix.swap_to_gpu = True + prefix = seq_group.prefix + if prefix is not None and prefix.allocated: + # Prefix has already been allocated. Use the existing block table. + num_prompt_blocks -= prefix.get_num_blocks() + for block in prefix.block_table: + block.ref_count += seq_group.num_seqs() + block_table.append(block) for logical_idx in range(num_prompt_blocks): if (self.block_sliding_window is not None @@ -159,22 +150,20 @@ def allocate(self, seq_group: SequenceGroup) -> None: # Set the reference counts of the token blocks. block.ref_count = seq_group.num_seqs() block_table.append(block) - # Store the blocks computed by - # the first seq group using this prefix - # the other seq groups in the same batch will also compute the prefix - # but those blocks won't be stored - if logical_idx < num_prefix_blocks: + + if prefix is not None and not prefix.allocated: + # Allocate blocks for the prefix, we will compute the prefix's + # KV cache in this run. + num_prefix_blocks = prefix.get_num_blocks() + prefix_block_table = block_table[:num_prefix_blocks] + for block in prefix_block_table: block.ref_count += 1 - prefix_block_table.append(block) + prefix.set_block_table(prefix_block_table) # Assign the block table for each sequence. for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): self.block_tables[seq.seq_id] = block_table.copy() - # Record the prefix block table for the prefix - if num_prefix_blocks > 0: - seq_group.prefix.block_table = prefix_block_table.copy() - def can_append_slot(self, seq_group: SequenceGroup) -> bool: # Simple heuristic: If there is at least one free block # for each sequence, we can append. @@ -288,8 +277,9 @@ def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: block_table = self.block_tables[seq.seq_id] for gpu_block in block_table: - # do not swap out the prefix - if seq_group.prefix is not None and gpu_block in seq_group.prefix.block_table: + if (seq_group.prefix is not None + and gpu_block in seq_group.prefix.block_table): + # We do not swap out the prefix blocks. self.gpu_allocator.free(gpu_block) continue diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index db32abbd2c22a..a6a1e3014cbab 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -390,24 +390,13 @@ def add_request( seq_id = next(self.seq_counter) seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) - # check prefix - if prefix_pos is not None: - # a temp workaround - prefix_pos = (prefix_pos // block_size) * block_size - if prefix_pos > 0: - truncated_prefix_token_ids = prompt_token_ids[:prefix_pos] - prefix = self.scheduler.prefix_pool.fixed_search( - hash(tuple(truncated_prefix_token_ids))) - if prefix is not None: - seq.prefix = prefix - else: - # create a new prefix - seq.prefix = self.scheduler.prefix_pool.add_prefix( - truncated_prefix_token_ids) + # Check whether the input specifies prefix + prefix = self.scheduler.prefix_pool.add_or_get_prefix( + prompt_token_ids[:prefix_pos]) if prefix_pos is not None else None # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, - arrival_time, seq.prefix) + arrival_time, prefix) # Add the sequence group to the scheduler. self.scheduler.add_seq_group(seq_group) @@ -678,12 +667,6 @@ def _process_model_outputs( request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) - # update prefix state - for seq_group in scheduled_seq_groups: - if seq_group.prefix is not None and seq_group.prefix.swap_to_gpu: - seq_group.prefix.on_gpu = True - seq_group.prefix.swap_to_gpu = False - if self.log_stats: # Log the system stats. self._log_system_stats(scheduler_outputs.prompt_run, diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 67605df68cfa8..f7b8d258fae4c 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -38,8 +38,10 @@ async def generate(request: Request) -> Response: sampling_params = SamplingParams(**request_dict) request_id = random_uuid() - results_generator = engine.generate(prompt, sampling_params, request_id, - prefix_pos) + results_generator = engine.generate(prompt, + sampling_params, + request_id, + prefix_pos=prefix_pos) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: diff --git a/vllm/prefix.py b/vllm/prefix.py index bc539e03f1e40..cd040df568474 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -1,36 +1,33 @@ -from typing import List, Optional +from typing import Dict, List, Sequence, Tuple, Optional + +from vllm.block import BlockTable class Prefix: - """Data and states associated with a prefix of prompt tokens for multiple sequence groups. + """Data and states associated with a prefix of prompt tokens for multiple + sequence groups. Args: prefix_id: The id of the prefix in the prefix pool. token_ids: The token ids of the prefix. block_size: The block size of the executed model. - - Attributes: - on_gpu: True if the prefix will be on GPU before the execution of the model. - on_cpu: True if the prefix is on CPU. - swap_to_gpu: True when the prefix will be computed during the execution of the model. """ def __init__( self, - prefix_id: int, - token_ids: List[int], + token_ids: Sequence[int], block_size: int, ) -> None: - self.prefix_id = prefix_id - self.token_ids = token_ids + self.token_ids = tuple(token_ids) self.block_size = block_size self.length = len(token_ids) + self.hash = hash(token_ids) assert self.length % block_size == 0 - self.on_gpu = False - self.on_cpu = False - self.block_table: Optional[List[int]] = None - # a lock to prevent multiple sequence from calculating the same prefix - self.swap_to_gpu = False + self.block_table: Optional[BlockTable] = None + + @property + def allocated(self) -> bool: + return self.block_table is not None def get_num_blocks(self) -> int: return self.length // self.block_size @@ -38,26 +35,24 @@ def get_num_blocks(self) -> int: def get_block_numbers(self) -> List[int]: return [block.block_number for block in self.block_table] - def match(self, tokens: List[int]) -> bool: - return tokens[:self.length] == self.token_ids - - # whether the prefix is on GPU or not - def get_status(self) -> bool: - return self.on_gpu - def get_length(self) -> int: return self.length + def __hash__(self) -> int: + return self.hash + + def set_block_table(self, block_table: BlockTable) -> None: + self.block_table = block_table.copy() + class PrefixPool: """Manages all the prompt prefixes. Args: block_size: The block size of the executed model. - + Attributes: prefixes: A list of all the prefixes. - prefixes_hash: Mapping from the hash of the prefix to the prefix id. block_size: The block size of the executed model. """ @@ -65,24 +60,20 @@ def __init__( self, block_size: int, ) -> None: - self.prefixes = [] - self.prefixes_hash = {} + self.prefixes: Dict[int, Prefix] = {} self.block_size = block_size - def add_prefix(self, token_ids: List[int]) -> Prefix: - prefix_hash = hash(tuple(token_ids)) - assert prefix_hash not in self.prefixes_hash - # generate prefix_id - prefix_id = len(self.prefixes) - # create a new prefix - prefix = Prefix(prefix_id, token_ids, self.block_size) - self.prefixes.append(prefix) - self.prefixes_hash[prefix_hash] = prefix.prefix_id - return prefix - - # use this first, if we already know from the application which part of the tokens are prefix. - def fixed_search(self, prefix_hash: int) -> Optional[Prefix]: - if prefix_hash not in self.prefixes_hash: + def _truncate_token_ids(self, token_ids: Sequence[int]) -> Tuple[int]: + new_length = len(token_ids) // self.block_size * self.block_size + return tuple(token_ids[:new_length]) + + def add_or_get_prefix(self, token_ids: Sequence[int]) -> Optional[Prefix]: + token_ids = self._truncate_token_ids(token_ids) + if len(token_ids) == 0: + # Prefix is empty. return None - prefix_id = self.prefixes_hash[prefix_hash] - return self.prefixes[prefix_id] + prefix = Prefix(token_ids, self.block_size) + prefix_hash = hash(prefix) + if prefix_hash not in self.prefixes: + self.prefixes[prefix_hash] = prefix + return self.prefixes[prefix_hash] diff --git a/vllm/sequence.py b/vllm/sequence.py index d8fe6fdf453dd..00fa72dc56956 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -114,11 +114,9 @@ def __init__( prompt: str, prompt_token_ids: List[int], block_size: int, - prefix: Optional[Prefix] = None, ) -> None: self.seq_id = seq_id self.prompt = prompt - self.prefix = prefix self.block_size = block_size self.data = SequenceData(prompt_token_ids) @@ -245,8 +243,8 @@ def __init__( self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.sampling_params = sampling_params self.arrival_time = arrival_time + self.prefix: Optional[Prefix] = None self.prompt_logprobs: Optional[PromptLogprobs] = None - self.prefix = prefix @property def prompt(self) -> str: @@ -332,7 +330,6 @@ def __repr__(self) -> str: class SequenceGroupMetadata: """Metadata for a sequence group. Used to create `InputMetadata`. - Args: request_id: The ID of the request. is_prompt: Whether the request is at prompt stage. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0a76c65b100f8..27dce5e5a9376 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -84,7 +84,6 @@ def _prepare_prompt( context_lens: List[int] = [] subquery_lens: List[int] = [] prefix_block_tables: List[List[int]] = [] - max_num_blocks_per_seq_prompt = 0 for seq_group_metadata in seq_group_metadata_list: assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) @@ -96,16 +95,11 @@ def _prepare_prompt( prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) prefix_len = 0 - if (seq_group_metadata.prefix is not None - and seq_group_metadata.prefix.on_gpu): - prefix_len = seq_group_metadata.prefix.get_length() - assert prefix_len % self.block_size == 0 + prefix = seq_group_metadata.prefix + if prefix is not None and prefix.allocated: + prefix_len = prefix.get_length() prompt_tokens = prompt_tokens[prefix_len:] - prefix_block_table = seq_group_metadata.prefix.get_block_numbers( - ) - prefix_block_tables.append(prefix_block_table) - max_num_blocks_per_seq_prompt = max( - max_num_blocks_per_seq_prompt, len(prefix_block_table)) + prefix_block_tables.append(prefix.get_block_numbers()) else: prefix_block_tables.append([]) # actual prompt lens @@ -134,7 +128,9 @@ def _prepare_prompt( # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. start_idx = 0 if self.sliding_window is not None: - assert prefix_len == 0, "prefix caching is currently not supported when using sliding window" + assert prefix_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention") start_idx = max(0, prompt_len - self.sliding_window) for i in range(prefix_len, prompt_len): if i < start_idx: @@ -162,15 +158,14 @@ def _prepare_prompt( context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device='cuda') - - # prefix block tables + # Prepare prefix block tables + max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) block_tables = _make_tensor_with_pad( prefix_block_tables, - max_len=max_num_blocks_per_seq_prompt, + max_len=max_prompt_block_table_len, pad=0, dtype=torch.int, ) - start_loc_tensor = torch.arange(0, len(prompt_lens) * max_prompt_len, max_prompt_len, @@ -191,7 +186,8 @@ def _prepare_prompt( block_tables=block_tables, use_cuda_graph=False, ) - return input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens + return (input_tokens, input_positions, input_metadata, prompt_lens, + subquery_lens) def _prepare_decode( self, @@ -301,7 +297,7 @@ def _prepare_decode( block_tables=block_tables, use_cuda_graph=use_captured_graph, ) - return input_tokens, input_positions, input_metadata, None + return input_tokens, input_positions, input_metadata def _prepare_sample( self, @@ -315,7 +311,7 @@ def _prepare_sample( categorized_sample_indices = {t: [] for t in SamplingType} categorized_sample_indices_start_idx = 0 - max_prompt_len = max(subquery_lens) if subquery_lens else 1 + max_subquery_len = max(subquery_lens) if subquery_lens else 1 for i, seq_group_metadata in enumerate(seq_group_metadata_list): seq_ids = list(seq_group_metadata.seq_data.keys()) sampling_params = seq_group_metadata.sampling_params @@ -323,10 +319,10 @@ def _prepare_sample( if seq_group_metadata.is_prompt: assert len(seq_ids) == 1 - prompt_len = subquery_lens[i] + subquery_len = subquery_lens[i] if sampling_params.prompt_logprobs is not None: # NOTE: prompt token positions do not need sample, skip - categorized_sample_indices_start_idx += prompt_len - 1 + categorized_sample_indices_start_idx += subquery_len - 1 categorized_sample_indices[ sampling_params.sampling_type].append( @@ -336,10 +332,10 @@ def _prepare_sample( if sampling_params.prompt_logprobs is not None: selected_token_indices.extend( range(selected_token_start_idx, - selected_token_start_idx + prompt_len - 1)) + selected_token_start_idx + subquery_len - 1)) selected_token_indices.append(selected_token_start_idx + - prompt_len - 1) - selected_token_start_idx += max_prompt_len + subquery_len - 1) + selected_token_start_idx += max_subquery_len else: num_seqs = len(seq_ids) selected_token_indices.extend( @@ -387,8 +383,9 @@ def prepare_input_tensors( (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens) = self._prepare_prompt(seq_group_metadata_list) else: - (input_tokens, input_positions, input_metadata, - subquery_lens) = self._prepare_decode(seq_group_metadata_list) + (input_tokens, input_positions, input_metadata + ) = self._prepare_decode(seq_group_metadata_list) + subquery_lens = None prompt_lens = [] sampling_metadata = self._prepare_sample(seq_group_metadata_list, prompt_lens, From 58d28391ff22dda86d0744d1818847bc950758ac Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 16 Jan 2024 04:36:10 +0000 Subject: [PATCH 28/37] add example and test --- examples/offline_inference_with_prefix.py | 51 +++++++++++++++++++++ tests/prefix_caching/test_prefix_caching.py | 41 +++++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 examples/offline_inference_with_prefix.py create mode 100644 tests/prefix_caching/test_prefix_caching.py diff --git a/examples/offline_inference_with_prefix.py b/examples/offline_inference_with_prefix.py new file mode 100644 index 0000000000000..df9f1364ee514 --- /dev/null +++ b/examples/offline_inference_with_prefix.py @@ -0,0 +1,51 @@ +from vllm import LLM, SamplingParams + +prefix = ( + "You are an expert school principal, skilled in effectively managing " + "faculty and staff. Draft 10-15 questions for a potential first grade " + "Head Teacher for my K-12, all-girls', independent school that emphasizes " + "community, joyful discovery, and life-long learning. The candidate is " + "coming in for a first-round panel interview for a 8th grade Math " + "teaching role. They have 5 years of previous teaching experience " + "as an assistant teacher at a co-ed, public school with experience " + "in middle school math teaching. Based on these information, fulfill " + "the following paragraph: ") + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.0) + +# Create an LLM. +llm = LLM(model="facebook/opt-125m") + +generating_prompts = [prefix + prompt for prompt in prompts] + +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(generating_prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +print("-" * 80) + +# -1 since the last token can change when concatenating prompts. +prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1 + +# Generate with prefix +outputs = llm.generate(generating_prompts, sampling_params, + prefix_pos=[prefix_pos] * len(generating_prompts)) + +# Print the outputs. You should see the same outputs as before +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py new file mode 100644 index 0000000000000..1e301bedfc21e --- /dev/null +++ b/tests/prefix_caching/test_prefix_caching.py @@ -0,0 +1,41 @@ +"""Compare the with and without prefix caching. + +Run `pytest tests/prefix_caching/test_prefix_caching.py`. +""" +import pytest + +from vllm import LLM, SamplingParams + +prefix = ( + "You are an expert school principal, skilled in effectively managing " + "faculty and staff. Draft 10-15 questions for a potential first grade " + "Head Teacher for my K-12, all-girls', independent school that emphasizes " + "community, joyful discovery, and life-long learning. The candidate is " + "coming in for a first-round panel interview for a 8th grade Math " + "teaching role. They have 5 years of previous teaching experience " + "as an assistant teacher at a co-ed, public school with experience " + "in middle school math teaching. Based on these information, fulfill " + "the following paragraph: ") + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +@pytest.mark.parametrize("max_tokens", [16]) +def test_prefix_caching( + example_prompts, + model: str, + max_tokens: int, +): + llm = LLM(model=model) + # -1 since the last token can change when concatenating prompts. + prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1 + prompts = [prefix + prompt for prompt in example_prompts] + sampling_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) + outputs_without_prefix = llm.generate(prompts, sampling_params) + outputs_with_prefix = llm.generate(prompts, + sampling_params, + prefix_pos=[prefix_pos] * len(prompts)) + for output_without_prefix, output_with_prefix in zip( + outputs_without_prefix, outputs_with_prefix): + assert (output_without_prefix.outputs[0].token_ids == + output_with_prefix.outputs[0].token_ids) + assert len(llm.llm_engine.scheduler.prefix_pool.prefixes) == 1 From 58c5cff463a7b7ad27f9fac248cd81fb689ce50e Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 16 Jan 2024 04:38:50 +0000 Subject: [PATCH 29/37] add prefix caching test to ci --- .buildkite/test-pipeline.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 3cd1bed0e50a2..05d6258ddafc7 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -28,6 +28,10 @@ steps: - pytest -v -s models --forked soft_fail: true +- label: Prefix Caching Test + commands: + - pytest -v -s prefix_caching + - label: Samplers Test command: pytest -v -s samplers --forked From dc6e95904d5af664362cfc1441da221927fe1fec Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 16 Jan 2024 05:29:10 +0000 Subject: [PATCH 30/37] fix ci --- tests/samplers/test_sampler.py | 18 ++++++++++++------ tests/worker/test_model_runner.py | 5 +++-- vllm/entrypoints/llm.py | 2 +- vllm/worker/model_runner.py | 4 +++- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 996aa8e0a8d9a..bcd0cd60bfc52 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -66,7 +66,8 @@ def test_sampler_all_greedy(seed: int): prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens) + prompt_lens, + subquery_lens=prompt_lens) sampler_output = sampler(embedding=None, hidden_states=input_tensor, sampling_metadata=sampling_metadata) @@ -105,7 +106,8 @@ def test_sampler_all_random(seed: int): prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens) + prompt_lens, + subquery_lens=prompt_lens) sampler_output = sampler(embedding=None, hidden_states=input_tensor, sampling_metadata=sampling_metadata) @@ -140,7 +142,8 @@ def test_sampler_all_beam(seed: int): prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens) + prompt_lens, + subquery_lens=prompt_lens) sampler(embedding=None, hidden_states=input_tensor, sampling_metadata=sampling_metadata) @@ -193,7 +196,8 @@ def test_sampler_mixed(seed: int): prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens) + prompt_lens, + subquery_lens=prompt_lens) sampler_output = sampler(embedding=None, hidden_states=input_tensor, sampling_metadata=sampling_metadata) @@ -234,7 +238,8 @@ def pick_ith(token_ids, logits): prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens) + prompt_lens, + subquery_lens=prompt_lens) sampler_output = sampler(embedding=None, hidden_states=input_tensor, sampling_metadata=sampling_metadata) @@ -288,7 +293,8 @@ def test_sampler_top_k_top_p(seed: int): prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens) + prompt_lens, + subquery_lens=prompt_lens) sample_probs = None diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 250d84caf56d4..edbe10684741f 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -33,11 +33,12 @@ def test_prepare_prompt(): expected_selected_token_indices.append(selected_token_start_idx + prompt_len - 1) selected_token_start_idx += max_seq_len - input_tokens, input_positions, _, return_prompt_lens = ( + input_tokens, input_positions, _, return_prompt_lens, _ = ( model_runner._prepare_prompt(seq_group_metadata_list)) assert return_prompt_lens == prompt_lens sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, - prompt_lens) + prompt_lens, + subquery_lens=prompt_lens) assert input_tokens.shape == (batch_size, max_seq_len) assert input_positions.shape == (batch_size, max_seq_len) torch.testing.assert_close(input_tokens, input_positions) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 9f5fbbd876123..fb50df64df766 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -171,7 +171,7 @@ def _add_request( prompt: Optional[str], sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]], - prefix_pos: Optional[int], + prefix_pos: Optional[int] = None, ) -> None: request_id = str(next(self.request_counter)) self.llm_engine.add_request(request_id, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1fad6e33a170e..0f0ab5244b8ae 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -74,7 +74,8 @@ def set_block_size(self, block_size: int) -> None: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int]]: + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], + List[int]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] @@ -319,6 +320,7 @@ def _prepare_sample( if seq_group_metadata.is_prompt: assert len(seq_ids) == 1 + assert subquery_lens is not None subquery_len = subquery_lens[i] if sampling_params.prompt_logprobs is not None: # NOTE: prompt token positions do not need sample, skip From 7dc2e87945c082b56e8f72d42d93faf101f98f7b Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Wed, 17 Jan 2024 02:41:01 +0000 Subject: [PATCH 31/37] add comment --- vllm/engine/async_llm_engine.py | 5 +++++ vllm/engine/llm_engine.py | 5 +++++ vllm/entrypoints/llm.py | 5 +++++ 3 files changed, 15 insertions(+) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 3dd0434369d04..cbf2978c01c2a 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -429,6 +429,11 @@ async def generate( request_id: The unique id of the request. prompt_token_ids: The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs. + prefix_pos: If not None, we use the given position as the prefix + position for each prompt. We will cache the prefix's KV + cache and reuse it for the next request with the same prefix. + This is an experimental feature, and may be replaced with + automatic prefix caching in the future. Yields: The output `RequestOutput` objects from the LLMEngine for the diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a6a1e3014cbab..fd5a231a902ef 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -354,6 +354,11 @@ def add_request( use the tokenizer to convert the prompts to token IDs. arrival_time: The arrival time of the request. If None, we use the current monotonic time. + prefix_pos: If not None, we use the given position as the prefix + position for each prompt. We will cache the prefix's KV + cache and reuse it for the next request with the same prefix. + This is an experimental feature, and may be replaced with + automatic prefix caching in the future. Details: - Set arrival_time to the current time if it is None. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index fb50df64df766..b819e233c06b2 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -135,6 +135,11 @@ def generate( None, we use the default sampling parameters. prompt_token_ids: A list of token IDs for the prompts. If None, we use the tokenizer to convert the prompts to token IDs. + prefix_pos: If not None, we use the given position as the prefix + position for each prompt. We will cache the prefix's KV + cache and reuse it for the next request with the same prefix. + This is an experimental feature, and may be replaced with + automatic prefix caching in the future. use_tqdm: Whether to use tqdm to display the progress bar. Returns: From 49b2684298fd53c75cba44328d09919fdbcbb27c Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Wed, 17 Jan 2024 02:41:46 +0000 Subject: [PATCH 32/37] add TODO --- vllm/prefix.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/prefix.py b/vllm/prefix.py index cd040df568474..14266288f9092 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -60,6 +60,7 @@ def __init__( self, block_size: int, ) -> None: + # TODO(zhuohan): Add a capacity limit to the prefix pool. self.prefixes: Dict[int, Prefix] = {} self.block_size = block_size From c9050d3e1adf6c0e7b9d7c4ba6555ed14d428d41 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Wed, 17 Jan 2024 02:45:02 +0000 Subject: [PATCH 33/37] fix --- tests/kernels/test_prefix_prefill.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 8a0847d14adc1..8fa6358d3ec71 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -166,6 +166,3 @@ def test_contexted_kv_attention( print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") output_ref = output_ref.squeeze(0) assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) - - -test_contexted_kv_attention(12, 128, torch.float16) From cfe1444ae9fbeed33e8f04c3274c4602faf28c20 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Wed, 17 Jan 2024 22:18:13 +0000 Subject: [PATCH 34/37] fix swapping logic --- vllm/core/block_manager.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 16060f9145d8c..6fa7374125a52 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -103,9 +103,8 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] num_required_blocks = len(seq.logical_token_blocks) - if seq_group.prefix is not None and seq_group.prefix.on_gpu: - num_required_blocks -= seq_group.prefix.get_length( - ) // self.block_size + if seq_group.prefix is not None and seq_group.prefix.allocated: + num_required_blocks -= seq_group.prefix.get_num_blocks() if self.block_sliding_window is not None: num_required_blocks = min(num_required_blocks, @@ -236,7 +235,7 @@ def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: # CPU block -> GPU block. if seq_group.prefix is not None: # make sure to swap in the prefix first - assert seq_group.prefix.on_gpu is True + assert seq_group.prefix.allocated mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): From 29f4f96dfebba54e0776317045d20974efc928cc Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Wed, 17 Jan 2024 23:24:59 +0000 Subject: [PATCH 35/37] fix bug --- vllm/sequence.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/sequence.py b/vllm/sequence.py index 00fa72dc56956..fd10bc9b5b8cc 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -243,7 +243,7 @@ def __init__( self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.sampling_params = sampling_params self.arrival_time = arrival_time - self.prefix: Optional[Prefix] = None + self.prefix: Optional[Prefix] = prefix self.prompt_logprobs: Optional[PromptLogprobs] = None @property From bd56a6938a9f6b873b7de9c5ae683ae49d0f5287 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Wed, 17 Jan 2024 23:44:42 +0000 Subject: [PATCH 36/37] fix correctness --- vllm/engine/llm_engine.py | 6 ++++++ vllm/prefix.py | 1 + vllm/worker/model_runner.py | 2 +- 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index fd5a231a902ef..7072a8bbc5b3e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -672,6 +672,12 @@ def _process_model_outputs( request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) + # Update prefix state, now all the uncomputed prefixes are computed. + for seq_group in scheduled_seq_groups: + if (seq_group.prefix is not None and seq_group.prefix.allocated + and not seq_group.prefix.computed): + seq_group.prefix.computed = True + if self.log_stats: # Log the system stats. self._log_system_stats(scheduler_outputs.prompt_run, diff --git a/vllm/prefix.py b/vllm/prefix.py index 14266288f9092..985f8aa95a69f 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -24,6 +24,7 @@ def __init__( self.hash = hash(token_ids) assert self.length % block_size == 0 self.block_table: Optional[BlockTable] = None + self.computed = False @property def allocated(self) -> bool: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 0f0ab5244b8ae..d290886506507 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -97,7 +97,7 @@ def _prepare_prompt( prompt_lens.append(prompt_len) prefix_len = 0 prefix = seq_group_metadata.prefix - if prefix is not None and prefix.allocated: + if prefix is not None and prefix.computed: prefix_len = prefix.get_length() prompt_tokens = prompt_tokens[prefix_len:] prefix_block_tables.append(prefix.get_block_numbers()) From 6b002831eea52fad4c41255c18b875bad539d0d6 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 18 Jan 2024 00:24:00 +0000 Subject: [PATCH 37/37] add notes and small fix --- vllm/core/block_manager.py | 4 ++-- vllm/prefix.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 6fa7374125a52..7f91051f03ac1 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -235,7 +235,7 @@ def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: # CPU block -> GPU block. if seq_group.prefix is not None: # make sure to swap in the prefix first - assert seq_group.prefix.allocated + assert seq_group.prefix.allocated and seq_group.prefix.computed mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): @@ -278,7 +278,7 @@ def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: for gpu_block in block_table: if (seq_group.prefix is not None and gpu_block in seq_group.prefix.block_table): - # We do not swap out the prefix blocks. + # NOTE: We do not swap out the prefix blocks for now. self.gpu_allocator.free(gpu_block) continue diff --git a/vllm/prefix.py b/vllm/prefix.py index 985f8aa95a69f..415da1fc6d2bf 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -7,6 +7,9 @@ class Prefix: """Data and states associated with a prefix of prompt tokens for multiple sequence groups. + NOTE: This feature is experimental and may be replaced with automatic + prefix caching in the future. + Args: prefix_id: The id of the prefix in the prefix pool. token_ids: The token ids of the prefix. @@ -49,6 +52,9 @@ def set_block_table(self, block_table: BlockTable) -> None: class PrefixPool: """Manages all the prompt prefixes. + NOTE: This feature is experimental and may be replaced with automatic + prefix caching in the future. + Args: block_size: The block size of the executed model.