diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index df4a834b370c..9172f9599f77 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -1107,8 +1107,9 @@ def on_log(self, args, state, control, model=None, logs=None, **kwargs): self.setup(args, state, model) if state.is_world_process_zero: if self._experiment is not None: + rewritten_logs = rewrite_logs(logs) self._experiment.__internal_api__log_metrics__( - logs, step=state.global_step, epoch=state.epoch, framework="transformers" + rewritten_logs, step=state.global_step, epoch=state.epoch, framework="transformers" ) def on_train_end(self, args, state, control, **kwargs): @@ -1125,6 +1126,15 @@ def on_train_end(self, args, state, control, **kwargs): self._experiment.clean() self._initialized = False + def on_predict(self, args, state, control, metrics, **kwargs): + if not self._initialized: + self.setup(args, state, model=None) + if state.is_world_process_zero and self._experiment is not None: + rewritten_metrics = rewrite_logs(metrics) + self._experiment.__internal_api__log_metrics__( + rewritten_metrics, step=state.global_step, epoch=state.epoch, framework="transformers" + ) + class AzureMLCallback(TrainerCallback): """