Skip to content

Commit

Permalink
[Feature Branch] Logger Framework (#1860)
Browse files Browse the repository at this point in the history
* parent 5007b8c
author Rahul Tuli <rahul@neuralmagic.com> 1702308513 -0500
committer rahul-tuli <rahul@neuralmagic.com> 1704999325 -0500

[Feature] Add Model Level Logs to new Framework (#1783)

* Empty commit

* - Add loggable_items in ModifiableModel Contract
- Add pytorch implementation od loggable_items
- Attach logger to state
- Make changes in lifecycle convenience functions to include loggers
- Add log_model_info method in Modifier which is called on
each `update_event` but logs only on epoch end

* - Move log call at the end of update_event
- Remove extra space

* - Move model level logging responsibility to
`ModelLoggingMixin_`

* - Style

* - Update docstring

* - Move mixin to it's own file

* - Add test

* - update logic to check epoch end

* - Add more tests

* - log model level logs at session level
- filter logs to only include percentages
- filter logs to only include params with non zero sparsity
- filter logs to include only quantized modules
- LoggerMixin functions to helpers.py

* - Expose model log cadence
- Propagate cadence via session to state and loggers
- log model info in session
- update log condition in logger.py

* - Migrate tests

* Fix epoch number logging

* - fix failing test

* Address review comments!

* Remove event_type from should log model info

* Update docstring

* Style

* mock _log_model_info

* Move to using log scalar over log string
Do not log at last_log_epoch but only if current epoch has reached specified cadence
Remove model log cadence
Style
Update tests

`[Feature][Logger Framework]` Add convenience methods for logging strings (#1855)

* Empty commit

* - Add loggable_items in ModifiableModel Contract
- Add pytorch implementation od loggable_items
- Attach logger to state
- Make changes in lifecycle convenience functions to include loggers
- Add log_model_info method in Modifier which is called on
each `update_event` but logs only on epoch end

* - Move log call at the end of update_event
- Remove extra space

* - Move model level logging responsibility to
`ModelLoggingMixin_`

* - Style

* - Update docstring

* - Move mixin to it's own file

* - Add test

* - update logic to check epoch end

* - Add more tests

* - log model level logs at session level
- filter logs to only include percentages
- filter logs to only include params with non zero sparsity
- filter logs to include only quantized modules
- LoggerMixin functions to helpers.py

* - Expose model log cadence
- Propagate cadence via session to state and loggers
- log model info in session
- update log condition in logger.py

* - Migrate tests

* Fix epoch number logging

* - fix failing test

* Address review comments!

* Remove event_type from should log model info

* Update docstring

* Style

* mock _log_model_info

* Move to using log scalar over log string
Do not log at last_log_epoch but only if current epoch has reached specified cadence
Remove model log cadence
Style
Update tests

* Adds convenience methods to logger manager
to better conform to python's  `logging.Logger`

[Feature] Info logs to console and debug logs to file (#1861)

* Empty commit

* - Add loggable_items in ModifiableModel Contract
- Add pytorch implementation od loggable_items
- Attach logger to state
- Make changes in lifecycle convenience functions to include loggers
- Add log_model_info method in Modifier which is called on
each `update_event` but logs only on epoch end

* - Move log call at the end of update_event
- Remove extra space

* - Move model level logging responsibility to
`ModelLoggingMixin_`

* - Style

* - Update docstring

* - Move mixin to it's own file

* - Add test

* - update logic to check epoch end

* - Add more tests

* - log model level logs at session level
- filter logs to only include percentages
- filter logs to only include params with non zero sparsity
- filter logs to include only quantized modules
- LoggerMixin functions to helpers.py

* - Expose model log cadence
- Propagate cadence via session to state and loggers
- log model info in session
- update log condition in logger.py

* - Migrate tests

* Fix epoch number logging

* - fix failing test

* Address review comments!

* Remove event_type from should log model info

* Update docstring

* Style

* mock _log_model_info

* Move to using log scalar over log string
Do not log at last_log_epoch but only if current epoch has reached specified cadence
Remove model log cadence
Style
Update tests

* Adds convenience methods to logger manager
to better conform to python's  `logging.Logger`

* Move FileHandler creation above StreamHandler creation

* Remove missed comment

[Move] logger into it's own package (#1924)

improves structure and separation of concerns

Add local logs directory to .gitignore (#1925)

[Add] frequency_manager to check log "readiness" (#1927)

* Add frequency_manager to check log "readiness"

* Refactor condition to be more readable

[Logger Framework][Add] mode and frequency type to FrequencyManager (#1930)

* Add mode and frequency type to FrequencyManager

* move tests cases near the test

* Style

* [Use] frequency manager in logger (#1931)

* Make log_frequency a property of FrequencyManager
Add FrequencyManager to LoggerManager
Add log_wriiten and model_updated to LoggerManager

* [Set] [OLD MODIFIERS]
_last_log_epoch to None

* Update old modifiers to use named arguments to log_ready

* [Update][Rename]
* `PossibleLoggingMode` --> `LoggingModeType`
* `PossibleFrequencyType` --> `FrequencyType`

* Fix log format when step is None (#1933)

* [Add] Wrap System and Metric Logging into it's own classes (#1932)

* Add __repr__
Move string logging functions to SystemLoggingWrapper
Move metric logging functions to MetricLoggingWrapper

* minor fixes

* Remove warnings

* Add back old logger

* [Logger Refactor] Session based Logs (#1920)

* Add __repr__
Move string logging functions to SystemLoggingWrapper
Move metric logging functions to MetricLoggingWrapper

* minor fixes

* Remove warnings

* Add back old logger

* Add session level logs

* [Logger Framework] Log Losses (#1934)

* Add __repr__
Move string logging functions to SystemLoggingWrapper
Move metric logging functions to MetricLoggingWrapper

* minor fixes

* Remove warnings

* Add back old logger

* Add logs for loss

* Add test

* Use Sparsification Group logger (#1936)

* Save last log step
Add a stateless log_ready method

* log only when loss is not None

* Add tb style logging functions (#1919)

* Add wandb log function (#1918)

* [Test Update] Fix failing tests after adding back old logger (#1963)

* [Logger Refactor] Add timer in logger manager (#1967)

* Add timer in logger manager

* Address review comments

* update docstring

* fix argument
  • Loading branch information
rahul-tuli committed Jan 22, 2024
1 parent 923113e commit f77b42c
Show file tree
Hide file tree
Showing 22 changed files with 2,516 additions and 103 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -798,3 +798,4 @@ integrations/pytorch/pytorch_vision*

# local log files
nm_temp_test_logs/*
sparse_logs/*
103 changes: 103 additions & 0 deletions src/sparseml/core/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Any, Generator, Tuple

from sparseml.core.logger import LoggerManager
from sparseml.core.model.base import ModifiableModel
from sparseml.core.state import State


__all__ = [
"should_log_model_info",
"log_model_info",
]


def should_log_model_info(
model: ModifiableModel,
loggers: LoggerManager,
epoch: float,
) -> bool:
"""
Check if we should log model level info
Criteria:
- model has a loggable_items method
- state has a logger manager
- logger manager is ready to log based on cadence and last log epoch
:param model: The model whose info we want to log
:param loggers: The logger manager to log to
:param epoch: The current epoch
:return: True if we should log model level info, False otherwise
"""
return (
hasattr(model, "loggable_items")
and isinstance(loggers, LoggerManager)
and loggers.log_ready(current_log_step=epoch)
)


def log_model_info(state: State, epoch):
"""
Log model level info to the logger
Relies on `state.model` having a `loggable_items` method
that returns a generator of tuples of the loggable item
name and value. Also relies on `state.loggers` being a
`LoggerManager` instance.
:param state: The current state of sparsification
:param epoch: The epoch number to log model info
at
"""
_log_epoch(logger_manager=state.loggers, epoch=epoch)
_log_model_loggable_items(
logger_manager=state.loggers,
loggable_items=state.model.loggable_items(),
epoch=epoch,
)


def _log_epoch(logger_manager: LoggerManager, epoch: int):
"""
Log the epoch to the logger_manager
:param logger_manager: The logger manager to log to
:param epoch: The epoch to log
"""
logger_manager.log_scalar(tag="Epoch", value=float(epoch), step=epoch)


def _log_model_loggable_items(
logger_manager: LoggerManager,
loggable_items: Generator[Tuple[str, Any], None, None],
epoch: float,
):
"""
Log the model level loggable items to the logger_manager
:param logger_manager: The logger manager to log to
:param loggable_items: The loggable items to log, must be a generator of tuples
of the loggable item name and value
:param epoch: The epoch to log
"""
for loggable_item in loggable_items:
log_tag, log_value = loggable_item
if isinstance(log_value, dict):
logger_manager.log_scalars(tag=log_tag, values=log_value, step=epoch)
elif isinstance(log_value, (int, float)):
logger_manager.log_scalar(tag=log_tag, value=log_value, step=epoch)
else:
logger_manager.log_string(tag=log_tag, string=log_value, step=epoch)
15 changes: 15 additions & 0 deletions src/sparseml/core/logger/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
from .logger import *
Loading

0 comments on commit f77b42c

Please sign in to comment.