Skip to content

Commit

Permalink
Use Sparsification Group logger (#1936)
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Jan 11, 2024
1 parent c6b67b4 commit e6002e2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
14 changes: 9 additions & 5 deletions src/sparseml/core/logger/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,12 +821,16 @@ def __init__(
mode: LoggingModeType = "exact",
frequency_type: FrequencyType = "epoch",
):
self._loggers = loggers or []
self._name = name
if log_python and not any(
isinstance(log, PythonLogger) for log in self._loggers
):
self._loggers.append(PythonLogger())
self._loggers = (
loggers
or SparsificationGroupLogger(
python=log_python,
name=name,
tensorboard=True,
wandb_=True,
).loggers
)

self.frequency_manager = FrequencyManager(
mode=mode,
Expand Down
1 change: 1 addition & 0 deletions src/sparseml/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def _log_loss(self, event_type: EventType, loss: Any):
return

current_step = self._lifecycle.event_lifecycle.current_index
loss = loss if isinstance(loss, dict) else {"loss": loss}
self.state.loggers.metric.log_scalars(
tag="Loss", values=loss, step=current_step
)
Expand Down

0 comments on commit e6002e2

Please sign in to comment.