diff --git a/benchmarks/benchmark_lora_throughput.py b/benchmarks/benchmark_lora_throughput.py new file mode 100644 index 0000000000000..19323c39acaa9 --- /dev/null +++ b/benchmarks/benchmark_lora_throughput.py @@ -0,0 +1,506 @@ +"""Benchmark offline inference throughput.""" +import argparse +import dataclasses +import json +import random +import time +from typing import List, Optional + + +import torch +import uvloop +from PIL import Image +from tqdm import tqdm +from transformers import (AutoModelForCausalLM, AutoTokenizer, + PreTrainedTokenizerBase) + +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.entrypoints.openai.api_server import ( + build_async_engine_client_from_engine_args) +from vllm.inputs import TextPrompt +from vllm.multimodal import MultiModalDataDict +from vllm.sampling_params import BeamSearchParams +from vllm.utils import FlexibleArgumentParser, merge_async_iterators + +from functools import cache +from huggingface_hub import snapshot_download +from vllm.lora.request import LoRARequest +from vllm.outputs import RequestOutput +from typing import Generator, Tuple, Optional +import pickle +import random + +SAMPLING_TEMPERATURE=0.0 +SAMPLING_TOP_P=1.0 + +@cache +def lora_path() -> str: + return snapshot_download('yard1/llama-2-7b-sql-lora-test') + +def num_loras() -> int: + # Tied to model in lora_path + return 2 + +def random_lora_request() -> Generator[LoRARequest, None, None]: + # Create a random LoRA Request + while True: + lora_id = random.randint(0, num_loras()) + if lora_id == 0: + yield None + else: + yield LoRARequest(lora_name = str(lora_id), lora_int_id = lora_id, lora_path = lora_path()) + +@dataclasses.dataclass +class SampleRequest: + """A class representing a single inference request for benchmarking. + + Attributes: + prompt: The input text prompt for the model. + multi_modal_data: Optional dictionary containing multi-modal data (e.g. + images). + prompt_len: The length of the prompt in tokens. + expected_output_len: The expected length of the output in tokens. + """ + prompt: str + prompt_len: int + expected_output_len: int + multi_modal_data: Optional[MultiModalDataDict] = None + lora_request: Optional[LoRARequest] = None + + +def _get_prompt_for_image_model(question: str, *, model: str) -> str: + """Prepend and append special tokens around the question to form a prompt. + + Args: + question: The input question text to wrap with special tokens + model: The name of the model being used, to determine which special + tokens to add + + Returns: + The formatted prompt string with appropriate special tokens for the + model + + Raises: + ValueError: If an unsupported model name is provided + """ + model = model.lower() + if "pixtral" in model: + return f"[INST]{question}\n[IMG][/INST]" + raise ValueError(f"Unsupported model {model}") + + +def sample_requests(tokenizer: PreTrainedTokenizerBase, + args: argparse.Namespace) -> List[SampleRequest]: + + random_lora_request_gen = None + if args.enable_lora: + assert args.model == 'meta-llama/Llama-2-7b-hf', "hardcoded lora adapter only works for this model" + random_lora_request_gen = random_lora_request() + + dataset_path: str = args.dataset + num_requests: int = args.num_prompts + fixed_output_len: Optional[int] = args.output_len + model: str = args.model + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: List[SampleRequest] = [] + for data in dataset: + if len(filtered_dataset) == num_requests: + break + + # Only keep the first two turns of each conversation. + prompt = data["conversations"][0]["value"] + completion = data["conversations"][1]["value"] + + multi_modal_data: Optional[MultiModalDataDict] = None + if "image" in data: + multi_modal_data = multi_modal_data or {} + image_path = data["image"] + # TODO(vllm-project/vllm/issues/9778): Support multiple images. + assert isinstance(image_path, + str), "Only support single image input" + try: + multi_modal_data["image"] = Image.open(image_path).convert( + "RGB") + except FileNotFoundError: + # Ignore datapoint where asset is missing + continue + prompt = _get_prompt_for_image_model(question=prompt, model=model) + + # Tokenize the prompts and completions. + prompt_token_ids = tokenizer(prompt).input_ids + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + filtered_dataset.append( + SampleRequest(prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + multi_modal_data=multi_modal_data, + lora_request= next(random_lora_request_gen) if args.enable_lora else None)) + + return filtered_dataset + + +def run_vllm( + requests: List[SampleRequest], + n: int, + engine_args: EngineArgs, +) -> Tuple[float, Optional[List[RequestOutput]]]: + from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) + + # Add the requests to the engine. + prompts: List[TextPrompt] = [] + sampling_params: List[SamplingParams] = [] + lora_requests: List[LoRARequest] = [] + for request in requests: + prompts.append( + TextPrompt(prompt=request.prompt, + multi_modal_data=request.multi_modal_data)) + sampling_params.append( + SamplingParams( + n=n, + temperature=SAMPLING_TEMPERATURE, + top_p=SAMPLING_TOP_P, + ignore_eos=True, + max_tokens=request.expected_output_len, + )) + lora_requests.append(request.lora_request) + + use_beam_search = False + + outputs = None + if not use_beam_search: + start = time.perf_counter() + outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests, use_tqdm=True) + end = time.perf_counter() + else: + assert all([x is None for x in lora_requests]), "Lora requests not supported in beam-search API" + prompts = [request.prompt for request in requests] + # output_len should be the same for all requests. + output_len = requests[0][2] + for request in requests: + assert request.expected_output_len == output_len + start = time.perf_counter() + llm.beam_search( + prompts, + BeamSearchParams( + beam_width=n, + max_tokens=output_len, + ignore_eos=True, + )) + end = time.perf_counter() + return end - start, outputs + + +async def run_vllm_async( + requests: List[SampleRequest], + n: int, + engine_args: AsyncEngineArgs, + disable_frontend_multiprocessing: bool = False, +) -> Tuple[float, Optional[List[RequestOutput]]]: + from vllm import SamplingParams + + async with build_async_engine_client_from_engine_args( + engine_args, disable_frontend_multiprocessing) as llm: + + # Add the requests to the engine. + prompts: List[TextPrompt] = [] + sampling_params: List[SamplingParams] = [] + lora_requests: List[LoRARequest] = [] + for request in requests: + prompts.append( + TextPrompt(prompt=request.prompt, + multi_modal_data=request.multi_modal_data)) + sampling_params.append( + SamplingParams( + n=n, + temperature=SAMPLING_TEMPERATURE, + top_p=SAMPLING_TOP_P, + ignore_eos=True, + max_tokens=request.expected_output_len, + )) + lora_requests.append(request.lora_request) + + generators = [] + start = time.perf_counter() + for i, (prompt, sp, lora_request) in enumerate(zip(prompts, sampling_params, lora_requests)): + generator = llm.generate(prompt, sp, request_id=f"test{i}", lora_request=lora_request) + generators.append(generator) + all_gens = merge_async_iterators(*generators) + outputs_dict = {} + async for i, res in all_gens: + outputs_dict[i] = res + end = time.perf_counter() + + num_prompts = len(prompts) + outputs = [] + for i in range(num_prompts): + outputs.append(outputs_dict[i]) + + return end - start, outputs + + +def run_hf( + requests: List[SampleRequest], + model: str, + tokenizer: PreTrainedTokenizerBase, + n: int, + max_batch_size: int, + trust_remote_code: bool, +) -> float: + llm = AutoModelForCausalLM.from_pretrained( + model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) + if llm.config.model_type == "llama": + # To enable padding in the HF backend. + tokenizer.pad_token = tokenizer.eos_token + llm = llm.cuda() + + pbar = tqdm(total=len(requests)) + start = time.perf_counter() + batch: List[str] = [] + max_prompt_len = 0 + max_output_len = 0 + for i in range(len(requests)): + prompt, prompt_len, output_len = requests[i] + # Add the prompt to the batch. + batch.append(prompt) + max_prompt_len = max(max_prompt_len, prompt_len) + max_output_len = max(max_output_len, output_len) + if len(batch) < max_batch_size and i != len(requests) - 1: + # Check if we can add more requests to the batch. + _, next_prompt_len, next_output_len = requests[i + 1] + if (max(max_prompt_len, next_prompt_len) + + max(max_output_len, next_output_len)) <= 2048: + # We can add more requests to the batch. + continue + + # Generate the sequences. + input_ids = tokenizer(batch, return_tensors="pt", + padding=True).input_ids + llm_outputs = llm.generate( + input_ids=input_ids.cuda(), + do_sample=True, + num_return_sequences=n, + temperature=SAMPLING_TEMPERATURE, + top_p=SAMPLING_TOP_P, + use_cache=True, + max_new_tokens=max_output_len, + ) + # Include the decoding time. + tokenizer.batch_decode(llm_outputs, skip_special_tokens=True) + pbar.update(len(batch)) + + # Clear the batch. + batch = [] + max_prompt_len = 0 + max_output_len = 0 + end = time.perf_counter() + return end - start + + +def run_mii( + requests: List[SampleRequest], + model: str, + tensor_parallel_size: int, + output_len: int, +) -> float: + from mii import client, serve + llm = serve(model, tensor_parallel=tensor_parallel_size) + prompts = [request.prompt for request in requests] + + start = time.perf_counter() + llm.generate(prompts, max_new_tokens=output_len) + end = time.perf_counter() + client = client(model) + client.terminate_server() + return end - start + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + + # Sample the requests. + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer, trust_remote_code=args.trust_remote_code) + if args.dataset is None: + # Synthesize a prompt with the given input length. + # As tokenizer may add additional tokens like BOS, we need to try + # different lengths to get the desired input length. + for i in range(-10, 10): + prompt = "hi " * (args.input_len + i) + tokenized_prompt = tokenizer(prompt).input_ids + if len(tokenized_prompt) == args.input_len: + break + else: + raise ValueError( + f"Failed to synthesize a prompt with {args.input_len} tokens.") + requests = [ + SampleRequest(prompt=prompt, + prompt_len=args.input_len, + expected_output_len=args.output_len) + for _ in range(args.num_prompts) + ] + else: + requests = sample_requests(tokenizer, args) + + is_multi_modal = any(request.multi_modal_data is not None + for request in requests) + if args.backend == "vllm": + if args.async_engine: + elapsed_time, outputs = uvloop.run( + run_vllm_async( + requests, + args.n, + AsyncEngineArgs.from_cli_args(args), + args.disable_frontend_multiprocessing, + )) + else: + elapsed_time, outputs = run_vllm(requests, args.n, + EngineArgs.from_cli_args(args)) + + if args.pickle_outputs: + print("Pickling request outputs : ") + with open("outputs.pkl", "wb+") as f: + pickle.dump(outputs, f) + + elif args.backend == "hf": + assert args.tensor_parallel_size == 1 + elapsed_time = run_hf(requests, args.model, tokenizer, args.n, + args.hf_max_batch_size, args.trust_remote_code) + elif args.backend == "mii": + elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size, + args.output_len) + else: + raise ValueError(f"Unknown backend: {args.backend}") + total_num_tokens = sum(request.prompt_len + request.expected_output_len + for request in requests) + total_output_tokens = sum(request.expected_output_len + for request in requests) + if is_multi_modal: + print("\033[91mWARNING\033[0m: Multi-modal request detected. The " + "following metrics are not accurate because image tokens are not" + " counted. See vllm-project/vllm/issues/9778 for details.") + # TODO(vllm-project/vllm/issues/9778): Count molti-modal token length. + print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " + f"{total_output_tokens / elapsed_time:.2f} output tokens/s") + + # Output JSON results if specified + if args.output_json: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": total_num_tokens / elapsed_time, + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser(description="Benchmark the throughput.") + parser.add_argument("--backend", + type=str, + choices=["vllm", "hf", "mii"], + default="vllm") + parser.add_argument("--dataset", + type=str, + default=None, + help="Path to the dataset. The dataset is expected to " + "be a json in form of List[Dict[..., conversations: " + "List[Dict[..., value: ]]]]") + parser.add_argument("--input-len", + type=int, + default=None, + help="Input prompt length for each request") + parser.add_argument("--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.") + parser.add_argument("--n", + type=int, + default=1, + help="Number of generated sequences per prompt.") + parser.add_argument("--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.") + parser.add_argument("--hf-max-batch-size", + type=int, + default=None, + help="Maximum batch size for HF backend.") + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the throughput results in JSON format.') + parser.add_argument("--async-engine", + action='store_true', + default=False, + help="Use vLLM async engine rather than LLM class.") + parser.add_argument("--disable-frontend-multiprocessing", + action='store_true', + default=False, + help="Disable decoupled async engine frontend.") + parser.add_argument("--pickle-outputs", + action="store_true", + default=False, + help="Pickle outputs got from benchmark") + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + if args.tokenizer is None: + args.tokenizer = args.model + if args.dataset is None: + assert args.input_len is not None + assert args.output_len is not None + else: + assert args.input_len is None + + if args.enable_lora: + print (f" --model : {args.model}") + print (f" --max-loras : {args.max_loras}") + print (f" --max-lora-rank : {args.max_lora_rank}") + + if args.backend == "vllm": + if args.hf_max_batch_size is not None: + raise ValueError("HF max batch size is only for HF backend.") + elif args.backend == "hf": + if args.hf_max_batch_size is None: + raise ValueError("HF max batch size is required for HF backend.") + if args.quantization is not None: + raise ValueError("Quantization is only for vLLM backend.") + elif args.backend == "mii": + if args.dtype != "auto": + raise ValueError("dtype must be auto for MII backend.") + if args.n != 1: + raise ValueError("n must be 1 for MII backend.") + if args.quantization is not None: + raise ValueError("Quantization is only for vLLM backend.") + if args.hf_max_batch_size is not None: + raise ValueError("HF max batch size is only for HF backend.") + if args.tokenizer != args.model: + raise ValueError("Tokenizer must be the same as the model for MII " + "backend.") + main(args) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index ba50a9786d805..d076f342047bc 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -11,6 +11,7 @@ from vllm.v1.engine import EngineCoreOutput from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus +from vllm.lora.request import LoRARequest if TYPE_CHECKING: from vllm.multimodal import MultiModalKwargs @@ -30,8 +31,6 @@ def __init__( self.scheduler_config = scheduler_config self.cache_config = cache_config self.lora_config = lora_config - # TODO: Support LoRA. - assert lora_config is None, "V1 does not support LoRA yet." num_gpu_blocks = cache_config.num_gpu_blocks assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0 @@ -170,6 +169,15 @@ def schedule(self) -> "SchedulerOutput": self.encoder_cache_manager.allocate(request, i) encoder_budget = new_encoder_budget + # Record the LoRAs in scheduled_running_reqs + requested_loras: set[int] = set() + if self.lora_config: + requested_loras = \ + set(req.lora_request.lora_int_id \ + for req in scheduled_running_reqs \ + if req.lora_request and req.lora_request.lora_int_id > 0) + assert len(requested_loras) <= self.lora_config.max_loras + # Next, schedule the WAITING requests. if not preempted_reqs: while self.waiting: @@ -181,6 +189,17 @@ def schedule(self) -> "SchedulerOutput": break request = self.waiting[0] + + # Check that adding the request still respects the max_loras + # constraint. + if self.lora_config and request.lora_request: + req_lora_id = request.lora_request.lora_int_id + if len(requested_loras) == self.lora_config.max_loras and \ + req_lora_id not in requested_loras: + # cannot schedule + break + requested_loras.add(req_lora_id) + # Get already-cached tokens. computed_blocks = self.kv_cache_manager.get_computed_blocks( request) @@ -513,6 +532,7 @@ class NewRequestData: sampling_params: SamplingParams block_ids: List[int] num_computed_tokens: int + lora_request: Optional[LoRARequest] @classmethod def from_request( @@ -530,6 +550,7 @@ def from_request( sampling_params=request.sampling_params, block_ids=block_ids, num_computed_tokens=num_computed_tokens, + lora_request=request.lora_request, ) diff --git a/vllm/v1/worker/cached_request_state.py b/vllm/v1/worker/cached_request_state.py new file mode 100644 index 0000000000000..142137c1fd1a1 --- /dev/null +++ b/vllm/v1/worker/cached_request_state.py @@ -0,0 +1,43 @@ + +from dataclasses import dataclass +from typing import List, Optional, TYPE_CHECKING + +from vllm.multimodal import MultiModalKwargs +from vllm.sampling_params import SamplingParams +import torch +from vllm.v1.core.scheduler import RunningRequestData, ResumedRequestData +from vllm.lora.request import LoRARequest + + +if TYPE_CHECKING: + from vllm.multimodal.inputs import PlaceholderRange + +@dataclass +class CachedRequestState: + + req_id: str + prompt_token_ids: List[int] + prompt: Optional[str] + mm_inputs: List[MultiModalKwargs] + mm_positions: List["PlaceholderRange"] + sampling_params: SamplingParams + generator: Optional[torch.Generator] + + block_ids: List[int] + num_computed_tokens: int + output_token_ids: List[int] + + lora_request: Optional[LoRARequest] + + @property + def num_tokens(self) -> int: + return len(self.prompt_token_ids) + len(self.output_token_ids) + + def update_from_running_request_data(self, req_data: RunningRequestData) -> None: + self.num_computed_tokens = req_data.num_computed_tokens + if len(req_data.new_block_ids): + self.block_ids.extend(req_data.new_block_ids) + + def update_from_resumed_request_data(self, req_data: ResumedRequestData) -> None: + self.num_computed_tokens = req_data.num_computed_tokens + self.block_ids = req_data.block_ids \ No newline at end of file diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index eebd1de96537f..cb207c58d8891 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2,6 +2,7 @@ import time from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple +import pickle as pkl import numpy as np import torch @@ -26,6 +27,12 @@ FlashAttentionMetadata) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.worker.cached_request_state import CachedRequestState +from vllm.v1.worker.request_batch import RequestBatch +from vllm.v1.worker.model_runner_device_tensors import ModelRunnerDeviceTensors +from vllm.v1.worker.lora_request_batch import LoRARequestBatch + +from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin if TYPE_CHECKING: from vllm.multimodal.inputs import PlaceholderRange @@ -34,7 +41,9 @@ logger = init_logger(__name__) -class GPUModelRunner: +class GPUModelRunner(LoRAModelRunnerMixin): + + STEP = 0 def __init__( self, @@ -90,27 +99,28 @@ def __init__( # Request states. self.requests: Dict[str, CachedRequestState] = {} - # Persistent batch. - self.input_batch = InputBatch( + # Persistent batch. All tensors maintained by the batch are on the CPU. + request_batch_kwargs = {'max_num_reqs' : self.scheduler_config.max_num_seqs, + 'max_model_len' : self.max_model_len, + 'max_num_blocks_per_req' : self.max_num_blocks_per_req, + 'pin_memory' : self.pin_memory} + self.request_batch = LoRARequestBatch(**request_batch_kwargs) \ + if self.lora_config else RequestBatch(**request_batch_kwargs) + + # Device Tensors + self.device_tensors = ModelRunnerDeviceTensors.make( max_num_reqs=self.scheduler_config.max_num_seqs, - max_model_len=self.max_model_len, + max_num_tokens = self.max_num_tokens, max_num_blocks_per_req=self.max_num_blocks_per_req, - device=self.device, - pin_memory=self.pin_memory, - ) + input_embeds_hidden_size=self.hidden_size, + input_embeds_dtype=self.dtype, + device=self.device) self.use_cuda_graph = (envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. self.cudagraph_batch_sizes = [1, 2, 4] + [i for i in range(8, 513, 8)] - self.positions = torch.zeros(self.max_num_tokens, - dtype=torch.int64, - device=self.device) - self.inputs_embeds = torch.zeros( - (self.max_num_tokens, self.hidden_size), - dtype=self.dtype, - device=self.device) def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Remove stopped requests from the cached states. @@ -132,32 +142,19 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: scheduler_output.preempted_req_ids, scheduler_output.finished_req_ids, ) - removed_req_indices: List[int] = [] - for req_id in stopped_req_ids: - req_index = self.input_batch.remove_request(req_id) - if req_index is not None: - removed_req_indices.append(req_index) + self.request_batch.remove_requests(stopped_req_ids) # Update the states of the running requests. for req_data in scheduler_output.scheduled_running_reqs: req_id = req_data.req_id req_state = self.requests[req_id] - req_index = self.input_batch.req_id_to_index[req_id] - - # Update the num_computed_tokens. - req_state.num_computed_tokens = req_data.num_computed_tokens - self.input_batch.num_computed_tokens_cpu[req_index] = ( - req_data.num_computed_tokens) - - # Update the block table. - num_new_blocks = len(req_data.new_block_ids) - if num_new_blocks == 0: - continue - start_index = len(req_state.block_ids) - end_index = start_index + num_new_blocks - req_state.block_ids.extend(req_data.new_block_ids) - self.input_batch.block_table_cpu[ - req_index, start_index:end_index] = req_data.new_block_ids + + # Grab num_existing_block_ids from request state before + # update to request_state + num_existing_block_ids = len(req_state.block_ids) + + req_state.update_from_running_request_data(req_data) + self.request_batch.update_states(req_id, req_data, num_existing_block_ids) req_ids_to_add: List[str] = [] # Add new requests to the cached states. @@ -181,7 +178,9 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: block_ids=req_data.block_ids, num_computed_tokens=req_data.num_computed_tokens, output_token_ids=[], + lora_request = req_data.lora_request, ) + req_ids_to_add.append(req_id) # Update the cached states of the resumed requests. @@ -189,135 +188,90 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_id = req_data.req_id req_state = self.requests[req_id] - req_state.block_ids = req_data.block_ids - req_state.num_computed_tokens = req_data.num_computed_tokens + req_state.update_from_resumed_request_data(req_data) req_ids_to_add.append(req_id) - # Add the new or resumed requests to the persistent batch. - # The smaller empty indices are filled first. - removed_req_indices = sorted(removed_req_indices, reverse=True) for req_id in req_ids_to_add: req_state = self.requests[req_id] - if removed_req_indices: - # Fill the empty index. - req_index = removed_req_indices.pop() - else: - # Append to the end. - req_index = None - self.input_batch.add_request(req_state, req_index) + self.request_batch.add_request(req_state) # Condense the batched states if there are empty indices. - if removed_req_indices: - self.input_batch.condense(removed_req_indices) + self.request_batch.condense() + + def _prepare_inputs(self, num_scheduled_tokens: np.array, + total_num_scheduled_tokens: Optional[int] = None) \ + -> Tuple[torch.Tensor, FlashAttentionMetadata, torch.Tensor]: + """ + Prepare model inputs such as, input_token_ids, attention metadata etc. + This function triggers async CPU-GPU transfers some device tensors such + as block_table, positions and slot_mapping. + + Args: + num_scheduled_tokens (np.array): Numpy array containing the number of tokens + scheduled to be processed for every request in self.request_batch. + Note that num_scheduled_tokens[i] must corresponding to the ith request in + the request batch. + + total_num_scheduled_tokens (Optional[int]): Total number of tokens + scheduled to be computed. This must be equal to np.sum(num_scheduled_tokens). + This is an optional parameter for optimization. + + Returns: + input_token_ids: Token ids from the scheduled requests. input_token_ids contains + as many values as total_num_scheduled_tokens + attention_metadata: FlashAttentionMetadata + logits_indices: logits indices for model output readout. + """ + + if not total_num_scheduled_tokens: + total_num_scheduled_tokens = np.sum(num_scheduled_tokens) - def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): - total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens assert total_num_scheduled_tokens > 0 - num_reqs = self.input_batch.num_reqs + num_reqs: int = self.request_batch.num_reqs() assert num_reqs > 0 - # OPTIMIZATION: Start copying the block table first. - # This way, we can overlap the copy with the following CPU operations. - self.input_batch.block_table[:num_reqs].copy_( - self.input_batch.block_table_cpu_tensor[:num_reqs], - non_blocking=True) - - # Get the number of scheduled tokens for each request. - # TODO: The Python loop can be slow. Optimize. - num_scheduled_tokens = [] - max_num_scheduled_tokens = 0 - for req_id in self.input_batch.req_ids[:num_reqs]: - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - num_scheduled_tokens.append(num_tokens) - max_num_scheduled_tokens = max(max_num_scheduled_tokens, - num_tokens) - num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32) - assert max_num_scheduled_tokens > 0 - - # Get request indices. - # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] - indices = np.arange(num_reqs) - req_indices = np.repeat(indices, num_scheduled_tokens) - - # Get batched arange. - # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - arange_matrix = np.tile(np.arange(max_num_scheduled_tokens), - (num_reqs, 1)) - mask = arange_matrix < num_scheduled_tokens[:, np.newaxis] - arange = arange_matrix[mask] - - # Get positions. - positions = torch.empty((total_num_scheduled_tokens, ), - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - positions_np = positions.numpy() - np.add(self.input_batch.num_computed_tokens_cpu[req_indices], - arange, - out=positions_np) - - # Get token indices. - # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] - # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] - # where M is the max_model_len. - token_indices = positions_np + req_indices * self.max_model_len - token_indices = torch.from_numpy(token_indices) - input_ids = torch.empty((total_num_scheduled_tokens, ), - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - torch.index_select(torch.from_numpy( - self.input_batch.token_ids_cpu).flatten(), - 0, - token_indices, - out=input_ids) - - # Calculate the slot mapping. - block_numbers = self.input_batch.block_table_cpu_tensor.flatten()[ - token_indices // self.block_size] - block_offsets = token_indices % self.block_size - slot_mapping = torch.empty((total_num_scheduled_tokens, ), - dtype=torch.int32, - device="cpu", - pin_memory=self.pin_memory) - torch.add(block_numbers * self.block_size, - block_offsets, - out=slot_mapping) - - # Prepare the attention metadata. + # Prepare gpu tensors + self.request_batch.prepare_inputs(num_scheduled_tokens, self.block_size, + block_table_device_tensor=self.device_tensors.block_table, + input_tokens_device_tensor=self.device_tensors.input_tokens, + input_positions_device_tensor=self.device_tensors.input_positions, + slot_mapping_device_tensor=self.device_tensors.slot_mapping) + + ## Prepare attention meta + seq_lens_np: np.array = self.request_batch.make_seq_lens_tensor(num_scheduled_tokens) + + ## Query start loc query_start_loc = torch.empty((num_reqs + 1, ), dtype=torch.int32, device="cpu", pin_memory=self.pin_memory) query_start_loc_np = query_start_loc.numpy() query_start_loc_np[0] = 0 + # make a numpy array np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:]) - seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] + - num_scheduled_tokens) - max_seq_len = seq_lens.max() + ## Seq start loc seq_start_loc = torch.empty((num_reqs + 1, ), dtype=torch.int32, device="cpu", pin_memory=self.pin_memory) seq_start_loc_np = seq_start_loc.numpy() seq_start_loc_np[0] = 0 - np.cumsum(seq_lens, out=seq_start_loc_np[1:]) + np.cumsum(seq_lens_np, out=seq_start_loc_np[1:]) + + max_seq_len = np.max(seq_lens_np) + max_num_scheduled_tokens = np.max(num_scheduled_tokens) - input_ids = input_ids.to(self.device, non_blocking=True) - self.positions[:total_num_scheduled_tokens].copy_(positions, - non_blocking=True) query_start_loc = query_start_loc.to(self.device, non_blocking=True) seq_start_loc = seq_start_loc.to(self.device, non_blocking=True) - slot_mapping = slot_mapping.to(self.device, non_blocking=True).long() attn_metadata = FlashAttentionMetadata( num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, query_start_loc=query_start_loc, max_seq_len=max_seq_len, seq_start_loc=seq_start_loc, - block_table=self.input_batch.block_table[:num_reqs], - slot_mapping=slot_mapping, + block_table=self.device_tensors.block_table[:num_reqs], + slot_mapping=self.device_tensors.slot_mapping[:total_num_scheduled_tokens], ) # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this @@ -325,7 +279,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # token from the partial request. # TODO: Support prompt logprobs. logits_indices = query_start_loc[1:] - 1 - return input_ids, attn_metadata, logits_indices + return self.device_tensors.input_tokens[:total_num_scheduled_tokens], attn_metadata, logits_indices def _prepare_sampling( self, @@ -339,7 +293,7 @@ def _prepare_sampling( or scheduler_output.scheduled_resumed_reqs): skip_copy = False # Create the sampling metadata. - sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy) + sampling_metadata = self.request_batch.make_sampling_metadata(self.device_tensors.sampling_tensors, skip_copy) return sampling_metadata def _execute_encoder(self, scheduler_output: "SchedulerOutput"): @@ -379,8 +333,7 @@ def _gather_encoder_outputs( scheduler_output: "SchedulerOutput", ) -> List[torch.Tensor]: encoder_outputs: List[torch.Tensor] = [] - num_reqs = self.input_batch.num_reqs - for req_id in self.input_batch.req_ids[:num_reqs]: + for req_id in self.request_batch.request_ids(): num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ req_id] req_state = self.requests[req_id] @@ -413,11 +366,64 @@ def _gather_encoder_outputs( encoder_outputs.append(encoder_output[start_idx:end_idx]) return encoder_outputs + def dump_data(self, positions: torch.Tensor, input_ids: torch.Tensor, attn_metadata: FlashAttentionMetadata): + + print ("data dump : \n") + print (f" - input ids : {input_ids.shape} {input_ids.dtype} {input_ids}") + print (f" - positions : {positions.shape} {positions.dtype} {positions}") + print (f" - Flashattn : \n") + print (f" - num_actual_tokens : {attn_metadata.num_actual_tokens}") + print (f" - max_query_len : {attn_metadata.max_query_len}") + print (f" - query_start_loc : {attn_metadata.query_start_loc}") + print (f" - max_seq_len : {attn_metadata.max_seq_len}") + print (f" - seq_start_loc : {attn_metadata.seq_start_loc}") + print (f" - block_table : {attn_metadata.block_table}") + print (f" - slot mapping : {attn_metadata.slot_mapping}") + return + + # data tuple + #data = (input_ids.cpu().numpy(), positions.cpu().numpy(), + # attn_metadata.num_actual_tokens, + # attn_metadata.max_query_len, + # attn_metadata.query_start_loc.cpu().numpy(), + # attn_metadata.max_seq_len, + # attn_metadata.seq_start_loc.cpu().numpy(), + # attn_metadata.block_table.cpu().numpy(), + # attn_metadata.slot_mapping.cpu().numpy(), + # self.input_batch.num_reqs) + + main_data = None + fname = f"./dump/main_{self.STEP}.pkl" + with open(fname, "rb") as f: + main_data = pkl.load(f) + + main_input_ids, main_positions, main_actual_num_tokens, \ + main_max_query_len, main_query_start_loc, main_max_seq_len, \ + main_seq_start_loc, main_block_table, main_slot_mapping, main_num_reqs = main_data + + ## Tests + assert main_num_reqs == self.request_batch.num_reqs() + assert main_actual_num_tokens == attn_metadata.num_actual_tokens + assert main_max_query_len == attn_metadata.max_query_len + assert main_max_seq_len == attn_metadata.max_seq_len + + assert np.allclose(main_query_start_loc, attn_metadata.query_start_loc.cpu().numpy()) + assert np.allclose(main_seq_start_loc, attn_metadata.seq_start_loc.cpu().numpy()) + assert np.allclose(main_positions[:main_actual_num_tokens], positions.cpu().numpy()[:main_actual_num_tokens]) + assert np.allclose(main_input_ids, input_ids.cpu().numpy()) + assert np.allclose(main_slot_mapping, attn_metadata.slot_mapping.cpu().numpy()) + + assert np.allclose(main_block_table[:main_num_reqs], attn_metadata.block_table.cpu().numpy()[:main_num_reqs]) + + + self.STEP += 1 + @torch.inference_mode() def execute_model( self, scheduler_output: "SchedulerOutput", ) -> ModelRunnerOutput: + self._update_states(scheduler_output) # Run the encoder. @@ -425,8 +431,18 @@ def execute_model( encoder_outputs = self._gather_encoder_outputs(scheduler_output) # Prepare the decoder inputs. + num_scheduled_tokens: List[int] = [] + for req_id in self.request_batch.request_ids(): + num_scheduled_tokens.append(scheduler_output.num_scheduled_tokens[req_id]) + num_scheduled_tokens : np.array = np.array(num_scheduled_tokens) + input_ids, attn_metadata, logits_indices = self._prepare_inputs( - scheduler_output) + num_scheduled_tokens, scheduler_output.total_num_scheduled_tokens) + + # hot-swap lora model + if self.lora_config: + self.set_activte_loras(self.request_batch, num_scheduled_tokens) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -447,17 +463,19 @@ def execute_model( # NOTE(woosuk): To unify token ids and soft tokens (vision embeddings), # always use embeddings (rather than token ids) as input to the model. # TODO(woosuk): Avoid the copy. Optimize. - self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds) + self.device_tensors.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds) + + #self.dump_data(self.device_tensors.input_positions, input_ids, attn_metadata) # Run the decoder. # Use persistent buffers for CUDA graphs. with set_forward_context(attn_metadata): hidden_states = self.model( input_ids=None, - positions=self.positions[:num_input_tokens], + positions=self.device_tensors.input_positions[:num_input_tokens], kv_caches=self.kv_caches, attn_metadata=None, - inputs_embeds=self.inputs_embeds[:num_input_tokens], + inputs_embeds=self.device_tensors.inputs_embeds[:num_input_tokens], ) hidden_states = hidden_states[:num_scheduled_tokens] hidden_states = hidden_states[logits_indices] @@ -475,24 +493,20 @@ def execute_model( sampled_token_ids_list = sampled_token_ids.tolist() # TODO(woosuk): The following loop can be slow since it iterates over # the requests one by one. Optimize. - num_reqs = self.input_batch.num_reqs - for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + for req_idx, req_id in enumerate(self.request_batch.request_ids()): req_state = self.requests[req_id] seq_len = (req_state.num_computed_tokens + scheduler_output.num_scheduled_tokens[req_id]) assert seq_len <= req_state.num_tokens if seq_len == req_state.num_tokens: # Append the sampled token to the output token ids. - token_id = sampled_token_ids_list[i] - self.input_batch.token_ids_cpu[i, seq_len] = token_id + token_id = sampled_token_ids_list[req_idx] + self.request_batch.append_token_id(req_id, token_id, seq_len) req_state.output_token_ids.append(token_id) else: # Ignore the sampled token from the partial request. # Rewind the generator state as if the token was not sampled. - generator = self.input_batch.generators.get(i) - if generator is not None: - # This relies on cuda-specific torch-internal impl details - generator.set_offset(generator.get_offset() - 4) + self.request_batch.rewind_generator(req_id) if sampler_output.logprob_token_ids is None: logprob_token_ids = None @@ -502,9 +516,11 @@ def execute_model( logprobs = None else: logprobs = sampler_output.logprobs.cpu() + + req_id_to_index: Dict[str, int] = {req_id : idx for idx, req_id in enumerate(self.request_batch.request_ids())} model_runner_output = ModelRunnerOutput( - req_ids=self.input_batch.req_ids[:num_reqs], - req_id_to_index=self.input_batch.req_id_to_index, + req_ids=self.request_batch.request_ids(), + req_id_to_index=req_id_to_index, sampled_token_ids_cpu=sampled_token_ids, logprob_token_ids_cpu=logprob_token_ids, logprobs_cpu=logprobs, @@ -529,6 +545,8 @@ def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 self.model = get_model(vllm_config=self.vllm_config) + if self.lora_config: + self.model = self.load_lora_model() self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", @@ -550,10 +568,10 @@ def _dummy_run(self, model: nn.Module, num_tokens: int) -> None: with set_compile_context(self.cudagraph_batch_sizes): # Trigger compilation for general shape. model(input_ids=None, - positions=self.positions, + positions=self.device_tensors.input_positions, kv_caches=dummy_kv_caches, attn_metadata=None, - inputs_embeds=self.inputs_embeds) + inputs_embeds=self.device_tensors.inputs_embeds) @torch.inference_mode() def profile_run(self) -> None: @@ -581,10 +599,10 @@ def capture_model(self) -> None: for num_tokens in reversed(self.cudagraph_batch_sizes): self.model( input_ids=None, - positions=self.positions[:num_tokens], + positions=self.device_tensors.input_positions[:num_tokens], kv_caches=self.kv_caches, attn_metadata=None, - inputs_embeds=self.inputs_embeds[:num_tokens], + inputs_embeds=self.device_tensors.inputs_embeds[:num_tokens], ) end_time = time.perf_counter() @@ -611,269 +629,3 @@ def _get_padded_batch_size(self, batch_size: int) -> Optional[int]: if batch_size <= size: return size return None - - -@dataclass -class CachedRequestState: - - req_id: str - prompt_token_ids: List[int] - prompt: Optional[str] - mm_inputs: List[MultiModalKwargs] - mm_positions: List["PlaceholderRange"] - sampling_params: SamplingParams - generator: Optional[torch.Generator] - - block_ids: List[int] - num_computed_tokens: int - output_token_ids: List[int] - - @property - def num_tokens(self) -> int: - return len(self.prompt_token_ids) + len(self.output_token_ids) - - -class InputBatch: - - def __init__( - self, - max_num_reqs: int, - max_model_len: int, - max_num_blocks_per_req: int, - device: torch.device, - pin_memory: bool, - ): - self.max_num_reqs = max_num_reqs - self.max_model_len = max_model_len - self.max_num_blocks_per_req = max_num_blocks_per_req - self.device = device - self.pin_memory = pin_memory - - self.req_ids: List[Optional[str]] = [None] * max_num_reqs - self.req_id_to_index: Dict[str, int] = {} - - self.token_ids_cpu = np.empty((max_num_reqs, max_model_len), - dtype=np.int32) - self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) - - # Attention-related. - self.block_table = torch.zeros((max_num_reqs, max_num_blocks_per_req), - device=self.device, - dtype=torch.int32) - self.block_table_cpu_tensor = torch.zeros( - (max_num_reqs, max_num_blocks_per_req), - device="cpu", - dtype=torch.int32, - pin_memory=pin_memory, - ) - self.block_table_cpu = self.block_table_cpu_tensor.numpy() - - # Sampling-related. - self.temperature = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.temperature_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) - self.temperature_cpu = self.temperature_cpu_tensor.numpy() - self.greedy_reqs: Set[str] = set() - self.random_reqs: Set[str] = set() - - self.top_p = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) - self.top_p_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=pin_memory) - self.top_p_cpu = self.top_p_cpu_tensor.numpy() - self.top_p_reqs: Set[str] = set() - - self.top_k = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device=device) - self.top_k_cpu_tensor = torch.empty((max_num_reqs, ), - dtype=torch.int32, - device="cpu", - pin_memory=pin_memory) - self.top_k_cpu = self.top_k_cpu_tensor.numpy() - self.top_k_reqs: Set[str] = set() - - # req_index -> generator - self.generators: Dict[int, torch.Generator] = {} - - self.num_logprobs: Dict[str, int] = {} - self.prompt_logprob_reqs: Set[str] = set() - - def add_request( - self, - request: "CachedRequestState", - req_index: Optional[int] = None, - ) -> None: - if req_index is None: - req_index = self.num_reqs - assert req_index < self.max_num_reqs - - req_id = request.req_id - self.req_ids[req_index] = req_id - self.req_id_to_index[req_id] = req_index - - # Copy the prompt token ids and output token ids. - num_prompt_tokens = len(request.prompt_token_ids) - self.token_ids_cpu[ - req_index, :num_prompt_tokens] = request.prompt_token_ids - start_idx = num_prompt_tokens - end_idx = start_idx + len(request.output_token_ids) - self.token_ids_cpu[req_index, - start_idx:end_idx] = request.output_token_ids - - self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens - num_blocks = len(request.block_ids) - self.block_table_cpu[req_index, :num_blocks] = request.block_ids - - sampling_params = request.sampling_params - self.temperature_cpu[req_index] = sampling_params.temperature - if sampling_params.sampling_type == SamplingType.GREEDY: - self.greedy_reqs.add(req_id) - else: - self.random_reqs.add(req_id) - - self.top_p_cpu[req_index] = sampling_params.top_p - if sampling_params.top_p < 1: - self.top_p_reqs.add(req_id) - self.top_k_cpu[req_index] = sampling_params.top_k - if sampling_params.top_k > 0: - self.top_k_reqs.add(req_id) - - self.generators[req_index] = request.generator - - num_logprobs = sampling_params.logprobs - if num_logprobs is not None and num_logprobs > 0: - self.num_logprobs[req_id] = num_logprobs - if sampling_params.prompt_logprobs: - self.prompt_logprob_reqs.add(req_id) - - def remove_request(self, req_id: str) -> Optional[int]: - req_index = self.req_id_to_index.pop(req_id, None) - if req_index is None: - return None - self.req_ids[req_index] = None - - self.greedy_reqs.discard(req_id) - self.random_reqs.discard(req_id) - self.top_p_reqs.discard(req_id) - self.top_k_reqs.discard(req_id) - self.generators.pop(req_index, None) - self.num_logprobs.pop(req_id, None) - self.prompt_logprob_reqs.discard(req_id) - return req_index - - def clear(self) -> None: - self.req_ids = [None] * self.max_num_reqs - self.req_id_to_index.clear() - self.greedy_reqs.clear() - self.random_reqs.clear() - self.top_p_reqs.clear() - self.top_k_reqs.clear() - self.generators.clear() - self.num_logprobs.clear() - self.prompt_logprob_reqs.clear() - - def condense(self, empty_req_indices: List[int]) -> None: - if self.num_reqs == 0: - # The batched states are empty. - return - - # NOTE(woosuk): This function assumes that the empty_req_indices - # is sorted in descending order. - last_req_index = self.num_reqs + len(empty_req_indices) - 1 - while empty_req_indices: - # Find the largest non-empty index. - while last_req_index in empty_req_indices: - last_req_index -= 1 - - # Find the smallest empty index. - empty_index = empty_req_indices.pop() - if empty_index >= last_req_index: - break - - # Swap the states. - req_id = self.req_ids[last_req_index] - self.req_ids[empty_index] = req_id - self.req_ids[last_req_index] = None - self.req_id_to_index[req_id] = empty_index - - # TODO(woosuk): Optimize the copy of token_ids_cpu and - # block_table_cpu. - self.token_ids_cpu[empty_index] = self.token_ids_cpu[ - last_req_index] - self.num_computed_tokens_cpu[ - empty_index] = self.num_computed_tokens_cpu[last_req_index] - self.block_table_cpu[empty_index] = self.block_table_cpu[ - last_req_index] - self.temperature_cpu[empty_index] = self.temperature_cpu[ - last_req_index] - self.top_p_cpu[empty_index] = self.top_p_cpu[last_req_index] - self.top_k_cpu[empty_index] = self.top_k_cpu[last_req_index] - generator = self.generators.pop(last_req_index, None) - if generator is not None: - self.generators[empty_index] = generator - - # Decrement last_req_index since it is now empty. - last_req_index -= 1 - - def make_sampling_metadata( - self, - skip_copy: bool = False, - ) -> SamplingMetadata: - if not skip_copy: - self.temperature[:self.num_reqs].copy_( - self.temperature_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_p[:self.num_reqs].copy_( - self.top_p_cpu_tensor[:self.num_reqs], non_blocking=True) - self.top_k[:self.num_reqs].copy_( - self.top_k_cpu_tensor[:self.num_reqs], non_blocking=True) - return SamplingMetadata( - temperature=self.temperature[:self.num_reqs], - all_greedy=self.all_greedy, - all_random=self.all_random, - top_p=self.top_p[:self.num_reqs], - top_k=self.top_k[:self.num_reqs], - no_top_p=self.no_top_p, - no_top_k=self.no_top_k, - generators=self.generators, - max_num_logprobs=self.max_num_logprobs, - ) - - @property - def num_reqs(self) -> int: - return len(self.req_id_to_index) - - @property - def all_greedy(self) -> bool: - return len(self.random_reqs) == 0 - - @property - def all_random(self) -> bool: - return len(self.greedy_reqs) == 0 - - @property - def no_top_p(self) -> bool: - return len(self.top_p_reqs) == 0 - - @property - def no_top_k(self) -> bool: - return len(self.top_k_reqs) == 0 - - @property - def max_num_logprobs(self) -> int: - return max(self.num_logprobs.values()) if self.num_logprobs else 0 - - @property - def no_logprob(self) -> bool: - return len(self.num_logprobs) == 0 - - @property - def no_prompt_logprob(self) -> bool: - return len(self.prompt_logprob_reqs) == 0 diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py new file mode 100644 index 0000000000000..6d4266d172017 --- /dev/null +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -0,0 +1,65 @@ +""" +Define LoRA adapter for model runner. +""" + +from typing import List + +from vllm.v1.worker.lora_request_batch import LoRARequestBatch + +from vllm.model_executor.models import supports_lora, supports_multimodal +from vllm.logger import init_logger + +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.v1.core.scheduler import SchedulerOutput + +import torch +import numpy as np +import torch.nn as nn + +logger = init_logger(__name__) + +# Defined as a mixin for GPUModelRunner +class LoRAModelRunnerMixin: + + # TODO (varun) : self is untyped. This has ide code completion issues and + # could potentially lead to bugs. + def load_lora_model(self) -> nn.Module: + + assert supports_lora( + self.model + ), f"{self.model.__class__.__name__} does not support LoRA yet." + + if supports_multimodal(self.model): + logger.warning("Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model.") + # It's necessary to distinguish between the max_position_embeddings + # of VLMs and LLMs. + if hasattr(self.model.config, "max_position_embeddings"): + max_pos_embeddings = self.model.config.max_position_embeddings + else: + max_pos_embeddings = ( + self.model.config.text_config.max_position_embeddings) + + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, + self.model_config.get_vocab_size(), + self.lora_config, + self.device, + self.model.embedding_modules, + self.model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) + return self.lora_manager.create_lora_manager(self.model) + + def set_activte_loras(self, + request_batch: LoRARequestBatch, + num_scheduled_tokens: np.array) -> None: + + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + + lora_mapping, lora_requests = request_batch.prepare_lora_inputs(num_scheduled_tokens) + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) diff --git a/vllm/v1/worker/lora_request_batch.py b/vllm/v1/worker/lora_request_batch.py new file mode 100644 index 0000000000000..18525d0185645 --- /dev/null +++ b/vllm/v1/worker/lora_request_batch.py @@ -0,0 +1,304 @@ +import torch +import numpy as np + +from typing import Dict, TypeAlias, Optional, List, Tuple + +from vllm.v1.worker.request_batch_base import RequestBatchAbstract +from vllm.v1.worker.request_batch import RequestBatch +from vllm.v1.worker.cached_request_state import CachedRequestState +from vllm.v1.core.scheduler import RunningRequestData +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.worker.model_runner_device_tensors import ModelRunnerDeviceSamplingTensors +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest + +LoRAID: TypeAlias = int + +class LoRARequestBatch(RequestBatchAbstract): + """ + LoRARequestBatch maintains one RequestBatch object for all possible + LoRA IDs. + The Create, Update, Delete methods dispatch to the RequestBatch corresponding + to the request LoRA ID. + The Read methods, combine / collate information from difference RequestBatch + objects. + """ + + # Assume LoRA IDs are greater than 0. + NO_LORA_ID: LoRAID = 0 + + def _make_request_batch(self) -> RequestBatch: + return RequestBatch(self.max_num_reqs, + self.max_model_len, + self.max_num_blocks_per_req, + self.pin_memory) + + def _get_lora_id_from_request(self, request: CachedRequestState) -> LoRAID: + + if request.lora_request is None: + return self.NO_LORA_ID + + lora_id: LoRAID = request.lora_request.lora_int_id + + assert lora_id != self.NO_LORA_ID, \ + (f"LoRA request ID cannot be equal to NO_LORA_ID" + f"({self.NO_LORA_ID})") + + # Each lora_id gets it own batch + return lora_id + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_blocks_per_req: int, + pin_memory: bool, + ): + super().__init__(max_num_reqs, max_model_len, max_num_blocks_per_req, pin_memory) + + self.req_id_to_lora_id: Dict[str, LoRAID] = {} + self.lora_id_to_batch: Dict[LoRAID, RequestBatch] = \ + {self.NO_LORA_ID: self._make_request_batch()} + self.lora_id_to_lora_request: Dict[LoRAID, LoRARequest] = {} + + def remove_requests(self, req_ids: List[str]) -> None: + for req_id in req_ids: + lora_id: LoRAID = self.req_id_to_lora_id[req_id] + self.lora_id_to_batch[lora_id].remove_requests([req_id]) + self.req_id_to_lora_id.pop(req_id) + + def add_request(self, + request: "CachedRequestState") -> None: + """ + Add the new or resumed requests to the persistent batch. + """ + lora_id: LoRAID = self._get_lora_id_from_request(request) + if lora_id not in self.lora_id_to_batch: + self.lora_id_to_batch[lora_id] = self._make_request_batch() + self.lora_id_to_lora_request[lora_id] = request.lora_request + + # Requests with the same LoRA ID must have the same LoRA request + assert self.lora_id_to_lora_request[lora_id] == request.lora_request, \ + ("Encountered 2 different LoRA requests with the same LoRA ID" + f"LoRA request A : {self.lora_id_to_lora_request[lora_id]}" + f"LoRA request B : {request.lora_request}") + + self.lora_id_to_batch[lora_id].add_request(request) + self.req_id_to_lora_id[request.req_id] = lora_id + + def clear(self) -> None: + for batch in self.lora_id_to_batch.values(): + batch.clear() + + def is_condensed(self) -> bool: + return all([batch.is_condensed() for batch in self.lora_id_to_batch.values()]) + + def sanity_check(self, expected_num_reqs: int) -> None: + assert self.is_condensed() + assert self.num_reqs() == expected_num_reqs + req_ids = set(self.request_ids()) + assert len(req_ids) == self.num_reqs() + req_id_to_index_map = self.request_id_to_index() + assert len(req_id_to_index_map) == expected_num_reqs + assert len(set(req_id_to_index_map.values())) == expected_num_reqs + + def condense(self) -> None: + for batch in self.lora_id_to_batch.values(): + batch.condense() + + def update_states(self, + request_id: str, + request_data: RunningRequestData, + num_existing_block_ids: int) -> None: + lora_id: LoRAID = self.req_id_to_lora_id[request_id] + self.lora_id_to_batch[lora_id].update_states(request_id, request_data, num_existing_block_ids) + + def append_token_id(self, request_id: str, token_id: np.int32, token_idx: int): + lora_id: LoRAID = self.req_id_to_lora_id[request_id] + self.lora_id_to_batch[lora_id].append_token_id(request_id, token_id, token_idx) + + def rewind_generator(self, request_id: str): + lora_id: LoRAID = self.req_id_to_lora_id[request_id] + self.lora_id_to_batch[lora_id].rewind_generator(request_id) + + def request_ids(self) -> List[str]: + return sum([batch.request_ids() for batch in self.lora_id_to_batch.values()], []) + + def num_reqs(self) -> int: + return sum([batch.num_reqs() for batch in self.lora_id_to_batch.values()]) + + def all_greedy(self) -> bool: + return all([batch.all_greedy() for batch in self.lora_id_to_batch.values()]) + + def all_random(self) -> bool: + return all([batch.all_random() for batch in self.lora_id_to_batch.values()]) + + def no_top_p(self) -> bool: + return all([batch.no_top_p() for batch in self.lora_id_to_batch.values()]) + + def no_top_k(self) -> bool: + return all([batch.no_top_k() for batch in self.lora_id_to_batch.values()]) + + def max_num_logprobs(self) -> int: + return max([batch.max_num_logprobs() for batch in self.lora_id_to_batch.values()]) + + def no_logprob(self) -> bool: + return all([batch.no_logprob() for batch in self.lora_id_to_batch.values()]) + + def no_prompt_logprob(self) -> bool: + return all([batch.no_prompt_logprob() for batch in self.lora_id_to_batch.values()]) + + + def make_seq_lens_tensor(self, + num_scheduled_tokens: np.array) -> np.array: + + assert len(num_scheduled_tokens) == self.num_reqs() + + seq_lens_list : List[np.array] = [] + req_offset: int = 0 + for batch in self.lora_id_to_batch.values(): + batch_num_reqs = batch.num_reqs() + if batch_num_reqs == 0: + continue + seq_lens_list.append(batch.make_seq_lens_tensor(num_scheduled_tokens[req_offset : req_offset + batch_num_reqs])) + req_offset += batch_num_reqs + return np.concatenate(tuple(seq_lens_list)) + + def prepare_inputs(self, + num_scheduled_tokens: np.array, + block_size: int, + block_table_device_tensor: torch.Tensor, + input_tokens_device_tensor: torch.Tensor, + input_positions_device_tensor: torch.Tensor, + slot_mapping_device_tensor: torch.Tensor) -> None: + + total_num_reqs: int = self.num_reqs() + assert len(num_scheduled_tokens) == total_num_reqs, "" + total_num_scheduled_tokens: int = np.sum(num_scheduled_tokens) + + start_req_offset: int = 0 + start_token_offset: int = 0 + for batch in self.lora_id_to_batch.values(): + """ + Collate BatchInputs from all batches + """ + if batch.num_reqs() == 0: + continue + + end_req_offset = start_req_offset + batch.num_reqs() + end_token_offset = start_token_offset + np.sum(num_scheduled_tokens[start_req_offset : end_req_offset]) + + batch.prepare_inputs(num_scheduled_tokens[start_req_offset : end_req_offset], + block_size, + block_table_device_tensor[start_req_offset : end_req_offset], + input_tokens_device_tensor[start_token_offset : end_token_offset], + input_positions_device_tensor[start_token_offset : end_token_offset], + slot_mapping_device_tensor[start_token_offset : end_token_offset]) + + start_req_offset = end_req_offset + start_token_offset = end_token_offset + assert start_req_offset == total_num_reqs + assert start_token_offset == total_num_scheduled_tokens + + def make_sampling_metadata(self, + device_tensors: ModelRunnerDeviceSamplingTensors, + skip_copy: bool = False) -> SamplingMetadata: + + num_reqs: int = self.num_reqs() + + # Collate generators from batches + request_generator_map: Dict[int, torch.Generator] = {} + + start_req_offset: int = 0 + + for batch in self.lora_id_to_batch.values(): + if batch.num_reqs() == 0: + continue + + end_req_offset = start_req_offset + batch.num_reqs() + + if not skip_copy: + device_tensors.temperature[start_req_offset:end_req_offset].copy_( + batch.cpu_tensors.temperature.tensor[:batch.num_reqs()], non_blocking=True) + device_tensors.top_p[start_req_offset:end_req_offset].copy_( + batch.cpu_tensors.top_p.tensor[:batch.num_reqs()], non_blocking=True) + device_tensors.top_k[start_req_offset:end_req_offset].copy_( + batch.cpu_tensors.top_k.tensor[:batch.num_reqs()], non_blocking=True) + + batch_request_generator_map = {idx + start_req_offset : generator for idx, generator in batch.generators.items()} + request_generator_map.update(batch_request_generator_map) + + start_req_offset = end_req_offset + assert start_req_offset == num_reqs + + return SamplingMetadata( + temperature=device_tensors.temperature[:num_reqs], + all_greedy=self.all_greedy(), + all_random=self.all_random(), + top_p=device_tensors.top_p[:num_reqs], + top_k=device_tensors.top_k[:num_reqs], + no_top_p=self.no_top_p(), + no_top_k=self.no_top_k(), + generators=request_generator_map, + max_num_logprobs=self.max_num_logprobs(), + ) + + def prepare_lora_inputs(self, num_scheduled_tokens: np.array) \ + -> Tuple[LoRAMapping, set[LoRARequest]]: + """ + Construct and return LoRAMapping and the set of all LoRA Requests. + """ + def batch_num_prompt_mapping(batch: RequestBatch, + batch_req_offset: int): + if batch.no_prompt_logprob(): + return batch.num_reqs() + + batch_req_id_to_index: Dict[str, int] = batch.request_id_to_index() + + # select request indices that require prompt logprobs. Offset those + # indices with batch_req_offset so it can be used to index into + # num_scheduled_tokens. + prompt_logprobs_req_indices: List[int] = \ + [batch_req_id_to_index[req_id] + batch_req_offset \ + for req_id in batch.request_ids() \ + if req_id in batch.prompt_logprob_reqs] + + num_prompt_mapping: int = np.sum(num_scheduled_tokens[prompt_logprobs_req_indices]) + num_prompt_mapping += batch.num_reqs() - len(prompt_logprobs_req_indices) + + return num_prompt_mapping + + num_tokens: int = np.sum(num_scheduled_tokens) + index_mapping: np.array = np.empty((num_tokens,), dtype=np.int32) + # prompt_mapping could be as big as num_tokens depending on the + # requests requesting prompt_logprobs + prompt_mapping: np.array = np.empty((num_tokens,), dtype=np.int32) + lora_requests: set[LoRARequest] = set() + + token_offset: int = 0 + req_offset: int = 0 + prompt_mapping_offset: int = 0 + for lora_id, batch in self.lora_id_to_batch.items(): + batch_num_reqs = batch.num_reqs() + if batch_num_reqs == 0: + continue + + if lora_id != self.NO_LORA_ID: + lora_requests.add(self.lora_id_to_lora_request[lora_id]) + + batch_num_tokens = np.sum(num_scheduled_tokens[req_offset:req_offset + batch_num_reqs]) + index_mapping[token_offset : token_offset + batch_num_tokens] = lora_id + + num_prompt_mapping = batch_num_prompt_mapping(batch, req_offset) + prompt_mapping[prompt_mapping_offset : prompt_mapping_offset + num_prompt_mapping] = lora_id + + token_offset += batch_num_tokens + req_offset += batch_num_reqs + prompt_mapping_offset += num_prompt_mapping + + # TODO (varun) : Is there a way to remove cast to tuple ? + # TODO (varun) : Not differentiating between prefill and decode for now. + # needs some investigation. + return LoRAMapping(index_mapping = tuple(index_mapping[:token_offset]), + prompt_mapping = tuple(prompt_mapping[:prompt_mapping_offset]), + is_prefill=True), lora_requests diff --git a/vllm/v1/worker/model_runner_device_tensors.py b/vllm/v1/worker/model_runner_device_tensors.py new file mode 100644 index 0000000000000..b3a651ee49999 --- /dev/null +++ b/vllm/v1/worker/model_runner_device_tensors.py @@ -0,0 +1,79 @@ +""" +Define object that packs all the devices required by the model runner. +""" + +from dataclasses import dataclass + +import torch + +@dataclass +class ModelRunnerDeviceSamplingTensors: + """ + Device tensors related to model sampling + """ + temperature: torch.Tensor + top_p: torch.Tensor + top_k: torch.Tensor + + @staticmethod + def make(max_num_reqs: int, + device: torch.device) -> "ModelRunnerDeviceSamplingTensors": + temperature = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + top_p = torch.empty((max_num_reqs, ), + dtype=torch.float32, + device=device) + top_k = torch.empty((max_num_reqs, ), + dtype=torch.int32, + device=device) + return ModelRunnerDeviceSamplingTensors(temperature, top_p, top_k) + +@dataclass +class ModelRunnerDeviceTensors: + """ + Device tensors to be maintained by the ModelRunner + """ + block_table: torch.Tensor + input_positions: torch.Tensor + input_tokens: torch.Tensor + inputs_embeds: torch.Tensor + slot_mapping: torch.Tensor + sampling_tensors: ModelRunnerDeviceSamplingTensors + + @staticmethod + def make(max_num_reqs: int, + max_num_tokens: int, + max_num_blocks_per_req: int, + input_embeds_hidden_size: int, + input_embeds_dtype: torch.dtype, + device: torch.device) -> "ModelRunnerDeviceTensors": + block_table = torch.zeros((max_num_reqs, + max_num_blocks_per_req), + device=device, + dtype=torch.int32) + + input_positions = torch.zeros(max_num_tokens, + dtype=torch.int64, + device=device) + + input_tokens = torch.empty(max_num_tokens, + dtype=torch.int32, + device=device) + + slot_mapping = torch.empty(max_num_tokens, + dtype=torch.long, + device=device) + + inputs_embeds = torch.zeros( + (max_num_tokens, input_embeds_hidden_size), + dtype=input_embeds_dtype, + device=device) + + return ModelRunnerDeviceTensors( + block_table = block_table, + input_positions = input_positions, + input_tokens = input_tokens, + inputs_embeds = inputs_embeds, + slot_mapping = slot_mapping, + sampling_tensors = ModelRunnerDeviceSamplingTensors.make(max_num_reqs, device)) \ No newline at end of file diff --git a/vllm/v1/worker/request_batch.py b/vllm/v1/worker/request_batch.py new file mode 100644 index 0000000000000..755d06f9db5b4 --- /dev/null +++ b/vllm/v1/worker/request_batch.py @@ -0,0 +1,112 @@ +""" +ModelRunner input batch +""" + +from typing import Optional, List, Dict + +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.worker.request_batch_base import RequestBatchBase +from vllm.v1.core.scheduler import RunningRequestData +from vllm.v1.worker.model_runner_device_tensors import ModelRunnerDeviceSamplingTensors + +import numpy as np +import torch + +class RequestBatch(RequestBatchBase): + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_blocks_per_req: int, + pin_memory: bool): + super().__init__(max_num_reqs, max_model_len, max_num_blocks_per_req, pin_memory) + + def prepare_inputs(self, + num_scheduled_tokens: np.array, + block_size: int, + block_table_device_tensor: torch.Tensor, + input_tokens_device_tensor: torch.Tensor, + input_positions_device_tensor: torch.Tensor, + slot_mapping_device_tensor: torch.Tensor) -> None: + """ + Translate batch into numpy arrays for model execute. + When device_tensors are available, kickoff a non-blocking cpu-to-device transfer as soon as + the cpu tensors are prepared. + """ + num_reqs = self.num_reqs() + + if num_reqs == 0: + # Empty batch + return None + + # Trigger block table copy + block_table_device_tensor[:num_reqs].copy_( + self.cpu_tensors.block_table.tensor[:num_reqs], + non_blocking=True) + + assert len(num_scheduled_tokens) == num_reqs + indices = np.arange(num_reqs) + req_indices = np.repeat(indices, num_scheduled_tokens) + # TODO (varun) : get this directly from scheduler outputs + total_num_scheduled_tokens = np.sum(num_scheduled_tokens) + + # Input positions + token_positions: torch.Tensor = self.make_token_positions(num_scheduled_tokens, req_indices) + assert len(token_positions) == total_num_scheduled_tokens + # Trigger input positions copy + input_positions_device_tensor[:total_num_scheduled_tokens].copy_( + token_positions, + non_blocking=True) + + # Token indices + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = token_positions + req_indices * self.max_model_len + + # Input tokens + token_ids: torch.Tensor = self.make_token_ids(token_indices) + assert len(token_ids) == total_num_scheduled_tokens + # Trigger token indices copy + input_tokens_device_tensor[:total_num_scheduled_tokens].copy_( + token_ids, + non_blocking=True) + + # Slot mapping + slot_mapping: torch.Tensor = self.make_slot_mapping(token_indices, block_size) + assert len(slot_mapping) == total_num_scheduled_tokens + # Trigger slot mapping copy + slot_mapping_device_tensor[:total_num_scheduled_tokens].copy_( + slot_mapping, + non_blocking=True) + + def make_sampling_metadata(self, + device_tensors: ModelRunnerDeviceSamplingTensors, + skip_copy: bool = False) -> SamplingMetadata: + """ + Transfer cpu sampling to device, if a copy is required, and + translate the batch into SamplingMetadata for model sampling. + """ + + num_reqs: int = self.num_reqs() + + if not skip_copy: + device_tensors.temperature[:num_reqs].copy_( + self.cpu_tensors.temperature.tensor[:num_reqs], non_blocking=True) + device_tensors.top_p[:num_reqs].copy_( + self.cpu_tensors.top_p.tensor[:num_reqs], non_blocking=True) + device_tensors.top_k[:num_reqs].copy_( + self.cpu_tensors.top_k.tensor[:num_reqs], non_blocking=True) + + return SamplingMetadata( + temperature=device_tensors.temperature[:num_reqs], + all_greedy=self.all_greedy(), + all_random=self.all_random(), + top_p=device_tensors.top_p[:num_reqs], + top_k=device_tensors.top_k[:num_reqs], + no_top_p=self.no_top_p(), + no_top_k=self.no_top_k(), + generators=self.generators, + max_num_logprobs=self.max_num_logprobs(), + ) diff --git a/vllm/v1/worker/request_batch_base.py b/vllm/v1/worker/request_batch_base.py new file mode 100644 index 0000000000000..f4984a8028648 --- /dev/null +++ b/vllm/v1/worker/request_batch_base.py @@ -0,0 +1,596 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional, Dict, Set +from vllm.sampling_params import SamplingType +from vllm.v1.core.scheduler import RunningRequestData +from vllm.v1.sample.metadata import SamplingMetadata +from vllm.v1.worker.model_runner_device_tensors import ( + ModelRunnerDeviceTensors, + ModelRunnerDeviceSamplingTensors) + +from vllm.v1.worker.cached_request_state import CachedRequestState +from abc import ABC, abstractmethod + +import torch +import numpy as np + +@dataclass(slots=True) +class CPUTensor: + tensor: torch.Tensor + np_tensor: np.array + + @staticmethod + def build(tensor: torch.Tensor) -> "CPUTensor": + return CPUTensor(tensor, tensor.numpy()) + +@dataclass +class BatchCPUSamplingTensors: + temperature: CPUTensor + top_p: CPUTensor + top_k: CPUTensor + + @staticmethod + def build(max_num_reqs: int, + pin_memory: bool) -> "BatchCPUSamplingTensors": + def make_tensor(dtype: torch.dtype) -> CPUTensor: + tensor = torch.empty((max_num_reqs, ), + dtype = dtype, + pin_memory=pin_memory, + device="cpu") + return CPUTensor.build(tensor) + + return BatchCPUSamplingTensors( + temperature = make_tensor(torch.float32), + top_p = make_tensor(torch.float32), + top_k = make_tensor(torch.int32)) + +@dataclass(slots=True) +class BatchCPUTensors: + """ + Batched CPU tensors maintained by RequestBatch. + Note that all the tensors have the request index as their major dimension. + """ + token_ids_np: np.array + num_computed_tokens_np: np.array + block_table: CPUTensor + # Sampling tensors + temperature: CPUTensor + top_p: CPUTensor + top_k: CPUTensor + + @staticmethod + def make(max_num_reqs : int, + max_model_len : int, + max_num_blocks_per_req: int, + pin_memory: bool) -> "BatchCPUTensors": + + token_ids_np = np.empty((max_num_reqs, max_model_len), dtype=np.int32) + num_computed_tokens_np = np.empty(max_num_reqs, dtype=np.int32) + + block_table = CPUTensor.build( + tensor = torch.zeros((max_num_reqs, max_num_blocks_per_req), + dtype=torch.int32, + pin_memory=pin_memory, + device="cpu")) + + temperature = CPUTensor.build( + tensor = torch.empty((max_num_reqs, ), + dtype = torch.float32, + pin_memory=pin_memory, + device="cpu")) + + top_p = CPUTensor.build( + tensor = torch.empty((max_num_reqs, ), + dtype = torch.float32, + pin_memory=pin_memory, + device="cpu")) + + top_k = CPUTensor.build( + tensor = torch.empty((max_num_reqs, ), + dtype = torch.int32, + pin_memory=pin_memory, + device="cpu")) + + return BatchCPUTensors(token_ids_np = token_ids_np, + num_computed_tokens_np = num_computed_tokens_np, + block_table = block_table, + temperature = temperature, + top_p = top_p, + top_k = top_k) + + def add_request_data(self, req_index: int, request: CachedRequestState) -> None: + """ + Given a request index and some request data, update the CPU tensors at the request index, + with data in request data. + """ + # Copy the prompt token ids and output token ids. + num_prompt_tokens = len(request.prompt_token_ids) + self.token_ids_np[ + req_index, :num_prompt_tokens] = request.prompt_token_ids + start_idx = num_prompt_tokens + end_idx = start_idx + len(request.output_token_ids) + self.token_ids_np[req_index, start_idx:end_idx] = request.output_token_ids + + self.num_computed_tokens_np[req_index] = request.num_computed_tokens + + num_blocks = len(request.block_ids) + self.block_table.np_tensor[req_index, :num_blocks] = request.block_ids + + sampling_params = request.sampling_params + self.temperature.np_tensor[req_index] = sampling_params.temperature + self.top_p.np_tensor[req_index] = sampling_params.top_p + self.top_k.np_tensor[req_index] = sampling_params.top_k + + # to req index is the empty + def transfer(self, from_req_index: int, to_req_index: int) -> None: + """ + Copies data from `from_req_index` to `to_req_index` + """ + # TODO(woosuk): Optimize the copy of token_ids_cpu and + # block_table_cpu. + self.token_ids_np[to_req_index] = self.token_ids_np[from_req_index] + self.num_computed_tokens_np[to_req_index] = self.num_computed_tokens_np[from_req_index] + self.block_table.np_tensor[to_req_index] = self.block_table.np_tensor[from_req_index] + self.temperature.np_tensor[to_req_index] = self.temperature.np_tensor[from_req_index] + self.top_p.np_tensor[to_req_index] = self.top_p.np_tensor[from_req_index] + self.top_k.np_tensor[to_req_index] = self.top_k.np_tensor[from_req_index] + + def update_request_state(self, req_index: int, + request_data: RunningRequestData, + num_existing_block_ids: int): + """ + Given a request index, a running request data (a delta) update the states + in the cpu tensors. + num_existing_block_ids, is used in updating the block table. This is required + as we don't track how many block ids for a request are valid in the block table. + """ + # Update the num_computed_tokens. + self.num_computed_tokens_np[req_index] = (request_data.num_computed_tokens) + + # Update the block table. + num_new_blocks = len(request_data.new_block_ids) + if num_new_blocks == 0: + return + start_index = num_existing_block_ids + end_index = start_index + num_new_blocks + self.block_table.np_tensor[req_index, start_index:end_index] = \ + request_data.new_block_ids + +class RequestBatchAbstract(ABC): + + def __init__(self, + max_num_reqs: int, + max_model_len: int, + max_num_blocks_per_req: int, + pin_memory: bool): + self.max_num_reqs = max_num_reqs + self.max_model_len = max_model_len + self.max_num_blocks_per_req = max_num_blocks_per_req + self.pin_memory = pin_memory + + @abstractmethod + def remove_requests(self, req_ids: List[str]) -> None: + raise NotImplementedError + + + @abstractmethod + def add_request(self, + request: "CachedRequestState") -> None: + """ + Add the new or resumed requests to the persistent batch. + """ + raise NotImplementedError + + @abstractmethod + def clear(self) -> None: + raise NotImplementedError + + @abstractmethod + def condense(self) -> None: + raise NotImplementedError + + @abstractmethod + def update_states(self, + request_id: str, + request_data: RunningRequestData, + num_existing_block_ids: int) -> None: + """ + Update state of the request in batch. + """ + raise NotImplementedError + + @abstractmethod + def append_token_id(self, request_id: str, token_id: np.int32, token_idx: int) -> None: + raise NotImplementedError + + @abstractmethod + def rewind_generator(self, request_id: str) -> None: + raise NotImplementedError + + @abstractmethod + def make_seq_lens_tensor(self, + num_scheduled_tokens: np.array) -> np.array: + """ + Given the number of tokens scheduled per request, return the sequence lengths + of the requests based on cpu_tensors.num_computed_tokens_np + """ + raise NotImplementedError + + @abstractmethod + def prepare_inputs(self, + num_scheduled_tokens: np.array, + block_size: int, + block_table_device_tensor: torch.Tensor, + input_tokens_device_tensor: torch.Tensor, + input_positions_device_tensor: torch.Tensor, + slot_mapping_device_tensor: torch.Tensor) -> None: + + """ + Translate batch into model-input numpy arrays and kickoff a non-blocking + cpu-to-device transfer as soon as a particular tensor is prepared. + """ + raise NotImplementedError + + @abstractmethod + def make_sampling_metadata(self, + sampling_device_tensors: ModelRunnerDeviceSamplingTensors, + skip_copy: bool = False) -> SamplingMetadata: + """ + Transfer cpu sampling to device, if a copy is required, and + translate the batch into SamplingMetadata for model sampling. + """ + raise NotImplementedError + + @abstractmethod + def request_ids(self) -> List[str]: + # Return request ids in order that they appear in the batch. + raise NotImplementedError + + @abstractmethod + def num_reqs(self) -> int: + raise NotImplementedError + + @abstractmethod + def all_greedy(self) -> bool: + raise NotImplementedError + + @abstractmethod + def all_random(self) -> bool: + raise NotImplementedError + + @abstractmethod + def no_top_p(self) -> bool: + raise NotImplementedError + + @abstractmethod + def no_top_k(self) -> bool: + raise NotImplementedError + + @abstractmethod + def max_num_logprobs(self) -> int: + raise NotImplementedError + + @abstractmethod + def no_logprob(self) -> bool: + raise NotImplementedError + + @abstractmethod + def no_prompt_logprob(self) -> bool: + raise NotImplementedError + +class RequestBatchBase(RequestBatchAbstract): + + def __init__( + self, + max_num_reqs: int, + max_model_len: int, + max_num_blocks_per_req: int, + pin_memory: bool, + ): + super().__init__(max_num_reqs, max_model_len, max_num_blocks_per_req, pin_memory) + + self.req_ids: List[Optional[str]] = [None] * max_num_reqs + self.req_id_to_index: Dict[str, int] = {} + # Track fragmentation due to request-removal so the empty slots + # can be re-used for new requests. + self.empty_req_indices: List[int] = [] + self.is_empty_req_indices_sorted: bool = True + + # Batch CPU Tensors + self.cpu_tensors = \ + BatchCPUTensors.make(self.max_num_reqs, self.max_model_len, self.max_num_blocks_per_req, self.pin_memory) + + # Batch Request info + self.greedy_reqs: Set[str] = set() + self.random_reqs: Set[str] = set() + self.top_p_reqs: Set[str] = set() + self.top_k_reqs: Set[str] = set() + # req_index -> generator + self.generators: Dict[int, torch.Generator] = {} + self.num_logprobs: Dict[str, int] = {} + self.prompt_logprob_reqs: Set[str] = set() + + def _add_request( + self, + request: "CachedRequestState", + req_index: int, + ) -> None: + assert req_index < self.max_num_reqs + + req_id = request.req_id + self.req_ids[req_index] = req_id + self.req_id_to_index[req_id] = req_index + + self.cpu_tensors.add_request_data(req_index, request) + + sampling_params = request.sampling_params + if sampling_params.sampling_type == SamplingType.GREEDY: + self.greedy_reqs.add(req_id) + else: + self.random_reqs.add(req_id) + if sampling_params.top_p < 1: + self.top_p_reqs.add(req_id) + if sampling_params.top_k > 0: + self.top_k_reqs.add(req_id) + + self.generators[req_index] = request.generator + + num_logprobs = sampling_params.logprobs + if num_logprobs is not None and num_logprobs > 0: + self.num_logprobs[req_id] = num_logprobs + if sampling_params.prompt_logprobs: + self.prompt_logprob_reqs.add(req_id) + + def _remove_request(self, req_id: str) -> None: + req_index = self.req_id_to_index.pop(req_id, None) + if req_index is None: + return + self.req_ids[req_index] = None + + self.greedy_reqs.discard(req_id) + self.random_reqs.discard(req_id) + self.top_p_reqs.discard(req_id) + self.top_k_reqs.discard(req_id) + self.generators.pop(req_index, None) + self.num_logprobs.pop(req_id, None) + self.prompt_logprob_reqs.discard(req_id) + + self.empty_req_indices.append(req_index) + self.is_empty_req_indices_sorted = False + return req_index + + def remove_requests(self, req_ids: List[str]) -> None: + for req_id in req_ids: + self._remove_request(req_id) + + def add_request(self, + request: "CachedRequestState") -> None: + """ + Add the new or resumed requests to the persistent batch. + """ + # sort empty_req_indices so the smaller ones can be filled first. + if not self.is_empty_req_indices_sorted: + self.empty_req_indices = sorted(self.empty_req_indices, reverse = True) + self.is_empty_req_indices_sorted = True + + # The smaller empty indices are filled first. + if self.empty_req_indices: + # Fill the empty index. + req_index = self.empty_req_indices.pop() + else: + # Append to the end. + req_index = self.num_reqs() + + self._add_request(request, req_index) + + def clear(self) -> None: + self.req_ids = [None] * self.max_num_reqs + self.req_id_to_index.clear() + self.greedy_reqs.clear() + self.random_reqs.clear() + self.top_p_reqs.clear() + self.top_k_reqs.clear() + self.generators.clear() + self.num_logprobs.clear() + self.prompt_logprob_reqs.clear() + + def is_condensed(self) -> bool: + val = all([x is not None for x in self.req_ids[:self.num_reqs()]]) and \ + all([x is None for x in self.req_ids[self.num_reqs():]]) + if not val: + print (f"num reqs : {self.num_reqs()}") + valid_reqs = self.req_ids[:self.num_reqs()] + invalid_reqs = self.req_ids[self.num_reqs():] + + print (f" - empy in valid reqs {valid_reqs.index(None)}") + print (f" - valid reqs {len(valid_reqs)} : {valid_reqs}") + print (f" - invalid reqs {len(invalid_reqs)} : {invalid_reqs} ") + print (f" - valid correct : {all([x is not None for x in self.req_ids[:self.num_reqs()]])}") + print (f" - invalid correct : {all([x is None for x in self.req_ids[self.num_reqs():]])}") + print (f" empty indices : {self.empty_req_indices} ") + + return val + + def condense(self) -> None: + if self.num_reqs() == 0: + # The batched states are empty. + return + if not self.empty_req_indices: + # The batch is packed already. + return + + # sort empty_req_indices. + if not self.is_empty_req_indices_sorted: + self.empty_req_indices = sorted(self.empty_req_indices, reverse = True) + self.is_empty_req_indices_sorted = True + + last_req_index = self.num_reqs() + len(self.empty_req_indices) - 1 + while self.empty_req_indices: + # Find the largest non-empty index. + while last_req_index in self.empty_req_indices: + last_req_index -= 1 + + # Find the smallest empty index. + empty_index = self.empty_req_indices.pop() + if empty_index >= last_req_index: + self.empty_req_indices.clear() + break + + assert self.req_ids[last_req_index] is not None, \ + (f"Invalid last_req_index {last_req_index}, " + f" num_reqs {self.num_reqs()}" + f" empty_indices {self.empty_req_indices}") + + # Swap the states. + req_id = self.req_ids[last_req_index] + self.req_ids[empty_index] = req_id + self.req_ids[last_req_index] = None + self.req_id_to_index[req_id] = empty_index + + self.cpu_tensors.transfer(to_req_index=empty_index, + from_req_index=last_req_index) + + generator = self.generators.pop(last_req_index, None) + if generator is not None: + self.generators[empty_index] = generator + + # Decrement last_req_index since it is now empty. + last_req_index -= 1 + + def update_states(self, + request_id: str, + request_data: RunningRequestData, + num_existing_block_ids: int) -> None: + """ + Update state of the request in batch. + """ + req_index = self.req_id_to_index[request_id] + self.cpu_tensors.update_request_state(req_index, request_data, num_existing_block_ids) + + def append_token_id(self, request_id: str, token_id: np.int32, token_idx: int) -> None: + req_idx: int = self.req_id_to_index[request_id] + self.cpu_tensors.token_ids_np[req_idx, token_idx] = token_id + + def rewind_generator(self, request_id: str) -> None: + req_idx: int = self.req_id_to_index[request_id] + generator = self.generators.get(req_idx) + if generator is not None: + # This relies on cuda-specific torch-internal impl details + generator.set_offset(generator.get_offset() - 4) + + def make_token_positions(self, + num_scheduled_tokens: np.array, + token_req_indices: Optional[np.array]) -> torch.Tensor: + """ + Given the number of scheduled tokens for per request, translate the + batch into token positions. + + E.g. If there are 3 requests in batch, where, + self.cpu_tensors.num_computed_tokens => [4, 10, 3] + num_scheduled_tokens => [3, 4, 2] + then, + return [4, 5, 6, 10, 11, 12, 13, 3, 4] + """ + + max_num_scheduled_tokens = num_scheduled_tokens.max() + assert max_num_scheduled_tokens > 0 + + if token_req_indices is None: + indices = np.arange(self.num_reqs()) + token_req_indices = np.repeat(indices, num_scheduled_tokens) + + num_tokens : int = len(token_req_indices) + positions = torch.empty(num_tokens, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory) + + # Get batched arange + # e.g. num_schedule_tokens [3, 4, 2] + # arange => [0, 1, 2, 0, 1, 2, 3, 0, 1] + arange_matrix = np.tile(np.arange(max_num_scheduled_tokens), + (self.num_reqs(), 1)) + mask = arange_matrix < num_scheduled_tokens[:, np.newaxis] + arange = arange_matrix[mask] + + ## Input Positions + np.add(self.cpu_tensors.num_computed_tokens_np[token_req_indices], + arange, + out = positions.numpy()) + + return positions + + + def make_token_ids(self, + token_indices: torch.Tensor) -> torch.Tensor: + """ + Given the token indices of the requests, that is flattened to match + cpu_tensors.token_ids_np, select the tokens and return as numpy array. + """ + num_tokens : int = len(token_indices) + token_ids = torch.empty(num_tokens, + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + + torch.index_select( + torch.from_numpy(self.cpu_tensors.token_ids_np).flatten(), + 0, + token_indices, + out = token_ids) + return token_ids + + def make_slot_mapping(self, + token_indices: torch.Tensor, + block_size: int) -> torch.Tensor: + """ + Given the token indices of the requests, that is flattened to match + cpu_tensors.token_ids_np, return the slot mapping for the tokens. + """ + num_tokens : int = len(token_indices) + slot_mapping = torch.empty(num_tokens, + dtype=torch.int64, + device="cpu", + pin_memory=self.pin_memory) + + block_numbers = self.cpu_tensors.block_table.tensor.flatten()[ + token_indices // block_size] + block_offsets = token_indices % block_size + torch.add(block_numbers * block_size, + block_offsets, + out=slot_mapping) + return slot_mapping + + def make_seq_lens_tensor(self, + num_scheduled_tokens: np.array) -> np.array: + """ + Given the number of tokens scheduled per request, return the sequence lengths + of the requests based on cpu_tensors.num_computed_tokens_np + """ + return self.cpu_tensors.num_computed_tokens_np[:self.num_reqs()] + num_scheduled_tokens + + def request_ids(self) -> List[str]: + return self.req_ids[:self.num_reqs()] + + def num_reqs(self) -> int: + return len(self.req_id_to_index) + + def all_greedy(self) -> bool: + return len(self.random_reqs) == 0 + + def all_random(self) -> bool: + return len(self.greedy_reqs) == 0 + + def no_top_p(self) -> bool: + return len(self.top_p_reqs) == 0 + + def no_top_k(self) -> bool: + return len(self.top_k_reqs) == 0 + + def max_num_logprobs(self) -> int: + return max(self.num_logprobs.values()) if self.num_logprobs else 0 + + def no_logprob(self) -> bool: + return len(self.num_logprobs) == 0 + + def no_prompt_logprob(self) -> bool: + return len(self.prompt_logprob_reqs) == 0 \ No newline at end of file