-
-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial code drop with Spyre support Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com> Signed-off-by: Nikolaos Papandreou <npo@zurich.ibm.com> Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Nick Hill <nickhill@us.ibm.com> Co-authored-by: Thomas Parnell <tpa@zurich.ibm.com> Co-authored-by: Nikolaos Papandreou <npo@zurich.ibm.com> Co-authored-by: TRAVIS JOHNSON <tsjohnso@us.ibm.com> Co-authored-by: Burkhard Ringlein <NGL@zurich.ibm.com> Co-authored-by: Yannick Schnider <Yannick.Schnider1@ibm.com> Co-authored-by: Jan van Lunteren <jvl@zurich.ibm.com> Co-authored-by: Maximilien Philippe Marie de Bayser <mbayser@br.ibm.com>
- Loading branch information
1 parent
772a667
commit 3e43bb2
Showing
39 changed files
with
3,748 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
collect_env.py | ||
|
||
vllm/model_executor/model_loader/spyre_setup.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Global Args ################################################################# | ||
ARG BASE_UBI_IMAGE_TAG=9.4 | ||
ARG PYTHON_VERSION=3.12 | ||
|
||
# Base Layer ################################################################## | ||
FROM registry.access.redhat.com/ubi9/ubi-minimal:${BASE_UBI_IMAGE_TAG} AS base | ||
ARG PYTHON_VERSION | ||
ENV PYTHON_VERSION=${PYTHON_VERSION} | ||
WORKDIR /workspace/vllm | ||
|
||
# Install some basic utilities ################################################################## | ||
RUN microdnf update -y && microdnf install -y \ | ||
python${PYTHON_VERSION}-devel python${PYTHON_VERSION}-pip python${PYTHON_VERSION}-wheel git vim gcc g++\ | ||
&& microdnf clean all | ||
|
||
# Install build dependencies ################################################################## | ||
RUN --mount=type=bind,source=requirements-build.txt,target=requirements-build.txt \ | ||
python3.12 -m pip install --upgrade pip && \ | ||
pip install -r requirements-build.txt | ||
|
||
# Build vLLM ################################################################## | ||
COPY . . | ||
|
||
ENV VLLM_TARGET_DEVICE=spyre | ||
RUN --mount=type=bind,source=.git,target=.git \ | ||
pip install --no-build-isolation -v -e . | ||
|
||
CMD ["/bin/bash"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import gc | ||
import os | ||
import time | ||
|
||
from vllm import LLM, SamplingParams | ||
|
||
max_tokens = 3 | ||
|
||
os.environ["VLLM_SPYRE_WARMUP_PROMPT_LENS"] = '64' | ||
os.environ["VLLM_SPYRE_WARMUP_NEW_TOKENS"] = str(max_tokens) | ||
os.environ['VLLM_SPYRE_WARMUP_BATCH_SIZES'] = '1' | ||
|
||
# stuff for multi-spyre | ||
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" | ||
os.environ["DISTRIBUTED_STRATEGY_IGNORE_MODULES"] = "WordEmbedding" | ||
os.environ["MASTER_ADDR"] = "localhost" | ||
os.environ["MASTER_PORT"] = "12355" | ||
|
||
# Sample prompts. | ||
template = ( | ||
"Below is an instruction that describes a task. Write a response that " | ||
"appropriately completes the request. Be polite in your response to the " | ||
"user.\n\n### Instruction:\n{}\n\n### Response:") | ||
prompt1 = template.format( | ||
"Provide a list of instructions for preparing chicken soup for a family " | ||
"of four.") | ||
prompts = [ | ||
prompt1, | ||
] | ||
|
||
# Create a sampling params object. | ||
sampling_params = SamplingParams(max_tokens=max_tokens, | ||
temperature=0.0, | ||
ignore_eos=True) | ||
# Create an LLM. | ||
llm = LLM( | ||
model="/models/llama-194m", | ||
tokenizer="/models/llama-194m", | ||
max_model_len=2048, | ||
block_size=2048, | ||
device="spyre", | ||
tensor_parallel_size=2, | ||
) | ||
|
||
# Generate texts from the prompts. The output is a list of RequestOutput objects | ||
# that contain the prompt, generated text, and other information. | ||
print("=============== GENERATE") | ||
t0 = time.time() | ||
outputs = llm.generate(prompts, sampling_params) | ||
print("Time elaspsed for %d tokens is %.2f sec" % | ||
(len(outputs[0].outputs[0].token_ids), time.time() - t0)) | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||
print(output.outputs[0]) | ||
|
||
# needed to prevent ugly stackdump caused by sigterm | ||
del llm | ||
gc.collect() |
Oops, something went wrong.