Skip to content

Commit

Permalink
Fix FlatMapMapDataset length overflow when combined with infinite rep…
Browse files Browse the repository at this point in the history
…eat.

PiperOrigin-RevId: 706724840
  • Loading branch information
Grain Team authored and copybara-github committed Dec 16, 2024
1 parent bc87384 commit 29b5689
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
6 changes: 6 additions & 0 deletions grain/_src/python/dataset/transformations/flatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Flatmap transformation for MapDataset."""
import functools
import sys
from typing import Any, Callable, Sequence, TypeVar

from grain._src.core import transforms
Expand All @@ -37,6 +38,11 @@ def __init__(
self._transform = transform

def __len__(self) -> int:
# If the parent dataset is on infinite repeat, its length is
# sys.maxsize and would result in overflows if further increased.
# In this case, we just keep the length as sys.maxsize.
if len(self._parent) >= sys.maxsize / self._transform.max_fan_out:
return sys.maxsize
return self._transform.max_fan_out * len(self._parent)

def __str__(self) -> str:
Expand Down
8 changes: 8 additions & 0 deletions grain/_src/python/dataset/transformations/flatmap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import dataclasses
import itertools
import sys
from typing import Any, Sequence

from absl.testing import absltest
Expand Down Expand Up @@ -73,6 +74,13 @@ def test_fixed_fan_out_size(self):
)
self.assertLen(flatmap_ds, self.fan_out * len(self.range_ds))

def test_flatmap_ds_length_after_repeat(self):
flatmap_ds = flatmap.FlatMapMapDataset(
self.range_ds.repeat(),
FixedSizeSplitWithNoTransform(max_fan_out=self.fan_out),
)
self.assertLen(flatmap_ds, sys.maxsize)

def test_fixed_fan_out_data_no_transform(self):
flatmap_ds = flatmap.FlatMapMapDataset(
self.range_ds, FixedSizeSplitWithNoTransform(max_fan_out=self.fan_out)
Expand Down

0 comments on commit 29b5689

Please sign in to comment.