diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f7108b12b904..9263486ddddf 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -230,6 +230,7 @@ from accelerate import __version__ as accelerate_version from accelerate.state import AcceleratorState from accelerate.utils import ( + AutocastKwargs, DistributedDataParallelKwargs, DistributedType, load_fsdp_model, @@ -1832,7 +1833,8 @@ def torch_jit_model_eval(self, model, dataloader, training=False): # remove mixed precision hooks from the model if original_forward: jit_model.forward = original_forward - with self.accelerator.autocast(cache_enabled=False), torch.no_grad(): + autocast_handler = AutocastKwargs(cache_enabled=False) + with self.accelerator.autocast(autocast_handler=autocast_handler), torch.no_grad(): if version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0.0"): if isinstance(example_batch, dict): jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False)