diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index df87dec07582..fe9eb6ea8ab6 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -513,16 +513,15 @@ def forward( else: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] - if n_image_tokens == n_image_features: - special_image_mask = ( - (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - else: + if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index ddda3c0d43a5..52c37df57aa1 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -899,16 +899,15 @@ def forward( else: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] - if n_image_tokens == n_image_features: - special_image_mask = ( - (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - else: + if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index d732fcefdf75..aad21e28a6b2 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -981,29 +981,27 @@ def forward( if image_features is not None: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] - if n_image_tokens == n_image_features: - special_image_mask = ( - (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - else: + if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) if video_features is not None: n_video_tokens = (input_ids == self.config.video_token_index).sum().item() n_video_features = video_features.shape[0] - if n_video_tokens == n_video_features: - special_image_mask = ( - (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) - else: + if n_video_tokens != n_video_features: raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) + special_image_mask = ( + (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/llava_next_video/modular_llava_next_video.py b/src/transformers/models/llava_next_video/modular_llava_next_video.py index 5e7bf97761eb..9df11d6e688c 100644 --- a/src/transformers/models/llava_next_video/modular_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modular_llava_next_video.py @@ -486,29 +486,27 @@ def forward( if image_features is not None: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] - if n_image_tokens == n_image_features: - special_image_mask = ( - (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - else: + if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) if video_features is not None: n_video_tokens = (input_ids == self.config.video_token_index).sum().item() n_video_features = video_features.shape[0] - if n_video_tokens == n_video_features: - special_image_mask = ( - (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) - else: + if n_video_tokens != n_video_features: raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) + special_image_mask = ( + (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 4b0737023d1e..1bf1e9dfa984 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -623,19 +623,18 @@ def forward( ) n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] - if n_image_tokens == n_image_features: - special_image_mask = ( - (input_ids == self.config.image_token_index) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - else: + if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) + special_image_mask = ( + (input_ids == self.config.image_token_index) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # Video are simply embedded and further pooled to decrease seq len if pixel_values_videos is not None: diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 80b9df25234c..2508dbf652a1 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -622,29 +622,27 @@ def forward( if image_outputs is not None: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] - if n_image_tokens == n_image_features: - special_image_mask = ( - (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - else: + if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) if video_features is not None: n_video_tokens = (input_ids == self.config.video_token_index).sum().item() n_video_features = video_features.shape[0] - if n_video_tokens == n_video_features: - special_image_mask = ( - (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) - else: + if n_video_tokens != n_video_features: raise ValueError( f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" ) + special_image_mask = ( + (input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features) outputs = self.language_model( attention_mask=attention_mask, diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 4fbcf8dd5ab3..e39ff4a76606 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -508,16 +508,15 @@ def forward( else: n_image_tokens = (input_ids == self.config.image_token_index).sum().item() n_image_features = image_features.shape[0] - if n_image_tokens == n_image_features: - special_image_mask = ( - (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) - else: + if n_image_tokens != n_image_features: raise ValueError( f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" ) + special_image_mask = ( + (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) + ) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( attention_mask=attention_mask,