From e95dbe8a6106ecd6fda105863609c427e435b0cd Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Tue, 14 May 2024 13:43:36 +0200 Subject: [PATCH] =?UTF-8?q?=E2=AC=86=EF=B8=8F=20Upgrade=20wandb=20(#2040)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * upgrade wandb Signed-off-by: Ashwin Vaidya * upgrade wandb Signed-off-by: Ashwin Vaidya * sort imports Signed-off-by: Ashwin Vaidya * limit wandb version Signed-off-by: Ashwin Vaidya * revert changes to InferenceModel Signed-off-by: Ashwin Vaidya --------- Signed-off-by: Ashwin Vaidya --- pyproject.toml | 2 +- src/anomalib/loggers/wandb.py | 30 ++++++++++++++++++------------ 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 85d48c13e6..a8a57bfc1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ loggers = [ "comet-ml>=3.31.7", "gradio>=4", "tensorboard", - "wandb==0.12.17", + "wandb>=0.12.17,<=0.15.9", "mlflow >=1.0.0", ] notebooks = ["gitpython", "ipykernel", "ipywidgets", "notebook"] diff --git a/src/anomalib/loggers/wandb.py b/src/anomalib/loggers/wandb.py index ddb7e66d19..0a23c25192 100644 --- a/src/anomalib/loggers/wandb.py +++ b/src/anomalib/loggers/wandb.py @@ -3,24 +3,24 @@ # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +from typing import TYPE_CHECKING, Literal import numpy as np +from lightning.fabric.utilities.types import _PATH from lightning.pytorch.loggers.wandb import WandbLogger from lightning.pytorch.utilities import rank_zero_only from matplotlib.figure import Figure from anomalib.utils.exceptions import try_import +from .base import ImageLoggerBase + if try_import("wandb"): import wandb -from typing import TYPE_CHECKING - -from .base import ImageLoggerBase - if TYPE_CHECKING: from wandb.sdk.lib import RunDisabled - from wandb.wandb_run import Run + from wandb.sdk.wandb_run import Run class AnomalibWandbLogger(ImageLoggerBase, WandbLogger): @@ -44,8 +44,10 @@ class AnomalibWandbLogger(ImageLoggerBase, WandbLogger): Defaults to ``None``. save_dir: Path where data is saved (wandb dir by default). Defaults to ``None``. + version: Sets the version, mainly used to resume a previous run. offline: Run offline (data can be streamed later to wandb servers). Defaults to ``False``. + dir: Alias for save_dir. id: Sets the version, mainly used to resume a previous run. Defaults to ``None``. anonymous: Enables or explicitly disables anonymous logging. @@ -89,28 +91,32 @@ class AnomalibWandbLogger(ImageLoggerBase, WandbLogger): def __init__( self, name: str | None = None, - save_dir: str | None = None, - offline: bool | None = False, + save_dir: _PATH = ".", + version: str | None = None, + offline: bool = False, + dir: _PATH | None = None, # kept to match wandb init # noqa: A002 id: str | None = None, # kept to match wandb init # noqa: A002 anonymous: bool | None = None, - version: str | None = None, project: str | None = None, - log_model: str | bool = False, - experiment: type["Run"] | type["RunDisabled"] | None = None, - prefix: str | None = "", + log_model: Literal["all"] | bool = False, + experiment: "Run | RunDisabled | None" = None, + prefix: str = "", + checkpoint_name: str | None = None, **kwargs, ) -> None: super().__init__( name=name, save_dir=save_dir, + version=version, offline=offline, + dir=dir, id=id, anonymous=anonymous, - version=version, project=project, log_model=log_model, experiment=experiment, prefix=prefix, + checkpoint_name=checkpoint_name, **kwargs, ) self.image_list: list[wandb.Image] = [] # Cache images