diff --git a/src/sparseml/core/session.py b/src/sparseml/core/session.py index 445fbbff904..e3214f0240a 100644 --- a/src/sparseml/core/session.py +++ b/src/sparseml/core/session.py @@ -271,8 +271,7 @@ def event( if loss is not None: self.state.loss = loss - self._log_model_info(event_type=event_type) - self._log_loss(event_type=event_type, loss=loss) + self.log(event_type=event_type, loss=loss) return ModifiedState( model=self.state.model.model if self.state.model else None, @@ -281,6 +280,16 @@ def event( modifier_data=mod_data, ) + def log(self, event_type: EventType, loss: Optional[Any] = None): + """ + Log model and loss information for the current event type + + :param event_type: the event type to log for + :param loss: the loss to log if any + """ + self._log_model_info() + self._log_loss(event_type=event_type, loss=loss) + def reset(self): """ Reset the session to its initial state @@ -301,8 +310,8 @@ def get_serialized_recipe(self) -> str: recipe = self.lifecycle.recipe_container.compiled_recipe return recipe.yaml() - def _log_model_info(self, event_type: EventType): - # Log model level logs if needed + def _log_model_info(self): + # Log model level logs if cadence reached epoch = self._lifecycle.event_lifecycle.current_index diff --git a/tests/sparseml/core/test_session.py b/tests/sparseml/core/test_session.py index 66f875357a8..9938c003d26 100644 --- a/tests/sparseml/core/test_session.py +++ b/tests/sparseml/core/test_session.py @@ -167,6 +167,24 @@ def test_apply_calls_lifecycle_initialize_and_finalize( lifecycle_mock._hit_count["finalize"] == 1 ), "apply did not invoke the lifecycle finalize method" + def test_log_methods_called_same_number_of_times( + self, mocker, monkeypatch, setup_active_session + ): + mock_log_model_info = mocker.patch.object( + setup_active_session, "_log_model_info" + ) + mock_log_loss = mocker.patch.object(setup_active_session, "_log_loss") + monkeypatch.setattr(setup_active_session, "_lifecycle", LifeCycleMock()) + + event_type = EventType.BATCH_START + batch_data = None + loss = None + + setup_active_session.initialize(framework=Framework.pytorch) + setup_active_session.event(event_type, batch_data, loss) + + assert mock_log_model_info.call_count == mock_log_loss.call_count + @pytest.mark.parametrize( "attribute_name",