Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: speculative decoding #27979

Merged
merged 8 commits into from
Dec 19, 2023
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 67 additions & 18 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4624,40 +4624,89 @@ def assisted_decoding(
for i in range(candidate_length + 1):
new_logits[:, i, :] = logits_warper(candidate_input_ids[:, : cur_len + i], new_logits[:, i, :])

# 3. Obtain the next tokens from the original model logits.
if do_sample:
probs = new_logits.softmax(dim=-1)
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
# 3. Select the accepted tokens. There are two possible cases:
# Case 1: `do_sample=True` and we have logits for the candidates (originally from speculative decoding)
# 👉 Apply algorithm 1 from the speculative decoding paper (https://arxiv.org/pdf/2211.17192.pdf).
# NOTE: Unless otherwise stated, the variable names match those in the paper.
if do_sample and candidate_logits is not None:
# Gets the probabilities from the logits. q_i and p_i denote the model and assistant probabilities of
# the tokens selected by the assistant, respectivelly.
gante marked this conversation as resolved.
Show resolved Hide resolved
q = candidate_logits.softmax(dim=-1)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are not the best variable names, but it's hard to compare against the original algorithm if they don't match 🤔 As such, I've decided to keep the original names

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with it as there's good comments and other variables are well names e.g. is_rejected :)

q_i = q[
:,
torch.range(0, candidate_length - 1, dtype=torch.int),
candidate_input_ids[:, -candidate_length:],
].squeeze(0, 1)
p = new_logits.softmax(dim=-1)
p_i = p[
:,
torch.range(0, candidate_length - 1, dtype=torch.int),
candidate_input_ids[:, -candidate_length:],
].squeeze(0, 1)
probability_ratio = p_i / q_i

# When probability_ratio > 1 (i.e. q_i(x) < p_i(x)), keep the token. Otherwise reject with
# p = 1 - probability_ratio (= keep with p = probability_ratio). Keep all the tokens until the first
# rejection
gante marked this conversation as resolved.
Show resolved Hide resolved
r_i = torch.rand_like(probability_ratio)
is_rejected = r_i > probability_ratio # equivalent: is_accepted = r_i <= probability_ratio
n_matches = (is_rejected.cumsum(dim=-1) < 1).sum() # this is `n` in algorithm 1
gante marked this conversation as resolved.
Show resolved Hide resolved

# Ensure we don't generate beyond max_len or an EOS token (not in algorithm 1, but needed for correct
# behavior)
if last_assistant_token_is_eos and n_matches == candidate_length:
n_matches -= 1
n_matches = min(n_matches, max_len - cur_len - 1)

# Next token selection: if there is a rejection, adjust the distribution from the main model before
# sampling.
gamma = candidate_logits.shape[1]
p_n_plus_1 = p[:, n_matches, :]
if n_matches < gamma:
q_n_plus_1 = q[:, n_matches, :]
p_prime = torch.clamp((p_n_plus_1 - q_n_plus_1), min=0).softmax(dim=-1)
else:
p_prime = p_n_plus_1
t = torch.multinomial(p_prime, num_samples=1).squeeze(1)[None, :]

# The selected tokens include the matches plus the next sampled token
selected_tokens = torch.cat((candidate_input_ids[:, :n_matches], t), dim=-1)
gante marked this conversation as resolved.
Show resolved Hide resolved

# Case 2: all other cases (originally from assisted generation) 👉 Compare the tokens selected from the
# original model logits with the candidate tokens. We can keep the candidate tokens until the first
# mismatch, or until the max length is reached.
else:
selected_tokens = new_logits.argmax(dim=-1)
if do_sample:
probs = new_logits.softmax(dim=-1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this case still relevant? Not sure it's a good idea to have two "assisted decoding" do_sample=True cases in our generate. Should we maybe just deprecate this case?

selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
else:
selected_tokens = new_logits.argmax(dim=-1)
Comment on lines +4647 to +4651
Copy link
Contributor

@patrickvonplaten patrickvonplaten Dec 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if do_sample:
probs = new_logits.softmax(dim=-1)
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
else:
selected_tokens = new_logits.argmax(dim=-1)
if do_sample:
probs = new_logits.softmax(dim=-1)
selected_tokens = torch.multinomial(probs[0, :, :], num_samples=1).squeeze(1)[None, :]
else:
selected_tokens = new_logits.argmax(dim=-1)

It's probably time to soon factor this out into something like:

selected_tokens = Categorical(new_logits / temperature).sample()

everywhere in generate

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes! Then equivalent sampling/non-sampling methods (e.g. greedy decoding/samplinh) could be merged into a single function, facilitating maintenance. I'm going to leave it to a follow-up PR, though, to keep this PR exclusively about speculative decoding.


candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
gante marked this conversation as resolved.
Show resolved Hide resolved

# 4. Compare the argmax from the original model logits with the assistant forecasted tokens. We can keep
# the assistant forecasted tokens until the first mismatch, or until the max length is reached.
candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()
# Ensure we don't generate beyond max_len or an EOS token
if last_assistant_token_is_eos and n_matches == candidate_length:
n_matches -= 1
n_matches = min(n_matches, max_len - cur_len - 1)

# 5. Update variables according to the number of matching assistant tokens. Remember: the token generated
# 4. Update variables according to the number of matching assistant tokens. Remember: the token generated
# by the model after the last candidate match is also valid, as it is generated from a correct sequence.
# Because of this last token, assisted generation search reduces to a normal greedy search/sample if there
# is no match.

# 5.1. Ensure we don't generate beyond max_len or an EOS token
if last_assistant_token_is_eos and n_matches == candidate_length:
n_matches -= 1
n_matches = min(n_matches, max_len - cur_len - 1)

# 5.2. Get the valid continuation, after the matching tokens
# 4.1. Get the valid continuation, after the matching tokens
valid_tokens = selected_tokens[:, : n_matches + 1]
input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
if streamer is not None:
streamer.put(valid_tokens.cpu())
new_cur_len = input_ids.shape[-1]

# 5.3. Discard past key values relative to unused assistant tokens
# 4.2. Discard past key values relative to unused assistant tokens
new_cache_size = new_cur_len - 1
outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)

# 6. Update the candidate generation strategy if needed
# 5. Update the candidate generation strategy if needed
candidate_generator.update_candidate_strategy(input_ids, new_logits, n_matches)

if synced_gpus and this_peer_finished:
Expand Down
Loading