From 44681e4fde8f34823e6f2342d784f1aa0f53dc1b Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Fri, 26 Jul 2024 14:47:41 +0200 Subject: [PATCH] Initial support for AIU (#17) Initial support for AIU in vLLM. What is currently supported/tested: - Single AIU - Model: llama-7b-chat - Offline inference (batch size 1) - Online inference (with `max-num-seq=1`) --------- Signed-off-by: Nikolaos Papandreou Signed-off-by: Thomas Parnell Co-authored-by: Nikolaos Papandreou Co-authored-by: TRAVIS JOHNSON --- examples/offline_inference_sendnn.ipynb | 367 ++++++++++++++++++ examples/offline_inference_sendnn.py | 56 +++ examples/test_online_client.py | 85 ++++ examples/test_online_multi_client.py | 78 ++++ requirements-sendnn.txt | 6 + setup.py | 7 + vllm/config.py | 6 +- vllm/core/scheduler.py | 17 +- vllm/engine/arg_utils.py | 3 +- vllm/engine/async_llm_engine.py | 3 + vllm/engine/llm_engine.py | 5 +- vllm/executor/executor_base.py | 2 +- vllm/executor/sendnn_executor.py | 117 ++++++ .../model_executor/model_loader/aiu_common.py | 284 ++++++++++++++ vllm/model_executor/model_loader/sendnn.py | 213 ++++++++++ vllm/utils.py | 10 + vllm/worker/sendnn_model_runner.py | 202 ++++++++++ vllm/worker/sendnn_worker.py | 99 +++++ 18 files changed, 1549 insertions(+), 11 deletions(-) create mode 100644 examples/offline_inference_sendnn.ipynb create mode 100644 examples/offline_inference_sendnn.py create mode 100644 examples/test_online_client.py create mode 100644 examples/test_online_multi_client.py create mode 100644 requirements-sendnn.txt create mode 100644 vllm/executor/sendnn_executor.py create mode 100644 vllm/model_executor/model_loader/aiu_common.py create mode 100644 vllm/model_executor/model_loader/sendnn.py create mode 100644 vllm/worker/sendnn_model_runner.py create mode 100644 vllm/worker/sendnn_worker.py diff --git a/examples/offline_inference_sendnn.ipynb b/examples/offline_inference_sendnn.ipynb new file mode 100644 index 000000000..578c008c7 --- /dev/null +++ b/examples/offline_inference_sendnn.ipynb @@ -0,0 +1,367 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The wurlitzer extension is already loaded. To reload it, use:\n", + " %reload_ext wurlitzer\n" + ] + } + ], + "source": [ + "from vllm import LLM, SamplingParams\n", + "import time\n", + "%load_ext wurlitzer" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "\n", + "with open(\"/etc/aiu/senlib_config.json\", 'rb') as f:\n", + " config = json.load(f)\n", + "\n", + "os.environ[\"AIU_CONFIG_FILE_0\"] = \"/etc/aiu/senlib_config.json\"\n", + "os.environ[\"FLEX_RDMA_PCI_BUS_ADDR_0\"] = config[\"GENERAL\"][\"sen_bus_id\"][0]\n", + "os.environ[\"AIU_WORLD_RANK_0\"] = \"0\"" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO 06-27 12:28:18 llm_engine.py:100] Initializing an LLM engine (v0.4.2) with config: model='/tmp/7B-F', speculative_config=None, tokenizer='/tmp/7B-F', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cpu, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0, served_model_name=/tmp/7B-F)\n", + ">> DEBUG SETUP\n", + "0 / 1 : Python Version : 3.9.18\n", + "0 / 1 : PyTorch Version: 2.2.2+cpu\n", + "0 / 1 : PCI Addr Rank 0 AIU_WORLD_RANK_0=0\n", + "0 / 1 : PCI Addr Rank 0 FLEX_RDMA_PCI_BUS_ADDR_0=0000:1d:00.0\n", + "0 / 1 : FLEX_COMPUTE=SENTIENT\n", + "0 / 1 : FLEX_DEVICE=VFIO\n", + "0 / 1 : DEEPRT_EXPORT_DIR=export/0\n", + "0 / 1 : DTCOMPILER_EXPORT_DIR=export/0\n", + "0 / 1 : AIU_CONFIG_FILE_0=/etc/aiu/senlib_config.json\n", + "0 / 1 : SENLIB_DEVEL_CONFIG_FILE=/etc/aiu/senlib_config.json\n", + "0 / 1 : FLEX_RDMA_PCI_BUS_ADDR_0=0000:1d:00.0\n", + "0 / 1 : FLEX_RDMA_LOCAL_RANK=0\n", + "0 / 1 : FLEX_RDMA_LOCAL_SIZE=1\n", + "0 / 1 : FLEX_RDMA_WORLD_RANK=0\n", + "0 / 1 : FLEX_RDMA_WORLD_SIZE=1\n", + "0 / 1 : Sentient AIU: Enabled (0) (offset=0)\n", + "0 / 1 : Dynamo Backend : sendnn_decoder\n", + "0 / 1 : CPU Cores : 56 x 2 HW threads\n", + "------------------------------------------------------------\n", + "NOTICE: Adjusting torch._dynamo.config.accumulated_cache_size_limit from 64 to 160 to accomidate prompt size of 64 and decode tokens of 20\n", + "NOTICE: Adjusting torch._dynamo.config.cache_size_limit from 8 to 160 to accomidate prompt size of 64 and decode tokens of 20\n" + ] + } + ], + "source": [ + "# Create an LLM.\n", + "llm = LLM(\n", + " model=\"/tmp/7B-F\",\n", + " tokenizer=\"/tmp/7B-F\",\n", + " max_model_len=2048,\n", + " block_size=2048,\n", + " device=\"sendnn\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['Below is an instruction that describes a task. Write a response that appropriately completes the request. Be polite in your response to the user.\\n\\n### Instruction:\\nProvide a list of instructions for preparing chicken soup for a family of four.\\n\\n### Response:']\n" + ] + } + ], + "source": [ + "# Sample prompts.\n", + "template = \"Below is an instruction that describes a task. Write a response that appropriately completes the request. Be polite in your response to the user.\\n\\n### Instruction:\\n{}\\n\\n### Response:\"\n", + "prompt1 = template.format(\n", + " \"Provide a list of instructions for preparing chicken soup for a family of four.\"\n", + ")\n", + "prompts = [\n", + " prompt1,\n", + "]\n", + "print(prompts)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a sampling params object.\n", + "max_tokens = 10\n", + "sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=============== WARM UP 1\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processed prompts: 0%| | 0/1 [00:00 bool: return VLLM_TARGET_DEVICE == "cpu" +def _is_sendnn() -> bool: + return VLLM_TARGET_DEVICE == "sendnn" + def _is_openvino() -> bool: return VLLM_TARGET_DEVICE == "openvino" @@ -370,6 +373,8 @@ def get_vllm_version() -> str: if neuron_version != MAIN_CUDA_VERSION: neuron_version_str = neuron_version.replace(".", "")[:3] version += f"+neuron{neuron_version_str}" + elif _is_sendnn(): + version += "+sendnn" elif _is_openvino(): version += "+openvino" elif _is_tpu(): @@ -423,6 +428,8 @@ def _read_requirements(filename: str) -> List[str]: requirements = _read_requirements("requirements-rocm.txt") elif _is_neuron(): requirements = _read_requirements("requirements-neuron.txt") + elif _is_sendnn(): + requirements = _read_requirements("requirements-sendnn.txt") elif _is_openvino(): requirements = _read_requirements("requirements-openvino.txt") elif _is_tpu(): diff --git a/vllm/config.py b/vllm/config.py index 6403a53f8..3c03db047 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -12,7 +12,7 @@ from vllm.tracing import is_otel_installed from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu, - is_hip, is_neuron, is_openvino, is_tpu, is_xpu, + is_hip, is_neuron, is_sendnn, is_openvino, is_tpu, is_xpu, print_warning_once) if TYPE_CHECKING: @@ -851,6 +851,8 @@ def __init__(self, device: str = "auto") -> None: # Automated device type detection if is_neuron(): self.device_type = "neuron" + elif is_sendnn(): + self.device_type = "sendnn" elif is_openvino(): self.device_type = "openvino" elif is_tpu(): @@ -868,7 +870,7 @@ def __init__(self, device: str = "auto") -> None: self.device_type = device # Some device types require processing inputs on CPU - if self.device_type in ["neuron", "openvino"]: + if self.device_type in ["neuron", "sendnn", "openvino"]: self.device = torch.device("cpu") elif self.device_type in ["tpu"]: self.device = None diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 6e59c5e0f..c5cbe443b 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -6,7 +6,7 @@ from dataclasses import dataclass, field from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union -from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig, DeviceConfig from vllm.core.interfaces import AllocStatus, BlockSpaceManager from vllm.core.policy import Policy, PolicyFactory from vllm.logger import init_logger @@ -266,11 +266,13 @@ def __init__( self, scheduler_config: SchedulerConfig, cache_config: CacheConfig, + device_config: DeviceConfig, lora_config: Optional[LoRAConfig], pipeline_parallel_size: int = 1, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config + self.device_config = device_config # Note for LoRA scheduling: the current policy is extremely # simple and NOT fair. It can lead to starvation of some # LoRAs. This should be improved in the future. @@ -966,10 +968,15 @@ def _can_append_slots(self, seq_group: SequenceGroup) -> bool: # Appending slots only occurs in decoding. is_prefill = False - return self.block_manager.can_append_slots( - seq_group=seq_group, - num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), - ) + if self.device_config.device_type == "sendnn": + # heuristic below doesn't make sense when using very large + # blocks + return True + else: + return self.block_manager.can_append_slots( + seq_group=seq_group, + num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), + ) def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # Schedule sequence groups. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index cd64d3345..0b0cb8898 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -290,7 +290,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument('--block-size', type=int, default=EngineArgs.block_size, - choices=[8, 16, 32], help='Token block size for contiguous chunks of ' 'tokens.') @@ -500,7 +499,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.device, choices=[ "auto", "cuda", "neuron", "cpu", "openvino", - "tpu", "xpu" + "tpu", "xpu", "sendnn" ], help='Device type for vLLM execution.') diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 16b7bc64a..7f58b0c5e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -406,6 +406,9 @@ def _get_executor_cls( elif engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutorAsync executor_class = NeuronExecutorAsync + if engine_config.device_config.device_type == "sendnn": + from vllm.executor.sendnn_executor import SENDNNExecutorAsync + executor_class = SENDNNExecutorAsync elif engine_config.device_config.device_type == "tpu": from vllm.executor.tpu_executor import TPUExecutorAsync executor_class = TPUExecutorAsync diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 48d530589..a13084176 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -310,7 +310,7 @@ def __init__( # NOTE: the cache_config here have been updated with the numbers of # GPU and CPU blocks, which are profiled in the distributed executor. self.scheduler = [ - Scheduler(scheduler_config, cache_config, lora_config, + Scheduler(scheduler_config, cache_config, device_config, lora_config, parallel_config.pipeline_parallel_size) for _ in range(parallel_config.pipeline_parallel_size) ] @@ -393,6 +393,9 @@ def _get_executor_cls(cls, elif engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutor executor_class = NeuronExecutor + elif engine_config.device_config.device_type == "sendnn": + from vllm.executor.sendnn_executor import SENDNNExecutor + executor_class = SENDNNExecutor elif engine_config.device_config.device_type == "tpu": from vllm.executor.tpu_executor import TPUExecutor executor_class = TPUExecutor diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index a848bc709..129186516 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -14,7 +14,7 @@ class ExecutorBase(ABC): """Base class for all executors. An executor is responsible for executing the model on a specific device - type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor + type (e.g., CPU, GPU, Neuron, SENDNN, etc.). Or it can be a distributed executor that can execute the model on multiple devices. """ diff --git a/vllm/executor/sendnn_executor.py b/vllm/executor/sendnn_executor.py new file mode 100644 index 000000000..eb1e5856c --- /dev/null +++ b/vllm/executor/sendnn_executor.py @@ -0,0 +1,117 @@ +from typing import List, Set, Tuple + +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.utils import make_async + +logger = init_logger(__name__) + + +class SENDNNExecutor(ExecutorBase): + + uses_ray: bool = False + + def _init_executor(self) -> None: + assert (self.lora_config is + None), "LoRA is not supported for SENDNN backend." + assert (not self.speculative_config + ), "Speculative decoding not yet supported for SENDNN backend." + + # Instantiate the worker and load the model to the device. + self._init_worker() + + def _init_worker(self): + from vllm.worker.sendnn_worker import SENDNNWorker + + self.driver_worker = SENDNNWorker( + self.model_config, + self.parallel_config, + self.scheduler_config, + self.device_config, + self.cache_config, + ) + self.driver_worker.init_device() + self.driver_worker.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. + """ + return self.driver_worker.determine_num_available_blocks() + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + # assert (execute_model_req.blocks_to_swap_in == {} + # and execute_model_req.blocks_to_swap_out == {} + # and execute_model_req.blocks_to_copy == {}), ( + # "Cache operations are not supported for SENDNN backend.") + + #assert (execute_model_req.blocks_to_swap_in == {}), ("assert 1") + #assert (execute_model_req.blocks_to_swap_out == {}), ("assert 2") + #assert (execute_model_req.blocks_to_copy == {}), ("assert 3") THIS FAILS + + assert execute_model_req.num_lookahead_slots == 0, ( + "lookahead not supported for SENDNN backend.") + + output = self.driver_worker.execute_model( + execute_model_req.seq_group_metadata_list) + return output + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.driver_worker.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.driver_worker.remove_lora(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + return self.driver_worker.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.driver_worker.list_loras() + + def add_prompt_adapter(self, prompt_adapter_request) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the SENDNN backend.") + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the SENDNN backend.") + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the SENDNN backend.") + + def list_prompt_adapters(self) -> Set[int]: + raise NotImplementedError( + "Soft prompt is currently not supported by the SENDNN backend.") + + def check_health(self) -> None: + # SENDNNExecutor will always be healthy as long as + # it's running. + return + + +class SENDNNExecutorAsync(SENDNNExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest, + ) -> List[SamplerOutput]: + output = await make_async( + self.driver_worker.execute_model + )(seq_group_metadata_list=execute_model_req.seq_group_metadata_list, ) + return output + + async def check_health_async(self) -> None: + # SENDNNExecutor will always be healthy as long as + # it's running. + return diff --git a/vllm/model_executor/model_loader/aiu_common.py b/vllm/model_executor/model_loader/aiu_common.py new file mode 100644 index 000000000..f420fb883 --- /dev/null +++ b/vllm/model_executor/model_loader/aiu_common.py @@ -0,0 +1,284 @@ +import os +import sys +import re +import logging +import psutil +import json +import time + +import torch +from torch import nn + +from contextlib import redirect_stdout, redirect_stderr + +# ============================================================== +# Common setup +# ============================================================== +def setup(dynamo_backend, rank=0, world_size=1, local_rank=0, local_size=1): + + print(">> DEBUG SETUP") + verbose = 1 + + if verbose > 0 or 0 == rank: + print(f"{rank} / {world_size} : Python Version : {sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}") + print(f"{rank} / {world_size} : PyTorch Version: {torch.__version__}") + + # put __pycache__ in different directories + # https://docs.python.org/3/library/sys.html#sys.pycache_prefix + if world_size > 1: + sys.pycache_prefix="export/py-" + str(rank) + + # https://pytorch.org/docs/main/generated/torch._logging.set_logs.html#torch._logging.set_logs + # Defaults to: logging.WARNING + ##if pargs.verbose > 1: + ## if 2 == pargs.verbose: + ## torch._logging.set_logs(dynamo=logging.INFO) + ## elif 3 == pargs.verbose: + ## torch._logging.set_logs(dynamo=logging.INFO) + ## torch._dynamo.config.verbose = True + ## else: + ## torch._logging.set_logs(dynamo=logging.DEBUG) + + #------------- + # Envar setup for Sentient backend + #------------- + if "sendnn" in dynamo_backend: + # Environment variable created by the runtime to identify the specific AIU that is assigned to this rank + env_str="AIU_CONFIG_FILE_"+str(rank) + + if os.getenv("FLEX_COMPUTE") is None: + os.environ["FLEX_COMPUTE"] = "SENULATOR" + if os.getenv("FLEX_DEVICE") is None: + os.environ["FLEX_DEVICE"] = "MOCK" + if os.getenv("DEEPRT_EXPORT_DIR") is None: + os.environ["DEEPRT_EXPORT_DIR"] = "export/" + str(rank) + if os.getenv("DTCOMPILER_EXPORT_DIR") is None: + os.environ["DTCOMPILER_EXPORT_DIR"] = "export/" + str(rank) + if os.getenv("DTCOMPILER_KEEP_EXPORT") is None: + os.environ["DTCOMPILER_KEEP_EXPORT"] = "1" + if os.getenv("FLEX_COMPUTE") == "SENTIENT": + if os.getenv("SENLIB_DEVEL_CONFIG_FILE") is None: + if os.getenv(env_str) is not None: + os.environ["SENLIB_DEVEL_CONFIG_FILE"] = os.getenv(env_str) + else: + print(f"{rank} / {world_size} : WARNING: {env_str} was not set.") + + if os.getenv("FLEX_RDMA_WORLD_SIZE") is None: + os.environ["FLEX_RDMA_WORLD_SIZE"] = str(world_size) + if os.getenv("FLEX_RDMA_WORLD_RANK") is None: + os.environ["FLEX_RDMA_WORLD_RANK"] = str(rank) + # Use the 'world' version since the 'local' is not consistent from torchrun + if os.getenv("FLEX_RDMA_LOCAL_SIZE") is None: + os.environ["FLEX_RDMA_LOCAL_SIZE"] = str(world_size) + if os.getenv("FLEX_RDMA_LOCAL_RANK") is None: + os.environ["FLEX_RDMA_LOCAL_RANK"] = str(rank) + aiu_dev_offset=int(os.getenv("AIU_ASSIGNMENT_OFFSET", "0")) + for peer_rank in range(world_size): + pcie_env_str="AIU_WORLD_RANK_"+str(peer_rank+aiu_dev_offset) + flex_env_str="FLEX_RDMA_PCI_BUS_ADDR_"+str(peer_rank) + if os.getenv("FLEX_COMPUTE") == "SENULATOR": + if os.getenv(pcie_env_str) is not None: + os.environ[flex_env_str] = os.getenv(pcie_env_str) + else: + os.environ[pcie_env_str] = f"0000:{rank:02x}:01.0" + os.environ[flex_env_str] = f"0000:{rank:02x}:01.0" + else: + if os.getenv(flex_env_str) is None: + if os.getenv(pcie_env_str) is not None: + os.environ[flex_env_str] = os.getenv(pcie_env_str) + else: + raise RuntimeError(f"{rank} / {world_size} : ERROR: {flex_env_str} and {pcie_env_str} were not set for peer {peer_rank}.") + if 0 == rank and verbose > 0: + print(f"{rank} / {world_size} : PCI Addr Rank {peer_rank} {pcie_env_str}={os.environ[pcie_env_str]}") + print(f"{rank} / {world_size} : PCI Addr Rank {peer_rank} {flex_env_str}={os.environ[flex_env_str]}") + + if 0 == rank and verbose > 0: + print(f"{rank} / {world_size} : FLEX_COMPUTE=" + os.getenv("FLEX_COMPUTE")) + print(f"{rank} / {world_size} : FLEX_DEVICE=" + os.getenv("FLEX_DEVICE")) + print(f"{rank} / {world_size} : DEEPRT_EXPORT_DIR=" + os.getenv("DEEPRT_EXPORT_DIR")) + print(f"{rank} / {world_size} : DTCOMPILER_EXPORT_DIR=" + os.getenv("DTCOMPILER_EXPORT_DIR")) + if os.getenv(env_str) is not None: + print(f"{rank} / {world_size} : {env_str}=" + os.environ[env_str]) + if os.getenv("SENLIB_DEVEL_CONFIG_FILE") is not None: + print(f"{rank} / {world_size} : SENLIB_DEVEL_CONFIG_FILE=" + os.environ["SENLIB_DEVEL_CONFIG_FILE"]) + if os.getenv(flex_env_str) is not None: + print(f"{rank} / {world_size} : {flex_env_str}=" + os.environ[flex_env_str]) + print(f"{rank} / {world_size} : FLEX_RDMA_LOCAL_RANK=" + os.getenv("FLEX_RDMA_LOCAL_RANK")) + print(f"{rank} / {world_size} : FLEX_RDMA_LOCAL_SIZE=" + os.getenv("FLEX_RDMA_LOCAL_SIZE")) + print(f"{rank} / {world_size} : FLEX_RDMA_WORLD_RANK=" + os.getenv("FLEX_RDMA_WORLD_RANK")) + print(f"{rank} / {world_size} : FLEX_RDMA_WORLD_SIZE=" + os.getenv("FLEX_RDMA_WORLD_SIZE")) + + if os.getenv("FLEX_COMPUTE") == "SENTIENT": + pcie_env_str="AIU_WORLD_RANK_"+str(rank+aiu_dev_offset) + if os.getenv(pcie_env_str) is not None: + device_id = os.getenv(pcie_env_str) + else: + with open(os.getenv(env_str)) as fd: + data = json.load(fd) + device_id = data['GENERAL']['sen_bus_id'] + print(f"{rank} / {world_size} : Sentient AIU: Enabled ({device_id}) (offset={aiu_dev_offset})") + else: + print(f"{rank} / {world_size} : Sentient AIU: Disabled (Senulator)") + else: + print(f"{rank} / {world_size} : CPU Execution (Not Sentient AIU)") + + if verbose > 0 or 0 == rank: + print(f"{rank} / {world_size} : Dynamo Backend : {dynamo_backend}") + + # Pin this process to 1 core + p = psutil.Process() + hw_thread_per_core = int(psutil.cpu_count(logical=True)/psutil.cpu_count(logical=False)) + if 0 == rank and verbose > 0: + print(f"{rank} / {world_size} : CPU Cores : {psutil.cpu_count(logical=False)} x {hw_thread_per_core} HW threads") + # p.cpu_affinity( [x+hw_thread_per_core*rank for x in range(hw_thread_per_core)] ) + # print(f"{rank} / {world_size} : CPU Affinity : {p.cpu_affinity()}") + + if 0 == rank and verbose > 0: + print("-"*60) + +# ============================================================== +# Distributed setup +# ============================================================== +def dist_setup(pargs, rank, world_size, local_rank=-0, local_size=-1): + if local_rank < 0: + local_rank = rank + if local_size < 0: + local_size = world_size + + if os.getenv("TORCHELASTIC_RUN_ID") is None: + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + elif pargs.verbose > 0 or 0 == rank: + print(f"{rank} / {world_size} : Detected running via torchrun") + + # initialize the process group + torch.distributed.init_process_group(backend=pargs.dist, rank=rank, world_size=world_size) + # Fix until PT 2.3 + torch._C._distributed_c10d._register_process_group("default", torch.distributed.group.WORLD) + + if pargs.verbose > 0 or 0 == rank: + print(f"{rank} / {world_size} : Parallel Backend: {torch.distributed.get_backend()}") + + setup(pargs, rank, world_size) + + +def dist_get_rank(): + return torch.distributed.get_rank() + +def dist_get_world_size(): + return torch.distributed.get_world_size() + +def dist_cleanup(): + # print(f"{dist_get_rank()} / {dist_get_world_size()} : Done") + torch.distributed.destroy_process_group() + +# ============================================================== +# +# ============================================================== +def parse_human_readable_bytes(human_bytes): + units = {"B": 1, "KB": 2**10, "MB": 2**20, "GB": 2**30, "TB": 2**40} + + human_bytes = human_bytes.upper() + + # Add a space if it not there already + if not re.match(r' ', human_bytes): + # Add the 'B' if it is not present + if "B" not in human_bytes: + # Is this just a raw number? + if not re.match(r'[KMGT]', human_bytes): + human_bytes = human_bytes + "B" + else: + human_bytes = re.sub(r'([KMGT])', r'\1B', human_bytes) + # Add the space + human_bytes = re.sub(r'([KMGT]?B)', r' \1', human_bytes) + + number, unit = [string.strip() for string in human_bytes.split()] + return int(float(number)*units[unit]) + +def size_human_readable(num, is_bytes=True): + if is_bytes: + suffix = "B" + units = ["", "K", "M", "G", "T"] + else: + suffix = "" + units = ["", "K", "M", "B", "T"] + + for unit in units: + if is_bytes: + if abs(num) < 1024.0: + return f"{num:5.1f} {unit}{suffix}" + num /= 1024.0 + else: + if abs(num) < 1000.0: + return f"{num:5.1f} {unit}{suffix}" + num /= 1000.0 + return f"{num:,5.1f} P{suffix}" + +def get_model_size_readable(bytes, params): + return size_human_readable(bytes) + ' (' + size_human_readable(params, False) + ' params)' + +def get_model_size(model: nn.Linear): + """Given a model return a human readable string of the size of the model, + both in terms of the memory footprint (bytes) and parameters. + """ + if model is None: + return "N/A" + + # Count the number of parameters + num_params = 0 + param_size = 0 + for param in model.parameters(): + num_params += param.nelement() + param_size += param.nelement() * param.element_size() + + # Buffer element size + buffer_size = 0 + for buffer in model.buffers(): + buffer_size += buffer.nelement() * buffer.element_size() + + return get_model_size_readable(param_size + buffer_size, num_params) + +def get_model_size_bytes(model: nn.Linear): + """Given a model return a human readable string of the size of the model, + both in terms of the memory footprint (bytes) and parameters. + """ + if model is None: + return 0 + + # Count the number of parameters + param_size = 0 + for param in model.parameters(): + param_size += param.nelement() * param.element_size() + + # Buffer element size + buffer_size = 0 + for buffer in model.buffers(): + buffer_size += buffer.nelement() * buffer.element_size() + + return param_size + buffer_size + +def get_model_num_params(model: nn.Linear): + """Given a model return a human readable string of the size of the model, + both in terms of the memory footprint (bytes) and parameters. + """ + if model is None: + return 0 + + # Count the number of parameters + num_params = 0 + for param in model.parameters(): + num_params += param.nelement() + + return num_params + +def display_model(model): + if model is None: + return + + print("----- Model ----- (Start)") + for m in model.modules(): + if isinstance(m, nn.Linear): + print("---> Linear Layer") + print(f"Model Weight: {m.weight}") + print(f"Model Bias : {m.bias}") + print("----- Model ----- (End)") diff --git a/vllm/model_executor/model_loader/sendnn.py b/vllm/model_executor/model_loader/sendnn.py new file mode 100644 index 000000000..78ee88f26 --- /dev/null +++ b/vllm/model_executor/model_loader/sendnn.py @@ -0,0 +1,213 @@ +"""Utilities for selecting and loading SENDNN models.""" +import importlib +import os +from typing import Dict, Optional, Tuple, List + +import torch +import torch.nn as nn +import transformers +from transformers import PretrainedConfig + +from torch import distributed as dist #for get_model() +from fms.models import get_model +from fms.utils import generation, tokenizers +from fms.utils.fusion import apply_unfuse_weights +from fms.utils.generation import generate, _make_cache_contiguous + +from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, DeviceConfig +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import SamplerOutput, SequenceGroupMetadata + +from vllm.model_executor.model_loader.aiu_common import setup as aiu_common_setup +import torch._inductor.config +try: + from torch_sendnn import torch_sendnn +except ImportError: + print("WARNING: Disabled: torch_sendnn") + pass +try: + import backends.dynamo_tracer +except ImportError: + print("WARNING: Disabled: dynamo_tracer") + pass + + +TORCH_DTYPE_TO_SENDNN_AMP = { + "auto": "f32", + "half": "f16", + "float16": "f16", + "bfloat16": "bf16", + "float": "f32", + "float32": "f32", + torch.float16: "f16", + torch.bfloat16: "bf16", + torch.float32: "f32", +} + +# Models supported by SENDNN. +_SENDNN_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = { + # "LlamaForCausalLM": ("transformers_neuronx.llama.model", + # "LlamaForSampling", "LlamaForCausalLM"), + # "MistralForCausalLM": ("transformers_neuronx.mistral.model", + # "MistralForSampling", "MistralForCausalLM") +} + +DYN_BACKEND = "sendnn_decoder" + +# used as baseline the following code for llama 7b +# https://github.ibm.com/ai-chip-toolchain/multi-aiu-dev/blob/main/models/llama/inference.py + +class SENDNNCasualLM(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + ) -> None: + super().__init__() + self.config = config + self.logits_processor = LogitsProcessor(config.vocab_size, + logits_as_input=True) + self.sampler = Sampler() + self.past_key_value_states = None + self.warmup_flag = 0 + + # Lazy initialized + self.model: nn.Module + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + input_block_ids: torch.Tensor, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> torch.Tensor: + + # logits = self.model(input_ids, + # cache_ids=positions, + # start_ids=input_block_ids) + + is_prompt = seq_group_metadata_list[0].is_prompt + if is_prompt: + self.past_key_value_states = None + + # BELOW IS DEBUG CODE TO TEST WARM-UP CONCEPT + dynamo_backend = DYN_BACKEND + if "sendnn" in dynamo_backend: + if is_prompt: + if (self.warmup_flag == 0): + self.warmup_flag = self.warmup_flag+1 + elif (self.warmup_flag == 1): + torch_sendnn.update_lazyhandle() + self.warmup_flag = self.warmup_flag+1 + else: + self.warmup_flag = self.warmup_flag + + #output = self.model(input_ids, **kwargs) + output = self.model(input_ids, past_key_value_states=self.past_key_value_states, use_cache=True) + logits, past_key_value_states = output + self.past_key_value_states = past_key_value_states + logits = logits[:, -1, :] + + return logits + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(None, hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, model_name_or_path: str, device_type, **kwargs): + + # TODO(npo): move parameters in proper place + model_source = "meta" + variant = "7b" + distr_param = None + architecture = "llama" + prompt_len = 64 + compile_mode = "default" + dynamo_backend = DYN_BACKEND + max_decode_tokens = 20 + + # Load the weights from the cached or downloaded files. + self.model = get_model( + architecture=architecture,# + variant=variant,# + model_path=model_name_or_path,# + source=model_source,# + device_type=device_type,# + distributed_strategy=distr_param,# + group=dist.group.WORLD,# + ) + + self.model.eval() + torch.set_grad_enabled(False) + + _target_cache_size = max( int(max_decode_tokens * 2), int(prompt_len * 2.5)) + if hasattr(torch._dynamo.config, "accumulated_cache_size_limit"): + if _target_cache_size > torch._dynamo.config.accumulated_cache_size_limit: + _prev = torch._dynamo.config.accumulated_cache_size_limit + torch._dynamo.config.accumulated_cache_size_limit = _target_cache_size + print(f"NOTICE: Adjusting torch._dynamo.config.accumulated_cache_size_limit from {_prev} to {torch._dynamo.config.accumulated_cache_size_limit} to accomidate prompt size of {prompt_len} and decode tokens of {max_decode_tokens}") + + if _target_cache_size > torch._dynamo.config.cache_size_limit: + _prev = torch._dynamo.config.cache_size_limit + torch._dynamo.config.cache_size_limit = _target_cache_size + print(f"NOTICE: Adjusting torch._dynamo.config.cache_size_limit from {_prev} to {torch._dynamo.config.cache_size_limit} to accomidate prompt size of {prompt_len} and decode tokens of {max_decode_tokens}") + + if "sendnn" in dynamo_backend: + torch._dynamo.config.dynamic_shapes=False + torch._dynamo.config.automatic_dynamic_shapes=False + + # Bug in PT 2.1.2 + torch._inductor.config.split_cat_fx_passes = False + # Silence error from PT 2.1.2: + # AssertionError: expected size 32==32, stride 128==8192 at dim=1 + # TORCHINDUCTOR_SIZE_ASSERTS=0 + torch._inductor.config.size_asserts=False + # Bug with kv-cache in PT2.1 + torch._inductor.config.joint_graph_constant_folding = False + + self.model = torch.compile(self.model, mode=compile_mode, backend=dynamo_backend) + +def _get_model_architecture(config: PretrainedConfig) -> str: + architectures = getattr(config, "architectures", []) + for arch in architectures: + if arch in _SENDNN_SUPPORTED_MODELS: + return arch + raise ValueError( + f"Model architectures {architectures} are not supported on SENDNN " + f"for now. Supported architectures: " + f"{list(_SENDNN_SUPPORTED_MODELS.keys())}") + + +def get_sendnn_model(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig) -> nn.Module: + + # Create a model instance. + model = SENDNNCasualLM(model_config.hf_config) + + # Envar setup for Sentient backend + dynamo_backend = DYN_BACKEND + aiu_common_setup(dynamo_backend) + + # Load the weights from the cached or downloaded files. + model.load_weights( + model_config.model, + device_type=device_config.device) + + if dynamo_backend == "sendnn_decoder": + model = apply_unfuse_weights(model) + + #return model.eval() + return model diff --git a/vllm/utils.py b/vllm/utils.py index 876c3bf90..d71fd0d1b 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -206,6 +206,13 @@ def is_neuron() -> bool: @lru_cache(maxsize=None) +def is_sendnn() -> bool: + try: + import torch_sendnn + except ImportError: + torch_sendnn = None + return torch_sendnn is not None + def is_tpu() -> bool: try: import libtpu @@ -589,6 +596,9 @@ def is_pin_memory_available() -> bool: elif is_neuron(): print_warning_once("Pin memory is not supported on Neuron.") return False + elif is_sendnn(): + print_warning_once("Pin memory is not supported on SENDNN device.") + return False elif is_cpu() or is_openvino(): return False return True diff --git a/vllm/worker/sendnn_model_runner.py b/vllm/worker/sendnn_model_runner.py new file mode 100644 index 000000000..b76d4945f --- /dev/null +++ b/vllm/worker/sendnn_model_runner.py @@ -0,0 +1,202 @@ +from typing import List, Optional, Tuple + +import torch +from torch import nn +import time + +from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, + SchedulerConfig) +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.model_loader.sendnn import get_sendnn_model +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.utils import is_pin_memory_available, make_tensor_with_pad + +logger = init_logger(__name__) + + +class SENDNNModelRunner: + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + + if model_config is not None and model_config.get_sliding_window(): + logger.warning("Sliding window is not supported on SENDNN. " + "The model will run without sliding window.") + self.device_config = (device_config + if device_config is not None else DeviceConfig()) + self.device = self.device_config.device + self.pin_memory = is_pin_memory_available() + + # Lazy initialization. + self.model: nn.Module # initialize after load_model. + + def load_model(self) -> None: + self.model = get_sendnn_model(self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config) + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + input_block_ids: List[int] = [] + + seq_lens: List[int] = [] + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_tokens = seq_data.get_token_ids() + seq_len = len(prompt_tokens) + seq_lens.append(seq_len) + + input_tokens.append(prompt_tokens) + input_positions.append(list(range(seq_len))) + + assert seq_group_metadata.block_tables is not None + block_table = seq_group_metadata.block_tables[seq_id] + assert len(block_table) == 1 + input_block_ids.append(block_table[0]) + + max_seq_len = max(seq_lens) + assert max_seq_len > 0 + input_tokens = make_tensor_with_pad(input_tokens, + pad=0, + dtype=torch.long, + max_len=max_seq_len, + device=self.device) + input_positions = make_tensor_with_pad(input_positions, + pad=0, + dtype=torch.long, + max_len=max_seq_len, + device=self.device) + input_block_ids = torch.tensor(input_block_ids, + dtype=torch.long, + device=self.device) + + return input_tokens, input_positions, input_block_ids, seq_lens + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + input_block_ids: List[int] = [] + context_lens: List[int] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + + seq_ids = list(seq_group_metadata.seq_data.keys()) + + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append([generation_token]) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append([position]) + context_lens.append(seq_len) + + assert seq_group_metadata.block_tables is not None + block_table = seq_group_metadata.block_tables[seq_id] + assert len(block_table) == 1 + input_block_ids.append(block_table[0]) + + input_tokens = make_tensor_with_pad(input_tokens, + pad=0, + dtype=torch.long, + max_len=1, + device=self.device) + input_positions = make_tensor_with_pad(input_positions, + pad=0, + dtype=torch.long, + max_len=1, + device=self.device) + context_lens = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) + input_block_ids = torch.tensor(input_block_ids, + dtype=torch.long, + device=self.device) + + return input_tokens, input_positions, input_block_ids + + def prepare_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]: + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, input_block_ids, + seq_lens) = self._prepare_prompt(seq_group_metadata_list) + else: + (input_tokens, input_positions, + input_block_ids) = self._prepare_decode(seq_group_metadata_list) + seq_lens = [] + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + # query_lens is not needed if chunked prefill is not + # supported. Since SENDNN worker doesn't support chunked prefill + # just use seq_lens instead. + seq_lens, + self.device, + self.pin_memory) + return (input_tokens, input_positions, input_block_ids, + sampling_metadata) + + #@torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Optional[SamplerOutput]: + (input_tokens, input_positions, input_block_ids, sampling_metadata + ) = self.prepare_input_tensors(seq_group_metadata_list) + t0 = time.time() + hidden_states = self.model( + input_ids=input_tokens, + positions=input_positions, + input_block_ids=input_block_ids, + seq_group_metadata_list=seq_group_metadata_list, + ) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # Sample the next token. + output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + t1 = time.time()-t0 + print("[sendnn_model_runner:execute_model] t_token: %.2fms" % (t1*1000)) + + return output + + @property + def vocab_size(self) -> int: + return self.model_config.get_vocab_size() diff --git a/vllm/worker/sendnn_worker.py b/vllm/worker/sendnn_worker.py new file mode 100644 index 000000000..fdbcbccad --- /dev/null +++ b/vllm/worker/sendnn_worker.py @@ -0,0 +1,99 @@ +"""A SENDNN worker class.""" +from typing import List, Tuple + +import torch +import torch.distributed + +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +from vllm.model_executor import set_random_seed +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.worker.sendnn_model_runner import SENDNNModelRunner +from vllm.worker.worker_base import LoraNotSupportedWorkerBase + + +class SENDNNWorker(LoraNotSupportedWorkerBase): + """A worker class that executes the model on a group of SENDNN cores. + """ + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + ) -> None: + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + self.model_runner = SENDNNModelRunner(model_config, parallel_config, + scheduler_config, device_config) + + def init_device(self) -> None: + # Set random seed. + set_random_seed(self.model_config.seed) + + def load_model(self): + self.model_runner.load_model() + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks. + + Swapping is not yet supported, so always return num_cpu_blocks=0. + + We configure num_gpu_blocks to be equal to max_num_seqs. + """ + # Set the number of GPU blocks to be the same as the maximum number of + # sequences that can be processed in a single batch. This is equivalent + # to schedule without PagedAttention. + num_gpu_blocks = self.scheduler_config.max_num_seqs + + # Swap not yet supported with SENDNN backend. + num_cpu_blocks = 0 + + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache. + """ + + # Different values are not tested. + assert num_cpu_blocks == 0 + assert num_gpu_blocks == self.scheduler_config.max_num_seqs + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + #@torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> List[SamplerOutput]: + torch.set_grad_enabled(False) + num_seq_groups = len(seq_group_metadata_list) + + # If there is no input, we don't need to execute the model. + if num_seq_groups == 0: + return [] + + output = self.model_runner.execute_model(seq_group_metadata_list) + + # SENDNN worker only supports single-step output. Wrap the output in a + # list to conform to interface. + return [output] + + def get_cache_block_size_bytes(self) -> int: + """Determine the size in bytes of a cache block. + + This is required for speculative decoding; it is not yet implemented. + """ + raise NotImplementedError