Skip to content

Commit

Permalink
Add Dataset.pipe method, based on pandas.DataFrame.pipe.
Browse files Browse the repository at this point in the history
`pipe` is convenient because it allows for using method chaining syntax in an extensible fashion, with transformations that are not built-in methods on `Dataset`.

For example, consider shuffling a dataset in windows. It would be convenient if we could write something like:
```
ds = (
    dataset.MapDataset.range(400)
    .window_shuffle(window_size=10, seed=42)
    .batch(16)
    .repeat()
)
```

Unfortunately this doesn't work, because there is no `window_shuffle()` method. Instead you would need to write something like:

```
ds = (
    shuffle.WindowShuffleMapDataset(
        dataset.MapDataset.range(400),
        window_size=10,
        seed=42,
    )
    .batch(16)
    .repeat()
)
```

Readability suffers here, because the shuffle transformation comes out of order.

Instead, `pipe` lets us write something like, keeping transformations in the order in which they are applied:
```
ds = (
    dataset.MapDataset.range(400)
    .pipe(
        shuffle.WindowShuffleMapDataset,
        window_size=10,
        seed=42,
    )
    .batch(16)
    .repeat()
)
```
PiperOrigin-RevId: 729289880
  • Loading branch information
shoyer authored and copybara-github committed Feb 21, 2025
1 parent 8600444 commit 42039a1
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
55 changes: 55 additions & 0 deletions grain/_src/python/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,61 @@ def _default_seed(self) -> int | None:
seed_sequence = np.random.SeedSequence(aggregated_seed)
return seed_sequence.generate_state(1, dtype=np.uint32)[0]

# TODO: Define a more precise type signature for this method,
# once pytype fully supports Concatenate and ParamSpec
# (b/217789659, https://github.com/google/pytype/issues/786):
# P = ParamSpec("P")
# def pipe(
# self,
# func: Callable[Concatenate[Self, P], T],
# /,
# *args: P.args,
# **kwargs: P.kwargs,
# ) -> T:
def pipe(self, func: Callable[..., T], /, *args, **kwargs) -> T:
"""Syntactic sugar for applying a callable to this dataset.
The `pipe` method, borrowed from `pandas.DataFrame`, is convenient because
it allows for using method chaining syntax in an extensible fashion, with
transformations that are not built-in methods on `Dataset`.
For example, suppose you want to shuffle a dataset within a window.
Functionality for this is available in `WindowShuffleMapDataset`, but not as
a method on `MapDataset`, e.g.,
```
dataset = (
grain.experimental.WindowShuffleMapDataset(
grain.MapDataset.range(1000),
window_size=128,
seed=0,
)
.batch(16)
)
This solution suffers from readability, because the shuffle transformation
appears out of order from the data flow.
In contrast, with `pipe` you can write:
```
dataset = (
grain.MapDataset.range(1000)
.pipe(
grain.experimental.WindowShuffleMapDataset, window_size=128, seed=0
)
.batch(16)
)
```
Args:
func: The callable to apply to this dataset.
*args: Additional positional arguments to pass to the callable.
**kwargs: Keyword arguments to pass to the callable.
Returns:
The result of calling `func(self, *args, **kwargs)`.
"""
return func(self, *args, **kwargs)


class _MapDatasetMeta(abc.ABCMeta):
"""Metaclass for `MapDataset` containing factory transfromations."""
Expand Down
5 changes: 5 additions & 0 deletions grain/_src/python/dataset/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,11 @@ def test_iterator_restore_with_dictionary_elements(self):
ds = ds.to_iter_dataset()
test_util.assert_equal_output_after_checkpoint(ds)

def test_pipe(self):
ds = dataset.MapDataset.range(10)
outputs = ds.pipe(lambda self, *args, **kwargs: (args, kwargs), 1, 2, x=3)
self.assertEqual(outputs, ((1, 2), {"x": 3}))


class TfRandomMapAlwaysAddingOne(transforms.TfRandomMapTransform):

Expand Down

0 comments on commit 42039a1

Please sign in to comment.