diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 48cf500b7454..6857fb624c0f 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1690,8 +1690,13 @@ def prepare_inputs_for_generation( # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) + + # If past_key_values are present then slice the postion ids for only only the unprocessed tokens. if past_key_values: - position_ids = position_ids[:, -input_ids.shape[1] :] + if inputs_embeds is not None and input_ids.shape[1] == 0: + position_ids = position_ids[:, -inputs_embeds.shape[1] :] + else: + position_ids = position_ids[:, -input_ids.shape[1] :] # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. position_ids = position_ids.clone(memory_format=torch.contiguous_format) diff --git a/tests/models/idefics/test_modeling_idefics.py b/tests/models/idefics/test_modeling_idefics.py index 01871e81b30e..cc9efc967db2 100644 --- a/tests/models/idefics/test_modeling_idefics.py +++ b/tests/models/idefics/test_modeling_idefics.py @@ -755,6 +755,65 @@ def test_generate_without_input_ids(self): ) self.assertIsNotNone(output_ids_generate) + @pytest.mark.generate + def test_generate_continue_from_inputs_embeds(self): + """Overwrite for IDEFICS: Ensure image attention mask is processed while continuing from `inputs_embeds`.""" + + for model_class in self.all_generative_model_classes: + config, inputs = self.model_tester.prepare_config_and_inputs_for_common() + print(inputs) + + model = model_class(config).to(torch_device).eval() + + model.generation_config.pad_token_id = model.generation_config.eos_token_id = -1 + model.generation_config.forced_eos_token_id = None + model.generation_config.use_cache = True + + input_ids = inputs.pop("input_ids") + input_embeds = model.get_input_embeddings()(input_ids) + + generation_kwargs = { + "return_dict_in_generate": True, + "do_sample": False, + } + + inputs["inputs_embeds"] = input_embeds + + # Traditional way of generating text, with `return_dict_in_generate` to return the past key values + outputs = model.generate(**inputs, max_new_tokens=4, **generation_kwargs) + # Let's generate again, but passing the past key values in between (3 + 1 = 4 tokens). Note that the + # inputs may need to be tweaked across `generate` calls (like the attention mask). + initial_output = model.generate(**inputs, max_new_tokens=3, **generation_kwargs) + inputs["past_key_values"] = initial_output.past_key_values + + new_attention_len = input_ids.shape[1] + initial_output.sequences.shape[-1] + continued_embeds = torch.cat([input_embeds, model.get_input_embeddings()(initial_output.sequences)], dim=1) + inputs["inputs_embeds"] = continued_embeds + + if "attention_mask" in inputs: + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], + (0, new_attention_len - inputs["attention_mask"].shape[1]), + mode="constant", + value=1, + ) + if "image_attention_mask" in inputs: + inputs["image_attention_mask"] = inputs["image_attention_mask"][..., -1:, :] + + cached_output = model.generate(**inputs, max_new_tokens=1, **generation_kwargs) + + # Verify that the combined outputs match the full generation. + combined_output_sequences = torch.concat([initial_output.sequences, cached_output.sequences], axis=1) + self.assertListEqual(outputs.sequences.tolist(), combined_output_sequences.tolist()) + for layer_idx in range(len(cached_output.past_key_values)): + for kv_idx in range(len(cached_output.past_key_values[layer_idx])): + self.assertTrue( + torch.allclose( + outputs.past_key_values[layer_idx][kv_idx], + cached_output.past_key_values[layer_idx][kv_idx], + ) + ) + def _check_attentions_for_generate( self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1 ): diff --git a/tests/models/moshi/test_modeling_moshi.py b/tests/models/moshi/test_modeling_moshi.py index f51d0199156c..09278f0d24c4 100644 --- a/tests/models/moshi/test_modeling_moshi.py +++ b/tests/models/moshi/test_modeling_moshi.py @@ -358,7 +358,7 @@ def test_disk_offload_bin(self): def test_disk_offload_safetensors(self): pass - @unittest.skip(reason="Test becomes too complex with Moshi requiring multiple modalities input.") + @unittest.skip(reason="Test becomes too complex with Moshi requiring multiple input modalities.") def test_generate_continue_from_inputs_embeds(self): pass @@ -828,6 +828,7 @@ def test_generate_without_input_ids(self): output_ids_generate = model.generate( do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True ) + print(output_ids_generate) self.assertIsNotNone(output_ids_generate) @unittest.skip(reason="The audio encoder has no gradients.")