diff --git a/doc/source/tune/examples/tune-wandb.ipynb b/doc/source/tune/examples/tune-wandb.ipynb index 80c8bc5d7609..5242b7015995 100644 --- a/doc/source/tune/examples/tune-wandb.ipynb +++ b/doc/source/tune/examples/tune-wandb.ipynb @@ -770,7 +770,7 @@ " :noindex:\n", "```\n", "\n", - "### Wandb-Mixin\n", + "### setup_wandb\n", "\n", "(air-wandb-setup)=\n", "\n", diff --git a/python/ray/_private/test_utils.py b/python/ray/_private/test_utils.py index a089c07a1280..faa5d56a64f3 100644 --- a/python/ray/_private/test_utils.py +++ b/python/ray/_private/test_utils.py @@ -1737,3 +1737,14 @@ def get_gcs_memory_used(): } assert "gcs_server" in m return sum(m.values()) + + +def wandb_populate_run_location_hook(): + """ + Example external hook to populate W&B project and group env vars in + WandbIntegrationTest.testWandbLoggerConfig + """ + from ray.air.integrations.wandb import WANDB_GROUP_ENV_VAR, WANDB_PROJECT_ENV_VAR + + os.environ[WANDB_PROJECT_ENV_VAR] = "test_project" + os.environ[WANDB_GROUP_ENV_VAR] = "test_group" diff --git a/python/ray/air/integrations/wandb.py b/python/ray/air/integrations/wandb.py index 2b1dc733edb9..d07657353381 100644 --- a/python/ray/air/integrations/wandb.py +++ b/python/ray/air/integrations/wandb.py @@ -41,6 +41,13 @@ # It doesn't take in any arguments and returns the W&B API key. # Example: "your.module.wandb_setup_api_key_hook". WANDB_SETUP_API_KEY_HOOK = "WANDB_SETUP_API_KEY_HOOK" +# Hook that is invoked before wandb.init in the setup method of WandbLoggerCallback +# to populate environment variables to specify the location +# (project and group) of the W&B run. +# It doesn't take in any arguments and doesn't return anything, but it does populate +# WANDB_PROJECT_NAME and WANDB_GROUP_NAME. +# Example: "your.module.wandb_populate_run_location_hook". +WANDB_POPULATE_RUN_LOCATION_HOOK = "WANDB_POPULATE_RUN_LOCATION_HOOK" # Hook that is invoked after running wandb.init in WandbLoggerCallback # to process information about the W&B run. # It takes in a W&B run object and doesn't return anything. @@ -143,6 +150,12 @@ def _setup_wandb( api_key_file = os.path.expanduser(api_key_file) _set_api_key(api_key_file, wandb_config.pop("api_key", None)) + wandb_config["project"] = _get_wandb_project(wandb_config.get("project")) + wandb_config["group"] = ( + os.environ.get(WANDB_GROUP_ENV_VAR) + if (not wandb_config.get("group") and os.environ.get(WANDB_GROUP_ENV_VAR)) + else wandb_config.get("group") + ) # remove unpickleable items _config = _clean_log(_config) @@ -168,7 +181,9 @@ def _setup_wandb( _wandb = _wandb or wandb - return _wandb.init(**wandb_init_kwargs) + run = _wandb.init(**wandb_init_kwargs) + _run_wandb_process_run_info_hook(run) + return run def _is_allowed_type(obj): @@ -224,6 +239,31 @@ def _clean_log(obj: Any): return fallback +def _get_wandb_project(project: Optional[str] = None) -> Optional[str]: + """Get W&B project from environment variable or external hook if not passed + as and argument.""" + if ( + not project + and not os.environ.get(WANDB_PROJECT_ENV_VAR) + and os.environ.get(WANDB_POPULATE_RUN_LOCATION_HOOK) + ): + # Try to populate WANDB_PROJECT_ENV_VAR and WANDB_GROUP_ENV_VAR + # from external hook + try: + _load_class(os.environ[WANDB_POPULATE_RUN_LOCATION_HOOK])() + except Exception as e: + logger.exception( + f"Error executing {WANDB_POPULATE_RUN_LOCATION_HOOK} to " + f"populate {WANDB_PROJECT_ENV_VAR} and {WANDB_GROUP_ENV_VAR}: {e}", + exc_info=e, + ) + if not project and os.environ.get(WANDB_PROJECT_ENV_VAR): + # Try to get project and group from environment variables if not + # passed through WandbLoggerCallback. + project = os.environ.get(WANDB_PROJECT_ENV_VAR) + return project + + def _set_api_key(api_key_file: Optional[str] = None, api_key: Optional[str] = None): """Set WandB API key from `wandb_config`. Will pop the `api_key_file` and `api_key` keys from `wandb_config` parameter""" @@ -260,6 +300,17 @@ def _set_api_key(api_key_file: Optional[str] = None, api_key: Optional[str] = No ) +def _run_wandb_process_run_info_hook(run: Any) -> None: + """Run external hook to process information about wandb run""" + if WANDB_PROCESS_RUN_INFO_HOOK in os.environ: + try: + _load_class(os.environ[WANDB_PROCESS_RUN_INFO_HOOK])(run) + except Exception as e: + logger.exception( + f"Error calling {WANDB_PROCESS_RUN_INFO_HOOK}: {e}", exc_info=e + ) + + class _QueueItem(enum.Enum): END = enum.auto() RESULT = enum.auto() @@ -310,14 +361,7 @@ def run(self): run = self._wandb.init(*self.args, **self.kwargs) run.config.trial_log_path = self._logdir - # Run external hook to process information about wandb run - if WANDB_PROCESS_RUN_INFO_HOOK in os.environ: - try: - _load_class(os.environ[WANDB_PROCESS_RUN_INFO_HOOK])(run) - except Exception as e: - logger.exception( - f"Error calling {WANDB_PROCESS_RUN_INFO_HOOK}: {e}", exc_info=e - ) + _run_wandb_process_run_info_hook(run) while True: item_type, item_content = self.queue.get() @@ -468,10 +512,7 @@ def setup(self, *args, **kwargs): ) _set_api_key(self.api_key_file, self.api_key) - # Try to get project and group from environment variables if not - # passed through WandbLoggerCallback. - if not self.project and os.environ.get(WANDB_PROJECT_ENV_VAR): - self.project = os.environ.get(WANDB_PROJECT_ENV_VAR) + self.project = _get_wandb_project(self.project) if not self.project: raise ValueError( "Please pass the project name as argument or through " diff --git a/python/ray/tune/tests/test_integration_wandb.py b/python/ray/tune/tests/test_integration_wandb.py index 1724d870d325..15f2b4ca3a7f 100644 --- a/python/ray/tune/tests/test_integration_wandb.py +++ b/python/ray/tune/tests/test_integration_wandb.py @@ -28,6 +28,7 @@ from ray.air.integrations.wandb import ( WANDB_ENV_VAR, WANDB_GROUP_ENV_VAR, + WANDB_POPULATE_RUN_LOCATION_HOOK, WANDB_PROJECT_ENV_VAR, WANDB_SETUP_API_KEY_HOOK, ) @@ -222,6 +223,22 @@ def test_wandb_logger_api_key_external_hook(self, monkeypatch): logger.setup() assert os.environ[WANDB_ENV_VAR] == "abcd" + def test_wandb_logger_run_location_external_hook(self, monkeypatch): + # No project + with pytest.raises(ValueError): + logger = WandbTestExperimentLogger(api_key="1234") + logger.setup() + + # Project and group env vars from external hook + monkeypatch.setenv( + WANDB_POPULATE_RUN_LOCATION_HOOK, + "ray._private.test_utils.wandb_populate_run_location_hook", + ) + logger = WandbTestExperimentLogger(api_key="1234") + logger.setup() + assert os.environ[WANDB_PROJECT_ENV_VAR] == "test_project" + assert os.environ[WANDB_GROUP_ENV_VAR] == "test_group" + def test_wandb_logger_start(self, monkeypatch, trial): monkeypatch.setenv(WANDB_ENV_VAR, "9012") # API Key in env