Skip to content

Commit

Permalink
Fix the bug that Trainer cannot correctly call `torch_jit_model_eva…
Browse files Browse the repository at this point in the history
…l` (#35722)

Fix the bug that the accelerator.autocast does not pass parameters correctly when calling torch_jit_model_eval (#35706)
  • Loading branch information
Wanguy authored Jan 16, 2025
1 parent 2cbcc58 commit 8b78d9d
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@
from accelerate import __version__ as accelerate_version
from accelerate.state import AcceleratorState
from accelerate.utils import (
AutocastKwargs,
DistributedDataParallelKwargs,
DistributedType,
load_fsdp_model,
Expand Down Expand Up @@ -1832,7 +1833,8 @@ def torch_jit_model_eval(self, model, dataloader, training=False):
# remove mixed precision hooks from the model
if original_forward:
jit_model.forward = original_forward
with self.accelerator.autocast(cache_enabled=False), torch.no_grad():
autocast_handler = AutocastKwargs(cache_enabled=False)
with self.accelerator.autocast(autocast_handler=autocast_handler), torch.no_grad():
if version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0.0"):
if isinstance(example_batch, dict):
jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)
Expand Down

0 comments on commit 8b78d9d

Please sign in to comment.