From 3e43bb2a7e31aa008cd71b49d28c8fe2aa73ca1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sophie=20du=20Cou=C3=A9dic?= Date: Thu, 19 Dec 2024 11:43:37 +0100 Subject: [PATCH] Prepare for open sourcing (#80) Initial code drop with Spyre support Signed-off-by: Thomas Parnell Signed-off-by: Nikolaos Papandreou Signed-off-by: Burkhard Ringlein Signed-off-by: Max de Bayser Co-authored-by: Nick Hill Co-authored-by: Thomas Parnell Co-authored-by: Nikolaos Papandreou Co-authored-by: TRAVIS JOHNSON Co-authored-by: Burkhard Ringlein Co-authored-by: Yannick Schnider Co-authored-by: Jan van Lunteren Co-authored-by: Maximilien Philippe Marie de Bayser --- .yapfignore | 2 + Dockerfile.spyre | 28 ++ README.md | 70 ++++ examples/offline_inference_multi_spyre.py | 60 +++ examples/offline_inference_spyre.ipynb | 313 ++++++++++++++++ examples/offline_inference_spyre.py | 46 +++ examples/online_inference_spyre.ipynb | 250 +++++++++++++ .../online_inference_spyre_multiple.ipynb | 256 +++++++++++++ examples/spyre_warmup_online_client.py | 82 ++++ format.sh | 4 +- pyproject.toml | 6 +- requirements-spyre.txt | 7 + setup.py | 21 +- tests/spyre/spyre_util.py | 277 ++++++++++++++ tests/spyre/test_spyre_basic.py | 73 ++++ tests/spyre/test_spyre_embeddings.py | 55 +++ tests/spyre/test_spyre_max_prompt_length.py | 101 +++++ tests/spyre/test_spyre_seed.py | 75 ++++ tests/spyre/test_spyre_tensor_parallel.py | 77 ++++ tests/spyre/test_spyre_warmup_shapes.py | 85 +++++ vllm/config.py | 45 ++- vllm/core/scheduler.py | 76 +++- vllm/engine/arg_utils.py | 2 +- vllm/engine/async_llm_engine.py | 8 + vllm/engine/llm_engine.py | 9 + vllm/envs.py | 38 ++ vllm/executor/executor_base.py | 4 +- vllm/executor/multiproc_spyre_executor.py | 267 ++++++++++++++ vllm/executor/spyre_executor.py | 180 +++++++++ .../layers/quantization/utils/marlin_utils.py | 10 +- vllm/model_executor/model_loader/spyre.py | 213 +++++++++++ .../model_loader/spyre_setup.py | 145 ++++++++ vllm/platforms/__init__.py | 10 + vllm/platforms/interface.py | 4 + vllm/platforms/spyre.py | 9 + vllm/utils.py | 3 + vllm/worker/spyre_embedding_model_runner.py | 171 +++++++++ vllm/worker/spyre_model_runner.py | 349 ++++++++++++++++++ vllm/worker/spyre_worker.py | 337 +++++++++++++++++ 39 files changed, 3748 insertions(+), 20 deletions(-) create mode 100644 Dockerfile.spyre create mode 100644 examples/offline_inference_multi_spyre.py create mode 100644 examples/offline_inference_spyre.ipynb create mode 100644 examples/offline_inference_spyre.py create mode 100644 examples/online_inference_spyre.ipynb create mode 100644 examples/online_inference_spyre_multiple.ipynb create mode 100644 examples/spyre_warmup_online_client.py create mode 100644 requirements-spyre.txt create mode 100644 tests/spyre/spyre_util.py create mode 100644 tests/spyre/test_spyre_basic.py create mode 100644 tests/spyre/test_spyre_embeddings.py create mode 100644 tests/spyre/test_spyre_max_prompt_length.py create mode 100644 tests/spyre/test_spyre_seed.py create mode 100644 tests/spyre/test_spyre_tensor_parallel.py create mode 100644 tests/spyre/test_spyre_warmup_shapes.py create mode 100644 vllm/executor/multiproc_spyre_executor.py create mode 100644 vllm/executor/spyre_executor.py create mode 100644 vllm/model_executor/model_loader/spyre.py create mode 100644 vllm/model_executor/model_loader/spyre_setup.py create mode 100644 vllm/platforms/spyre.py create mode 100644 vllm/worker/spyre_embedding_model_runner.py create mode 100644 vllm/worker/spyre_model_runner.py create mode 100644 vllm/worker/spyre_worker.py diff --git a/.yapfignore b/.yapfignore index 2d6dcf838..7a23d8913 100644 --- a/.yapfignore +++ b/.yapfignore @@ -1 +1,3 @@ collect_env.py + +vllm/model_executor/model_loader/spyre_setup.py diff --git a/Dockerfile.spyre b/Dockerfile.spyre new file mode 100644 index 000000000..c68dc5b4d --- /dev/null +++ b/Dockerfile.spyre @@ -0,0 +1,28 @@ +# Global Args ################################################################# +ARG BASE_UBI_IMAGE_TAG=9.4 +ARG PYTHON_VERSION=3.12 + +# Base Layer ################################################################## +FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS base +ARG PYTHON_VERSION +ENV PYTHON_VERSION=${PYTHON_VERSION} +WORKDIR /workspace/vllm + +# Install some basic utilities ################################################################## +RUN microdnf update -y && microdnf install -y \ + python${PYTHON_VERSION}-devel python${PYTHON_VERSION}-pip python${PYTHON_VERSION}-wheel git vim gcc g++\ + && microdnf clean all + +# Install build dependencies ################################################################## +RUN --mount=type=bind,source=requirements-build.txt,target=requirements-build.txt \ + python3.12 -m pip install --upgrade pip && \ + pip install -r requirements-build.txt + +# Build vLLM ################################################################## +COPY . . + +ENV VLLM_TARGET_DEVICE=spyre +RUN --mount=type=bind,source=.git,target=.git \ + pip install --no-build-isolation -v -e . + +CMD ["/bin/bash"] diff --git a/README.md b/README.md index 0ef073210..cf24582d0 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,76 @@ Easy, fast, and cheap LLM serving for everyone --- +## What is the purpose of this fork? + +This is a private fork of vLLM that we are using to develop support for IBM Research's AI accelerator (Spyre). +The idea is that the main branch of this repo should not diverge significantly from upstream beyond changes required to enable Spyre. +We will try to rebase against upstream frequently and we plan to contribute these changes to the upstream repository in the future. + +--- +## Supported IBM Granite models on Spyre + +| Model | 3b | 7b | 8b | 13b | 20b | +|:------------:|:------------:|:------------:|:------------:|:------------:|:------------:| +| **llama** | NO1
[weights](https://huggingface.co/ibm-granite/granite-3b-code-base) | YES2
[weights](https://huggingface.co/ibm-granite/granite-7b-base) | YES3
[weights](https://huggingface.co/ibm-granite/granite-8b-code-base) | X | X | +| **gpt big code** | YES4
[-](tom) | X | X | YES5
[-](tom) | YES6
[weights](https://huggingface.co/ibm-granite/granite-20b-code-base) | + + + +YES  =  working on Spyre +NO   =  not yet working on Spyre +X      =  no weights available + + +#### Path to models + +1 : ```/models/granite-3b-code-base```
+2 : ```/models/granite-7b-base```
+3 : ```/models/granite-8b-code-base```
+4 : ```/models/granite-3b-base```
+5 : ```/models/granite-13b-base```
+6 : ```/models/granite-20b-code-base```

+(PVC in dev pod)
+## Running ***offline*** demo on Spyre + +```bash +python3 examples/offline_inference_spyre.py +``` +## Running ***online*** demo on Spyre + +### Batch size 1 +Log in to the same pod with two terminal windows and launch the server in one and submit requests from the other. + +**1st terminal window**: Set up the server with a model provided at \ [above](#path-to-models) (slow, takes a long time due to Spyre compilation): +```bash +python3 -m vllm.entrypoints.openai.api_server --model --max-model-len=2048 --block-size=2048 +``` +Optionally set the desired prompt padding (*default 64*) to any multiple of 64 and specify the maximal number of generated output tokens (*default 20*) with **VLLM_SPYRE_WARMUP_PROMPT_LENS** and **VLLM_SPYRE_WARMUP_NEW_TOKENS**: +```bash +export VLLM_SPYRE_WARMUP_PROMPT_LENS=64 +export VLLM_SPYRE_WARMUP_NEW_TOKENS=20 +``` +before starting the server. +**2nd terminal window**: When the above warmup has completed, submit sample prompts for LLM completion (fast): +```bash +python3 examples/spyre_warmup_online_client.py +``` +### Batch size 4/8 + +Before launching the server specify the batch size to be used (below set to 8) via the environment variable **VLLM_SPYRE_WARMUP_BATCH_SIZES** (*default 1*): +```bash +export VLLM_SPYRE_WARMUP_BATCH_SIZES=4 +``` + +Finally continue as described [above](#batch-size-1) by launching the server in the 1st terminal window. +Before submitting prompts from the 2nd terminal window make sure to specify the batch size (same as set via **VLLM_SPYRE_WARMUP_BATCH_SIZES**) in the [client script](./examples/spyre_warmup_online_client.py) (line 44). +### Example notebooks + +- [./examples/online_inference_spyre.ipynb](./examples/online_inference_spyre.ipynb) +- [./examples/offline_inference_spyre.ipynb](./examples/offline_inference_spyre.ipynb) + + +--- *Latest News* 🔥 - [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing). - [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there! diff --git a/examples/offline_inference_multi_spyre.py b/examples/offline_inference_multi_spyre.py new file mode 100644 index 000000000..7bf422d8c --- /dev/null +++ b/examples/offline_inference_multi_spyre.py @@ -0,0 +1,60 @@ +import gc +import os +import time + +from vllm import LLM, SamplingParams + +max_tokens = 3 + +os.environ["VLLM_SPYRE_WARMUP_PROMPT_LENS"] = '64' +os.environ["VLLM_SPYRE_WARMUP_NEW_TOKENS"] = str(max_tokens) +os.environ['VLLM_SPYRE_WARMUP_BATCH_SIZES'] = '1' + +# stuff for multi-spyre +os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" +os.environ["DISTRIBUTED_STRATEGY_IGNORE_MODULES"] = "WordEmbedding" +os.environ["MASTER_ADDR"] = "localhost" +os.environ["MASTER_PORT"] = "12355" + +# Sample prompts. +template = ( + "Below is an instruction that describes a task. Write a response that " + "appropriately completes the request. Be polite in your response to the " + "user.\n\n### Instruction:\n{}\n\n### Response:") +prompt1 = template.format( + "Provide a list of instructions for preparing chicken soup for a family " + "of four.") +prompts = [ + prompt1, +] + +# Create a sampling params object. +sampling_params = SamplingParams(max_tokens=max_tokens, + temperature=0.0, + ignore_eos=True) +# Create an LLM. +llm = LLM( + model="/models/llama-194m", + tokenizer="/models/llama-194m", + max_model_len=2048, + block_size=2048, + device="spyre", + tensor_parallel_size=2, +) + +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +print("=============== GENERATE") +t0 = time.time() +outputs = llm.generate(prompts, sampling_params) +print("Time elaspsed for %d tokens is %.2f sec" % + (len(outputs[0].outputs[0].token_ids), time.time() - t0)) +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") +print(output.outputs[0]) + +# needed to prevent ugly stackdump caused by sigterm +del llm +gc.collect() diff --git a/examples/offline_inference_spyre.ipynb b/examples/offline_inference_spyre.ipynb new file mode 100644 index 000000000..792c73177 --- /dev/null +++ b/examples/offline_inference_spyre.ipynb @@ -0,0 +1,313 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "bb1996e6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING 08-15 14:54:53 _custom_ops.py:14] Failed to import from vllm._C with ModuleNotFoundError(\"No module named 'vllm._C'\")\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/vllm/lib64/python3.9/site-packages/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash:\n", + "No module named 'vllm.commit_id'\n", + " from vllm.version import __version__ as VLLM_VERSION\n" + ] + } + ], + "source": [ + "import time\n", + "%load_ext wurlitzer" + ] + }, + { + "cell_type": "markdown", + "id": "45172614", + "metadata": {}, + "source": [ + "Offline inference demo\n", + "----------------------------\n", + "This is just a brief demo to show that vLLM with Spyre can be used in the offline mode. \n", + "\n", + "vLLM will determine the Spyre config automatically and warmup the Spyre stack. \n", + "The startup of vLLM (including warmup of Spyre), is expected to take 15 min for prompt length of 64 and maximum number of decode tokens 5 (it will take ~20min for 64/20)." + ] + }, + { + "cell_type": "markdown", + "id": "bf37c07b", + "metadata": {}, + "source": [ + "### 1. create vLLM instance \n", + "(for offline usage, including warmup)\n", + "\n", + "The maximum prompt length and maximum number of decode tokens can be specified using the environment variables `VLLM_SPYRE_WARMUP_PROMPT_LENS`, and `VLLM_SPYRE_WARMUP_NEW_TOKENS`. \n", + "Otherwise the default max prompt length of 64 and maximum of 20 decode tokens will be used. " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "ecf0992b", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ['VLLM_SPYRE_WARMUP_PROMPT_LENS'] = '64'\n", + "os.environ['VLLM_SPYRE_WARMUP_NEW_TOKENS'] = '5'" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "88b984d7", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO 08-15 14:54:54 llm_engine.py:176] Initializing an LLM engine (v0.5.3.post1) with config: model='/tmp/7B-F', speculative_config=None, tokenizer='/tmp/7B-F', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cpu, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=0, served_model_name=/tmp/7B-F, use_v2_block_manager=False, enable_prefix_caching=False)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "You are using the default legacy behaviour of the . This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING: Disabled: dynamo_tracer\n", + "WARNING 08-15 14:54:55 utils.py:581] Pin memory is not supported on Spyre device.\n", + "[SpyreWorker] environment configured\n", + "[SpyreWorker] load model...\n", + ">> DEBUG SETUP\n", + "0 / 1 : Python Version : 3.9.18\n", + "0 / 1 : PyTorch Version: 2.2.2+cpu\n", + "0 / 1 : PCI Addr Rank 0 AIU_WORLD_RANK_0=0\n", + "0 / 1 : PCI Addr Rank 0 FLEX_RDMA_PCI_BUS_ADDR_0=0000:1e:00.0\n", + "0 / 1 : FLEX_COMPUTE=SENTIENT\n", + "0 / 1 : FLEX_DEVICE=VFIO\n", + "0 / 1 : DEEPRT_EXPORT_DIR=export/0\n", + "0 / 1 : DTCOMPILER_EXPORT_DIR=export/0\n", + "0 / 1 : AIU_CONFIG_FILE_0=/etc/aiu/senlib_config.json\n", + "0 / 1 : SENLIB_DEVEL_CONFIG_FILE=/etc/aiu/senlib_config.json\n", + "0 / 1 : FLEX_RDMA_PCI_BUS_ADDR_0=0000:1e:00.0\n", + "0 / 1 : FLEX_RDMA_LOCAL_RANK=0\n", + "0 / 1 : FLEX_RDMA_LOCAL_SIZE=1\n", + "0 / 1 : FLEX_RDMA_WORLD_RANK=0\n", + "0 / 1 : FLEX_RDMA_WORLD_SIZE=1\n", + "0 / 1 : Spyre: Enabled (0) (offset=0)\n", + "0 / 1 : Dynamo Backend : sendnn_decoder\n", + "0 / 1 : CPU Cores : 56 x 2 HW threads\n", + "------------------------------------------------------------\n", + "NOTICE: Adjusting torch._dynamo.config.accumulated_cache_size_limit from 64 to 160 to accommodate prompt size of 64 and decode tokens of 5\n", + "NOTICE: Adjusting torch._dynamo.config.cache_size_limit from 8 to 160 to accommodate prompt size of 64 and decode tokens of 5\n", + "\tload model took 62.92104411125183s\n", + "[SpyreWorker] Start warming up 1 different prompt/decode-shape combinations.\n", + "[SpyreWorker] Warmup 1/1 prompt/decode-shape combinations...\n", + "[SpyreWorker] warmup for prompt length 64 and max output tokens 5.\n", + "[SpyreWorker] warmup 1/2...\n", + "[SpyreWorker] warmup 2/2...\n", + "update_lazyhandle() done (duration: 134.3403525352478s)\n", + "[SpyreWorker] ... warmup finished.\n", + "\twarmup took 893.4236354827881s (for prompt length 64 and max output tokens 5)\n", + "[SpyreWorker] All warmups for 1 different prompt/decode-shape combinations finished. Total warmup time 893.4242045879364s.\n" + ] + } + ], + "source": [ + "from vllm import LLM, SamplingParams\n", + "\n", + "# Create an LLM.\n", + "llm = LLM(\n", + " model=\"/models/llama-7b-chat\",\n", + " tokenizer=\"/models/llama-7b-chat\",\n", + " max_model_len=2048,\n", + " block_size=2048,\n", + " device=\"spyre\")" + ] + }, + { + "cell_type": "markdown", + "id": "818488f4", + "metadata": {}, + "source": [ + "### 2. Create the prompt and `SamplingParams`" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "6c32e3e2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['Below is an instruction that describes a task. Write a response that appropriately completes the request. Be polite in your response to the user.\\n\\n### Instruction:\\nProvide a list of instructions for preparing chicken soup for a family of four.\\n\\n### Response:']\n" + ] + } + ], + "source": [ + "# Sample prompts.\n", + "template = (\n", + " \"Below is an instruction that describes a task. Write a response that \"\n", + " \"appropriately completes the request. Be polite in your response to the \"\n", + " \"user.\\n\\n### Instruction:\\n{}\\n\\n### Response:\"\n", + ")\n", + "prompt1 = template.format(\n", + " \"Provide a list of instructions for preparing chicken soup for a family \"\n", + " \"of four.\"\n", + ")\n", + "prompts = [\n", + " prompt1,\n", + "]\n", + "print(prompts)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "4cc1277e", + "metadata": {}, + "outputs": [], + "source": [ + "# Create a sampling params object.\n", + "max_tokens = 5\n", + "sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0)\n" + ] + }, + { + "cell_type": "markdown", + "id": "cffb16c5", + "metadata": {}, + "source": [ + "### 3. Generate the response" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "522c0610", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "=============== GENERATE\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processed prompts: 0%| | 0/1 [00:00 bool: def _is_cuda() -> bool: - has_cuda = torch.version.cuda is not None - return (VLLM_TARGET_DEVICE == "cuda" and has_cuda + return (VLLM_TARGET_DEVICE == "cuda" and (torch.version.cuda is not None) and not (_is_neuron() or _is_tpu() or _is_hpu())) @@ -299,6 +304,10 @@ def _is_cpu() -> bool: return VLLM_TARGET_DEVICE == "cpu" +def _is_spyre() -> bool: + return VLLM_TARGET_DEVICE == "spyre" + + def _is_openvino() -> bool: return VLLM_TARGET_DEVICE == "openvino" @@ -415,6 +424,8 @@ def get_vllm_version() -> str: if neuron_version != MAIN_CUDA_VERSION: neuron_version_str = neuron_version.replace(".", "")[:3] version += f"{sep}neuron{neuron_version_str}" + elif _is_spyre(): + version += f"{sep}spyre" elif _is_hpu(): # Get the Intel Gaudi Software Suite version gaudi_sw_version = str(get_gaudi_sw_version()) @@ -479,6 +490,8 @@ def _read_requirements(filename: str) -> List[str]: requirements = _read_requirements("requirements-rocm.txt") elif _is_neuron(): requirements = _read_requirements("requirements-neuron.txt") + elif _is_spyre(): + requirements = _read_requirements("requirements-spyre.txt") elif _is_hpu(): requirements = _read_requirements("requirements-hpu.txt") elif _is_openvino(): diff --git a/tests/spyre/spyre_util.py b/tests/spyre/spyre_util.py new file mode 100644 index 000000000..7fc3bc10d --- /dev/null +++ b/tests/spyre/spyre_util.py @@ -0,0 +1,277 @@ +import math +import os +from typing import Any, Dict, List, Tuple + +import numpy as np +from sentence_transformers import SentenceTransformer, util +from transformers import AutoModelForCausalLM, AutoTokenizer + +from vllm import LLM, SamplingParams + +DISABLE_ASSERTS = False # used for debugging + +ISCLOSE_REL_TOL_CPU = 0.1 +ISCLOSE_REL_TOL_SPYRE = 0.1 + + +# vLLM / Spyre +def generate_spyre_vllm_output(model: str, prompts: List[str], + warmup_shapes: List[Tuple[int, int, int]], + max_model_len: int, block_size: int, + sampling_params: SamplingParams, + tensor_parallel_size: int, + backend: str) -> List[Dict[str, Any]]: + + warmup_prompt_length = [t[0] for t in warmup_shapes] + warmup_new_tokens = [t[1] for t in warmup_shapes] + warmup_batch_size = [t[2] for t in warmup_shapes] + + os.environ['VLLM_SPYRE_WARMUP_PROMPT_LENS'] = ','.join( + str(val) for val in warmup_prompt_length) + os.environ['VLLM_SPYRE_WARMUP_NEW_TOKENS'] = ','.join( + str(val) for val in warmup_new_tokens) + os.environ['VLLM_SPYRE_WARMUP_BATCH_SIZES'] = ','.join( + str(val) for val in warmup_batch_size) + os.environ['VLLM_SPYRE_DYNAMO_BACKEND'] = backend + + vllm_model = LLM(model=model, + tokenizer=model, + max_model_len=max_model_len, + block_size=block_size, + tensor_parallel_size=tensor_parallel_size, + device="spyre") + + vllm_outputs = vllm_model.generate(prompts, sampling_params) + + results = [] + for req_output in vllm_outputs: + result = {} + result['text'] = req_output.outputs[0].text + result['token_ids'] = req_output.outputs[0].token_ids + result['tokens'] = tuple([ + req_output.outputs[0].logprobs[i][t].decoded_token + for i, t in enumerate(result['token_ids']) + ]) + result['logprobs'] = tuple([ + req_output.outputs[0].logprobs[i][t].logprob + for i, t in enumerate(result['token_ids']) + ]) + results.append(result) + + return results + + +# Hugging Face +def generate_hf_output(model: str, prompts: List[str], + max_new_tokens: int) -> List[Dict[str, Any]]: + + hf_model = AutoModelForCausalLM.from_pretrained(model) + hf_tokenizer = AutoTokenizer.from_pretrained(model) + + results = [] + for prompt_index, prompt in enumerate(prompts): + hf_input_tokens = hf_tokenizer(prompt, return_tensors="pt").input_ids + hf_output = hf_model.generate(hf_input_tokens, + do_sample=False, + max_new_tokens=max_new_tokens, + return_dict_in_generate=True, + output_scores=True) + + # decode output tokens after first removing input tokens (prompt) + hf_generated_text = hf_tokenizer.batch_decode( + hf_output.sequences[:, len(hf_input_tokens[0]):])[0] + hf_transition_scores = hf_model.compute_transition_scores( + hf_output.sequences, hf_output.scores, normalize_logits=True) + + # return HF generated text, tokens, token ids and logprobs + result = {} + result['text'] = hf_generated_text + result['token_ids'] = [] + result['tokens'] = [] + result['logprobs'] = [] + for tok_index, hf_logprob in enumerate(hf_transition_scores[0]): + hf_token_id = hf_output.sequences[0][tok_index + + len(hf_input_tokens[0])] + result['token_ids'].append(hf_token_id.item()) + result['tokens'].append(hf_tokenizer.decode(hf_token_id)) + result['logprobs'].append(hf_logprob.item()) + result['token_ids'] = tuple(result['token_ids']) + result['tokens'] = tuple(result['tokens']) + result['logprobs'] = tuple(result['logprobs']) + results.append(result) + + return results + + +# compare results +def compare_results(model: str, prompts: List[str], + warmup_shapes: List[Tuple[int, int, + int]], tensor_parallel_size: int, + backend: str, vllm_results: List[Dict[str, Any]], + hf_results: List[Dict[str, Any]]): + + print(f"\nmodel: {model:s}") + print(f"warmup shapes: {warmup_shapes}") + print(f"tp size: {tensor_parallel_size}") + print(f"backend: {backend:s}") + print(f"\n#prompts: {len(prompts):d}") + print(f"#HF results: {len(hf_results):d}" + f"{'' if len(hf_results) == len(prompts) else ' ERROR':s}") + print(f"#vLLM results: {len(vllm_results):d}" + f"{'' if len(vllm_results) == len(prompts) else ' ERROR':s}") + print() + + assert DISABLE_ASSERTS or len(hf_results) == len(vllm_results) + assert DISABLE_ASSERTS or len(hf_results) == len(prompts) + + for prompt_index, (prompt, hf_result, vllm_result) in enumerate( + zip(prompts, hf_results, vllm_results)): + err_msg = '' if hf_result['text'] == vllm_result['text'] else ' ERROR' + print(f"\nprompt {prompt_index:3d}: {repr(prompt):s}") + print("generated:") + print(f" HF: {repr(hf_result['text']):s}") + print(f" vLLM: {repr(vllm_result['text']):s}{err_msg}") + print() + + assert DISABLE_ASSERTS or backend == 'sendnn_decoder' or\ + hf_result['text'] == vllm_result['text'] + + if len(hf_result['tokens']) > 0: + print(" token id. token logprob " + " token id. token logprob") + + logprob_abs_diff_list = [] + logprob_rel_diff_list = [] + + for hf_token, hf_token_id, hf_logprob, vllm_token,\ + vllm_token_id, vllm_logprob in zip( + hf_result['tokens'], hf_result['token_ids'], + hf_result['logprobs'], vllm_result['tokens'], + vllm_result['token_ids'], vllm_result['logprobs']): + logprob_abs_diff = math.fabs(hf_logprob - vllm_logprob) + logprob_abs_diff_list.append(logprob_abs_diff) + logprob_rel_diff = math.fabs(logprob_abs_diff / hf_logprob) + logprob_rel_diff_list.append(logprob_rel_diff) + + print( + f"HF: {hf_token_id:8d} {repr(hf_token):14s} " + f"{hf_logprob:14f} " + f"vLLM: {vllm_token_id:8d} {repr(vllm_token):14s} " + f"{vllm_logprob:14f} ", + end='') + + if backend == 'sendnn_decoder': + rel_tol = ISCLOSE_REL_TOL_SPYRE + else: + rel_tol = ISCLOSE_REL_TOL_CPU + + if hf_token_id != vllm_token_id: # different tokens + if backend == 'sendnn_decoder' and math.isclose( + hf_logprob, vllm_logprob, rel_tol=rel_tol): + # probably still OK + print('DIVERGING') + break + else: + print('ERROR') + assert DISABLE_ASSERTS or False + break + else: # identical tokens + if math.isclose(hf_logprob, vllm_logprob, rel_tol=rel_tol): + print() + else: + print('ERROR') + assert DISABLE_ASSERTS or False + break + + print() + print("logprob absolute differences: " + f"average={np.mean(logprob_abs_diff_list):f} " + f"maximum={np.max(logprob_abs_diff_list):f}") + print("logprob relative differences: " + f"average={np.mean(logprob_rel_diff_list):f} " + f"maximum={np.max(logprob_rel_diff_list):f}") + + print() + + +# vLLM / Spyre +def spyre_vllm_embeddings(model: str, prompts: List[str], + warmup_shapes: List[Tuple[int, + int]], max_model_len: int, + block_size: int, tensor_parallel_size: int, + backend: str) -> List[Dict[str, Any]]: + + warmup_prompt_length = [t[0] for t in warmup_shapes] + warmup_new_tokens = [0] * len(warmup_shapes) + warmup_batch_size = [t[1] for t in warmup_shapes] + + os.environ['VLLM_SPYRE_WARMUP_PROMPT_LENS'] = ','.join( + str(val) for val in warmup_prompt_length) + os.environ['VLLM_SPYRE_WARMUP_NEW_TOKENS'] = ','.join( + str(val) for val in warmup_new_tokens) + os.environ['VLLM_SPYRE_WARMUP_BATCH_SIZES'] = ','.join( + str(val) for val in warmup_batch_size) + os.environ['VLLM_SPYRE_DYNAMO_BACKEND'] = backend + + vllm_model = LLM(model=model, + tokenizer=model, + max_model_len=max_model_len, + block_size=block_size, + tensor_parallel_size=tensor_parallel_size, + device="spyre") + + vllm_outputs = vllm_model.encode(prompts) + + results = [] + for req_output in vllm_outputs: + result = {} + result["embeddings"] = req_output.outputs.embedding + results.append(result) + + return results + + +# Hugging Face +def st_embeddings(model: str, prompts: List[str]) -> List[Dict[str, Any]]: + + model = SentenceTransformer(model) + + results = [] + for prompt in prompts: + embeddings = model.encode(prompt) + + # return ST generated embeddings + result = {} + result['embeddings'] = embeddings + results.append(result) + + return results + + +# compare results +def compare_embedding_results(model: str, prompts: List[str], + warmup_shapes: List[Tuple[int, int]], + tensor_parallel_size: int, backend: str, + vllm_results: List[Dict[str, Any]], + hf_results: List[Dict[str, Any]]): + + print(f"\nmodel: {model:s}") + print(f"warmup shapes: {warmup_shapes}") + print(f"tp size: {tensor_parallel_size}") + print(f"backend: {backend:s}") + print(f"\n#prompts: {len(prompts):d}") + print(f"#HF results: {len(hf_results):d}" + f"{'' if len(hf_results) == len(prompts) else ' ERROR':s}") + print(f"#vLLM results: {len(vllm_results):d}" + f"{'' if len(vllm_results) == len(prompts) else ' ERROR':s}") + print() + + assert DISABLE_ASSERTS or len(hf_results) == len(vllm_results) + assert DISABLE_ASSERTS or len(hf_results) == len(prompts) + + for hf_result, vllm_result in zip(hf_results, vllm_results): + + sim = util.pytorch_cos_sim(hf_result["embeddings"], + vllm_result["embeddings"]) + + assert math.isclose(sim, 1.0, rel_tol=0.05) diff --git a/tests/spyre/test_spyre_basic.py b/tests/spyre/test_spyre_basic.py new file mode 100644 index 000000000..0e2d73f32 --- /dev/null +++ b/tests/spyre/test_spyre_basic.py @@ -0,0 +1,73 @@ +"""Verification of vLLM output by comparing with HF + +Run `pytest tests/spyre/test_spyre_basic.py`. +""" + +from typing import List, Tuple + +import pytest +from spyre_util import (compare_results, generate_hf_output, + generate_spyre_vllm_output) + +from vllm import SamplingParams + + +@pytest.mark.parametrize("model", ["/models/llama-194m"]) +@pytest.mark.parametrize("prompts", [[ + "Provide a list of instructions for preparing" + " chicken soup for a family of four.", "Hello", + "What is the weather today like?", "Who are you?" +]]) +@pytest.mark.parametrize("warmup_shape", [(64, 20, 4), (64, 20, 8), + (128, 20, 4), (128, 20, 8)] + ) # (prompt_length/new_tokens/batch_size) +@pytest.mark.parametrize("backend", + ["eager"]) #, "inductor", "sendnn_decoder"]) +def test_output( + model: str, + prompts: List[str], + warmup_shape: Tuple[int, int, int], + backend: str, +) -> None: + ''' + The warmup is based on a single shape. After the warmup, + one request with the provided prompts is input to vLLM. + The same prompts are also input to HF. The generated output + including text, token ids, and logprobs, is verified to be + identical for vLLM and HF. + + If errors occur, these can be analyzed/debugged by setting + 'DISABLE_ASSERTS = True' in spyre_util.py and by rerunning the + test using 'pytest --capture=no tests/spyre/test_spyre_basic.py' + After debugging, DISABLE_ASSERTS should be reset to 'False'. + ''' + + max_new_tokens = warmup_shape[1] + + vllm_sampling_params = SamplingParams( + max_tokens=max_new_tokens, + temperature=0, + logprobs=0, # return logprobs of generated tokens only + ignore_eos=True) + + vllm_results = generate_spyre_vllm_output( + model=model, + prompts=prompts, + warmup_shapes=[warmup_shape], + max_model_len=2048, + block_size=2048, + sampling_params=vllm_sampling_params, + tensor_parallel_size=1, + backend=backend) + + hf_results = generate_hf_output(model=model, + prompts=prompts, + max_new_tokens=max_new_tokens) + + compare_results(model=model, + prompts=prompts, + warmup_shapes=[warmup_shape], + tensor_parallel_size=1, + backend=backend, + vllm_results=vllm_results, + hf_results=hf_results) diff --git a/tests/spyre/test_spyre_embeddings.py b/tests/spyre/test_spyre_embeddings.py new file mode 100644 index 000000000..22e4eda7b --- /dev/null +++ b/tests/spyre/test_spyre_embeddings.py @@ -0,0 +1,55 @@ +"""Verification of vLLM output by comparing with HF + +Run `pytest tests/spyre/test_spyre_basic.py`. +""" + +from typing import List, Tuple + +import pytest +from spyre_util import (compare_embedding_results, spyre_vllm_embeddings, + st_embeddings) + + +@pytest.mark.skip("Skip until failure is resolved.") +@pytest.mark.parametrize("model", ["/models/all-roberta-large-v1"]) +@pytest.mark.parametrize("prompts", [[ + "The capital of France is Paris." + "Provide a list of instructions for preparing" + " chicken soup for a family of four.", "Hello", + "What is the weather today like?", "Who are you?" +]]) +@pytest.mark.parametrize("warmup_shape", + [(64, 4), (64, 8), (128, 4), + (128, 8)]) # (prompt_length/new_tokens/batch_size) +@pytest.mark.parametrize("backend", + ["eager"]) #, "inductor", "sendnn_decoder"]) +def test_output( + model: str, + prompts: List[str], + warmup_shape: Tuple[int, int], + backend: str, +) -> None: + ''' + The warmup is based on a single shape. After the warmup, + one request with the provided prompts is input to vLLM. + The same prompts are also input to HF. The generated embeddings + are verified to be identical for vLLM and SentenceTransformers. + ''' + + vllm_results = spyre_vllm_embeddings(model=model, + prompts=prompts, + warmup_shapes=[warmup_shape], + max_model_len=256, + block_size=256, + tensor_parallel_size=1, + backend=backend) + + hf_results = st_embeddings(model=model, prompts=prompts) + + compare_embedding_results(model=model, + prompts=prompts, + warmup_shapes=[warmup_shape], + tensor_parallel_size=1, + backend=backend, + vllm_results=vllm_results, + hf_results=hf_results) diff --git a/tests/spyre/test_spyre_max_prompt_length.py b/tests/spyre/test_spyre_max_prompt_length.py new file mode 100644 index 000000000..37f02f3a6 --- /dev/null +++ b/tests/spyre/test_spyre_max_prompt_length.py @@ -0,0 +1,101 @@ +"""Verification of handling prompt length exceeding warmup shapes + +Run `pytest tests/spyre/test_spyre_max_prompt_length.py`. +""" + +from typing import List, Tuple + +import pytest +from spyre_util import (compare_results, generate_hf_output, + generate_spyre_vllm_output) +from transformers import AutoTokenizer + +from vllm import SamplingParams + + +@pytest.mark.parametrize("model", ["/models/llama-194m"]) +@pytest.mark.parametrize("prompts", [ + 7 * [ + "Hello", + "Below is an instruction that describes a task. Write a response" + " that appropriately completes the request. Be polite in your response" + " to the user. Provide a list of instructions for preparing chicken " + "soup for a family of four. Indicate if the weather forecast looks " + "good for today. Explain in a brief summary comprised of at most 50" + " words what you are." + ] +]) +@pytest.mark.parametrize("warmup_shapes", + [[(64, 20, 4)], [(64, 20, 4), (128, 20, 4)]] + ) # (prompt_length/new_tokens/batch_size) +@pytest.mark.parametrize("backend", + ["eager"]) #, "inductor", "sendnn_decoder"]) +def test_output( + model: str, + prompts: List[str], + warmup_shapes: List[Tuple[int, int, int]], + backend: str, +) -> None: + ''' + The warmup is based on one or multiple shapes. After the warmup, + one request with multiple provided prompts is input to vLLM. + At least one provided prompt should have a length longer than the + maximum prompt length defined by the warmup shapes. It is useful + to define enough prompts to fill multiple batches entirely and + partially, in order to test the maximum prompt length check + also in relation with the position of a prompt within a batch (not + likely that this will be an issue, but just to be sure). + It is verified that only for the prompts that + do not exceed the maximum prompt length, "non-empty" output is + generated. The output is verified using HF. + + If errors occur, these can be analyzed/debugged by setting + 'DISABLE_ASSERTS = True' in spyre_util.py and by rerunning the test + using 'pytest --capture=no tests/spyre/test_spyre_max_prompt_length.py' + After debugging, DISABLE_ASSERTS should be reset to 'False'. + ''' + + max_prompt_length = max([t[0] for t in warmup_shapes]) + max_new_tokens = max([t[1] for t in warmup_shapes]) + + vllm_sampling_params = SamplingParams( + max_tokens=max_new_tokens, + temperature=0, + logprobs=0, # return logprobs of generated tokens only + ignore_eos=True) + + vllm_results = generate_spyre_vllm_output( + model=model, + prompts=prompts, + warmup_shapes=warmup_shapes, + max_model_len=2048, + block_size=2048, + sampling_params=vllm_sampling_params, + tensor_parallel_size=1, + backend=backend) + + hf_results = generate_hf_output(model=model, + prompts=prompts, + max_new_tokens=max_new_tokens) + + # for prompts longer than the max_prompt_length, the corresponding + # output in hf_results is reset to 'empty' in order to create the + # expected output for vLLM + hf_tokenizer = AutoTokenizer.from_pretrained(model) + for prompt_index, prompt in enumerate(prompts): + hf_input_tokens = hf_tokenizer(prompt, return_tensors="pt").input_ids + if len(hf_input_tokens[0]) > max_prompt_length: + hf_results[prompt_index] = { + 'text': '', + 'token_ids': (), + 'tokens': (), + 'logprobs': () + } + + compare_results(model=model, + prompts=prompts, + warmup_shapes=warmup_shapes, + tensor_parallel_size=1, + backend=backend, + vllm_results=vllm_results, + hf_results=hf_results) diff --git a/tests/spyre/test_spyre_seed.py b/tests/spyre/test_spyre_seed.py new file mode 100644 index 000000000..ca55d7024 --- /dev/null +++ b/tests/spyre/test_spyre_seed.py @@ -0,0 +1,75 @@ +"""Verification of seeded random sampling to be deterministic + +Run `pytest tests/spyre/test_spyre_seed.py`. +""" + +import math +from typing import Tuple + +import pytest +from spyre_util import generate_spyre_vllm_output + +from vllm import SamplingParams + + +@pytest.mark.parametrize("model", ["/models/llama-194m"]) +@pytest.mark.parametrize("prompt", [ + "Provide a list of instructions for preparing" + " chicken soup for a family of four." +]) +@pytest.mark.parametrize("temperature", [0.1, 1.0]) +@pytest.mark.parametrize("seed", [42]) +@pytest.mark.parametrize("warmup_shape", [(64, 20, 4), (64, 20, 8), + (128, 20, 4), (128, 20, 8)] + ) # (prompt_length/new_tokens/batch_size) +@pytest.mark.parametrize("backend", + ["eager"]) #, "inductor", "sendnn_decoder"]) +def test_seed( + model: str, + prompt: str, + temperature: float, + seed: int, + warmup_shape: Tuple[int, int, int], + backend: str, +) -> None: + ''' + The warmup is based on a single shape. After the warmup, + output is generated for one request with 16 identical prompts + using random sampling (non-zero temperature) in combination + with a seed. The generated output, including text, token ids, + logprobs is verified to be identical for all 16 sequences. + ''' + + max_new_tokens = warmup_shape[1] + + prompts = [prompt] * 16 + + vllm_sampling_params = SamplingParams( + max_tokens=max_new_tokens, + temperature=temperature, + logprobs=0, # return logprobs of generated tokens only + ignore_eos=True, + seed=seed) + + vllm_results = generate_spyre_vllm_output( + model=model, + prompts=prompts, + warmup_shapes=[warmup_shape], + max_model_len=2048, + block_size=2048, + sampling_params=vllm_sampling_params, + tensor_parallel_size=1, + backend=backend) + + # compare all generated outputs against the first generated output + for vllm_result in vllm_results: + assert vllm_result['text'] == vllm_results[0]['text'] + + # compare logprobs for all tokens between + # the current and the first sequence + assert len(vllm_result['logprobs']) == len(vllm_results[0]['logprobs']) + for token_id, logprob, token_id_0, logprob_0 in zip( + vllm_result['token_ids'], vllm_result['logprobs'], + vllm_results[0]['token_ids'], vllm_results[0]['logprobs']): + assert token_id == token_id_0 + assert math.isclose(logprob, logprob_0, rel_tol=0.1) diff --git a/tests/spyre/test_spyre_tensor_parallel.py b/tests/spyre/test_spyre_tensor_parallel.py new file mode 100644 index 000000000..b3622f688 --- /dev/null +++ b/tests/spyre/test_spyre_tensor_parallel.py @@ -0,0 +1,77 @@ +"""Verification of vLLM output by comparing with HF + +Run `pytest tests/spyre/test_spyre_tensor_parallel.py`. +""" + +from typing import List, Tuple + +import pytest +from spyre_util import (compare_results, generate_hf_output, + generate_spyre_vllm_output) + +from vllm import SamplingParams + + +@pytest.mark.skip("Skip until failure is resolved.") +@pytest.mark.parametrize("model", ["/models/llama-194m"]) +@pytest.mark.parametrize("prompts", [[ + "Provide a list of instructions for preparing" + " chicken soup for a family of four.", "Hello", + "What is the weather today like?", "Who are you?" +]]) +@pytest.mark.parametrize("warmup_shapes", [[(64, 20, 4)]] + ) #,[(64,20,8)],[(128,20,4)],[(128,20,8)]]) +# (prompt_length/new_tokens/batch_size) +@pytest.mark.parametrize("tp_size", [2]) +@pytest.mark.parametrize("backend", + ["eager"]) #, "inductor", "sendnn_decoder"]) +def test_output( + model: str, + prompts: List[str], + warmup_shapes: List[Tuple[int, int, int]], + tp_size: int, + backend: str, +) -> None: + ''' + The warmup is based on one or multiple shapes. After the warmup, + one request with the provided prompts is input to vLLM which + is executed in tensor-parallel fashion on Spyres. + The same prompts are also input to HF. The generated output + including text, token ids, and logprobs, is verified to be + identical for vLLM and HF. + + If errors occur, these can be analyzed/debugged by setting + 'DISABLE_ASSERTS = True' in spyre_util.py and by rerunning the + test using 'pytest --capture=no tests/spyre/test_spyre_tensore_parallel.py' + After debugging, DISABLE_ASSERTS should be reset to 'False'. + ''' + + max_new_tokens = max([t[1] for t in warmup_shapes]) + + vllm_sampling_params = SamplingParams( + max_tokens=max_new_tokens, + temperature=0, + logprobs=0, # return logprobs of generated tokens only + ignore_eos=True) + + vllm_results = generate_spyre_vllm_output( + model=model, + prompts=prompts, + warmup_shapes=warmup_shapes, + max_model_len=2048, + block_size=2048, + sampling_params=vllm_sampling_params, + tensor_parallel_size=tp_size, + backend=backend) + + hf_results = generate_hf_output(model=model, + prompts=prompts, + max_new_tokens=max_new_tokens) + + compare_results(model=model, + prompts=prompts, + warmup_shapes=warmup_shapes, + tensor_parallel_size=tp_size, + backend=backend, + vllm_results=vllm_results, + hf_results=hf_results) diff --git a/tests/spyre/test_spyre_warmup_shapes.py b/tests/spyre/test_spyre_warmup_shapes.py new file mode 100644 index 000000000..be58ea516 --- /dev/null +++ b/tests/spyre/test_spyre_warmup_shapes.py @@ -0,0 +1,85 @@ +"""Verification of Spyre warmup shapes + +Run `pytest tests/spyre/test_spyre_warmup_shapes.py`. +""" + +from typing import List, Tuple + +import pytest +from spyre_util import (compare_results, generate_hf_output, + generate_spyre_vllm_output) + +from vllm import SamplingParams + + +@pytest.mark.parametrize("model", ["/models/llama-194m"]) +@pytest.mark.parametrize("prompts", [ + 7 * [ + "Hello", + "Below is an instruction that describes a task. Write a response that " + "appropriately completes the request. Be polite in your response to " + "the user. Provide a list of instructions for preparing chicken soup" + " for a family of four. Indicate if the weather forecast looks good " + "for today. Explain in a brief summary comprised of at most 50 words" + " what you are." + ] +]) +@pytest.mark.parametrize("warmup_shapes", [[(64, 20, 8), (128, 20, 4)]] + ) # (prompt_length/new_tokens/batch_size) +@pytest.mark.parametrize("backend", + ["eager"]) #, "inductor", "sendnn_decoder"]) +def test_output( + model: str, + prompts: List[str], + warmup_shapes: List[Tuple[int, int, int]], + backend: str, +) -> None: + ''' + The warmup is based on two shapes, that 'overlap' each + other. After the warmup, one request with the provided + prompts is input to vLLM. There should be at least one + prompt corresponding to each of the two warmup shapes. + It is useful to define enough prompts to fill multiple + batches entirely and partially, in order to test the + handling of overlapping warmup shapes also in relation + with the position of a prompt within a batch (not + likely that this will be an issue, but just to be sure). + The same prompts are also input to HF. The generated + output including text, token ids, and logprobs, is + verified to be identical for vLLM and HF. + + If errors occur, these can be analyzed/debugged by setting + 'DISABLE_ASSERTS = True' in spyre_util.py and by rerunning the + test using 'pytest --capture=no tests/spyre/test_spyre_warmup_shapes.py' + After debugging, DISABLE_ASSERTS should be reset to 'False'. + ''' + + max_new_tokens = max([t[1] for t in warmup_shapes]) + + vllm_sampling_params = SamplingParams( + max_tokens=max_new_tokens, + temperature=0, + logprobs=0, # return logprobs of generated tokens only + ignore_eos=True) + + vllm_results = generate_spyre_vllm_output( + model=model, + prompts=prompts, + warmup_shapes=warmup_shapes, + max_model_len=2048, + block_size=2048, + sampling_params=vllm_sampling_params, + tensor_parallel_size=1, + backend=backend) + + hf_results = generate_hf_output(model=model, + prompts=prompts, + max_new_tokens=max_new_tokens) + + compare_results(model=model, + prompts=prompts, + warmup_shapes=warmup_shapes, + tensor_parallel_size=1, + backend=backend, + vllm_results=vllm_results, + hf_results=hf_results) diff --git a/vllm/config.py b/vllm/config.py index e69cbd3eb..4669abd78 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,6 +1,7 @@ import copy import enum import json +import operator import warnings from dataclasses import dataclass, field, replace from pathlib import Path @@ -382,6 +383,8 @@ def _verify_quantization(self) -> None: ] tpu_supported_quantization = ["tpu_int8"] neuron_supported_quantization = ["neuron_quant"] + spyre_supported_quantization = ["gptq"] + if self.quantization is not None: self.quantization = self.quantization.lower() @@ -441,6 +444,11 @@ def _verify_quantization(self) -> None: raise ValueError( f"{self.quantization} quantization is currently not " f"supported in Neuron Backend.") + if current_platform.is_spyre( + ) and self.quantization not in spyre_supported_quantization: + raise ValueError( + f"{self.quantization} quantization is currently not " + f"supported in Spyre Backend.") def _verify_cuda_graph(self) -> None: if self.max_seq_len_to_capture is None: @@ -1148,6 +1156,39 @@ def __init__(self, self.num_scheduler_steps = num_scheduler_steps self.multi_step_stream_outputs = multi_step_stream_outputs self.send_delta_data = send_delta_data + self.spyre_scheduling_enabled = current_platform.is_spyre() + if self.spyre_scheduling_enabled: + # load warmup shapes and sort by "speed" + wup_prompt_lens = envs.VLLM_SPYRE_WARMUP_PROMPT_LENS or [] + wup_batch_sizes = envs.VLLM_SPYRE_WARMUP_BATCH_SIZES or [] + if len(wup_prompt_lens) != len(wup_batch_sizes): + raise RuntimeError( + "The lists in VLLM_SPYRE_WARMUP_PROMPT_LENS and " + "VLLM_SPYRE_WARMUP_BATCH_SIZES must have equal length") + if task == "embedding": + wup_new_tokens = [0] * len(wup_prompt_lens) + else: + wup_new_tokens = envs.VLLM_SPYRE_WARMUP_NEW_TOKENS or [] + if len(wup_new_tokens) != len(wup_prompt_lens): + raise RuntimeError( + "The lists in VLLM_SPYRE_WARMUP_PROMPT_LENS and " + "VLLM_SPYRE_WARMUP_NEW_TOKENS must have equal length") + + print("[SchedulerConfig] VLLM_SPYRE_WARMUP_PROMPT_LENS =", + wup_prompt_lens) + print("[SchedulerConfig] VLLM_SPYRE_WARMUP_NEW_TOKENS =", + wup_new_tokens) + print("[SchedulerConfig] VLLM_SPYRE_WARMUP_BATCH_SIZES =", + wup_batch_sizes) + + self.spyre_warmup_shapes = tuple( + sorted([{ + 'prompt_length': pl, + 'new_tokens': nt, + 'batch_size': bs + } for pl, nt, bs in zip(wup_prompt_lens, wup_new_tokens, + wup_batch_sizes)], + key=operator.itemgetter('batch_size', 'prompt_length'))) self.policy = policy self._verify_args() @@ -1195,6 +1236,8 @@ def __init__(self, device: str = "auto") -> None: self.device_type = "cuda" elif current_platform.is_neuron(): self.device_type = "neuron" + elif current_platform.is_spyre(): + self.device_type = "spyre" elif current_platform.is_hpu(): self.device_type = "hpu" elif current_platform.is_openvino(): @@ -1212,7 +1255,7 @@ def __init__(self, device: str = "auto") -> None: self.device_type = device # Some device types require processing inputs on CPU - if self.device_type in ["neuron", "openvino"]: + if self.device_type in ["neuron", "spyre", "openvino"]: self.device = torch.device("cpu") elif self.device_type in ["tpu"]: self.device = None diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index af4671ec2..ef6854e45 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -872,6 +872,9 @@ def _schedule_prefills( ignored_seq_groups: List[SequenceGroup] = [] seq_groups: List[ScheduledSequenceGroup] = [] + applicable_spyre_warmup_shapes = list( + self.scheduler_config.spyre_warmup_shapes) + waiting_queue = self.waiting leftover_waiting_sequences: Deque[SequenceGroup] = deque() @@ -941,6 +944,54 @@ def _schedule_prefills( num_new_seqs=num_new_seqs)): break + # check if current request can be scheduled based on the applicable + # spyre warmup shapes + if self.scheduler_config.spyre_scheduling_enabled: + max_tokens = 0 + if seq_group.sampling_params is not None and\ + seq_group.sampling_params.max_tokens is not None: + max_tokens = seq_group.sampling_params.max_tokens + updated_spyre_warmup_shapes = [ + shape for shape in applicable_spyre_warmup_shapes + if num_new_tokens <= shape['prompt_length'] + and max_tokens <= shape['new_tokens'] + and len(seq_groups) < shape['batch_size'] + ] + if not updated_spyre_warmup_shapes: + if not seq_groups: + # request was tested against all spyre warmup shapes: + # request cannot be processed + if (seq_group.sampling_params is not None + and seq_group.sampling_params.max_tokens + is not None): + logger.warning( + "No applicable warmup shape exists for " + "combination of prompt length (%d tokens) " + "and maximum number of output tokens to be " + "generated (%d tokens)", num_new_tokens, + seq_group.sampling_params.max_tokens) + else: + logger.warning( + "No applicable warmup shape exists for " + "combination of prompt length (%d tokens) " + "and undefined maximum number of output " + "tokens", num_new_tokens) + for seq in waiting_seqs: + seq.status = SequenceStatus.FINISHED_IGNORED + ignored_seq_groups.append(seq_group) + waiting_queue.popleft() + continue + else: + # request was only tested against spyre warmup shapes + # that remain after processing previous requests in + # waiting queue: request will be evaluated again in + # a future scheduling step + leftover_waiting_sequences.appendleft(seq_group) + waiting_queue.popleft() + continue + else: + applicable_spyre_warmup_shapes = updated_spyre_warmup_shapes + # Can schedule this request. if curr_loras is not None and lora_int_id > 0: curr_loras.add(lora_int_id) @@ -970,6 +1021,15 @@ def _schedule_prefills( budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) budget.add_num_seqs(seq_group.request_id, num_new_seqs) + # Check if number of scheduled requests has reached the maximum + # batch size of the applicable warmup shapes + if self.scheduler_config.spyre_scheduling_enabled and len( + seq_groups) >= max([ + shape['batch_size'] + for shape in applicable_spyre_warmup_shapes + ]): + break + # Queue requests that couldn't be scheduled. waiting_queue.extendleft(leftover_waiting_sequences) if len(seq_groups) > 0: @@ -1007,8 +1067,11 @@ def _schedule_default(self) -> SchedulerOutputs: running_scheduled = SchedulerRunningOutputs.create_empty() swapped_in = SchedulerSwappedInOutputs.create_empty() - # If any requests are swapped, prioritized swapped requests. - if not self.swapped: + # Schedule new prefills only when no requests have been swapped + # and all previous decodes have completed. + if not self.swapped and ( + not self.scheduler_config.spyre_scheduling_enabled + or not self.running): prefills = self._schedule_prefills(budget, curr_loras, enable_chunking=False) @@ -1198,8 +1261,13 @@ def _can_append_slots(self, seq_group: SequenceGroup, # chunked-prefill are enabled together. assert self.scheduler_config.is_multi_step and enable_chunking - return self.block_manager.can_append_slots( - seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) + if self.scheduler_config.spyre_scheduling_enabled: + # heuristic below doesn't make sense when using very large + # blocks + return True + else: + return self.block_manager.can_append_slots( + seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: # async_output_proc is allowed only when we have a single sequence diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9288cd22c..2626652fb 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -37,6 +37,7 @@ "openvino", "tpu", "xpu", + "spyre", "hpu", ] @@ -394,7 +395,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument('--block-size', type=int, default=EngineArgs.block_size, - choices=[8, 16, 32, 64, 128], help='Token block size for contiguous chunks of ' 'tokens. This is ignored on neuron devices and ' 'set to max-model-len') diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 5a5388708..0b6f3a00d 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -617,6 +617,14 @@ def _get_executor_cls( elif engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutorAsync executor_class = NeuronExecutorAsync + if engine_config.device_config.device_type == "spyre": + if distributed_executor_backend == "mp": + from vllm.executor.multiproc_spyre_executor import ( + MultiprocessingSpyreExecutorAsync) + executor_class = MultiprocessingSpyreExecutorAsync + else: + from vllm.executor.spyre_executor import SpyreExecutorAsync + executor_class = SpyreExecutorAsync elif engine_config.device_config.device_type == "tpu": if distributed_executor_backend == "ray": from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 2a5eaf134..9367714eb 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -516,6 +516,15 @@ def _get_executor_cls(cls, elif engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutor executor_class = NeuronExecutor + elif engine_config.device_config.device_type == "spyre": + if distributed_executor_backend == "mp": + from vllm.executor.multiproc_spyre_executor import ( + MultiprocessingSpyreExecutor) + executor_class = MultiprocessingSpyreExecutor + else: + from vllm.executor.spyre_executor import SpyreExecutor + executor_class = SpyreExecutor + elif engine_config.device_config.device_type == "tpu": if distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) diff --git a/vllm/envs.py b/vllm/envs.py index 853c49bc4..77fd1da1c 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -67,6 +67,9 @@ VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False + VLLM_SPYRE_WARMUP_PROMPT_LENS: Optional[List[int]] = None + VLLM_SPYRE_WARMUP_NEW_TOKENS: Optional[List[int]] = None + VLLM_SPYRE_WARMUP_BATCH_SIZES: Optional[List[int]] = None VLLM_DISABLED_KERNELS: List[str] = [] VLLM_USE_V1: bool = False VLLM_ENABLE_V1_MULTIPROCESSING: bool = False @@ -457,6 +460,41 @@ def get_default_config_root(): # If set, enable multiprocessing in LLM for the V1 code path. "VLLM_ENABLE_V1_MULTIPROCESSING": lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0"))), + + # Defines the prompt lengths the Spyre accelerator should be prepared + # for, formatted as comma separated list. + "VLLM_SPYRE_WARMUP_PROMPT_LENS": + lambda: [ + int(p) for p in os.getenv(key='VLLM_SPYRE_WARMUP_PROMPT_LENS', + default='64').split(',') + ], + + # Defines the max output tokens the Spyre accelerator should be prepared + # for, formatted as comma separated list. + "VLLM_SPYRE_WARMUP_NEW_TOKENS": + lambda: [ + int(d) for d in os.getenv(key='VLLM_SPYRE_WARMUP_NEW_TOKENS', + default='20').split(',') + ], + + # Defines the batch sizes the Spyre accelerator should be prepared + # for, formatted as comma separated list. + "VLLM_SPYRE_WARMUP_BATCH_SIZES": + lambda: [ + int(b) for b in os.getenv(key='VLLM_SPYRE_WARMUP_BATCH_SIZES', + default='1').split(',') + ], + + # Defines the backend that torch.compile will use when using Spyre + # Available options: + # - "sendnn_decoder": Compile for execution on Spyre hardware for + # decoder models + # - "sendnn": Compile for execution on Spyre hardware for + # encoder models + # - "inductor": Compile for execution on CPU (for debug and testing) + # - "eager": Skip compile entirely (for debug and testing + "VLLM_SPYRE_DYNAMO_BACKEND": + lambda: os.getenv("VLLM_SPYRE_DYNAMO_BACKEND", "sendnn_decoder"), } # end-env-vars-definition diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 9cba189dd..7786355db 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -12,8 +12,8 @@ class ExecutorBase(ABC): """Base class for all executors. An executor is responsible for executing the model on a specific device - type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor - that can execute the model on multiple devices. + type (e.g., CPU, GPU, Neuron, Spyre, etc.). Or it can be a distributed + executor that can execute the model on multiple devices. """ uses_ray: bool # whether the executor uses Ray for orchestration. diff --git a/vllm/executor/multiproc_spyre_executor.py b/vllm/executor/multiproc_spyre_executor.py new file mode 100644 index 000000000..2b3860431 --- /dev/null +++ b/vllm/executor/multiproc_spyre_executor.py @@ -0,0 +1,267 @@ +import json +import os +import platform +import signal +import threading +import weakref +from functools import partial +from typing import Any, List, Set, Tuple + +import torch + +from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, + ResultHandler, WorkerMonitor) +from vllm.executor.spyre_executor import SpyreExecutor, create_worker +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.utils import (get_distributed_init_method, get_open_port, + get_vllm_instance_id) + +logger = init_logger(__name__) + + +class MultiprocessingSpyreExecutor(SpyreExecutor): + """Python multiprocessing-based multi-Spyre executor""" + + uses_ray: bool = False + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks. + + This invokes `determine_num_available_blocks` on each worker and takes + the min of the results, guaranteeing that the selected cache sizes are + compatible with all workers. + + Returns: + - tuple[num_gpu_blocks, num_cpu_blocks] + """ + # Get the maximum number of blocks that can be allocated on GPU and CPU. + num_blocks = self._run_workers("determine_num_available_blocks", ) + + # Since we use a shared centralized controller, we take the minimum + # number of blocks across all workers to make sure all the memory + # operators can be applied to all workers. + num_gpu_blocks = min(b[0] for b in num_blocks) + num_cpu_blocks = min(b[1] for b in num_blocks) + + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache in all workers. + """ + + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. + logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, + num_cpu_blocks) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self._run_workers("initialize_cache", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + + def _init_executor(self) -> None: + self._check_executor_parameters() + + # Create the parallel GPU workers. + world_size = self.parallel_config.world_size + tensor_parallel_size = self.parallel_config.tensor_parallel_size + + # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers + os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() + + # Disable torch async compiling which won't work with daemonic processes + # [tom] it doesn't seme to work setting this from the code, we need to + # set at command line. hopefully will be fixed by upgrading torch. + os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" + + # Configure thread parallelism if OMP_NUM_THREADS isn't set + # + # Helps to avoid CPU contention. The default of spawning a thread per + # core combined with multiprocessing for each GPU can have a negative + # impact on performance. The contention is amplified when running in a + # container where CPU limits can cause throttling. + default_omp_num_threads = 1 + if "OMP_NUM_THREADS" not in os.environ and ( + current_parallelism := + torch.get_num_threads()) > default_omp_num_threads: + logger.warning( + "Reducing Torch parallelism from %d threads to %d to avoid " + "unnecessary CPU contention. Set OMP_NUM_THREADS in the " + "external environment to tune this value as needed.", + current_parallelism, default_omp_num_threads) + os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads) + torch.set_num_threads(default_omp_num_threads) + + # Multiprocessing-based executor does not support multi-node setting. + # Since it only works for single node, we can use the loopback address + # 127.0.0.1 for communication. + distributed_init_method = get_distributed_init_method( + "127.0.0.1", get_open_port()) + + self.workers: List[ProcessWorkerWrapper] = [] + # This is the list of workers that are rank 0 of each TP group EXCEPT + # global rank 0. These are the workers that will broadcast to the + # rest of the workers. + self.tp_driver_workers: List[ProcessWorkerWrapper] = [] + # This is the list of workers that are not drivers and not the first + # worker in a TP group. These are the workers that will be + # broadcasted to. + self.non_driver_workers: List[ProcessWorkerWrapper] = [] + + if world_size == 1: + self.worker_monitor = None + else: + result_handler = ResultHandler() + for rank in range(1, world_size): + worker = ProcessWorkerWrapper( + result_handler, + partial( + create_worker, + **self._get_create_worker_kwargs( + rank=rank, + local_rank=rank, + distributed_init_method=distributed_init_method, + ))) + self.workers.append(worker) + if rank % tensor_parallel_size == 0: + self.tp_driver_workers.append(worker) + else: + self.non_driver_workers.append(worker) + + self.worker_monitor = WorkerMonitor(self.workers, result_handler) + result_handler.start() + self.worker_monitor.start() + + # Set up signal handlers to shutdown the executor cleanly + # sometimes gc does not work well + + # Use weakref to avoid holding a reference to self + ref = weakref.ref(self) + + def shutdown(signum, frame): + if executor := ref(): + executor.shutdown() + + if threading.current_thread() is threading.main_thread(): + signal.signal(signal.SIGINT, shutdown) + signal.signal(signal.SIGTERM, shutdown) + + self.driver_worker = self._create_worker( + distributed_init_method=distributed_init_method) + self._run_workers("init_device") + self._run_workers("load_model") + + def _check_executor_parameters(self): + world_size = self.parallel_config.world_size + tensor_parallel_size = self.parallel_config.tensor_parallel_size + + # Read number of Spyre cards from senlib config file + if platform.machine() == "s390x": + spyre_device_count = int(os.getenv("AIU_WORLD_SIZE", 1)) + else: + with open("/etc/aiu/senlib_config.json", 'rb') as f: + config = json.load(f) + spyre_device_count = len(config["GENERAL"]["sen_bus_id"]) + + # Use confusing message for more common TP-only case. + assert tensor_parallel_size <= spyre_device_count, ( + f"please set tensor_parallel_size ({tensor_parallel_size}) " + f"to less than max local gpu count ({spyre_device_count})") + + assert world_size <= spyre_device_count, ( + f"please ensure that world_size ({world_size}) " + f"is less than than max local gpu count ({spyre_device_count})") + + def shutdown(self): + if (worker_monitor := getattr(self, "worker_monitor", + None)) is not None: + worker_monitor.close() + + def execute_model( + self, + execute_model_req: ExecuteModelRequest, + ) -> List[SamplerOutput]: + + output = self._run_workers("execute_model", + execute_model_req=execute_model_req) + return output[0] + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "add_lora", + lora_request=lora_request, + ) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "remove_lora", + lora_id=lora_id, + ) + + def pin_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "pin_lora", + lora_id=lora_id, + ) + + def list_loras(self) -> Set[int]: + return self._run_workers("list_loras") + + def _run_workers( + self, + method: str, + *args, + **kwargs, + ) -> Any: + """Runs the given method on all workers. + """ + + # Start all remote workers first. + worker_outputs = [ + worker.execute_method(method, *args, **kwargs) + for worker in self.workers + ] + + driver_worker_method = getattr(self.driver_worker, method) + driver_worker_output = driver_worker_method(*args, **kwargs) + + # Get the results of the workers. + return [driver_worker_output + ] + [output.get() for output in worker_outputs] + + def check_health(self) -> None: + """Raises an error if engine is unhealthy.""" + if self.worker_monitor is not None and not self.worker_monitor.is_alive( + ): + raise RuntimeError("Worker processes are not running") + + +class MultiprocessingSpyreExecutorAsync(MultiprocessingSpyreExecutor): + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest, + ) -> List[SamplerOutput]: + + # this is not really async, rather this a blocking call. + # there may well be perf implications, not sure. + # However, Spyre does not seem to play nice with asyncio. + # I think vLLM is moving away from the AsyncLLMEngine, + # so let's not waste time here for now. + output = self._run_workers("execute_model", + execute_model_req=execute_model_req) + + return output[0] + + async def stop_remote_worker_execution_loop_async(self) -> None: + return diff --git a/vllm/executor/spyre_executor.py b/vllm/executor/spyre_executor.py new file mode 100644 index 000000000..b07b620ae --- /dev/null +++ b/vllm/executor/spyre_executor.py @@ -0,0 +1,180 @@ +import os +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type + +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + make_async) +from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase + +logger = init_logger(__name__) + + +def create_worker(worker_module_name: str, worker_class_name: str, + worker_class_fn: Optional[Callable[[], Type[WorkerBase]]], + **kwargs): + wrapper = WorkerWrapperBase( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + worker_class_fn=worker_class_fn, + ) + wrapper.init_worker(**kwargs) + return wrapper.worker + + +class SpyreExecutor(ExecutorBase): + + uses_ray: bool = False + + def _init_executor(self) -> None: + assert (self.lora_config is + None), "LoRA is not supported for Spyre backend." + assert (not self.speculative_config + ), "Speculative decoding not yet supported for Spyre backend." + + #assert self.parallel_config.world_size == 1, ( + # "SpyreExecutor only supports single Spyre card.") + + if os.getenv(key='SPYRE_PYTEST_DEBUG', default='0') == '1': + import debugpy + host_addr = os.getenv(key='SPYRE_PYTEST_DBG_ADDR', + default='0.0.0.0') + debugpy.listen((host_addr, 5678)) + print(f"[debugpy] {host_addr}: wait for client...\n") + debugpy.wait_for_client() + + self.driver_worker = self._create_worker() + self.driver_worker.init_device() + self.driver_worker.load_model() + + def _get_worker_kwargs( + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None) -> Dict[str, Any]: + """Return worker init args for a given rank.""" + if distributed_init_method is None: + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + return dict( + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=(not self.parallel_config) + or (rank % self.parallel_config.tensor_parallel_size == 0), + ) + + def _get_worker_module_and_class( + self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]: + worker_class_fn = None + worker_module_name = "vllm.worker.spyre_worker" + worker_class_name = "SpyreWorker" + return (worker_module_name, worker_class_name, worker_class_fn) + + def _get_create_worker_kwargs( + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None) -> Dict: + + worker_kwargs = self._get_worker_kwargs(local_rank, rank, + distributed_init_method) + + (worker_module_name, worker_class_name, + worker_class_fn) = self._get_worker_module_and_class() + worker_kwargs.update( + worker_module_name=worker_module_name, + worker_class_name=worker_class_name, + worker_class_fn=worker_class_fn, + ) + return worker_kwargs + + def _create_worker(self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None): + return create_worker(**self._get_create_worker_kwargs( + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method)) + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. + """ + return self.driver_worker.determine_num_available_blocks() + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) + + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + + assert execute_model_req.num_lookahead_slots == 0, ( + "lookahead not supported for Spyre backend.") + + output = self.driver_worker.execute_model(execute_model_req) + + return output + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.driver_worker.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.driver_worker.remove_lora(lora_id) + + def pin_lora(self, lora_id: int) -> bool: + return self.driver_worker.pin_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.driver_worker.list_loras() + + def add_prompt_adapter(self, prompt_adapter_request) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the Spyre backend.") + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the Spyre backend.") + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError( + "Soft prompt is currently not supported by the Spyre backend.") + + def list_prompt_adapters(self) -> Set[int]: + raise NotImplementedError( + "Soft prompt is currently not supported by the Spyre backend.") + + def check_health(self) -> None: + # SpyreExecutor will always be healthy as long as + # it's running. + return + + +class SpyreExecutorAsync(SpyreExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest, + ) -> List[SamplerOutput]: + output = await make_async( + self.driver_worker.execute_model + )(seq_group_metadata_list=execute_model_req.seq_group_metadata_list, ) + return output + + async def check_health_async(self) -> None: + # SpyreExecutor will always be healthy as long as + # it's running. + return diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index c9366ca97..8006e0a91 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -61,10 +61,12 @@ def _check_marlin_supported( has_zp, device_capability) if quant_type not in supported_types: - return (False, f"Marlin does not support weight_bits = {quant_type}. " - f"Only types = {supported_types} " - f"are supported (for group_size = {group_size}, " - f"device_capability = {device_capability}, zp = {has_zp}).") + return ( + False, + f"Marlin does not support weight_bits = {quant_type.size_bits}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).") if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES): return (False, f"Marlin does not support group_size = {group_size}. " f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " diff --git a/vllm/model_executor/model_loader/spyre.py b/vllm/model_executor/model_loader/spyre.py new file mode 100644 index 000000000..d2c4634e4 --- /dev/null +++ b/vllm/model_executor/model_loader/spyre.py @@ -0,0 +1,213 @@ +"""Utilities for selecting and loading Spyre models.""" +import sys +from typing import List, Optional + +import torch +import torch._inductor.config +import torch.distributed as dist +import torch.nn as nn +from fms.models import get_model +from transformers import PretrainedConfig + +import vllm.envs as envs +from vllm.config import ModelConfig, ParallelConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import SequenceGroupMetadata + +try: + from torch_sendnn import torch_sendnn # noqa: F401 +except ImportError: + print("WARNING: Disabled: torch_sendnn") + pass +try: + import backends.dynamo_tracer # noqa: F401 +except ImportError: + print("WARNING: Disabled: dynamo_tracer") + pass + +BACKEND_LIST = ['sendnn_decoder', 'inductor'] + +logger = init_logger(__name__) + + +class SpyreCausalLM(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + ) -> None: + super().__init__() + self.config = config + self.logits_processor = LogitsProcessor(config.vocab_size, + logits_as_input=True) + self.sampler = Sampler() + self.past_key_value_states = None + self.dtype = torch.float16 if envs.VLLM_SPYRE_DYNAMO_BACKEND == \ + 'sendnn_decoder' else torch.float32 + # number of added padding sequences to fill + # batch to warmed up batch size + self.num_padded_sequences = 0 + + # Lazy initialized + self.model: nn.Module + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + masks: torch.Tensor, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> torch.Tensor: + + is_prompt = seq_group_metadata_list[0].is_prompt + if is_prompt: + self.past_key_value_states = None + + extra_kwargs = {} + if envs.VLLM_SPYRE_DYNAMO_BACKEND != "sendnn_decoder": + # Bug in 2.3.1 fixed in 2.4.1 for SDPA flash + # cpu impl when padding too much + extra_kwargs["attn_algorithm"] = "math" + + output = self.model( + input_ids, + position_ids=positions, + mask=masks, + past_key_value_states=self.past_key_value_states, + use_cache=True, + only_last_token=True, + **extra_kwargs, + ) + + logits, past_key_value_states = output + self.past_key_value_states = past_key_value_states + + # mark dynamic + if self.past_key_value_states is not None: + for layer in self.past_key_value_states: + for tensor in layer: + torch._dynamo.mark_dynamic(tensor, 2) + + # removing batch padding sequences to compute logits + batch_size = input_ids.shape[0] + + logits = logits[:batch_size - self.num_padded_sequences] + + return logits + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(None, hidden_states, sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, model_config: ModelConfig, max_prompt_length: int, + max_decode_length: int, + distributed_strategy: Optional[str], **kwargs): + + if self.dtype is not model_config.dtype: + logger.info( + "Ignoring user-provided dtype=%s and using dtype=%s instead.", + model_config.dtype, self.dtype) + + if model_config.quantization == "gptq": + + # note, we have to find a better way to package this + # shouldn't it be part of FMS? + sys.path.append("/home/senuser/aiu-fms") + + if envs.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder": + from aiu_as_addon import aiu_adapter, aiu_linear # noqa: F401 + linear_type = "gptq_aiu" + print("Loaded `aiu_as_addon` functionalities") + else: + from cpu_addon import cpu_linear # noqa: F401 + linear_type = "gptq_cpu" + print("Loaded `cpu_addon` functionalities") + + quant_cfg = model_config._parse_quant_hf_config() + + linear_config = { + "linear_type": linear_type, + "group_size": quant_cfg['group_size'], + "desc_act": quant_cfg['desc_act'], + } + data_type = None + model_source = "llama_gptq_hf_unfused_aiu" + else: + linear_config = {"linear_type": "torch_linear"} + data_type = self.dtype + model_source = "hf" + + # we can use fused weights unless running on Spyre + fused_weights = envs.VLLM_SPYRE_DYNAMO_BACKEND != "sendnn_decoder" + + self.model = get_model(architecture="hf_configured", + variant=model_config.model, + model_path=model_config.model, + source=model_source, + data_type=data_type, + distributed_strategy=distributed_strategy, + group=dist.group.WORLD, + fused_weights=fused_weights, + linear_config=linear_config) + + compile_mode = "default" + + self.model.eval() + torch.set_grad_enabled(False) + + _target_cache_size = max(int(max_decode_length * 2), + int(max_prompt_length * 2.5)) + if hasattr(torch._dynamo.config, "accumulated_cache_size_limit") and \ + _target_cache_size > torch._dynamo.config.\ + accumulated_cache_size_limit: + _prev = torch._dynamo.config.accumulated_cache_size_limit + torch._dynamo.config.accumulated_cache_size_limit = \ + _target_cache_size + print("NOTICE: Adjusting " + "torch._dynamo.config.accumulated_cache_size_limit" + f" from {_prev} to " + f"{torch._dynamo.config.accumulated_cache_size_limit} " + f"to accommodate prompt size of {max_prompt_length} " + f"and decode tokens of {max_decode_length}") + + if _target_cache_size > torch._dynamo.config.cache_size_limit: + _prev = torch._dynamo.config.cache_size_limit + torch._dynamo.config.cache_size_limit = _target_cache_size + print( + "NOTICE: Adjusting torch._dynamo.config.cache_size_limit from" + f" {_prev} to {torch._dynamo.config.cache_size_limit} to " + f"accommodate prompt size of {max_prompt_length} and " + f"decode tokens of {max_decode_length}") + + if envs.VLLM_SPYRE_DYNAMO_BACKEND in BACKEND_LIST: + self.model = torch.compile(self.model, + mode=compile_mode, + backend=envs.VLLM_SPYRE_DYNAMO_BACKEND) + + +def get_spyre_model(model_config: ModelConfig, parallel_config: ParallelConfig, + max_prompt_length, max_decode_length) -> nn.Module: + + # Create a model instance. + model = SpyreCausalLM(model_config.hf_config) + + # Load the weights from the cached or downloaded files. + model.load_weights( + model_config, + max_prompt_length=max_prompt_length, + max_decode_length=max_decode_length, + distributed_strategy="tp" if parallel_config.world_size > 1 else None) + + return model diff --git a/vllm/model_executor/model_loader/spyre_setup.py b/vllm/model_executor/model_loader/spyre_setup.py new file mode 100644 index 000000000..4ebb2c536 --- /dev/null +++ b/vllm/model_executor/model_loader/spyre_setup.py @@ -0,0 +1,145 @@ +import json +import os +import sys + +import torch + +# ============================================================== +# Common utilities +# ============================================================== +#------------- +# Discover the world size and my rank (envars set by torchrun) +# https://pytorch.org/docs/stable/elastic/run.html#environment-variables +#------------- +local_rank = int(os.getenv("LOCAL_RANK", 0)) +rank = int(os.getenv("RANK", 0)) +world_rank = rank +world_size = int(os.getenv("WORLD_SIZE", 1)) + +def dprint(text): + print(f"[{rank:2d}/{world_size:2d}]: {text}") + +# ============================================================== +# Common setup +# ============================================================== +def spyre_setup(rank=0, world_size=1, local_rank=0, local_size=1, verbose=False): + # ------------- + # Envar setup for backend + # ------------- + # Environment variable created by the runtime to identify the specific Spyre card that is assigned to this rank + spyre_config_file_envar = "AIU_CONFIG_FILE_" + str(rank) + + # Default to senulator backend unless user specified otherwise + os.environ.setdefault("FLEX_COMPUTE", "SENULATOR") + os.environ.setdefault("FLEX_DEVICE", "MOCK") + + # Each rank needs a unique space to write its binaries + # For both 'export' and '__pycache' + # https://docs.python.org/3/library/sys.html#sys.pycache_prefix + os.environ.setdefault("DEEPRT_EXPORT_DIR", "export") + os.environ.setdefault("DTCOMPILER_EXPORT_DIR", "export") + if world_size > 1: + os.environ["DEEPRT_EXPORT_DIR"] += f"/{rank}" + os.environ["DTCOMPILER_EXPORT_DIR"] += f"/{rank}" + sys.pycache_prefix=os.getenv("DEEPRT_EXPORT_DIR")+"/py-" + str(rank) + os.environ.setdefault("DTCOMPILER_KEEP_EXPORT", "1") + + # Inform Flex of the size of this job + os.environ.setdefault("FLEX_RDMA_WORLD_SIZE", str(world_size)) + os.environ.setdefault("FLEX_RDMA_WORLD_RANK", str(rank)) + os.environ.setdefault("FLEX_RDMA_LOCAL_SIZE", str(world_size)) + os.environ.setdefault("FLEX_RDMA_LOCAL_RANK", str(rank)) + for peer_rank in range(world_size): + pcie_env_str="AIU_WORLD_RANK_"+str(peer_rank) + flex_env_str="FLEX_RDMA_PCI_BUS_ADDR_"+str(peer_rank) + if os.getenv(pcie_env_str) is None: + raise RuntimeError(f"Error: The environment variable {pcie_env_str} is not defined") + if os.getenv(flex_env_str) is None: + raise RuntimeError(f"Error: The environment variable {flex_env_str} is not defined") + if os.getenv("DUMP_MEMMAP") is not None: + if os.getenv("SDSC_REF_DIR") is None: + os.environ["SDSC_REF_DIR"] = os.environ["DEEPRT_EXPORT_DIR"] + else: + os.environ["SDSC_REF_DIR"] += f"/{rank}" + assert ( + os.getenv("DUMP_MEMMAP_DIR") is not None + ), "DUMP_MEMMAP_DIR not set while DUMP_MEMMAP set" + os.environ["DUMP_MEMMAP_DIR"] += f"/{rank}" + os.makedirs( + os.environ["DUMP_MEMMAP_DIR"], exist_ok=True + ) # directory needs to exist + + for peer_rank in range(world_size): + pcie_env_str = "AIU_WORLD_RANK_" + str(peer_rank) + flex_env_str = "FLEX_RDMA_PCI_BUS_ADDR_" + str(peer_rank) + if os.getenv("FLEX_COMPUTE") == "SENULATOR": + if os.getenv(pcie_env_str) is not None: + os.environ[flex_env_str] = os.getenv(pcie_env_str) + else: + os.environ[pcie_env_str] = f"0000:{rank:02x}:01.0" + os.environ[flex_env_str] = f"0000:{rank:02x}:01.0" + else: + if os.getenv(flex_env_str) is None: + if os.getenv("PCIDEVICE_IBM_COM_SENTIENT_PF") is not None: + os.environ[pcie_env_str] = os.getenv( + "PCIDEVICE_IBM_COM_SENTIENT_PF" + ) + + if os.getenv(pcie_env_str) is not None: + os.environ[flex_env_str] = os.getenv(pcie_env_str) + else: + raise RuntimeError( + f"[{rank}/{world_size}]: ERROR: {flex_env_str} and {pcie_env_str} were not set for peer {peer_rank}." + ) + if rank == 0 and verbose: + dprint(f"PCI Addr Rank {peer_rank} {pcie_env_str}={os.environ[pcie_env_str]}") + dprint(f"PCI Addr Rank {peer_rank} {flex_env_str}={os.environ[flex_env_str]}") + + if rank == 0 and verbose: + dprint(f"FLEX_COMPUTE=" + os.getenv("FLEX_COMPUTE")) + dprint(f"FLEX_DEVICE=" + os.getenv("FLEX_DEVICE")) + dprint(f"DEEPRT_EXPORT_DIR=" + os.getenv("DEEPRT_EXPORT_DIR")) + dprint(f"DTCOMPILER_EXPORT_DIR=" + os.getenv("DTCOMPILER_EXPORT_DIR")) + if os.getenv(spyre_config_file_envar) is not None: + dprint(f"{spyre_config_file_envar}=" + os.environ[spyre_config_file_envar]) + if os.getenv("SENLIB_DEVEL_CONFIG_FILE") is not None: + dprint(f"SENLIB_DEVEL_CONFIG_FILE=" + os.environ["SENLIB_DEVEL_CONFIG_FILE"]) + if os.getenv(flex_env_str) is not None: + dprint(f"{flex_env_str}=" + os.environ[flex_env_str]) + dprint(f"FLEX_RDMA_LOCAL_RANK=" + os.getenv("FLEX_RDMA_LOCAL_RANK")) + dprint(f"FLEX_RDMA_LOCAL_SIZE=" + os.getenv("FLEX_RDMA_LOCAL_SIZE")) + dprint(f"FLEX_RDMA_WORLD_RANK=" + os.getenv("FLEX_RDMA_WORLD_RANK")) + dprint(f"FLEX_RDMA_WORLD_SIZE=" + os.getenv("FLEX_RDMA_WORLD_SIZE")) + + if os.getenv("FLEX_COMPUTE") == "SENTIENT": + pcie_env_str = "AIU_WORLD_RANK_" + str(rank) + if os.getenv(pcie_env_str) is not None: + device_id = os.getenv(pcie_env_str) + else: + with open(os.getenv(spyre_config_file_envar)) as fd: + data = json.load(fd) + device_id = data["GENERAL"]["sen_bus_id"] + dprint(f"Spyre: Enabled ({device_id})") + else: + dprint(f"Spyre: Disabled (Senulator)") + + +# ============================================================== +# Distributed setup +# ============================================================== +def spyre_dist_setup(rank, world_size, local_rank=-0, local_size=-1, verbose=False): + if local_rank < 0: + local_rank = rank + if local_size < 0: + local_size = world_size + + if os.getenv("TORCHELASTIC_RUN_ID") is None: + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + elif rank == 0 or verbose: + dprint(f"Detected running via torchrun") + + if rank == 0 or verbose: + dprint(f"Parallel Backend: {torch.distributed.get_backend()}") + + spyre_setup(rank, world_size) \ No newline at end of file diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py index 1f68fc2e2..04b670c77 100644 --- a/vllm/platforms/__init__.py +++ b/vllm/platforms/__init__.py @@ -83,6 +83,13 @@ except Exception: pass +is_spyre = False +try: + from importlib.metadata import version + is_spyre = "spyre" in version("vllm") +except Exception: + pass + if is_tpu: # people might install pytorch built with cuda but run on tpu # so we need to check tpu first @@ -109,6 +116,9 @@ elif is_openvino: from .openvino import OpenVinoPlatform current_platform = OpenVinoPlatform() +elif is_spyre: + from .spyre import SpyrePlatform + current_platform = SpyrePlatform() else: current_platform = UnspecifiedPlatform() diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index f4849fa2c..1bf5434c5 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -34,6 +34,7 @@ class PlatformEnum(enum.Enum): CPU = enum.auto() NEURON = enum.auto() OPENVINO = enum.auto() + SPYRE = enum.auto() UNSPECIFIED = enum.auto() @@ -81,6 +82,9 @@ def is_neuron(self) -> bool: def is_openvino(self) -> bool: return self._enum == PlatformEnum.OPENVINO + def is_spyre(self) -> bool: + return self._enum == PlatformEnum.SPYRE + def is_cuda_alike(self) -> bool: """Stateless version of :func:`torch.cuda.is_available`.""" return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM) diff --git a/vllm/platforms/spyre.py b/vllm/platforms/spyre.py new file mode 100644 index 000000000..3634c1f4f --- /dev/null +++ b/vllm/platforms/spyre.py @@ -0,0 +1,9 @@ +from .interface import Platform, PlatformEnum + + +class SpyrePlatform(Platform): + _enum = PlatformEnum.SPYRE + + @classmethod + def get_device_name(cls, device_id: int = 0) -> str: + return "spyre" diff --git a/vllm/utils.py b/vllm/utils.py index 2bbdc8d1e..fe9bc9d77 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -726,6 +726,9 @@ def is_pin_memory_available() -> bool: elif current_platform.is_neuron(): print_warning_once("Pin memory is not supported on Neuron.") return False + elif current_platform.is_spyre(): + print_warning_once("Pin memory is not supported on Spyre device.") + return False elif current_platform.is_hpu(): print_warning_once("Pin memory is not supported on HPU.") return False diff --git a/vllm/worker/spyre_embedding_model_runner.py b/vllm/worker/spyre_embedding_model_runner.py new file mode 100644 index 000000000..447742d02 --- /dev/null +++ b/vllm/worker/spyre_embedding_model_runner.py @@ -0,0 +1,171 @@ +import time +from typing import Dict, Iterable, List, Optional, Tuple + +import torch +from transformers import AutoModel + +import vllm.envs as envs +from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, + SchedulerConfig) +from vllm.logger import init_logger +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.pooling_params import PoolingParams +from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata + +from .spyre_model_runner import SpyreModelRunner + +logger = init_logger(__name__) + +BACKEND_LIST = ['sendnn', 'inductor'] + + +class SpyreEmbeddingModelRunner(SpyreModelRunner): + + # Map of request_id -> generator used for seeded random sampling + generators: Dict[str, torch.Generator] = {} + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + ): + super().__init__(model_config=model_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config) + + pooler_config = model_config.pooler_config + self.pooler = Pooler.from_config_with_defaults( + pooler_config, + pooling_type=PoolingType.CLS, + normalize=True, + softmax=False) + + def load_model(self, prompt_lens: Iterable[int], + num_decode_tokens: Iterable[int], + batch_sizes: Iterable[int]) -> None: + self.model = AutoModel.from_pretrained(self.model_config.model) + self.model.eval() + torch.set_grad_enabled(False) + if envs.VLLM_SPYRE_DYNAMO_BACKEND in BACKEND_LIST: + self.model = torch.compile(self.model, + mode="default", + dynamic=False, + backend=envs.VLLM_SPYRE_DYNAMO_BACKEND) + + @property + def vocab_size(self) -> int: + return self.model.config.vocab_size + + def prepare_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_requests_ids: Optional[List[str]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, PoolingMetadata]: + # NOTE: We assume that all sequences in the group are all prompts + (input_tokens, input_positions, input_masks, + seq_lens) = self._prepare_prompt(seq_group_metadata_list) + + pooling_metadata = self._prepare_pooling( + seq_group_metadata_list=seq_group_metadata_list, + prompt_lens=seq_lens) + return (input_tokens, input_positions, input_masks, pooling_metadata) + + def _prepare_pooling( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + ) -> PoolingMetadata: + """Prepare PoolingMetadata for the sequence group metadata list.""" + seq_groups: List[Tuple[List[int], PoolingParams]] = [] + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = list(seq_group_metadata.seq_data.keys()) + pooling_params = seq_group_metadata.pooling_params + seq_groups.append((seq_ids, pooling_params)) + + seq_data: Dict[int, SequenceData] = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + + pooling_metadata = PoolingMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + ) + + return pooling_metadata + + def pad_input_ids( + self, + input_ids_list: List[torch.Tensor], + min_pad_length: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + padded_input_ids_list, mask_list, position_ids_list = self.\ + _prepare_pad_input_ids(input_ids_list, min_pad_length) + + input_ids = torch.stack(padded_input_ids_list) + mask = torch.stack(mask_list) + position_ids = torch.stack(position_ids_list) + + return input_ids, position_ids, mask + + def execute_model( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_requests_ids: Optional[List[str]] = None, + ) -> Optional[PoolerOutput]: + (input_tokens, input_positions, input_masks, + pooling_metadata) = self.prepare_input_tensors( + seq_group_metadata_list, finished_requests_ids) + t0 = time.time() + + outputs = self.model( + input_ids=input_tokens, + # Let the Embedding layer use it's default + # because the rules can be a bit different + # e.g. For Roberta models the inputs start + # at padding_inx +1 + #position_ids=input_positions, + attention_mask=input_masks) + hidden_states = outputs["last_hidden_state"] + + unpadded = [] + max_len = hidden_states.shape[1] + + for i, seq_len in enumerate(pooling_metadata.prompt_lens): + unpadded.append(hidden_states[i, max_len - seq_len:, :]) + + hidden_states = torch.concat(unpadded) + + pooler_output = self.pooler(hidden_states=hidden_states, + pooling_metadata=pooling_metadata) + + t1 = time.time() - t0 + print("[spyre_model_runner:execute_model] t_token: %.2fms" % + (t1 * 1000)) + + return pooler_output + + def _raw_model_forward( + self, + input_ids: torch.Tensor, + mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value_states: Optional[List[Tuple[torch.Tensor, + torch.Tensor]]] = None, + use_cache: bool = False, + only_last_token: bool = False, + attn_algorithm: Optional[str] = None + ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, + torch.Tensor]]]]: + + hidden_states, _ = self.model( + input_ids=input_ids, + attention_mask=mask, + #position_ids=position_ids + ) + return hidden_states, None diff --git a/vllm/worker/spyre_model_runner.py b/vllm/worker/spyre_model_runner.py new file mode 100644 index 000000000..e1925811c --- /dev/null +++ b/vllm/worker/spyre_model_runner.py @@ -0,0 +1,349 @@ +import time +from typing import Dict, Iterable, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, + SchedulerConfig) +from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader.spyre import get_spyre_model +from vllm.sequence import SequenceGroupMetadata +from vllm.utils import is_pin_memory_available + +logger = init_logger(__name__) + + +class SpyreModelRunner: + + # Map of request_id -> generator used for seeded random sampling + generators: Dict[str, torch.Generator] = {} + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + + self.pad_token_id = 0 + if model_config is not None: + if model_config.hf_config is not None: + self.pad_token_id = getattr(model_config.hf_config, + "pad_token_id", None) or 0 + if model_config.get_sliding_window(): + logger.warning("Sliding window is not supported on Spyre. " + "The model will run without sliding window.") + self.device_config = (device_config + if device_config is not None else DeviceConfig()) + self.device = self.device_config.device + self.pin_memory = is_pin_memory_available() + # position_ids of all the sequences in current batch + self._position_ids: torch.Tensor = None + # attention masks of all the sequences in current batch + self._mask: torch.Tensor = None + # Lazy initialization: after load_model. + self.model: nn.Module + + def load_model(self, prompt_lens: Iterable[int], + num_decode_tokens: Iterable[int], + batch_sizes: Iterable[int]) -> None: + max_pad_length = max(prompt_lens) + max_decode_length = max(num_decode_tokens) + self.model = get_spyre_model(self.model_config, + parallel_config=self.parallel_config, + max_prompt_length=max_pad_length, + max_decode_length=max_decode_length) + + @property + def vocab_size(self) -> int: + return self.model.model.config.src_vocab_size + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + assert len(seq_group_metadata_list) > 0 + input_token_list: List[torch.Tensor] = [] + + # find warmup shape to be used for padding and batching + applicable_spyre_warmup_shapes = [ + shape for shape in self.scheduler_config.spyre_warmup_shapes + if len(seq_group_metadata_list) <= shape['batch_size'] + ] + for seq_group_metadata in seq_group_metadata_list: + seq_data = seq_group_metadata.seq_data[list( + seq_group_metadata.seq_data.keys())[0]] + # retrieve initial (unpadded) tokens + prompt_tokens = seq_data.get_token_ids() + new_tokens = seq_group_metadata.sampling_params.max_tokens\ + if seq_group_metadata.sampling_params is not None else 0 + + updated_spyre_warmup_shapes = [ + shape for shape in applicable_spyre_warmup_shapes + if len(prompt_tokens) <= shape['prompt_length'] + and new_tokens <= shape['new_tokens'] + ] + applicable_spyre_warmup_shapes = updated_spyre_warmup_shapes + + assert applicable_spyre_warmup_shapes + + # If multiple warmup shapes apply, the first one is selected. + # For improving performance, the warmup shapes in scheduler_config + # are ordered by "processing speed". + min_pad_length_batch = applicable_spyre_warmup_shapes[0][ + 'prompt_length'] + padded_batch_size = applicable_spyre_warmup_shapes[0]['batch_size'] + + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + # retrieve initial (unpadded) tokens + prompt_tokens = seq_data.get_token_ids() + + input_token_list.append( + torch.tensor(prompt_tokens, + dtype=torch.long, + device=torch.device("cpu"))) + + # set number of added padding sequences used for computing logits + self.model.num_padded_sequences = padded_batch_size - len( + input_token_list) + + # padding to compiled batch size + while len(input_token_list) < padded_batch_size: + input_token_list.append( + torch.zeros(min_pad_length_batch, + dtype=torch.long, + device=torch.device("cpu"))) + + # get position ids and attention mask + input_tokens, self._position_ids, self._mask = self.pad_input_ids( + input_token_list, min_pad_length=min_pad_length_batch) + + seq_lens = [t.shape[0] for t in input_token_list] + + return input_tokens, self._position_ids, self._mask, seq_lens + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[List[int]] = [] + + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append([generation_token]) + + # padding to compiled batch size + actual_batch_size = len(seq_group_metadata_list) + padded_batch_size = self._position_ids.shape[0] + while actual_batch_size < padded_batch_size: + input_tokens.append([0]) + actual_batch_size += 1 + + # update position ids and attention mask + self._update_position_ids() + self._update_mask() + + input_tokens = torch.tensor(input_tokens, + dtype=torch.long, + device=self.device) + + return input_tokens, self._position_ids, self._mask + + def _update_position_ids(self) -> None: + """Updating the position ids of all sequences + in a batch. Will be called in decoding phase""" + + self._position_ids = self._position_ids[:, -1] + 1 + self._position_ids = self._position_ids.unsqueeze(-1) + + def _update_mask(self) -> None: + """Updating/extending the attention masks of all + sequences in a batch. Will be called in decoding phase""" + + assert self._mask is not None + masks = self._mask + + masks_new = [] + for mask in masks: + # get the last row of the 3d mask + mask_new = mask[-1:, :] + + # extend the mask one slot + mask_new = torch.cat( + ( + mask_new, + torch.zeros( + 1, 1, dtype=mask_new.dtype, device=mask_new.device), + ), + dim=1, + ) + masks_new.append(mask_new) + + self._mask = torch.stack(masks_new, dim=0) + + def prepare_input_tensors( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_requests_ids: Optional[List[str]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]: + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, input_masks, + _) = self._prepare_prompt(seq_group_metadata_list) + seq_lens = [ + input_tokens.shape[1] for i in range(input_tokens.shape[0]) + ] + else: + (input_tokens, input_positions, + input_masks) = self._prepare_decode(seq_group_metadata_list) + seq_lens = [] + + # Clean up generators from completed requests + if finished_requests_ids: + for request_id in finished_requests_ids: + self.generators.pop(request_id, None) + + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + seq_lens, + # query_lens is not needed if chunked prefill is not + # supported. Since Spyre worker doesn't support chunked prefill + # just use seq_lens instead. + seq_lens, + self.device, + self.pin_memory, + self.generators) + return (input_tokens, input_positions, input_masks, sampling_metadata) + + def execute_model( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + finished_requests_ids: Optional[List[str]] = None, + ) -> Optional[SamplerOutput]: + (input_tokens, input_positions, input_masks, + sampling_metadata) = self.prepare_input_tensors( + seq_group_metadata_list, finished_requests_ids) + t0 = time.time() + hidden_states = self.model( + input_ids=input_tokens, + positions=input_positions, + masks=input_masks, + seq_group_metadata_list=seq_group_metadata_list, + ) + + # Compute the logits. + logits = self.model.compute_logits(hidden_states, sampling_metadata) + + # Sample the next token. + output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + t1 = time.time() - t0 + print("[spyre_model_runner:execute_model] t_token: %.2fms" % + (t1 * 1000)) + + return output + + def _prepare_pad_input_ids( + self, + input_ids_list: List[torch.Tensor], + min_pad_length: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """left side padding implemented as + in fms.utils.generation.pad_input_id""" + max_len = max([min_pad_length] + + [seq.size(0) for seq in input_ids_list]) + padded_input_ids_list = [] + mask_list = [] + position_ids_list = [] + for input_ids_i in input_ids_list: + seq_len = input_ids_i.size(0) + if max_len > seq_len: + print(f"[SpyreModelRunner] INFO: Padding request of length " + f"{seq_len} tokens to {max_len} tokens.") + pads = torch.ones(max_len - seq_len, + dtype=torch.long, + device=input_ids_i.device) * self.pad_token_id + non_pads = torch.ones(seq_len, + dtype=torch.long, + device=input_ids_i.device) + + pos_ids_pads = pads + pos_ids_seq = torch.arange(0, + seq_len, + dtype=torch.long, + device=input_ids_i.device) + + # Setting this to 0, however if 0 is the eos, we will end up + # truncating the output if using truncate_after_eos once this + # workflow works for nested tensor, this can probably be removed + padded_input_ids_list.append(torch.cat((pads, input_ids_i))) + mask_list.append(torch.cat((torch.zeros_like(pads), non_pads))) + position_ids_list.append(torch.cat((pos_ids_pads, pos_ids_seq))) + + return padded_input_ids_list, mask_list, position_ids_list + + def pad_input_ids( + self, + input_ids_list: List[torch.Tensor], + min_pad_length: int = 0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + padded_input_ids_list, mask_list, position_ids_list = self.\ + _prepare_pad_input_ids(input_ids_list, min_pad_length) + + input_ids = torch.stack(padded_input_ids_list) + mask = torch.stack(mask_list).bool() + # this is a causal mask for generation + mask = (mask.unsqueeze(-1) == mask.unsqueeze(-2)).tril() + mask = torch.where(mask.logical_not(), -torch.inf, 0.0) + mask = mask.to(self.model.dtype) + position_ids = torch.stack(position_ids_list) + + return input_ids, position_ids, mask + + def _raw_model_forward( + self, + input_ids: torch.Tensor, + mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_value_states: Optional[List[Tuple[torch.Tensor, + torch.Tensor]]] = None, + use_cache: bool = False, + only_last_token: bool = False, + attn_algorithm: Optional[str] = None + ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, + torch.Tensor]]]]: + + return self.model.model(input_ids, + mask=mask, + position_ids=position_ids, + past_key_value_states=past_key_value_states, + use_cache=use_cache, + only_last_token=only_last_token, + attn_algorithm=attn_algorithm) diff --git a/vllm/worker/spyre_worker.py b/vllm/worker/spyre_worker.py new file mode 100644 index 000000000..fa95ec444 --- /dev/null +++ b/vllm/worker/spyre_worker.py @@ -0,0 +1,337 @@ +"""A Spyre worker class.""" +import json +import os +import platform +import time +from typing import List, Optional, Tuple + +import torch +import torch.distributed as dist + +import vllm.envs as envs +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, SchedulerConfig) +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) +from vllm.model_executor import set_random_seed +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.model_loader import spyre_setup +from vllm.sequence import ExecuteModelRequest +from vllm.worker.spyre_embedding_model_runner import SpyreEmbeddingModelRunner +from vllm.worker.spyre_model_runner import SpyreModelRunner +from vllm.worker.worker_base import LoraNotSupportedWorkerBase + + +class SpyreWorker(LoraNotSupportedWorkerBase): + """A worker class that executes the model on a group of Spyre cores. + """ + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + cache_config: CacheConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool = False, + ) -> None: + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.cache_config = cache_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.is_driver_worker = is_driver_worker + if parallel_config and is_driver_worker: + assert rank % parallel_config.tensor_parallel_size == 0, \ + "Driver worker should be rank 0 of tensor parallel group." + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + if self.model_config.task == "embedding": + self.model_runner: SpyreModelRunner = SpyreEmbeddingModelRunner( + model_config, parallel_config, scheduler_config, device_config) + else: + self.model_runner = SpyreModelRunner(model_config, parallel_config, + scheduler_config, + device_config) + self._env_initialized = False + + def init_distributed_environment(self) -> None: + """Initialize the distributed environment.""" + + init_distributed_environment( + world_size=self.parallel_config.world_size, + rank=self.rank, + distributed_init_method="env://", + backend="gloo", + ) + + torch._C._distributed_c10d._register_process_group( + "default", dist.group.WORLD) + + if envs.VLLM_SPYRE_DYNAMO_BACKEND in ["sendnn", "sendnn_decoder"]: + spyre_setup.spyre_dist_setup( + rank=self.rank, + world_size=self.parallel_config.world_size, + verbose=True) + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cpu()) + + ensure_model_parallel_initialized( + self.parallel_config.tensor_parallel_size, + self.parallel_config.pipeline_parallel_size, + ) + + def init_device(self) -> None: + + if platform.machine() == "s390x": + from torch.serialization import LoadEndianness + torch.serialization.set_default_load_endianness( + LoadEndianness.LITTLE) + + if not self._env_initialized: + if self.parallel_config.world_size > 1: + self.init_distributed_environment() + elif envs.VLLM_SPYRE_DYNAMO_BACKEND in [ + "sendnn", "sendnn_decoder" + ]: + spyre_setup.spyre_setup(rank=0, world_size=1, verbose=True) + + self._env_initialized = True + + # Set random seed. + set_random_seed(self.model_config.seed) + + def load_model(self): + assert self._env_initialized + + with open(os.path.join(self.model_config.model, 'config.json'), + 'rb') as f: + config = json.load(f) + + bos_token_id, eos_token_id = int(config["bos_token_id"]), int( + config["eos_token_id"]) + + print("[SpyreWorker] load model...") + # TODO: check additionally if the Spyre card has enough memory + # for all requested model warmups + # printing env variables for debugging purposes + load_model_start_t = time.time() + + wup_prompt_lens, wup_new_tokens, wup_batch_sizes = zip( + *[(s["prompt_length"], s["new_tokens"], s["batch_size"]) + for s in self.scheduler_config.spyre_warmup_shapes]) + + self.model_runner.load_model(prompt_lens=wup_prompt_lens, + num_decode_tokens=wup_new_tokens, + batch_sizes=wup_batch_sizes) + + load_model_end_t = time.time() + load_model_total_t = load_model_end_t - load_model_start_t + print(f"\tload model took {load_model_total_t}s") + + print(f"[SpyreWorker] Start warming up " + f"{len(wup_new_tokens)} " + f"different prompt/decode/batchsize-shape combinations.") + all_warmup_start_t = time.time() + for i, (prompt_len, num_decode_tokens, batch_size) in enumerate([ + (s["prompt_length"], s["new_tokens"], s["batch_size"]) + for s in self.scheduler_config.spyre_warmup_shapes + ]): + if self.model_config.task != "embedding": + # TODO: remove if spyre supports + # lower number of output tokens + assert num_decode_tokens >= 3, ( + "VLLM_SPYRE_WARMUP_NEW_TOKENS must be " + "at least 2 (spyre requirement).") + # warmup individual combination + print(f"[SpyreWorker] Warmup {i+1}/" + f"{len(wup_new_tokens)} " + f"prompt/decode/batchsize-shape combinations...") + print(f"[SpyreWorker] Warming up for prompt length {prompt_len}, " + f"decoding {num_decode_tokens} tokens with batch " + f"size {batch_size}") + self._warmup_spyre_fixed_size(prompt_len, num_decode_tokens, + (bos_token_id, eos_token_id), + batch_size) + all_warmup_end_t = time.time() + all_warmup_total_t = all_warmup_end_t - all_warmup_start_t + print(f"[SpyreWorker] All warmups for " + f"{len(wup_new_tokens)} different " + f"prompt/decode/batchsize-shape combinations finished. " + f"Total warmup time {all_warmup_total_t}s.") + + def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens, + special_token_ids, batch_size): + # warmup the model + warmup_start_t = time.time() + # NOTE(ngl): empty tensor causes spyre to hang, so using + # randint without 0 and the eos and bos token + + # Create a list of valid values between 1 (inclusive) and vocab + # size (exclusive) by excluding the eos and bos token ids + # (in special_token_ids) + vocab_size = self.model_runner.vocab_size + valid_token_ids = [ + i for i in range(1, vocab_size) if i not in set(special_token_ids) + ] + # Convert to tensor for sampling + valid_token_ids_tensor = torch.tensor(valid_token_ids, + dtype=torch.long, + device=torch.device("cpu")) + # Sample from the valid token ids + warmup_tokens_tensor = valid_token_ids_tensor[torch.randint( + 0, len(valid_token_ids_tensor), (batch_size, prompt_len))] + + extra_kwargs = {} + if envs.VLLM_SPYRE_DYNAMO_BACKEND not in ["sendnn", "sendnn_decoder"]: + # Bug in 2.3.1 fixed in 2.4.1 for SDPA flash cpu + # impl when padding too much + extra_kwargs["attn_algorithm"] = "math" + + print(f"[SpyreWorker] warmup for prompt length " + f"{prompt_len} and max output tokens {num_decode_tokens}.") + + # 1. trace + print("[SpyreWorker] warmup 1/2...") + # TODO: torch_sendnn.CleanGraph() should be necessary? + # warmup 1st forward pass + self._warmup_model_forward_pass(warmup_tokens_tensor, + valid_token_ids_tensor, prompt_len, + num_decode_tokens, batch_size, + extra_kwargs) + + # 2. compile + print("[SpyreWorker] warmup 2/2...") + if envs.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder": + from torch_sendnn import torch_sendnn + ul_start_time = time.time() + torch_sendnn.update_lazyhandle() + ul_stop_time = time.time() + ul_total_t = ul_stop_time - ul_start_time + print(f"update_lazyhandle() done (duration: {ul_total_t}s)") + + # warmup 2nd forward pass + self._warmup_model_forward_pass(warmup_tokens_tensor, + valid_token_ids_tensor, prompt_len, + num_decode_tokens, batch_size, + extra_kwargs) + + warmup_end_t = time.time() + warmup_total_t = warmup_end_t - warmup_start_t + print("[SpyreWorker] ... warmup finished.") + print(f"\twarmup took {warmup_total_t}s (for prompt length" + f"{prompt_len} and max output tokens {num_decode_tokens})") + + def _warmup_model_forward_pass(self, warmup_tokens_tensor, + valid_token_ids_tensor, prompt_len, + num_decode_tokens, batch_size, + extra_kwargs): + # padding warmup tokens to obtain the + # corresponding position ids and mask + warmup_tokens_pad, self.model_runner._position_ids, \ + self.model_runner._mask = self.model_runner.pad_input_ids( + warmup_tokens_tensor, min_pad_length=prompt_len) + + logits, past_key_value_states = self.model_runner._raw_model_forward( + warmup_tokens_pad, + position_ids=self.model_runner._position_ids, + mask=self.model_runner._mask, + past_key_value_states=None, + use_cache=True, + only_last_token=True, + **extra_kwargs) + # decoding + for i in range(num_decode_tokens - 1): + # sampling next input token from vocab without bos and eos tokens + decode_tokens = valid_token_ids_tensor[torch.randint( + 0, len(valid_token_ids_tensor), (batch_size, 1))] + + # update mask and position_ids + self.model_runner._update_mask() + self.model_runner._update_position_ids() + + if past_key_value_states is not None: + for layer in past_key_value_states: + for tensor in layer: + torch._dynamo.mark_dynamic(tensor, 2) + + logits, past_key_value_states = self.model_runner.\ + _raw_model_forward( + decode_tokens, + position_ids=self.model_runner._position_ids, + mask=self.model_runner._mask, + past_key_value_states=past_key_value_states, + use_cache=True, + only_last_token=True, + **extra_kwargs) + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks. + + Swapping is not yet supported, so always return num_cpu_blocks=0. + + We configure num_gpu_blocks to be equal to max_num_seqs. + """ + # Set the number of GPU blocks to be the same as the maximum number of + # sequences that can be processed in a single batch. This is equivalent + # to schedule without PagedAttention. + num_gpu_blocks = self.scheduler_config.max_num_seqs + + # Swap not yet supported with Spyre backend. + num_cpu_blocks = 0 + + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache. + """ + + # Different values are not tested. + assert num_cpu_blocks == 0 + assert num_gpu_blocks == self.scheduler_config.max_num_seqs + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + # TODO: why not inference mode? + #@torch.inference_mode() + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> Optional[List[SamplerOutput]]: + + torch.set_grad_enabled(False) + if execute_model_req is None: + return None + finished_requests_ids = execute_model_req.finished_requests_ids + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + num_seq_groups = len(seq_group_metadata_list) + + # If there is no input, we don't need to execute the model. + if num_seq_groups == 0: + return [] + + output = self.model_runner.execute_model(seq_group_metadata_list, + finished_requests_ids) + + # Spyre worker only supports single-step output. Wrap the output in a + # list to conform to interface. + return [output] + + def get_cache_block_size_bytes(self) -> int: + """Determine the size in bytes of a cache block. + + This is required for speculative decoding; it is not yet implemented. + """ + raise NotImplementedError