Skip to content

Commit

Permalink
[tune] Add hook to get project/group for W&B integration (ray-project…
Browse files Browse the repository at this point in the history
…#31035)

- Allow setting the W&B project and group environment variables from an external hook if it is not already passed to the `WandbLoggerCallback` and `setup_wandb`
- Add remaining external hooks to `setup_wandb`

Signed-off-by: Nikita Vemuri <nikitavemuri@gmail.com>
Signed-off-by: tmynn <hovhannes.tamoyan@gmail.com>
  • Loading branch information
nikitavemuri authored and tamohannes committed Jan 25, 2023
1 parent 414116e commit 4e39d31
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 14 deletions.
2 changes: 1 addition & 1 deletion doc/source/tune/examples/tune-wandb.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@
" :noindex:\n",
"```\n",
"\n",
"### Wandb-Mixin\n",
"### setup_wandb\n",
"\n",
"(air-wandb-setup)=\n",
"\n",
Expand Down
11 changes: 11 additions & 0 deletions python/ray/_private/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1737,3 +1737,14 @@ def get_gcs_memory_used():
}
assert "gcs_server" in m
return sum(m.values())


def wandb_populate_run_location_hook():
"""
Example external hook to populate W&B project and group env vars in
WandbIntegrationTest.testWandbLoggerConfig
"""
from ray.air.integrations.wandb import WANDB_GROUP_ENV_VAR, WANDB_PROJECT_ENV_VAR

os.environ[WANDB_PROJECT_ENV_VAR] = "test_project"
os.environ[WANDB_GROUP_ENV_VAR] = "test_group"
67 changes: 54 additions & 13 deletions python/ray/air/integrations/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@
# It doesn't take in any arguments and returns the W&B API key.
# Example: "your.module.wandb_setup_api_key_hook".
WANDB_SETUP_API_KEY_HOOK = "WANDB_SETUP_API_KEY_HOOK"
# Hook that is invoked before wandb.init in the setup method of WandbLoggerCallback
# to populate environment variables to specify the location
# (project and group) of the W&B run.
# It doesn't take in any arguments and doesn't return anything, but it does populate
# WANDB_PROJECT_NAME and WANDB_GROUP_NAME.
# Example: "your.module.wandb_populate_run_location_hook".
WANDB_POPULATE_RUN_LOCATION_HOOK = "WANDB_POPULATE_RUN_LOCATION_HOOK"
# Hook that is invoked after running wandb.init in WandbLoggerCallback
# to process information about the W&B run.
# It takes in a W&B run object and doesn't return anything.
Expand Down Expand Up @@ -143,6 +150,12 @@ def _setup_wandb(
api_key_file = os.path.expanduser(api_key_file)

_set_api_key(api_key_file, wandb_config.pop("api_key", None))
wandb_config["project"] = _get_wandb_project(wandb_config.get("project"))
wandb_config["group"] = (
os.environ.get(WANDB_GROUP_ENV_VAR)
if (not wandb_config.get("group") and os.environ.get(WANDB_GROUP_ENV_VAR))
else wandb_config.get("group")
)

# remove unpickleable items
_config = _clean_log(_config)
Expand All @@ -168,7 +181,9 @@ def _setup_wandb(

_wandb = _wandb or wandb

return _wandb.init(**wandb_init_kwargs)
run = _wandb.init(**wandb_init_kwargs)
_run_wandb_process_run_info_hook(run)
return run


def _is_allowed_type(obj):
Expand Down Expand Up @@ -224,6 +239,31 @@ def _clean_log(obj: Any):
return fallback


def _get_wandb_project(project: Optional[str] = None) -> Optional[str]:
"""Get W&B project from environment variable or external hook if not passed
as and argument."""
if (
not project
and not os.environ.get(WANDB_PROJECT_ENV_VAR)
and os.environ.get(WANDB_POPULATE_RUN_LOCATION_HOOK)
):
# Try to populate WANDB_PROJECT_ENV_VAR and WANDB_GROUP_ENV_VAR
# from external hook
try:
_load_class(os.environ[WANDB_POPULATE_RUN_LOCATION_HOOK])()
except Exception as e:
logger.exception(
f"Error executing {WANDB_POPULATE_RUN_LOCATION_HOOK} to "
f"populate {WANDB_PROJECT_ENV_VAR} and {WANDB_GROUP_ENV_VAR}: {e}",
exc_info=e,
)
if not project and os.environ.get(WANDB_PROJECT_ENV_VAR):
# Try to get project and group from environment variables if not
# passed through WandbLoggerCallback.
project = os.environ.get(WANDB_PROJECT_ENV_VAR)
return project


def _set_api_key(api_key_file: Optional[str] = None, api_key: Optional[str] = None):
"""Set WandB API key from `wandb_config`. Will pop the
`api_key_file` and `api_key` keys from `wandb_config` parameter"""
Expand Down Expand Up @@ -260,6 +300,17 @@ def _set_api_key(api_key_file: Optional[str] = None, api_key: Optional[str] = No
)


def _run_wandb_process_run_info_hook(run: Any) -> None:
"""Run external hook to process information about wandb run"""
if WANDB_PROCESS_RUN_INFO_HOOK in os.environ:
try:
_load_class(os.environ[WANDB_PROCESS_RUN_INFO_HOOK])(run)
except Exception as e:
logger.exception(
f"Error calling {WANDB_PROCESS_RUN_INFO_HOOK}: {e}", exc_info=e
)


class _QueueItem(enum.Enum):
END = enum.auto()
RESULT = enum.auto()
Expand Down Expand Up @@ -310,14 +361,7 @@ def run(self):
run = self._wandb.init(*self.args, **self.kwargs)
run.config.trial_log_path = self._logdir

# Run external hook to process information about wandb run
if WANDB_PROCESS_RUN_INFO_HOOK in os.environ:
try:
_load_class(os.environ[WANDB_PROCESS_RUN_INFO_HOOK])(run)
except Exception as e:
logger.exception(
f"Error calling {WANDB_PROCESS_RUN_INFO_HOOK}: {e}", exc_info=e
)
_run_wandb_process_run_info_hook(run)

while True:
item_type, item_content = self.queue.get()
Expand Down Expand Up @@ -468,10 +512,7 @@ def setup(self, *args, **kwargs):
)
_set_api_key(self.api_key_file, self.api_key)

# Try to get project and group from environment variables if not
# passed through WandbLoggerCallback.
if not self.project and os.environ.get(WANDB_PROJECT_ENV_VAR):
self.project = os.environ.get(WANDB_PROJECT_ENV_VAR)
self.project = _get_wandb_project(self.project)
if not self.project:
raise ValueError(
"Please pass the project name as argument or through "
Expand Down
17 changes: 17 additions & 0 deletions python/ray/tune/tests/test_integration_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ray.air.integrations.wandb import (
WANDB_ENV_VAR,
WANDB_GROUP_ENV_VAR,
WANDB_POPULATE_RUN_LOCATION_HOOK,
WANDB_PROJECT_ENV_VAR,
WANDB_SETUP_API_KEY_HOOK,
)
Expand Down Expand Up @@ -222,6 +223,22 @@ def test_wandb_logger_api_key_external_hook(self, monkeypatch):
logger.setup()
assert os.environ[WANDB_ENV_VAR] == "abcd"

def test_wandb_logger_run_location_external_hook(self, monkeypatch):
# No project
with pytest.raises(ValueError):
logger = WandbTestExperimentLogger(api_key="1234")
logger.setup()

# Project and group env vars from external hook
monkeypatch.setenv(
WANDB_POPULATE_RUN_LOCATION_HOOK,
"ray._private.test_utils.wandb_populate_run_location_hook",
)
logger = WandbTestExperimentLogger(api_key="1234")
logger.setup()
assert os.environ[WANDB_PROJECT_ENV_VAR] == "test_project"
assert os.environ[WANDB_GROUP_ENV_VAR] == "test_group"

def test_wandb_logger_start(self, monkeypatch, trial):
monkeypatch.setenv(WANDB_ENV_VAR, "9012")
# API Key in env
Expand Down

0 comments on commit 4e39d31

Please sign in to comment.