Skip to content

Commit

Permalink
Addressing Cody' note
Browse files Browse the repository at this point in the history
  • Loading branch information
wenlei03 authored and root committed Apr 26, 2024
1 parent f6b8afe commit 748c687
Showing 4 changed files with 13 additions and 58 deletions.
4 changes: 2 additions & 2 deletions tests/spec_decode/e2e/test_compatibility.py
Original file line number Diff line number Diff line change
@@ -90,7 +90,7 @@ def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "NousResearch/Llama-2-7b-chat-hf",
"model": "meta-llama/Llama-2-7b-chat-hf",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
@@ -112,7 +112,7 @@ def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
},
{
# Speculative max model len > target max model len should raise.
# https://huggingface.co/NousResearch/Llama-2-7b-chat-hf/blob/37892f30c23786c0d5367d80481fa0d9fba93cf8/config.json#L11
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/f5db02db724555f92da89c216ac04704f23d4590/config.json#L12
"speculative_max_model_len": 4096 + 1,
},
])
4 changes: 2 additions & 2 deletions tests/spec_decode/e2e/test_multistep_correctness.py
Original file line number Diff line number Diff line change
@@ -264,7 +264,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
"common_llm_kwargs",
[{
# A "real" model (not tiny).
"model": "NousResearch/Llama-2-7b-chat-hf",
"model": "meta-llama/Llama-2-7b-chat-hf",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
@@ -308,7 +308,7 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
"common_llm_kwargs",
[{
# A "real" model (not tiny).
"model": "NousResearch/Llama-2-7b-chat-hf",
"model": "meta-llama/Llama-2-7b-chat-hf",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
60 changes: 8 additions & 52 deletions tests/spec_decode/e2e/test_ngram_correctness.py
Original file line number Diff line number Diff line change
@@ -64,13 +64,14 @@
"output_len",
[
# Use long output len for the small model test.
1536,
256,
])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("batch_size", [1, 64])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
baseline_llm_generator, test_llm_generator, batch_size: int,
output_len: int):
def test_spec_decode_e2e_greedy_correctness_tiny_model(baseline_llm_generator,
test_llm_generator,
batch_size: int,
output_len: int):
"""Verify greedy equality on a tiny model with batch size of one.
Since this test is cheaper than other e2e correctness tests, we generate
@@ -83,51 +84,6 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
# Print spec metrics.
"disable_log_stats": False,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"model": "JackFram/llama-68m",
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": 5,
"ngram_prompt_lookup_max": 3,
},
])
@pytest.mark.parametrize(
"output_len",
[
# Use small output len for fast test.
256,
])
@pytest.mark.parametrize("batch_size", [64])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
baseline_llm_generator, test_llm_generator, batch_size: int,
output_len: int):
"""Verify greedy equality on a tiny model and large batch size.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=output_len,
force_output_len=True)


@pytest.mark.parametrize(
"common_llm_kwargs",
[{
@@ -198,15 +154,15 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
"ngram_prompt_lookup_max": 3,
}
# Try a range of common k, as well as large speculation.
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
for k in [1, 3, 5, 7, 10, 63]
] + [
{
"speculative_model": "[ngram]",
"num_speculative_tokens": k,
"ngram_prompt_lookup_max": 1,
}
# Try a range of common k, as well as large speculation.
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
for k in [1, 3, 5, 7, 10, 63]
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize(
3 changes: 1 addition & 2 deletions vllm/spec_decode/ngram_worker.py
Original file line number Diff line number Diff line change
@@ -115,8 +115,7 @@ def sampler_output(
ngram_size + sample_len]
res_len = len(res)
# pad 0 towards output as sample_len tokens required
for i in range(res_len, sample_len):
res.append(0)
res += [0] * (sample_len - res_len)

break
else:

0 comments on commit 748c687

Please sign in to comment.