Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Qwen2VL vision model by precomputing cos/sin embeds before ViT blocks #35837

Merged
merged 4 commits into from
Feb 13, 2025

Conversation

li-plus
Copy link
Contributor

@li-plus li-plus commented Jan 22, 2025

What does this PR do?

The current implementation of Qwen2VL recomputes cos/sin embeddings and repeats them on each layer of the ViT model. This is computational inefficient. This PR attempts to precompute the rotary embeddings at the beginning of the ViT forward function to save computation and memory.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker @amyeroberts @qubvel

@qubvel
Copy link
Member

qubvel commented Jan 22, 2025

cc @zucchini-nlp for vlms

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes total sense, thanks! One tiny nit, we'd need to keep backwards compatibility for the vision attention classes and add a small deprecation warning

As an example of how it was in llama:

if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(value_states, position_ids)
else:
cos, sin = position_embeddings

@li-plus li-plus force-pushed the fast-qwen2vl-vision-rope branch from 2e912f3 to 343e47e Compare January 24, 2025 02:01
@li-plus
Copy link
Contributor Author

li-plus commented Jan 24, 2025

I just rebased and the CI is failing on this check. Is it related to this PR?

...
No differences found for 
src/transformers/models/moonshine/modeling_moonshine.py.
Traceback (most recent call last):
  File "/root/transformers/utils/check_modular_conversion.py", line 80, in <module>
    raise ValueError("Some diff and their modeling code did not match.")
ValueError: Some diff and their modeling code did not match.

Exited with code exit status 1

@zucchini-nlp
Copy link
Member

@li-plus Right, you need to add the changes in a "modular_model.py" file and then running make fix-copies will copy everything to the actual "modeling_model.py". So you make CI green by adding the same changes to the modular_qwen2vl.py and then with make fix-copies

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating. Left a small comment and we need to make CI green. Approving and after changes feel free to @ ArthurZucker , the core maintainer

For the core maintainer: not sure if we need to keep the signature in functions (apply_rotary_pos_emb_vision) for BC

@li-plus li-plus force-pushed the fast-qwen2vl-vision-rope branch from 343e47e to 64c2827 Compare January 25, 2025 13:57
@li-plus
Copy link
Contributor Author

li-plus commented Jan 25, 2025

I see. It seems that this change also affects the implementation of Qwen2.5VL since it reuses the modules in Qwen2VL. I need to get a Qwen2.5VL model to ensure it's bug-free.

@zucchini-nlp
Copy link
Member

@li-plus running make- fix-copies should fix it and copy content from Qwen2VL to 2.5VL

@li-plus li-plus force-pushed the fast-qwen2vl-vision-rope branch from 64c2827 to 8c22089 Compare January 28, 2025 03:31
@li-plus
Copy link
Contributor Author

li-plus commented Jan 28, 2025

Running make fix-copies ports changes to Qwen2.5VL incompatible with other modules in it. Fortunately just now Qwen2.5VL has been released so I have tested and fixed it. Now this PR is compatible for both Qwen2VL & Qwen2.5VL.

@li-plus
Copy link
Contributor Author

li-plus commented Jan 28, 2025

@ArthurZucker Hi, would you take a look at this PR? It's been approved and the CI turns green.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super super late review! Yeah LGTM and it is welcome!

Comment on lines 221 to 229
q = q.float()
k = k.float()
cos = cos.unsqueeze(-2)
sin = sin.unsqueeze(-2)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.to(orig_q_dtype)
k_embed = k_embed.to(orig_k_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
q = q.float()
k = k.float()
cos = cos.unsqueeze(-2)
sin = sin.unsqueeze(-2)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.to(orig_q_dtype)
k_embed = k_embed.to(orig_k_dtype)
q, k = q.float(), k.float()
cos, sin = cos.unsqueeze(-2), sin.unsqueeze(-2)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.to(orig_q_dtype)
k_embed = k_embed.to(orig_k_dtype)

would just like it to be a tad less verbose!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Already fixed.

Comment on lines +165 to +167
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need chunk here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because they use apply_rotary_emb provided by flash-attn when attn impl is FA2. It only needs half of the cos/sin embeddings, i.e. [seq_len, head_dim//2], while the naive apply_rotary_pos_emb_vision for eager/sdpa requires full embeds ([seq_len, head_dim]). Conventionally we pass full embeddings to this API so I have to slice it for FA2. Do you have an elegant and efficient implementation for this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope thanks for explaining!

@ArthurZucker
Copy link
Collaborator

Could you update to fix the conflicts? 🤗

@li-plus li-plus force-pushed the fast-qwen2vl-vision-rope branch from 8c22089 to 7990710 Compare February 13, 2025 03:11
@li-plus
Copy link
Contributor Author

li-plus commented Feb 13, 2025

Conflicts are just solved.

@li-plus li-plus force-pushed the fast-qwen2vl-vision-rope branch from 40d625d to 65b620d Compare February 13, 2025 03:25
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Comment on lines +165 to +167
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope thanks for explaining!

@ArthurZucker ArthurZucker merged commit 5f0fd11 into huggingface:main Feb 13, 2025
16 checks passed
@li-plus li-plus deleted the fast-qwen2vl-vision-rope branch February 13, 2025 16:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants