Skip to content

Commit

Permalink
Add tb style logging functions (#1919)
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli authored Jan 12, 2024
1 parent bf0fb3f commit d67a1c8
Showing 1 changed file with 45 additions and 1 deletion.
46 changes: 45 additions & 1 deletion src/sparseml/core/logger/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from logging import CRITICAL, DEBUG, ERROR, INFO, WARN, Logger
from pathlib import Path
from types import ModuleType
from typing import Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

from sparseml.core.logger.utils import (
FrequencyManager,
Expand Down Expand Up @@ -73,6 +73,7 @@
"error": ERROR,
"critical": CRITICAL,
}
DEFAULT_TAG = "defaul_tag"


class BaseLogger(ABC):
Expand Down Expand Up @@ -1279,6 +1280,49 @@ def log_scalars(
level=level,
)

def add_scalar(
self,
value,
tag: str = DEFAULT_TAG,
step: Optional[int] = None,
wall_time: Union[int, float, None] = None,
**kwargs,
):
"""
Add a scalar value to the logger
:param value: value to log
:param tag: tag to log the value with, defaults to DEFAULT_TAG
:param step: global step for when the value was taken
:param wall_time: global wall time for when the value was taken
:param kwargs: additional logging arguments to to pass through to the
logger
"""
self.log_scalar(tag=tag, value=value, step=step, wall_time=wall_time, **kwargs)

def add_scalars(
self,
values: Dict[str, Any],
tag: str = DEFAULT_TAG,
step: Optional[int] = None,
wall_time: Union[int, float, None] = None,
**kwargs,
):
"""
Adds multiple scalar values to the logger
:param values: values to log, must be A dict of serializable
python objects i.e `str`, `ints`, `floats`, `Tensors`, `dicts`, etc
:param tag: tag to log the value with, defaults to DEFAULT_TAG
:param step: global step for when the value was taken
:param wall_time: global wall time for when the value was taken
:param kwargs: additional logging arguments to to pass through to the
logger
"""
self.log_scalars(
tag=tag, values=values, step=step, wall_time=wall_time, **kwargs
)


def _create_dirs(path: str):
path = Path(path).expanduser().absolute()
Expand Down

0 comments on commit d67a1c8

Please sign in to comment.