Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli committed Jan 11, 2024
1 parent 3875046 commit c6b67b4
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
17 changes: 13 additions & 4 deletions src/sparseml/core/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,7 @@ def event(
if loss is not None:
self.state.loss = loss

self._log_model_info(event_type=event_type)
self._log_loss(event_type=event_type, loss=loss)
self.log(event_type=event_type, loss=loss)

return ModifiedState(
model=self.state.model.model if self.state.model else None,
Expand All @@ -281,6 +280,16 @@ def event(
modifier_data=mod_data,
)

def log(self, event_type: EventType, loss: Optional[Any] = None):
"""
Log model and loss information for the current event type
:param event_type: the event type to log for
:param loss: the loss to log if any
"""
self._log_model_info()
self._log_loss(event_type=event_type, loss=loss)

def reset(self):
"""
Reset the session to its initial state
Expand All @@ -301,8 +310,8 @@ def get_serialized_recipe(self) -> str:
recipe = self.lifecycle.recipe_container.compiled_recipe
return recipe.yaml()

def _log_model_info(self, event_type: EventType):
# Log model level logs if needed
def _log_model_info(self):
# Log model level logs if cadence reached

epoch = self._lifecycle.event_lifecycle.current_index

Expand Down
18 changes: 18 additions & 0 deletions tests/sparseml/core/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,24 @@ def test_apply_calls_lifecycle_initialize_and_finalize(
lifecycle_mock._hit_count["finalize"] == 1
), "apply did not invoke the lifecycle finalize method"

def test_log_methods_called_same_number_of_times(
self, mocker, monkeypatch, setup_active_session
):
mock_log_model_info = mocker.patch.object(
setup_active_session, "_log_model_info"
)
mock_log_loss = mocker.patch.object(setup_active_session, "_log_loss")
monkeypatch.setattr(setup_active_session, "_lifecycle", LifeCycleMock())

event_type = EventType.BATCH_START
batch_data = None
loss = None

setup_active_session.initialize(framework=Framework.pytorch)
setup_active_session.event(event_type, batch_data, loss)

assert mock_log_model_info.call_count == mock_log_loss.call_count


@pytest.mark.parametrize(
"attribute_name",
Expand Down

0 comments on commit c6b67b4

Please sign in to comment.