-
Notifications
You must be signed in to change notification settings - Fork 165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sparse tree #42
Sparse tree #42
Conversation
update roadmap
This reverts commit 922689a.
# Extract the TOPK candidates from the medusa logits. | ||
candidates_medusa_logits = torch.topk(medusa_logits[:, 0, -1], TOPK, dim = -1).indices | ||
|
||
Returns: | ||
- candidates (torch.Tensor): Cartesian product of candidate tokens across Medusa layers. | ||
- tree_candidates (torch.Tensor): Reshaped candidates matched to the tree structure. | ||
""" | ||
# Greedy decoding for original logits | ||
candidates = [torch.argmax(logits[:, -1]).unsqueeze(0)] | ||
for i in range(medusa_logits.shape[0]): | ||
candidate_i = torch.topk(medusa_logits[i, 0, -1], medusa_topk[i]).indices | ||
candidates.append(candidate_i) | ||
candidates_flat = torch.cat(candidates) | ||
candidates = torch.cartesian_prod(*candidates) | ||
tree_candidates = candidates_flat[tree_indices].unsqueeze(0) | ||
return candidates, tree_candidates | ||
# Combine the selected candidate from the original logits with the topk medusa logits. | ||
candidates = torch.cat([candidates_logit, candidates_medusa_logits.view(-1)], dim=-1) | ||
|
||
# Map the combined candidates to the tree indices to get tree candidates. | ||
tree_candidates = candidates[tree_indices] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if we pruned low probability subtrees before verifying with the base model? This would give the benefits of sparse tree attention without relying on manually specifying the sparse tree paths in medusa_choices
.
Concretely,
- a path's probability is the product of its nodes' proabilities according to
medusa_logits
(so it's not a formal verification of the sequence probability like running the base model, but we may be able to prune quite a few very low probability subtrees). - tree indices are updated to reflect the paths remaining after pruning
generate_medusa_buffers
creates buffers for a smaller size than the exponentially growing tree, so we prune to fit that size
What's the intuition of why this should improve decoding speed?
Medusa may excel in cases where the probability density of the tree is heavily imblanced--i.e., we have an easy subsequence coming up which is reflected in the Medusa heads being confident in their predictions. But in other cases, the deeper Medusa heads are uncertain and we're pretty much wasting computation verifying deep parts of the subtree with full attention.
It would be interesting to see how much the shape of the tree's probability density changes across different contexts. If it has a large variance, it seems like it would be valuable to have a more dynamic sparse tree that takes an optimal shape based on the current context.
Basically, this could allow us to
- crank up the number of Medusa heads/size of the tree
- only verify deep paths of the tree with full tree attention if there's a decent Medusa probability
- not pay for a massive tree in cases where deep Medusa heads are uncertain
#34