Skip to content

Commit

Permalink
[RLlib] MetricsLogger cleanup throughput logic. (ray-project#49981)
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 authored and win5923 committed Jan 23, 2025
1 parent fe296b8 commit 4541731
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 33 deletions.
38 changes: 9 additions & 29 deletions rllib/utils/metrics/metrics_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class MetricsLogger:
- Reducing these collected values using a user specified reduction method (for
example "min" or "mean") and other settings controlling the reduction and internal
data, such as sliding windows or EMA coefficients.
- Resetting the logged values after a `reduce()` call in order to make space for
new values to be logged.
- Optionally clearing all logged values after a `reduce()` call to make space for
new data.
.. testcode::
Expand Down Expand Up @@ -335,7 +335,8 @@ def log_value(
object under the logged key then keeps track of the time passed
between two consecutive calls to `reduce()` and update its throughput
estimate. The current throughput estimate of a key can be obtained
through: `MetricsLogger.peek([some key], throughput=True)`.
through: peeked_value, throuthput_per_sec =
<MetricsLogger>.peek([key], throughput=True).
"""
# No reduction (continue appending to list) AND no window.
# -> We'll force-reset our values upon `reduce()`.
Expand Down Expand Up @@ -701,14 +702,9 @@ def log_time(
window: Optional[Union[int, float]] = None,
ema_coeff: Optional[float] = None,
clear_on_reduce: bool = False,
key_for_throughput: Optional[Union[str, Tuple[str, ...]]] = None,
key_for_unit_count: Optional[Union[str, Tuple[str, ...]]] = None,
) -> Stats:
"""Measures and logs a time delta value under `key` when used with a with-block.
Additionally, measures and logs the throughput for the timed code, iff
`key_for_throughput` and `key_for_unit_count` are provided.
.. testcode::
import time
Expand Down Expand Up @@ -769,32 +765,13 @@ def log_time(
clear_on_reduce = True

if not self._key_in_stats(key):
measure_throughput = None
if key_for_unit_count is not None:
measure_throughput = True
key_for_throughput = key_for_throughput or (key + "_throughput_per_s")

self._set_key(
key,
Stats(
reduce=reduce,
window=window,
ema_coeff=ema_coeff,
clear_on_reduce=clear_on_reduce,
on_exit=(
lambda time_delta_s, kt=key_for_throughput, ku=key_for_unit_count, r=reduce, w=window, e=ema_coeff, c=clear_on_reduce: ( # noqa
self.log_value(
kt,
value=self.peek(ku) / time_delta_s,
reduce=r,
window=w,
ema_coeff=e,
clear_on_reduce=c,
)
)
)
if measure_throughput
else None,
),
)

Expand Down Expand Up @@ -933,7 +910,9 @@ def _reduce(path, stats):

try:
with self._threading_lock:
assert not self.tensor_mode
assert (
not self.tensor_mode
), "Can't reduce if `self.tensor_mode` is True!"
reduced = copy.deepcopy(
tree.map_structure_with_path(_reduce, stats_to_return)
)
Expand Down Expand Up @@ -1048,7 +1027,8 @@ def set_value(
object under the logged key then keeps track of the time passed
between two consecutive calls to `reduce()` and update its throughput
estimate. The current throughput estimate of a key can be obtained
through: `MetricsLogger.peek([some key], throughput=True)`.
through: peeked_value, throuthput_per_sec =
<MetricsLogger>.peek([key], throughput=True).
"""
# Key already in self -> Erase internal values list with [`value`].
if self._key_in_stats(key):
Expand Down
11 changes: 7 additions & 4 deletions rllib/utils/metrics/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,15 @@ def __init__(
to True is useful for cases, in which the internal values list would
otherwise grow indefinitely, for example if reduce is None and there
is no `window` provided.
with_throughput: Whether to track a throughput estimate together with this
throughput: If True, track a throughput estimate together with this
Stats. This is only supported for `reduce=sum` and
`clear_on_reduce=False` metrics (aka. "lifetime counts"). The `Stats`
then keeps track of the time passed between two consecutive calls to
`reduce()` and update its throughput estimate. The current throughput
estimate of a key can be obtained through:
`Stats.peek([some key], throughput=True)`.
`peeked_val, throughput_per_sec = Stats.peek([key], throughput=True)`.
If a float, track throughput and also set current throughput estimate
to the given value.
"""
# Thus far, we only support mean, max, min, and sum.
if reduce not in [None, "mean", "min", "max", "sum"]:
Expand Down Expand Up @@ -318,9 +320,10 @@ class for details on the reduction logic applied to the values list, based on
# Take the delta between the new (upcoming) reduced value and the most
# recently reduced value (one `reduce()` call ago).
delta_sum = reduced - self._hist[-1]
assert delta_sum >= 0
time_now = time.perf_counter()
if self._throughput_last_time == -1:
# `delta_sum` may be < 0.0 if user overrides a metric through
# `.set_value()`.
if self._throughput_last_time == -1 or delta_sum < 0.0:
self._throughput = np.nan
else:
delta_time = time_now - self._throughput_last_time
Expand Down

0 comments on commit 4541731

Please sign in to comment.