Skip to content

Commit

Permalink
Merge pull request #338 from fabianp:main
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 605435144
  • Loading branch information
copybara-github committed Feb 8, 2024
2 parents a5ad2ee + 0b59607 commit 059bffe
Show file tree
Hide file tree
Showing 20 changed files with 158 additions and 110 deletions.
73 changes: 38 additions & 35 deletions grain/_src/python/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
import json
from multiprocessing import pool
import os
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar
import sys
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, Union

from absl import logging
from concurrent import futures
Expand All @@ -45,6 +46,7 @@

_T = TypeVar("_T")
_IteratorState = dict[str, Any]
_IS_PY310 = sys.version_info >= (3, 10)

# Dictionary keys used in checkpoints.
_VERSION = "version"
Expand Down Expand Up @@ -82,7 +84,9 @@ def _determine_worker_count(input_worker_count: int | None) -> int:
raise ValueError("Can't determine worker count. Please set worker count.")


@dataclasses.dataclass(frozen=True, slots=True)
@dataclasses.dataclass(
**({"slots": True, "frozen": True} if _IS_PY310 else {"frozen": True})
)
class _ReaderQueueElement:
"""Element to be added to the reader queue."""

Expand All @@ -99,7 +103,9 @@ class _GrainPoolProcessingComplete:


_GRAIN_POOL_PROCESSING_COMPLETE = _GrainPoolProcessingComplete()
_QueueElement = _ReaderQueueElement | _GrainPoolProcessingComplete | Exception
_QueueElement = Union[
_ReaderQueueElement, _GrainPoolProcessingComplete, Exception
]


@contextlib.contextmanager
Expand Down Expand Up @@ -447,38 +453,35 @@ def _apply_transform(
fn: Callable[[record.Record], Tuple[record.Record, bool]] = None
# pylint: disable=g-long-lambda
# pytype: disable=attribute-error
match transform:
case transforms.MapTransform():
fn = lambda r: (record.Record(r.metadata, transform.map(r.data)), True)
case transforms.RandomMapTransform():
fn = lambda r: (
record.Record(
r.metadata, transform.random_map(r.data, r.metadata.rng)
),
True,
)
case transforms.TfRandomMapTransform():
fn = lambda r: (
record.Record(
r.metadata, transform.np_random_map(r.data, r.metadata.rng)
),
True,
)
case transforms.FilterTransform():
fn = lambda r: (r, bool(transform.filter(r.data)))
case transforms.BatchTransform():
batch_op = BatchOperation(
batch_size=transform.batch_size,
drop_remainder=transform.drop_remainder,
)
batch_op.disable_deprecation_message()
for r in batch_op(input_iterator):
yield r
case _:
# Transform is a legacy style operation and __call__() yield output
# records.
for r in transform(input_iterator):
yield r
if isinstance(transform, transforms.MapTransform):
fn = lambda r: (record.Record(r.metadata, transform.map(r.data)), True)
elif isinstance(transform, transforms.RandomMapTransform):
fn = lambda r: (
record.Record(r.metadata, transform.random_map(r.data, r.metadata.rng)),
True,
)
elif isinstance(transform, transforms.TfRandomMapTransform):
fn = lambda r: (
record.Record(
r.metadata, transform.np_random_map(r.data, r.metadata.rng)
),
True,
)
elif isinstance(transform, transforms.FilterTransform):
fn = lambda r: (r, bool(transform.filter(r.data)))
elif isinstance(transform, transforms.BatchTransform):
batch_op = BatchOperation(
batch_size=transform.batch_size,
drop_remainder=transform.drop_remainder,
)
batch_op.disable_deprecation_message()
for r in batch_op(input_iterator):
yield r
else:
# Transform is a legacy style operation and __call__() yield output
# records.
for r in transform(input_iterator):
yield r
# pytype: enable=attribute-error
# pylint: enable=g-long-lambda

Expand Down
7 changes: 6 additions & 1 deletion grain/_src/python/data_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from collections.abc import Sequence
import pathlib
from typing import Union
from unittest import mock

from absl import flags
Expand Down Expand Up @@ -144,7 +145,11 @@ def setUp(self):
self.testdata_dir = pathlib.Path(FLAGS.test_srcdir) / "testdata"

def _create_data_loader_for_short_sequence(
self, transformations, *, worker_count: int = 0, seed: int | None = None
self,
transformations,
*,
worker_count: int = 0,
seed: Union[int, None] = None,
) -> data_loader_lib.DataLoader:
# Generates elements [0, 1, 2, 3, 4, 5, 6, 7].
range_data_source = RangeDataSource(start=0, stop=8, step=1)
Expand Down
6 changes: 3 additions & 3 deletions grain/_src/python/data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import os
import threading
import typing
from typing import Any, Generic, Protocol, SupportsIndex, TypeVar
from typing import Any, Generic, Optional, Protocol, SupportsIndex, TypeVar

from absl import logging
import array_record.python.array_record_data_source as array_record
Expand Down Expand Up @@ -113,9 +113,9 @@ class InMemoryDataSource(shared_memory.ShareableList):

def __init__(
self,
elements: Sequence[Any] | None = None,
elements: Optional[Sequence[Any]] = None,
*,
name: str | None = None,
name: Optional[str] = None,
):
"""Creates a new InMemoryDataSource object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from collections.abc import Sequence
import dataclasses
import heapq
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

from grain._src.core import sharding
from grain._src.python import record
Expand Down Expand Up @@ -362,7 +362,7 @@ class SamplerWrapper:

def __init__(
self,
sampler: ContinualSequenceSampler | BatchedContinualSequenceSampler,
sampler: Union[ContinualSequenceSampler, BatchedContinualSequenceSampler],
start_index_ordered: np.ndarray,
seed: int,
):
Expand Down
4 changes: 2 additions & 2 deletions grain/_src/python/experimental/example_packing/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"""

import dataclasses
from typing import Generic, Iterator, TypeVar, cast
from typing import Generic, Iterator, TypeVar, Union, cast

from grain._src.core import tree
from grain._src.python import record
Expand Down Expand Up @@ -180,7 +180,7 @@ class PackAndBatchOperation(Generic[_T]):
length_struct: jt.PyTree[int]
batch_size: int
# We don't know input shapes and corresponding buffer shapes until __call__.
_cur_batch: _PackedBatch | None = None
_cur_batch: Union[_PackedBatch, None] = None

def __call__(
self, input_iterator: Iterator[record.Record[_T]]
Expand Down
16 changes: 12 additions & 4 deletions grain/_src/python/grain_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@
from multiprocessing import synchronize
import pstats
import queue
import sys
import threading
import traceback
from typing import Any, Protocol, TypeVar
from typing import Any, Protocol, TypeVar, Union

from absl import logging
import cloudpickle
Expand All @@ -68,6 +69,7 @@
from grain._src.python.options import MultiprocessingOptions # pylint: disable=g-importing-member

T = TypeVar("T")
_IS_PY310 = sys.version_info >= (3, 10)

# Maximum number of threads for starting and stopping processes.
_PROCESS_MANAGEMENT_MAX_THREADS = 64
Expand All @@ -86,7 +88,9 @@ class _ProcessingComplete:
_PROCESSING_COMPLETE = _ProcessingComplete()


@dataclasses.dataclass(frozen=True, slots=True)
@dataclasses.dataclass(
**({"slots": True, "frozen": True} if _IS_PY310 else {"frozen": True})
)
class GrainPoolElement:
"""Wrapper for output records emited by Grain Pool."""

Expand Down Expand Up @@ -412,7 +416,9 @@ def _shutdown(self) -> None:
process.terminate()


@dataclasses.dataclass(frozen=True, slots=True)
@dataclasses.dataclass(
**({"slots": True, "frozen": True} if _IS_PY310 else {"frozen": True})
)
class _ReaderQueueElement:
"""Element to be added to the reader queue."""

Expand All @@ -427,7 +433,9 @@ class _GrainPoolProcessingComplete:


_GRAIN_POOL_PROCESSING_COMPLETE = _GrainPoolProcessingComplete()
_QueueElement = _ReaderQueueElement | _GrainPoolProcessingComplete | Exception
_QueueElement = Union[
_ReaderQueueElement, _GrainPoolProcessingComplete, Exception
]


class GrainPoolProcessingError(Exception):
Expand Down
4 changes: 2 additions & 2 deletions grain/_src/python/lazy_dataset/data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""LazyDataset data sources."""

from typing import Protocol
from typing import Protocol, Union

from absl import logging
from grain._src.python.lazy_dataset import lazy_dataset
Expand Down Expand Up @@ -51,7 +51,7 @@ def log_lineage(self):


def log_lineage_for_sources(
root: lazy_dataset.LazyMapDataset | lazy_dataset.LazyIterDataset,
root: Union[lazy_dataset.LazyMapDataset, lazy_dataset.LazyIterDataset]
):
"""Traverses tree of transformations and logs lineage on source datasets."""
pass
32 changes: 17 additions & 15 deletions grain/_src/python/lazy_dataset/lazy_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import copy
import functools
import time
from typing import Any, Callable, Optional, TypeVar, overload
from typing import Any, Callable, Optional, TypeVar, Union, overload

from concurrent import futures
from grain._src.core import sharding
Expand All @@ -66,7 +66,9 @@ class LazyMapDataset(Sequence[T], abc.ABC):

_functions: dict[str, Callable[[LazyMapDataset], Any]] = {}

def __init__(self, parents: LazyMapDataset | Sequence[LazyMapDataset] = ()):
def __init__(
self, parents: Union[LazyMapDataset, Sequence[LazyMapDataset]] = ()
):
if isinstance(parents, LazyMapDataset):
self._parents = (parents,)
else:
Expand All @@ -91,7 +93,7 @@ def __getitem__(self, index: slice) -> LazyMapDataset:
...

@overload
def __getitem__(self, index: int) -> T | None:
def __getitem__(self, index: int) -> Optional[T]:
...

@abc.abstractmethod
Expand Down Expand Up @@ -121,7 +123,7 @@ def __iter__(self) -> LazyDatasetIterator[T]:
return self.to_iter_dataset().__iter__()

def to_iter_dataset(
self, read_options: grain_options.ReadOptions | None = None
self, read_options: Optional[grain_options.ReadOptions] = None
) -> LazyIterDataset[T]:
"""Syntactic sugar to construct a LazyIterDataset."""
return PrefetchLazyIterDataset(
Expand All @@ -136,11 +138,11 @@ class LazyIterDataset(Iterable[T], abc.ABC):

def __init__(
self,
parents: (
LazyMapDataset
| LazyIterDataset
| Sequence[LazyMapDataset | LazyIterDataset]
) = (),
parents: Union[
LazyMapDataset,
LazyIterDataset,
Sequence[Union[LazyMapDataset, LazyIterDataset]],
] = (),
):
if isinstance(parents, (LazyMapDataset, LazyIterDataset)):
self._parents = (parents,)
Expand All @@ -149,11 +151,11 @@ def __init__(
usage_logging.log_event("LazyIterDataset", tag_3="PyGrain")

@property
def parents(self) -> Sequence[LazyMapDataset | LazyIterDataset]:
def parents(self) -> Sequence[Union[LazyMapDataset, LazyIterDataset]]:
return self._parents

@property
def _parent(self) -> LazyMapDataset | LazyIterDataset:
def _parent(self) -> Union[LazyMapDataset, LazyIterDataset]:
assert len(self._parents) == 1, self._parents
return self._parents[0]

Expand Down Expand Up @@ -452,7 +454,7 @@ def __next__(self) -> T:

def get_element_producer_fn(
worker_index: int, worker_count: int
) -> Iterator[tuple[T, dict[str, Any] | None]]:
) -> Iterator[tuple[T, Optional[dict[str, Any]]]]:
# Recover from the last recorded state for the given worker.
worker_state = state[_WORKERS_STATE][str(worker_index)]
parent.set_parent_maps_slice(slice(worker_index, None, worker_count))
Expand Down Expand Up @@ -502,7 +504,7 @@ def get_state(self) -> dict[str, Any]:
class RangeLazyMapDataset(LazyMapDataset[int]):
"""Range data source, similar to python range() function."""

def __init__(self, start: int, stop: int | None = None, step: int = 1):
def __init__(self, start: int, stop: Optional[int] = None, step: int = 1):
super().__init__()
self.start = 0 if stop is None else start
self.stop = start if stop is None else stop
Expand All @@ -522,7 +524,7 @@ def __getitem__(self, index):

def to_iter_dataset(
self,
read_options: grain_options.ReadOptions | None = None,
read_options: Optional[grain_options.ReadOptions] = None,
) -> LazyIterDataset[int]:
"""Syntactic sugar to construct a LazyIterDataset."""
return PrefetchLazyIterDataset(
Expand Down Expand Up @@ -550,7 +552,7 @@ def __init__(
def __len__(self) -> int:
return self._end - self._start

def __getitem__(self, index: int | slice) -> Optional[T]:
def __getitem__(self, index: Union[int, slice]) -> Optional[T]:
if isinstance(index, slice):
return self.slice(index)
epoch = index // len(self)
Expand Down
13 changes: 7 additions & 6 deletions grain/_src/python/lazy_dataset/transformations/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@


def _make_batch(values: Sequence[T]) -> T:
match len(values):
case 0:
return ()
case 1:
tree.map_structure(lambda x: np.expand_dims(x, axis=0), values[0])
return tree.map_structure(lambda *xs: np.stack(xs), values[0], *values[1:])
num_values = len(values)
if num_values == 0:
return ()
elif num_values == 1:
return tree.map_structure(lambda x: np.expand_dims(x, axis=0), values[0])
else:
return tree.map_structure(lambda *xs: np.stack(xs), values[0], *values[1:])


class _BatchLazyDatasetIterator(lazy_dataset.LazyDatasetIterator[T]):
Expand Down
Loading

0 comments on commit 059bffe

Please sign in to comment.