Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Dataset.pipe method, based on pandas.DataFrame.pipe. #734

Merged
merged 1 commit into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions grain/_src/python/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,61 @@ def _default_seed(self) -> int | None:
seed_sequence = np.random.SeedSequence(aggregated_seed)
return seed_sequence.generate_state(1, dtype=np.uint32)[0]

# TODO: Define a more precise type signature for this method,
# once pytype fully supports Concatenate and ParamSpec
# (b/217789659, https://github.com/google/pytype/issues/786):
# P = ParamSpec("P")
# def pipe(
# self,
# func: Callable[Concatenate[Self, P], T],
# /,
# *args: P.args,
# **kwargs: P.kwargs,
# ) -> T:
def pipe(self, func: Callable[..., T], /, *args, **kwargs) -> T:
"""Syntactic sugar for applying a callable to this dataset.
The `pipe` method, borrowed from `pandas.DataFrame`, is convenient because
it allows for using method chaining syntax in an extensible fashion, with
transformations that are not built-in methods on `Dataset`.
For example, suppose you want to shuffle a dataset within a window.
Functionality for this is available in `WindowShuffleMapDataset`, but not as
a method on `MapDataset`, e.g.,
```
dataset = (
grain.experimental.WindowShuffleMapDataset(
grain.MapDataset.range(1000),
window_size=128,
seed=0,
)
.batch(16)
)
This solution suffers from readability, because the shuffle transformation
appears out of order from the data flow.
In contrast, with `pipe` you can write:
```
dataset = (
grain.MapDataset.range(1000)
.pipe(
grain.experimental.WindowShuffleMapDataset, window_size=128, seed=0
)
.batch(16)
)
```
Args:
func: The callable to apply to this dataset.
*args: Additional positional arguments to pass to the callable.
**kwargs: Keyword arguments to pass to the callable.
Returns:
The result of calling `func(self, *args, **kwargs)`.
"""
return func(self, *args, **kwargs)


class _MapDatasetMeta(abc.ABCMeta):
"""Metaclass for `MapDataset` containing factory transfromations."""
Expand Down
5 changes: 5 additions & 0 deletions grain/_src/python/dataset/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,11 @@ def test_iterator_restore_with_dictionary_elements(self):
ds = ds.to_iter_dataset()
test_util.assert_equal_output_after_checkpoint(ds)

def test_pipe(self):
ds = dataset.MapDataset.range(10)
outputs = ds.pipe(lambda self, *args, **kwargs: (args, kwargs), 1, 2, x=3)
self.assertEqual(outputs, ((1, 2), {"x": 3}))


class TfRandomMapAlwaysAddingOne(transforms.TfRandomMapTransform):

Expand Down