From 08c22b61257461ff7f6bbb43ec03b4bd2a052ccd Mon Sep 17 00:00:00 2001 From: dberenbaum Date: Mon, 22 Jan 2024 15:19:24 -0500 Subject: [PATCH] integrations: fix DVCLiveCallback model logging --- .../integrations/integration_utils.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/transformers/integrations/integration_utils.py b/src/transformers/integrations/integration_utils.py index 145a3b25289f..dff98adbaf75 100644 --- a/src/transformers/integrations/integration_utils.py +++ b/src/transformers/integrations/integration_utils.py @@ -1635,16 +1635,21 @@ def __init__( raise RuntimeError("DVCLiveCallback requires dvclive to be installed. Run `pip install dvclive`.") from dvclive import Live - self._log_model = log_model - self._initialized = False self.live = None if isinstance(live, Live): self.live = live - self._initialized = True elif live is not None: raise RuntimeError(f"Found class {live.__class__} for live, expected dvclive.Live") + self._log_model = log_model + if self._log_model is None: + log_model_env = os.getenv("HF_DVCLIVE_LOG_MODEL", "FALSE") + if log_model_env.upper() in ENV_VARS_TRUE_VALUES: + self._log_model = True + elif log_model_env.lower() == "all": + self._log_model = "all" + def setup(self, args, state, model): """ Setup the optional DVCLive integration. To customize this callback beyond the environment variables below, see @@ -1659,12 +1664,6 @@ def setup(self, args, state, model): from dvclive import Live self._initialized = True - if self._log_model is not None: - log_model_env = os.getenv("HF_DVCLIVE_LOG_MODEL") - if log_model_env.upper() in ENV_VARS_TRUE_VALUES: - self._log_model = True - elif log_model_env.lower() == "all": - self._log_model = "all" if state.is_world_process_zero: if not self.live: self.live = Live()