Skip to content

Commit

Permalink
[Speculative decoding] Add ngram prompt lookup decoding (vllm-project…
Browse files Browse the repository at this point in the history
…#4237)

Co-authored-by: Lei Wen <wenlei03@qiyi.com>
  • Loading branch information
2 people authored and dtrifiro committed May 7, 2024
1 parent 0662b71 commit 2266a5e
Show file tree
Hide file tree
Showing 14 changed files with 1,004 additions and 319 deletions.
58 changes: 58 additions & 0 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from itertools import cycle
from typing import List, Optional, Tuple, Union

import pytest
Expand Down Expand Up @@ -185,3 +186,60 @@ def get_output_from_llm_generator(
del llm

return tokens, token_ids


def run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
print_tokens: bool = False):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
"""
temperature = 0.0

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"San Francisco is know for its",
"Facebook was created in 2004 by",
"Curious George is a",
"Python 3.11 brings improvements to its",
]

prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]

# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos = force_output_len

sampling_params = SamplingParams(
max_tokens=max_output_len,
ignore_eos=ignore_eos,
temperature=temperature,
)

spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
test_llm_generator, prompts, sampling_params)

(baseline_batch_tokens,
baseline_batch_token_ids) = get_output_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)

assert len(baseline_batch_token_ids) == len(prompts)
assert len(spec_batch_token_ids) == len(prompts)

for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
spec_tokens) in enumerate(
zip(baseline_batch_token_ids, baseline_batch_tokens,
spec_batch_token_ids, spec_batch_tokens)):
if print_tokens:
print(f'{i=} {baseline_tokens=}')
print(f'{i=} {spec_tokens=}')
print(f'{i=} {baseline_token_ids=}')
print(f'{i=} {spec_token_ids=}')
assert baseline_token_ids == spec_token_ids
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@

from vllm import SamplingParams

from .conftest import get_output_from_llm_generator
from .conftest import (get_output_from_llm_generator,
run_greedy_equality_correctness_test)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -545,60 +546,3 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
batch_size,
max_output_len=output_len,
force_output_len=True)


def run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len,
force_output_len: bool,
print_tokens: bool = False):
"""Helper method that compares the outputs of both the baseline LLM and
the test LLM. It asserts greedy equality, e.g. that the outputs are exactly
the same when temperature is zero.
"""
temperature = 0.0

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"San Francisco is know for its",
"Facebook was created in 2004 by",
"Curious George is a",
"Python 3.11 brings improvements to its",
]

prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]

# If the test requires that we generated max_output_len tokens, then set the
# sampling params to ignore eos token.
ignore_eos = force_output_len

sampling_params = SamplingParams(
max_tokens=max_output_len,
ignore_eos=ignore_eos,
temperature=temperature,
)

spec_batch_tokens, spec_batch_token_ids = get_output_from_llm_generator(
test_llm_generator, prompts, sampling_params)

(baseline_batch_tokens,
baseline_batch_token_ids) = get_output_from_llm_generator(
baseline_llm_generator, prompts, sampling_params)

assert len(baseline_batch_token_ids) == len(prompts)
assert len(spec_batch_token_ids) == len(prompts)

for i, (baseline_token_ids, baseline_tokens, spec_token_ids,
spec_tokens) in enumerate(
zip(baseline_batch_token_ids, baseline_batch_tokens,
spec_batch_token_ids, spec_batch_tokens)):
if print_tokens:
print(f'{i=} {baseline_tokens=}')
print(f'{i=} {spec_tokens=}')
print(f'{i=} {baseline_token_ids=}')
print(f'{i=} {spec_token_ids=}')
assert baseline_token_ids == spec_token_ids
172 changes: 172 additions & 0 deletions tests/spec_decode/e2e/test_ngram_correctness.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""This docstring details important information on the testing methodology.
Most of the tests rely on "greedy equality", where we expect the output of
speculative decoding on a sequence to exactly match the output of normal non-
speculative decoding.
Since speculative decoding with rejection sampling guarantees that the output
distribution matches the target model's output distribution (up to hardware
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
equality.
For ngram lookup, its idea comes from https://github.com/apoorvumang/prompt-lookup-decoding,
and is merged into transform code base: https://github.com/huggingface/transformers/pull/27775.
Since there is no model is needed for generate the proposal, we could make
the testcase much simpler than drafter multi-step one.
However, we still need to verify below scenario could be passed:
* Batch size 1 greedy equality
* Batch size >1 greedy equality
* Test greedy equality under preemption
* Test greedy equality under various ngram sizes / speculative sizes
With those tests, we can say at least, ngram spec would not break the correctess
for the target model outputs.
"""

import pytest

from .conftest import run_greedy_equality_correctness_test


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"model": "JackFram/llama-68m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
])
@pytest.mark.parametrize("output_len", [
256,
])
@pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
test_llm_generator, batch_size: int,
output_len: int):
"""Verify greedy equality on a tiny model with different batch size."""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"block_size": 8,
# 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"model": "JackFram/llama-160m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
256,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
test_llm_generator,
batch_size: int,
output_len: int):
"""Verify greedy equality, even when some sequences are preempted mid-
generation.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-68m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
{
"speculative_model": "[ngram]",
"num_speculative_tokens": k,
"ngram_prompt_lookup_max": 3,
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5]
] + [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": k,
"ngram_prompt_lookup_max": 1,
}
# Try a range of common k, as well as large speculation.
for k in [1, 3, 5]
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
"output_len",
[
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that ngram speculative decoding produces exact equality
to without spec decode with many different values of k and
different ngram_prompt_lookup_max.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)
Loading

0 comments on commit 2266a5e

Please sign in to comment.