Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 724053185
  • Loading branch information
Grain Team authored and copybara-github committed Feb 27, 2025
1 parent 9022156 commit 4439601
Show file tree
Hide file tree
Showing 13 changed files with 74 additions and 2 deletions.
46 changes: 46 additions & 0 deletions grain/_src/python/dataset/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections.abc import Sequence
import contextlib
import dataclasses
import functools
import pprint
import sys
import threading
Expand Down Expand Up @@ -51,6 +52,20 @@
bucketer=monitoring.Bucketer.PowersOf(2.0),
)

_next_duration_ns_histogram = monitoring.EventMetric(
"/grain/python/dataset/next_duration_ns",
metadata=monitoring.Metadata(
description=(
"Histogram of durations of every `__next__` call on the output"
" iterator. Each data point is the duration value of `__next__`"
" call."
),
units=monitoring.Units.NANOSECONDS,
),
root=grain_monitoring.get_monitoring_root(),
bucketer=monitoring.Bucketer.PowersOf(2.0),
)

T = TypeVar("T")
# Time between two consecutive monitoring reports.
_REPORTING_PERIOD_SEC = 10
Expand Down Expand Up @@ -267,6 +282,37 @@ def _pretty_format_summary(
return table.get_pretty_wrapped_summary() # pylint: disable=protected-access


def record_next_duration_if_output(next_fn):
"""Records the duration of the `__next__` call on the output iterator node.
Expected to be used as follows:
```
class MyMapDatasetIterator(DatasetIterator):
...
@stats.record_next_duration_if_output
def __next__(self):
...
```
Args:
next_fn: The `__next__` function to wrap.
Returns:
The wrapped `next_fn`.
"""

@functools.wraps(next_fn)
def wrapper(iterator):
start_time = time.perf_counter_ns()
result = next_fn(iterator)
if iterator._stats._is_output: # pylint:disable=protected-access
next_duration_ns = time.perf_counter_ns() - start_time
_next_duration_ns_histogram.Record(next_duration_ns)
return result

return wrapper


class _Table:
"""Table class for pretty printing tabular data."""

Expand Down
10 changes: 8 additions & 2 deletions grain/_src/python/dataset/transformations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,10 @@ py_library(
name = "interleave",
srcs = ["interleave.py"],
srcs_version = "PY3",
deps = ["//grain/_src/python/dataset"],
deps = [
"//grain/_src/python/dataset",
"//grain/_src/python/dataset:stats",
],
)

py_test(
Expand All @@ -247,7 +250,10 @@ py_library(
name = "limit",
srcs = ["limit.py"],
srcs_version = "PY3",
deps = ["//grain/_src/python/dataset"],
deps = [
"//grain/_src/python/dataset",
"//grain/_src/python/dataset:stats",
],
)

py_test(
Expand Down
2 changes: 2 additions & 0 deletions grain/_src/python/dataset/transformations/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from grain._src.core import tree_lib
from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats
import numpy as np

T = TypeVar("T")
Expand Down Expand Up @@ -61,6 +62,7 @@ def __init__(
self._drop_remainder = drop_remainder
self._batch_fn = batch_fn

@stats.record_next_duration_if_output
def __next__(self) -> T:
values = []
for _ in range(self._batch_size):
Expand Down
1 change: 1 addition & 0 deletions grain/_src/python/dataset/transformations/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def _threshold_checker(self):
raise_threshold=self._ctx.dataset_options.filter_raise_threshold_ratio,
)

@dataset_stats.record_next_duration_if_output
def __next__(self):
value = None
passed_filter = False
Expand Down
1 change: 1 addition & 0 deletions grain/_src/python/dataset/transformations/flatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def __init__(
def _has_consumed_all_buffer_elements(self):
return self._next_index_in_buffer >= len(self._buffer)

@dataset_stats.record_next_duration_if_output
def __next__(self):
timer = dataset_stats.Timer()
while self._has_consumed_all_buffer_elements():
Expand Down
2 changes: 2 additions & 0 deletions grain/_src/python/dataset/transformations/interleave.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import TypeVar

from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats
from grain._src.python.dataset.transformations import prefetch

T = TypeVar("T")
Expand Down Expand Up @@ -53,6 +54,7 @@ def __init__(
None
] * self._cycle_length

@stats.record_next_duration_if_output
def __next__(self) -> T:
while True:
if iterator_to_use := self._iterators_in_use[self._next_index_in_cycle]:
Expand Down
2 changes: 2 additions & 0 deletions grain/_src/python/dataset/transformations/limit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Any, TypeVar

from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats

Element = Any
T = TypeVar("T") # pylint: disable=invalid-name
Expand All @@ -33,6 +34,7 @@ def __init__(
self._count = count
self._count_elements_read = 0

@stats.record_next_duration_if_output
def __next__(self):
if self._count_elements_read >= self._count:
raise StopIteration
Expand Down
3 changes: 3 additions & 0 deletions grain/_src/python/dataset/transformations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from grain._src.core import transforms
from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats
import numpy as np


Expand Down Expand Up @@ -215,6 +216,7 @@ def __init__(
# TODO: Move users away from this and remove.
self._index_for_rng = 0

@stats.record_next_duration_if_output
def __next__(self):
element = next(self._parent)
with self._stats.record_self_time():
Expand Down Expand Up @@ -254,6 +256,7 @@ def __init__(
self._rng = np.random.Generator(np.random.Philox(seed))
self._transform_name = transform_name

@stats.record_next_duration_if_output
def __next__(self):
element = next(self._parent)
with self._stats.record_self_time():
Expand Down
2 changes: 2 additions & 0 deletions grain/_src/python/dataset/transformations/mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from grain._src.core import exceptions
from grain._src.python.dataset import base
from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats
from typing_extensions import override

Element = Any
Expand Down Expand Up @@ -139,6 +140,7 @@ def __init__(
self._index = 0
self._stop = False

@stats.record_next_duration_if_output
def __next__(self):
if self._stop:
# Although there may be elements available in some parent datasets, do not
Expand Down
1 change: 1 addition & 0 deletions grain/_src/python/dataset/transformations/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def _finalize_current_batch(self, element_for_shapes):
meta_features=self._meta_features,
)

@dataset_stats.record_next_duration_if_output
def __next__(self):
timer = dataset_stats.Timer()
if self._packed_batch is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,7 @@ def _maybe_add_to_buffer(
tokens_in_buffer[k] += v
return remainder

@stats.record_next_duration_if_output
def __next__(self):
if self._packed_elements:
self._state.elements_from_buffer_after_checkpoint += 1
Expand Down
3 changes: 3 additions & 0 deletions grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def _threshold_checker(self):
raise_threshold=self._ctx.dataset_options.filter_raise_threshold_ratio,
)

@dataset_stats.record_next_duration_if_output
def __next__(self) -> T:
# The time recorded here is the time spent in prefetch node to return an
# element, including the time spent in parent node.
Expand Down Expand Up @@ -496,6 +497,7 @@ def _stats(self):
def __iter__(self) -> dataset.DatasetIterator[T]:
return self

@dataset_stats.record_next_duration_if_output
def __next__(self) -> T:
self._ensure_iterator_initialized()
result, state = next(self._iterator)
Expand Down Expand Up @@ -714,6 +716,7 @@ def _producer(
except Exception as e: # pylint: disable=broad-except
output_buffer.put((None, None, e))

@dataset_stats.record_next_duration_if_output
def __next__(self):
self.start_prefetch()
assert self._buffer is not None
Expand Down
2 changes: 2 additions & 0 deletions grain/_src/python/dataset/transformations/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import TypeVar

from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats
from grain._src.python.experimental.index_shuffle.python import index_shuffle_module as index_shuffle


Expand Down Expand Up @@ -188,6 +189,7 @@ def _fill_and_shuffle_window(self):
seed=self._global_seed + self._window_index, window=self._window
)

@stats.record_next_duration_if_output
def __next__(self):
# Window is empty, fill up the next window.
if not self._window:
Expand Down

0 comments on commit 4439601

Please sign in to comment.