Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] MetricsLogger cleanup throughput logic. #49981

Merged
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1736,7 +1736,6 @@ def _log_steps_trained_metrics(self, batch: MultiAgentBatch):
(ALL_MODULES, NUM_ENV_STEPS_TRAINED_LIFETIME),
batch.env_steps(),
reduce="sum",
with_throughput=True,
)

@Deprecated(
Expand Down
2 changes: 0 additions & 2 deletions rllib/env/multi_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,6 @@ def set_state(self, state: StateDict) -> None:
key=NUM_ENV_STEPS_SAMPLED_LIFETIME,
value=state[NUM_ENV_STEPS_SAMPLED_LIFETIME],
reduce="sum",
with_throughput=True,
)

@override(Checkpointable)
Expand Down Expand Up @@ -993,7 +992,6 @@ def _increase_sampled_metrics(self, num_steps, next_obs, episode):
NUM_ENV_STEPS_SAMPLED_LIFETIME,
num_steps,
reduce="sum",
with_throughput=True,
)
# Completed episodes.
if episode.is_done:
Expand Down
2 changes: 0 additions & 2 deletions rllib/env/single_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,6 @@ def set_state(self, state: StateDict) -> None:
key=NUM_ENV_STEPS_SAMPLED_LIFETIME,
value=state[NUM_ENV_STEPS_SAMPLED_LIFETIME],
reduce="sum",
with_throughput=True,
)

@override(Checkpointable)
Expand Down Expand Up @@ -799,7 +798,6 @@ def _increase_sampled_metrics(self, num_steps, num_episodes_completed):
NUM_ENV_STEPS_SAMPLED_LIFETIME,
num_steps,
reduce="sum",
with_throughput=True,
)
self.metrics.log_value(
(NUM_AGENT_STEPS_SAMPLED_LIFETIME, DEFAULT_AGENT_ID),
Expand Down
46 changes: 2 additions & 44 deletions rllib/utils/metrics/metrics_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class MetricsLogger:
- Reducing these collected values using a user specified reduction method (for
example "min" or "mean") and other settings controlling the reduction and internal
data, such as sliding windows or EMA coefficients.
- Resetting the logged values after a `reduce()` call in order to make space for
new values to be logged.
- Optionally clearing all logged values after a `reduce()` call to make space for
new data.

.. testcode::

Expand Down Expand Up @@ -233,7 +233,6 @@ def log_value(
window: Optional[Union[int, float]] = None,
ema_coeff: Optional[float] = None,
clear_on_reduce: bool = False,
with_throughput: bool = False,
) -> None:
"""Logs a new value under a (possibly nested) key to the logger.

Expand Down Expand Up @@ -329,13 +328,6 @@ def log_value(
`self.reduce()` is called. Setting this to True is useful for cases,
in which the internal values list would otherwise grow indefinitely,
for example if reduce is None and there is no `window` provided.
with_throughput: Whether to track a throughput estimate together with this
metric. This is only supported for `reduce=sum` and
`clear_on_reduce=False` metrics (aka. "lifetime counts"). The `Stats`
object under the logged key then keeps track of the time passed
between two consecutive calls to `reduce()` and update its throughput
estimate. The current throughput estimate of a key can be obtained
through: `MetricsLogger.peek([some key], throughput=True)`.
"""
# No reduction (continue appending to list) AND no window.
# -> We'll force-reset our values upon `reduce()`.
Expand All @@ -358,7 +350,6 @@ def log_value(
window=window,
ema_coeff=ema_coeff,
clear_on_reduce=clear_on_reduce,
throughput=with_throughput,
)
),
)
Expand Down Expand Up @@ -701,14 +692,9 @@ def log_time(
window: Optional[Union[int, float]] = None,
ema_coeff: Optional[float] = None,
clear_on_reduce: bool = False,
key_for_throughput: Optional[Union[str, Tuple[str, ...]]] = None,
key_for_unit_count: Optional[Union[str, Tuple[str, ...]]] = None,
) -> Stats:
"""Measures and logs a time delta value under `key` when used with a with-block.

Additionally, measures and logs the throughput for the timed code, iff
`key_for_throughput` and `key_for_unit_count` are provided.

.. testcode::

import time
Expand Down Expand Up @@ -769,32 +755,13 @@ def log_time(
clear_on_reduce = True

if not self._key_in_stats(key):
measure_throughput = None
if key_for_unit_count is not None:
measure_throughput = True
key_for_throughput = key_for_throughput or (key + "_throughput_per_s")

self._set_key(
key,
Stats(
reduce=reduce,
window=window,
ema_coeff=ema_coeff,
clear_on_reduce=clear_on_reduce,
on_exit=(
lambda time_delta_s, kt=key_for_throughput, ku=key_for_unit_count, r=reduce, w=window, e=ema_coeff, c=clear_on_reduce: ( # noqa
self.log_value(
kt,
value=self.peek(ku) / time_delta_s,
reduce=r,
window=w,
ema_coeff=e,
clear_on_reduce=c,
)
)
)
if measure_throughput
else None,
),
)

Expand Down Expand Up @@ -1005,7 +972,6 @@ def set_value(
window: Optional[Union[int, float]] = None,
ema_coeff: Optional[float] = None,
clear_on_reduce: bool = False,
with_throughput: bool = False,
) -> None:
"""Overrides the logged values under `key` with `value`.

Expand Down Expand Up @@ -1042,13 +1008,6 @@ def set_value(
in which the internal values list would otherwise grow indefinitely,
for example if reduce is None and there is no `window` provided.
Note that this is only applied if `key` does not exist in `self` yet.
with_throughput: Whether to track a throughput estimate together with this
metric. This is only supported for `reduce=sum` and
`clear_on_reduce=False` metrics (aka. "lifetime counts"). The `Stats`
object under the logged key then keeps track of the time passed
between two consecutive calls to `reduce()` and update its throughput
estimate. The current throughput estimate of a key can be obtained
through: `MetricsLogger.peek([some key], throughput=True)`.
"""
# Key already in self -> Erase internal values list with [`value`].
if self._key_in_stats(key):
Expand All @@ -1064,7 +1023,6 @@ def set_value(
window=window,
ema_coeff=ema_coeff,
clear_on_reduce=clear_on_reduce,
with_throughput=with_throughput,
)

def reset(self) -> None:
Expand Down
28 changes: 12 additions & 16 deletions rllib/utils/metrics/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ def __init__(
ema_coeff: Optional[float] = None,
clear_on_reduce: bool = False,
on_exit: Optional[Callable] = None,
throughput: Union[bool, float] = False,
):
"""Initializes a Stats instance.

Expand Down Expand Up @@ -175,13 +174,6 @@ def __init__(
to True is useful for cases, in which the internal values list would
otherwise grow indefinitely, for example if reduce is None and there
is no `window` provided.
with_throughput: Whether to track a throughput estimate together with this
Stats. This is only supported for `reduce=sum` and
`clear_on_reduce=False` metrics (aka. "lifetime counts"). The `Stats`
then keeps track of the time passed between two consecutive calls to
`reduce()` and update its throughput estimate. The current throughput
estimate of a key can be obtained through:
`Stats.peek([some key], throughput=True)`.
"""
# Thus far, we only support mean, max, min, and sum.
if reduce not in [None, "mean", "min", "max", "sum"]:
Expand Down Expand Up @@ -228,10 +220,14 @@ def __init__(
# previous `reduce()` result in hist[1].
self._hist = deque([0, 0, 0], maxlen=3)

self._throughput = throughput if throughput is not True else 0.0
if self._throughput is not False:
assert self._reduce_method == "sum"
assert self._window in [None, float("inf")]
self._throughput = 0.0
self._measure_throughput = False
if (
self._reduce_method == "sum"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dumb question: Why is this only relevant for sum? Because these are the only lifetime stats?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only reduce setting, for which it makes sense to measure throughput for. Doesn't really make sense for min/max/mean.

and self._window in [None, float("inf")]
and not self._clear_on_reduce
):
self._measure_throughput = True
self._throughput_last_time = -1

def push(self, value) -> None:
Expand Down Expand Up @@ -295,7 +291,7 @@ def peek(self, *, previous: Optional[int] = None, throughput: bool = False) -> A
return self._hist[-abs(previous)]
# Return the last measured throughput.
elif throughput:
return self._throughput if self._throughput is not False else None
return self._throughput if self._measure_throughput else None
return self._reduced_values()[0]

def reduce(self) -> "Stats":
Expand All @@ -314,7 +310,7 @@ class for details on the reduction logic applied to the values list, based on
reduced, values = self._reduced_values()

# Keep track and update underlying throughput metric.
if self._throughput is not False:
if self._measure_throughput:
# Take the delta between the new (upcoming) reduced value and the most
# recently reduced value (one `reduce()` call ago).
delta_sum = reduced - self._hist[-1]
Expand Down Expand Up @@ -356,7 +352,7 @@ def merge_on_time_axis(self, other: "Stats") -> None:
self.values = self.values[-self._window :]

# Adopt `other`'s current throughput estimate (it's the newer one).
if self._throughput is not False:
if self._measure_throughput:
self._throughput = other._throughput

def merge_in_parallel(self, *others: "Stats") -> None:
Expand Down Expand Up @@ -647,8 +643,8 @@ def similar_to(
window=other._window,
ema_coeff=other._ema_coeff,
clear_on_reduce=other._clear_on_reduce,
throughput=other._throughput,
)
stats._throughput = other._throughput
stats._hist = other._hist
return stats

Expand Down