Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix speculative decode seeded test #6733

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,18 +140,18 @@ async def get_output(prompt, sampling_param) -> RequestOutput:
@pytest.fixture
def baseline_llm_generator(request, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs,
seed):
baseline_seed):
return create_llm_generator("baseline", request, common_llm_kwargs,
per_test_common_llm_kwargs,
baseline_llm_kwargs, seed)
baseline_llm_kwargs, baseline_seed)


@pytest.fixture
def test_llm_generator(request, common_llm_kwargs, per_test_common_llm_kwargs,
test_llm_kwargs, seed):
test_llm_kwargs, test_seed):
return create_llm_generator("test", request, common_llm_kwargs,
per_test_common_llm_kwargs, test_llm_kwargs,
seed)
test_seed)


def create_llm_generator(baseline_or_test, request, common_llm_kwargs,
Expand Down
6 changes: 3 additions & 3 deletions tests/spec_decode/e2e/test_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
"""Verify that speculative decoding with chunked prefill fails.
"""
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
"""Verify that speculative decoding validates speculative_max_model_len.
"""
Expand Down Expand Up @@ -103,7 +103,7 @@ def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_spec_decode_xfail_block_manager_v1(test_llm_generator):
"""Verify that speculative decoding with block manager v1 fails.
"""
Expand Down
3 changes: 2 additions & 1 deletion tests/spec_decode/e2e/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [32])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
batch_size, output_len):
"""Verify spec decode equality when cuda graphs are enabled.
Expand Down
6 changes: 4 additions & 2 deletions tests/spec_decode/e2e/test_integration_dist_tp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify greedy equality when tensor parallelism is used.
Expand Down Expand Up @@ -113,7 +114,8 @@ def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
})
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_draft_model_tp_lt_target_model_tp2(test_llm_generator,
baseline_llm_generator,
batch_size: int):
Expand Down
6 changes: 4 additions & 2 deletions tests/spec_decode/e2e/test_integration_dist_tp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@
},
])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_draft_model_tp_lt_target_model_tp4(test_llm_generator,
baseline_llm_generator,
batch_size: int):
Expand Down Expand Up @@ -104,7 +105,8 @@ def test_draft_model_tp_lt_target_model_tp4(test_llm_generator,
# ensure fast test.
64,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_skip_speculation(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify job failure with RuntimeError when all sequences skip speculation.
Expand Down
15 changes: 10 additions & 5 deletions tests/spec_decode/e2e/test_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
# Use smaller output len for fast test.
7,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_logprobs_equality(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify output logprobs are equal with and without speculative decoding.
Expand Down Expand Up @@ -75,7 +76,8 @@ def test_logprobs_equality(baseline_llm_generator, test_llm_generator,
# Use smaller output len for fast test.
7,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_diff_num_logprobs(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int,
num_logprobs: int):
Expand Down Expand Up @@ -120,7 +122,8 @@ def test_diff_num_logprobs(baseline_llm_generator, test_llm_generator,
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_logprobs_different_k(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Veriy logprob greedy equality with different speculation lens.
Expand Down Expand Up @@ -163,7 +166,8 @@ def test_logprobs_different_k(baseline_llm_generator, test_llm_generator,
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_logprobs_when_skip_speculation(baseline_llm_generator,
test_llm_generator, batch_size: int,
output_len: int):
Expand Down Expand Up @@ -202,7 +206,8 @@ def test_logprobs_when_skip_speculation(baseline_llm_generator,
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_logprobs_temp_1(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify at least one logprob result has num_logprobs+1, which tests the
Expand Down
12 changes: 8 additions & 4 deletions tests/spec_decode/e2e/test_medusa_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify greedy equality with different batch size."""
Expand Down Expand Up @@ -115,7 +116,8 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
test_llm_generator,
batch_size: int,
Expand Down Expand Up @@ -164,7 +166,8 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that mlp speculative decoding produces exact equality
Expand Down Expand Up @@ -207,7 +210,8 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that mlp speculative decoding produces exact equality
Expand Down
12 changes: 8 additions & 4 deletions tests/spec_decode/e2e/test_mlp_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@
128,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify greedy equality with different batch size."""
Expand Down Expand Up @@ -111,7 +112,8 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
128,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
test_llm_generator,
batch_size: int,
Expand Down Expand Up @@ -160,7 +162,8 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that mlp speculative decoding produces exact equality
Expand Down Expand Up @@ -202,7 +205,8 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that mlp speculative decoding produces exact equality
Expand Down
38 changes: 25 additions & 13 deletions tests/spec_decode/e2e/test_multistep_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_spec_decode_e2e_with_detokenization(test_llm_generator,
batch_size: int):
"""Run generation with speculative decoding on a batch. Verify the engine
Expand Down Expand Up @@ -141,7 +141,8 @@ def test_spec_decode_e2e_with_detokenization(test_llm_generator,
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_spec_decode_e2e_with_async_engine(test_llm_generator,
baseline_llm_generator,
batch_size: int):
Expand Down Expand Up @@ -192,7 +193,8 @@ def test_spec_decode_e2e_with_async_engine(test_llm_generator,
1536,
])
@pytest.mark.parametrize("batch_size", [1])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
baseline_llm_generator, test_llm_generator, batch_size: int,
output_len: int):
Expand Down Expand Up @@ -252,7 +254,8 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
256,
])
@pytest.mark.parametrize("batch_size", [64])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
baseline_llm_generator, test_llm_generator, batch_size: int,
output_len: int):
Expand Down Expand Up @@ -297,7 +300,8 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
256,
])
@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
baseline_llm_generator, test_llm_generator, batch_size: int,
max_output_len: int):
Expand Down Expand Up @@ -341,7 +345,8 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
# Use decently long output len for a high quality test.
256,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
baseline_llm_generator, test_llm_generator, batch_size: int,
output_len: int):
Expand Down Expand Up @@ -385,7 +390,8 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
# Use smaller output len for fast test.
64,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
baseline_llm_generator, test_llm_generator, batch_size: int,
output_len: int):
Expand Down Expand Up @@ -432,7 +438,8 @@ def test_spec_decode_e2e_greedy_correctness_real_model_large_bs(
256,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_spec_decode_e2e_greedy_correctness_with_preemption(
baseline_llm_generator, test_llm_generator, batch_size: int,
output_len: int):
Expand Down Expand Up @@ -486,7 +493,8 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_spec_decode_different_block_size(baseline_llm_generator,
test_llm_generator, batch_size: int,
output_len: int):
Expand Down Expand Up @@ -533,7 +541,8 @@ def test_spec_decode_different_block_size(baseline_llm_generator,
# ensure fast test.
64,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_skip_speculation(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify greedy equality when some (or all) sequences skip speculation.
Expand Down Expand Up @@ -570,7 +579,8 @@ def test_skip_speculation(baseline_llm_generator, test_llm_generator,
])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("output_len", [10])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_disable_speculation(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify greedy equality when all sequences disable speculation.
Expand Down Expand Up @@ -612,7 +622,8 @@ def test_disable_speculation(baseline_llm_generator, test_llm_generator,
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
output_len: int):
"""Verify that speculative decoding produces exact equality to without spec
Expand Down Expand Up @@ -656,7 +667,8 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int,
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_typical_acceptance_sampling(baseline_llm_generator,
test_llm_generator, batch_size: int,
output_len: int):
Expand Down
12 changes: 8 additions & 4 deletions tests/spec_decode/e2e/test_ngram_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@
256,
])
@pytest.mark.parametrize("batch_size", [1, 32])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
test_llm_generator, batch_size: int,
output_len: int):
Expand Down Expand Up @@ -104,7 +105,8 @@ def test_ngram_e2e_greedy_correctness(baseline_llm_generator,
256,
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
test_llm_generator,
batch_size: int,
Expand Down Expand Up @@ -158,7 +160,8 @@ def test_ngram_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that ngram speculative decoding produces exact equality
Expand Down Expand Up @@ -199,7 +202,8 @@ def test_ngram_different_k(baseline_llm_generator, test_llm_generator,
# Use smaller output len for fast test.
32,
])
@pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("baseline_seed", [1])
@pytest.mark.parametrize("test_seed", [1])
def test_ngram_disable_queue(baseline_llm_generator, test_llm_generator,
batch_size: int, output_len: int):
"""Verify that ngram speculative decoding produces exact equality
Expand Down
Loading
Loading