diff --git a/rllib/BUILD b/rllib/BUILD index 72e97d846854..20908f0d9060 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -2120,15 +2120,6 @@ py_test( # subdirectory: checkpoints/ # .................................... -#@OldAPIStack -py_test( - name = "examples/checkpoints/cartpole_dqn_export", - main = "examples/checkpoints/cartpole_dqn_export.py", - tags = ["team:rllib", "exclusive", "examples"], - size = "small", - srcs = ["examples/checkpoints/cartpole_dqn_export.py"], -) - py_test( name = "examples/checkpoints/checkpoint_by_custom_criteria", main = "examples/checkpoints/checkpoint_by_custom_criteria.py", @@ -2138,6 +2129,42 @@ py_test( args = ["--enable-new-api-stack", "--stop-reward=150.0", "--num-cpus=8"] ) +py_test( + name = "examples/checkpoints/continue_training_from_checkpoint", + main = "examples/checkpoints/continue_training_from_checkpoint.py", + tags = ["team:rllib", "exclusive", "examples"], + size = "large", + srcs = ["examples/checkpoints/continue_training_from_checkpoint.py"], + args = ["--enable-new-api-stack", "--as-test"] +) + +py_test( + name = "examples/checkpoints/continue_training_from_checkpoint_multi_agent", + main = "examples/checkpoints/continue_training_from_checkpoint.py", + tags = ["team:rllib", "exclusive", "examples"], + size = "large", + srcs = ["examples/checkpoints/continue_training_from_checkpoint.py"], + args = ["--enable-new-api-stack", "--as-test", "--num-agents=2", "--stop-reward-crash=400.0", "--stop-reward=900.0"] +) + +#@OldAPIStack +py_test( + name = "examples/checkpoints/continue_training_from_checkpoint_old_api_stack", + main = "examples/checkpoints/continue_training_from_checkpoint.py", + tags = ["team:rllib", "exclusive", "examples"], + size = "large", + srcs = ["examples/checkpoints/continue_training_from_checkpoint.py"], + args = ["--as-test"] +) + +py_test( + name = "examples/checkpoints/cartpole_dqn_export", + main = "examples/checkpoints/cartpole_dqn_export.py", + tags = ["team:rllib", "exclusive", "examples"], + size = "small", + srcs = ["examples/checkpoints/cartpole_dqn_export.py"], +) + #@OldAPIStack py_test( name = "examples/checkpoints/onnx_tf2", diff --git a/rllib/algorithms/dqn/dqn.py b/rllib/algorithms/dqn/dqn.py index 3bf51e0edd42..8ad7ae3780fe 100644 --- a/rllib/algorithms/dqn/dqn.py +++ b/rllib/algorithms/dqn/dqn.py @@ -630,22 +630,27 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict: ) self.metrics.log_dict( - self.metrics.peek(ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED, default={}), + self.metrics.peek( + (ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED), default={} + ), key=NUM_AGENT_STEPS_SAMPLED_LIFETIME, reduce="sum", ) self.metrics.log_value( NUM_ENV_STEPS_SAMPLED_LIFETIME, - self.metrics.peek(ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED, default=0), + self.metrics.peek((ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED), default=0), reduce="sum", ) self.metrics.log_value( NUM_EPISODES_LIFETIME, - self.metrics.peek(ENV_RUNNER_RESULTS, NUM_EPISODES, default=0), + self.metrics.peek((ENV_RUNNER_RESULTS, NUM_EPISODES), default=0), reduce="sum", ) self.metrics.log_dict( - self.metrics.peek(ENV_RUNNER_RESULTS, NUM_MODULE_STEPS_SAMPLED, default={}), + self.metrics.peek( + (ENV_RUNNER_RESULTS, NUM_MODULE_STEPS_SAMPLED), + default={}, + ), key=NUM_MODULE_STEPS_SAMPLED_LIFETIME, reduce="sum", ) @@ -708,7 +713,7 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict: self.metrics.log_value( NUM_ENV_STEPS_TRAINED_LIFETIME, self.metrics.peek( - LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED + (LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED) ), reduce="sum", ) @@ -725,7 +730,7 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict: # TODO (sven): Uncomment this once agent steps are available in the # Learner stats. # self.metrics.log_dict(self.metrics.peek( - # LEARNER_RESULTS, NUM_AGENT_STEPS_TRAINED, default={} + # (LEARNER_RESULTS, NUM_AGENT_STEPS_TRAINED), default={} # ), key=NUM_AGENT_STEPS_TRAINED_LIFETIME, reduce="sum") # Update replay buffer priorities. diff --git a/rllib/algorithms/dreamerv3/dreamerv3.py b/rllib/algorithms/dreamerv3/dreamerv3.py index 50bcce11ad90..80159953f713 100644 --- a/rllib/algorithms/dreamerv3/dreamerv3.py +++ b/rllib/algorithms/dreamerv3/dreamerv3.py @@ -582,13 +582,13 @@ def training_step(self) -> ResultDict: self.metrics.log_dict( { NUM_AGENT_STEPS_SAMPLED_LIFETIME: self.metrics.peek( - ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED + (ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED) ), NUM_ENV_STEPS_SAMPLED_LIFETIME: self.metrics.peek( - ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED + (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED) ), NUM_EPISODES_LIFETIME: self.metrics.peek( - ENV_RUNNER_RESULTS, NUM_EPISODES + (ENV_RUNNER_RESULTS, NUM_EPISODES) ), }, reduce="sum", diff --git a/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_learner.py b/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_learner.py index d35717e4aa44..15fda2445bbd 100644 --- a/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_learner.py +++ b/rllib/algorithms/dreamerv3/tf/dreamerv3_tf_learner.py @@ -158,7 +158,7 @@ def compute_gradients( # Take individual loss term from the registered metrics for # the main module. self.metrics.peek( - DEFAULT_MODULE_ID, component.upper() + "_L_total" + (DEFAULT_MODULE_ID, component.upper() + "_L_total") ), self.filter_param_dict_for_optimizer( self._params, self.get_optimizer(optimizer_name=component) diff --git a/rllib/algorithms/dreamerv3/utils/summaries.py b/rllib/algorithms/dreamerv3/utils/summaries.py index dd36adbb3160..15768e333848 100644 --- a/rllib/algorithms/dreamerv3/utils/summaries.py +++ b/rllib/algorithms/dreamerv3/utils/summaries.py @@ -217,9 +217,7 @@ def report_dreamed_eval_trajectory_vs_samples( the report/videos. """ dream_data = metrics.peek( - LEARNER_RESULTS, - DEFAULT_MODULE_ID, - "dream_data", + (LEARNER_RESULTS, DEFAULT_MODULE_ID, "dream_data"), default={}, ) metrics.delete(LEARNER_RESULTS, DEFAULT_MODULE_ID, "dream_data", key_error=False) diff --git a/rllib/algorithms/ppo/ppo.py b/rllib/algorithms/ppo/ppo.py index 60dfe4b6eed6..657bec2f1034 100644 --- a/rllib/algorithms/ppo/ppo.py +++ b/rllib/algorithms/ppo/ppo.py @@ -463,13 +463,13 @@ def _training_step_new_api_stack(self) -> ResultDict: self.metrics.log_dict( { NUM_AGENT_STEPS_SAMPLED_LIFETIME: self.metrics.peek( - ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED + (ENV_RUNNER_RESULTS, NUM_AGENT_STEPS_SAMPLED) ), NUM_ENV_STEPS_SAMPLED_LIFETIME: self.metrics.peek( - ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED + (ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED) ), NUM_EPISODES_LIFETIME: self.metrics.peek( - ENV_RUNNER_RESULTS, NUM_EPISODES + (ENV_RUNNER_RESULTS, NUM_EPISODES) ), }, reduce="sum", @@ -494,10 +494,10 @@ def _training_step_new_api_stack(self) -> ResultDict: self.metrics.log_dict( { NUM_ENV_STEPS_TRAINED_LIFETIME: self.metrics.peek( - LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED + (LEARNER_RESULTS, ALL_MODULES, NUM_ENV_STEPS_TRAINED) ), # NUM_MODULE_STEPS_TRAINED_LIFETIME: self.metrics.peek( - # LEARNER_RESULTS, NUM_MODULE_STEPS_TRAINED + # (LEARNER_RESULTS, NUM_MODULE_STEPS_TRAINED) # ), }, reduce="sum", @@ -531,7 +531,9 @@ def _training_step_new_api_stack(self) -> ResultDict: if self.config.use_kl_loss: for mid in modules_to_update: kl = convert_to_numpy( - self.metrics.peek(LEARNER_RESULTS, mid, LEARNER_RESULTS_KL_KEY) + self.metrics.peek( + (LEARNER_RESULTS, mid, LEARNER_RESULTS_KL_KEY) + ) ) if np.isnan(kl): logger.warning( diff --git a/rllib/algorithms/sac/torch/sac_torch_learner.py b/rllib/algorithms/sac/torch/sac_torch_learner.py index 229b8cc4549f..879ee2887db2 100644 --- a/rllib/algorithms/sac/torch/sac_torch_learner.py +++ b/rllib/algorithms/sac/torch/sac_torch_learner.py @@ -314,7 +314,7 @@ def compute_gradients( for component in ( ["qf", "policy", "alpha"] + ["qf_twin"] if config.twin_q else [] ): - self.metrics.peek(module_id, component + "_loss").backward( + self.metrics.peek((module_id, component + "_loss")).backward( retain_graph=True ) grads.update( diff --git a/rllib/examples/checkpoints/checkpoint_by_custom_criteria.py b/rllib/examples/checkpoints/checkpoint_by_custom_criteria.py index e0a7f7de4f62..0419a8ae1512 100644 --- a/rllib/examples/checkpoints/checkpoint_by_custom_criteria.py +++ b/rllib/examples/checkpoints/checkpoint_by_custom_criteria.py @@ -1,7 +1,7 @@ """Example extracting a checkpoint from n trials using one or more custom criteria. This example: -- runs a simple CartPole experiment with three different learning rates (three tune +- runs a CartPole experiment with three different learning rates (three tune "trials"). During the experiment, for each trial, we create a checkpoint at each iteration. - at the end of the experiment, we compare the trials and pick the one that performed diff --git a/rllib/examples/checkpoints/continue_training_from_checkpoint.py b/rllib/examples/checkpoints/continue_training_from_checkpoint.py new file mode 100644 index 000000000000..a8400659d960 --- /dev/null +++ b/rllib/examples/checkpoints/continue_training_from_checkpoint.py @@ -0,0 +1,267 @@ +"""Example showing how to restore an Algorithm from a checkpoint and resume training. + +Use the setup shown in this script if your experiments tend to crash after some time, +and you would therefore like to make your setup more robust and fault-tolerant. + +This example: +- runs a single- or multi-agent CartPole experiment (for multi-agent, we use different +learning rates) thereby checkpointing the state of the Algorithm every n iterations. +- stops the experiment due to an expected crash in the algorithm's main process after +a certain number of iterations. +- just for testing purposes, restores the entire algorithm from the latest checkpoint +and checks, whether the state of the restored algo exactly match the state of the +crashed one. +- then continues training with the restored algorithm until the desired final episode +return is reached. + + +How to run this script +---------------------- +`python [script file name].py --enable-new-api-stack --num-agents=[0 or 2] +--stop-reward-crash=[the episode return after which the algo should crash] +--stop-reward=[the final episode return to achieve after(!) restoration from the +checkpoint] +` + +For debugging, use the following additional command line options +`--no-tune --num-env-runners=0` +which should allow you to set breakpoints anywhere in the RLlib code and +have the execution stop there for inspection and debugging. + +For logging to your WandB account, use: +`--wandb-key=[your WandB API key] --wandb-project=[some project name] +--wandb-run-name=[optional: WandB run name (within the defined project)]` + + +Results to expect +----------------- +First, you should see the initial tune.Tuner do it's thing: + +Trial status: 1 RUNNING +Current time: 2024-06-03 12:03:39. Total running time: 30s +Logical resource usage: 3.0/12 CPUs, 0/0 GPUs +╭──────────────────────────────────────────────────────────────────────── +│ Trial name status iter total time (s) +├──────────────────────────────────────────────────────────────────────── +│ PPO_CartPole-v1_7b1eb_00000 RUNNING 6 15.362 +╰──────────────────────────────────────────────────────────────────────── +───────────────────────────────────────────────────────────────────────╮ +..._sampled_lifetime ..._trained_lifetime ...episodes_lifetime │ +───────────────────────────────────────────────────────────────────────┤ + 24000 24000 340 │ +───────────────────────────────────────────────────────────────────────╯ +... + +then, you should see the experiment crashing as soon as the `--stop-reward-crash` +has been reached: + +```RuntimeError: Intended crash after reaching trigger return.``` + +At some point, the experiment should resume exactly where it left off (using +the checkpoint and restored Tuner): + +Trial status: 1 RUNNING +Current time: 2024-06-03 12:05:00. Total running time: 1min 0s +Logical resource usage: 3.0/12 CPUs, 0/0 GPUs +╭──────────────────────────────────────────────────────────────────────── +│ Trial name status iter total time (s) +├──────────────────────────────────────────────────────────────────────── +│ PPO_CartPole-v1_7b1eb_00000 RUNNING 27 66.1451 +╰──────────────────────────────────────────────────────────────────────── +───────────────────────────────────────────────────────────────────────╮ +..._sampled_lifetime ..._trained_lifetime ...episodes_lifetime │ +───────────────────────────────────────────────────────────────────────┤ + 108000 108000 531 │ +───────────────────────────────────────────────────────────────────────╯ + +And if you are using the `--as-test` option, you should see a finel message: + +``` +`env_runners/episode_return_mean` of 500.0 reached! ok +``` +""" +import re +import time + +from ray import train, tune +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig +from ray.rllib.algorithms.callbacks import DefaultCallbacks +from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole +from ray.rllib.policy.policy import PolicySpec +from ray.rllib.utils.metrics import ( + ENV_RUNNER_RESULTS, + EPISODE_RETURN_MEAN, +) +from ray.rllib.utils.test_utils import ( + add_rllib_example_script_args, + check_learning_achieved, +) +from ray.tune.registry import get_trainable_cls, register_env +from ray.air.integrations.wandb import WandbLoggerCallback + + +parser = add_rllib_example_script_args( + default_reward=500.0, default_timesteps=10000000, default_iters=2000 +) +parser.add_argument( + "--stop-reward-crash", + type=float, + default=200.0, + help="Mean episode return after which the Algorithm should crash.", +) +# By default, set `args.checkpoint_freq` to 1 and `args.checkpoint_at_end` to True. +parser.set_defaults(checkpoint_freq=1, checkpoint_at_end=True) + + +class CrashAfterNIters(DefaultCallbacks): + """Callback that makes the algo crash after a certain avg. return is reached.""" + + def __init__(self): + super().__init__() + # We have to delay crashing by one iteration just so the checkpoint still + # gets created by Tune after(!) we have reached the trigger avg. return. + self._should_crash = False + + def on_train_result(self, *, algorithm, metrics_logger, result, **kwargs): + # We had already reached the mean-return to crash, the last checkpoint written + # (the one from the previous iteration) should yield that exact avg. return. + if self._should_crash: + raise RuntimeError("Intended crash after reaching trigger return.") + # Reached crashing criterion, crash on next iteration. + elif result[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] >= args.stop_reward_crash: + print( + "Reached trigger return of " + f"{result[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN]}" + ) + self._should_crash = True + + +if __name__ == "__main__": + args = parser.parse_args() + + register_env( + "ma_cart", lambda cfg: MultiAgentCartPole({"num_agents": args.num_agents}) + ) + + # Simple generic config. + config = ( + get_trainable_cls(args.algo) + .get_default_config() + .api_stack( + enable_rl_module_and_learner=args.enable_new_api_stack, + enable_env_runner_and_connector_v2=args.enable_new_api_stack, + ) + .environment("CartPole-v1" if args.num_agents == 0 else "ma_cart") + .env_runners(create_env_on_local_worker=True) + .training(lr=0.0001) + .callbacks(CrashAfterNIters) + ) + + # Tune config. + # Need a WandB callback? + tune_callbacks = [] + if args.wandb_key: + project = args.wandb_project or ( + args.algo.lower() + "-" + re.sub("\\W+", "-", str(config.env).lower()) + ) + tune_callbacks.append( + WandbLoggerCallback( + api_key=args.wandb_key, + project=args.wandb_project, + upload_checkpoints=False, + **({"name": args.wandb_run_name} if args.wandb_run_name else {}), + ) + ) + + # Setup multi-agent, if required. + if args.num_agents > 0: + config.multi_agent( + policies={ + f"p{aid}": PolicySpec( + config=AlgorithmConfig.overrides( + lr=5e-5 + * (aid + 1), # agent 1 has double the learning rate as 0. + ) + ) + for aid in range(args.num_agents) + }, + policy_mapping_fn=lambda aid, *a, **kw: f"p{aid}", + ) + + # Define some stopping criterion. Note that this criterion is an avg episode return + # to be reached. The stop criterion does not consider the built-in crash we are + # triggering through our callback. + stop = { + f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": args.stop_reward, + } + + # Run tune for some iterations and generate checkpoints. + tuner = tune.Tuner( + trainable=config.algo_class, + param_space=config, + run_config=train.RunConfig( + callbacks=tune_callbacks, + checkpoint_config=train.CheckpointConfig( + checkpoint_frequency=args.checkpoint_freq, + checkpoint_at_end=args.checkpoint_at_end, + ), + stop=stop, + ), + ) + tuner_results = tuner.fit() + + # Perform a very quick test to make sure our algo (upon restoration) did not lose + # its ability to perform well in the env. + # - Extract the best checkpoint. + metric = f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}" + best_result = tuner_results.get_best_result(metric=metric, mode="max") + assert ( + best_result.metrics[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] + >= args.stop_reward_crash + ) + # - Change our config, such that the restored algo will have an env on the local + # EnvRunner (to perform evaluation) and won't crash anymore (remove the crashing + # callback). + config.callbacks(None) + # Rebuild the algorithm (just for testing purposes). + test_algo = config.build() + # Load algo's state from best checkpoint. + test_algo.restore(best_result.checkpoint) + # Perform some checks on the restored state. + assert test_algo.training_iteration > 0 + # Evaluate on the restored algorithm. + test_eval_results = test_algo.evaluate() + assert ( + test_eval_results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] + >= args.stop_reward_crash + ) + # Train one iteration to make sure, the performance does not collapse (e.g. due + # to the optimizer weights not having been restored properly). + test_results = test_algo.train() + assert ( + test_results[ENV_RUNNER_RESULTS][EPISODE_RETURN_MEAN] >= args.stop_reward_crash + ) + # Stop the test algorithm again. + test_algo.stop() + + # Create a new Tuner from the existing experiment path (which contains the tuner's + # own checkpoint file). Note that even the WandB logging will be continued without + # creating a new WandB run name. + restored_tuner = tune.Tuner.restore( + path=tuner_results.experiment_path, + trainable=config.algo_class, + param_space=config, + # Important to set this to True b/c the previous trial had failed (due to our + # `CrashAfterNIters` callback). + resume_errored=True, + ) + # Continue the experiment exactly where we left off. + tuner_results = restored_tuner.fit() + + # Not sure, whether this is really necessary, but we have observed the WandB + # logger sometimes not logging some of the last iterations. This sleep here might + # give it enough time to do so. + time.sleep(20) + + if args.as_test: + check_learning_achieved(tuner_results, args.stop_reward, metric=metric) diff --git a/rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py b/rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py index 4338791c71fa..fb53e2cb876f 100644 --- a/rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py +++ b/rllib/examples/checkpoints/restore_1_of_n_agents_from_checkpoint.py @@ -48,7 +48,11 @@ from ray.air.constants import TRAINING_ITERATION from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec from ray.rllib.examples.envs.classes.multi_agent import MultiAgentPendulum -from ray.rllib.utils.metrics import ENV_RUNNER_RESULTS, NUM_ENV_STEPS_SAMPLED_LIFETIME +from ray.rllib.utils.metrics import ( + ENV_RUNNER_RESULTS, + EPISODE_RETURN_MEAN, + NUM_ENV_STEPS_SAMPLED_LIFETIME, +) from ray.rllib.utils.test_utils import ( add_rllib_example_script_args, run_rllib_example_script_experiment, @@ -142,9 +146,7 @@ ) # Define stopping criteria. stop = { - # TODO (simon): Change to -800 once the metrics are fixed. Currently - # the combined return is not correctly computed. - f"{ENV_RUNNER_RESULTS}/episode_return_mean": -400, + f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": -800, f"{NUM_ENV_STEPS_SAMPLED_LIFETIME}": 20000, TRAINING_ITERATION: 30, } diff --git a/rllib/utils/metrics/metrics_logger.py b/rllib/utils/metrics/metrics_logger.py index 611d6f7d9b8c..d22629041ef7 100644 --- a/rllib/utils/metrics/metrics_logger.py +++ b/rllib/utils/metrics/metrics_logger.py @@ -102,7 +102,7 @@ def log_value( # Log a value under a deeper nested key. logger.log_value(("some", "nested", "key"), -1.0) - check(logger.peek("some", "nested", "key"), -1.0) + check(logger.peek(("some", "nested", "key")), -1.0) # Log n values without reducing them (we want to just collect some items). logger.log_value("some_items", 5.0, reduce=None) @@ -169,12 +169,16 @@ def log_value( if not self._key_in_stats(key): self._set_key( key, - Stats( - value, - reduce=reduce, - window=window, - ema_coeff=ema_coeff, - clear_on_reduce=clear_on_reduce, + ( + Stats.similar_to(value, init_value=value.values) + if isinstance(value, Stats) + else Stats( + value, + reduce=reduce, + window=window, + ema_coeff=ema_coeff, + clear_on_reduce=clear_on_reduce, + ) ), ) # If value itself is a stat, we merge it on time axis into `self`. @@ -229,7 +233,7 @@ def log_dict( # Peek at the current (reduced) values under "a" and "b". check(logger.peek("a"), 0.15) check(logger.peek("b"), -0.15) - check(logger.peek("c", "d"), 5.0) + check(logger.peek(("c", "d")), 5.0) # Reduced all stats. results = logger.reduce(return_stats_obj=False) @@ -320,7 +324,7 @@ def merge_and_log_n_dicts( [learner1_results, learner2_results], key="learners", ) - check(main_logger.peek("learners", "loss"), 0.15) + check(main_logger.peek(("learners", "loss")), 0.15) # Example: m EnvRunners logging episode returns to be merged. main_logger = MetricsLogger() @@ -358,7 +362,7 @@ def merge_and_log_n_dicts( main_logger.stats["env_runners"]["mean_ret"].values, [325, 325, 425, 425], ) - check(main_logger.peek("env_runners", "mean_ret"), (325 + 425 + 425) / 3) + check(main_logger.peek(("env_runners", "mean_ret")), (325 + 425 + 425) / 3) # Example: Lifetime sum over n parallel components' stats. main_logger = MetricsLogger() @@ -604,7 +608,12 @@ def tensors_to_numpy(self, tensor_metrics): def tensor_mode(self): return self._tensor_mode - def peek(self, *key, default: Optional[Any] = None) -> Any: + def peek( + self, + key: Union[str, Tuple[str]], + *, + default: Optional[Any] = None, + ) -> Any: """Returns the (reduced) value(s) found under the given key or key sequence. If `key` only reaches to a nested dict deeper in `self`, that @@ -635,7 +644,7 @@ def peek(self, *key, default: Optional[Any] = None) -> Any: # Peek at the (reduced) nested struct under ("some", "nested"). check( - logger.peek("some", "nested"), # <- *args work as well + logger.peek(("some", "nested")), {"key": {"sequence": expected_reduced}}, )