Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: ModernBertModel.forward() got an unexpected keyword argument 'num_items_in_batch' #36074

Open
4 tasks
Bachstelze opened this issue Feb 6, 2025 · 15 comments · May be fixed by #36095
Open
4 tasks
Labels

Comments

@Bachstelze
Copy link

System Info

Reopening: #35838

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

see #35838

Expected behavior

see #35838

@Bachstelze Bachstelze added the bug label Feb 6, 2025
@Rocketknight1
Copy link
Member

cc @muellerzr , but also @Bachstelze have you tried installing the latest version from main to see if that fixes it?

@Bachstelze
Copy link
Author

I installed transformers with those lines after the PR:

git clone https://github.com/huggingface/transformers.git
cd transformers
pip install .

@Rocketknight1
Copy link
Member

Hmm, okay, seems like the issue still exists then!

@muellerzr
Copy link
Contributor

@Bachstelze Can you please provide the full stack trace and code you're running for this instance?

@Bachstelze
Copy link
Author

@Pappasad also mentions that TrOCR doesn't forward
the stack trace for modernBERT as encoder in an encoder-decoder model from this colab notebook:

TypeError                                 Traceback (most recent call last)

[<ipython-input-9-a71da896f123>](https://localhost:8080/#) in <cell line: 0>()
    131     optimizers=(adam, lr_scheduler)
    132 )
--> 133 trainer.train()
    134 print("training finished", flush=True)
    135 #wandb.finish()

8 frames

[/usr/local/lib/python3.11/dist-packages/transformers/trainer.py](https://localhost:8080/#) in train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   2183                 hf_hub_utils.enable_progress_bars()
   2184         else:
-> 2185             return inner_training_loop(
   2186                 args=args,
   2187                 resume_from_checkpoint=resume_from_checkpoint,

[/usr/local/lib/python3.11/dist-packages/transformers/trainer.py](https://localhost:8080/#) in _inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2489                     )
   2490                     with context():
-> 2491                         tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
   2492 
   2493                     if (

[/usr/local/lib/python3.11/dist-packages/transformers/trainer.py](https://localhost:8080/#) in training_step(self, model, inputs, num_items_in_batch)
   3608 
   3609         with self.compute_loss_context_manager():
-> 3610             loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
   3611 
   3612         del inputs

[/usr/local/lib/python3.11/dist-packages/transformers/trainer.py](https://localhost:8080/#) in compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
   3669                 loss_kwargs["num_items_in_batch"] = num_items_in_batch
   3670             inputs = {**inputs, **loss_kwargs}
-> 3671         outputs = model(**inputs)
   3672         # Save past state if it exists
   3673         # TODO: this needs to be fixed and made cleaner later.

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

[/usr/local/lib/python3.11/dist-packages/transformers/models/encoder_decoder/modeling_encoder_decoder.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, **kwargs)
    601 
    602         if encoder_outputs is None:
--> 603             encoder_outputs = self.encoder(
    604                 input_ids=input_ids,
    605                 attention_mask=attention_mask,

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1734             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1735         else:
-> 1736             return self._call_impl(*args, **kwargs)
   1737 
   1738     # torchrec tests the code consistency with the following code

[/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1745                 or _global_backward_pre_hooks or _global_backward_hooks
   1746                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747             return forward_call(*args, **kwargs)
   1748 
   1749         result = None

TypeError: ModernBertModel.forward() got an unexpected keyword argument 'num_items_in_batch'

@muellerzr
Copy link
Contributor

@Bachstelze can you try via pip install git+https://github.com/huggingface/transformers@muellerzr-more-models-sadface

@muellerzr muellerzr linked a pull request Feb 7, 2025 that will close this issue
5 tasks
@Bachstelze
Copy link
Author

still the same error

@Bachstelze
Copy link
Author

@muellerzr don't make a sad face, take your time and enjoy your weekend! ;)

@wagpa
Copy link

wagpa commented Feb 16, 2025

@muellerzr I have a similar issue with Donut

DonutSwinModel.forward() got an unexpected keyword argument 'num_items_in_batch'

Looking at the linked PR, does it also fix the Donut model?

I'm quite new to this, but the changes in the PR are for the encoder_decoder/modeling_encoder_decoder.py while Donut seems to be using vision_encoder_decoder/modeling_vision_encoder_decoder.py

@bluestealth
Copy link

bluestealth commented Feb 23, 2025

@muellerzr also hit a similar issue for DistilBert upgrading from 4.46.3 -> 4.49.0

def _call_impl(self, *args, **kwargs):
        forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
        # If we don't have any hooks, we want to skip the rest of the logic in
        # this function, and just call forward.
        if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
                or _global_backward_pre_hooks or _global_backward_hooks
                or _global_forward_hooks or _global_forward_pre_hooks):
>           return forward_call(*args, **kwargs)
E           TypeError: DistilBertForSequenceClassification.forward() got an unexpected keyword argument 'num_items_in_batch'

../pip_deps_torch/site-packages/torch/nn/modules/module.py:1520: TypeError

@milanalimova
Copy link

milanalimova commented Feb 25, 2025

@muellerzr I get the error ViModel.forward() got an unexpected keyword argument 'num_items_in_batch' with TrOCR even using your approach:

@Bachstelze can you try via pip install git+https://github.com/huggingface/transformers@muellerzr-more-models-sadface

@zheka77111
Copy link

zheka77111 commented Feb 25, 2025

I got same error after update my transformer's libs

@milanalimova
Copy link

I got same error after update my transformer's libs

@zheka77111 What was the previous version you used? Did that version work for you?

@LinYan-lab
Copy link

same error DeiTModel.forward() got an unexpected keyword argument 'num_items_in_batch' with transformers == 4.49.0.dev0

Image

@zheka77111
Copy link

I got same error after update my transformer's libs

@zheka77111 What was the previous version you used? Did that version work for you?

4.36, but now I backup to it and get same error

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants