Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AIR] Allow users to exclude config values with WandbLoggerCallback #31624

Merged
merged 1 commit into from
Jan 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions python/ray/air/integrations/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,10 +419,10 @@ def _handle_result(self, result: Dict) -> Tuple[Dict, Dict]:
flat_result = flatten_dict(result, delimiter="/")

for k, v in flat_result.items():
if any(k.startswith(item + "/") or k == item for item in self._to_config):
config_update[k] = v
elif any(k.startswith(item + "/") or k == item for item in self._exclude):
if any(k.startswith(item + "/") or k == item for item in self._exclude):
continue
elif any(k.startswith(item + "/") or k == item for item in self._to_config):
config_update[k] = v
elif not _is_allowed_type(v):
continue
else:
Expand All @@ -448,7 +448,7 @@ class WandbLoggerCallback(LoggerCallback):
file only needs to be present on the node running the Tune script
if using the WandbLogger.
api_key: Wandb API Key. Alternative to setting ``api_key_file``.
excludes: List of metrics that should be excluded from
excludes: List of metrics and config that should be excluded from
the log.
log_config: Boolean indicating if the ``config`` parameter of
the ``results`` dict should be logged. This makes sense if
Expand Down Expand Up @@ -488,8 +488,7 @@ class WandbLoggerCallback(LoggerCallback):
# Do not log these result keys
_exclude_results = ["done", "should_checkpoint"]

# Use these result keys to update `wandb.config`
_config_results = [
AUTO_CONFIG_KEYS = [
"trial_id",
"experiment_tag",
"node_ip",
Expand All @@ -498,6 +497,7 @@ class WandbLoggerCallback(LoggerCallback):
"pid",
"date",
]
"""Results that are saved with `wandb.config` instead of `wandb.log`."""

_logger_actor_cls = _WandbLoggingActor

Expand Down Expand Up @@ -570,6 +570,9 @@ def log_trial_start(self, trial: "Trial"):

# remove unpickleable items!
config = _clean_log(config)
config = {
key: value for key, value in config.items() if key not in self.excludes
}

wandb_init_kwargs = dict(
id=trial_id,
Expand Down Expand Up @@ -606,7 +609,7 @@ def _start_logging_actor(
logdir=trial.logdir,
queue=self._trial_queues[trial],
exclude=exclude_results,
to_config=self._config_results,
to_config=self.AUTO_CONFIG_KEYS,
**wandb_init_kwargs,
)
self._trial_logging_futures[trial] = self._trial_logging_actors[
Expand Down
57 changes: 52 additions & 5 deletions python/ray/air/tests/test_integration_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,28 @@ class _MockWandbConfig:
kwargs: Dict


class _FakeConfig:
def update(self, config, *args, **kwargs):
for key, value in config.items():
setattr(self, key, value)

def __iter__(self):
return iter(self.__dict__)


class _MockWandbAPI:
def __init__(self):
self.logs = Queue()
self.config = _FakeConfig()

def init(self, *args, **kwargs):
mock = Mock()
mock.args = args
mock.kwargs = kwargs

if "config" in kwargs:
self.config.update(kwargs["config"])

return mock

def log(self, data):
Expand All @@ -85,10 +99,6 @@ def log(self, data):
def finish(self):
pass

@property
def config(self):
return Mock()


class _MockWandbLoggingActor(_WandbLoggingActor):
def __init__(self, logdir, queue, exclude, to_config, *args, **kwargs):
Expand All @@ -109,7 +119,7 @@ def _start_logging_actor(self, trial, exclude_results, **wandb_init_kwargs):
logdir=trial.logdir,
queue=self._trial_queues[trial],
exclude=exclude_results,
to_config=self._config_results,
to_config=self.AUTO_CONFIG_KEYS,
**wandb_init_kwargs,
)
self._trial_logging_actors[trial] = local_actor
Expand Down Expand Up @@ -300,6 +310,43 @@ def test_wandb_logger_reporting(self, trial):
assert "const" not in logged
assert "config" not in logged

def test_wandb_logger_auto_config_keys(self, trial):
logger = WandbTestExperimentLogger(project="test_project", api_key="1234")
logger.on_trial_start(iteration=0, trials=[], trial=trial)
config = logger.trial_processes[trial]._wandb.config

result = {key: 0 for key in WandbLoggerCallback.AUTO_CONFIG_KEYS}
logger.on_trial_result(0, [], trial, result)

logger.on_trial_complete(0, [], trial)
# The results in `AUTO_CONFIG_KEYS` should be saved as training configuration
# instead of output metrics.
assert set(WandbLoggerCallback.AUTO_CONFIG_KEYS) < set(config)

def test_wandb_logger_exclude_config(self):
trial = Trial(
config={"param1": 0, "param2": 0},
trial_id=0,
trial_name="trial_0",
experiment_dir_name="trainable",
placement_group_factory=PlacementGroupFactory([{"CPU": 1}]),
logdir=tempfile.gettempdir(),
)
logger = WandbTestExperimentLogger(
project="test_project",
api_key="1234",
excludes=(["param2"] + WandbLoggerCallback.AUTO_CONFIG_KEYS),
)
logger.on_trial_start(iteration=0, trials=[], trial=trial)
config = logger.trial_processes[trial]._wandb.config

# We need to test that `excludes` also applies to `AUTO_CONFIG_KEYS`.
result = {key: 0 for key in WandbLoggerCallback.AUTO_CONFIG_KEYS}
logger.on_trial_result(0, [], trial, result)

logger.on_trial_complete(0, [], trial)
assert set(config) == {"param1"}

def test_set_serializability_result(self, trial):
"""Tests that objects that contain sets can be serialized by wandb."""
logger = WandbTestExperimentLogger(
Expand Down