diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index c4eb1eade6e2..2bef9544dfb7 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1343,6 +1343,12 @@ def forward( if pixel_values is not None: image_tokens = self.get_image_tokens(pixel_values) + n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item() + n_image_features = image_tokens.shape[0] + if n_image_tokens_in_text != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens_in_text}, features {n_image_features}" + ) special_image_mask = input_ids == self.vocabulary_mapping.image_token_id image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) input_ids = input_ids.masked_scatter(special_image_mask, image_tokens) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 938ec4d5e423..d27263320d4e 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1689,6 +1689,12 @@ def forward( if pixel_values is not None: pixel_values = pixel_values.type(self.visual.get_dtype()) image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds) image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) @@ -1696,6 +1702,12 @@ def forward( if pixel_values_videos is not None: pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) video_mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds) video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)