Skip to content

Commit

Permalink
Fix model kwargs (#35875)
Browse files Browse the repository at this point in the history
* Save state

* Make a failing test

* Better test

* mpt -> done, many more to go

* Rm extranious

* Bamba

* Bert

* big_bird

* biogpt

* bloom

* codegen

* ctrl

* data2vec

* dbrx

* Through up to Dbrx

* electra

* ernie

* falcon

* Fuyu/persimmon

* Include noop kwargs to base models

* Rebase

* Skip musigen

* Refactor/skip mllama

* Revert makefile

* Rm file

* Fix PT failing, need to modify rest of loss funcs to not resize

* Propagate some

* Continue

* More

* More options

* Mostly fixed

* Proved that it's the same

* Bloom is good

* Make ability to override loss func possible

* Fixup

* Clean

* Fix xglm

* Quality tests

* Skip OCR2

* Make specific loss for xglm

* Make order the same/line up 1:1

* xglm

* Skip fx output loss bloom model

* Didn't pass in pad_token_id

* Fix quality
  • Loading branch information
muellerzr authored and ArthurZucker committed Feb 6, 2025
1 parent 093bebc commit 3d6e55c
Show file tree
Hide file tree
Showing 48 changed files with 365 additions and 241 deletions.
7 changes: 7 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5142,6 +5142,9 @@ def tplize(mod: torch.nn.Module) -> None:

@property
def loss_function(self):
if hasattr(self, "_loss_function"):
return self._loss_function

loss_type = getattr(self, "loss_type", None)

if loss_type is None or loss_type not in LOSS_MAPPING:
Expand All @@ -5152,6 +5155,10 @@ def loss_function(self):
loss_type = "ForCausalLM"
return LOSS_MAPPING[loss_type]

@loss_function.setter
def loss_function(self, value):
self._loss_function = value

def get_compiled_call(self, compile_config: CompileConfig):
"""Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between
non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -1200,6 +1200,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # NOOP kwargs, for now
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/bamba/modular_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # NOOP kwargs, for now
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
from ...generation import GenerationMixin
Expand Down Expand Up @@ -734,6 +733,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs, # NOOP kwargs, for now
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Expand Down Expand Up @@ -901,6 +901,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Expand Down Expand Up @@ -963,18 +964,20 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)

lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
lm_loss = self.loss_function(
prediction_scores,
labels,
vocab_size=self.config.vocab_size,
**kwargs,
)

if not return_dict:
output = (prediction_scores,) + outputs[1:]
Expand Down
14 changes: 9 additions & 5 deletions src/transformers/models/big_bird/modeling_big_bird.py
Original file line number Diff line number Diff line change
Expand Up @@ -1983,6 +1983,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs, # NOOP kwargs, for now
) -> Union[BaseModelOutputWithPoolingAndCrossAttentions, Tuple[torch.FloatTensor]]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Expand Down Expand Up @@ -2540,6 +2541,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[CausalLMOutputWithCrossAttentions, Tuple[torch.FloatTensor]]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Expand Down Expand Up @@ -2580,18 +2582,20 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)

lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
lm_loss = self.loss_function(
prediction_scores,
labels,
vocab_size=self.config.vocab_size,
**kwargs,
)

if not return_dict:
output = (prediction_scores,) + outputs[2:]
Expand Down
13 changes: 8 additions & 5 deletions src/transformers/models/biogpt/modeling_biogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs, # NOOP kwargs, for now
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -757,6 +758,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand All @@ -783,11 +785,12 @@ def forward(

lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
lm_loss = self.loss_function(
prediction_scores,
labels,
vocab_size=self.config.vocab_size,
**kwargs,
)

if not return_dict:
output = (prediction_scores,) + outputs[1:]
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,8 @@ def forward(
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
"""
# Bloom has deprecated kwargs, so we need to pop num_items_in_batch explicitly
num_items_in_batch = deprecated_arguments.pop("num_items_in_batch", None)
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings.warn(
Expand Down Expand Up @@ -990,14 +992,12 @@ def forward(
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
batch_size, seq_length, vocab_size = shift_logits.shape
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
loss = self.loss_function(
lm_logits,
labels,
vocab_size=self.config.vocab_size,
num_items_in_batch=num_items_in_batch,
)

if not return_dict:
Expand Down
12 changes: 7 additions & 5 deletions src/transformers/models/camembert/modeling_camembert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1584,6 +1584,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Expand Down Expand Up @@ -1655,11 +1656,12 @@ def forward(
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(prediction_scores.device)
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
lm_loss = self.loss_function(
prediction_scores,
labels,
vocab_size=self.config.vocab_size,
**kwargs,
)

if not return_dict:
output = (prediction_scores,) + outputs[2:]
Expand Down
14 changes: 8 additions & 6 deletions src/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
Expand Down Expand Up @@ -450,6 +449,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # NOOP kwargs, for now
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -741,6 +741,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down Expand Up @@ -775,12 +776,13 @@ def forward(
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
loss = self.loss_function(
lm_logits,
labels,
vocab_size=self.config.vocab_size,
**kwargs,
)

loss = loss.to(hidden_states.dtype)

Expand Down
14 changes: 8 additions & 6 deletions src/transformers/models/ctrl/modeling_ctrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs, # NOOP kwargs, for now
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
r"""
Returns:
Expand Down Expand Up @@ -537,6 +538,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Expand Down Expand Up @@ -593,12 +595,12 @@ def forward(

loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
loss = self.loss_function(
lm_logits,
labels,
vocab_size=self.config.vocab_size,
**kwargs,
)

if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/models/data2vec/modeling_data2vec_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Expand Down Expand Up @@ -975,13 +976,12 @@ def forward(

lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()

labels = labels.to(shifted_prediction_scores.device)
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
lm_loss = self.loss_function(
prediction_scores,
labels,
vocab_size=self.config.vocab_size,
**kwargs,
)

if not return_dict:
output = (prediction_scores,) + outputs[2:]
Expand Down
18 changes: 8 additions & 10 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,7 @@ def forward(
output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, # NOOP kwargs, for now
) -> Union[Tuple, MoeModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -1274,6 +1275,7 @@ def forward(
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**kwargs,
) -> Union[Tuple, MoeCausalLMOutputWithPast]:
r"""Forward function for causal language modeling.
Expand Down Expand Up @@ -1337,16 +1339,12 @@ def forward(

loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
loss = self.loss_function(
logits,
labels,
vocab_size=self.config.vocab_size,
**kwargs,
)

aux_loss = None
if output_router_logits:
Expand Down
12 changes: 7 additions & 5 deletions src/transformers/models/electra/modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,6 +1564,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Expand Down Expand Up @@ -1633,11 +1634,12 @@ def forward(

lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
lm_loss = self.loss_function(
prediction_scores,
labels,
vocab_size=self.config.vocab_size,
**kwargs,
)

if not return_dict:
output = (prediction_scores,) + outputs[1:]
Expand Down
Loading

0 comments on commit 3d6e55c

Please sign in to comment.