diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 6ab813ad9ade..5e7919a95a7d 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -550,8 +550,13 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += cache_position[0] + 1 - + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_multimodal_rotary_pos_emb( query_states, key_states, cos, sin, position_ids, self.rope_scaling["mrope_section"] @@ -632,10 +637,19 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += cache_position[0] + 1 + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = cache_position[-1] + rotary_seq_len = ( + max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len + ) + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_multimodal_rotary_pos_emb( diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py index ba3b30d94533..536e0ab54abc 100644 --- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py @@ -27,8 +27,9 @@ is_vision_available, ) from transformers.testing_utils import ( - require_bitsandbytes, + require_flash_attn, require_torch, + require_torch_gpu, slow, torch_device, ) @@ -311,7 +312,7 @@ def setUp(self): ], } ] - url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" + url = "https://qianwen-res.oss-accelerate-overseas.aliyuncs.com/Qwen2-VL/demo_small.jpg" self.image = Image.open(requests.get(url, stream=True).raw) def tearDown(self): @@ -319,11 +320,9 @@ def tearDown(self): torch.cuda.empty_cache() @slow - @require_bitsandbytes def test_small_model_integration_test(self): model = Qwen2VLForConditionalGeneration.from_pretrained( - "Qwen/Qwen2-VL-7B-Instruct", - load_in_4bit=True, + "Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto" ) text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True) @@ -334,23 +333,23 @@ def test_small_model_integration_test(self): expected_pixel_slice = torch.tensor( [ - [0.8501, 0.8647, 0.8647], - [1.0106, 1.0106, 1.0252], - [0.9960, 1.0106, 1.0252], - [1.0982, 1.1128, 1.1274], - [1.0836, 1.0982, 1.0982], - [1.1858, 1.1858, 1.1858], + [0.8792, 0.8792, 0.9084], + [1.1858, 1.1858, 1.2296], + [1.2004, 1.2004, 1.2150], + [1.4340, 1.4340, 1.4194], + [1.3902, 1.4048, 1.4194], + [1.5216, 1.5362, 1.5362], ], dtype=torch.float32, device="cpu", ) - assert torch.allclose(expected_pixel_slice, inputs.pixel_values[:6, :3], atol=1e-3) + assert torch.allclose(expected_pixel_slice, inputs.pixel_values[:6, :3], atol=3e-3) # verify generation inputs = inputs.to(torch_device) output = model.generate(**inputs, max_new_tokens=30) - EXPECTED_DECODED_TEXT = "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?assistant\nThe dog in the picture appears to be a Labrador Retriever or a similar breed. Labradors are known for their friendly and intelligent nature," + EXPECTED_DECODED_TEXT = "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets" self.assertEqual( self.processor.decode(output[0], skip_special_tokens=True), @@ -358,9 +357,10 @@ def test_small_model_integration_test(self): ) @slow - @require_bitsandbytes def test_small_model_integration_test_batch(self): - model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", load_in_4bit=True) + model = Qwen2VLForConditionalGeneration.from_pretrained( + "Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto" + ) text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True) inputs = self.processor(text=[text, text], images=[self.image, self.image], return_tensors="pt").to( torch_device @@ -370,78 +370,125 @@ def test_small_model_integration_test_batch(self): output = model.generate(**inputs, max_new_tokens=30) EXPECTED_DECODED_TEXT = [ - "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?assistant\nThe dog in the picture appears to be a Labrador Retriever or a similar breed. Labradors are known for their friendly and intelligent nature,", - "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?assistant\nThe dog in the image appears to be a Labrador Retriever or a similar breed. Labradors are known for their friendly and outgoing nature,", - ] + 'system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular choices', + 'system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets' + ] # fmt: skip self.assertEqual( self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT, ) - self.assertEqual( - self.processor.batch_decode(output, skip_special_tokens=True)[0], - self.processor.batch_decode(output, skip_special_tokens=True)[1], - ) @slow - @require_bitsandbytes def test_small_model_integration_test_batch_wo_image(self): - model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", load_in_4bit=True) + model = Qwen2VLForConditionalGeneration.from_pretrained( + "Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto" + ) text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True) messages2 = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "Who are you?"}, ] text2 = self.processor.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) - inputs = self.processor(text=[text, text2], images=[self.image], return_tensors="pt").to(torch_device) + inputs = self.processor(text=[text, text2], images=[self.image], padding=True, return_tensors="pt").to( + torch_device + ) + + # it should not matter whether two images are the same size or not + output = model.generate(**inputs, max_new_tokens=30) + + EXPECTED_DECODED_TEXT = [ + 'system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets', + 'system\nYou are a helpful assistant.\nuser\nWho are you?\nassistant\nI am Qwen, a large language model created by Alibaba Cloud. I am designed to assist with various tasks and answer questions to the best of my' + ] # fmt: skip + self.assertEqual( + self.processor.batch_decode(output, skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + def test_small_model_integration_test_batch_different_resolutions(self): + model = Qwen2VLForConditionalGeneration.from_pretrained( + "Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto" + ) + text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True) + text2 = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True) + image2 = self.image.resize((224, 224)) + inputs = self.processor(text=[text, text2], images=[self.image, image2], padding=True, return_tensors="pt").to( + torch_device + ) # it should not matter whether two images are the same size or not output = model.generate(**inputs, max_new_tokens=30) EXPECTED_DECODED_TEXT = [ - "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?assistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and outgoing personalities, as well as their", - "system\nYou are a helpful assistant.user\nWho are you?assistant\nI am Qwen, a large language model created by Alibaba Cloud. I am designed to assist with various tasks and answer a wide range of questions to", + "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets", + "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets", + ] + self.assertEqual( + self.processor.batch_decode(output, skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + + @slow + @require_flash_attn + @require_torch_gpu + def test_small_model_integration_test_batch_flashatt2(self): + model = Qwen2VLForConditionalGeneration.from_pretrained( + "Qwen/Qwen2-VL-7B-Instruct", + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + device_map="auto", + ) + text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True) + inputs = self.processor(text=[text, text], images=[self.image, self.image], return_tensors="pt").to( + torch_device + ) + + # it should not matter whether two images are the same size or not + output = model.generate(**inputs, max_new_tokens=30) + + EXPECTED_DECODED_TEXT = [ + "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets", + "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets", ] self.assertEqual( self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT, ) + self.assertEqual( + self.processor.batch_decode(output, skip_special_tokens=True)[0], + self.processor.batch_decode(output, skip_special_tokens=True)[1], + ) @slow - @require_bitsandbytes - def test_small_model_integration_test_batch_different_resolutions(self): - model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct", load_in_4bit=True) - text, vision_infos = self.processor.apply_chat_template( - self.messages, tokenize=False, add_generation_prompt=True + @require_flash_attn + @require_torch_gpu + def test_small_model_integration_test_batch_wo_image_flashatt2(self): + model = Qwen2VLForConditionalGeneration.from_pretrained( + "Qwen/Qwen2-VL-7B-Instruct", + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + device_map="auto", ) + text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True) messages2 = [ - { - "role": "user", - "content": [ - { - "type": "image", - "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", - "resized_height": 504, - "resized_width": 252, - }, - {"type": "text", "text": "What kind of dog is this?"}, - ], - } + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who are you?"}, ] - text2, vision_infos2 = self.processor.apply_chat_template( - messages2, tokenize=False, add_generation_prompt=True + text2 = self.processor.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) + inputs = self.processor(text=[text, text2], images=[self.image], padding=True, return_tensors="pt").to( + torch_device ) - inputs = self.processor( - text=[text, text2], vision_infos=[vision_infos, vision_infos2], return_tensors="pt" - ).to(torch_device) # it should not matter whether two images are the same size or not output = model.generate(**inputs, max_new_tokens=30) EXPECTED_DECODED_TEXT = [ - "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?assistant\nThe dog in the picture appears to be a Labrador Retriever or a similar breed. Labradors are known for their friendly and intelligent nature,", - "system\nYou are a helpful assistant.\nuser\nWho are you?assistant\nI am a large language model created by Alibaba Cloud. I am called Qwen.", + "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets", + "system\nYou are a helpful assistant.\nuser\nWho are you?\nassistant\nI am Qwen, a large language model created by Alibaba Cloud. I am designed to answer a wide range of questions and provide information on various topics", ] + self.assertEqual( self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT,