Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add self.head_dim for VisionAttention in Qwen2-VL #33211

Merged
merged 14 commits into from
Sep 6, 2024
1 change: 1 addition & 0 deletions src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ class VisionAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good, I am a bit baffled as to how this was not caught, the math.sqrt could not have run 😅

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have few failing tests: https://github.com/huggingface/transformers/actions/runs/10656977518/job/29536379001#step:13:694 but this was not caught.

     @require_bitsandbytes
    def test_small_model_integration_test_batch_different_resolutions(self):
        model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", load_in_4bit=True)
>       text, vision_infos = self.processor.apply_chat_template(
            self.messages, tokenize=False, add_generation_prompt=True
        )
E       ValueError: too many values to unpack (expected 2)

this one needs to be updated

Copy link
Contributor Author

@GeLee-Q GeLee-Q Sep 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker Hello, I found the code related to vision_infos in the file vision_process.py on QwenLM. However, the Qwen2-VL processor in tranformers does not have an interface to process vision_info. Therefore, I added a function to process this information.

self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.proj = nn.Linear(dim, dim)

Expand Down
4 changes: 3 additions & 1 deletion tests/models/qwen2_vl/test_modeling_qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ def prepare_config_and_inputs_for_common(self):
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
input_ids[:, torch.arange(vision_seqlen, device=torch_device) + 1] = self.image_token_id
labels = torch.zeros(
(self.batch_size, self.seq_length - 1 + vision_seqlen), dtype=torch.long, device=torch_device
(self.batch_size, self.seq_length - 1 + vision_seqlen),
dtype=torch.long,
device=torch_device,
)
patch_size = self.vision_config["patch_size"]
inputs_dict = {
Expand Down
Loading