diff --git a/pl_bolts/callbacks/data_monitor.py b/pl_bolts/callbacks/data_monitor.py index 95879847d4..97ba087229 100644 --- a/pl_bolts/callbacks/data_monitor.py +++ b/pl_bolts/callbacks/data_monitor.py @@ -3,7 +3,7 @@ import numpy as np import torch from pytorch_lightning import Callback, LightningModule, Trainer -from pytorch_lightning.loggers import LightningLoggerBase, TensorBoardLogger, WandbLogger +from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection from torch import Tensor, nn @@ -14,6 +14,11 @@ from pl_bolts.utils.stability import under_review from pl_bolts.utils.warnings import warn_missing_pkg +try: + from pytorch_lightning.loggers import Logger +except ImportError: + from pytorch_lightning.loggers import LightningLoggerBase as Logger + if _WANDB_AVAILABLE: import wandb else: # pragma: no cover @@ -96,7 +101,7 @@ def log_histogram(self, tensor: Tensor, name: str) -> None: logger.experiment.log(data={name: wandb.Histogram(tensor)}, commit=False) - def _is_logger_available(self, logger: LightningLoggerBase) -> bool: + def _is_logger_available(self, logger: Logger) -> bool: available = True if not logger: rank_zero_warn("Cannot log histograms because Trainer has no logger.")