Skip to content

Commit

Permalink
Add logs for loss
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Jan 11, 2024
1 parent df53768 commit 3875046
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
5 changes: 3 additions & 2 deletions src/sparseml/core/logger/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,8 @@ class LoggerManager(ABC):
:param log_frequency: number of stes or fraction of steps to wait between logs
:param mode: The logging mode to use, either "on_change" or "exact",
"on_change" will log when the model has been updated since the last log,
"exact" will log at the given frequency regardless of model updates
"exact" will log at the given frequency regardless of model updates.
Defaults to "exact"
:param frequency_type: The frequency type to use, either "epoch", "step", or "batch"
controls what the frequency manager is tracking, e.g. if the frequency type
is "epoch", then the frequency manager will track the number of epochs that
Expand All @@ -817,7 +818,7 @@ def __init__(
log_frequency: Union[float, None] = 0.1,
log_python: bool = True,
name: str = "manager",
mode: LoggingModeType = "on_change",
mode: LoggingModeType = "exact",
frequency_type: FrequencyType = "epoch",
):
self._loggers = loggers or []
Expand Down
31 changes: 27 additions & 4 deletions src/sparseml/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class SparseSession:

def __init__(self):
self._lifecycle = SparsificationLifecycle()
self._loss_logged = False

@property
def lifecycle(self) -> SparsificationLifecycle:
Expand Down Expand Up @@ -220,6 +221,10 @@ def finalize(self, **kwargs) -> ModifiedState:
:param kwargs: additional kwargs to pass to the lifecycle's finalize method
:return: the modified state of the session after finalizing
"""
# log losses on finalization
self._loss_logged = False
self._log_loss(event_type=EventType.LOSS_CALCULATED, loss=self.state.loss)

mod_data = self._lifecycle.finalize(**kwargs)

return ModifiedState(
Expand Down Expand Up @@ -261,7 +266,14 @@ def event(
mod_data = self._lifecycle.event(
event_type=event_type, batch_data=batch_data, loss=loss, **kwargs
)

# Update loss
if loss is not None:
self.state.loss = loss

self._log_model_info(event_type=event_type)
self._log_loss(event_type=event_type, loss=loss)

return ModifiedState(
model=self.state.model.model if self.state.model else None,
optimizer=self.state.optimizer.optimizer if self.state.optimizer else None,
Expand Down Expand Up @@ -294,20 +306,31 @@ def _log_model_info(self, event_type: EventType):

epoch = self._lifecycle.event_lifecycle.current_index

# override logging cadence temporarily

if should_log_model_info(
model=self.state.model,
loggers=self.state.loggers,
epoch=epoch,
last_log_epoch=self.state._last_log_epoch,
):
log_model_info(
state=self.state,
epoch=epoch,
)
# update last log epoch
self.state._last_log_epoch = epoch
self.state.loggers.log_written(epoch)

# loss was not logged for this cadence
# reset flag
self._loss_logged = False

def _log_loss(self, event_type: EventType, loss: Any):
if event_type != EventType.LOSS_CALCULATED or self._loss_logged:
return

current_step = self._lifecycle.event_lifecycle.current_index
self.state.loggers.metric.log_scalars(
tag="Loss", values=loss, step=current_step
)
self._loss_logged = True


_global_session = SparseSession()
Expand Down

0 comments on commit 3875046

Please sign in to comment.