Skip to content

Commit

Permalink
Merge branch 'vllm-project:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
sroy745 authored Oct 3, 2024
2 parents 0dd96ed + 9aaf14c commit 9d4d969
Show file tree
Hide file tree
Showing 56 changed files with 2,069 additions and 620 deletions.
2 changes: 1 addition & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ steps:
- tests/spec_decode
commands:
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
- pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py

- label: LoRA Test %N # 15min each
mirror_hardwares: [amd]
Expand Down
11 changes: 11 additions & 0 deletions csrc/prepare_inputs/advance_step.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@ __global__ void advance_step_flashattn_kernel(
long const* sampled_token_ids_ptr, long* input_positions_ptr,
int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr,
int64_t const block_tables_stride) {
int const n_pad = num_seqs - num_queries;
if (n_pad && blockIdx.x == 0) {
// Handle cuda graph padding
int const offset = num_queries;
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
input_tokens_ptr[offset + i] = 0;
input_positions_ptr[offset + i] = 0;
slot_mapping_ptr[offset + i] = -1;
}
}

int num_query_blocks = div_ceil(num_queries, num_threads);

if (blockIdx.x >= num_query_blocks) {
Expand Down
35 changes: 29 additions & 6 deletions docs/source/getting_started/openvino-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Installation with OpenVINO
==========================

vLLM powered by OpenVINO supports all LLM models from :doc:`vLLM supported models list <../models/supported_models>` and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support. OpenVINO vLLM backend supports the following advanced vLLM features:
vLLM powered by OpenVINO supports all LLM models from :doc:`vLLM supported models list <../models/supported_models>` and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support, as well as on both integrated and discrete Intel® GPUs (`the list of supported GPUs <https://docs.openvino.ai/2024/about-openvino/release-notes-openvino/system-requirements.html#gpu>`_). OpenVINO vLLM backend supports the following advanced vLLM features:

- Prefix caching (``--enable-prefix-caching``)
- Chunked prefill (``--enable-chunked-prefill``)
Expand Down Expand Up @@ -53,34 +53,57 @@ Install from source
$ pip install --upgrade pip
$ pip install -r requirements-build.txt --extra-index-url https://download.pytorch.org/whl/cpu
- Finally, install vLLM with OpenVINO backend:
- Finally, install vLLM with OpenVINO backend:

.. code-block:: console
$ PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" VLLM_TARGET_DEVICE=openvino python -m pip install -v .
- [Optional] To use vLLM OpenVINO backend with a GPU device, ensure your system is properly set up. Follow the instructions provided here: `https://docs.openvino.ai/2024/get-started/configurations/configurations-intel-gpu.html <https://docs.openvino.ai/2024/get-started/configurations/configurations-intel-gpu.html>`_.

.. _openvino_backend_performance_tips:

Performance tips
----------------

vLLM OpenVINO backend uses the following environment variables to control behavior:
vLLM OpenVINO backend environment variables
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

- ``VLLM_OPENVINO_DEVICE`` to specify which device utilize for the inference. If there are multiple GPUs in the system, additional indexes can be used to choose the proper one (e.g, ``VLLM_OPENVINO_DEVICE=GPU.1``). If the value is not specified, CPU device is used by default.

- ``VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON`` to enable U8 weights compression during model loading stage. By default, compression is turned off. You can also export model with different compression techniques using `optimum-cli` and pass exported folder as `<model_id>`

CPU performance tips
~~~~~~~~~~~~~~~~~~~~

CPU uses the following environment variables to control behavior:

- ``VLLM_OPENVINO_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_OPENVINO_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users.

- ``VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8`` to control KV cache precision. By default, FP16 / BF16 is used depending on platform.

- ``VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON`` to enable U8 weights compression during model loading stage. By default, compression is turned off. You can also export model with different compression techniques using `optimum-cli` and pass exported folder as `<model_id>`

To enable better TPOT / TTFT latency, you can use vLLM's chunked prefill feature (``--enable-chunked-prefill``). Based on the experiments, the recommended batch size is ``256`` (``--max-num-batched-tokens``)

OpenVINO best known configuration is:
OpenVINO best known configuration for CPU is:

.. code-block:: console
$ VLLM_OPENVINO_KVCACHE_SPACE=100 VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8 VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON \
python3 vllm/benchmarks/benchmark_throughput.py --model meta-llama/Llama-2-7b-chat-hf --dataset vllm/benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json --enable-chunked-prefill --max-num-batched-tokens 256
GPU performance tips
~~~~~~~~~~~~~~~~~~~~
GPU device implements the logic for automatic detection of available GPU memory and, by default, tries to reserve as much memory as possible for the KV cache (taking into account ``gpu_memory_utilization`` option). However, this behavior can be overridden by explicitly specifying the desired amount of memory for the KV cache using ``VLLM_OPENVINO_KVCACHE_SPACE`` environment variable (e.g, ``VLLM_OPENVINO_KVCACHE_SPACE=8`` means 8 GB space for KV cache).

Currently, the best performance using GPU can be achieved with the default vLLM execution parameters for models with quantized weights (8 and 4-bit integer data types are supported) and `preemption-mode=swap`.

OpenVINO best known configuration for GPU is:

.. code-block:: console
$ VLLM_OPENVINO_DEVICE=GPU VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON \
python3 vllm/benchmarks/benchmark_throughput.py --model meta-llama/Llama-2-7b-chat-hf --dataset vllm/benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json
.. _openvino_backend_limitations:

Limitations
Expand Down
8 changes: 6 additions & 2 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,12 @@ Decoder-only Language Models
- :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc.
-
* - :code:`GraniteForCausalLM`
- Granite, Power-LM
- :code:`ibm/granite-7b-base`, :code:`ibm/PowerLM-3b` etc.
- PowerLM
- :code:`ibm/PowerLM-3b` etc.
- ✅︎
* - :code:`GraniteMoeForCausalLM`
- PowerMoE
- :code:`ibm/PowerMoE-3b` etc.
- ✅︎
* - :code:`InternLMForCausalLM`
- InternLM
Expand Down
3 changes: 1 addition & 2 deletions examples/tool_chat_template_mistral_parallel.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
{%- endif %}
{%- if not tools is defined %}
{%- set tools = none %}
{%- endif %}
{%- if tools is defined %}
{%- elif tools is not none %}
{%- set parallel_tool_prompt = "You are a helpful assistant that can call tools. If you call one or more tools, format them in a single JSON array or objects, where each object is a tool call, not as separate objects outside of an array or multiple arrays. Use the format [{\"name\": tool call name, \"arguments\": tool call arguments}, additional tool calls] if you call more than one tool. If you call tools, do not attempt to interpret them or otherwise provide a response until you receive a tool call result that you can interpret for the user." %}
{%- if system_message is defined %}
{%- set system_message = parallel_tool_prompt + "\n\n" + system_message %}
Expand Down
5 changes: 3 additions & 2 deletions requirements-openvino.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@

# OpenVINO dependencies
torch >= 2.1.2
openvino ~= 2024.3.0
optimum-intel[openvino] >= 1.18.2
openvino ~= 2024.4.0
openvino-tokenizers[transformers] ~= 2024.4.0
optimum-intel[openvino] >= 1.19.0
48 changes: 34 additions & 14 deletions tests/engine/test_arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,42 @@ def test_bad_nullable_kvs(arg):
nullable_kvs(arg)


@pytest.mark.parametrize(("arg", "expected"), [
(None, None),
("{}", {}),
('{"num_crops": 4}', {
"num_crops": 4
}),
('{"foo": {"bar": "baz"}}', {
"foo": {
"bar": "baz"
}
}),
# yapf: disable
@pytest.mark.parametrize(("arg", "expected", "option"), [
(None, None, "mm-processor-kwargs"),
("{}", {}, "mm-processor-kwargs"),
(
'{"num_crops": 4}',
{
"num_crops": 4
},
"mm-processor-kwargs"
),
(
'{"foo": {"bar": "baz"}}',
{
"foo":
{
"bar": "baz"
}
},
"mm-processor-kwargs"
),
(
'{"cast_logits_dtype":"bfloat16","sequence_parallel_norm":true,"sequence_parallel_norm_threshold":2048}',
{
"cast_logits_dtype": "bfloat16",
"sequence_parallel_norm": True,
"sequence_parallel_norm_threshold": 2048,
},
"override-neuron-config"
),
])
def test_mm_processor_kwargs_prompt_parser(arg, expected):
# yapf: enable
def test_composite_arg_parser(arg, expected, option):
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
if arg is None:
args = parser.parse_args([])
else:
args = parser.parse_args(["--mm-processor-kwargs", arg])
assert args.mm_processor_kwargs == expected
args = parser.parse_args([f"--{option}", arg])
assert getattr(args, option.replace("-", "_")) == expected
56 changes: 7 additions & 49 deletions tests/kernels/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import pytest
import torch

import vllm.attention.backends.flash_attn # noqa: F401
from tests.kernels.utils import opcheck
from vllm.utils import seed_everything
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)

NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
HEAD_SIZES = [128, 256]
Expand Down Expand Up @@ -112,36 +112,17 @@ def test_flash_attn_with_paged_kv(
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

output = torch.ops.vllm.flash_attn_with_kvcache(
decode_query=query.unsqueeze(1),
key_cache=key_cache,
value_cache=value_cache,
output = flash_attn_with_kvcache(
q=query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
softmax_scale=scale,
causal=True,
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
).squeeze(1)

if num_blocks <= 2048:
test_utils = ["test_faketensor", "test_schema"]
else:
test_utils = ["test_faketensor"]

opcheck(torch.ops.vllm.flash_attn_with_kvcache,
args=tuple(),
kwargs=dict(
decode_query=query.unsqueeze(1),
key_cache=key_cache,
value_cache=value_cache,
softmax_scale=scale,
causal=True,
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
),
test_utils=test_utils)

ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
Expand Down Expand Up @@ -213,7 +194,7 @@ def test_varlen_with_paged_kv(
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

output = torch.ops.vllm.flash_attn_varlen_func(
output = flash_attn_varlen_func(
q=query,
k=key_cache,
v=value_cache,
Expand All @@ -228,29 +209,6 @@ def test_varlen_with_paged_kv(
softcap=soft_cap if soft_cap is not None else 0,
)

if num_blocks <= 2048:
test_utils = ["test_faketensor", "test_schema"]
else:
test_utils = ["test_faketensor"]

opcheck(torch.ops.vllm.flash_attn_varlen_func,
args=tuple(),
kwargs=dict(
q=query,
k=key_cache,
v=value_cache,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_kv_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
softmax_scale=scale,
causal=True,
window_size=window_size,
block_table=block_tables,
softcap=soft_cap if soft_cap is not None else 0,
),
test_utils=test_utils)

ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
Expand Down
39 changes: 39 additions & 0 deletions tests/models/decoder_only/language/test_granitemoe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Compare the outputs of HF and vLLM for Granite models using greedy sampling.
Run `pytest tests/models/test_granite.py`.
"""
import pytest

from ...utils import check_logprobs_close

MODELS = [
"ibm/PowerMoE-3b",
]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)

with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
Loading

0 comments on commit 9d4d969

Please sign in to comment.