Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Speculative Decoding 1/2 ] Add typical acceptance sampling as one of the sampling techniques in the verifier #5131

Merged
merged 50 commits into from
Jun 18, 2024
Merged
Changes from 1 commit
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
00d74c0
Test commit
sroy745 May 8, 2024
5650b95
Merge pull request #1 from vllm-project/main
sroy745 May 29, 2024
d936a25
Merge branch 'main' of https://github.com/sroy745/vllm into test-branch
sroy745 May 29, 2024
81eb966
Refactoring some of the logic in rejection_sampling to a base class a…
sroy745 May 30, 2024
b09a20f
Adding comments and some new tests.
sroy745 May 31, 2024
9340244
Formatting and comments.
sroy745 May 31, 2024
7a7f9bd
Added comments and some more tests.
sroy745 May 31, 2024
aecc2e8
Dummy commit
sroy745 May 31, 2024
513e252
Reverting change to llm_engine_example.py
sroy745 May 31, 2024
69e52f0
Updating a comment
sroy745 May 31, 2024
8f36146
Merge branch 'vllm-project:main' into main
sroy745 Jun 3, 2024
5559757
Merge remote-tracking branch 'origin/main' into acceptance_sampling_s…
sroy745 Jun 3, 2024
312bc49
Dummy commit
sroy745 Jun 3, 2024
9e75057
Merge branch 'vllm-project:main' into main
sroy745 Jun 3, 2024
ca41215
Fix device for tensors
sroy745 Jun 3, 2024
ba4d3fb
Fixing review comments
sroy745 Jun 6, 2024
c9c7a8b
Addressing comments
sroy745 Jun 7, 2024
0ad9afd
Add a new test for non default posteriors
sroy745 Jun 7, 2024
db2c679
Merge branch 'vllm-project:main' into main
sroy745 Jun 7, 2024
9b572f7
Documentation for test
sroy745 Jun 7, 2024
920ffa4
Merge remote-tracking branch 'origin/main' into acceptance_sampling_s…
sroy745 Jun 7, 2024
644cae4
Fix ruff errors
sroy745 Jun 7, 2024
5ee4018
Fix spell corrections
sroy745 Jun 7, 2024
b414d42
Fixing spell errors
sroy745 Jun 7, 2024
7aa132b
Ran format.sh
sroy745 Jun 7, 2024
8d7512c
Merge branch 'vllm-project:main' into main
sroy745 Jun 10, 2024
d0e0827
Merge branch 'main' into acceptance_sampling_spec_decode
sroy745 Jun 12, 2024
1473f74
Merge branch 'vllm-project:main' into main
sroy745 Jun 12, 2024
26694a7
Fix formatting
sroy745 Jun 12, 2024
4013e1a
Merge branch 'vllm-project:main' into main
sroy745 Jun 14, 2024
2dbdd78
Merge branch 'vllm-project:main' into main
sroy745 Jun 17, 2024
1b8fd3e
Test commit
sroy745 May 8, 2024
5e80dd8
Refactoring some of the logic in rejection_sampling to a base class a…
sroy745 May 30, 2024
77dcb79
Adding comments and some new tests.
sroy745 May 31, 2024
db3e6fa
Formatting and comments.
sroy745 May 31, 2024
7d876d4
Added comments and some more tests.
sroy745 May 31, 2024
78d011b
Dummy commit
sroy745 May 31, 2024
f67ffcd
Reverting change to llm_engine_example.py
sroy745 May 31, 2024
1bdbe09
Updating a comment
sroy745 May 31, 2024
07f7b9c
Dummy commit
sroy745 Jun 3, 2024
74782dc
Fix device for tensors
sroy745 Jun 3, 2024
7bcbbdb
Fixing review comments
sroy745 Jun 6, 2024
f0c2b79
Addressing comments
sroy745 Jun 7, 2024
0b7bc75
Add a new test for non default posteriors
sroy745 Jun 7, 2024
197ded6
Documentation for test
sroy745 Jun 7, 2024
bd435df
Fix ruff errors
sroy745 Jun 7, 2024
d4c750e
Fix spell corrections
sroy745 Jun 7, 2024
2c2004f
Fixing spell errors
sroy745 Jun 7, 2024
85c48f5
Ran format.sh
sroy745 Jun 7, 2024
b841f90
Merge branch 'acceptance_sampling_spec_decode' of https://github.com/…
sroy745 Jun 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Documentation for test
  • Loading branch information
sroy745 committed Jun 17, 2024
commit 197ded69348a93a30d3189f78a9d1f7d0cf042ec
28 changes: 13 additions & 15 deletions tests/samplers/test_typical_acceptance_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,18 +374,10 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
def test_accept_tokens_set_non_default_posteriors(
seed: int, disable_bonus_tokens: bool, device: str):
"""
Test the TypicalAcceptanceSampler's behavior when only a subset of draft
tokens should be accepted.

This test verifies that the TypicalAcceptanceSampler correctly accepts or
rejects draft tokens based on a zero-temperature target probability
distribution. Specifically, it ensures that:

- When all draft tokens match tokens with a probability of 1.0 in the
target distribution, all draft tokens are accepted.
- When only some draft tokens match tokens with a probability of 1.0 in
the target distribution, only those matching tokens are accepted, and the
rest are rejected.
Test the TypicalAcceptanceSampler with custom posterior thresholds and
alpha values. This test verifies that by modifying the posterior
thresholds and alpha values we can change the acceptance behavior of the
sampler.
"""
set_random_seed(seed)
k = 5
Expand All @@ -395,9 +387,12 @@ def test_accept_tokens_set_non_default_posteriors(
typical_acceptance_sampler = TypicalAcceptanceSampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens)
typical_acceptance_sampler.init_gpu_tensors(rank=0)
# Create a temperature zero target probability distribution and ensure
# all draft token ids correspond to the tokens with 1.0 probability.
# Verify that all of them are accepted.
# Simulate temperature 0 probability distribution for target
# probabilities and create target probabilities such that only 1 token
# id has probability 1.0 and others have a very low probability of
# 0.00001. Populate draft_token_ids such that they exclude the token_ids
# with probability = 1.0. Without any changes to the posterior thresholds
# none of the draft tokens are accepted.
target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist(
batch_size, k, vocab_size))
target_probs[target_probs == 0] = 0.00001
Expand All @@ -414,6 +409,9 @@ def test_accept_tokens_set_non_default_posteriors(
assert output_token_ids.shape[1] == (k + 1)
assert torch.all(output_token_ids[:, 1:-1] == -1)

# Change the posterior threshold values to 0.0 so that we will
# now accept even draft tokens with very low probability in the
# target distribution. Simulate and verify the same.
typical_acceptance_sampler = TypicalAcceptanceSampler(
strict_mode=True, disable_bonus_tokens=disable_bonus_tokens,
posterior_threshold=0.0, posterior_alpha=0.0)
Expand Down