Skip to content

Commit

Permalink
[RLlib] Smaller eval worker set fixes. (ray-project#28811)
Browse files Browse the repository at this point in the history
Signed-off-by: Weichen Xu <weichen.xu@databricks.com>
  • Loading branch information
sven1977 authored and WeichenXu123 committed Dec 19, 2022
1 parent da43a54 commit a1985e0
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 11 deletions.
14 changes: 7 additions & 7 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -781,13 +781,6 @@ py_test(
srcs = ["algorithms/tests/test_memory_leaks.py"]
)

py_test(
name = "test_worker_failures",
tags = ["team:rllib", "tests_dir", "tests_dir_W"],
size = "large",
srcs = ["tests/test_worker_failures.py"]
)

py_test(
name = "test_node_failure",
tags = ["team:rllib", "tests_dir", "tests_dir_N", "exclusive"],
Expand All @@ -802,6 +795,13 @@ py_test(
srcs = ["algorithms/tests/test_registry.py"],
)

py_test(
name = "test_worker_failures",
tags = ["team:rllib", "algorithms_dir", "algorithms_dir_generic"],
size = "large",
srcs = ["algorithms/tests/test_worker_failures.py"]
)

# Specific Algorithms

# A2C
Expand Down
12 changes: 10 additions & 2 deletions rllib/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,7 @@ def duration_fn(num_units_done):
# subsequent step results as latest evaluation result.
self.evaluation_metrics = {"evaluation": metrics}

# Trigger `on_evaluate_end` callback.
self.callbacks.on_evaluate_end(
algorithm=self, evaluation_metrics=self.evaluation_metrics
)
Expand Down Expand Up @@ -1147,6 +1148,11 @@ def remote_fn(worker, w_ref, w_seq_no):
# subsequent step results as latest evaluation result.
self.evaluation_metrics = {"evaluation": metrics}

# Trigger `on_evaluate_end` callback.
self.callbacks.on_evaluate_end(
algorithm=self, evaluation_metrics=self.evaluation_metrics
)

# Return evaluation results.
return self.evaluation_metrics

Expand Down Expand Up @@ -2654,7 +2660,6 @@ def _run_one_evaluation(
"episode_reward_mean": np.nan,
}
}
eval_results["evaluation"]["num_recreated_workers"] = 0

eval_func_to_use = (
self._evaluate_async
Expand Down Expand Up @@ -2694,14 +2699,17 @@ def _run_one_evaluation(
"recreate_failed_workers"
),
)
# `self._evaluate_async` handles its own worker failures and already adds
# this metric, but `self.evaluate` doesn't.
if "num_recreated_workers" not in eval_results["evaluation"]:
eval_results["evaluation"]["num_recreated_workers"] = num_recreated

# Add number of healthy evaluation workers after this iteration.
eval_results["evaluation"]["num_healthy_workers"] = (
len(self.evaluation_workers.remote_workers())
if self.evaluation_workers is not None
else 0
)
eval_results["evaluation"]["num_recreated_workers"] = num_recreated

return eval_results

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def is_recreated(w):
)


class TestWorkerFailure(unittest.TestCase):
class TestWorkerFailures(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()
Expand Down
2 changes: 1 addition & 1 deletion rllib/evaluation/rollout_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,7 +1634,7 @@ def set_weights(
>>> worker.set_weights(weights, {"timestep": 42}) # doctest: +SKIP
"""
# Only update our weights, if no seq no given OR given seq no is different
# from ours
# from ours.
if weights_seq_no is None or weights_seq_no != self.weights_seq_no:
# If per-policy weights are object refs, `ray.get()` them first.
if weights and isinstance(next(iter(weights.values())), ObjectRef):
Expand Down

0 comments on commit a1985e0

Please sign in to comment.