Skip to content

Commit

Permalink
[minor] avoid nested if else
Browse files Browse the repository at this point in the history
  • Loading branch information
laurentd-lunit committed Sep 30, 2024
1 parent 610bc75 commit 353c610
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 73 deletions.
13 changes: 6 additions & 7 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,16 +513,15 @@ def forward(
else:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens == n_image_features:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
else:
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

outputs = self.language_model(
attention_mask=attention_mask,
Expand Down
13 changes: 6 additions & 7 deletions src/transformers/models/llava_next/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,16 +899,15 @@ def forward(
else:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens == n_image_features:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
else:
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

outputs = self.language_model(
attention_mask=attention_mask,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -981,29 +981,27 @@ def forward(
if image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens == n_image_features:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
else:
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if video_features is not None:
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0]
if n_video_tokens == n_video_features:
special_image_mask = (
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
else:
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
special_image_mask = (
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)

outputs = self.language_model(
attention_mask=attention_mask,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -486,29 +486,27 @@ def forward(
if image_features is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens == n_image_features:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
else:
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if video_features is not None:
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0]
if n_video_tokens == n_video_features:
special_image_mask = (
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
else:
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
special_image_mask = (
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)

outputs = self.language_model(
attention_mask=attention_mask,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -623,19 +623,18 @@ def forward(
)
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens == n_image_features:
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
else:
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index)
.unsqueeze(-1)
.expand_as(inputs_embeds)
.to(inputs_embeds.device)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

# Video are simply embedded and further pooled to decrease seq len
if pixel_values_videos is not None:
Expand Down
26 changes: 12 additions & 14 deletions src/transformers/models/video_llava/modeling_video_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,29 +622,27 @@ def forward(
if image_outputs is not None:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens == n_image_features:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
else:
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
if video_features is not None:
n_video_tokens = (input_ids == self.config.video_token_index).sum().item()
n_video_features = video_features.shape[0]
if n_video_tokens == n_video_features:
special_image_mask = (
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)
else:
if n_video_tokens != n_video_features:
raise ValueError(
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
)
special_image_mask = (
(input_ids == self.config.video_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, video_features)

outputs = self.language_model(
attention_mask=attention_mask,
Expand Down
13 changes: 6 additions & 7 deletions src/transformers/models/vipllava/modeling_vipllava.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,16 +508,15 @@ def forward(
else:
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
n_image_features = image_features.shape[0]
if n_image_tokens == n_image_features:
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
else:
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
special_image_mask = (
(input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
)
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

outputs = self.language_model(
attention_mask=attention_mask,
Expand Down

0 comments on commit 353c610

Please sign in to comment.