Skip to content

Commit

Permalink
[minor] add error handling to all llava models
Browse files Browse the repository at this point in the history
  • Loading branch information
laurentd-lunit committed Sep 30, 2024
1 parent 4ea3e0e commit 610bc75
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 51 deletions.
17 changes: 12 additions & 5 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,11 +511,18 @@ def forward(

# TODO: @raushan retain only the new behavior after v4.47
else:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens == n_image_features:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
else:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)

outputs = self.language_model(
attention_mask=attention_mask,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -979,17 +979,31 @@ def forward(
# TODO: @raushan retain only the new behavior after v4.47
else:
if image_features is not None:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens == n_image_features:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
else:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
if video_features is not None:
special_image_mask = (
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0]
if n_video_tokens == n_video_features:
special_image_mask = (
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
else:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)

outputs = self.language_model(
attention_mask=attention_mask,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -484,17 +484,31 @@ def forward(
# TODO: @raushan retain only the new behavior after v4.47
else:
if image_features is not None:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens == n_image_features:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
else:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
if video_features is not None:
special_image_mask = (
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0]
if n_video_tokens == n_video_features:
special_image_mask = (
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
else:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)

outputs = self.language_model(
attention_mask=attention_mask,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -621,15 +621,21 @@ def forward(
image_newline=self.image_newline,
vision_aspect_ratio=vision_aspect_ratio,
)

special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens == n_image_features:
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
else:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)

# Video are simply embedded and further pooled to decrease seq len
if pixel_values_videos is not None:
Expand Down
37 changes: 25 additions & 12 deletions src/transformers/models/video_llava/modeling_video_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,18 +620,31 @@ def forward(
# TODO: @raushan retain only the new behavior after v4.47
else:
if image_outputs is not None:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

if video_outputs is not None:
special_image_mask = (
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens == n_image_features:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
else:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
if video_features is not None:
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0]
if n_video_tokens == n_video_features:
special_image_mask = (
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
else:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)

outputs = self.language_model(
attention_mask=attention_mask,
Expand Down
17 changes: 12 additions & 5 deletions src/transformers/models/vipllava/modeling_vipllava.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,11 +506,18 @@ def forward(

# TODO: @raushan retain only the new behavior after v4.47
else:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens == n_image_features:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
else:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)

outputs = self.language_model(
attention_mask=attention_mask,
Expand Down

0 comments on commit 610bc75

Please sign in to comment.