From d67a1c8ddf13292ceab1758aad325ccee983a9fd Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Fri, 12 Jan 2024 11:05:21 -0500 Subject: [PATCH] Add tb style logging functions (#1919) --- src/sparseml/core/logger/logger.py | 46 +++++++++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/src/sparseml/core/logger/logger.py b/src/sparseml/core/logger/logger.py index b9303753130..8c01ba3b60b 100644 --- a/src/sparseml/core/logger/logger.py +++ b/src/sparseml/core/logger/logger.py @@ -25,7 +25,7 @@ from logging import CRITICAL, DEBUG, ERROR, INFO, WARN, Logger from pathlib import Path from types import ModuleType -from typing import Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from sparseml.core.logger.utils import ( FrequencyManager, @@ -73,6 +73,7 @@ "error": ERROR, "critical": CRITICAL, } +DEFAULT_TAG = "defaul_tag" class BaseLogger(ABC): @@ -1279,6 +1280,49 @@ def log_scalars( level=level, ) + def add_scalar( + self, + value, + tag: str = DEFAULT_TAG, + step: Optional[int] = None, + wall_time: Union[int, float, None] = None, + **kwargs, + ): + """ + Add a scalar value to the logger + + :param value: value to log + :param tag: tag to log the value with, defaults to DEFAULT_TAG + :param step: global step for when the value was taken + :param wall_time: global wall time for when the value was taken + :param kwargs: additional logging arguments to to pass through to the + logger + """ + self.log_scalar(tag=tag, value=value, step=step, wall_time=wall_time, **kwargs) + + def add_scalars( + self, + values: Dict[str, Any], + tag: str = DEFAULT_TAG, + step: Optional[int] = None, + wall_time: Union[int, float, None] = None, + **kwargs, + ): + """ + Adds multiple scalar values to the logger + + :param values: values to log, must be A dict of serializable + python objects i.e `str`, `ints`, `floats`, `Tensors`, `dicts`, etc + :param tag: tag to log the value with, defaults to DEFAULT_TAG + :param step: global step for when the value was taken + :param wall_time: global wall time for when the value was taken + :param kwargs: additional logging arguments to to pass through to the + logger + """ + self.log_scalars( + tag=tag, values=values, step=step, wall_time=wall_time, **kwargs + ) + def _create_dirs(path: str): path = Path(path).expanduser().absolute()