Skip to content

Commit

Permalink
[Logger Framework][Add] mode and frequency type to FrequencyManager (#…
Browse files Browse the repository at this point in the history
…1930)

* Add mode and frequency type to FrequencyManager

* move tests cases near the test
  • Loading branch information
rahul-tuli authored Jan 3, 2024
1 parent 63d0089 commit 8706a78
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 7 deletions.
114 changes: 107 additions & 7 deletions src/sparseml/core/logger/utils/frequency_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,46 @@
# limitations under the License.


from typing import Union
from typing import Literal, Union


__all__ = ["FrequencyManager"]

LogStepType = Union[int, float, None]
PossibleLoggingMode = Literal["on_change", "exact"]
PossibleFrequencyType = Literal["epoch", "step", "batch"]


class FrequencyManager:
"""
Class for managing the frequency of logging and model updates
:param log_frequency: The frequency to log at
:param mode: The logging mode to use, either "on_change" or "exact",
"on_change" will log when the model has been updated since the last log,
"exact" will log at the given frequency regardless of model updates
:param frequency_type: The frequency type to use, either "epoch", "step", or "batch"
controls what the frequency manager is tracking, e.g. if the frequency type
is "epoch", then the frequency manager will track the number of epochs that
have passed since the last log, if the frequency type is "step", then the
frequency manager will track the number of optimizer steps
"""

def __init__(self, log_frequency: LogStepType = None):
def __init__(
self,
log_frequency: LogStepType = None,
mode: PossibleLoggingMode = "exact",
frequency_type: PossibleFrequencyType = "epoch",
):
# sets self._logging_mode and self._check_model_update
self._logging_mode = self._set_logging_mode(mode=mode)

# sets self._frequency_type and self._valid_python_types
self._frequency_type = self._set_frequency_type(frequency_type=frequency_type)

self._validate_log_frequency(log_frequency=log_frequency)
self.log_frequency = log_frequency

self.last_log_step: LogStepType = None
self.last_model_update_step: LogStepType = None

Expand All @@ -46,10 +68,10 @@ def log_ready(
- current log step is None
- current log step greater than or equal to the last log step
plus the log frequency
- if check_model_update is True, then the last model update step
must be greater than or equal to the last log step, and the current
log step must be greater than or equal to the last model update step
plus the log frequency
- if check_model_update is True, or self._check_model_update is True,
then the last model update step must be greater than or equal
to the last log step, and the current log step must be greater
than or equal to the last model update step plus the log frequency
:param current_log_step: The current log step
:param check_model_update: If True, will check if the model has been updated
Expand All @@ -58,6 +80,12 @@ def log_ready(
:return: True if the frequency manager is ready to log,
False otherwise
"""
# check_model_update is used to override self._check_model_update
# e.g. if check_model_update is True, then the model update check
# will be performed even if self._check_model_update is False

check_model_update = check_model_update or self._check_model_update

# format is used to avoid floating point errors
# e.g. 0.1 + 0.2 != 0.3
# format(0.1 + 0.2, ".4f") == format(0.3, ".4f")
Expand Down Expand Up @@ -133,7 +161,7 @@ def _validate_log_step(self, log_step):
# raise TypeError if not a number or None
# raises ValueError if negative number

if not isinstance(log_step, (int, float, type(None))) or isinstance(
if not isinstance(log_step, self._valid_python_types) or isinstance(
log_step, bool
):
raise TypeError(
Expand All @@ -144,3 +172,75 @@ def _validate_log_step(self, log_step):
raise ValueError(
f"log step must be greater than or equal to 0, given {log_step}"
)

def _set_logging_mode(self, mode: PossibleLoggingMode) -> PossibleLoggingMode:
"""
Set the logging mode for the frequency manager.
The logging mode determines how the frequency manager determines
if it is ready to log
:param mode: The logging mode to set
:post-cond: The self._logging_mode is set to the given mode
:post-cond: The self._check_model_update is set to True if the mode is
"on_change"
:raises ValueError: If the given mode is not one of "on_change" or "exact"
:return: The logging mode that was set
"""
mode = _basic_normalization(mode)
if mode == "on_change":
self._check_model_update = True
self._logging_mode = "on_change"
elif mode == "exact":
self._check_model_update = False
self._logging_mode = "exact"
else:
raise ValueError(
f"Invalid logging mode {mode}, must be one of 'on_change', 'exact'"
)
return self._logging_mode

def _set_frequency_type(
self, frequency_type: PossibleFrequencyType
) -> PossibleFrequencyType:
"""
Set the frequency type for the frequency manager.
The frequency type determines what the frequency manager is tracking.
For example, if the frequency type is "epoch", then the frequency manager
will track the number of epochs that have passed since the last log.
:param frequency_type: The frequency type to set
:post-cond: The self._frequency_type is set to the given frequency type
:post-cond: The self._valid_python_types is set to the valid python types
for the given frequency type, e.g. (int, float, type(None)) for "epoch"
and (int, type(None)) for "step" or "batch"
:raises ValueError: If the given frequency type is not one of "epoch",
"step", or "batch"
:return: The frequency type that was set
"""
frequency_type = _basic_normalization(frequency_type)
if frequency_type == "epoch":
self._frequency_type = "epoch"
self._valid_python_types = (int, float, type(None))
elif frequency_type == "step":
self._frequency_type = "step"
self._valid_python_types = (int, type(None))
elif frequency_type == "batch":
self._frequency_type = "batch"
self._valid_python_types = (int, type(None))
else:
raise ValueError(
f"Invalid frequency type {frequency_type}, must be one of "
"'epoch', 'step', 'batch'"
)
return self._frequency_type


def _basic_normalization(value: str) -> str:
"""
Basic normalization for string values.
Removes leading and trailing whitespace and converts to lowercase.
:param value: The value to normalize
:return: The normalized value
"""
return value.strip().lower()
20 changes: 20 additions & 0 deletions tests/sparseml/core/logger/utils/test_frequency_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,23 @@ def test_log_ready(
)

assert actual == expected


@pytest.mark.parametrize(
"log_frequency, frequency_type, expectation",
[
# epoch frequency type accepts floats
(0.1, "epoch", does_not_raise()),
# batch and step frequency types need
# log_frequency to be an integer
(0.1, "batch", pytest.raises(TypeError)),
(0.1, "step", pytest.raises(TypeError)),
# negative log_frequency is invalid
(-1, "epoch", pytest.raises(ValueError)),
(-1, "batch", pytest.raises(ValueError)),
(-1, "step", pytest.raises(ValueError)),
],
)
def test__validate_log_frequency(log_frequency, frequency_type, expectation):
with expectation:
FrequencyManager(log_frequency=log_frequency, frequency_type=frequency_type)

0 comments on commit 8706a78

Please sign in to comment.