Skip to content

Commit

Permalink
Log additional test metrics with the CometCallback (huggingface#33124)
Browse files Browse the repository at this point in the history
* Log additional test metrics with the CometCallback.

Also follow the same metric naming convention as other callbacks

* Merge 2 subsequent if-statements

* Trigger Build

---------

Co-authored-by: Aliaksandr Kuzmik <alexander.kuzmik99@gmail.com>
  • Loading branch information
2 people authored and BernardZach committed Dec 5, 2024
1 parent 0444a03 commit 0a5fd5e
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/transformers/integrations/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,8 +1107,9 @@ def on_log(self, args, state, control, model=None, logs=None, **kwargs):
self.setup(args, state, model)
if state.is_world_process_zero:
if self._experiment is not None:
rewritten_logs = rewrite_logs(logs)
self._experiment.__internal_api__log_metrics__(
logs, step=state.global_step, epoch=state.epoch, framework="transformers"
rewritten_logs, step=state.global_step, epoch=state.epoch, framework="transformers"
)

def on_train_end(self, args, state, control, **kwargs):
Expand All @@ -1125,6 +1126,15 @@ def on_train_end(self, args, state, control, **kwargs):
self._experiment.clean()
self._initialized = False

def on_predict(self, args, state, control, metrics, **kwargs):
if not self._initialized:
self.setup(args, state, model=None)
if state.is_world_process_zero and self._experiment is not None:
rewritten_metrics = rewrite_logs(metrics)
self._experiment.__internal_api__log_metrics__(
rewritten_metrics, step=state.global_step, epoch=state.epoch, framework="transformers"
)


class AzureMLCallback(TrainerCallback):
"""
Expand Down

0 comments on commit 0a5fd5e

Please sign in to comment.