Skip to content

Commit

Permalink
Introduce LimitIterDataset & .limit() transform to limit number o…
Browse files Browse the repository at this point in the history
…f elements produced by an IterDataset.

PiperOrigin-RevId: 724076602
  • Loading branch information
aayooush authored and copybara-github committed Feb 7, 2025
1 parent 8757578 commit 2919edb
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 0 deletions.
1 change: 1 addition & 0 deletions grain/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ py_library(
"//grain/_src/python/dataset:visualize",
"//grain/_src/python/dataset/transformations:flatmap",
"//grain/_src/python/dataset/transformations:interleave",
"//grain/_src/python/dataset/transformations:limit",
"//grain/_src/python/dataset/transformations:packing",
"//grain/_src/python/dataset/transformations:zip",
"//grain/_src/python/experimental/example_packing:packing",
Expand Down
20 changes: 20 additions & 0 deletions grain/_src/python/dataset/transformations/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,23 @@ py_test(
"@abseil-py//absl/testing:parameterized",
],
)

py_library(
name = "limit",
srcs = ["limit.py"],
srcs_version = "PY3",
deps = ["//grain/_src/python/dataset"],
)

py_test(
name = "limit_test",
srcs = ["limit_test.py"],
srcs_version = "PY3",
deps = [
":limit",
"//grain/_src/python/dataset",
"//grain/_src/python/testing:experimental",
"@abseil-py//absl/testing:absltest",
"@abseil-py//absl/testing:parameterized",
],
)
84 changes: 84 additions & 0 deletions grain/_src/python/dataset/transformations/limit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements limit transformations."""

from typing import Any, TypeVar

from grain._src.python.dataset import dataset

Element = Any
T = TypeVar("T") # pylint: disable=invalid-name


class _LimitDatasetIterator(dataset.DatasetIterator[T]):
"""Iterator that limits the number of elements in the dataset."""

def __init__(
self,
parent: dataset.DatasetIterator[T],
count: int,
):
super().__init__(parent)
self._count = count
self._count_elements_read = 0

def __next__(self):
if self._count_elements_read >= self._count:
raise StopIteration
value = next(self._parent)
self._count_elements_read += 1
return value

def get_state(self):
return {
"parent": self._parent.get_state(),
"count_elements_read": self._count_elements_read,
}

def set_state(self, state):
self._parent.set_state(state["parent"])
self._count_elements_read = state["count_elements_read"]


class LimitIterDataset(dataset.IterDataset[T]):
"""Limits the number of elements in the dataset.
Example usage:
```
list(LimitIterDataset(MapDataset.range(5).to_iter_dataset(), 2) == [0, 1]
```
Attributes:
parent: The dataset to limit.
count: The maximum number of elements to include in the dataset.
"""

def __init__(
self,
parent: dataset.IterDataset[T],
count: int,
):
"""Initializes the limit dataset."""
if count <= 0:
raise ValueError(f"Count must be a non-negative integer. Got {count}")
super().__init__(parent)
self._count = count

def __iter__(self) -> _LimitDatasetIterator[T]:
parent_iter = self._parent.__iter__()
return _LimitDatasetIterator(parent_iter, self._count)

def __str__(self) -> str:
return f"LimitIterDataset(count={self._count})"
81 changes: 81 additions & 0 deletions grain/_src/python/dataset/transformations/limit_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for limit transformations."""

from absl.testing import absltest
from absl.testing import parameterized
from grain._src.python.dataset import dataset
from grain._src.python.dataset.transformations import limit
import grain._src.python.testing.experimental as test_util


class LimitIterDatasetTest(parameterized.TestCase):

@parameterized.parameters([0, -1, -5])
def test_non_positive_count_raises_error(self, count):
ds = dataset.MapDataset.range(0, 10).to_iter_dataset()
with self.assertRaises(ValueError):
_ = limit.LimitIterDataset(ds, count=count)

def test_stop_iteration_raised_after_limit_reached(self):
ds = dataset.MapDataset.range(0, 10).to_iter_dataset()
ds = limit.LimitIterDataset(ds, count=1)
ds_iter = iter(ds)
_ = next(ds_iter)
with self.assertRaises(StopIteration):
next(ds_iter)

@parameterized.parameters([1, 3, 5, 7, 10])
def test_count(self, count):
ds = dataset.MapDataset.range(0, 10).to_iter_dataset()
ds = limit.LimitIterDataset(ds, count=count)
actual_data = list(ds)
self.assertLen(actual_data, count)
self.assertEqual(actual_data, list(range(count)))

def test_count_over_epochs(self):
ds = dataset.MapDataset.range(0, 10).repeat(2).to_iter_dataset()
ds = limit.LimitIterDataset(ds, count=15)
actual_data = list(ds)
self.assertLen(actual_data, 15)
self.assertEqual(actual_data, list(range(10)) + list(range(5)))

def test_limit_after_batch(self):
def flatten_batches(batches):
actual_data = []
for batch in batches:
actual_data.extend(batch.tolist())
return actual_data

ds = dataset.MapDataset.range(0, 10).batch(3).to_iter_dataset()

ds_1 = limit.LimitIterDataset(ds, count=2)
batches = list(ds_1)
actual_data = flatten_batches(batches)
self.assertEqual(actual_data, list(range(6)))

ds_2 = limit.LimitIterDataset(ds, count=5)
batches = list(ds_2)
actual_data = flatten_batches(batches)
self.assertLen(batches, 4)
self.assertEqual(actual_data, list(range(10)))

def test_checkpointing(self):
ds = dataset.MapDataset.range(0, 10).batch(3).to_iter_dataset()
limited_ds = limit.LimitIterDataset(ds, count=2)
test_util.assert_equal_output_after_checkpoint(limited_ds)


if __name__ == "__main__":
absltest.main()
1 change: 1 addition & 0 deletions grain/python/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from grain._src.python.dataset.transformations.interleave import (
InterleaveIterDataset,
)
from grain._src.python.dataset.transformations.limit import LimitIterDataset
from grain._src.python.dataset.transformations.map import RngPool
from grain._src.python.dataset.transformations.mix import ConcatenateMapDataset
from grain._src.python.dataset.transformations.packing import FirstFitPackIterDataset
Expand Down

0 comments on commit 2919edb

Please sign in to comment.