Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse tree #42

Merged
merged 12 commits into from
Sep 20, 2023
Merged

Sparse tree #42

merged 12 commits into from
Sep 20, 2023

Conversation

ctlllll
Copy link
Contributor

@ctlllll ctlllll commented Sep 20, 2023

#34

@ctlllll ctlllll merged commit 10a2697 into main Sep 20, 2023
Comment on lines +213 to +220
# Extract the TOPK candidates from the medusa logits.
candidates_medusa_logits = torch.topk(medusa_logits[:, 0, -1], TOPK, dim = -1).indices

Returns:
- candidates (torch.Tensor): Cartesian product of candidate tokens across Medusa layers.
- tree_candidates (torch.Tensor): Reshaped candidates matched to the tree structure.
"""
# Greedy decoding for original logits
candidates = [torch.argmax(logits[:, -1]).unsqueeze(0)]
for i in range(medusa_logits.shape[0]):
candidate_i = torch.topk(medusa_logits[i, 0, -1], medusa_topk[i]).indices
candidates.append(candidate_i)
candidates_flat = torch.cat(candidates)
candidates = torch.cartesian_prod(*candidates)
tree_candidates = candidates_flat[tree_indices].unsqueeze(0)
return candidates, tree_candidates
# Combine the selected candidate from the original logits with the topk medusa logits.
candidates = torch.cat([candidates_logit, candidates_medusa_logits.view(-1)], dim=-1)

# Map the combined candidates to the tree indices to get tree candidates.
tree_candidates = candidates[tree_indices]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we pruned low probability subtrees before verifying with the base model? This would give the benefits of sparse tree attention without relying on manually specifying the sparse tree paths in medusa_choices.

Concretely,

  1. a path's probability is the product of its nodes' proabilities according to medusa_logits (so it's not a formal verification of the sequence probability like running the base model, but we may be able to prune quite a few very low probability subtrees).
  2. tree indices are updated to reflect the paths remaining after pruning
  3. generate_medusa_buffers creates buffers for a smaller size than the exponentially growing tree, so we prune to fit that size

What's the intuition of why this should improve decoding speed?
Medusa may excel in cases where the probability density of the tree is heavily imblanced--i.e., we have an easy subsequence coming up which is reflected in the Medusa heads being confident in their predictions. But in other cases, the deeper Medusa heads are uncertain and we're pretty much wasting computation verifying deep parts of the subtree with full attention.

It would be interesting to see how much the shape of the tree's probability density changes across different contexts. If it has a large variance, it seems like it would be valuable to have a more dynamic sparse tree that takes an optimal shape based on the current context.

Basically, this could allow us to

  1. crank up the number of Medusa heads/size of the tree
  2. only verify deep paths of the tree with full tree attention if there's a decent Medusa probability
  3. not pay for a massive tree in cases where deep Medusa heads are uncertain

@leeyeehoo leeyeehoo deleted the sparse_tree branch December 22, 2023 17:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants