diff --git a/.yapfignore b/.yapfignore
index 2d6dcf838..7a23d8913 100644
--- a/.yapfignore
+++ b/.yapfignore
@@ -1 +1,3 @@
collect_env.py
+
+vllm/model_executor/model_loader/spyre_setup.py
diff --git a/Dockerfile.spyre b/Dockerfile.spyre
new file mode 100644
index 000000000..c68dc5b4d
--- /dev/null
+++ b/Dockerfile.spyre
@@ -0,0 +1,28 @@
+# Global Args #################################################################
+ARG BASE_UBI_IMAGE_TAG=9.4
+ARG PYTHON_VERSION=3.12
+
+# Base Layer ##################################################################
+FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS base
+ARG PYTHON_VERSION
+ENV PYTHON_VERSION=${PYTHON_VERSION}
+WORKDIR /workspace/vllm
+
+# Install some basic utilities ##################################################################
+RUN microdnf update -y && microdnf install -y \
+ python${PYTHON_VERSION}-devel python${PYTHON_VERSION}-pip python${PYTHON_VERSION}-wheel git vim gcc g++\
+ && microdnf clean all
+
+# Install build dependencies ##################################################################
+RUN --mount=type=bind,source=requirements-build.txt,target=requirements-build.txt \
+ python3.12 -m pip install --upgrade pip && \
+ pip install -r requirements-build.txt
+
+# Build vLLM ##################################################################
+COPY . .
+
+ENV VLLM_TARGET_DEVICE=spyre
+RUN --mount=type=bind,source=.git,target=.git \
+ pip install --no-build-isolation -v -e .
+
+CMD ["/bin/bash"]
diff --git a/README.md b/README.md
index 0ef073210..cf24582d0 100644
--- a/README.md
+++ b/README.md
@@ -15,6 +15,76 @@ Easy, fast, and cheap LLM serving for everyone
---
+## What is the purpose of this fork?
+
+This is a private fork of vLLM that we are using to develop support for IBM Research's AI accelerator (Spyre).
+The idea is that the main branch of this repo should not diverge significantly from upstream beyond changes required to enable Spyre.
+We will try to rebase against upstream frequently and we plan to contribute these changes to the upstream repository in the future.
+
+---
+## Supported IBM Granite models on Spyre
+
+| Model | 3b | 7b | 8b | 13b | 20b |
+|:------------:|:------------:|:------------:|:------------:|:------------:|:------------:|
+| **llama** | NO1
[weights](https://huggingface.co/ibm-granite/granite-3b-code-base) | YES2
[weights](https://huggingface.co/ibm-granite/granite-7b-base) | YES3
[weights](https://huggingface.co/ibm-granite/granite-8b-code-base) | X | X |
+| **gpt big code** | YES4
[-](tom) | X | X | YES5
[-](tom) | YES6
[weights](https://huggingface.co/ibm-granite/granite-20b-code-base) |
+
+
+
+YES = working on Spyre
+NO = not yet working on Spyre
+X = no weights available
+
+
+#### Path to models
+
+1 : ```/models/granite-3b-code-base```
+2 : ```/models/granite-7b-base```
+3 : ```/models/granite-8b-code-base```
+4 : ```/models/granite-3b-base```
+5 : ```/models/granite-13b-base```
+6 : ```/models/granite-20b-code-base```
+(PVC in dev pod)
+## Running ***offline*** demo on Spyre
+
+```bash
+python3 examples/offline_inference_spyre.py
+```
+## Running ***online*** demo on Spyre
+
+### Batch size 1
+Log in to the same pod with two terminal windows and launch the server in one and submit requests from the other.
+
+**1st terminal window**: Set up the server with a model provided at \ [above](#path-to-models) (slow, takes a long time due to Spyre compilation):
+```bash
+python3 -m vllm.entrypoints.openai.api_server --model --max-model-len=2048 --block-size=2048
+```
+Optionally set the desired prompt padding (*default 64*) to any multiple of 64 and specify the maximal number of generated output tokens (*default 20*) with **VLLM_SPYRE_WARMUP_PROMPT_LENS** and **VLLM_SPYRE_WARMUP_NEW_TOKENS**:
+```bash
+export VLLM_SPYRE_WARMUP_PROMPT_LENS=64
+export VLLM_SPYRE_WARMUP_NEW_TOKENS=20
+```
+before starting the server.
+**2nd terminal window**: When the above warmup has completed, submit sample prompts for LLM completion (fast):
+```bash
+python3 examples/spyre_warmup_online_client.py
+```
+### Batch size 4/8
+
+Before launching the server specify the batch size to be used (below set to 8) via the environment variable **VLLM_SPYRE_WARMUP_BATCH_SIZES** (*default 1*):
+```bash
+export VLLM_SPYRE_WARMUP_BATCH_SIZES=4
+```
+
+Finally continue as described [above](#batch-size-1) by launching the server in the 1st terminal window.
+Before submitting prompts from the 2nd terminal window make sure to specify the batch size (same as set via **VLLM_SPYRE_WARMUP_BATCH_SIZES**) in the [client script](./examples/spyre_warmup_online_client.py) (line 44).
+### Example notebooks
+
+- [./examples/online_inference_spyre.ipynb](./examples/online_inference_spyre.ipynb)
+- [./examples/offline_inference_spyre.ipynb](./examples/offline_inference_spyre.ipynb)
+
+
+---
*Latest News* 🔥
- [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing).
- [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there!
diff --git a/examples/offline_inference_multi_spyre.py b/examples/offline_inference_multi_spyre.py
new file mode 100644
index 000000000..7bf422d8c
--- /dev/null
+++ b/examples/offline_inference_multi_spyre.py
@@ -0,0 +1,60 @@
+import gc
+import os
+import time
+
+from vllm import LLM, SamplingParams
+
+max_tokens = 3
+
+os.environ["VLLM_SPYRE_WARMUP_PROMPT_LENS"] = '64'
+os.environ["VLLM_SPYRE_WARMUP_NEW_TOKENS"] = str(max_tokens)
+os.environ['VLLM_SPYRE_WARMUP_BATCH_SIZES'] = '1'
+
+# stuff for multi-spyre
+os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
+os.environ["DISTRIBUTED_STRATEGY_IGNORE_MODULES"] = "WordEmbedding"
+os.environ["MASTER_ADDR"] = "localhost"
+os.environ["MASTER_PORT"] = "12355"
+
+# Sample prompts.
+template = (
+ "Below is an instruction that describes a task. Write a response that "
+ "appropriately completes the request. Be polite in your response to the "
+ "user.\n\n### Instruction:\n{}\n\n### Response:")
+prompt1 = template.format(
+ "Provide a list of instructions for preparing chicken soup for a family "
+ "of four.")
+prompts = [
+ prompt1,
+]
+
+# Create a sampling params object.
+sampling_params = SamplingParams(max_tokens=max_tokens,
+ temperature=0.0,
+ ignore_eos=True)
+# Create an LLM.
+llm = LLM(
+ model="/models/llama-194m",
+ tokenizer="/models/llama-194m",
+ max_model_len=2048,
+ block_size=2048,
+ device="spyre",
+ tensor_parallel_size=2,
+)
+
+# Generate texts from the prompts. The output is a list of RequestOutput objects
+# that contain the prompt, generated text, and other information.
+print("=============== GENERATE")
+t0 = time.time()
+outputs = llm.generate(prompts, sampling_params)
+print("Time elaspsed for %d tokens is %.2f sec" %
+ (len(outputs[0].outputs[0].token_ids), time.time() - t0))
+for output in outputs:
+ prompt = output.prompt
+ generated_text = output.outputs[0].text
+ print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
+print(output.outputs[0])
+
+# needed to prevent ugly stackdump caused by sigterm
+del llm
+gc.collect()
diff --git a/examples/offline_inference_spyre.ipynb b/examples/offline_inference_spyre.ipynb
new file mode 100644
index 000000000..792c73177
--- /dev/null
+++ b/examples/offline_inference_spyre.ipynb
@@ -0,0 +1,313 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "bb1996e6",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING 08-15 14:54:53 _custom_ops.py:14] Failed to import from vllm._C with ModuleNotFoundError(\"No module named 'vllm._C'\")\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/opt/vllm/lib64/python3.9/site-packages/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash:\n",
+ "No module named 'vllm.commit_id'\n",
+ " from vllm.version import __version__ as VLLM_VERSION\n"
+ ]
+ }
+ ],
+ "source": [
+ "import time\n",
+ "%load_ext wurlitzer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "45172614",
+ "metadata": {},
+ "source": [
+ "Offline inference demo\n",
+ "----------------------------\n",
+ "This is just a brief demo to show that vLLM with Spyre can be used in the offline mode. \n",
+ "\n",
+ "vLLM will determine the Spyre config automatically and warmup the Spyre stack. \n",
+ "The startup of vLLM (including warmup of Spyre), is expected to take 15 min for prompt length of 64 and maximum number of decode tokens 5 (it will take ~20min for 64/20)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bf37c07b",
+ "metadata": {},
+ "source": [
+ "### 1. create vLLM instance \n",
+ "(for offline usage, including warmup)\n",
+ "\n",
+ "The maximum prompt length and maximum number of decode tokens can be specified using the environment variables `VLLM_SPYRE_WARMUP_PROMPT_LENS`, and `VLLM_SPYRE_WARMUP_NEW_TOKENS`. \n",
+ "Otherwise the default max prompt length of 64 and maximum of 20 decode tokens will be used. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "ecf0992b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "os.environ['VLLM_SPYRE_WARMUP_PROMPT_LENS'] = '64'\n",
+ "os.environ['VLLM_SPYRE_WARMUP_NEW_TOKENS'] = '5'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "88b984d7",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "INFO 08-15 14:54:54 llm_engine.py:176] Initializing an LLM engine (v0.5.3.post1) with config: model='/tmp/7B-F', speculative_config=None, tokenizer='/tmp/7B-F', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cpu, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=0, served_model_name=/tmp/7B-F, use_v2_block_manager=False, enable_prefix_caching=False)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "You are using the default legacy behaviour of the . This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING: Disabled: dynamo_tracer\n",
+ "WARNING 08-15 14:54:55 utils.py:581] Pin memory is not supported on Spyre device.\n",
+ "[SpyreWorker] environment configured\n",
+ "[SpyreWorker] load model...\n",
+ ">> DEBUG SETUP\n",
+ "0 / 1 : Python Version : 3.9.18\n",
+ "0 / 1 : PyTorch Version: 2.2.2+cpu\n",
+ "0 / 1 : PCI Addr Rank 0 AIU_WORLD_RANK_0=0\n",
+ "0 / 1 : PCI Addr Rank 0 FLEX_RDMA_PCI_BUS_ADDR_0=0000:1e:00.0\n",
+ "0 / 1 : FLEX_COMPUTE=SENTIENT\n",
+ "0 / 1 : FLEX_DEVICE=VFIO\n",
+ "0 / 1 : DEEPRT_EXPORT_DIR=export/0\n",
+ "0 / 1 : DTCOMPILER_EXPORT_DIR=export/0\n",
+ "0 / 1 : AIU_CONFIG_FILE_0=/etc/aiu/senlib_config.json\n",
+ "0 / 1 : SENLIB_DEVEL_CONFIG_FILE=/etc/aiu/senlib_config.json\n",
+ "0 / 1 : FLEX_RDMA_PCI_BUS_ADDR_0=0000:1e:00.0\n",
+ "0 / 1 : FLEX_RDMA_LOCAL_RANK=0\n",
+ "0 / 1 : FLEX_RDMA_LOCAL_SIZE=1\n",
+ "0 / 1 : FLEX_RDMA_WORLD_RANK=0\n",
+ "0 / 1 : FLEX_RDMA_WORLD_SIZE=1\n",
+ "0 / 1 : Spyre: Enabled (0) (offset=0)\n",
+ "0 / 1 : Dynamo Backend : sendnn_decoder\n",
+ "0 / 1 : CPU Cores : 56 x 2 HW threads\n",
+ "------------------------------------------------------------\n",
+ "NOTICE: Adjusting torch._dynamo.config.accumulated_cache_size_limit from 64 to 160 to accommodate prompt size of 64 and decode tokens of 5\n",
+ "NOTICE: Adjusting torch._dynamo.config.cache_size_limit from 8 to 160 to accommodate prompt size of 64 and decode tokens of 5\n",
+ "\tload model took 62.92104411125183s\n",
+ "[SpyreWorker] Start warming up 1 different prompt/decode-shape combinations.\n",
+ "[SpyreWorker] Warmup 1/1 prompt/decode-shape combinations...\n",
+ "[SpyreWorker] warmup for prompt length 64 and max output tokens 5.\n",
+ "[SpyreWorker] warmup 1/2...\n",
+ "[SpyreWorker] warmup 2/2...\n",
+ "update_lazyhandle() done (duration: 134.3403525352478s)\n",
+ "[SpyreWorker] ... warmup finished.\n",
+ "\twarmup took 893.4236354827881s (for prompt length 64 and max output tokens 5)\n",
+ "[SpyreWorker] All warmups for 1 different prompt/decode-shape combinations finished. Total warmup time 893.4242045879364s.\n"
+ ]
+ }
+ ],
+ "source": [
+ "from vllm import LLM, SamplingParams\n",
+ "\n",
+ "# Create an LLM.\n",
+ "llm = LLM(\n",
+ " model=\"/models/llama-7b-chat\",\n",
+ " tokenizer=\"/models/llama-7b-chat\",\n",
+ " max_model_len=2048,\n",
+ " block_size=2048,\n",
+ " device=\"spyre\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "818488f4",
+ "metadata": {},
+ "source": [
+ "### 2. Create the prompt and `SamplingParams`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "6c32e3e2",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "['Below is an instruction that describes a task. Write a response that appropriately completes the request. Be polite in your response to the user.\\n\\n### Instruction:\\nProvide a list of instructions for preparing chicken soup for a family of four.\\n\\n### Response:']\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Sample prompts.\n",
+ "template = (\n",
+ " \"Below is an instruction that describes a task. Write a response that \"\n",
+ " \"appropriately completes the request. Be polite in your response to the \"\n",
+ " \"user.\\n\\n### Instruction:\\n{}\\n\\n### Response:\"\n",
+ ")\n",
+ "prompt1 = template.format(\n",
+ " \"Provide a list of instructions for preparing chicken soup for a family \"\n",
+ " \"of four.\"\n",
+ ")\n",
+ "prompts = [\n",
+ " prompt1,\n",
+ "]\n",
+ "print(prompts)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "4cc1277e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create a sampling params object.\n",
+ "max_tokens = 5\n",
+ "sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "cffb16c5",
+ "metadata": {},
+ "source": [
+ "### 3. Generate the response"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "522c0610",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "=============== GENERATE\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Processed prompts: 0%| | 0/1 [00:00, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[spyre_model_runner:execute_model] t_token: 195.92ms\n",
+ "[spyre_model_runner:execute_model] t_token: 158.88ms\n",
+ "[spyre_model_runner:execute_model] t_token: 158.49ms\n",
+ "[spyre_model_runner:execute_model] t_token: 158.88ms\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Processed prompts: 100%|██████████| 1/1 [00:00<00:00, 1.20it/s, est. speed input: 76.65 toks/s, output: 5.99 toks/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[spyre_model_runner:execute_model] t_token: 158.67ms\n",
+ "Time elaspsed for 5 tokens is 0.84 sec\n",
+ "Prompt: 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Be polite in your response to the user.\\n\\n### Instruction:\\nProvide a list of instructions for preparing chicken soup for a family of four.\\n\\n### Response:', Generated text: '\\nOf course! Here'\n",
+ "CompletionOutput(index=0, text='\\nOf course! Here', token_ids=(13, 2776, 3236, 29991, 2266), cumulative_logprob=-0.9147708129385137, logprobs=None, finish_reason=length, stop_reason=None)\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"=============== GENERATE\")\n",
+ "t0 = time.time()\n",
+ "outputs = llm.generate(prompts, sampling_params)\n",
+ "print(\"Time elaspsed for %d tokens is %.2f sec\" % \n",
+ " (len(outputs[0].outputs[0].token_ids), time.time()-t0))\n",
+ "for output in outputs:\n",
+ " prompt = output.prompt\n",
+ " generated_text = output.outputs[0].text\n",
+ " print(f\"Prompt: {prompt!r}, Generated text: {generated_text!r}\")\n",
+ "print(output.outputs[0])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "221b2d2b",
+ "metadata": {},
+ "source": [
+ "Expectation (for the 2nd+ tokens): \n",
+ "- ~158ms per token if the model was warmed up with max 5 output tokens\n",
+ "- ~162ms per token if the model was warmed up with max 20 output tokens"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "bd23c547",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/examples/offline_inference_spyre.py b/examples/offline_inference_spyre.py
new file mode 100644
index 000000000..ea95a5c91
--- /dev/null
+++ b/examples/offline_inference_spyre.py
@@ -0,0 +1,46 @@
+import os
+import time
+
+from vllm import LLM, SamplingParams
+
+max_tokens = 3
+
+os.environ["VLLM_SPYRE_WARMUP_PROMPT_LENS"] = '64'
+os.environ["VLLM_SPYRE_WARMUP_NEW_TOKENS"] = str(max_tokens)
+os.environ['VLLM_SPYRE_WARMUP_BATCH_SIZES'] = '1'
+
+# Sample prompts.
+template = (
+ "Below is an instruction that describes a task. Write a response that "
+ "appropriately completes the request. Be polite in your response to the "
+ "user.\n\n### Instruction:\n{}\n\n### Response:")
+prompt1 = template.format(
+ "Provide a list of instructions for preparing chicken soup for a family "
+ "of four.")
+prompts = [
+ prompt1,
+]
+
+# Create a sampling params object.
+sampling_params = SamplingParams(max_tokens=max_tokens,
+ temperature=0.0,
+ ignore_eos=True)
+# Create an LLM.
+llm = LLM(model="/models/llama-7b-chat",
+ tokenizer="/models/llama-7b-chat",
+ max_model_len=2048,
+ block_size=2048,
+ device="spyre")
+
+# Generate texts from the prompts. The output is a list of RequestOutput objects
+# that contain the prompt, generated text, and other information.
+print("=============== GENERATE")
+t0 = time.time()
+outputs = llm.generate(prompts, sampling_params)
+print("Time elaspsed for %d tokens is %.2f sec" %
+ (len(outputs[0].outputs[0].token_ids), time.time() - t0))
+for output in outputs:
+ prompt = output.prompt
+ generated_text = output.outputs[0].text
+ print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
+print(output.outputs[0])
diff --git a/examples/online_inference_spyre.ipynb b/examples/online_inference_spyre.ipynb
new file mode 100644
index 000000000..9cffdc8f1
--- /dev/null
+++ b/examples/online_inference_spyre.ipynb
@@ -0,0 +1,250 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "786d5912",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from openai import OpenAI\n",
+ "import time\n",
+ "%load_ext wurlitzer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7f873ca1",
+ "metadata": {},
+ "source": [
+ "Online inference demo\n",
+ "----------------------------\n",
+ "This is just a brief demo to show that vLLM with Spyre can be used in the online mode. \n",
+ "\n",
+ "Hence, a vLLM server must be started before (outside of this notebook):\n",
+ "```bash\n",
+ "python3 -m vllm.entrypoints.openai.api_server --model /models/llama-7b-chat --max-model-len=2048 --block-size=2048\n",
+ "```\n",
+ "\n",
+ "and waited until vLLM is ready, which is after the following log messages were printed (otherwise, there will be `ConnectError`s in the code below):\n",
+ "```log\n",
+ "INFO: Started server process [1840]\n",
+ "INFO: Waiting for application startup.\n",
+ "INFO: Application startup complete.\n",
+ "INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)\n",
+ "```\n",
+ "\n",
+ "(The startup of vLLM, including warmup of Spyre, is expected to take 15 min.)\n",
+ "\n",
+ "Here, the default max prompt length of 64 and maximum of 20 decode tokens is used. Otherwise change this behavior with the environment variables `VLLM_SPYRE_WARMUP_PROMPT_LENS`, and `VLLM_SPYRE_WARMUP_NEW_TOKENS`. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f7d4924c",
+ "metadata": {},
+ "source": [
+ "### 1. Create the prompts "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "bb328e33",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "template = (\n",
+ " \"Below is an instruction that describes a task. Write a response that \"\n",
+ " \"appropriately completes the request. Be polite in your response to the \"\n",
+ " \"user.\\n\\n### Instruction:\\n{}\\n\\n### Response:\"\n",
+ ")\n",
+ "prompt1 = template.format(\n",
+ " \"Provide a list of instructions for preparing chicken soup for a family \"\n",
+ " \"of four.\"\n",
+ ")\n",
+ "\n",
+ "prompt2 = template.format(\n",
+ " \"Please compare New York City and Zurich and provide a list of attractions \"\n",
+ " \"for each city.\"\n",
+ ")\n",
+ "\n",
+ "prompt3 = template.format(\n",
+ " \"Provide detailed instructions for preparing asparagus soup for a family \"\n",
+ " \"of four.\"\n",
+ ")\n",
+ "\n",
+ "prompts = [prompt1, prompt2, prompt3]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "2d0cee55",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'Below is an instruction that describes a task. Write a response that appropriately completes the request. Be polite in your response to the user.\\n\\n### Instruction:\\nPlease compare New York City and Zurich and provide a list of attractions for each city.\\n\\n### Response:'"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# e.g. prompt 2\n",
+ "prompt2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "37c92859",
+ "metadata": {},
+ "source": [
+ "### 2. Initialize client and connect to vLLM\n",
+ "\n",
+ "(Adapt the `openai_api_base` URL to point to the (forwarded/tunneled) vLLM instance. E.g. forward it to localhost with `oc port-forward $DEV_POD 8000`.)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "f7dee2e2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Modify OpenAI's API key and API base to use vLLM's API server.\n",
+ "openai_api_key = \"EMPTY\"\n",
+ "openai_api_base = \"http://localhost:8000/v1\"\n",
+ "\n",
+ "client = OpenAI(\n",
+ " # defaults to os.environ.get(\"OPENAI_API_KEY\")\n",
+ " api_key=openai_api_key,\n",
+ " base_url=openai_api_base,\n",
+ ")\n",
+ "\n",
+ "models = client.models.list()\n",
+ "model = models.data[0].id\n",
+ "\n",
+ "# Completion API\n",
+ "stream = False\n",
+ "max_tokens = 20 # default\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "026153e2",
+ "metadata": {},
+ "source": [
+ "### 3. Submit requests and await responses"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "bceb2c21",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Prompt: Below is an instruction that describes a task. Write a response that appropriately completes the request. Be polite in your response to the user.\n",
+ "\n",
+ "### Instruction:\n",
+ "Provide a list of instructions for preparing chicken soup for a family of four.\n",
+ "\n",
+ "### Response:\n",
+ "Results:\n",
+ "[CompletionChoice(finish_reason='length', index=0, logprobs=None, text='\\nOf course! Here are the steps to prepare chicken soup for a family of four:\\n', stop_reason=None)]\n",
+ "Duration: 3.3749101161956787s\n",
+ "---------------------------\n",
+ "\n",
+ "Prompt: Below is an instruction that describes a task. Write a response that appropriately completes the request. Be polite in your response to the user.\n",
+ "\n",
+ "### Instruction:\n",
+ "Please compare New York City and Zurich and provide a list of attractions for each city.\n",
+ "\n",
+ "### Response:\n",
+ "Results:\n",
+ "[CompletionChoice(finish_reason='length', index=0, logprobs=None, text='\\nThank you for reaching out! Both New York City and Zurich are incredible destinations with', stop_reason=None)]\n",
+ "Duration: 3.367875576019287s\n",
+ "---------------------------\n",
+ "\n",
+ "Prompt: Below is an instruction that describes a task. Write a response that appropriately completes the request. Be polite in your response to the user.\n",
+ "\n",
+ "### Instruction:\n",
+ "Provide detailed instructions for preparing asparagus soup for a family of four.\n",
+ "\n",
+ "### Response:\n",
+ "Results:\n",
+ "[CompletionChoice(finish_reason='length', index=0, logprobs=None, text='\\nOf course! Preparing asparagus soup for a family of four is a straightforward', stop_reason=None)]\n",
+ "Duration: 3.3706459999084473s\n",
+ "---------------------------\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "\n",
+ "for prompt in prompts:\n",
+ " print(f\"Prompt: {prompt}\")\n",
+ " start_t = time.time()\n",
+ "\n",
+ " completion = client.completions.create(\n",
+ " model=model,\n",
+ " prompt=prompt,\n",
+ " echo=False,\n",
+ " n=1,\n",
+ " stream=stream,\n",
+ " temperature=0.0,\n",
+ " max_tokens=max_tokens)\n",
+ "\n",
+ " end_t = time.time()\n",
+ " print(\"Results:\")\n",
+ " if stream:\n",
+ " for c in completion:\n",
+ " print(c)\n",
+ " else:\n",
+ " print(completion.choices)\n",
+ "\n",
+ " total_t = end_t - start_t\n",
+ " print(f\"Duration: {total_t}s\")\n",
+ " print(\"---------------------------\\n\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ea1e686d",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/examples/online_inference_spyre_multiple.ipynb b/examples/online_inference_spyre_multiple.ipynb
new file mode 100644
index 000000000..1c32a40a4
--- /dev/null
+++ b/examples/online_inference_spyre_multiple.ipynb
@@ -0,0 +1,256 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "cae04a1f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from openai import OpenAI\n",
+ "import time\n",
+ "%load_ext wurlitzer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7f873ca1",
+ "metadata": {},
+ "source": [
+ "Online inference demo with multiple prompt lengths\n",
+ "------------------------------------------------------\n",
+ "This is just a brief demo to show that vLLM with Spyre can be used in the online mode with MULTIPLE prompt-length/max-decode shapes! \n",
+ "\n",
+ "Hence, a vLLM server must be started before (outside of this notebook). We assume the following:\n",
+ "```bash\n",
+ "export VLLM_SPYRE_WARMUP_PROMPT_LENS=64,128\n",
+ "export VLLM_SPYRE_WARMUP_NEW_TOKENS=20,10\n",
+ "export VLLM_SPYRE_WARMUP_BATCH_SIZES=1,1\n",
+ "python3 -m vllm.entrypoints.openai.api_server --model /models/llama-7b-chat --max-model-len=2048 --block-size=2048\n",
+ "```\n",
+ "\n",
+ "Then, we need to wait until vLLM is ready, which is after the following log messages were printed (otherwise, there will be `ConnectError`s in the code below):\n",
+ "```log\n",
+ "INFO: Started server process [1840]\n",
+ "INFO: Waiting for application startup.\n",
+ "INFO: Application startup complete.\n",
+ "INFO: Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)\n",
+ "```\n",
+ "\n",
+ "(The startup of vLLM, including warmup of Spyre for both shapes, is expected to take ~35 min.)\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f7d4924c",
+ "metadata": {},
+ "source": [
+ "### 1. Create the prompts "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "bb328e33",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "template = (\n",
+ " \"Below is an instruction that describes a task. Write a response that \"\n",
+ " \"appropriately completes the request. Be polite in your response to the \"\n",
+ " \"user.\\n\\n### Instruction:\\n{}\\n\\n### Response:\"\n",
+ ")\n",
+ "\n",
+ "prompt1 = template.format(\n",
+ " \"Provide a list of instructions for preparing chicken soup for a family \"\n",
+ " \"of four.\"\n",
+ ")\n",
+ "\n",
+ "prompt2 = template.format(\n",
+ " \"Please compare the Cities of New York and Zurich and provide a list of \"\n",
+ " \"attractions for each city to visit in one day.\"\n",
+ ")\n",
+ "\n",
+ "prompt3 = template.format(\n",
+ " \"Provide detailed instructions for preparing asparagus soup for a family \"\n",
+ " \"of four using lots of cream.\"\n",
+ ")\n",
+ "\n",
+ "prompts = [prompt1, prompt2, prompt3]\n",
+ "max_tokens_list = [20, 10, 10]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "ee30282f",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'Below is an instruction that describes a task. Write a response that appropriately completes the request. Be polite in your response to the user.\\n\\n### Instruction:\\nPlease compare the Cities of New York and Zurich and provide a list of attractions for each city to visit in one day.\\n\\n### Response:'"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# e.g. prompt 2\n",
+ "prompt2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "37c92859",
+ "metadata": {},
+ "source": [
+ "### 2. Initialize client and connect to vLLM\n",
+ "\n",
+ "(Adapt the `openai_api_base` URL to point to the (forwarded/tunneled) vLLM instance. E.g. forward it to localhost with `oc port-forward $DEV_POD 8000`.)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "f7dee2e2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Modify OpenAI's API key and API base to use vLLM's API server.\n",
+ "openai_api_key = \"EMPTY\"\n",
+ "openai_api_base = \"http://localhost:8000/v1\"\n",
+ "\n",
+ "client = OpenAI(\n",
+ " # defaults to os.environ.get(\"OPENAI_API_KEY\")\n",
+ " api_key=openai_api_key,\n",
+ " base_url=openai_api_base,\n",
+ ")\n",
+ "\n",
+ "models = client.models.list()\n",
+ "model = models.data[0].id\n",
+ "\n",
+ "# Completion API\n",
+ "stream = False"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "026153e2",
+ "metadata": {},
+ "source": [
+ "### 3. Submit requests and await responses"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "bceb2c21",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Prompt: Below is an instruction that describes a task. Write a response that appropriately completes the request. Be polite in your response to the user.\n",
+ "\n",
+ "### Instruction:\n",
+ "Provide a list of instructions for preparing chicken soup for a family of four.\n",
+ "\n",
+ "### Response:\n",
+ "Results:\n",
+ "\n",
+ "Of course! Here are the steps to prepare chicken soup for a family of four:\n",
+ "\n",
+ "Duration: 3.6192898750305176s\n",
+ "---------------------------\n",
+ "\n",
+ "Prompt: Below is an instruction that describes a task. Write a response that appropriately completes the request. Be polite in your response to the user.\n",
+ "\n",
+ "### Instruction:\n",
+ "Please compare the Cities of New York and Zurich and provide a list of attractions for each city to visit in one day.\n",
+ "\n",
+ "### Response:\n",
+ "Results:\n",
+ "\n",
+ "Thank you for reaching out! Both New York\n",
+ "Duration: 1.6599247455596924s\n",
+ "---------------------------\n",
+ "\n",
+ "Prompt: Below is an instruction that describes a task. Write a response that appropriately completes the request. Be polite in your response to the user.\n",
+ "\n",
+ "### Instruction:\n",
+ "Provide detailed instructions for preparing asparagus soup for a family of four using lots of cream.\n",
+ "\n",
+ "### Response:\n",
+ "Results:\n",
+ "\n",
+ "Of course! Here are the steps to prepare\n",
+ "Duration: 1.6462435722351074s\n",
+ "---------------------------\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "\n",
+ "for prompt, max_tokens in zip(prompts, max_tokens_list):\n",
+ " print(f\"Prompt: {prompt}\")\n",
+ " start_t = time.time()\n",
+ "\n",
+ " completion = client.completions.create(\n",
+ " model=model,\n",
+ " prompt=prompt,\n",
+ " echo=False,\n",
+ " n=1,\n",
+ " stream=stream,\n",
+ " temperature=0.0,\n",
+ " max_tokens=max_tokens)\n",
+ "\n",
+ " end_t = time.time()\n",
+ " print(\"Results:\")\n",
+ " if stream:\n",
+ " for c in completion:\n",
+ " print(c)\n",
+ " else:\n",
+ " # print(completion.choices)\n",
+ " print(completion.choices[0].text)\n",
+ "\n",
+ " total_t = end_t - start_t\n",
+ " print(f\"Duration: {total_t}s\")\n",
+ " print(\"---------------------------\\n\")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ea1e686d",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.10"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/examples/spyre_warmup_online_client.py b/examples/spyre_warmup_online_client.py
new file mode 100644
index 000000000..02e168d9c
--- /dev/null
+++ b/examples/spyre_warmup_online_client.py
@@ -0,0 +1,82 @@
+""" Test for online serving.
+
+On the server side, run the following commands
+ python3 -m vllm.entrypoints.openai.api_server \
+ --model /models/llama-7b-chat/ \
+ --max-model-len=2048 \
+ --block-size=2048
+
+Then the default batch size 1, max prompt length of 64 and maximum of 20
+ecode tokens is used. Otherwise change the behavior with the environment
+variables `VLLM_SPYRE_WARMUP_BATCH_SIZES`, `VLLM_SPYRE_WARMUP_PROMPT_LENS`,
+and `VLLM_SPYRE_WARMUP_NEW_TOKENS`.
+"""
+
+import time
+
+from openai import OpenAI
+
+# Modify OpenAI's API key and API base to use vLLM's API server.
+openai_api_key = "EMPTY"
+openai_api_base = "http://localhost:8000/v1"
+
+client = OpenAI(
+ # defaults to os.environ.get("OPENAI_API_KEY")
+ api_key=openai_api_key,
+ base_url=openai_api_base,
+)
+
+models = client.models.list()
+model = models.data[0].id
+
+template = (
+ "Below is an instruction that describes a task. Write a response that "
+ "appropriately completes the request. Be polite in your response to the "
+ "user.\n\n### Instruction:\n{}\n\n### Response:")
+prompt1 = template.format(
+ "Provide a list of instructions for preparing chicken soup for a family "
+ "of four.")
+
+prompt2 = template.format(
+ "Please compare New York City and Zurich and provide a list of attractions "
+ "for each city.")
+
+prompt3 = template.format(
+ "Provide detailed instructions for preparing asparagus soup for a family "
+ "of four.")
+
+prompts = [prompt1, prompt2, prompt3]
+
+# make sure that the specified batch size is in VLLM_SPYRE_WARMUP_BATCH_SIZES
+batch_size = 1
+print('submitting prompts of batch size', batch_size)
+
+# making sure not to submit more prompts than the batch size
+for i in range(0, len(prompts), batch_size):
+ prompt = prompts[i:i + batch_size]
+
+ # Completion API
+ stream = False
+ max_tokens = 20
+
+ print(f"Prompt: {prompt}")
+ start_t = time.time()
+
+ completion = client.completions.create(model=model,
+ prompt=prompt,
+ echo=False,
+ n=1,
+ stream=stream,
+ temperature=0.0,
+ max_tokens=max_tokens)
+
+ end_t = time.time()
+ print("Results:")
+ if stream:
+ for c in completion:
+ print(c)
+ else:
+ print(completion)
+
+ total_t = end_t - start_t
+ print(f"Duration: {total_t}s")
diff --git a/format.sh b/format.sh
index 0b196de9d..d35b388b3 100755
--- a/format.sh
+++ b/format.sh
@@ -195,7 +195,7 @@ if [[ "$1" == '--files' ]]; then
# If `--all` is passed, then any further arguments are ignored and the
# entire python directory is linted.
elif [[ "$1" == '--all' ]]; then
- lint vllm tests
+ lint vllm tests examples
else
# Format only the files that changed in last commit.
lint_changed
@@ -278,7 +278,7 @@ clang_format_changed() {
# Format all files with clang-format
clang_format_all() {
- find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
+ find csrc \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
| grep -vFf <(printf "%s\n" "${CLANG_FORMAT_EXCLUDES[@]}") \
| xargs clang-format -i
}
diff --git a/pyproject.toml b/pyproject.toml
index 3c8c46cc8..4be60cb03 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -20,7 +20,8 @@ build-backend = "setuptools.build_meta"
line-length = 80
exclude = [
# External file, leaving license intact
- "examples/fp8/quantizer/quantize.py"
+ "examples/fp8/quantizer/quantize.py",
+ "vllm/model_executor/model_loader/spyre_setup.py"
]
[tool.ruff.lint.per-file-ignores]
@@ -79,7 +80,8 @@ files = [
exclude = [
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
# Ignore triton kernels in ops.
- 'vllm/attention/ops/.*\.py$'
+ 'vllm/attention/ops/.*\.py$',
+ 'vllm/model_executor/model_loader/spyre_setup.py'
]
[tool.codespell]
diff --git a/requirements-spyre.txt b/requirements-spyre.txt
new file mode 100644
index 000000000..ad8b1cbd2
--- /dev/null
+++ b/requirements-spyre.txt
@@ -0,0 +1,7 @@
+# Common dependencies
+-r requirements-common.txt
+
+# IBM foundation model stack
+ibm-fms==0.0.7
+
+wurlitzer
diff --git a/setup.py b/setup.py
index b93658986..9831559c7 100644
--- a/setup.py
+++ b/setup.py
@@ -8,12 +8,10 @@
from shutil import which
from typing import Dict, List
-import torch
from packaging.version import Version, parse
from setuptools import Extension, find_packages, setup
from setuptools.command.build_ext import build_ext
from setuptools_scm import get_version
-from torch.utils.cpp_extension import CUDA_HOME
def load_module_from_path(module_name, path):
@@ -33,6 +31,14 @@ def load_module_from_path(module_name, path):
VLLM_TARGET_DEVICE = envs.VLLM_TARGET_DEVICE
+if VLLM_TARGET_DEVICE in ["cuda", "rocm"]:
+ # we need to make this import happen only when needed
+ # since torch==2.4.0 is pinned in build deps but not
+ # available on s390x currently. we only actually need
+ # this import when using cuda or rocm.
+ import torch
+ from torch.utils.cpp_extension import CUDA_HOME
+
if not sys.platform.startswith("linux"):
logger.warning(
"vLLM only supports Linux platform (including WSL). "
@@ -272,8 +278,7 @@ def _no_device() -> bool:
def _is_cuda() -> bool:
- has_cuda = torch.version.cuda is not None
- return (VLLM_TARGET_DEVICE == "cuda" and has_cuda
+ return (VLLM_TARGET_DEVICE == "cuda" and (torch.version.cuda is not None)
and not (_is_neuron() or _is_tpu() or _is_hpu()))
@@ -299,6 +304,10 @@ def _is_cpu() -> bool:
return VLLM_TARGET_DEVICE == "cpu"
+def _is_spyre() -> bool:
+ return VLLM_TARGET_DEVICE == "spyre"
+
+
def _is_openvino() -> bool:
return VLLM_TARGET_DEVICE == "openvino"
@@ -415,6 +424,8 @@ def get_vllm_version() -> str:
if neuron_version != MAIN_CUDA_VERSION:
neuron_version_str = neuron_version.replace(".", "")[:3]
version += f"{sep}neuron{neuron_version_str}"
+ elif _is_spyre():
+ version += f"{sep}spyre"
elif _is_hpu():
# Get the Intel Gaudi Software Suite version
gaudi_sw_version = str(get_gaudi_sw_version())
@@ -479,6 +490,8 @@ def _read_requirements(filename: str) -> List[str]:
requirements = _read_requirements("requirements-rocm.txt")
elif _is_neuron():
requirements = _read_requirements("requirements-neuron.txt")
+ elif _is_spyre():
+ requirements = _read_requirements("requirements-spyre.txt")
elif _is_hpu():
requirements = _read_requirements("requirements-hpu.txt")
elif _is_openvino():
diff --git a/tests/spyre/spyre_util.py b/tests/spyre/spyre_util.py
new file mode 100644
index 000000000..7fc3bc10d
--- /dev/null
+++ b/tests/spyre/spyre_util.py
@@ -0,0 +1,277 @@
+import math
+import os
+from typing import Any, Dict, List, Tuple
+
+import numpy as np
+from sentence_transformers import SentenceTransformer, util
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from vllm import LLM, SamplingParams
+
+DISABLE_ASSERTS = False # used for debugging
+
+ISCLOSE_REL_TOL_CPU = 0.1
+ISCLOSE_REL_TOL_SPYRE = 0.1
+
+
+# vLLM / Spyre
+def generate_spyre_vllm_output(model: str, prompts: List[str],
+ warmup_shapes: List[Tuple[int, int, int]],
+ max_model_len: int, block_size: int,
+ sampling_params: SamplingParams,
+ tensor_parallel_size: int,
+ backend: str) -> List[Dict[str, Any]]:
+
+ warmup_prompt_length = [t[0] for t in warmup_shapes]
+ warmup_new_tokens = [t[1] for t in warmup_shapes]
+ warmup_batch_size = [t[2] for t in warmup_shapes]
+
+ os.environ['VLLM_SPYRE_WARMUP_PROMPT_LENS'] = ','.join(
+ str(val) for val in warmup_prompt_length)
+ os.environ['VLLM_SPYRE_WARMUP_NEW_TOKENS'] = ','.join(
+ str(val) for val in warmup_new_tokens)
+ os.environ['VLLM_SPYRE_WARMUP_BATCH_SIZES'] = ','.join(
+ str(val) for val in warmup_batch_size)
+ os.environ['VLLM_SPYRE_DYNAMO_BACKEND'] = backend
+
+ vllm_model = LLM(model=model,
+ tokenizer=model,
+ max_model_len=max_model_len,
+ block_size=block_size,
+ tensor_parallel_size=tensor_parallel_size,
+ device="spyre")
+
+ vllm_outputs = vllm_model.generate(prompts, sampling_params)
+
+ results = []
+ for req_output in vllm_outputs:
+ result = {}
+ result['text'] = req_output.outputs[0].text
+ result['token_ids'] = req_output.outputs[0].token_ids
+ result['tokens'] = tuple([
+ req_output.outputs[0].logprobs[i][t].decoded_token
+ for i, t in enumerate(result['token_ids'])
+ ])
+ result['logprobs'] = tuple([
+ req_output.outputs[0].logprobs[i][t].logprob
+ for i, t in enumerate(result['token_ids'])
+ ])
+ results.append(result)
+
+ return results
+
+
+# Hugging Face
+def generate_hf_output(model: str, prompts: List[str],
+ max_new_tokens: int) -> List[Dict[str, Any]]:
+
+ hf_model = AutoModelForCausalLM.from_pretrained(model)
+ hf_tokenizer = AutoTokenizer.from_pretrained(model)
+
+ results = []
+ for prompt_index, prompt in enumerate(prompts):
+ hf_input_tokens = hf_tokenizer(prompt, return_tensors="pt").input_ids
+ hf_output = hf_model.generate(hf_input_tokens,
+ do_sample=False,
+ max_new_tokens=max_new_tokens,
+ return_dict_in_generate=True,
+ output_scores=True)
+
+ # decode output tokens after first removing input tokens (prompt)
+ hf_generated_text = hf_tokenizer.batch_decode(
+ hf_output.sequences[:, len(hf_input_tokens[0]):])[0]
+ hf_transition_scores = hf_model.compute_transition_scores(
+ hf_output.sequences, hf_output.scores, normalize_logits=True)
+
+ # return HF generated text, tokens, token ids and logprobs
+ result = {}
+ result['text'] = hf_generated_text
+ result['token_ids'] = []
+ result['tokens'] = []
+ result['logprobs'] = []
+ for tok_index, hf_logprob in enumerate(hf_transition_scores[0]):
+ hf_token_id = hf_output.sequences[0][tok_index +
+ len(hf_input_tokens[0])]
+ result['token_ids'].append(hf_token_id.item())
+ result['tokens'].append(hf_tokenizer.decode(hf_token_id))
+ result['logprobs'].append(hf_logprob.item())
+ result['token_ids'] = tuple(result['token_ids'])
+ result['tokens'] = tuple(result['tokens'])
+ result['logprobs'] = tuple(result['logprobs'])
+ results.append(result)
+
+ return results
+
+
+# compare results
+def compare_results(model: str, prompts: List[str],
+ warmup_shapes: List[Tuple[int, int,
+ int]], tensor_parallel_size: int,
+ backend: str, vllm_results: List[Dict[str, Any]],
+ hf_results: List[Dict[str, Any]]):
+
+ print(f"\nmodel: {model:s}")
+ print(f"warmup shapes: {warmup_shapes}")
+ print(f"tp size: {tensor_parallel_size}")
+ print(f"backend: {backend:s}")
+ print(f"\n#prompts: {len(prompts):d}")
+ print(f"#HF results: {len(hf_results):d}"
+ f"{'' if len(hf_results) == len(prompts) else ' ERROR':s}")
+ print(f"#vLLM results: {len(vllm_results):d}"
+ f"{'' if len(vllm_results) == len(prompts) else ' ERROR':s}")
+ print()
+
+ assert DISABLE_ASSERTS or len(hf_results) == len(vllm_results)
+ assert DISABLE_ASSERTS or len(hf_results) == len(prompts)
+
+ for prompt_index, (prompt, hf_result, vllm_result) in enumerate(
+ zip(prompts, hf_results, vllm_results)):
+ err_msg = '' if hf_result['text'] == vllm_result['text'] else ' ERROR'
+ print(f"\nprompt {prompt_index:3d}: {repr(prompt):s}")
+ print("generated:")
+ print(f" HF: {repr(hf_result['text']):s}")
+ print(f" vLLM: {repr(vllm_result['text']):s}{err_msg}")
+ print()
+
+ assert DISABLE_ASSERTS or backend == 'sendnn_decoder' or\
+ hf_result['text'] == vllm_result['text']
+
+ if len(hf_result['tokens']) > 0:
+ print(" token id. token logprob "
+ " token id. token logprob")
+
+ logprob_abs_diff_list = []
+ logprob_rel_diff_list = []
+
+ for hf_token, hf_token_id, hf_logprob, vllm_token,\
+ vllm_token_id, vllm_logprob in zip(
+ hf_result['tokens'], hf_result['token_ids'],
+ hf_result['logprobs'], vllm_result['tokens'],
+ vllm_result['token_ids'], vllm_result['logprobs']):
+ logprob_abs_diff = math.fabs(hf_logprob - vllm_logprob)
+ logprob_abs_diff_list.append(logprob_abs_diff)
+ logprob_rel_diff = math.fabs(logprob_abs_diff / hf_logprob)
+ logprob_rel_diff_list.append(logprob_rel_diff)
+
+ print(
+ f"HF: {hf_token_id:8d} {repr(hf_token):14s} "
+ f"{hf_logprob:14f} "
+ f"vLLM: {vllm_token_id:8d} {repr(vllm_token):14s} "
+ f"{vllm_logprob:14f} ",
+ end='')
+
+ if backend == 'sendnn_decoder':
+ rel_tol = ISCLOSE_REL_TOL_SPYRE
+ else:
+ rel_tol = ISCLOSE_REL_TOL_CPU
+
+ if hf_token_id != vllm_token_id: # different tokens
+ if backend == 'sendnn_decoder' and math.isclose(
+ hf_logprob, vllm_logprob, rel_tol=rel_tol):
+ # probably still OK
+ print('DIVERGING')
+ break
+ else:
+ print('ERROR')
+ assert DISABLE_ASSERTS or False
+ break
+ else: # identical tokens
+ if math.isclose(hf_logprob, vllm_logprob, rel_tol=rel_tol):
+ print()
+ else:
+ print('ERROR')
+ assert DISABLE_ASSERTS or False
+ break
+
+ print()
+ print("logprob absolute differences: "
+ f"average={np.mean(logprob_abs_diff_list):f} "
+ f"maximum={np.max(logprob_abs_diff_list):f}")
+ print("logprob relative differences: "
+ f"average={np.mean(logprob_rel_diff_list):f} "
+ f"maximum={np.max(logprob_rel_diff_list):f}")
+
+ print()
+
+
+# vLLM / Spyre
+def spyre_vllm_embeddings(model: str, prompts: List[str],
+ warmup_shapes: List[Tuple[int,
+ int]], max_model_len: int,
+ block_size: int, tensor_parallel_size: int,
+ backend: str) -> List[Dict[str, Any]]:
+
+ warmup_prompt_length = [t[0] for t in warmup_shapes]
+ warmup_new_tokens = [0] * len(warmup_shapes)
+ warmup_batch_size = [t[1] for t in warmup_shapes]
+
+ os.environ['VLLM_SPYRE_WARMUP_PROMPT_LENS'] = ','.join(
+ str(val) for val in warmup_prompt_length)
+ os.environ['VLLM_SPYRE_WARMUP_NEW_TOKENS'] = ','.join(
+ str(val) for val in warmup_new_tokens)
+ os.environ['VLLM_SPYRE_WARMUP_BATCH_SIZES'] = ','.join(
+ str(val) for val in warmup_batch_size)
+ os.environ['VLLM_SPYRE_DYNAMO_BACKEND'] = backend
+
+ vllm_model = LLM(model=model,
+ tokenizer=model,
+ max_model_len=max_model_len,
+ block_size=block_size,
+ tensor_parallel_size=tensor_parallel_size,
+ device="spyre")
+
+ vllm_outputs = vllm_model.encode(prompts)
+
+ results = []
+ for req_output in vllm_outputs:
+ result = {}
+ result["embeddings"] = req_output.outputs.embedding
+ results.append(result)
+
+ return results
+
+
+# Hugging Face
+def st_embeddings(model: str, prompts: List[str]) -> List[Dict[str, Any]]:
+
+ model = SentenceTransformer(model)
+
+ results = []
+ for prompt in prompts:
+ embeddings = model.encode(prompt)
+
+ # return ST generated embeddings
+ result = {}
+ result['embeddings'] = embeddings
+ results.append(result)
+
+ return results
+
+
+# compare results
+def compare_embedding_results(model: str, prompts: List[str],
+ warmup_shapes: List[Tuple[int, int]],
+ tensor_parallel_size: int, backend: str,
+ vllm_results: List[Dict[str, Any]],
+ hf_results: List[Dict[str, Any]]):
+
+ print(f"\nmodel: {model:s}")
+ print(f"warmup shapes: {warmup_shapes}")
+ print(f"tp size: {tensor_parallel_size}")
+ print(f"backend: {backend:s}")
+ print(f"\n#prompts: {len(prompts):d}")
+ print(f"#HF results: {len(hf_results):d}"
+ f"{'' if len(hf_results) == len(prompts) else ' ERROR':s}")
+ print(f"#vLLM results: {len(vllm_results):d}"
+ f"{'' if len(vllm_results) == len(prompts) else ' ERROR':s}")
+ print()
+
+ assert DISABLE_ASSERTS or len(hf_results) == len(vllm_results)
+ assert DISABLE_ASSERTS or len(hf_results) == len(prompts)
+
+ for hf_result, vllm_result in zip(hf_results, vllm_results):
+
+ sim = util.pytorch_cos_sim(hf_result["embeddings"],
+ vllm_result["embeddings"])
+
+ assert math.isclose(sim, 1.0, rel_tol=0.05)
diff --git a/tests/spyre/test_spyre_basic.py b/tests/spyre/test_spyre_basic.py
new file mode 100644
index 000000000..0e2d73f32
--- /dev/null
+++ b/tests/spyre/test_spyre_basic.py
@@ -0,0 +1,73 @@
+"""Verification of vLLM output by comparing with HF
+
+Run `pytest tests/spyre/test_spyre_basic.py`.
+"""
+
+from typing import List, Tuple
+
+import pytest
+from spyre_util import (compare_results, generate_hf_output,
+ generate_spyre_vllm_output)
+
+from vllm import SamplingParams
+
+
+@pytest.mark.parametrize("model", ["/models/llama-194m"])
+@pytest.mark.parametrize("prompts", [[
+ "Provide a list of instructions for preparing"
+ " chicken soup for a family of four.", "Hello",
+ "What is the weather today like?", "Who are you?"
+]])
+@pytest.mark.parametrize("warmup_shape", [(64, 20, 4), (64, 20, 8),
+ (128, 20, 4), (128, 20, 8)]
+ ) # (prompt_length/new_tokens/batch_size)
+@pytest.mark.parametrize("backend",
+ ["eager"]) #, "inductor", "sendnn_decoder"])
+def test_output(
+ model: str,
+ prompts: List[str],
+ warmup_shape: Tuple[int, int, int],
+ backend: str,
+) -> None:
+ '''
+ The warmup is based on a single shape. After the warmup,
+ one request with the provided prompts is input to vLLM.
+ The same prompts are also input to HF. The generated output
+ including text, token ids, and logprobs, is verified to be
+ identical for vLLM and HF.
+
+ If errors occur, these can be analyzed/debugged by setting
+ 'DISABLE_ASSERTS = True' in spyre_util.py and by rerunning the
+ test using 'pytest --capture=no tests/spyre/test_spyre_basic.py'
+ After debugging, DISABLE_ASSERTS should be reset to 'False'.
+ '''
+
+ max_new_tokens = warmup_shape[1]
+
+ vllm_sampling_params = SamplingParams(
+ max_tokens=max_new_tokens,
+ temperature=0,
+ logprobs=0, # return logprobs of generated tokens only
+ ignore_eos=True)
+
+ vllm_results = generate_spyre_vllm_output(
+ model=model,
+ prompts=prompts,
+ warmup_shapes=[warmup_shape],
+ max_model_len=2048,
+ block_size=2048,
+ sampling_params=vllm_sampling_params,
+ tensor_parallel_size=1,
+ backend=backend)
+
+ hf_results = generate_hf_output(model=model,
+ prompts=prompts,
+ max_new_tokens=max_new_tokens)
+
+ compare_results(model=model,
+ prompts=prompts,
+ warmup_shapes=[warmup_shape],
+ tensor_parallel_size=1,
+ backend=backend,
+ vllm_results=vllm_results,
+ hf_results=hf_results)
diff --git a/tests/spyre/test_spyre_embeddings.py b/tests/spyre/test_spyre_embeddings.py
new file mode 100644
index 000000000..22e4eda7b
--- /dev/null
+++ b/tests/spyre/test_spyre_embeddings.py
@@ -0,0 +1,55 @@
+"""Verification of vLLM output by comparing with HF
+
+Run `pytest tests/spyre/test_spyre_basic.py`.
+"""
+
+from typing import List, Tuple
+
+import pytest
+from spyre_util import (compare_embedding_results, spyre_vllm_embeddings,
+ st_embeddings)
+
+
+@pytest.mark.skip("Skip until failure is resolved.")
+@pytest.mark.parametrize("model", ["/models/all-roberta-large-v1"])
+@pytest.mark.parametrize("prompts", [[
+ "The capital of France is Paris."
+ "Provide a list of instructions for preparing"
+ " chicken soup for a family of four.", "Hello",
+ "What is the weather today like?", "Who are you?"
+]])
+@pytest.mark.parametrize("warmup_shape",
+ [(64, 4), (64, 8), (128, 4),
+ (128, 8)]) # (prompt_length/new_tokens/batch_size)
+@pytest.mark.parametrize("backend",
+ ["eager"]) #, "inductor", "sendnn_decoder"])
+def test_output(
+ model: str,
+ prompts: List[str],
+ warmup_shape: Tuple[int, int],
+ backend: str,
+) -> None:
+ '''
+ The warmup is based on a single shape. After the warmup,
+ one request with the provided prompts is input to vLLM.
+ The same prompts are also input to HF. The generated embeddings
+ are verified to be identical for vLLM and SentenceTransformers.
+ '''
+
+ vllm_results = spyre_vllm_embeddings(model=model,
+ prompts=prompts,
+ warmup_shapes=[warmup_shape],
+ max_model_len=256,
+ block_size=256,
+ tensor_parallel_size=1,
+ backend=backend)
+
+ hf_results = st_embeddings(model=model, prompts=prompts)
+
+ compare_embedding_results(model=model,
+ prompts=prompts,
+ warmup_shapes=[warmup_shape],
+ tensor_parallel_size=1,
+ backend=backend,
+ vllm_results=vllm_results,
+ hf_results=hf_results)
diff --git a/tests/spyre/test_spyre_max_prompt_length.py b/tests/spyre/test_spyre_max_prompt_length.py
new file mode 100644
index 000000000..37f02f3a6
--- /dev/null
+++ b/tests/spyre/test_spyre_max_prompt_length.py
@@ -0,0 +1,101 @@
+"""Verification of handling prompt length exceeding warmup shapes
+
+Run `pytest tests/spyre/test_spyre_max_prompt_length.py`.
+"""
+
+from typing import List, Tuple
+
+import pytest
+from spyre_util import (compare_results, generate_hf_output,
+ generate_spyre_vllm_output)
+from transformers import AutoTokenizer
+
+from vllm import SamplingParams
+
+
+@pytest.mark.parametrize("model", ["/models/llama-194m"])
+@pytest.mark.parametrize("prompts", [
+ 7 * [
+ "Hello",
+ "Below is an instruction that describes a task. Write a response"
+ " that appropriately completes the request. Be polite in your response"
+ " to the user. Provide a list of instructions for preparing chicken "
+ "soup for a family of four. Indicate if the weather forecast looks "
+ "good for today. Explain in a brief summary comprised of at most 50"
+ " words what you are."
+ ]
+])
+@pytest.mark.parametrize("warmup_shapes",
+ [[(64, 20, 4)], [(64, 20, 4), (128, 20, 4)]]
+ ) # (prompt_length/new_tokens/batch_size)
+@pytest.mark.parametrize("backend",
+ ["eager"]) #, "inductor", "sendnn_decoder"])
+def test_output(
+ model: str,
+ prompts: List[str],
+ warmup_shapes: List[Tuple[int, int, int]],
+ backend: str,
+) -> None:
+ '''
+ The warmup is based on one or multiple shapes. After the warmup,
+ one request with multiple provided prompts is input to vLLM.
+ At least one provided prompt should have a length longer than the
+ maximum prompt length defined by the warmup shapes. It is useful
+ to define enough prompts to fill multiple batches entirely and
+ partially, in order to test the maximum prompt length check
+ also in relation with the position of a prompt within a batch (not
+ likely that this will be an issue, but just to be sure).
+ It is verified that only for the prompts that
+ do not exceed the maximum prompt length, "non-empty" output is
+ generated. The output is verified using HF.
+
+ If errors occur, these can be analyzed/debugged by setting
+ 'DISABLE_ASSERTS = True' in spyre_util.py and by rerunning the test
+ using 'pytest --capture=no tests/spyre/test_spyre_max_prompt_length.py'
+ After debugging, DISABLE_ASSERTS should be reset to 'False'.
+ '''
+
+ max_prompt_length = max([t[0] for t in warmup_shapes])
+ max_new_tokens = max([t[1] for t in warmup_shapes])
+
+ vllm_sampling_params = SamplingParams(
+ max_tokens=max_new_tokens,
+ temperature=0,
+ logprobs=0, # return logprobs of generated tokens only
+ ignore_eos=True)
+
+ vllm_results = generate_spyre_vllm_output(
+ model=model,
+ prompts=prompts,
+ warmup_shapes=warmup_shapes,
+ max_model_len=2048,
+ block_size=2048,
+ sampling_params=vllm_sampling_params,
+ tensor_parallel_size=1,
+ backend=backend)
+
+ hf_results = generate_hf_output(model=model,
+ prompts=prompts,
+ max_new_tokens=max_new_tokens)
+
+ # for prompts longer than the max_prompt_length, the corresponding
+ # output in hf_results is reset to 'empty' in order to create the
+ # expected output for vLLM
+ hf_tokenizer = AutoTokenizer.from_pretrained(model)
+ for prompt_index, prompt in enumerate(prompts):
+ hf_input_tokens = hf_tokenizer(prompt, return_tensors="pt").input_ids
+ if len(hf_input_tokens[0]) > max_prompt_length:
+ hf_results[prompt_index] = {
+ 'text': '',
+ 'token_ids': (),
+ 'tokens': (),
+ 'logprobs': ()
+ }
+
+ compare_results(model=model,
+ prompts=prompts,
+ warmup_shapes=warmup_shapes,
+ tensor_parallel_size=1,
+ backend=backend,
+ vllm_results=vllm_results,
+ hf_results=hf_results)
diff --git a/tests/spyre/test_spyre_seed.py b/tests/spyre/test_spyre_seed.py
new file mode 100644
index 000000000..ca55d7024
--- /dev/null
+++ b/tests/spyre/test_spyre_seed.py
@@ -0,0 +1,75 @@
+"""Verification of seeded random sampling to be deterministic
+
+Run `pytest tests/spyre/test_spyre_seed.py`.
+"""
+
+import math
+from typing import Tuple
+
+import pytest
+from spyre_util import generate_spyre_vllm_output
+
+from vllm import SamplingParams
+
+
+@pytest.mark.parametrize("model", ["/models/llama-194m"])
+@pytest.mark.parametrize("prompt", [
+ "Provide a list of instructions for preparing"
+ " chicken soup for a family of four."
+])
+@pytest.mark.parametrize("temperature", [0.1, 1.0])
+@pytest.mark.parametrize("seed", [42])
+@pytest.mark.parametrize("warmup_shape", [(64, 20, 4), (64, 20, 8),
+ (128, 20, 4), (128, 20, 8)]
+ ) # (prompt_length/new_tokens/batch_size)
+@pytest.mark.parametrize("backend",
+ ["eager"]) #, "inductor", "sendnn_decoder"])
+def test_seed(
+ model: str,
+ prompt: str,
+ temperature: float,
+ seed: int,
+ warmup_shape: Tuple[int, int, int],
+ backend: str,
+) -> None:
+ '''
+ The warmup is based on a single shape. After the warmup,
+ output is generated for one request with 16 identical prompts
+ using random sampling (non-zero temperature) in combination
+ with a seed. The generated output, including text, token ids,
+ logprobs is verified to be identical for all 16 sequences.
+ '''
+
+ max_new_tokens = warmup_shape[1]
+
+ prompts = [prompt] * 16
+
+ vllm_sampling_params = SamplingParams(
+ max_tokens=max_new_tokens,
+ temperature=temperature,
+ logprobs=0, # return logprobs of generated tokens only
+ ignore_eos=True,
+ seed=seed)
+
+ vllm_results = generate_spyre_vllm_output(
+ model=model,
+ prompts=prompts,
+ warmup_shapes=[warmup_shape],
+ max_model_len=2048,
+ block_size=2048,
+ sampling_params=vllm_sampling_params,
+ tensor_parallel_size=1,
+ backend=backend)
+
+ # compare all generated outputs against the first generated output
+ for vllm_result in vllm_results:
+ assert vllm_result['text'] == vllm_results[0]['text']
+
+ # compare logprobs for all tokens between
+ # the current and the first sequence
+ assert len(vllm_result['logprobs']) == len(vllm_results[0]['logprobs'])
+ for token_id, logprob, token_id_0, logprob_0 in zip(
+ vllm_result['token_ids'], vllm_result['logprobs'],
+ vllm_results[0]['token_ids'], vllm_results[0]['logprobs']):
+ assert token_id == token_id_0
+ assert math.isclose(logprob, logprob_0, rel_tol=0.1)
diff --git a/tests/spyre/test_spyre_tensor_parallel.py b/tests/spyre/test_spyre_tensor_parallel.py
new file mode 100644
index 000000000..b3622f688
--- /dev/null
+++ b/tests/spyre/test_spyre_tensor_parallel.py
@@ -0,0 +1,77 @@
+"""Verification of vLLM output by comparing with HF
+
+Run `pytest tests/spyre/test_spyre_tensor_parallel.py`.
+"""
+
+from typing import List, Tuple
+
+import pytest
+from spyre_util import (compare_results, generate_hf_output,
+ generate_spyre_vllm_output)
+
+from vllm import SamplingParams
+
+
+@pytest.mark.skip("Skip until failure is resolved.")
+@pytest.mark.parametrize("model", ["/models/llama-194m"])
+@pytest.mark.parametrize("prompts", [[
+ "Provide a list of instructions for preparing"
+ " chicken soup for a family of four.", "Hello",
+ "What is the weather today like?", "Who are you?"
+]])
+@pytest.mark.parametrize("warmup_shapes", [[(64, 20, 4)]]
+ ) #,[(64,20,8)],[(128,20,4)],[(128,20,8)]])
+# (prompt_length/new_tokens/batch_size)
+@pytest.mark.parametrize("tp_size", [2])
+@pytest.mark.parametrize("backend",
+ ["eager"]) #, "inductor", "sendnn_decoder"])
+def test_output(
+ model: str,
+ prompts: List[str],
+ warmup_shapes: List[Tuple[int, int, int]],
+ tp_size: int,
+ backend: str,
+) -> None:
+ '''
+ The warmup is based on one or multiple shapes. After the warmup,
+ one request with the provided prompts is input to vLLM which
+ is executed in tensor-parallel fashion on Spyres.
+ The same prompts are also input to HF. The generated output
+ including text, token ids, and logprobs, is verified to be
+ identical for vLLM and HF.
+
+ If errors occur, these can be analyzed/debugged by setting
+ 'DISABLE_ASSERTS = True' in spyre_util.py and by rerunning the
+ test using 'pytest --capture=no tests/spyre/test_spyre_tensore_parallel.py'
+ After debugging, DISABLE_ASSERTS should be reset to 'False'.
+ '''
+
+ max_new_tokens = max([t[1] for t in warmup_shapes])
+
+ vllm_sampling_params = SamplingParams(
+ max_tokens=max_new_tokens,
+ temperature=0,
+ logprobs=0, # return logprobs of generated tokens only
+ ignore_eos=True)
+
+ vllm_results = generate_spyre_vllm_output(
+ model=model,
+ prompts=prompts,
+ warmup_shapes=warmup_shapes,
+ max_model_len=2048,
+ block_size=2048,
+ sampling_params=vllm_sampling_params,
+ tensor_parallel_size=tp_size,
+ backend=backend)
+
+ hf_results = generate_hf_output(model=model,
+ prompts=prompts,
+ max_new_tokens=max_new_tokens)
+
+ compare_results(model=model,
+ prompts=prompts,
+ warmup_shapes=warmup_shapes,
+ tensor_parallel_size=tp_size,
+ backend=backend,
+ vllm_results=vllm_results,
+ hf_results=hf_results)
diff --git a/tests/spyre/test_spyre_warmup_shapes.py b/tests/spyre/test_spyre_warmup_shapes.py
new file mode 100644
index 000000000..be58ea516
--- /dev/null
+++ b/tests/spyre/test_spyre_warmup_shapes.py
@@ -0,0 +1,85 @@
+"""Verification of Spyre warmup shapes
+
+Run `pytest tests/spyre/test_spyre_warmup_shapes.py`.
+"""
+
+from typing import List, Tuple
+
+import pytest
+from spyre_util import (compare_results, generate_hf_output,
+ generate_spyre_vllm_output)
+
+from vllm import SamplingParams
+
+
+@pytest.mark.parametrize("model", ["/models/llama-194m"])
+@pytest.mark.parametrize("prompts", [
+ 7 * [
+ "Hello",
+ "Below is an instruction that describes a task. Write a response that "
+ "appropriately completes the request. Be polite in your response to "
+ "the user. Provide a list of instructions for preparing chicken soup"
+ " for a family of four. Indicate if the weather forecast looks good "
+ "for today. Explain in a brief summary comprised of at most 50 words"
+ " what you are."
+ ]
+])
+@pytest.mark.parametrize("warmup_shapes", [[(64, 20, 8), (128, 20, 4)]]
+ ) # (prompt_length/new_tokens/batch_size)
+@pytest.mark.parametrize("backend",
+ ["eager"]) #, "inductor", "sendnn_decoder"])
+def test_output(
+ model: str,
+ prompts: List[str],
+ warmup_shapes: List[Tuple[int, int, int]],
+ backend: str,
+) -> None:
+ '''
+ The warmup is based on two shapes, that 'overlap' each
+ other. After the warmup, one request with the provided
+ prompts is input to vLLM. There should be at least one
+ prompt corresponding to each of the two warmup shapes.
+ It is useful to define enough prompts to fill multiple
+ batches entirely and partially, in order to test the
+ handling of overlapping warmup shapes also in relation
+ with the position of a prompt within a batch (not
+ likely that this will be an issue, but just to be sure).
+ The same prompts are also input to HF. The generated
+ output including text, token ids, and logprobs, is
+ verified to be identical for vLLM and HF.
+
+ If errors occur, these can be analyzed/debugged by setting
+ 'DISABLE_ASSERTS = True' in spyre_util.py and by rerunning the
+ test using 'pytest --capture=no tests/spyre/test_spyre_warmup_shapes.py'
+ After debugging, DISABLE_ASSERTS should be reset to 'False'.
+ '''
+
+ max_new_tokens = max([t[1] for t in warmup_shapes])
+
+ vllm_sampling_params = SamplingParams(
+ max_tokens=max_new_tokens,
+ temperature=0,
+ logprobs=0, # return logprobs of generated tokens only
+ ignore_eos=True)
+
+ vllm_results = generate_spyre_vllm_output(
+ model=model,
+ prompts=prompts,
+ warmup_shapes=warmup_shapes,
+ max_model_len=2048,
+ block_size=2048,
+ sampling_params=vllm_sampling_params,
+ tensor_parallel_size=1,
+ backend=backend)
+
+ hf_results = generate_hf_output(model=model,
+ prompts=prompts,
+ max_new_tokens=max_new_tokens)
+
+ compare_results(model=model,
+ prompts=prompts,
+ warmup_shapes=warmup_shapes,
+ tensor_parallel_size=1,
+ backend=backend,
+ vllm_results=vllm_results,
+ hf_results=hf_results)
diff --git a/vllm/config.py b/vllm/config.py
index e69cbd3eb..4669abd78 100644
--- a/vllm/config.py
+++ b/vllm/config.py
@@ -1,6 +1,7 @@
import copy
import enum
import json
+import operator
import warnings
from dataclasses import dataclass, field, replace
from pathlib import Path
@@ -382,6 +383,8 @@ def _verify_quantization(self) -> None:
]
tpu_supported_quantization = ["tpu_int8"]
neuron_supported_quantization = ["neuron_quant"]
+ spyre_supported_quantization = ["gptq"]
+
if self.quantization is not None:
self.quantization = self.quantization.lower()
@@ -441,6 +444,11 @@ def _verify_quantization(self) -> None:
raise ValueError(
f"{self.quantization} quantization is currently not "
f"supported in Neuron Backend.")
+ if current_platform.is_spyre(
+ ) and self.quantization not in spyre_supported_quantization:
+ raise ValueError(
+ f"{self.quantization} quantization is currently not "
+ f"supported in Spyre Backend.")
def _verify_cuda_graph(self) -> None:
if self.max_seq_len_to_capture is None:
@@ -1148,6 +1156,39 @@ def __init__(self,
self.num_scheduler_steps = num_scheduler_steps
self.multi_step_stream_outputs = multi_step_stream_outputs
self.send_delta_data = send_delta_data
+ self.spyre_scheduling_enabled = current_platform.is_spyre()
+ if self.spyre_scheduling_enabled:
+ # load warmup shapes and sort by "speed"
+ wup_prompt_lens = envs.VLLM_SPYRE_WARMUP_PROMPT_LENS or []
+ wup_batch_sizes = envs.VLLM_SPYRE_WARMUP_BATCH_SIZES or []
+ if len(wup_prompt_lens) != len(wup_batch_sizes):
+ raise RuntimeError(
+ "The lists in VLLM_SPYRE_WARMUP_PROMPT_LENS and "
+ "VLLM_SPYRE_WARMUP_BATCH_SIZES must have equal length")
+ if task == "embedding":
+ wup_new_tokens = [0] * len(wup_prompt_lens)
+ else:
+ wup_new_tokens = envs.VLLM_SPYRE_WARMUP_NEW_TOKENS or []
+ if len(wup_new_tokens) != len(wup_prompt_lens):
+ raise RuntimeError(
+ "The lists in VLLM_SPYRE_WARMUP_PROMPT_LENS and "
+ "VLLM_SPYRE_WARMUP_NEW_TOKENS must have equal length")
+
+ print("[SchedulerConfig] VLLM_SPYRE_WARMUP_PROMPT_LENS =",
+ wup_prompt_lens)
+ print("[SchedulerConfig] VLLM_SPYRE_WARMUP_NEW_TOKENS =",
+ wup_new_tokens)
+ print("[SchedulerConfig] VLLM_SPYRE_WARMUP_BATCH_SIZES =",
+ wup_batch_sizes)
+
+ self.spyre_warmup_shapes = tuple(
+ sorted([{
+ 'prompt_length': pl,
+ 'new_tokens': nt,
+ 'batch_size': bs
+ } for pl, nt, bs in zip(wup_prompt_lens, wup_new_tokens,
+ wup_batch_sizes)],
+ key=operator.itemgetter('batch_size', 'prompt_length')))
self.policy = policy
self._verify_args()
@@ -1195,6 +1236,8 @@ def __init__(self, device: str = "auto") -> None:
self.device_type = "cuda"
elif current_platform.is_neuron():
self.device_type = "neuron"
+ elif current_platform.is_spyre():
+ self.device_type = "spyre"
elif current_platform.is_hpu():
self.device_type = "hpu"
elif current_platform.is_openvino():
@@ -1212,7 +1255,7 @@ def __init__(self, device: str = "auto") -> None:
self.device_type = device
# Some device types require processing inputs on CPU
- if self.device_type in ["neuron", "openvino"]:
+ if self.device_type in ["neuron", "spyre", "openvino"]:
self.device = torch.device("cpu")
elif self.device_type in ["tpu"]:
self.device = None
diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py
index af4671ec2..ef6854e45 100644
--- a/vllm/core/scheduler.py
+++ b/vllm/core/scheduler.py
@@ -872,6 +872,9 @@ def _schedule_prefills(
ignored_seq_groups: List[SequenceGroup] = []
seq_groups: List[ScheduledSequenceGroup] = []
+ applicable_spyre_warmup_shapes = list(
+ self.scheduler_config.spyre_warmup_shapes)
+
waiting_queue = self.waiting
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
@@ -941,6 +944,54 @@ def _schedule_prefills(
num_new_seqs=num_new_seqs)):
break
+ # check if current request can be scheduled based on the applicable
+ # spyre warmup shapes
+ if self.scheduler_config.spyre_scheduling_enabled:
+ max_tokens = 0
+ if seq_group.sampling_params is not None and\
+ seq_group.sampling_params.max_tokens is not None:
+ max_tokens = seq_group.sampling_params.max_tokens
+ updated_spyre_warmup_shapes = [
+ shape for shape in applicable_spyre_warmup_shapes
+ if num_new_tokens <= shape['prompt_length']
+ and max_tokens <= shape['new_tokens']
+ and len(seq_groups) < shape['batch_size']
+ ]
+ if not updated_spyre_warmup_shapes:
+ if not seq_groups:
+ # request was tested against all spyre warmup shapes:
+ # request cannot be processed
+ if (seq_group.sampling_params is not None
+ and seq_group.sampling_params.max_tokens
+ is not None):
+ logger.warning(
+ "No applicable warmup shape exists for "
+ "combination of prompt length (%d tokens) "
+ "and maximum number of output tokens to be "
+ "generated (%d tokens)", num_new_tokens,
+ seq_group.sampling_params.max_tokens)
+ else:
+ logger.warning(
+ "No applicable warmup shape exists for "
+ "combination of prompt length (%d tokens) "
+ "and undefined maximum number of output "
+ "tokens", num_new_tokens)
+ for seq in waiting_seqs:
+ seq.status = SequenceStatus.FINISHED_IGNORED
+ ignored_seq_groups.append(seq_group)
+ waiting_queue.popleft()
+ continue
+ else:
+ # request was only tested against spyre warmup shapes
+ # that remain after processing previous requests in
+ # waiting queue: request will be evaluated again in
+ # a future scheduling step
+ leftover_waiting_sequences.appendleft(seq_group)
+ waiting_queue.popleft()
+ continue
+ else:
+ applicable_spyre_warmup_shapes = updated_spyre_warmup_shapes
+
# Can schedule this request.
if curr_loras is not None and lora_int_id > 0:
curr_loras.add(lora_int_id)
@@ -970,6 +1021,15 @@ def _schedule_prefills(
budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens)
budget.add_num_seqs(seq_group.request_id, num_new_seqs)
+ # Check if number of scheduled requests has reached the maximum
+ # batch size of the applicable warmup shapes
+ if self.scheduler_config.spyre_scheduling_enabled and len(
+ seq_groups) >= max([
+ shape['batch_size']
+ for shape in applicable_spyre_warmup_shapes
+ ]):
+ break
+
# Queue requests that couldn't be scheduled.
waiting_queue.extendleft(leftover_waiting_sequences)
if len(seq_groups) > 0:
@@ -1007,8 +1067,11 @@ def _schedule_default(self) -> SchedulerOutputs:
running_scheduled = SchedulerRunningOutputs.create_empty()
swapped_in = SchedulerSwappedInOutputs.create_empty()
- # If any requests are swapped, prioritized swapped requests.
- if not self.swapped:
+ # Schedule new prefills only when no requests have been swapped
+ # and all previous decodes have completed.
+ if not self.swapped and (
+ not self.scheduler_config.spyre_scheduling_enabled
+ or not self.running):
prefills = self._schedule_prefills(budget,
curr_loras,
enable_chunking=False)
@@ -1198,8 +1261,13 @@ def _can_append_slots(self, seq_group: SequenceGroup,
# chunked-prefill are enabled together.
assert self.scheduler_config.is_multi_step and enable_chunking
- return self.block_manager.can_append_slots(
- seq_group=seq_group, num_lookahead_slots=num_lookahead_slots)
+ if self.scheduler_config.spyre_scheduling_enabled:
+ # heuristic below doesn't make sense when using very large
+ # blocks
+ return True
+ else:
+ return self.block_manager.can_append_slots(
+ seq_group=seq_group, num_lookahead_slots=num_lookahead_slots)
def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
# async_output_proc is allowed only when we have a single sequence
diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py
index 9288cd22c..2626652fb 100644
--- a/vllm/engine/arg_utils.py
+++ b/vllm/engine/arg_utils.py
@@ -37,6 +37,7 @@
"openvino",
"tpu",
"xpu",
+ "spyre",
"hpu",
]
@@ -394,7 +395,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
- choices=[8, 16, 32, 64, 128],
help='Token block size for contiguous chunks of '
'tokens. This is ignored on neuron devices and '
'set to max-model-len')
diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py
index 5a5388708..0b6f3a00d 100644
--- a/vllm/engine/async_llm_engine.py
+++ b/vllm/engine/async_llm_engine.py
@@ -617,6 +617,14 @@ def _get_executor_cls(
elif engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync
+ if engine_config.device_config.device_type == "spyre":
+ if distributed_executor_backend == "mp":
+ from vllm.executor.multiproc_spyre_executor import (
+ MultiprocessingSpyreExecutorAsync)
+ executor_class = MultiprocessingSpyreExecutorAsync
+ else:
+ from vllm.executor.spyre_executor import SpyreExecutorAsync
+ executor_class = SpyreExecutorAsync
elif engine_config.device_config.device_type == "tpu":
if distributed_executor_backend == "ray":
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py
index 2a5eaf134..9367714eb 100644
--- a/vllm/engine/llm_engine.py
+++ b/vllm/engine/llm_engine.py
@@ -516,6 +516,15 @@ def _get_executor_cls(cls,
elif engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor
+ elif engine_config.device_config.device_type == "spyre":
+ if distributed_executor_backend == "mp":
+ from vllm.executor.multiproc_spyre_executor import (
+ MultiprocessingSpyreExecutor)
+ executor_class = MultiprocessingSpyreExecutor
+ else:
+ from vllm.executor.spyre_executor import SpyreExecutor
+ executor_class = SpyreExecutor
+
elif engine_config.device_config.device_type == "tpu":
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
diff --git a/vllm/envs.py b/vllm/envs.py
index 853c49bc4..77fd1da1c 100644
--- a/vllm/envs.py
+++ b/vllm/envs.py
@@ -67,6 +67,9 @@
VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
+ VLLM_SPYRE_WARMUP_PROMPT_LENS: Optional[List[int]] = None
+ VLLM_SPYRE_WARMUP_NEW_TOKENS: Optional[List[int]] = None
+ VLLM_SPYRE_WARMUP_BATCH_SIZES: Optional[List[int]] = None
VLLM_DISABLED_KERNELS: List[str] = []
VLLM_USE_V1: bool = False
VLLM_ENABLE_V1_MULTIPROCESSING: bool = False
@@ -457,6 +460,41 @@ def get_default_config_root():
# If set, enable multiprocessing in LLM for the V1 code path.
"VLLM_ENABLE_V1_MULTIPROCESSING":
lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0"))),
+
+ # Defines the prompt lengths the Spyre accelerator should be prepared
+ # for, formatted as comma separated list.
+ "VLLM_SPYRE_WARMUP_PROMPT_LENS":
+ lambda: [
+ int(p) for p in os.getenv(key='VLLM_SPYRE_WARMUP_PROMPT_LENS',
+ default='64').split(',')
+ ],
+
+ # Defines the max output tokens the Spyre accelerator should be prepared
+ # for, formatted as comma separated list.
+ "VLLM_SPYRE_WARMUP_NEW_TOKENS":
+ lambda: [
+ int(d) for d in os.getenv(key='VLLM_SPYRE_WARMUP_NEW_TOKENS',
+ default='20').split(',')
+ ],
+
+ # Defines the batch sizes the Spyre accelerator should be prepared
+ # for, formatted as comma separated list.
+ "VLLM_SPYRE_WARMUP_BATCH_SIZES":
+ lambda: [
+ int(b) for b in os.getenv(key='VLLM_SPYRE_WARMUP_BATCH_SIZES',
+ default='1').split(',')
+ ],
+
+ # Defines the backend that torch.compile will use when using Spyre
+ # Available options:
+ # - "sendnn_decoder": Compile for execution on Spyre hardware for
+ # decoder models
+ # - "sendnn": Compile for execution on Spyre hardware for
+ # encoder models
+ # - "inductor": Compile for execution on CPU (for debug and testing)
+ # - "eager": Skip compile entirely (for debug and testing
+ "VLLM_SPYRE_DYNAMO_BACKEND":
+ lambda: os.getenv("VLLM_SPYRE_DYNAMO_BACKEND", "sendnn_decoder"),
}
# end-env-vars-definition
diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py
index 9cba189dd..7786355db 100644
--- a/vllm/executor/executor_base.py
+++ b/vllm/executor/executor_base.py
@@ -12,8 +12,8 @@ class ExecutorBase(ABC):
"""Base class for all executors.
An executor is responsible for executing the model on a specific device
- type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor
- that can execute the model on multiple devices.
+ type (e.g., CPU, GPU, Neuron, Spyre, etc.). Or it can be a distributed
+ executor that can execute the model on multiple devices.
"""
uses_ray: bool # whether the executor uses Ray for orchestration.
diff --git a/vllm/executor/multiproc_spyre_executor.py b/vllm/executor/multiproc_spyre_executor.py
new file mode 100644
index 000000000..2b3860431
--- /dev/null
+++ b/vllm/executor/multiproc_spyre_executor.py
@@ -0,0 +1,267 @@
+import json
+import os
+import platform
+import signal
+import threading
+import weakref
+from functools import partial
+from typing import Any, List, Set, Tuple
+
+import torch
+
+from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
+ ResultHandler, WorkerMonitor)
+from vllm.executor.spyre_executor import SpyreExecutor, create_worker
+from vllm.logger import init_logger
+from vllm.lora.request import LoRARequest
+from vllm.model_executor.layers.sampler import SamplerOutput
+from vllm.sequence import ExecuteModelRequest
+from vllm.utils import (get_distributed_init_method, get_open_port,
+ get_vllm_instance_id)
+
+logger = init_logger(__name__)
+
+
+class MultiprocessingSpyreExecutor(SpyreExecutor):
+ """Python multiprocessing-based multi-Spyre executor"""
+
+ uses_ray: bool = False
+
+ def determine_num_available_blocks(self) -> Tuple[int, int]:
+ """Determine the number of available KV blocks.
+
+ This invokes `determine_num_available_blocks` on each worker and takes
+ the min of the results, guaranteeing that the selected cache sizes are
+ compatible with all workers.
+
+ Returns:
+ - tuple[num_gpu_blocks, num_cpu_blocks]
+ """
+ # Get the maximum number of blocks that can be allocated on GPU and CPU.
+ num_blocks = self._run_workers("determine_num_available_blocks", )
+
+ # Since we use a shared centralized controller, we take the minimum
+ # number of blocks across all workers to make sure all the memory
+ # operators can be applied to all workers.
+ num_gpu_blocks = min(b[0] for b in num_blocks)
+ num_cpu_blocks = min(b[1] for b in num_blocks)
+
+ return num_gpu_blocks, num_cpu_blocks
+
+ def initialize_cache(self, num_gpu_blocks: int,
+ num_cpu_blocks: int) -> None:
+ """Initialize the KV cache in all workers.
+ """
+
+ # NOTE: We log here to avoid multiple logs when number of workers is
+ # greater than one. We could log in the engine, but not all executors
+ # have GPUs.
+ logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
+ num_cpu_blocks)
+
+ self.cache_config.num_gpu_blocks = num_gpu_blocks
+ self.cache_config.num_cpu_blocks = num_cpu_blocks
+
+ self._run_workers("initialize_cache",
+ num_gpu_blocks=num_gpu_blocks,
+ num_cpu_blocks=num_cpu_blocks)
+
+ def _init_executor(self) -> None:
+ self._check_executor_parameters()
+
+ # Create the parallel GPU workers.
+ world_size = self.parallel_config.world_size
+ tensor_parallel_size = self.parallel_config.tensor_parallel_size
+
+ # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers
+ os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id()
+
+ # Disable torch async compiling which won't work with daemonic processes
+ # [tom] it doesn't seme to work setting this from the code, we need to
+ # set at command line. hopefully will be fixed by upgrading torch.
+ os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
+
+ # Configure thread parallelism if OMP_NUM_THREADS isn't set
+ #
+ # Helps to avoid CPU contention. The default of spawning a thread per
+ # core combined with multiprocessing for each GPU can have a negative
+ # impact on performance. The contention is amplified when running in a
+ # container where CPU limits can cause throttling.
+ default_omp_num_threads = 1
+ if "OMP_NUM_THREADS" not in os.environ and (
+ current_parallelism :=
+ torch.get_num_threads()) > default_omp_num_threads:
+ logger.warning(
+ "Reducing Torch parallelism from %d threads to %d to avoid "
+ "unnecessary CPU contention. Set OMP_NUM_THREADS in the "
+ "external environment to tune this value as needed.",
+ current_parallelism, default_omp_num_threads)
+ os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
+ torch.set_num_threads(default_omp_num_threads)
+
+ # Multiprocessing-based executor does not support multi-node setting.
+ # Since it only works for single node, we can use the loopback address
+ # 127.0.0.1 for communication.
+ distributed_init_method = get_distributed_init_method(
+ "127.0.0.1", get_open_port())
+
+ self.workers: List[ProcessWorkerWrapper] = []
+ # This is the list of workers that are rank 0 of each TP group EXCEPT
+ # global rank 0. These are the workers that will broadcast to the
+ # rest of the workers.
+ self.tp_driver_workers: List[ProcessWorkerWrapper] = []
+ # This is the list of workers that are not drivers and not the first
+ # worker in a TP group. These are the workers that will be
+ # broadcasted to.
+ self.non_driver_workers: List[ProcessWorkerWrapper] = []
+
+ if world_size == 1:
+ self.worker_monitor = None
+ else:
+ result_handler = ResultHandler()
+ for rank in range(1, world_size):
+ worker = ProcessWorkerWrapper(
+ result_handler,
+ partial(
+ create_worker,
+ **self._get_create_worker_kwargs(
+ rank=rank,
+ local_rank=rank,
+ distributed_init_method=distributed_init_method,
+ )))
+ self.workers.append(worker)
+ if rank % tensor_parallel_size == 0:
+ self.tp_driver_workers.append(worker)
+ else:
+ self.non_driver_workers.append(worker)
+
+ self.worker_monitor = WorkerMonitor(self.workers, result_handler)
+ result_handler.start()
+ self.worker_monitor.start()
+
+ # Set up signal handlers to shutdown the executor cleanly
+ # sometimes gc does not work well
+
+ # Use weakref to avoid holding a reference to self
+ ref = weakref.ref(self)
+
+ def shutdown(signum, frame):
+ if executor := ref():
+ executor.shutdown()
+
+ if threading.current_thread() is threading.main_thread():
+ signal.signal(signal.SIGINT, shutdown)
+ signal.signal(signal.SIGTERM, shutdown)
+
+ self.driver_worker = self._create_worker(
+ distributed_init_method=distributed_init_method)
+ self._run_workers("init_device")
+ self._run_workers("load_model")
+
+ def _check_executor_parameters(self):
+ world_size = self.parallel_config.world_size
+ tensor_parallel_size = self.parallel_config.tensor_parallel_size
+
+ # Read number of Spyre cards from senlib config file
+ if platform.machine() == "s390x":
+ spyre_device_count = int(os.getenv("AIU_WORLD_SIZE", 1))
+ else:
+ with open("/etc/aiu/senlib_config.json", 'rb') as f:
+ config = json.load(f)
+ spyre_device_count = len(config["GENERAL"]["sen_bus_id"])
+
+ # Use confusing message for more common TP-only case.
+ assert tensor_parallel_size <= spyre_device_count, (
+ f"please set tensor_parallel_size ({tensor_parallel_size}) "
+ f"to less than max local gpu count ({spyre_device_count})")
+
+ assert world_size <= spyre_device_count, (
+ f"please ensure that world_size ({world_size}) "
+ f"is less than than max local gpu count ({spyre_device_count})")
+
+ def shutdown(self):
+ if (worker_monitor := getattr(self, "worker_monitor",
+ None)) is not None:
+ worker_monitor.close()
+
+ def execute_model(
+ self,
+ execute_model_req: ExecuteModelRequest,
+ ) -> List[SamplerOutput]:
+
+ output = self._run_workers("execute_model",
+ execute_model_req=execute_model_req)
+ return output[0]
+
+ def add_lora(self, lora_request: LoRARequest) -> bool:
+ assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
+ return self._run_workers(
+ "add_lora",
+ lora_request=lora_request,
+ )
+
+ def remove_lora(self, lora_id: int) -> bool:
+ assert lora_id > 0, "lora_id must be greater than 0."
+ return self._run_workers(
+ "remove_lora",
+ lora_id=lora_id,
+ )
+
+ def pin_lora(self, lora_id: int) -> bool:
+ assert lora_id > 0, "lora_id must be greater than 0."
+ return self._run_workers(
+ "pin_lora",
+ lora_id=lora_id,
+ )
+
+ def list_loras(self) -> Set[int]:
+ return self._run_workers("list_loras")
+
+ def _run_workers(
+ self,
+ method: str,
+ *args,
+ **kwargs,
+ ) -> Any:
+ """Runs the given method on all workers.
+ """
+
+ # Start all remote workers first.
+ worker_outputs = [
+ worker.execute_method(method, *args, **kwargs)
+ for worker in self.workers
+ ]
+
+ driver_worker_method = getattr(self.driver_worker, method)
+ driver_worker_output = driver_worker_method(*args, **kwargs)
+
+ # Get the results of the workers.
+ return [driver_worker_output
+ ] + [output.get() for output in worker_outputs]
+
+ def check_health(self) -> None:
+ """Raises an error if engine is unhealthy."""
+ if self.worker_monitor is not None and not self.worker_monitor.is_alive(
+ ):
+ raise RuntimeError("Worker processes are not running")
+
+
+class MultiprocessingSpyreExecutorAsync(MultiprocessingSpyreExecutor):
+
+ async def execute_model_async(
+ self,
+ execute_model_req: ExecuteModelRequest,
+ ) -> List[SamplerOutput]:
+
+ # this is not really async, rather this a blocking call.
+ # there may well be perf implications, not sure.
+ # However, Spyre does not seem to play nice with asyncio.
+ # I think vLLM is moving away from the AsyncLLMEngine,
+ # so let's not waste time here for now.
+ output = self._run_workers("execute_model",
+ execute_model_req=execute_model_req)
+
+ return output[0]
+
+ async def stop_remote_worker_execution_loop_async(self) -> None:
+ return
diff --git a/vllm/executor/spyre_executor.py b/vllm/executor/spyre_executor.py
new file mode 100644
index 000000000..b07b620ae
--- /dev/null
+++ b/vllm/executor/spyre_executor.py
@@ -0,0 +1,180 @@
+import os
+from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
+
+from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
+from vllm.logger import init_logger
+from vllm.lora.request import LoRARequest
+from vllm.model_executor.layers.sampler import SamplerOutput
+from vllm.sequence import ExecuteModelRequest
+from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
+ make_async)
+from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase
+
+logger = init_logger(__name__)
+
+
+def create_worker(worker_module_name: str, worker_class_name: str,
+ worker_class_fn: Optional[Callable[[], Type[WorkerBase]]],
+ **kwargs):
+ wrapper = WorkerWrapperBase(
+ worker_module_name=worker_module_name,
+ worker_class_name=worker_class_name,
+ worker_class_fn=worker_class_fn,
+ )
+ wrapper.init_worker(**kwargs)
+ return wrapper.worker
+
+
+class SpyreExecutor(ExecutorBase):
+
+ uses_ray: bool = False
+
+ def _init_executor(self) -> None:
+ assert (self.lora_config is
+ None), "LoRA is not supported for Spyre backend."
+ assert (not self.speculative_config
+ ), "Speculative decoding not yet supported for Spyre backend."
+
+ #assert self.parallel_config.world_size == 1, (
+ # "SpyreExecutor only supports single Spyre card.")
+
+ if os.getenv(key='SPYRE_PYTEST_DEBUG', default='0') == '1':
+ import debugpy
+ host_addr = os.getenv(key='SPYRE_PYTEST_DBG_ADDR',
+ default='0.0.0.0')
+ debugpy.listen((host_addr, 5678))
+ print(f"[debugpy] {host_addr}: wait for client...\n")
+ debugpy.wait_for_client()
+
+ self.driver_worker = self._create_worker()
+ self.driver_worker.init_device()
+ self.driver_worker.load_model()
+
+ def _get_worker_kwargs(
+ self,
+ local_rank: int = 0,
+ rank: int = 0,
+ distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
+ """Return worker init args for a given rank."""
+ if distributed_init_method is None:
+ distributed_init_method = get_distributed_init_method(
+ get_ip(), get_open_port())
+ return dict(
+ model_config=self.model_config,
+ parallel_config=self.parallel_config,
+ scheduler_config=self.scheduler_config,
+ device_config=self.device_config,
+ cache_config=self.cache_config,
+ local_rank=local_rank,
+ rank=rank,
+ distributed_init_method=distributed_init_method,
+ is_driver_worker=(not self.parallel_config)
+ or (rank % self.parallel_config.tensor_parallel_size == 0),
+ )
+
+ def _get_worker_module_and_class(
+ self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]:
+ worker_class_fn = None
+ worker_module_name = "vllm.worker.spyre_worker"
+ worker_class_name = "SpyreWorker"
+ return (worker_module_name, worker_class_name, worker_class_fn)
+
+ def _get_create_worker_kwargs(
+ self,
+ local_rank: int = 0,
+ rank: int = 0,
+ distributed_init_method: Optional[str] = None) -> Dict:
+
+ worker_kwargs = self._get_worker_kwargs(local_rank, rank,
+ distributed_init_method)
+
+ (worker_module_name, worker_class_name,
+ worker_class_fn) = self._get_worker_module_and_class()
+ worker_kwargs.update(
+ worker_module_name=worker_module_name,
+ worker_class_name=worker_class_name,
+ worker_class_fn=worker_class_fn,
+ )
+ return worker_kwargs
+
+ def _create_worker(self,
+ local_rank: int = 0,
+ rank: int = 0,
+ distributed_init_method: Optional[str] = None):
+ return create_worker(**self._get_create_worker_kwargs(
+ local_rank=local_rank,
+ rank=rank,
+ distributed_init_method=distributed_init_method))
+
+ def determine_num_available_blocks(self) -> Tuple[int, int]:
+ """Determine the number of available KV blocks by invoking the
+ underlying worker.
+ """
+ return self.driver_worker.determine_num_available_blocks()
+
+ def initialize_cache(self, num_gpu_blocks: int,
+ num_cpu_blocks: int) -> None:
+ """Initialize the KV cache by invoking the underlying worker.
+ """
+ self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
+
+ def execute_model(
+ self,
+ execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
+
+ assert execute_model_req.num_lookahead_slots == 0, (
+ "lookahead not supported for Spyre backend.")
+
+ output = self.driver_worker.execute_model(execute_model_req)
+
+ return output
+
+ def add_lora(self, lora_request: LoRARequest) -> bool:
+ return self.driver_worker.add_lora(lora_request)
+
+ def remove_lora(self, lora_id: int) -> bool:
+ return self.driver_worker.remove_lora(lora_id)
+
+ def pin_lora(self, lora_id: int) -> bool:
+ return self.driver_worker.pin_lora(lora_id)
+
+ def list_loras(self) -> Set[int]:
+ return self.driver_worker.list_loras()
+
+ def add_prompt_adapter(self, prompt_adapter_request) -> bool:
+ raise NotImplementedError(
+ "Soft prompt is currently not supported by the Spyre backend.")
+
+ def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
+ raise NotImplementedError(
+ "Soft prompt is currently not supported by the Spyre backend.")
+
+ def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
+ raise NotImplementedError(
+ "Soft prompt is currently not supported by the Spyre backend.")
+
+ def list_prompt_adapters(self) -> Set[int]:
+ raise NotImplementedError(
+ "Soft prompt is currently not supported by the Spyre backend.")
+
+ def check_health(self) -> None:
+ # SpyreExecutor will always be healthy as long as
+ # it's running.
+ return
+
+
+class SpyreExecutorAsync(SpyreExecutor, ExecutorAsyncBase):
+
+ async def execute_model_async(
+ self,
+ execute_model_req: ExecuteModelRequest,
+ ) -> List[SamplerOutput]:
+ output = await make_async(
+ self.driver_worker.execute_model
+ )(seq_group_metadata_list=execute_model_req.seq_group_metadata_list, )
+ return output
+
+ async def check_health_async(self) -> None:
+ # SpyreExecutor will always be healthy as long as
+ # it's running.
+ return
diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py
index c9366ca97..8006e0a91 100644
--- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py
+++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py
@@ -61,10 +61,12 @@ def _check_marlin_supported(
has_zp, device_capability)
if quant_type not in supported_types:
- return (False, f"Marlin does not support weight_bits = {quant_type}. "
- f"Only types = {supported_types} "
- f"are supported (for group_size = {group_size}, "
- f"device_capability = {device_capability}, zp = {has_zp}).")
+ return (
+ False,
+ f"Marlin does not support weight_bits = {quant_type.size_bits}. "
+ f"Only types = {supported_types} "
+ f"are supported (for group_size = {group_size}, "
+ f"device_capability = {device_capability}, zp = {has_zp}).")
if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES):
return (False, f"Marlin does not support group_size = {group_size}. "
f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
diff --git a/vllm/model_executor/model_loader/spyre.py b/vllm/model_executor/model_loader/spyre.py
new file mode 100644
index 000000000..d2c4634e4
--- /dev/null
+++ b/vllm/model_executor/model_loader/spyre.py
@@ -0,0 +1,213 @@
+"""Utilities for selecting and loading Spyre models."""
+import sys
+from typing import List, Optional
+
+import torch
+import torch._inductor.config
+import torch.distributed as dist
+import torch.nn as nn
+from fms.models import get_model
+from transformers import PretrainedConfig
+
+import vllm.envs as envs
+from vllm.config import ModelConfig, ParallelConfig
+from vllm.logger import init_logger
+from vllm.model_executor.layers.logits_processor import LogitsProcessor
+from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
+from vllm.model_executor.sampling_metadata import SamplingMetadata
+from vllm.sequence import SequenceGroupMetadata
+
+try:
+ from torch_sendnn import torch_sendnn # noqa: F401
+except ImportError:
+ print("WARNING: Disabled: torch_sendnn")
+ pass
+try:
+ import backends.dynamo_tracer # noqa: F401
+except ImportError:
+ print("WARNING: Disabled: dynamo_tracer")
+ pass
+
+BACKEND_LIST = ['sendnn_decoder', 'inductor']
+
+logger = init_logger(__name__)
+
+
+class SpyreCausalLM(nn.Module):
+
+ def __init__(
+ self,
+ config: PretrainedConfig,
+ ) -> None:
+ super().__init__()
+ self.config = config
+ self.logits_processor = LogitsProcessor(config.vocab_size,
+ logits_as_input=True)
+ self.sampler = Sampler()
+ self.past_key_value_states = None
+ self.dtype = torch.float16 if envs.VLLM_SPYRE_DYNAMO_BACKEND == \
+ 'sendnn_decoder' else torch.float32
+ # number of added padding sequences to fill
+ # batch to warmed up batch size
+ self.num_padded_sequences = 0
+
+ # Lazy initialized
+ self.model: nn.Module
+
+ def forward(
+ self,
+ input_ids: torch.Tensor,
+ positions: torch.Tensor,
+ masks: torch.Tensor,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ ) -> torch.Tensor:
+
+ is_prompt = seq_group_metadata_list[0].is_prompt
+ if is_prompt:
+ self.past_key_value_states = None
+
+ extra_kwargs = {}
+ if envs.VLLM_SPYRE_DYNAMO_BACKEND != "sendnn_decoder":
+ # Bug in 2.3.1 fixed in 2.4.1 for SDPA flash
+ # cpu impl when padding too much
+ extra_kwargs["attn_algorithm"] = "math"
+
+ output = self.model(
+ input_ids,
+ position_ids=positions,
+ mask=masks,
+ past_key_value_states=self.past_key_value_states,
+ use_cache=True,
+ only_last_token=True,
+ **extra_kwargs,
+ )
+
+ logits, past_key_value_states = output
+ self.past_key_value_states = past_key_value_states
+
+ # mark dynamic
+ if self.past_key_value_states is not None:
+ for layer in self.past_key_value_states:
+ for tensor in layer:
+ torch._dynamo.mark_dynamic(tensor, 2)
+
+ # removing batch padding sequences to compute logits
+ batch_size = input_ids.shape[0]
+
+ logits = logits[:batch_size - self.num_padded_sequences]
+
+ return logits
+
+ def compute_logits(self, hidden_states: torch.Tensor,
+ sampling_metadata: SamplingMetadata) -> torch.Tensor:
+ logits = self.logits_processor(None, hidden_states, sampling_metadata)
+ return logits
+
+ def sample(
+ self,
+ logits: torch.Tensor,
+ sampling_metadata: SamplingMetadata,
+ ) -> Optional[SamplerOutput]:
+ next_tokens = self.sampler(logits, sampling_metadata)
+ return next_tokens
+
+ def load_weights(self, model_config: ModelConfig, max_prompt_length: int,
+ max_decode_length: int,
+ distributed_strategy: Optional[str], **kwargs):
+
+ if self.dtype is not model_config.dtype:
+ logger.info(
+ "Ignoring user-provided dtype=%s and using dtype=%s instead.",
+ model_config.dtype, self.dtype)
+
+ if model_config.quantization == "gptq":
+
+ # note, we have to find a better way to package this
+ # shouldn't it be part of FMS?
+ sys.path.append("/home/senuser/aiu-fms")
+
+ if envs.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder":
+ from aiu_as_addon import aiu_adapter, aiu_linear # noqa: F401
+ linear_type = "gptq_aiu"
+ print("Loaded `aiu_as_addon` functionalities")
+ else:
+ from cpu_addon import cpu_linear # noqa: F401
+ linear_type = "gptq_cpu"
+ print("Loaded `cpu_addon` functionalities")
+
+ quant_cfg = model_config._parse_quant_hf_config()
+
+ linear_config = {
+ "linear_type": linear_type,
+ "group_size": quant_cfg['group_size'],
+ "desc_act": quant_cfg['desc_act'],
+ }
+ data_type = None
+ model_source = "llama_gptq_hf_unfused_aiu"
+ else:
+ linear_config = {"linear_type": "torch_linear"}
+ data_type = self.dtype
+ model_source = "hf"
+
+ # we can use fused weights unless running on Spyre
+ fused_weights = envs.VLLM_SPYRE_DYNAMO_BACKEND != "sendnn_decoder"
+
+ self.model = get_model(architecture="hf_configured",
+ variant=model_config.model,
+ model_path=model_config.model,
+ source=model_source,
+ data_type=data_type,
+ distributed_strategy=distributed_strategy,
+ group=dist.group.WORLD,
+ fused_weights=fused_weights,
+ linear_config=linear_config)
+
+ compile_mode = "default"
+
+ self.model.eval()
+ torch.set_grad_enabled(False)
+
+ _target_cache_size = max(int(max_decode_length * 2),
+ int(max_prompt_length * 2.5))
+ if hasattr(torch._dynamo.config, "accumulated_cache_size_limit") and \
+ _target_cache_size > torch._dynamo.config.\
+ accumulated_cache_size_limit:
+ _prev = torch._dynamo.config.accumulated_cache_size_limit
+ torch._dynamo.config.accumulated_cache_size_limit = \
+ _target_cache_size
+ print("NOTICE: Adjusting "
+ "torch._dynamo.config.accumulated_cache_size_limit"
+ f" from {_prev} to "
+ f"{torch._dynamo.config.accumulated_cache_size_limit} "
+ f"to accommodate prompt size of {max_prompt_length} "
+ f"and decode tokens of {max_decode_length}")
+
+ if _target_cache_size > torch._dynamo.config.cache_size_limit:
+ _prev = torch._dynamo.config.cache_size_limit
+ torch._dynamo.config.cache_size_limit = _target_cache_size
+ print(
+ "NOTICE: Adjusting torch._dynamo.config.cache_size_limit from"
+ f" {_prev} to {torch._dynamo.config.cache_size_limit} to "
+ f"accommodate prompt size of {max_prompt_length} and "
+ f"decode tokens of {max_decode_length}")
+
+ if envs.VLLM_SPYRE_DYNAMO_BACKEND in BACKEND_LIST:
+ self.model = torch.compile(self.model,
+ mode=compile_mode,
+ backend=envs.VLLM_SPYRE_DYNAMO_BACKEND)
+
+
+def get_spyre_model(model_config: ModelConfig, parallel_config: ParallelConfig,
+ max_prompt_length, max_decode_length) -> nn.Module:
+
+ # Create a model instance.
+ model = SpyreCausalLM(model_config.hf_config)
+
+ # Load the weights from the cached or downloaded files.
+ model.load_weights(
+ model_config,
+ max_prompt_length=max_prompt_length,
+ max_decode_length=max_decode_length,
+ distributed_strategy="tp" if parallel_config.world_size > 1 else None)
+
+ return model
diff --git a/vllm/model_executor/model_loader/spyre_setup.py b/vllm/model_executor/model_loader/spyre_setup.py
new file mode 100644
index 000000000..4ebb2c536
--- /dev/null
+++ b/vllm/model_executor/model_loader/spyre_setup.py
@@ -0,0 +1,145 @@
+import json
+import os
+import sys
+
+import torch
+
+# ==============================================================
+# Common utilities
+# ==============================================================
+#-------------
+# Discover the world size and my rank (envars set by torchrun)
+# https://pytorch.org/docs/stable/elastic/run.html#environment-variables
+#-------------
+local_rank = int(os.getenv("LOCAL_RANK", 0))
+rank = int(os.getenv("RANK", 0))
+world_rank = rank
+world_size = int(os.getenv("WORLD_SIZE", 1))
+
+def dprint(text):
+ print(f"[{rank:2d}/{world_size:2d}]: {text}")
+
+# ==============================================================
+# Common setup
+# ==============================================================
+def spyre_setup(rank=0, world_size=1, local_rank=0, local_size=1, verbose=False):
+ # -------------
+ # Envar setup for backend
+ # -------------
+ # Environment variable created by the runtime to identify the specific Spyre card that is assigned to this rank
+ spyre_config_file_envar = "AIU_CONFIG_FILE_" + str(rank)
+
+ # Default to senulator backend unless user specified otherwise
+ os.environ.setdefault("FLEX_COMPUTE", "SENULATOR")
+ os.environ.setdefault("FLEX_DEVICE", "MOCK")
+
+ # Each rank needs a unique space to write its binaries
+ # For both 'export' and '__pycache'
+ # https://docs.python.org/3/library/sys.html#sys.pycache_prefix
+ os.environ.setdefault("DEEPRT_EXPORT_DIR", "export")
+ os.environ.setdefault("DTCOMPILER_EXPORT_DIR", "export")
+ if world_size > 1:
+ os.environ["DEEPRT_EXPORT_DIR"] += f"/{rank}"
+ os.environ["DTCOMPILER_EXPORT_DIR"] += f"/{rank}"
+ sys.pycache_prefix=os.getenv("DEEPRT_EXPORT_DIR")+"/py-" + str(rank)
+ os.environ.setdefault("DTCOMPILER_KEEP_EXPORT", "1")
+
+ # Inform Flex of the size of this job
+ os.environ.setdefault("FLEX_RDMA_WORLD_SIZE", str(world_size))
+ os.environ.setdefault("FLEX_RDMA_WORLD_RANK", str(rank))
+ os.environ.setdefault("FLEX_RDMA_LOCAL_SIZE", str(world_size))
+ os.environ.setdefault("FLEX_RDMA_LOCAL_RANK", str(rank))
+ for peer_rank in range(world_size):
+ pcie_env_str="AIU_WORLD_RANK_"+str(peer_rank)
+ flex_env_str="FLEX_RDMA_PCI_BUS_ADDR_"+str(peer_rank)
+ if os.getenv(pcie_env_str) is None:
+ raise RuntimeError(f"Error: The environment variable {pcie_env_str} is not defined")
+ if os.getenv(flex_env_str) is None:
+ raise RuntimeError(f"Error: The environment variable {flex_env_str} is not defined")
+ if os.getenv("DUMP_MEMMAP") is not None:
+ if os.getenv("SDSC_REF_DIR") is None:
+ os.environ["SDSC_REF_DIR"] = os.environ["DEEPRT_EXPORT_DIR"]
+ else:
+ os.environ["SDSC_REF_DIR"] += f"/{rank}"
+ assert (
+ os.getenv("DUMP_MEMMAP_DIR") is not None
+ ), "DUMP_MEMMAP_DIR not set while DUMP_MEMMAP set"
+ os.environ["DUMP_MEMMAP_DIR"] += f"/{rank}"
+ os.makedirs(
+ os.environ["DUMP_MEMMAP_DIR"], exist_ok=True
+ ) # directory needs to exist
+
+ for peer_rank in range(world_size):
+ pcie_env_str = "AIU_WORLD_RANK_" + str(peer_rank)
+ flex_env_str = "FLEX_RDMA_PCI_BUS_ADDR_" + str(peer_rank)
+ if os.getenv("FLEX_COMPUTE") == "SENULATOR":
+ if os.getenv(pcie_env_str) is not None:
+ os.environ[flex_env_str] = os.getenv(pcie_env_str)
+ else:
+ os.environ[pcie_env_str] = f"0000:{rank:02x}:01.0"
+ os.environ[flex_env_str] = f"0000:{rank:02x}:01.0"
+ else:
+ if os.getenv(flex_env_str) is None:
+ if os.getenv("PCIDEVICE_IBM_COM_SENTIENT_PF") is not None:
+ os.environ[pcie_env_str] = os.getenv(
+ "PCIDEVICE_IBM_COM_SENTIENT_PF"
+ )
+
+ if os.getenv(pcie_env_str) is not None:
+ os.environ[flex_env_str] = os.getenv(pcie_env_str)
+ else:
+ raise RuntimeError(
+ f"[{rank}/{world_size}]: ERROR: {flex_env_str} and {pcie_env_str} were not set for peer {peer_rank}."
+ )
+ if rank == 0 and verbose:
+ dprint(f"PCI Addr Rank {peer_rank} {pcie_env_str}={os.environ[pcie_env_str]}")
+ dprint(f"PCI Addr Rank {peer_rank} {flex_env_str}={os.environ[flex_env_str]}")
+
+ if rank == 0 and verbose:
+ dprint(f"FLEX_COMPUTE=" + os.getenv("FLEX_COMPUTE"))
+ dprint(f"FLEX_DEVICE=" + os.getenv("FLEX_DEVICE"))
+ dprint(f"DEEPRT_EXPORT_DIR=" + os.getenv("DEEPRT_EXPORT_DIR"))
+ dprint(f"DTCOMPILER_EXPORT_DIR=" + os.getenv("DTCOMPILER_EXPORT_DIR"))
+ if os.getenv(spyre_config_file_envar) is not None:
+ dprint(f"{spyre_config_file_envar}=" + os.environ[spyre_config_file_envar])
+ if os.getenv("SENLIB_DEVEL_CONFIG_FILE") is not None:
+ dprint(f"SENLIB_DEVEL_CONFIG_FILE=" + os.environ["SENLIB_DEVEL_CONFIG_FILE"])
+ if os.getenv(flex_env_str) is not None:
+ dprint(f"{flex_env_str}=" + os.environ[flex_env_str])
+ dprint(f"FLEX_RDMA_LOCAL_RANK=" + os.getenv("FLEX_RDMA_LOCAL_RANK"))
+ dprint(f"FLEX_RDMA_LOCAL_SIZE=" + os.getenv("FLEX_RDMA_LOCAL_SIZE"))
+ dprint(f"FLEX_RDMA_WORLD_RANK=" + os.getenv("FLEX_RDMA_WORLD_RANK"))
+ dprint(f"FLEX_RDMA_WORLD_SIZE=" + os.getenv("FLEX_RDMA_WORLD_SIZE"))
+
+ if os.getenv("FLEX_COMPUTE") == "SENTIENT":
+ pcie_env_str = "AIU_WORLD_RANK_" + str(rank)
+ if os.getenv(pcie_env_str) is not None:
+ device_id = os.getenv(pcie_env_str)
+ else:
+ with open(os.getenv(spyre_config_file_envar)) as fd:
+ data = json.load(fd)
+ device_id = data["GENERAL"]["sen_bus_id"]
+ dprint(f"Spyre: Enabled ({device_id})")
+ else:
+ dprint(f"Spyre: Disabled (Senulator)")
+
+
+# ==============================================================
+# Distributed setup
+# ==============================================================
+def spyre_dist_setup(rank, world_size, local_rank=-0, local_size=-1, verbose=False):
+ if local_rank < 0:
+ local_rank = rank
+ if local_size < 0:
+ local_size = world_size
+
+ if os.getenv("TORCHELASTIC_RUN_ID") is None:
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = "12355"
+ elif rank == 0 or verbose:
+ dprint(f"Detected running via torchrun")
+
+ if rank == 0 or verbose:
+ dprint(f"Parallel Backend: {torch.distributed.get_backend()}")
+
+ spyre_setup(rank, world_size)
\ No newline at end of file
diff --git a/vllm/platforms/__init__.py b/vllm/platforms/__init__.py
index 1f68fc2e2..04b670c77 100644
--- a/vllm/platforms/__init__.py
+++ b/vllm/platforms/__init__.py
@@ -83,6 +83,13 @@
except Exception:
pass
+is_spyre = False
+try:
+ from importlib.metadata import version
+ is_spyre = "spyre" in version("vllm")
+except Exception:
+ pass
+
if is_tpu:
# people might install pytorch built with cuda but run on tpu
# so we need to check tpu first
@@ -109,6 +116,9 @@
elif is_openvino:
from .openvino import OpenVinoPlatform
current_platform = OpenVinoPlatform()
+elif is_spyre:
+ from .spyre import SpyrePlatform
+ current_platform = SpyrePlatform()
else:
current_platform = UnspecifiedPlatform()
diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py
index f4849fa2c..1bf5434c5 100644
--- a/vllm/platforms/interface.py
+++ b/vllm/platforms/interface.py
@@ -34,6 +34,7 @@ class PlatformEnum(enum.Enum):
CPU = enum.auto()
NEURON = enum.auto()
OPENVINO = enum.auto()
+ SPYRE = enum.auto()
UNSPECIFIED = enum.auto()
@@ -81,6 +82,9 @@ def is_neuron(self) -> bool:
def is_openvino(self) -> bool:
return self._enum == PlatformEnum.OPENVINO
+ def is_spyre(self) -> bool:
+ return self._enum == PlatformEnum.SPYRE
+
def is_cuda_alike(self) -> bool:
"""Stateless version of :func:`torch.cuda.is_available`."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
diff --git a/vllm/platforms/spyre.py b/vllm/platforms/spyre.py
new file mode 100644
index 000000000..3634c1f4f
--- /dev/null
+++ b/vllm/platforms/spyre.py
@@ -0,0 +1,9 @@
+from .interface import Platform, PlatformEnum
+
+
+class SpyrePlatform(Platform):
+ _enum = PlatformEnum.SPYRE
+
+ @classmethod
+ def get_device_name(cls, device_id: int = 0) -> str:
+ return "spyre"
diff --git a/vllm/utils.py b/vllm/utils.py
index 2bbdc8d1e..fe9bc9d77 100644
--- a/vllm/utils.py
+++ b/vllm/utils.py
@@ -726,6 +726,9 @@ def is_pin_memory_available() -> bool:
elif current_platform.is_neuron():
print_warning_once("Pin memory is not supported on Neuron.")
return False
+ elif current_platform.is_spyre():
+ print_warning_once("Pin memory is not supported on Spyre device.")
+ return False
elif current_platform.is_hpu():
print_warning_once("Pin memory is not supported on HPU.")
return False
diff --git a/vllm/worker/spyre_embedding_model_runner.py b/vllm/worker/spyre_embedding_model_runner.py
new file mode 100644
index 000000000..447742d02
--- /dev/null
+++ b/vllm/worker/spyre_embedding_model_runner.py
@@ -0,0 +1,171 @@
+import time
+from typing import Dict, Iterable, List, Optional, Tuple
+
+import torch
+from transformers import AutoModel
+
+import vllm.envs as envs
+from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
+ SchedulerConfig)
+from vllm.logger import init_logger
+from vllm.model_executor.layers.pooler import Pooler, PoolingType
+from vllm.model_executor.pooling_metadata import PoolingMetadata
+from vllm.pooling_params import PoolingParams
+from vllm.sequence import PoolerOutput, SequenceData, SequenceGroupMetadata
+
+from .spyre_model_runner import SpyreModelRunner
+
+logger = init_logger(__name__)
+
+BACKEND_LIST = ['sendnn', 'inductor']
+
+
+class SpyreEmbeddingModelRunner(SpyreModelRunner):
+
+ # Map of request_id -> generator used for seeded random sampling
+ generators: Dict[str, torch.Generator] = {}
+
+ def __init__(
+ self,
+ model_config: ModelConfig,
+ parallel_config: ParallelConfig,
+ scheduler_config: SchedulerConfig,
+ device_config: DeviceConfig,
+ ):
+ super().__init__(model_config=model_config,
+ parallel_config=parallel_config,
+ scheduler_config=scheduler_config,
+ device_config=device_config)
+
+ pooler_config = model_config.pooler_config
+ self.pooler = Pooler.from_config_with_defaults(
+ pooler_config,
+ pooling_type=PoolingType.CLS,
+ normalize=True,
+ softmax=False)
+
+ def load_model(self, prompt_lens: Iterable[int],
+ num_decode_tokens: Iterable[int],
+ batch_sizes: Iterable[int]) -> None:
+ self.model = AutoModel.from_pretrained(self.model_config.model)
+ self.model.eval()
+ torch.set_grad_enabled(False)
+ if envs.VLLM_SPYRE_DYNAMO_BACKEND in BACKEND_LIST:
+ self.model = torch.compile(self.model,
+ mode="default",
+ dynamic=False,
+ backend=envs.VLLM_SPYRE_DYNAMO_BACKEND)
+
+ @property
+ def vocab_size(self) -> int:
+ return self.model.config.vocab_size
+
+ def prepare_input_tensors(
+ self,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ finished_requests_ids: Optional[List[str]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, PoolingMetadata]:
+ # NOTE: We assume that all sequences in the group are all prompts
+ (input_tokens, input_positions, input_masks,
+ seq_lens) = self._prepare_prompt(seq_group_metadata_list)
+
+ pooling_metadata = self._prepare_pooling(
+ seq_group_metadata_list=seq_group_metadata_list,
+ prompt_lens=seq_lens)
+ return (input_tokens, input_positions, input_masks, pooling_metadata)
+
+ def _prepare_pooling(
+ self,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ prompt_lens: List[int],
+ ) -> PoolingMetadata:
+ """Prepare PoolingMetadata for the sequence group metadata list."""
+ seq_groups: List[Tuple[List[int], PoolingParams]] = []
+ for i, seq_group_metadata in enumerate(seq_group_metadata_list):
+ seq_ids = list(seq_group_metadata.seq_data.keys())
+ pooling_params = seq_group_metadata.pooling_params
+ seq_groups.append((seq_ids, pooling_params))
+
+ seq_data: Dict[int, SequenceData] = {}
+ for seq_group_metadata in seq_group_metadata_list:
+ seq_data.update(seq_group_metadata.seq_data)
+
+ pooling_metadata = PoolingMetadata(
+ seq_groups=seq_groups,
+ seq_data=seq_data,
+ prompt_lens=prompt_lens,
+ )
+
+ return pooling_metadata
+
+ def pad_input_ids(
+ self,
+ input_ids_list: List[torch.Tensor],
+ min_pad_length: int = 0,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ padded_input_ids_list, mask_list, position_ids_list = self.\
+ _prepare_pad_input_ids(input_ids_list, min_pad_length)
+
+ input_ids = torch.stack(padded_input_ids_list)
+ mask = torch.stack(mask_list)
+ position_ids = torch.stack(position_ids_list)
+
+ return input_ids, position_ids, mask
+
+ def execute_model(
+ self,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ finished_requests_ids: Optional[List[str]] = None,
+ ) -> Optional[PoolerOutput]:
+ (input_tokens, input_positions, input_masks,
+ pooling_metadata) = self.prepare_input_tensors(
+ seq_group_metadata_list, finished_requests_ids)
+ t0 = time.time()
+
+ outputs = self.model(
+ input_ids=input_tokens,
+ # Let the Embedding layer use it's default
+ # because the rules can be a bit different
+ # e.g. For Roberta models the inputs start
+ # at padding_inx +1
+ #position_ids=input_positions,
+ attention_mask=input_masks)
+ hidden_states = outputs["last_hidden_state"]
+
+ unpadded = []
+ max_len = hidden_states.shape[1]
+
+ for i, seq_len in enumerate(pooling_metadata.prompt_lens):
+ unpadded.append(hidden_states[i, max_len - seq_len:, :])
+
+ hidden_states = torch.concat(unpadded)
+
+ pooler_output = self.pooler(hidden_states=hidden_states,
+ pooling_metadata=pooling_metadata)
+
+ t1 = time.time() - t0
+ print("[spyre_model_runner:execute_model] t_token: %.2fms" %
+ (t1 * 1000))
+
+ return pooler_output
+
+ def _raw_model_forward(
+ self,
+ input_ids: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ past_key_value_states: Optional[List[Tuple[torch.Tensor,
+ torch.Tensor]]] = None,
+ use_cache: bool = False,
+ only_last_token: bool = False,
+ attn_algorithm: Optional[str] = None
+ ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor,
+ torch.Tensor]]]]:
+
+ hidden_states, _ = self.model(
+ input_ids=input_ids,
+ attention_mask=mask,
+ #position_ids=position_ids
+ )
+ return hidden_states, None
diff --git a/vllm/worker/spyre_model_runner.py b/vllm/worker/spyre_model_runner.py
new file mode 100644
index 000000000..e1925811c
--- /dev/null
+++ b/vllm/worker/spyre_model_runner.py
@@ -0,0 +1,349 @@
+import time
+from typing import Dict, Iterable, List, Optional, Tuple
+
+import torch
+from torch import nn
+
+from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig,
+ SchedulerConfig)
+from vllm.logger import init_logger
+from vllm.model_executor import SamplingMetadata
+from vllm.model_executor.layers.sampler import SamplerOutput
+from vllm.model_executor.model_loader.spyre import get_spyre_model
+from vllm.sequence import SequenceGroupMetadata
+from vllm.utils import is_pin_memory_available
+
+logger = init_logger(__name__)
+
+
+class SpyreModelRunner:
+
+ # Map of request_id -> generator used for seeded random sampling
+ generators: Dict[str, torch.Generator] = {}
+
+ def __init__(
+ self,
+ model_config: ModelConfig,
+ parallel_config: ParallelConfig,
+ scheduler_config: SchedulerConfig,
+ device_config: DeviceConfig,
+ ):
+ self.model_config = model_config
+ self.parallel_config = parallel_config
+ self.scheduler_config = scheduler_config
+ self.device_config = device_config
+
+ self.pad_token_id = 0
+ if model_config is not None:
+ if model_config.hf_config is not None:
+ self.pad_token_id = getattr(model_config.hf_config,
+ "pad_token_id", None) or 0
+ if model_config.get_sliding_window():
+ logger.warning("Sliding window is not supported on Spyre. "
+ "The model will run without sliding window.")
+ self.device_config = (device_config
+ if device_config is not None else DeviceConfig())
+ self.device = self.device_config.device
+ self.pin_memory = is_pin_memory_available()
+ # position_ids of all the sequences in current batch
+ self._position_ids: torch.Tensor = None
+ # attention masks of all the sequences in current batch
+ self._mask: torch.Tensor = None
+ # Lazy initialization: after load_model.
+ self.model: nn.Module
+
+ def load_model(self, prompt_lens: Iterable[int],
+ num_decode_tokens: Iterable[int],
+ batch_sizes: Iterable[int]) -> None:
+ max_pad_length = max(prompt_lens)
+ max_decode_length = max(num_decode_tokens)
+ self.model = get_spyre_model(self.model_config,
+ parallel_config=self.parallel_config,
+ max_prompt_length=max_pad_length,
+ max_decode_length=max_decode_length)
+
+ @property
+ def vocab_size(self) -> int:
+ return self.model.model.config.src_vocab_size
+
+ def _prepare_prompt(
+ self,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
+ assert len(seq_group_metadata_list) > 0
+ input_token_list: List[torch.Tensor] = []
+
+ # find warmup shape to be used for padding and batching
+ applicable_spyre_warmup_shapes = [
+ shape for shape in self.scheduler_config.spyre_warmup_shapes
+ if len(seq_group_metadata_list) <= shape['batch_size']
+ ]
+ for seq_group_metadata in seq_group_metadata_list:
+ seq_data = seq_group_metadata.seq_data[list(
+ seq_group_metadata.seq_data.keys())[0]]
+ # retrieve initial (unpadded) tokens
+ prompt_tokens = seq_data.get_token_ids()
+ new_tokens = seq_group_metadata.sampling_params.max_tokens\
+ if seq_group_metadata.sampling_params is not None else 0
+
+ updated_spyre_warmup_shapes = [
+ shape for shape in applicable_spyre_warmup_shapes
+ if len(prompt_tokens) <= shape['prompt_length']
+ and new_tokens <= shape['new_tokens']
+ ]
+ applicable_spyre_warmup_shapes = updated_spyre_warmup_shapes
+
+ assert applicable_spyre_warmup_shapes
+
+ # If multiple warmup shapes apply, the first one is selected.
+ # For improving performance, the warmup shapes in scheduler_config
+ # are ordered by "processing speed".
+ min_pad_length_batch = applicable_spyre_warmup_shapes[0][
+ 'prompt_length']
+ padded_batch_size = applicable_spyre_warmup_shapes[0]['batch_size']
+
+ for seq_group_metadata in seq_group_metadata_list:
+ assert seq_group_metadata.is_prompt
+ seq_ids = list(seq_group_metadata.seq_data.keys())
+ assert len(seq_ids) == 1
+ seq_id = seq_ids[0]
+
+ seq_data = seq_group_metadata.seq_data[seq_id]
+ # retrieve initial (unpadded) tokens
+ prompt_tokens = seq_data.get_token_ids()
+
+ input_token_list.append(
+ torch.tensor(prompt_tokens,
+ dtype=torch.long,
+ device=torch.device("cpu")))
+
+ # set number of added padding sequences used for computing logits
+ self.model.num_padded_sequences = padded_batch_size - len(
+ input_token_list)
+
+ # padding to compiled batch size
+ while len(input_token_list) < padded_batch_size:
+ input_token_list.append(
+ torch.zeros(min_pad_length_batch,
+ dtype=torch.long,
+ device=torch.device("cpu")))
+
+ # get position ids and attention mask
+ input_tokens, self._position_ids, self._mask = self.pad_input_ids(
+ input_token_list, min_pad_length=min_pad_length_batch)
+
+ seq_lens = [t.shape[0] for t in input_token_list]
+
+ return input_tokens, self._position_ids, self._mask, seq_lens
+
+ def _prepare_decode(
+ self,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ assert len(seq_group_metadata_list) > 0
+ input_tokens: List[List[int]] = []
+
+ for seq_group_metadata in seq_group_metadata_list:
+ assert not seq_group_metadata.is_prompt
+ seq_ids = list(seq_group_metadata.seq_data.keys())
+ assert len(seq_ids) == 1
+ seq_id = seq_ids[0]
+
+ seq_data = seq_group_metadata.seq_data[seq_id]
+ generation_token = seq_data.get_last_token_id()
+ input_tokens.append([generation_token])
+
+ # padding to compiled batch size
+ actual_batch_size = len(seq_group_metadata_list)
+ padded_batch_size = self._position_ids.shape[0]
+ while actual_batch_size < padded_batch_size:
+ input_tokens.append([0])
+ actual_batch_size += 1
+
+ # update position ids and attention mask
+ self._update_position_ids()
+ self._update_mask()
+
+ input_tokens = torch.tensor(input_tokens,
+ dtype=torch.long,
+ device=self.device)
+
+ return input_tokens, self._position_ids, self._mask
+
+ def _update_position_ids(self) -> None:
+ """Updating the position ids of all sequences
+ in a batch. Will be called in decoding phase"""
+
+ self._position_ids = self._position_ids[:, -1] + 1
+ self._position_ids = self._position_ids.unsqueeze(-1)
+
+ def _update_mask(self) -> None:
+ """Updating/extending the attention masks of all
+ sequences in a batch. Will be called in decoding phase"""
+
+ assert self._mask is not None
+ masks = self._mask
+
+ masks_new = []
+ for mask in masks:
+ # get the last row of the 3d mask
+ mask_new = mask[-1:, :]
+
+ # extend the mask one slot
+ mask_new = torch.cat(
+ (
+ mask_new,
+ torch.zeros(
+ 1, 1, dtype=mask_new.dtype, device=mask_new.device),
+ ),
+ dim=1,
+ )
+ masks_new.append(mask_new)
+
+ self._mask = torch.stack(masks_new, dim=0)
+
+ def prepare_input_tensors(
+ self,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ finished_requests_ids: Optional[List[str]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]:
+ # NOTE: We assume that all sequences in the group are all prompts or
+ # all decodes.
+ is_prompt = seq_group_metadata_list[0].is_prompt
+ # Prepare input tensors.
+ if is_prompt:
+ (input_tokens, input_positions, input_masks,
+ _) = self._prepare_prompt(seq_group_metadata_list)
+ seq_lens = [
+ input_tokens.shape[1] for i in range(input_tokens.shape[0])
+ ]
+ else:
+ (input_tokens, input_positions,
+ input_masks) = self._prepare_decode(seq_group_metadata_list)
+ seq_lens = []
+
+ # Clean up generators from completed requests
+ if finished_requests_ids:
+ for request_id in finished_requests_ids:
+ self.generators.pop(request_id, None)
+
+ sampling_metadata = SamplingMetadata.prepare(
+ seq_group_metadata_list,
+ seq_lens,
+ # query_lens is not needed if chunked prefill is not
+ # supported. Since Spyre worker doesn't support chunked prefill
+ # just use seq_lens instead.
+ seq_lens,
+ self.device,
+ self.pin_memory,
+ self.generators)
+ return (input_tokens, input_positions, input_masks, sampling_metadata)
+
+ def execute_model(
+ self,
+ seq_group_metadata_list: List[SequenceGroupMetadata],
+ finished_requests_ids: Optional[List[str]] = None,
+ ) -> Optional[SamplerOutput]:
+ (input_tokens, input_positions, input_masks,
+ sampling_metadata) = self.prepare_input_tensors(
+ seq_group_metadata_list, finished_requests_ids)
+ t0 = time.time()
+ hidden_states = self.model(
+ input_ids=input_tokens,
+ positions=input_positions,
+ masks=input_masks,
+ seq_group_metadata_list=seq_group_metadata_list,
+ )
+
+ # Compute the logits.
+ logits = self.model.compute_logits(hidden_states, sampling_metadata)
+
+ # Sample the next token.
+ output = self.model.sample(
+ logits=logits,
+ sampling_metadata=sampling_metadata,
+ )
+ t1 = time.time() - t0
+ print("[spyre_model_runner:execute_model] t_token: %.2fms" %
+ (t1 * 1000))
+
+ return output
+
+ def _prepare_pad_input_ids(
+ self,
+ input_ids_list: List[torch.Tensor],
+ min_pad_length: int = 0,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """left side padding implemented as
+ in fms.utils.generation.pad_input_id"""
+ max_len = max([min_pad_length] +
+ [seq.size(0) for seq in input_ids_list])
+ padded_input_ids_list = []
+ mask_list = []
+ position_ids_list = []
+ for input_ids_i in input_ids_list:
+ seq_len = input_ids_i.size(0)
+ if max_len > seq_len:
+ print(f"[SpyreModelRunner] INFO: Padding request of length "
+ f"{seq_len} tokens to {max_len} tokens.")
+ pads = torch.ones(max_len - seq_len,
+ dtype=torch.long,
+ device=input_ids_i.device) * self.pad_token_id
+ non_pads = torch.ones(seq_len,
+ dtype=torch.long,
+ device=input_ids_i.device)
+
+ pos_ids_pads = pads
+ pos_ids_seq = torch.arange(0,
+ seq_len,
+ dtype=torch.long,
+ device=input_ids_i.device)
+
+ # Setting this to 0, however if 0 is the eos, we will end up
+ # truncating the output if using truncate_after_eos once this
+ # workflow works for nested tensor, this can probably be removed
+ padded_input_ids_list.append(torch.cat((pads, input_ids_i)))
+ mask_list.append(torch.cat((torch.zeros_like(pads), non_pads)))
+ position_ids_list.append(torch.cat((pos_ids_pads, pos_ids_seq)))
+
+ return padded_input_ids_list, mask_list, position_ids_list
+
+ def pad_input_ids(
+ self,
+ input_ids_list: List[torch.Tensor],
+ min_pad_length: int = 0,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+
+ padded_input_ids_list, mask_list, position_ids_list = self.\
+ _prepare_pad_input_ids(input_ids_list, min_pad_length)
+
+ input_ids = torch.stack(padded_input_ids_list)
+ mask = torch.stack(mask_list).bool()
+ # this is a causal mask for generation
+ mask = (mask.unsqueeze(-1) == mask.unsqueeze(-2)).tril()
+ mask = torch.where(mask.logical_not(), -torch.inf, 0.0)
+ mask = mask.to(self.model.dtype)
+ position_ids = torch.stack(position_ids_list)
+
+ return input_ids, position_ids, mask
+
+ def _raw_model_forward(
+ self,
+ input_ids: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.Tensor] = None,
+ past_key_value_states: Optional[List[Tuple[torch.Tensor,
+ torch.Tensor]]] = None,
+ use_cache: bool = False,
+ only_last_token: bool = False,
+ attn_algorithm: Optional[str] = None
+ ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor,
+ torch.Tensor]]]]:
+
+ return self.model.model(input_ids,
+ mask=mask,
+ position_ids=position_ids,
+ past_key_value_states=past_key_value_states,
+ use_cache=use_cache,
+ only_last_token=only_last_token,
+ attn_algorithm=attn_algorithm)
diff --git a/vllm/worker/spyre_worker.py b/vllm/worker/spyre_worker.py
new file mode 100644
index 000000000..fa95ec444
--- /dev/null
+++ b/vllm/worker/spyre_worker.py
@@ -0,0 +1,337 @@
+"""A Spyre worker class."""
+import json
+import os
+import platform
+import time
+from typing import List, Optional, Tuple
+
+import torch
+import torch.distributed as dist
+
+import vllm.envs as envs
+from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
+ ParallelConfig, SchedulerConfig)
+from vllm.distributed import (ensure_model_parallel_initialized,
+ init_distributed_environment)
+from vllm.model_executor import set_random_seed
+from vllm.model_executor.layers.sampler import SamplerOutput
+from vllm.model_executor.model_loader import spyre_setup
+from vllm.sequence import ExecuteModelRequest
+from vllm.worker.spyre_embedding_model_runner import SpyreEmbeddingModelRunner
+from vllm.worker.spyre_model_runner import SpyreModelRunner
+from vllm.worker.worker_base import LoraNotSupportedWorkerBase
+
+
+class SpyreWorker(LoraNotSupportedWorkerBase):
+ """A worker class that executes the model on a group of Spyre cores.
+ """
+
+ def __init__(
+ self,
+ model_config: ModelConfig,
+ parallel_config: ParallelConfig,
+ scheduler_config: SchedulerConfig,
+ device_config: DeviceConfig,
+ cache_config: CacheConfig,
+ local_rank: int,
+ rank: int,
+ distributed_init_method: str,
+ is_driver_worker: bool = False,
+ ) -> None:
+ self.model_config = model_config
+ self.parallel_config = parallel_config
+ self.scheduler_config = scheduler_config
+ self.device_config = device_config
+ self.cache_config = cache_config
+ self.local_rank = local_rank
+ self.rank = rank
+ self.distributed_init_method = distributed_init_method
+ self.is_driver_worker = is_driver_worker
+ if parallel_config and is_driver_worker:
+ assert rank % parallel_config.tensor_parallel_size == 0, \
+ "Driver worker should be rank 0 of tensor parallel group."
+ if self.model_config.trust_remote_code:
+ # note: lazy import to avoid importing torch before initializing
+ from vllm.utils import init_cached_hf_modules
+ init_cached_hf_modules()
+
+ if self.model_config.task == "embedding":
+ self.model_runner: SpyreModelRunner = SpyreEmbeddingModelRunner(
+ model_config, parallel_config, scheduler_config, device_config)
+ else:
+ self.model_runner = SpyreModelRunner(model_config, parallel_config,
+ scheduler_config,
+ device_config)
+ self._env_initialized = False
+
+ def init_distributed_environment(self) -> None:
+ """Initialize the distributed environment."""
+
+ init_distributed_environment(
+ world_size=self.parallel_config.world_size,
+ rank=self.rank,
+ distributed_init_method="env://",
+ backend="gloo",
+ )
+
+ torch._C._distributed_c10d._register_process_group(
+ "default", dist.group.WORLD)
+
+ if envs.VLLM_SPYRE_DYNAMO_BACKEND in ["sendnn", "sendnn_decoder"]:
+ spyre_setup.spyre_dist_setup(
+ rank=self.rank,
+ world_size=self.parallel_config.world_size,
+ verbose=True)
+
+ # A small all_reduce for warmup.
+ torch.distributed.all_reduce(torch.zeros(1).cpu())
+
+ ensure_model_parallel_initialized(
+ self.parallel_config.tensor_parallel_size,
+ self.parallel_config.pipeline_parallel_size,
+ )
+
+ def init_device(self) -> None:
+
+ if platform.machine() == "s390x":
+ from torch.serialization import LoadEndianness
+ torch.serialization.set_default_load_endianness(
+ LoadEndianness.LITTLE)
+
+ if not self._env_initialized:
+ if self.parallel_config.world_size > 1:
+ self.init_distributed_environment()
+ elif envs.VLLM_SPYRE_DYNAMO_BACKEND in [
+ "sendnn", "sendnn_decoder"
+ ]:
+ spyre_setup.spyre_setup(rank=0, world_size=1, verbose=True)
+
+ self._env_initialized = True
+
+ # Set random seed.
+ set_random_seed(self.model_config.seed)
+
+ def load_model(self):
+ assert self._env_initialized
+
+ with open(os.path.join(self.model_config.model, 'config.json'),
+ 'rb') as f:
+ config = json.load(f)
+
+ bos_token_id, eos_token_id = int(config["bos_token_id"]), int(
+ config["eos_token_id"])
+
+ print("[SpyreWorker] load model...")
+ # TODO: check additionally if the Spyre card has enough memory
+ # for all requested model warmups
+ # printing env variables for debugging purposes
+ load_model_start_t = time.time()
+
+ wup_prompt_lens, wup_new_tokens, wup_batch_sizes = zip(
+ *[(s["prompt_length"], s["new_tokens"], s["batch_size"])
+ for s in self.scheduler_config.spyre_warmup_shapes])
+
+ self.model_runner.load_model(prompt_lens=wup_prompt_lens,
+ num_decode_tokens=wup_new_tokens,
+ batch_sizes=wup_batch_sizes)
+
+ load_model_end_t = time.time()
+ load_model_total_t = load_model_end_t - load_model_start_t
+ print(f"\tload model took {load_model_total_t}s")
+
+ print(f"[SpyreWorker] Start warming up "
+ f"{len(wup_new_tokens)} "
+ f"different prompt/decode/batchsize-shape combinations.")
+ all_warmup_start_t = time.time()
+ for i, (prompt_len, num_decode_tokens, batch_size) in enumerate([
+ (s["prompt_length"], s["new_tokens"], s["batch_size"])
+ for s in self.scheduler_config.spyre_warmup_shapes
+ ]):
+ if self.model_config.task != "embedding":
+ # TODO: remove if spyre supports
+ # lower number of output tokens
+ assert num_decode_tokens >= 3, (
+ "VLLM_SPYRE_WARMUP_NEW_TOKENS must be "
+ "at least 2 (spyre requirement).")
+ # warmup individual combination
+ print(f"[SpyreWorker] Warmup {i+1}/"
+ f"{len(wup_new_tokens)} "
+ f"prompt/decode/batchsize-shape combinations...")
+ print(f"[SpyreWorker] Warming up for prompt length {prompt_len}, "
+ f"decoding {num_decode_tokens} tokens with batch "
+ f"size {batch_size}")
+ self._warmup_spyre_fixed_size(prompt_len, num_decode_tokens,
+ (bos_token_id, eos_token_id),
+ batch_size)
+ all_warmup_end_t = time.time()
+ all_warmup_total_t = all_warmup_end_t - all_warmup_start_t
+ print(f"[SpyreWorker] All warmups for "
+ f"{len(wup_new_tokens)} different "
+ f"prompt/decode/batchsize-shape combinations finished. "
+ f"Total warmup time {all_warmup_total_t}s.")
+
+ def _warmup_spyre_fixed_size(self, prompt_len, num_decode_tokens,
+ special_token_ids, batch_size):
+ # warmup the model
+ warmup_start_t = time.time()
+ # NOTE(ngl): empty tensor causes spyre to hang, so using
+ # randint without 0 and the eos and bos token
+
+ # Create a list of valid values between 1 (inclusive) and vocab
+ # size (exclusive) by excluding the eos and bos token ids
+ # (in special_token_ids)
+ vocab_size = self.model_runner.vocab_size
+ valid_token_ids = [
+ i for i in range(1, vocab_size) if i not in set(special_token_ids)
+ ]
+ # Convert to tensor for sampling
+ valid_token_ids_tensor = torch.tensor(valid_token_ids,
+ dtype=torch.long,
+ device=torch.device("cpu"))
+ # Sample from the valid token ids
+ warmup_tokens_tensor = valid_token_ids_tensor[torch.randint(
+ 0, len(valid_token_ids_tensor), (batch_size, prompt_len))]
+
+ extra_kwargs = {}
+ if envs.VLLM_SPYRE_DYNAMO_BACKEND not in ["sendnn", "sendnn_decoder"]:
+ # Bug in 2.3.1 fixed in 2.4.1 for SDPA flash cpu
+ # impl when padding too much
+ extra_kwargs["attn_algorithm"] = "math"
+
+ print(f"[SpyreWorker] warmup for prompt length "
+ f"{prompt_len} and max output tokens {num_decode_tokens}.")
+
+ # 1. trace
+ print("[SpyreWorker] warmup 1/2...")
+ # TODO: torch_sendnn.CleanGraph() should be necessary?
+ # warmup 1st forward pass
+ self._warmup_model_forward_pass(warmup_tokens_tensor,
+ valid_token_ids_tensor, prompt_len,
+ num_decode_tokens, batch_size,
+ extra_kwargs)
+
+ # 2. compile
+ print("[SpyreWorker] warmup 2/2...")
+ if envs.VLLM_SPYRE_DYNAMO_BACKEND == "sendnn_decoder":
+ from torch_sendnn import torch_sendnn
+ ul_start_time = time.time()
+ torch_sendnn.update_lazyhandle()
+ ul_stop_time = time.time()
+ ul_total_t = ul_stop_time - ul_start_time
+ print(f"update_lazyhandle() done (duration: {ul_total_t}s)")
+
+ # warmup 2nd forward pass
+ self._warmup_model_forward_pass(warmup_tokens_tensor,
+ valid_token_ids_tensor, prompt_len,
+ num_decode_tokens, batch_size,
+ extra_kwargs)
+
+ warmup_end_t = time.time()
+ warmup_total_t = warmup_end_t - warmup_start_t
+ print("[SpyreWorker] ... warmup finished.")
+ print(f"\twarmup took {warmup_total_t}s (for prompt length"
+ f"{prompt_len} and max output tokens {num_decode_tokens})")
+
+ def _warmup_model_forward_pass(self, warmup_tokens_tensor,
+ valid_token_ids_tensor, prompt_len,
+ num_decode_tokens, batch_size,
+ extra_kwargs):
+ # padding warmup tokens to obtain the
+ # corresponding position ids and mask
+ warmup_tokens_pad, self.model_runner._position_ids, \
+ self.model_runner._mask = self.model_runner.pad_input_ids(
+ warmup_tokens_tensor, min_pad_length=prompt_len)
+
+ logits, past_key_value_states = self.model_runner._raw_model_forward(
+ warmup_tokens_pad,
+ position_ids=self.model_runner._position_ids,
+ mask=self.model_runner._mask,
+ past_key_value_states=None,
+ use_cache=True,
+ only_last_token=True,
+ **extra_kwargs)
+ # decoding
+ for i in range(num_decode_tokens - 1):
+ # sampling next input token from vocab without bos and eos tokens
+ decode_tokens = valid_token_ids_tensor[torch.randint(
+ 0, len(valid_token_ids_tensor), (batch_size, 1))]
+
+ # update mask and position_ids
+ self.model_runner._update_mask()
+ self.model_runner._update_position_ids()
+
+ if past_key_value_states is not None:
+ for layer in past_key_value_states:
+ for tensor in layer:
+ torch._dynamo.mark_dynamic(tensor, 2)
+
+ logits, past_key_value_states = self.model_runner.\
+ _raw_model_forward(
+ decode_tokens,
+ position_ids=self.model_runner._position_ids,
+ mask=self.model_runner._mask,
+ past_key_value_states=past_key_value_states,
+ use_cache=True,
+ only_last_token=True,
+ **extra_kwargs)
+
+ def determine_num_available_blocks(self) -> Tuple[int, int]:
+ """Determine the number of available KV blocks.
+
+ Swapping is not yet supported, so always return num_cpu_blocks=0.
+
+ We configure num_gpu_blocks to be equal to max_num_seqs.
+ """
+ # Set the number of GPU blocks to be the same as the maximum number of
+ # sequences that can be processed in a single batch. This is equivalent
+ # to schedule without PagedAttention.
+ num_gpu_blocks = self.scheduler_config.max_num_seqs
+
+ # Swap not yet supported with Spyre backend.
+ num_cpu_blocks = 0
+
+ return num_gpu_blocks, num_cpu_blocks
+
+ def initialize_cache(self, num_gpu_blocks: int,
+ num_cpu_blocks: int) -> None:
+ """Initialize the KV cache.
+ """
+
+ # Different values are not tested.
+ assert num_cpu_blocks == 0
+ assert num_gpu_blocks == self.scheduler_config.max_num_seqs
+
+ self.cache_config.num_gpu_blocks = num_gpu_blocks
+ self.cache_config.num_cpu_blocks = num_cpu_blocks
+
+ # TODO: why not inference mode?
+ #@torch.inference_mode()
+ def execute_model(
+ self,
+ execute_model_req: Optional[ExecuteModelRequest] = None
+ ) -> Optional[List[SamplerOutput]]:
+
+ torch.set_grad_enabled(False)
+ if execute_model_req is None:
+ return None
+ finished_requests_ids = execute_model_req.finished_requests_ids
+ seq_group_metadata_list = execute_model_req.seq_group_metadata_list
+ num_seq_groups = len(seq_group_metadata_list)
+
+ # If there is no input, we don't need to execute the model.
+ if num_seq_groups == 0:
+ return []
+
+ output = self.model_runner.execute_model(seq_group_metadata_list,
+ finished_requests_ids)
+
+ # Spyre worker only supports single-step output. Wrap the output in a
+ # list to conform to interface.
+ return [output]
+
+ def get_cache_block_size_bytes(self) -> int:
+ """Determine the size in bytes of a cache block.
+
+ This is required for speculative decoding; it is not yet implemented.
+ """
+ raise NotImplementedError