From 0a5fd5e99a924ad4fd20004e95dcca8917bcb974 Mon Sep 17 00:00:00 2001 From: Boris Feld Date: Tue, 27 Aug 2024 13:40:53 +0200 Subject: [PATCH] Log additional test metrics with the CometCallback (#33124) * Log additional test metrics with the CometCallback. Also follow the same metric naming convention as other callbacks * Merge 2 subsequent if-statements * Trigger Build --------- Co-authored-by: Aliaksandr Kuzmik --- src/transformers/integrations/integration_utils.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index df4a834b370c..9172f9599f77 100755 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -1107,8 +1107,9 @@ def on_log(self, args, state, control, model=None, logs=None, **kwargs): self.setup(args, state, model) if state.is_world_process_zero: if self._experiment is not None: + rewritten_logs = rewrite_logs(logs) self._experiment.__internal_api__log_metrics__( - logs, step=state.global_step, epoch=state.epoch, framework="transformers" + rewritten_logs, step=state.global_step, epoch=state.epoch, framework="transformers" ) def on_train_end(self, args, state, control, **kwargs): @@ -1125,6 +1126,15 @@ def on_train_end(self, args, state, control, **kwargs): self._experiment.clean() self._initialized = False + def on_predict(self, args, state, control, metrics, **kwargs): + if not self._initialized: + self.setup(args, state, model=None) + if state.is_world_process_zero and self._experiment is not None: + rewritten_metrics = rewrite_logs(metrics) + self._experiment.__internal_api__log_metrics__( + rewritten_metrics, step=state.global_step, epoch=state.epoch, framework="transformers" + ) + class AzureMLCallback(TrainerCallback): """