Skip to content

Commit

Permalink
Remove some Kosmos-2 copied from (huggingface#27149)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
  • Loading branch information
ydshieh and ydshieh authored Oct 30, 2023
1 parent cd19b19 commit 3224c0c
Showing 1 changed file with 2 additions and 10 deletions.
12 changes: 2 additions & 10 deletions src/transformers/models/kosmos2/modeling_kosmos2.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
]


# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
Expand All @@ -67,7 +66,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)


# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
Expand Down Expand Up @@ -660,7 +658,7 @@ def forward(
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
Expand Down Expand Up @@ -1114,7 +1112,6 @@ def __init__(self, config: Kosmos2TextConfig):

self.gradient_checkpointing = False

# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
Expand Down Expand Up @@ -1268,7 +1265,7 @@ def forward(
past_key_value = past_key_values[idx] if past_key_values is not None else None

if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
Expand Down Expand Up @@ -1428,11 +1425,6 @@ def _init_weights(self, module):
if module.embed_tokens.padding_idx is not None:
module.embed_tokens.weight.data[module.embed_tokens.padding_idx].zero_()

def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (Kosmos2TextTransformer, Kosmos2VisionEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None


class Kosmos2VisionModel(Kosmos2PreTrainedModel):
config_class = Kosmos2VisionConfig
Expand Down

0 comments on commit 3224c0c

Please sign in to comment.