diff --git a/.gitignore b/.gitignore index 2fc41d97b0c..df33339cbfc 100644 --- a/.gitignore +++ b/.gitignore @@ -798,3 +798,4 @@ integrations/pytorch/pytorch_vision* # local log files nm_temp_test_logs/* +sparse_logs/* diff --git a/src/sparseml/core/helpers.py b/src/sparseml/core/helpers.py new file mode 100644 index 00000000000..d28ceaec7fa --- /dev/null +++ b/src/sparseml/core/helpers.py @@ -0,0 +1,103 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Generator, Tuple + +from sparseml.core.logger import LoggerManager +from sparseml.core.model.base import ModifiableModel +from sparseml.core.state import State + + +__all__ = [ + "should_log_model_info", + "log_model_info", +] + + +def should_log_model_info( + model: ModifiableModel, + loggers: LoggerManager, + epoch: float, +) -> bool: + """ + Check if we should log model level info + Criteria: + - model has a loggable_items method + - state has a logger manager + - logger manager is ready to log based on cadence and last log epoch + + :param model: The model whose info we want to log + :param loggers: The logger manager to log to + :param epoch: The current epoch + :return: True if we should log model level info, False otherwise + """ + return ( + hasattr(model, "loggable_items") + and isinstance(loggers, LoggerManager) + and loggers.log_ready(current_log_step=epoch) + ) + + +def log_model_info(state: State, epoch): + """ + Log model level info to the logger + Relies on `state.model` having a `loggable_items` method + that returns a generator of tuples of the loggable item + name and value. Also relies on `state.loggers` being a + `LoggerManager` instance. + + :param state: The current state of sparsification + :param epoch: The epoch number to log model info + at + """ + _log_epoch(logger_manager=state.loggers, epoch=epoch) + _log_model_loggable_items( + logger_manager=state.loggers, + loggable_items=state.model.loggable_items(), + epoch=epoch, + ) + + +def _log_epoch(logger_manager: LoggerManager, epoch: int): + """ + Log the epoch to the logger_manager + + :param logger_manager: The logger manager to log to + :param epoch: The epoch to log + """ + logger_manager.log_scalar(tag="Epoch", value=float(epoch), step=epoch) + + +def _log_model_loggable_items( + logger_manager: LoggerManager, + loggable_items: Generator[Tuple[str, Any], None, None], + epoch: float, +): + """ + Log the model level loggable items to the logger_manager + + :param logger_manager: The logger manager to log to + :param loggable_items: The loggable items to log, must be a generator of tuples + of the loggable item name and value + :param epoch: The epoch to log + """ + for loggable_item in loggable_items: + log_tag, log_value = loggable_item + if isinstance(log_value, dict): + logger_manager.log_scalars(tag=log_tag, values=log_value, step=epoch) + elif isinstance(log_value, (int, float)): + logger_manager.log_scalar(tag=log_tag, value=log_value, step=epoch) + else: + logger_manager.log_string(tag=log_tag, string=log_value, step=epoch) diff --git a/src/sparseml/core/logger/__init__.py b/src/sparseml/core/logger/__init__.py new file mode 100644 index 00000000000..0831d9c6602 --- /dev/null +++ b/src/sparseml/core/logger/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# flake8: noqa +from .logger import * diff --git a/src/sparseml/core/logger.py b/src/sparseml/core/logger/logger.py similarity index 68% rename from src/sparseml/core/logger.py rename to src/sparseml/core/logger/logger.py index 732be491c38..681e5c54d7c 100644 --- a/src/sparseml/core/logger.py +++ b/src/sparseml/core/logger/logger.py @@ -19,12 +19,21 @@ import logging import os import time +import warnings from abc import ABC +from contextlib import contextmanager from datetime import datetime from logging import CRITICAL, DEBUG, ERROR, INFO, WARN, Logger from pathlib import Path from types import ModuleType -from typing import Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union + +from sparseml.core.logger.utils import ( + FrequencyManager, + FrequencyType, + LoggingModeType, + LogStepType, +) try: @@ -65,6 +74,7 @@ "error": ERROR, "critical": CRITICAL, } +DEFAULT_TAG = "defaul_tag" class BaseLogger(ABC): @@ -100,6 +110,9 @@ def enabled(self, value: bool): """ self._enabled = value + def __repr__(self): + return f"{self.__class__.__name__}(name={self._name}, enabled={self._enabled})" + def log_hyperparams(self, params: Dict[str, float]) -> bool: """ :param params: Each key-value pair in the dictionary is the name of the @@ -355,6 +368,16 @@ def _create_default_logger(self, log_level: Optional[int] = None) -> logging.Log """ logger = logging.getLogger(__name__) + # Console handler, for logging high level modifier logs + # must be created before the file handler + # as file handler is also a stream handler + if not any( + isinstance(handler, logging.StreamHandler) for handler in logger.handlers + ): + stream_handler = logging.StreamHandler() + stream_handler.setLevel(log_level or logging.getLogger("sparseml").level) + logger.addHandler(stream_handler) + # File handler setup, for logging modifier debug statements if not any( isinstance(handler, logging.FileHandler) for handler in logger.handlers @@ -376,14 +399,6 @@ def _create_default_logger(self, log_level: Optional[int] = None) -> logging.Log logger.addHandler(file_handler) logger.info(f"Logging all SparseML modifier-level logs to {log_path}") - if not any( - isinstance(handler, logging.StreamHandler) for handler in logger.handlers - ): - # Console handler, for logging high level modifier logs - stream_handler = logging.StreamHandler() - stream_handler.setLevel(log_level or logging.getLogger("sparseml").level) - logger.addHandler(stream_handler) - logger.setLevel(LOGGING_LEVELS["debug"]) logger.propagate = False @@ -412,13 +427,17 @@ def _log_lambda( level = LOGGING_LEVELS["debug"] if level > LOGGING_LEVELS["debug"]: - format = "%s %s step %s: %s" - log_args = [ - self.name, - tag, - step, - values or value, - ] + if step is not None: + format = "%s %s step %s: %s" + log_args = [ + self.name, + tag, + step, + values or value, + ] + else: + format = "%s %s: %s" + log_args = [self.name, tag, values or value] else: format = "%s %s [%s - %s]: %s" log_args = [self.name, tag, step, wall_time, values or value] @@ -787,8 +806,16 @@ class LoggerManager(ABC): loggers. :param loggers: list of loggers assigned to this manager - :param log_frequency: number of epochs or fraction of epochs to wait between logs - + :param log_frequency: number of stes or fraction of steps to wait between logs + :param mode: The logging mode to use, either "on_change" or "exact", + "on_change" will log when the model has been updated since the last log, + "exact" will log at the given frequency regardless of model updates. + Defaults to "exact" + :param frequency_type: The frequency type to use, either "epoch", "step", or "batch" + controls what the frequency manager is tracking, e.g. if the frequency type + is "epoch", then the frequency manager will track the number of epochs that + have passed since the last log, if the frequency type is "step", then the + frequency manager will track the number of optimizer steps """ def __init__( @@ -797,14 +824,32 @@ def __init__( log_frequency: Union[float, None] = 0.1, log_python: bool = True, name: str = "manager", + mode: LoggingModeType = "exact", + frequency_type: FrequencyType = "epoch", ): - self._loggers = loggers or [] - self._log_frequency = log_frequency self._name = name - if log_python and not any( - isinstance(log, PythonLogger) for log in self._loggers - ): - self._loggers.append(PythonLogger()) + self._loggers = ( + loggers + or SparsificationGroupLogger( + python=log_python, + name=name, + tensorboard=True, + wandb_=True, + ).loggers + ) + + self.frequency_manager = FrequencyManager( + mode=mode, + frequency_type=frequency_type, + log_frequency=log_frequency, + ) + + self.system = SystemLoggingWraper( + loggers=self._loggers, frequency_manager=self.frequency_manager + ) + self.metric = MetricLoggingWrapper( + loggers=self._loggers, frequency_manager=self.frequency_manager + ) def __len__(self): return len(self.loggers) @@ -822,24 +867,43 @@ def add_logger(self, logger: BaseLogger): raise ValueError(f"logger {type(logger)} must be of type BaseLogger") self._loggers.append(logger) - def log_ready(self, epoch, last_log_epoch): + def log_ready( + self, current_log_step, last_log_step=None, check_model_update: bool = False + ): """ Check if there is a logger that is ready to accept a log - :param epoch: current epoch log is requested at - :param last_log_epoch: last time a log was recorder for this object + :param current_log_step: current step log is requested at + :param last_log_step: last time a log was recorder for this object. (Deprecated) + :param check_model_update: if True, will check if the model has been updated, + if False, will only check the log frequency :return: True if a logger is ready to accept a log. """ - return ( - self._log_frequency is not None - and ( - epoch is None - or epoch == last_log_epoch - or epoch >= last_log_epoch + self._log_frequency - ) - and any(log.enabled for log in self.loggers) + log_enabled = any(logger.enabled for logger in self.loggers) + if last_log_step is not None: + self.frequency_manager.log_written(step=last_log_step) + + return log_enabled and self.frequency_manager.log_ready( + current_log_step=current_log_step, + check_model_update=check_model_update, ) + def log_written(self, step: LogStepType): + """ + Update the frequency manager with the last log step written + + :param step: step that was last logged + """ + self.frequency_manager.log_written(step=step) + + def model_updated(self, step: LogStepType): + """ + Update the frequency manager with the last model update step + + :param step: step that was last logged + """ + self.frequency_manager.model_updated(step=step) + @staticmethod def epoch_to_step(epoch, steps_per_epoch): return round(epoch) if steps_per_epoch <= 0 else round(epoch * steps_per_epoch) @@ -863,14 +927,14 @@ def log_frequency(self) -> Union[str, float, None]: """ :return: number of epochs or fraction of epochs to wait between logs """ - return self._log_frequency + return self.frequency_manager._log_frequency @log_frequency.setter def log_frequency(self, value: Union[str, float, None]): """ :param value: number of epochs or fraction of epochs to wait between logs """ - self._log_frequency = value + self.frequency_manager._log_frequency = value @property def name(self) -> str: @@ -899,6 +963,9 @@ def log_scalar( level: Optional[int] = None, ): """ + (Note: this method is deprecated and will be removed in a future version, + use LoggerManager().metric.log_scalar instead) + :param tag: identifying tag to log the value with :param value: value to save :param step: global step for when the value was taken @@ -906,26 +973,29 @@ def log_scalar( :param kwargs: additional logging arguments to support Python and custom loggers :return: True if logged, False otherwise. """ - for log in self.loggers: - if log.enabled and (log_types == ALL_TOKEN or log.name in log_types): - log.log_scalar( - tag=tag, - value=value, - step=step, - wall_time=wall_time, - level=level, - ) + + self.metric.log_scalar( + tag=tag, + value=value, + step=step, + wall_time=wall_time, + log_types=log_types, + level=level, + ) def log_scalars( self, tag: str, - values: float, + values: Dict[str, float], step: Optional[int] = None, wall_time: Optional[float] = None, log_types: Union[str, List[str]] = ALL_TOKEN, level: Optional[int] = None, ): """ + (Note: this method is deprecated and will be removed in a future version, + use LoggerManager().metric.log_scalars instead) + :param tag: identifying tag to log the values with :param values: values to save :param step: global step for when the values were taken @@ -933,15 +1003,118 @@ def log_scalars( :param kwargs: additional logging arguments to support Python and custom loggers :return: True if logged, False otherwise. """ - for log in self.loggers: - if log.enabled and (log_types == ALL_TOKEN or log.name in log_types): - log.log_scalars( - tag=tag, - values=values, - step=step, - wall_time=wall_time, - level=level, - ) + + self.metric.log_scalars( + tag=tag, + values=values, + step=step, + wall_time=wall_time, + log_types=log_types, + level=level, + ) + + def log_hyperparams( + self, + params: Dict, + log_types: Union[str, List[str]] = ALL_TOKEN, + level: Optional[int] = None, + ): + """ + (Note: this method is deprecated and will be removed in a future version, + use LoggerManager().metric.log_hyperparams instead) + + :param params: Each key-value pair in the dictionary is the name of the + hyper parameter and it's corresponding value. + """ + + self.metric.log_hyperparams( + params=params, + log_types=log_types, + level=level, + ) + + def log_string( + self, + tag: str, + string: str, + step: Optional[int] = None, + wall_time: Optional[float] = None, + log_types: Union[str, List[str]] = ALL_TOKEN, + level: Optional[int] = None, + ): + """ + (Note: this method is deprecated and will be removed in a future version, + use LoggerManager().system.log_string instead) + + :param tag: identifying tag to log the values with + :param values: values to save + :param step: global step for when the values were taken + :param wall_time: global wall time for when the values were taken + :param kwargs: additional logging arguments to support Python and custom loggers + :return: True if logged, False otherwise. + """ + self.system.log_string( + tag=tag, + string=string, + step=step, + wall_time=wall_time, + log_types=log_types, + level=level, + ) + + def save( + self, + file_path: str, + **kwargs, + ): + """ + :param file_path: path to a file to be saved + :param kwargs: additional arguments that a specific logger might use + """ + for log in self._loggers: + if log.enabled: + log.save(file_path, **kwargs) + + @contextmanager + def time(self, tag: Optional[str] = None, *args, **kwargs): + """ + Context manager to log the time it takes to run the block of code + + Usage: + >>> with LoggerManager().time("my_block"): + >>> time.sleep(1) + + :param tag: identifying tag to log the values with + """ + + start = time.time() + yield + elapsed = time.time() - start + if not tag: + tag = f"{DEFAULT_TAG}_time_secs" + self.log_scalar(tag=tag, value=float(f"{elapsed:.3f}"), *args, **kwargs) + + +class LoggingWrapperBase: + """ + Base class that holds a reference to the loggers and frequency manager + """ + + def __init__(self, loggers: List[BaseLogger], frequency_manager: FrequencyManager): + self.loggers = loggers + self._frequency_manager = frequency_manager + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"loggers={self.loggers}, frequency_manager={self._frequency_manager})" + ) + + +class SystemLoggingWraper(LoggingWrapperBase): + """ + Wraps utilities and convenience methods for logging strings to the system + """ def log_string( self, @@ -970,6 +1143,95 @@ def log_string( level=level, ) + def debug(self, tag, string, *args, **kwargs): + """ + logs a string message with level DEBUG on all + loggers that are enabled + + :param tag: Identifying tag to log the string with + :param string: The string to log + :param args: additional arguments to pass to the logger, + see `log_string` for more details + :param kwargs: additional arguments to pass to the logger, + see `log_string` for more details + """ + kwargs["level"] = logging.DEBUG + self.log_string(tag=tag, string=string, *args, **kwargs) + + def info(self, tag, string, *args, **kwargs): + """ + logs a string message with level INFO on all + loggers that are enabled + + :param tag: Identifying tag to log the string with + :param string: The string to log + :param args: additional arguments to pass to the logger, + see `log_string` for more details + :param kwargs: additional arguments to pass to the logger, + see `log_string` for more details + """ + kwargs["level"] = logging.INFO + self.log_string(tag=tag, string=string, *args, **kwargs) + + def warning(self, tag, string, *args, **kwargs): + """ + logs a string message with level WARNING on all + loggers that are enabled + + :param tag: Identifying tag to log the string with + :param string: The string to log + :param args: additional arguments to pass to the logger, + see `log_string` for more details + :param kwargs: additional arguments to pass to the logger, + see `log_string` for more details + """ + kwargs["level"] = logging.WARNING + self.log_string(tag=tag, string=string, *args, **kwargs) + + def warn(self, tag, string, *args, **kwargs): + warnings.warn( + "The 'warn' method is deprecated, use 'warning' instead", + DeprecationWarning, + 2, + ) + self.warning(tag=tag, string=string, *args, **kwargs) + + def error(self, tag, string, *args, **kwargs): + """ + logs a string message with level ERROR on all + loggers that are enabled + + :param tag: Identifying tag to log the string with + :param string: The string to log + :param args: additional arguments to pass to the logger, + see `log_string` for more details + :param kwargs: additional arguments to pass to the logger, + see `log_string` for more details + """ + kwargs["level"] = logging.ERROR + self.log_string(tag=tag, string=string, *args, **kwargs) + + def critical(self, tag, string, *args, **kwargs): + """ + logs a string message with level CRITICAL on all + loggers that are enabled + + :param tag: Identifying tag to log the string with + :param string: The string to log + :param args: additional arguments to pass to the logger, + see `log_string` for more details + :param kwargs: additional arguments to pass to the logger, + see `log_string` for more details + """ + kwargs["level"] = logging.CRITICAL + self.log_string(tag=tag, string=string, *args, **kwargs) + + +class MetricLoggingWrapper(LoggingWrapperBase): + """ + Wraps utilities and convenience methods for logging metrics to the system + """ + def log_hyperparams( self, params: Dict, @@ -980,22 +1242,123 @@ def log_hyperparams( :param params: Each key-value pair in the dictionary is the name of the hyper parameter and it's corresponding value. """ - for log in self._loggers: + for log in self.loggers: if log.enabled and (log_types == ALL_TOKEN or log.name in log_types): log.log_hyperparams(params, level) - def save( + def log_scalar( self, - file_path: str, + tag: str, + value: float, + step: Optional[int] = None, + wall_time: Optional[float] = None, + log_types: Union[str, List[str]] = ALL_TOKEN, + level: Optional[int] = None, + ): + """ + :param tag: identifying tag to log the value with + :param value: value to save + :param step: global step for when the value was taken + :param wall_time: global wall time for when the value was taken + :param kwargs: additional logging arguments to support Python and custom loggers + :return: True if logged, False otherwise. + """ + for log in self.loggers: + if log.enabled and (log_types == ALL_TOKEN or log.name in log_types): + log.log_scalar( + tag=tag, + value=value, + step=step, + wall_time=wall_time, + level=level, + ) + + def log_scalars( + self, + tag: str, + values: Dict[str, float], + step: Optional[int] = None, + wall_time: Optional[float] = None, + log_types: Union[str, List[str]] = ALL_TOKEN, + level: Optional[int] = None, + ): + """ + :param tag: identifying tag to log the values with + :param values: values to save + :param step: global step for when the values were taken + :param wall_time: global wall time for when the values were taken + :param kwargs: additional logging arguments to support Python and custom loggers + :return: True if logged, False otherwise. + """ + for log in self.loggers: + if log.enabled and (log_types == ALL_TOKEN or log.name in log_types): + log.log_scalars( + tag=tag, + values=values, + step=step, + wall_time=wall_time, + level=level, + ) + + def add_scalar( + self, + value, + tag: str = DEFAULT_TAG, + step: Optional[int] = None, + wall_time: Union[int, float, None] = None, **kwargs, ): """ - :param file_path: path to a file to be saved - :param kwargs: additional arguments that a specific logger might use + Add a scalar value to the logger + + :param value: value to log + :param tag: tag to log the value with, defaults to DEFAULT_TAG + :param step: global step for when the value was taken + :param wall_time: global wall time for when the value was taken + :param kwargs: additional logging arguments to to pass through to the + logger """ - for log in self._loggers: - if log.enabled: - log.save(file_path, **kwargs) + self.log_scalar(tag=tag, value=value, step=step, wall_time=wall_time, **kwargs) + + def add_scalars( + self, + values: Dict[str, Any], + tag: str = DEFAULT_TAG, + step: Optional[int] = None, + wall_time: Union[int, float, None] = None, + **kwargs, + ): + """ + Adds multiple scalar values to the logger + + :param values: values to log, must be A dict of serializable + python objects i.e `str`, `ints`, `floats`, `Tensors`, `dicts`, etc + :param tag: tag to log the value with, defaults to DEFAULT_TAG + :param step: global step for when the value was taken + :param wall_time: global wall time for when the value was taken + :param kwargs: additional logging arguments to to pass through to the + logger + """ + self.log_scalars( + tag=tag, values=values, step=step, wall_time=wall_time, **kwargs + ) + + def log( + self, + data: Dict[str, Any], + step: Optional[int] = None, + tag: Optional[str] = DEFAULT_TAG, + **kwargs, + ) -> None: + """ + :param data: A dict of serializable python objects i.e `str`, + `ints`, `floats`, `Tensors`, `dicts`, etc + :param step: global step for when the values were taken + :param tag: identifying tag to log the values with, defaults to DEFAULT_TAG + :param kwargs: additional logging arguments to support + Python and custom loggers + """ + self.log_scalars(tag=tag, values=data, step=step, **kwargs) def _create_dirs(path: str): diff --git a/src/sparseml/core/logger/utils/__init__.py b/src/sparseml/core/logger/utils/__init__.py new file mode 100644 index 00000000000..686adfc96c0 --- /dev/null +++ b/src/sparseml/core/logger/utils/__init__.py @@ -0,0 +1,17 @@ +# flake8: noqa + +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .frequency_manager import * diff --git a/src/sparseml/core/logger/utils/frequency_manager.py b/src/sparseml/core/logger/utils/frequency_manager.py new file mode 100644 index 00000000000..6d78cc2968f --- /dev/null +++ b/src/sparseml/core/logger/utils/frequency_manager.py @@ -0,0 +1,310 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Literal, Optional, Union + + +__all__ = [ + "FrequencyManager", + "LoggingModeType", + "FrequencyType", + "LogStepType", + "log_ready", +] + +LogStepType = Union[int, float, None] +LoggingModeType = Literal["on_change", "exact"] +FrequencyType = Literal["epoch", "step", "batch"] + + +class FrequencyManager: + """ + Class for managing the frequency of logging and model updates + + :param log_frequency: The frequency to log at + :param mode: The logging mode to use, either "on_change" or "exact", + "on_change" will log when the model has been updated since the last log, + "exact" will log at the given frequency regardless of model updates + :param frequency_type: The frequency type to use, either "epoch", "step", or "batch" + controls what the frequency manager is tracking, e.g. if the frequency type + is "epoch", then the frequency manager will track the number of epochs that + have passed since the last log, if the frequency type is "step", then the + frequency manager will track the number of optimizer steps + """ + + def __init__( + self, + log_frequency: LogStepType = None, + mode: LoggingModeType = "exact", + frequency_type: FrequencyType = "epoch", + ): + # sets self._logging_mode and self._check_model_update + self._logging_mode = self._set_logging_mode(mode=mode) + + # sets self._frequency_type and self._valid_python_types + self._frequency_type = self._set_frequency_type(frequency_type=frequency_type) + + self._validate_log_frequency(log_frequency=log_frequency) + self._log_frequency = log_frequency + + self.last_log_step: LogStepType = None + self.last_model_update_step: LogStepType = None + + def __repr__(self): + return ( + f"{self.__class__.__name__}(log_frequency={self.log_frequency}, " + f"mode={self._logging_mode}, frequency_type={self._frequency_type})" + ) + + def log_ready( + self, + current_log_step: LogStepType, + check_model_update: bool = False, + ): + """ + Check if the frequency manager is ready to log + Conditions for readiness: + - log frequency is not None + - current log step is None + - current log step greater than or equal to the last log step + plus the log frequency + - if check_model_update is True, or self._check_model_update is True, + then the last model update step must be greater than or equal + to the last log step, and the current log step must be greater + than or equal to the last model update step plus the log frequency + + :param current_log_step: The current log step + :param check_model_update: If True, will check if the model has been updated + since the last log step and if _log_frequency steps have passed since the + last model update; Defaults to False. + :return: True if the frequency manager is ready to log, + False otherwise + """ + # check_model_update is used to override self._check_model_update + # e.g. if check_model_update is True, then the model update check + # will be performed even if self._check_model_update is False + + check_model_update = check_model_update or self._check_model_update + + return log_ready( + current_log_step=current_log_step, + last_log_step=self.last_log_step, + log_frequency=self.log_frequency, + last_model_update_step=self.last_model_update_step, + check_model_update=check_model_update, + ) + + def model_updated(self, step: LogStepType = None) -> None: + """ + Sets the last model update to the given step + + :param step: The step to set the last model update to + :post-cond: The last model update step is set to the given step + """ + self._validate_log_step(log_step=step) + self.last_model_update_step = step + + def log_written(self, step: LogStepType = None) -> None: + """ + Sets the last log step to the given step + + :param step: The step to set the last log step to + :post-cond: The last log step is set to the given step + """ + self._validate_log_step(log_step=step) + self.last_log_step = step + + @property + def log_frequency(self) -> LogStepType: + """ + :return: The log frequency + """ + return self._log_frequency + + @log_frequency.setter + def log_frequency(self, log_frequency: LogStepType) -> None: + """ + Sets the log frequency to the given value + + :param log_frequency: The log frequency to set + :post-cond: The log frequency is set to the given value + """ + self._validate_log_frequency(log_frequency=log_frequency) + self._log_frequency = log_frequency + + def _validate_log_frequency(self, log_frequency): + # checks that log frequency is a positive number or None + # raise TypeError if not a number or None + # raises ValueError if not a positive number + + try: + self._validate_log_step(log_step=log_frequency) + if log_frequency == 0: + raise ValueError() + # except clauses update the error message + except TypeError: + raise TypeError( + f"log frequency must be a number or None, given {type(log_frequency)}" + ) + except ValueError: + raise ValueError( + f"log frequency must be greater than 0, given {log_frequency}" + ) + + def _validate_log_step(self, log_step): + # checks that log step is a non negative number or None + # raise TypeError if not a number or None + # raises ValueError if negative number + + if not isinstance(log_step, self._valid_python_types) or isinstance( + log_step, bool + ): + raise TypeError( + f"log step must be a number or None, given {type(log_step)}" + ) + + if log_step is not None and log_step < 0: + raise ValueError( + f"log step must be greater than or equal to 0, given {log_step}" + ) + + def _set_logging_mode(self, mode: LoggingModeType) -> LoggingModeType: + """ + Set the logging mode for the frequency manager. + The logging mode determines how the frequency manager determines + if it is ready to log + + :param mode: The logging mode to set + :post-cond: The self._logging_mode is set to the given mode + :post-cond: The self._check_model_update is set to True if the mode is + "on_change" + :raises ValueError: If the given mode is not one of "on_change" or "exact" + :return: The logging mode that was set + """ + mode = _basic_normalization(mode) + if mode == "on_change": + self._check_model_update = True + self._logging_mode = "on_change" + elif mode == "exact": + self._check_model_update = False + self._logging_mode = "exact" + else: + raise ValueError( + f"Invalid logging mode {mode}, must be one of 'on_change', 'exact'" + ) + return self._logging_mode + + def _set_frequency_type(self, frequency_type: FrequencyType) -> FrequencyType: + """ + Set the frequency type for the frequency manager. + The frequency type determines what the frequency manager is tracking. + For example, if the frequency type is "epoch", then the frequency manager + will track the number of epochs that have passed since the last log. + + :param frequency_type: The frequency type to set + :post-cond: The self._frequency_type is set to the given frequency type + :post-cond: The self._valid_python_types is set to the valid python types + for the given frequency type, e.g. (int, float, type(None)) for "epoch" + and (int, type(None)) for "step" or "batch" + :raises ValueError: If the given frequency type is not one of "epoch", + "step", or "batch" + :return: The frequency type that was set + """ + frequency_type = _basic_normalization(frequency_type) + if frequency_type == "epoch": + self._frequency_type = "epoch" + self._valid_python_types = (int, float, type(None)) + elif frequency_type == "step": + self._frequency_type = "step" + self._valid_python_types = (int, type(None)) + elif frequency_type == "batch": + self._frequency_type = "batch" + self._valid_python_types = (int, type(None)) + else: + raise ValueError( + f"Invalid frequency type {frequency_type}, must be one of " + "'epoch', 'step', 'batch'" + ) + return self._frequency_type + + +def log_ready( + current_log_step: Optional[LogStepType], + last_log_step: Optional[LogStepType], + log_frequency: Optional[LogStepType], + last_model_update_step: Optional[LogStepType] = None, + check_model_update: bool = False, +): + """ + Check if we are ready to log again based on the given parameters + (Stateless version of FrequencyManager().log_ready) + + Conditions for readiness: + - log frequency is not None + - current log step is None + - current log step greater than or equal to the last log step + plus the log frequency + - if check_model_update is True, then the last model update step + must be greater than or equal to the last log step, and the + current log step must be greater than or equal to the + last model update step plus the log frequency + + :param current_log_step: The current log step + :param last_log_step: The last step at which logging occurred + :param log_frequency: The frequency to log at + :param last_model_update_step: The last step at which the model was updated + :param check_model_update: If True, will check if the model has been updated + since the last log step and if log_frequency steps have passed since the + last model update; Defaults to False. + :return: True if logging cadence has been reached again False otherwise + """ + # format is used to avoid floating point errors + # e.g. 0.1 + 0.2 != 0.3 + # format(0.1 + 0.2, ".4f") == format(0.3, ".4f") + + cadence_reached: bool = log_frequency is not None and ( + current_log_step is None + or last_log_step is None + or current_log_step >= float(format(last_log_step + log_frequency, ".4f")) + ) + + if not cadence_reached or not check_model_update: + # early return if cadence not reached or, + # model update check not requested + return cadence_reached + + model_updated_since_last_log: bool = ( + last_model_update_step is None + or last_log_step is None + or current_log_step is None + or ( + last_model_update_step >= last_log_step + and current_log_step + >= float(format(log_frequency + last_model_update_step, ".4f")) + ) + ) + + return cadence_reached and model_updated_since_last_log + + +def _basic_normalization(value: str) -> str: + """ + Basic normalization for string values. + Removes leading and trailing whitespace and converts to lowercase. + + :param value: The value to normalize + :return: The normalized value + """ + return value.strip().lower() diff --git a/src/sparseml/core/model/base.py b/src/sparseml/core/model/base.py index bee11706ade..d54fd9edc98 100644 --- a/src/sparseml/core/model/base.py +++ b/src/sparseml/core/model/base.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Dict, Generic, List, Optional, Tuple, TypeVar, Union +from typing import Any, Dict, Generator, Generic, List, Optional, Tuple, TypeVar, Union from sparseml.core.framework import Framework from sparseml.core.framework_object import MultiFrameworkObject @@ -125,6 +125,15 @@ def set_param(self, target: str, param: PT): """ raise NotImplementedError() + def loggable_items(self) -> Generator[Tuple[str, Any], None, None]: + """ + Model level information to be logged for the model + + :return a generator that yields a tuple of: + - the name of the loggable item + - the value of the loggable item + """ + @property def layer_prefix(self) -> Optional[str]: """ diff --git a/src/sparseml/core/model/pytorch.py b/src/sparseml/core/model/pytorch.py index f3a5701a3fe..1bb91af5a16 100644 --- a/src/sparseml/core/model/pytorch.py +++ b/src/sparseml/core/model/pytorch.py @@ -12,12 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Generator, List, Optional, Tuple, Union from torch.nn import Module, Parameter from sparseml.core.framework import Framework from sparseml.core.model.base import ModelParameterizedLayer, ModifiableModel +from sparseml.pytorch.utils.sparsification_info.module_sparsification_info import ( + ModuleSparsificationInfo, +) from sparseml.utils.pytorch import ( get_layer, get_layers, @@ -102,6 +105,30 @@ def set_param(self, target: str, param: Parameter): """ return set_param(target, param, self.model) + def loggable_items(self) -> Generator[Tuple[str, Any], None, None]: + """ + PyTorch specific logging info for the model. + loggable items are defined in the `ModuleSparsificationInfo` class, + and include sparsity, quantization, and pruning information. + + This includes: + - Total operation counts + - Total parameter counts + - sparsity percentages per layer (with non-zero sparsity only) + - quantization bitwidth (for quantized layers) + + :return a generator that yields a tuple of: + - the name of the loggable item + - the value of the loggable item + """ + sparsification_info = ModuleSparsificationInfo.from_module(self.model) + + yield from sparsification_info.loggable_items( + percentages_only=True, + non_zero_only=True, + enabled_only=True, + ) + def get_matching_layer( self, target: str, name_to_match: str, model: Module ) -> Optional[Tuple[str, Module]]: diff --git a/src/sparseml/core/modifier/stage.py b/src/sparseml/core/modifier/stage.py index ab0dcdb25b3..804d0874531 100644 --- a/src/sparseml/core/modifier/stage.py +++ b/src/sparseml/core/modifier/stage.py @@ -123,6 +123,7 @@ def pre_initialize_structure(self, state: "State", **kwargs): modifier.pre_initialize_structure(state, **kwargs) self.applied = True + state.loggers.system.info(tag="stage", string="Model structure initialized") def initialize(self, state: "State", **kwargs): """ @@ -138,6 +139,7 @@ def initialize(self, state: "State", **kwargs): for modifier in self.modifiers: modifier.initialize(state, **kwargs) + state.loggers.system.info(tag="stage", string="Modifiers initialized") def finalize(self, state: "State", **kwargs): """ @@ -155,6 +157,7 @@ def finalize(self, state: "State", **kwargs): modifier.finalize(state, **kwargs) self.applied = True + state.loggers.system.info(tag="stage", string="Modifiers finalized") def update_event(self, state: "State", event: "Event", **kwargs): """ diff --git a/src/sparseml/core/session.py b/src/sparseml/core/session.py index ac18b9add90..9cc3574fbb6 100644 --- a/src/sparseml/core/session.py +++ b/src/sparseml/core/session.py @@ -19,7 +19,10 @@ from sparseml.core.event import EventType from sparseml.core.framework import Framework +from sparseml.core.helpers import log_model_info, should_log_model_info from sparseml.core.lifecycle import SparsificationLifecycle +from sparseml.core.logger import BaseLogger, LoggerManager +from sparseml.core.logger.utils import log_ready from sparseml.core.recipe import Recipe from sparseml.core.state import ModifiedState, State @@ -63,6 +66,7 @@ class SparseSession: def __init__(self): self._lifecycle = SparsificationLifecycle() + self._last_loss_log_step = None @property def lifecycle(self) -> SparsificationLifecycle: @@ -146,6 +150,7 @@ def initialize( start: Optional[float] = None, steps_per_epoch: Optional[int] = None, batches_per_step: Optional[int] = None, + loggers: Union[None, LoggerManager, List[BaseLogger]] = None, **kwargs, ) -> ModifiedState: """ @@ -174,6 +179,8 @@ def initialize( sparsification :param batches_per_step: the number of batches per step to use for sparsification + :param loggers: the logger manager to setup logging important info + and milestones to, also accepts a list of BaseLogger(s) :param kwargs: additional kwargs to pass to the lifecycle's initialize method :return: the modified state of the session after initializing """ @@ -195,6 +202,7 @@ def initialize( start=start, steps_per_epoch=steps_per_epoch, batches_per_step=batches_per_step, + loggers=loggers, **kwargs, ) @@ -214,6 +222,9 @@ def finalize(self, **kwargs) -> ModifiedState: :param kwargs: additional kwargs to pass to the lifecycle's finalize method :return: the modified state of the session after finalizing """ + # log losses on finalization + self._log_loss(event_type=EventType.LOSS_CALCULATED, loss=self.state.loss) + mod_data = self._lifecycle.finalize(**kwargs) return ModifiedState( @@ -256,6 +267,12 @@ def event( event_type=event_type, batch_data=batch_data, loss=loss, **kwargs ) + # Update loss + if loss is not None: + self.state.loss = loss + + self.log(event_type=event_type, loss=loss) + return ModifiedState( model=self.state.model.model if self.state.model else None, optimizer=self.state.optimizer.optimizer if self.state.optimizer else None, @@ -263,6 +280,16 @@ def event( modifier_data=mod_data, ) + def log(self, event_type: EventType, loss: Optional[Any] = None): + """ + Log model and loss information for the current event type + + :param event_type: the event type to log for + :param loss: the loss to log if any + """ + self._log_model_info() + self._log_loss(event_type=event_type, loss=loss) + def reset(self): """ Reset the session to its initial state @@ -283,6 +310,45 @@ def get_serialized_recipe(self) -> str: recipe = self.lifecycle.recipe_container.compiled_recipe return recipe.yaml() + def _log_model_info(self): + # Log model level logs if cadence reached + + epoch = self._lifecycle.event_lifecycle.current_index + + if should_log_model_info( + model=self.state.model, + loggers=self.state.loggers, + epoch=epoch, + ): + log_model_info( + state=self.state, + epoch=epoch, + ) + # update last log epoch + self.state.loggers.log_written(epoch) + + def _log_loss(self, event_type: EventType, loss: Any): + if event_type != EventType.LOSS_CALCULATED: + # only log loss when loss is calculated + return + + current_step = self._lifecycle.event_lifecycle.current_index + + # No need to check model update for loss logging + check_model_update = False + + if loss is not None and log_ready( + current_log_step=current_step, + last_log_step=self._last_loss_log_step, + log_frequency=self.state.loggers.log_frequency, + check_model_update=check_model_update, + ): + loss = loss if isinstance(loss, dict) else {"loss": loss} + self.state.loggers.metric.log_scalars( + tag="Loss", values=loss, step=current_step + ) + self._last_loss_log_step = current_step + _global_session = SparseSession() _local_storage = threading.local() diff --git a/src/sparseml/core/state.py b/src/sparseml/core/state.py index b6fe6172dcb..64014e3c479 100644 --- a/src/sparseml/core/state.py +++ b/src/sparseml/core/state.py @@ -14,13 +14,12 @@ from copy import deepcopy from dataclasses import dataclass -from typing import Any, Dict, List, Optional - -from pydantic import Field +from typing import Any, Dict, List, Optional, Union from sparseml.core.data import ModifiableData from sparseml.core.event import Event from sparseml.core.framework import Framework +from sparseml.core.logger import BaseLogger, LoggerManager from sparseml.core.model import ModifiableModel from sparseml.core.optimizer import ModifiableOptimizer @@ -96,7 +95,9 @@ class State: :param hardware: Hardware Instance holding info about the target hardware being used :param start_event: The start event to begin training :param last_event: The last event to stop training - :param loggers: List of loggers to use for logging training information + :param loggers: LoggerManager instance holding all the loggers to log + :param model_log_cadence: The cadence to log model information w.r.t epochs. + If 1, logs every epoch. If 2, logs every other epoch, etc. Default is 1. """ framework: Framework @@ -110,7 +111,9 @@ class State: hardware = Hardware() start_event: Event = None last_event: Event = None - loggers = Field(default_factory=list) + loggers: Optional[LoggerManager] = None + model_log_cadence: Optional[float] = None + _last_log_epoch: Optional[float] = None @property def sparsification_ready(self) -> bool: @@ -135,6 +138,8 @@ def update( start: float = None, steps_per_epoch: int = None, batches_per_step: int = None, + loggers: Union[None, LoggerManager, List[BaseLogger]] = None, + model_log_cadence: Optional[float] = None, **kwargs, ) -> Dict: """ @@ -152,6 +157,10 @@ def update( :param start: The start index to update the state with :param steps_per_epoch: The steps per epoch to update the state with :param batches_per_step: The batches per step to update the state with + :param loggers: the logger manager to setup logging important info and + milestones to, also accepts a list of BaseLogger(s) + :param model_log_cadence: The cadence to log model information w.r.t epochs. + If 1, logs every epoch. If 2, logs every other epoch, etc. Default is 1. :param kwargs: Additional keyword arguments to update the state with """ if model is not None: @@ -193,6 +202,13 @@ def update( if batches_per_step is not None: self.start_event.batches_per_step = batches_per_step + loggers = loggers or [] + if isinstance(loggers, List): + loggers = LoggerManager(loggers) + self.loggers = loggers + + if model_log_cadence is not None: + self.model_log_cadence = model_log_cadence return kwargs diff --git a/src/sparseml/pytorch/sparsification/modifier.py b/src/sparseml/pytorch/sparsification/modifier.py index f202fae0f87..d9405f7e6ac 100644 --- a/src/sparseml/pytorch/sparsification/modifier.py +++ b/src/sparseml/pytorch/sparsification/modifier.py @@ -438,7 +438,9 @@ def wrapper(*args, **kwargs): epoch = kwargs.get("epoch", None) steps_per_epoch = kwargs.get("steps_per_epoch", None) # Log call state - if self.loggers and self.loggers.log_ready(epoch, self._last_log_epoch): + if self.loggers and self.loggers.log_ready( + epoch=epoch, last_log_epoch=self._last_log_epoch + ): self.log_string( string=( f"Calling {func.__name__} with:\n" @@ -451,7 +453,9 @@ def wrapper(*args, **kwargs): ) out = func(*args, **kwargs) # Log return state - if self.loggers and self.loggers.log_ready(epoch, self._last_log_epoch): + if self.loggers and self.loggers.log_ready( + epoch=epoch, last_log_epoch=self._last_log_epoch + ): out_print = out if isinstance(out, Tuple) else [out] self.log_string( string=(f"\nReturned: {format_args(out_print)}\n"), @@ -671,7 +675,9 @@ def scheduled_log_update( if not self._enabled: raise RuntimeError("modifier must be enabled") - if self.loggers and self.loggers.log_ready(epoch, self._last_log_epoch): + if self.loggers and self.loggers.log_ready( + epoch=epoch, last_log_epoch=self._last_log_epoch + ): self._last_log_epoch = epoch self._scheduled_log_called = True self.log_update(module, optimizer, epoch, steps_per_epoch) diff --git a/src/sparseml/pytorch/utils/logger.py b/src/sparseml/pytorch/utils/logger.py index 6ed2ae47802..0e9a5bc0ff6 100644 --- a/src/sparseml/pytorch/utils/logger.py +++ b/src/sparseml/pytorch/utils/logger.py @@ -12,21 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Contains code for loggers that help visualize the information from each modifier +""" -# This file has been moved to src/sparseml/core/logger.py -# and is kept here for backwards compatibility. -# It will be removed in a future release. +import logging +import os +import time +from abc import ABC +from datetime import datetime +from logging import CRITICAL, DEBUG, ERROR, INFO, WARN, Logger +from types import ModuleType +from typing import Callable, Dict, List, Optional, Union -from sparseml.core.logger import ( - LOGGING_LEVELS, - BaseLogger, - LambdaLogger, - LoggerManager, - PythonLogger, - SparsificationGroupLogger, - TensorBoardLogger, - WANDBLogger, -) + +try: + try: + from torch.utils.tensorboard import SummaryWriter + except (ModuleNotFoundError, ImportError): + from tensorboardX import SummaryWriter + tensorboard_import_error = None +except Exception as tensorboard_err: + SummaryWriter = object + tensorboard_import_error = tensorboard_err + + +try: + import wandb + + wandb_err = None +except Exception as err: + wandb = None + wandb_err = err + +from sparseml.utils import ALL_TOKEN, create_dirs __all__ = [ @@ -39,3 +58,942 @@ "LoggerManager", "LOGGING_LEVELS", ] + +LOGGING_LEVELS = { + "debug": DEBUG, + "info": INFO, + "warn": WARN, + "error": ERROR, + "critical": CRITICAL, +} + + +class BaseLogger(ABC): + """ + Base class that all modifier loggers must implement. + + :param name: name given to the logger, used for identification + :param enabled: True to log, False otherwise + """ + + def __init__(self, name: str, enabled: bool = True): + self._name = name + self._enabled = enabled + + @property + def name(self) -> str: + """ + :return: name given to the logger, used for identification + """ + return self._name + + @property + def enabled(self) -> bool: + """ + :return: True to log, False otherwise + """ + return self._enabled + + @enabled.setter + def enabled(self, value: bool): + """ + :param value: True to log, False otherwise + """ + self._enabled = value + + def log_hyperparams(self, params: Dict[str, float]) -> bool: + """ + :param params: Each key-value pair in the dictionary is the name of the + hyper parameter and it's corresponding value. + :return: True if logged, False otherwise. + """ + return False + + def log_scalar( + self, + tag: str, + value: float, + step: Optional[int] = None, + wall_time: Optional[float] = None, + **kwargs, + ) -> bool: + """ + :param tag: identifying tag to log the value with + :param value: value to save + :param step: global step for when the value was taken + :param wall_time: global wall time for when the value was taken + :param kwargs: additional logging arguments to support Python and custom loggers + :return: True if logged, False otherwise. + """ + return False + + def log_scalars( + self, + tag: str, + values: Dict[str, float], + step: Optional[int] = None, + wall_time: Optional[float] = None, + **kwargs, + ) -> bool: + """ + :param tag: identifying tag to log the values with + :param values: values to save + :param step: global step for when the values were taken + :param wall_time: global wall time for when the values were taken + :param kwargs: additional logging arguments to support Python and custom loggers + :return: True if logged, False otherwise. + """ + return False + + def log_string( + self, + tag: str, + string: str, + step: Optional[int] = None, + wall_time: Optional[float] = None, + **kwargs, + ) -> bool: + """ + :param tag: identifying tag to log the values with + :param values: values to save + :param step: global step for when the values were taken + :param wall_time: global wall time for when the values were taken + :param kwargs: additional logging arguments to support Python and custom loggers + :return: True if logged, False otherwise. + """ + return False + + def save( + self, + file_path: str, + **kwargs, + ) -> bool: + """ + :param file_path: path to a file to be saved + :param kwargs: additional arguments that a specific logger might use + :return: True if saved, False otherwise + """ + return False + + +class LambdaLogger(BaseLogger): + """ + Logger that handles calling back to a lambda function with any logs. + + :param lambda_func: the lambda function to call back into with any logs. + The expected call sequence is (tag, value, values, step, wall_time) -> bool + The return type is True if logged and False otherwise. + :param name: name given to the logger, used for identification; + defaults to lambda + :param enabled: True to log, False otherwise + """ + + def __init__( + self, + lambda_func: Callable[ + [ + Optional[str], + Optional[Union[float, str]], + Optional[Dict[str, float]], + Optional[int], + Optional[float], + Optional[int], + ], + bool, + ], + name: str = "lambda", + enabled: bool = True, + ): + super().__init__(name, enabled) + self._lambda_func = lambda_func + assert lambda_func, "lambda_func must be set to a callable function" + + @property + def lambda_func( + self, + ) -> Callable[ + [ + Optional[str], + Optional[Union[float, str]], + Optional[Dict[str, float]], + Optional[int], + Optional[float], + Optional[int], + ], + bool, + ]: + """ + :return: the lambda function to call back into with any logs. + The expected call sequence is (tag, value, values, step, wall_time) + """ + return self._lambda_func + + def log_hyperparams( + self, + params: Dict, + level: Optional[int] = None, + ) -> bool: + """ + :param params: Each key-value pair in the dictionary is the name of the + hyper parameter and it's corresponding value. + :return: True if logged, False otherwise. + """ + if not self.enabled: + return False + + return self._lambda_func( + tag=None, + value=None, + values=params, + step=None, + wall_time=None, + level=level, + ) + + def log_scalar( + self, + tag: str, + value: float, + step: Optional[int] = None, + wall_time: Optional[float] = None, + level: Optional[int] = None, + ) -> bool: + """ + :param tag: identifying tag to log the value with + :param value: value to save + :param step: global step for when the value was taken + :param wall_time: global wall time for when the value was taken, + defaults to time.time() + :param kwargs: additional logging arguments to support Python and custom loggers + :return: True if logged, False otherwise. + """ + if not wall_time: + wall_time = time.time() + + return self._lambda_func( + tag=tag, + value=value, + values=None, + step=step, + wall_time=wall_time, + level=level, + ) + + def log_scalars( + self, + tag: str, + values: Dict[str, float], + step: Optional[int] = None, + wall_time: Optional[float] = None, + level: Optional[int] = None, + ) -> bool: + """ + :param tag: identifying tag to log the values with + :param values: values to save + :param step: global step for when the values were taken + :param wall_time: global wall time for when the values were taken, + defaults to time.time() + :param kwargs: additional logging arguments to support Python and custom loggers + :return: True if logged, False otherwise. + """ + if not wall_time: + wall_time = time.time() + + return self._lambda_func( + tag=tag, + value=None, + values=values, + step=step, + wall_time=wall_time, + level=level, + ) + + +class PythonLogger(LambdaLogger): + """ + Modifier logger that handles printing values into a python logger instance. + + :param logger: a logger instance to log to, if None then will create it's own + :param log_level: default level to log any incoming data at on the logging.Logger + instance when an explicit log level isn't provided + :param name: name given to the logger, used for identification; + defaults to python + :param enabled: True to log, False otherwise + """ + + def __init__( + self, + logger: Logger = None, + log_level: int = None, + name: str = "python", + enabled: bool = True, + ): + self._logger = logger or self._create_default_logger(log_level=log_level) + + super().__init__( + lambda_func=self._log_lambda, + name=name, + enabled=enabled, + ) + + def __getattr__(self, item): + return getattr(self._logger, item) + + @property + def logger(self) -> Logger: + """ + :return: a logger instance to log to, if None then will create it's own + """ + return self._logger + + def _create_default_logger(self, log_level: Optional[int] = None) -> logging.Logger: + """ + Create a default modifier logger, with a file handler logging at the debug level + and a stream handler logging to console at the specified level + + :param log_level: logging level for the console logger + :return: logger + """ + logger = logging.getLogger(__name__) + + # File handler setup, for logging modifier debug statements + if not any( + isinstance(handler, logging.FileHandler) for handler in logger.handlers + ): + base_log_path = ( + os.environ.get("NM_TEST_LOG_DIR") + if os.environ.get("NM_TEST_MODE") + else "sparse_logs" + ) + now = datetime.now() + dt_string = now.strftime("%d-%m-%Y_%H.%M.%S") + log_path = os.path.join(base_log_path, f"{dt_string}.log") + os.makedirs(base_log_path, exist_ok=True) + file_handler = logging.FileHandler( + log_path, + delay=True, + ) + file_handler.setLevel(LOGGING_LEVELS["debug"]) + logger.addHandler(file_handler) + logger.info(f"Logging all SparseML modifier-level logs to {log_path}") + + if not any( + isinstance(handler, logging.StreamHandler) for handler in logger.handlers + ): + # Console handler, for logging high level modifier logs + stream_handler = logging.StreamHandler() + stream_handler.setLevel(log_level or logging.getLogger("sparseml").level) + logger.addHandler(stream_handler) + + logger.setLevel(LOGGING_LEVELS["debug"]) + logger.propagate = False + + return logger + + def _log_lambda( + self, + tag: Optional[str], + value: Optional[Union[float, str]], + values: Optional[Dict[str, float]], + step: Optional[int], + wall_time: Optional[float], + level: Optional[int] = None, + ) -> bool: + """ + :param tag: identifying tag to log the values with + :param value: value to save + :param values: values to save + :param step: global step for when the values were taken + :param wall_time: global wall time for when the values were taken, + defaults to time.time() + :param level: level to log at. Corresponds to default logging package levels + :return: True if logged, False otherwise. + """ + if not level: + level = LOGGING_LEVELS["debug"] + + if level > LOGGING_LEVELS["debug"]: + format = "%s %s step %s: %s" + log_args = [ + self.name, + tag, + step, + values or value, + ] + else: + format = "%s %s [%s - %s]: %s" + log_args = [self.name, tag, step, wall_time, values or value] + + self._logger.log(level, format, *log_args) + + return True + + def log_string( + self, + tag: Optional[str], + string: Optional[str], + step: Optional[int], + wall_time: Optional[float] = None, + level: Optional[int] = None, + ) -> bool: + """ + :param tag: identifying tag to log the values with + :param string: string to log + :param step: global step for when the values were taken + :param wall_time: global wall time for when the values were taken, + defaults to time.time() + :param level: level to log at. Corresponds to default logging package levels + :return: True if logged, False otherwise. + """ + if not wall_time: + wall_time = time.time() + + return self._lambda_func( + tag=tag, + value=string, + values=None, + step=step, + level=level, + wall_time=wall_time, + ) + + +class TensorBoardLogger(LambdaLogger): + """ + Modifier logger that handles outputting values into a TensorBoard log directory + for viewing in TensorBoard. + + :param log_path: the path to create a SummaryWriter at. writer must be None + to use if not supplied (and writer is None), + will create a TensorBoard dir in cwd + :param writer: the writer to log results to, + if none is given creates a new one at the log_path + :param name: name given to the logger, used for identification; + defaults to tensorboard + :param enabled: True to log, False otherwise + """ + + def __init__( + self, + log_path: str = None, + writer: SummaryWriter = None, + name: str = "tensorboard", + enabled: bool = True, + ): + if tensorboard_import_error: + raise tensorboard_import_error + + if writer and log_path: + raise ValueError( + ( + "log_path given:{} and writer object passed in, " + "to create a writer at the log path set writer=None" + ).format(log_path) + ) + elif not writer and not log_path: + log_path = os.path.join(".", "tensorboard") + + if os.environ.get("NM_TEST_MODE"): + test_log_root = os.environ.get("NM_TEST_LOG_DIR") + log_path = ( + os.path.join(test_log_root, log_path) if log_path else test_log_root + ) + + if log_path: + create_dirs(log_path) + + self._writer = writer if writer is not None else SummaryWriter(log_path) + super().__init__( + lambda_func=self._log_lambda, + name=name, + enabled=enabled, + ) + + @property + def writer(self) -> SummaryWriter: + """ + :return: the writer to log results to, + if none is given creates a new one at the log_path + """ + return self._writer + + def _log_lambda( + self, + tag: Optional[str], + value: Optional[float], + values: Optional[Dict[str, float]], + step: Optional[int], + wall_time: Optional[float], + level: Optional[int] = None, + ) -> bool: + if value is not None: + self._writer.add_scalar(tag, value, step, wall_time) + + if values and tag: + self._writer.add_scalars(tag, values, step, wall_time) + elif values: + for name, val in values.items(): + # hyperparameters logging case + self._writer.add_scalar(name, val, step, wall_time) + + return True + + +class WANDBLogger(LambdaLogger): + """ + Modifier logger that handles outputting values to Weights and Biases. + + :param init_kwargs: the args to call into wandb.init with; + ex: wandb.init(**init_kwargs). If not supplied, then init will not be called + :param name: name given to the logger, used for identification; + defaults to wandb + :param enabled: True to log, False otherwise + """ + + @staticmethod + def available() -> bool: + """ + :return: True if wandb is available and installed, False, otherwise + """ + return not wandb_err + + def __init__( + self, + init_kwargs: Optional[Dict] = None, + name: str = "wandb", + enabled: bool = True, + wandb_err: Optional[Exception] = wandb_err, + ): + if wandb_err: + raise ModuleNotFoundError( + "Error: Failed to import wandb. " + "Please install the wandb library in order to use it." + ) from wandb_err + + super().__init__( + lambda_func=self._log_lambda, + name=name, + enabled=enabled, + ) + + if os.environ.get("NM_TEST_MODE"): + test_log_path = os.environ.get("NM_TEST_LOG_DIR") + create_dirs(test_log_path) + if init_kwargs: + init_kwargs["dir"] = test_log_path + else: + init_kwargs = {"dir": test_log_path} + + if wandb_err: + raise wandb_err + + if init_kwargs: + wandb.init(**init_kwargs) + else: + wandb.init() + + self.wandb = wandb + + def _log_lambda( + self, + tag: Optional[str], + value: Optional[float], + values: Optional[Dict[str, float]], + step: Optional[int], + wall_time: Optional[float], + level: Optional[int] = None, + ) -> bool: + params = {} + + if value: + params[tag] = value + + if values: + if tag: + values = {f"{tag}/{key}": val for key, val in values.items()} + params.update(values) + + wandb.log(params, step=step) + + return True + + def save( + self, + file_path: str, + ) -> bool: + """ + :param file_path: path to a file to be saved + """ + wandb.save(file_path) + return True + + +class SparsificationGroupLogger(BaseLogger): + """ + Modifier logger that handles outputting values to other supported systems. + Supported ones include: + - Python logging + - Tensorboard + - Weights and Biases + - Lambda callback + + All are optional and can be bulk disabled and enabled by this root. + + :param lambda_func: an optional lambda function to call back into with any logs. + The expected call sequence is (tag, value, values, step, wall_time) -> bool + The return type is True if logged and False otherwise. + :param python: an optional argument for logging to a python logger. + May be a logging.Logger instance to log to, True to create a logger instance, + or non truthy to not log anything (False, None) + :param python_log_level: if python, + the level to log any incoming data at on the logging.Logger instance + :param tensorboard: an optional argument for logging to a tensorboard writer. + May be a SummaryWriter instance to log to, a string representing the directory + to create a new SummaryWriter to log to, True to create a new SummaryWriter, + or non truthy to not log anything (False, None) + :param wandb_: an optional argument for logging to wandb. + May be a dictionary to pass to the init call for wandb, + True to log to wandb (will not call init), + or non truthy to not log anything (False, None) + :param name: name given to the logger, used for identification; + defaults to sparsification + :param enabled: True to log, False otherwise + """ + + def __init__( + self, + lambda_func: Optional[ + Callable[ + [ + Optional[str], + Optional[float], + Optional[Dict[str, float]], + Optional[int], + Optional[float], + ], + bool, + ] + ] = None, + python: Optional[Union[bool, Logger]] = None, + python_log_level: int = logging.INFO, + tensorboard: Optional[Union[bool, str, SummaryWriter]] = None, + wandb_: Optional[Union[bool, Dict]] = None, + name: str = "sparsification", + enabled: bool = True, + ): + super().__init__(name, enabled) + self._loggers: List[BaseLogger] = [] + + if lambda_func: + self._loggers.append( + LambdaLogger(lambda_func=lambda_func, name=name, enabled=enabled) + ) + + if python: + self._loggers.append( + PythonLogger( + logger=python if isinstance(python, Logger) else None, + log_level=python_log_level, + name=name, + enabled=enabled, + ) + ) + + if tensorboard: + self._loggers.append( + TensorBoardLogger( + log_path=tensorboard if isinstance(tensorboard, str) else None, + writer=( + tensorboard if isinstance(tensorboard, SummaryWriter) else None + ), + name=name, + enabled=enabled, + ) + ) + + if wandb_ and WANDBLogger.available(): + self._loggers.append( + WANDBLogger( + init_kwargs=wandb_ if isinstance(wandb_, Dict) else None, + name=name, + enabled=enabled, + ) + ) + + @BaseLogger.enabled.setter + def enabled(self, value: bool): + """ + :param value: True to log, False otherwise + """ + self._enabled = value + + for logger in self._loggers: + logger.enabled = value + + @property + def loggers(self) -> List[BaseLogger]: + """ + :return: the created logger sub instances for this logger + """ + return self._loggers + + def log_hyperparams(self, params: Dict, level: Optional[int] = None): + """ + :param params: Each key-value pair in the dictionary is the name of the + hyper parameter and it's corresponding value. + """ + for logger in self._loggers: + logger.log_hyperparams(params, level) + + def log_scalar( + self, + tag: str, + value: float, + step: Optional[int] = None, + wall_time: Optional[float] = None, + level: Optional[int] = None, + ): + """ + :param tag: identifying tag to log the value with + :param value: value to save + :param step: global step for when the value was taken + :param wall_time: global wall time for when the value was taken, + defaults to time.time() + """ + for logger in self._loggers: + logger.log_scalar(tag, value, step, wall_time, level) + + def log_scalars( + self, + tag: str, + values: Dict[str, float], + step: Optional[int] = None, + wall_time: Optional[float] = None, + level: Optional[int] = None, + ): + """ + :param tag: identifying tag to log the values with + :param values: values to save + :param step: global step for when the values were taken + :param wall_time: global wall time for when the values were taken, + defaults to time.time() + """ + for logger in self._loggers: + logger.log_scalars(tag, values, step, wall_time, level) + + +class LoggerManager(ABC): + """ + Wrapper around loggers that handles log scheduling and handing off logs to intended + loggers. + + :param loggers: list of loggers assigned to this manager + :param log_frequency: number of epochs or fraction of epochs to wait between logs + + """ + + def __init__( + self, + loggers: Optional[List[BaseLogger]] = None, + log_frequency: Union[float, None] = 0.1, + log_python: bool = True, + name: str = "manager", + ): + self._loggers = loggers or [] + self._log_frequency = log_frequency + self._name = name + if log_python and not any( + isinstance(log, PythonLogger) for log in self._loggers + ): + self._loggers.append(PythonLogger()) + + def __len__(self): + return len(self.loggers) + + def __iter__(self): + return iter(self.loggers) + + def add_logger(self, logger: BaseLogger): + """ + add a BaseLogger implementation to the loggers of this manager + + :param logger: logger object to add + """ + if not isinstance(logger, BaseLogger): + raise ValueError(f"logger {type(logger)} must be of type BaseLogger") + self._loggers.append(logger) + + def log_ready(self, epoch, last_log_epoch): + """ + Check if there is a logger that is ready to accept a log + + :param epoch: current epoch log is requested at + :param last_log_epoch: last time a log was recorder for this object + :return: True if a logger is ready to accept a log. + """ + return ( + self._log_frequency is not None + and ( + epoch is None + or epoch == last_log_epoch + or epoch >= last_log_epoch + self._log_frequency + ) + and any(log.enabled for log in self.loggers) + ) + + @staticmethod + def epoch_to_step(epoch, steps_per_epoch): + return round(epoch) if steps_per_epoch <= 0 else round(epoch * steps_per_epoch) + + @property + def loggers(self) -> List[BaseLogger]: + """ + :return: list of loggers assigned to this manager + """ + return self._loggers + + @loggers.setter + def loggers(self, value: List[BaseLogger]): + """ + :param value: list of loggers assigned to this manager + """ + self._loggers = value + + @property + def log_frequency(self) -> Union[str, float, None]: + """ + :return: number of epochs or fraction of epochs to wait between logs + """ + return self._log_frequency + + @log_frequency.setter + def log_frequency(self, value: Union[str, float, None]): + """ + :param value: number of epochs or fraction of epochs to wait between logs + """ + self._log_frequency = value + + @property + def name(self) -> str: + """ + :return: name given to the logger, used for identification + """ + return self._name + + @property + def wandb(self) -> Optional[ModuleType]: + """ + :return: wandb module if initialized + """ + for log in self.loggers: + if isinstance(log, WANDBLogger) and log.enabled: + return log.wandb + return None + + def log_scalar( + self, + tag: str, + value: float, + step: Optional[int] = None, + wall_time: Optional[float] = None, + log_types: Union[str, List[str]] = ALL_TOKEN, + level: Optional[int] = None, + ): + """ + :param tag: identifying tag to log the value with + :param value: value to save + :param step: global step for when the value was taken + :param wall_time: global wall time for when the value was taken + :param kwargs: additional logging arguments to support Python and custom loggers + :return: True if logged, False otherwise. + """ + for log in self.loggers: + if log.enabled and (log_types == ALL_TOKEN or log.name in log_types): + log.log_scalar( + tag=tag, + value=value, + step=step, + wall_time=wall_time, + level=level, + ) + + def log_scalars( + self, + tag: str, + values: float, + step: Optional[int] = None, + wall_time: Optional[float] = None, + log_types: Union[str, List[str]] = ALL_TOKEN, + level: Optional[int] = None, + ): + """ + :param tag: identifying tag to log the values with + :param values: values to save + :param step: global step for when the values were taken + :param wall_time: global wall time for when the values were taken + :param kwargs: additional logging arguments to support Python and custom loggers + :return: True if logged, False otherwise. + """ + for log in self.loggers: + if log.enabled and (log_types == ALL_TOKEN or log.name in log_types): + log.log_scalars( + tag=tag, + values=values, + step=step, + wall_time=wall_time, + level=level, + ) + + def log_string( + self, + tag: str, + string: str, + step: Optional[int] = None, + wall_time: Optional[float] = None, + log_types: Union[str, List[str]] = ALL_TOKEN, + level: Optional[int] = None, + ): + """ + :param tag: identifying tag to log the values with + :param values: values to save + :param step: global step for when the values were taken + :param wall_time: global wall time for when the values were taken + :param kwargs: additional logging arguments to support Python and custom loggers + :return: True if logged, False otherwise. + """ + for log in self.loggers: + if log.enabled and (log_types == ALL_TOKEN or log.name in log_types): + log.log_string( + tag=tag, + string=string, + step=step, + wall_time=wall_time, + level=level, + ) + + def log_hyperparams( + self, + params: Dict, + log_types: Union[str, List[str]] = ALL_TOKEN, + level: Optional[int] = None, + ): + """ + :param params: Each key-value pair in the dictionary is the name of the + hyper parameter and it's corresponding value. + """ + for log in self._loggers: + if log.enabled and (log_types == ALL_TOKEN or log.name in log_types): + log.log_hyperparams(params, level) + + def save( + self, + file_path: str, + **kwargs, + ): + """ + :param file_path: path to a file to be saved + :param kwargs: additional arguments that a specific logger might use + """ + for log in self._loggers: + if log.enabled: + log.save(file_path, **kwargs) diff --git a/src/sparseml/pytorch/utils/sparsification_info/configs.py b/src/sparseml/pytorch/utils/sparsification_info/configs.py index 32292eee008..af6c2f9e48d 100644 --- a/src/sparseml/pytorch/utils/sparsification_info/configs.py +++ b/src/sparseml/pytorch/utils/sparsification_info/configs.py @@ -14,7 +14,7 @@ from abc import ABC, abstractmethod from collections import Counter, defaultdict -from typing import Dict, Generator, Tuple, Union +from typing import Any, Dict, Generator, Tuple, Union import torch.nn from pydantic import BaseModel, Field @@ -54,6 +54,7 @@ def from_module( @abstractmethod def loggable_items( self, + **kwargs, ) -> Generator[Tuple[str, Union[Dict[str, int], float, int]], None, None]: """ Yield the loggable items for SparsificationInfo object. @@ -62,6 +63,67 @@ def loggable_items( """ raise NotImplementedError() + @staticmethod + def filter_loggable_items_percentages_only( + items_to_log: Generator[Tuple[str, Any], None, None], + percentage_only: bool = False, + ): + """ + Filter the loggable items to only yield the percentages of the loggable items + + :param items_to_log: A generator that yields the loggable items for this object. + :param percentage_only: If True, only yield the percentages of the loggable + items. If False, yield both the counts and percentages. Defaults to False + :return: A generator that yields the loggable items for this object. + """ + + def filter_percentage(log): + # log tag ends with percent + return log[0].endswith("percent") + + yield from SparsificationInfo._filter_items_to_log( + items_to_log, + filter_function=filter_percentage, + to_filter=percentage_only, + ) + + @staticmethod + def filter_loggable_items_non_zero_only(items_to_log, non_zero_only): + """ + Filter the loggable items to only yield the non-zero items + + :param items_to_log: A generator that yields the loggable items for this object. + :param non_zero_only: If True, only yield information for non-zero items. + :return: A generator that yields the loggable items for this object. + """ + + def filter_non_zero_values(log): + # log value must be non-zero + return log[1] != 0 + + yield from SparsificationInfo._filter_items_to_log( + items_to_log, + filter_function=filter_non_zero_values, + to_filter=non_zero_only, + ) + + @staticmethod + def _filter_items_to_log(items_to_log, filter_function, to_filter: bool = True): + """ + Utility function to filter the loggable itemsn based on a filter function + + :param items_to_log: A generator that yields the loggable items for this object. + :param filter_function: A function that takes in a loggable item and returns + True if the item should be yieled, False otherwise. + :param to_filter: If True, filter the loggable items. If False, do not filter. + :return: A generator that yields the loggable items for this object. + """ + for loggable_item in items_to_log: + if not to_filter: + yield loggable_item + elif filter_function(loggable_item): + yield loggable_item + class CountAndPercent(BaseModel): count: int = Field(description="The count of items") @@ -141,19 +203,37 @@ def from_module( def loggable_items( self, + non_zero_only: bool = False, + percentages_only: bool = True, + **kwargs, ) -> Generator[Tuple[str, Union[Dict[str, int], float, int]], None, None]: """ Yield the loggable items for SparsificationSummaries object. + :param non_zero_only: If True, only yield information for non-zero items. + :param percentages_only: If True, only yield the percentages of the loggable + items. If False, yield both the counts and percentages. Defaults to True :return: A generator that yields the loggable items for this object. """ main_tag = self.__class__.__name__ yield f"{main_tag}/OperationCounts", self.operation_counts yield f"{main_tag}/ParameterCounts", self.parameter_counts - yield f"{main_tag}/QuantizedOperations/count", self.quantized.count - yield f"{main_tag}/QuantizedOperations/percent", self.quantized.percent - yield f"{main_tag}/PrunedParameters/count", self.pruned.count - yield f"{main_tag}/PrunedParameters/percent", self.pruned.percent + + items_to_log = ( + (f"{main_tag}/QuantizedOperations/count", self.quantized.count), + (f"{main_tag}/QuantizedOperations/percent", self.quantized.percent), + (f"{main_tag}/PrunedParameters/count", self.pruned.count), + (f"{main_tag}/PrunedParameters/percent", self.pruned.percent), + ) + + items_to_log = SparsificationInfo.filter_loggable_items_percentages_only( + items_to_log, percentages_only + ) + items_to_log = SparsificationInfo.filter_loggable_items_non_zero_only( + items_to_log, non_zero_only + ) + + yield from items_to_log class SparsificationPruning(SparsificationInfo): @@ -191,16 +271,43 @@ def from_module(cls, module: torch.nn.Module) -> "SparsificationPruning": def loggable_items( self, + percentages_only: bool = False, + non_zero_only: bool = False, + **kwargs, ) -> Generator[Tuple[str, Union[Dict[str, int], float, int]], None, None]: """ Yield the loggable items for SparsificationPruning object. + :param percentages_only: If True, only yield the percentages of the loggable + items. If False, yield both the counts and percentages. Default is False. + :param non_zero_only: If True, only yield information for non-zero + counts/percentages. Default is False. :return: A generator that yields the loggable items for this object. """ main_tag = self.__class__.__name__ + items_to_log = [] for param_name, count_and_percent in self.sparse_parameters.items(): - yield f"{main_tag}/SparseParameters/{param_name}/count", count_and_percent.count # noqa: E501 - yield f"{main_tag}/SparseParameters/{param_name}/percent", count_and_percent.percent # noqa: E501 + items_to_log.append( + ( + f"{main_tag}/SparseParameters/{param_name}/count", + count_and_percent.count, + ) + ) # noqa: E501 + items_to_log.append( + ( + f"{main_tag}/SparseParameters/{param_name}/percent", + count_and_percent.percent, + ) + ) # noqa: E501 + + items_to_log = SparsificationInfo.filter_loggable_items_percentages_only( + items_to_log, percentages_only + ) + items_to_log = SparsificationInfo.filter_loggable_items_non_zero_only( + items_to_log, non_zero_only + ) + + yield from items_to_log class SparsificationQuantization(SparsificationInfo): @@ -250,14 +357,22 @@ def from_module( def loggable_items( self, + enabled_only: bool = False, + **kwargs, ) -> Generator[Tuple[str, Union[Dict[str, int], float, int]], None, None]: """ Yield the loggable items for SparsificationQuantization object. + :param enabled_only: If True, only yield loggable items for + operations where quantization is enabled. If False, yield irrespective + of whether quantization is enabled or not. Defaults to False. :return: A generator that yields the loggable items for this object. """ main_tag = self.__class__.__name__ for operation in self.enabled.keys(): + if enabled_only and not self.enabled[operation]: + continue + yield f"{main_tag}/{operation}/enabled", self.enabled[operation] precision = self.precision[operation] diff --git a/src/sparseml/pytorch/utils/sparsification_info/module_sparsification_info.py b/src/sparseml/pytorch/utils/sparsification_info/module_sparsification_info.py index c5ab483f449..b0d65fdf1f9 100644 --- a/src/sparseml/pytorch/utils/sparsification_info/module_sparsification_info.py +++ b/src/sparseml/pytorch/utils/sparsification_info/module_sparsification_info.py @@ -59,20 +59,15 @@ def from_module(cls, module: torch.nn.Module) -> "ModuleSparsificationInfo": quantization_info=SparsificationQuantization.from_module(module), ) - return cls( - summary_info=SparsificationSummaries.from_module(module), - pruning_info=SparsificationPruning.from_module(module), - quantization_info=SparsificationQuantization.from_module(module), - ) - - def loggable_items(self) -> Generator[Tuple[str, Any], None, None]: + def loggable_items(self, **kwargs) -> Generator[Tuple[str, Any], None, None]: """ A generator that yields the loggable items of the ModuleSparsificationInfo object. + :param kwargs: additional kwargs to pass to the loggable items :return a generator that yields a tuple of: - the name of the loggable item - the value of the loggable item """ for info in [self.summary_info, self.pruning_info, self.quantization_info]: - yield from info.loggable_items() + yield from info.loggable_items(**kwargs) diff --git a/tests/sparseml/core/logger/__init__.py b/tests/sparseml/core/logger/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/tests/sparseml/core/logger/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sparseml/core/test_logger.py b/tests/sparseml/core/logger/test_logger.py similarity index 98% rename from tests/sparseml/core/test_logger.py rename to tests/sparseml/core/logger/test_logger.py index d26fe3bee42..86603cb8e2b 100644 --- a/tests/sparseml/core/test_logger.py +++ b/tests/sparseml/core/logger/test_logger.py @@ -19,7 +19,7 @@ import pytest -from sparseml.core.logger import ( +from sparseml.core import ( LambdaLogger, LoggerManager, PythonLogger, diff --git a/tests/sparseml/core/logger/utils/test_frequency_manager.py b/tests/sparseml/core/logger/utils/test_frequency_manager.py new file mode 100644 index 00000000000..381598810cb --- /dev/null +++ b/tests/sparseml/core/logger/utils/test_frequency_manager.py @@ -0,0 +1,159 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from contextlib import nullcontext as does_not_raise + +import pytest + +from sparseml.core.logger.utils import FrequencyManager + + +@pytest.fixture(scope="function") +def frequency_manager(): + return FrequencyManager(log_frequency=1) + + +@pytest.mark.parametrize( + "kwargs, expectation", + [ + ( + {"log_frequency": 0}, + pytest.raises(ValueError, match="must be greater than 0, given 0"), + ), + ( + {"log_frequency": -1}, + pytest.raises(ValueError, match="must be greater than 0, given -1"), + ), + ( + {"log_frequency": True}, + pytest.raises(TypeError, match="must be a number or None"), + ), + ( + {"log_frequency": []}, + pytest.raises(TypeError, match="must be a number or None"), + ), + ( + {"log_frequency": {}}, + pytest.raises(TypeError, match="must be a number or None"), + ), + ({"log_frequency": 1}, does_not_raise()), + ({"log_frequency": None}, does_not_raise()), + ], +) +def test_frequency_manager_creation(kwargs, expectation): + with expectation: + _ = FrequencyManager(**kwargs) + + +@pytest.mark.parametrize( + "step, expectation", + [ + (0.1, does_not_raise()), + (0, does_not_raise()), + (-1, pytest.raises(ValueError, match="must be greater than or equal to 0")), + (True, pytest.raises(TypeError, match="must be a number or None")), + ([], pytest.raises(TypeError, match="must be a number or None")), + ({}, pytest.raises(TypeError, match="must be a number or None")), + ], +) +class TestFrequencyManagerUpdationUtilities: + def test_model_updated(self, frequency_manager, step, expectation): + # test that model_updated sets last_model_update_step + # to the given step + + with expectation: + frequency_manager.model_updated(step=step) + assert frequency_manager.last_model_update_step == step + + def test_log_written(self, frequency_manager, step, expectation): + # test that log_written sets last_log_step + # to the given step + + with expectation: + frequency_manager.log_written(step=step) + assert frequency_manager.last_log_step == step + + +def _log_ready_test_cases(): + # test cases for log_ready + + # each test case is a tuple of: + # (log_frequency, current_log_step, last_log_step, + # last_model_update_step, check_model_update, expected) + + return [ + # None values should give True + (0.1, None, None, None, False, True), + (0.1, None, 1, 1, False, True), + (0.1, 1, None, 1, False, True), + (0.1, 0.3, 0.2, None, False, True), + (0.1, 0.3, 0.2, None, True, True), + # log frequency is None + (None, 1, 2, 3, False, False), + (None, 1, 2, 3, True, False), + # cadence not reached + (0.1, 1, 1, 0.1, False, False), + (0.1, 1, 1, 0.1, True, False), + # cadence reached + (0.1, 0.3, 0.1, 0.3, False, True), + (0.1, 0.3, 0.1, 0.1, True, True), + # model updated long back and + # and cadence reached + (0.1, 0.3, 0.1, 0.1, True, True), + ] + + +@pytest.mark.parametrize( + "log_frequency, current_log_step, last_log_step," + " last_model_update_step, check_model_update, expected", + _log_ready_test_cases(), +) +def test_log_ready( + log_frequency, + current_log_step, + last_log_step, + last_model_update_step, + check_model_update, + expected, +): + frequency_manager = FrequencyManager(log_frequency=log_frequency) + frequency_manager.last_log_step = last_log_step + frequency_manager.last_model_update_step = last_model_update_step + + actual = frequency_manager.log_ready( + current_log_step=current_log_step, check_model_update=check_model_update + ) + + assert actual == expected + + +@pytest.mark.parametrize( + "log_frequency, frequency_type, expectation", + [ + # epoch frequency type accepts floats + (0.1, "epoch", does_not_raise()), + # batch and step frequency types need + # log_frequency to be an integer + (0.1, "batch", pytest.raises(TypeError)), + (0.1, "step", pytest.raises(TypeError)), + # negative log_frequency is invalid + (-1, "epoch", pytest.raises(ValueError)), + (-1, "batch", pytest.raises(ValueError)), + (-1, "step", pytest.raises(ValueError)), + ], +) +def test__validate_log_frequency(log_frequency, frequency_type, expectation): + with expectation: + FrequencyManager(log_frequency=log_frequency, frequency_type=frequency_type) diff --git a/tests/sparseml/core/modifier/__init__.py b/tests/sparseml/core/modifier/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/tests/sparseml/core/modifier/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sparseml/core/test_helpers.py b/tests/sparseml/core/test_helpers.py new file mode 100644 index 00000000000..1aba52be9fc --- /dev/null +++ b/tests/sparseml/core/test_helpers.py @@ -0,0 +1,124 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict + +import pytest + +from sparseml.core.event import EventType +from sparseml.core.helpers import _log_epoch, _log_model_loggable_items, log_model_info +from sparseml.core.logger import LoggerManager + + +class ModelMock: + def loggable_items(self): + for value in [("a", 1), ("b", 2), ("c", 3)]: + yield value + + +class LoggerManagerMock: + def __init__(self): + self.hit_count = defaultdict(int) + + def epoch_to_step(self, epoch, steps_per_epoch): + self.hit_count["epoch_to_step"] += 1 + return epoch * steps_per_epoch + + def log_string(self, tag, string, step): + self.hit_count["log_string"] += 1 + + def log_scalars(self, tag, values, step): + self.hit_count["log_scalars"] += 1 + + def log_scalar(self, tag, value, step): + self.hit_count["log_scalar"] += 1 + + def log_ready(self, *args, **kwargs): + pass + + @property + def __class__(self): + return LoggerManager + + +class StateMock: + def __init__(self, steps_per_epoch=10, epoch=1): + self.model = ModelMock() + self.loggers = LoggerManagerMock() + self.epoch = epoch + self.steps_per_epoch = steps_per_epoch + + +class EventMock: + def __init__(self, type_=EventType.BATCH_END, current_index=1, steps_per_epoch=10): + self.type_ = type_ + self.current_index = current_index + self.steps_per_epoch = steps_per_epoch + + +@pytest.fixture +def state_mock(): + yield StateMock() + + +def test__log_epoch_invokes_log_scalar(): + logger_manager = LoggerManagerMock() + _log_epoch( + logger_manager=logger_manager, + epoch=1, + ) + # log epoch should invoke log_string + assert logger_manager.hit_count["log_scalar"] == 1 + + +def test_log_model_info_logs_epoch_and_loggable_items(): + state = StateMock() + epoch = 3 + log_model_info(state, epoch=epoch) + + # loggable items will invoke log_scalar for each + # int/float value + 1 for the epoch + assert state.loggers.hit_count["log_scalar"] == epoch + 1 + + +@pytest.mark.parametrize( + "loggable_items", [ModelMock().loggable_items(), [("a", {}), ("b", 2), ("c", {})]] +) +def test__log_model_loggable_items_routes_appropriately(loggable_items, monkeypatch): + logger_manager = LoggerManagerMock() + loggable_items = list(loggable_items) + + scalar_count = scalars_count = string_count = 0 + for _, value in loggable_items: + if isinstance(value, (int, float)): + scalar_count += 1 + elif isinstance(value, dict): + scalars_count += 1 + else: + string_count += 1 + + _log_model_loggable_items( + logger_manager=logger_manager, loggable_items=loggable_items, epoch=1 + ) + + # loggable items will invoke log_scalar for each + # int/float value + assert logger_manager.hit_count["log_scalar"] == scalar_count + + # loggable items will invoke log_scalars for each + # dict value + assert logger_manager.hit_count["log_scalars"] == scalars_count + + # All other value types will invoke log_string + assert logger_manager.hit_count["log_string"] == string_count diff --git a/tests/sparseml/core/test_session.py b/tests/sparseml/core/test_session.py index b9710566ebc..976a1b3ae3e 100644 --- a/tests/sparseml/core/test_session.py +++ b/tests/sparseml/core/test_session.py @@ -71,6 +71,10 @@ def get_linear_net(): return LinearNet() +def empty_mock(*args, **kwargs): + pass + + class TestSparseSession: def test_session_has_a_sparsification_lifecycle(self, setup_active_session): assert hasattr( @@ -132,6 +136,9 @@ def test_session_methods_invoke_lifecycle_methods( "_lifecycle", lifecycle_mock := LifeCycleMock(model=kwargs.get("model")), ) + monkeypatch.setattr(setup_active_session, "_log_model_info", empty_mock) + monkeypatch.setattr(setup_active_session, "_log_loss", empty_mock) + method = getattr(setup_active_session, method_name) result = method(**kwargs) @@ -150,6 +157,7 @@ def test_apply_calls_lifecycle_initialize_and_finalize( monkeypatch.setattr( setup_active_session, "_lifecycle", lifecycle_mock := LifeCycleMock() ) + monkeypatch.setattr(setup_active_session, "_log_loss", empty_mock) setup_active_session.apply() # check initialize was called once diff --git a/tests/sparseml/pytorch/utils/test_logger.py b/tests/sparseml/pytorch/utils/test_logger.py new file mode 100644 index 00000000000..7cceeff3017 --- /dev/null +++ b/tests/sparseml/pytorch/utils/test_logger.py @@ -0,0 +1,92 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import time +from abc import ABC + +import pytest + +from sparseml.pytorch.utils import ( + LambdaLogger, + LoggerManager, + PythonLogger, + SparsificationGroupLogger, + TensorBoardLogger, + WANDBLogger, +) + + +@pytest.mark.skipif( + os.getenv("NM_ML_SKIP_PYTORCH_TESTS", False), + reason="Skipping pytorch tests", +) +@pytest.mark.parametrize( + "logger", + [ + PythonLogger(), + TensorBoardLogger(), + LambdaLogger( + lambda_func=lambda tag, value, values, step, wall_time, level: logging.info( + f"{tag}, {value}, {values}, {step}, {wall_time}, {level}" + ) + or True + ), + *([WANDBLogger()] if WANDBLogger.available() else []), + SparsificationGroupLogger( + lambda_func=lambda tag, value, values, step, wall_time, level: logging.info( + f"{tag}, {value}, {values}, {step}, {wall_time}, {level}" + ) + or True, + python=True, + tensorboard=True, + wandb_=True, + ), + LoggerManager(), + LoggerManager( + [ + TensorBoardLogger(), + WANDBLogger() if WANDBLogger.available() else PythonLogger(), + ] + ), + ], +) +class TestModifierLogger(ABC): + def test_name(self, logger): + assert logger.name is not None + + def test_log_hyperparams(self, logger): + logger.log_hyperparams({"param1": 0.0, "param2": 1.0}) + logger.log_hyperparams({"param1": 0.0, "param2": 1.0}, level=10) + + def test_log_scalar(self, logger): + logger.log_scalar("test-scalar-tag", 0.1) + logger.log_scalar("test-scalar-tag", 0.1, 1) + logger.log_scalar("test-scalar-tag", 0.1, 2, time.time() - 1) + logger.log_scalar("test-scalar-tag", 0.1, 2, time.time() - 1, level=10) + + def test_log_scalars(self, logger): + logger.log_scalars("test-scalars-tag", {"scalar1": 0.0, "scalar2": 1.0}) + logger.log_scalars("test-scalars-tag", {"scalar1": 0.0, "scalar2": 1.0}, 1) + logger.log_scalars( + "test-scalars-tag", {"scalar1": 0.0, "scalar2": 1.0}, 2, time.time() - 1 + ) + logger.log_scalars( + "test-scalars-tag", + {"scalar1": 0.0, "scalar2": 1.0}, + 2, + time.time() - 1, + level=10, + )