diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 00000000..6bed1e17 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,54 @@ +name: tests + +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] + +jobs: + build-and-test: + name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}" + runs-on: "${{ matrix.os }}" + + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11"] + os: [ubuntu-latest] + + steps: + - uses: "actions/checkout@v2" + # - uses: "actions/setup-python@v4" + # with: + # python-version: "${{ matrix.python-version }}" + - name: Create directory + run: | + mkdir -p /tmp/grain + cp -r . /tmp/grain + - name: Build package + run: | + set -xe + export PYTHON_VERSION=${{ matrix.python-version }} + export PYTHON_MAJOR_VERSION=$(echo $PYTHON_VERSION | cut -d. -f1) + export PYTHON_MINOR_VERSION=$(echo $PYTHON_VERSION | cut -d. -f2) + export CP_VERSION="cp${PYTHON_MAJOR_VERSION}${PYTHON_MINOR_VERSION}" + export BAZEL_VERSION="5.4.0" + export AUDITWHEEL_PLATFORM="manylinux2014_x86_64" + cd /tmp/grain + DOCKER_BUILDKIT=1 docker build --progress=plain --no-cache \ + --build-arg PYTHON_VERSION=${PYTHON_VERSION} \ + --build-arg PYTHON_MAJOR_VERSION=${PYTHON_MAJOR_VERSION} \ + --build-arg PYTHON_MINOR_VERSION=${PYTHON_MINOR_VERSION} \ + --build-arg BAZEL_VERSION=${BAZEL_VERSION} \ + -t grain:${PYTHON_VERSION} - < grain/oss/build.Dockerfile + docker run --rm -a stdin -a stdout -a stderr \ + --env PYTHON_VERSION=${PYTHON_VERSION} \ + --env PYTHON_MAJOR_VERSION=${PYTHON_MAJOR_VERSION} \ + --env PYTHON_MINOR_VERSION=${PYTHON_MINOR_VERSION} \ + --env PYTHON_BIN_PATH="/opt/python/${CP_VERSION}-${CP_VERSION}/bin/python" \ + --env BAZEL_VERSION=${BAZEL_VERSION} \ + -v /tmp/grain:/tmp/grain \ + --env AUDITWHEEL_PLATFORM=${AUDITWHEEL_PLATFORM} \ + --name grain grain:${PYTHON_VERSION} \ + bash grain/oss/build_whl.sh + diff --git a/grain/_src/python/data_loader.py b/grain/_src/python/data_loader.py index 005a6fad..6ff6ef8e 100644 --- a/grain/_src/python/data_loader.py +++ b/grain/_src/python/data_loader.py @@ -23,7 +23,8 @@ import json from multiprocessing import pool import os -from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar +import sys +from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, Union from absl import logging from concurrent import futures @@ -31,7 +32,6 @@ from grain._src.core import transforms from grain._src.core import tree from grain._src.core import usage_logging -import multiprocessing as mp from grain._src.python import grain_pool from grain._src.python import options from grain._src.python import record @@ -45,6 +45,7 @@ _T = TypeVar("_T") _IteratorState = dict[str, Any] +PY310 = sys.version_info >= (3, 10) # Dictionary keys used in checkpoints. _VERSION = "version" @@ -82,7 +83,7 @@ def _determine_worker_count(input_worker_count: int | None) -> int: raise ValueError("Can't determine worker count. Please set worker count.") -@dataclasses.dataclass(frozen=True, slots=True) +@dataclasses.dataclass(**({"slots": True, "frozen": True} if PY310 else {"frozen": True})) class _ReaderQueueElement: """Element to be added to the reader queue.""" @@ -99,7 +100,7 @@ class _GrainPoolProcessingComplete: _GRAIN_POOL_PROCESSING_COMPLETE = _GrainPoolProcessingComplete() -_QueueElement = _ReaderQueueElement | _GrainPoolProcessingComplete | Exception +_QueueElement = Union[_ReaderQueueElement, _GrainPoolProcessingComplete, Exception] @contextlib.contextmanager @@ -440,54 +441,50 @@ def __str__(self): def _apply_transform( - transform: transforms.Transformation | Operation, - input_iterator: Iterator[record.Record], + transform: transforms.Transformation | Operation, + input_iterator: Iterator[record.Record], ) -> Iterator[record.Record]: """Applies the `transform` to records in the iterator.""" fn: Callable[[record.Record], Tuple[record.Record, bool]] = None - # pylint: disable=g-long-lambda - # pytype: disable=attribute-error - match transform: - case transforms.MapTransform(): - fn = lambda r: (record.Record(r.metadata, transform.map(r.data)), True) - case transforms.RandomMapTransform(): - fn = lambda r: ( - record.Record( - r.metadata, transform.random_map(r.data, r.metadata.rng) - ), - True, - ) - case transforms.TfRandomMapTransform(): - fn = lambda r: ( - record.Record( - r.metadata, transform.np_random_map(r.data, r.metadata.rng) - ), - True, - ) - case transforms.FilterTransform(): - fn = lambda r: (r, bool(transform.filter(r.data))) - case transforms.BatchTransform(): - batch_op = BatchOperation( - batch_size=transform.batch_size, - drop_remainder=transform.drop_remainder, - ) - batch_op.disable_deprecation_message() - for r in batch_op(input_iterator): - yield r - case _: - # Transform is a legacy style operation and __call__() yield output - # records. - for r in transform(input_iterator): - yield r - # pytype: enable=attribute-error - # pylint: enable=g-long-lambda + + if isinstance(transform, transforms.MapTransform): + fn = lambda r: (record.Record(r.metadata, transform.map(r.data)), True) + elif isinstance(transform, transforms.RandomMapTransform): + fn = lambda r: ( + record.Record( + r.metadata, transform.random_map(r.data, r.metadata.rng) + ), + True, + ) + elif isinstance(transform, transforms.TfRandomMapTransform): + fn = lambda r: ( + record.Record( + r.metadata, transform.np_random_map(r.data, r.metadata.rng) + ), + True, + ) + elif isinstance(transform, transforms.FilterTransform): + fn = lambda r: (r, bool(transform.filter(r.data))) + elif isinstance(transform, transforms.BatchTransform): + batch_op = BatchOperation( + batch_size=transform.batch_size, + drop_remainder=transform.drop_remainder, + ) + batch_op.disable_deprecation_message() + for r in batch_op(input_iterator): + yield r + else: + # Transform is a legacy style operation and __call__() yield output + # records. + for r in transform(input_iterator): + yield r for input_record in input_iterator: try: output_record, filter_result = fn(input_record) except Exception as e: raise ValueError( - f"PyGrain encountered an error when applying {transform}." + f"PyGrain encountered an error when applying {transform}." ) from e if filter_result: yield output_record diff --git a/grain/_src/python/data_loader_test.py b/grain/_src/python/data_loader_test.py index d3fa37bc..1a0cb685 100644 --- a/grain/_src/python/data_loader_test.py +++ b/grain/_src/python/data_loader_test.py @@ -16,6 +16,7 @@ from collections.abc import Sequence import pathlib from unittest import mock +from typing import Union from absl import flags from absl.testing import absltest @@ -144,7 +145,7 @@ def setUp(self): self.testdata_dir = pathlib.Path(FLAGS.test_srcdir) / "testdata" def _create_data_loader_for_short_sequence( - self, transformations, *, worker_count: int = 0, seed: int | None = None + self, transformations, *, worker_count: int = 0, seed: Union[int, None] = None ) -> data_loader_lib.DataLoader: # Generates elements [0, 1, 2, 3, 4, 5, 6, 7]. range_data_source = RangeDataSource(start=0, stop=8, step=1) diff --git a/grain/_src/python/data_sources.py b/grain/_src/python/data_sources.py index 23899a61..35cee799 100644 --- a/grain/_src/python/data_sources.py +++ b/grain/_src/python/data_sources.py @@ -25,15 +25,11 @@ from collections.abc import Sequence import math from multiprocessing import shared_memory -import os -import threading import typing -from typing import Any, Generic, Protocol, SupportsIndex, TypeVar +from typing import Any, Generic, Protocol, SupportsIndex, TypeVar, Union from absl import logging import array_record.python.array_record_data_source as array_record -from etils import epath -from grain._src.core import usage_logging T = TypeVar("T") @@ -113,9 +109,9 @@ class InMemoryDataSource(shared_memory.ShareableList): def __init__( self, - elements: Sequence[Any] | None = None, + elements: Union[Sequence[Any], None] = None, *, - name: str | None = None, + name: Union[str, None] = None, ): """Creates a new InMemoryDataSource object. diff --git a/grain/_src/python/experimental/continual_sequence_sampler/continual_sequence_sampler.py b/grain/_src/python/experimental/continual_sequence_sampler/continual_sequence_sampler.py index 746df160..d65f5a3b 100644 --- a/grain/_src/python/experimental/continual_sequence_sampler/continual_sequence_sampler.py +++ b/grain/_src/python/experimental/continual_sequence_sampler/continual_sequence_sampler.py @@ -51,7 +51,7 @@ from collections.abc import Sequence import dataclasses import heapq -from typing import Optional, Tuple +from typing import Optional, Tuple, Union from grain._src.core import sharding from grain._src.python import record @@ -362,7 +362,7 @@ class SamplerWrapper: def __init__( self, - sampler: ContinualSequenceSampler | BatchedContinualSequenceSampler, + sampler: Union[ContinualSequenceSampler, BatchedContinualSequenceSampler], start_index_ordered: np.ndarray, seed: int, ): diff --git a/grain/_src/python/experimental/example_packing/packing.py b/grain/_src/python/experimental/example_packing/packing.py index 50045b9a..c3e5e805 100644 --- a/grain/_src/python/experimental/example_packing/packing.py +++ b/grain/_src/python/experimental/example_packing/packing.py @@ -12,7 +12,7 @@ """ import dataclasses -from typing import Generic, Iterator, TypeVar, cast +from typing import Generic, Iterator, TypeVar, cast, Union from grain._src.core import tree from grain._src.python import record @@ -180,7 +180,7 @@ class PackAndBatchOperation(Generic[_T]): length_struct: jt.PyTree[int] batch_size: int # We don't know input shapes and corresponding buffer shapes until __call__. - _cur_batch: _PackedBatch | None = None + _cur_batch: Union[_PackedBatch, None] = None def __call__( self, input_iterator: Iterator[record.Record[_T]] diff --git a/grain/_src/python/grain_pool.py b/grain/_src/python/grain_pool.py index 4a16a99c..8a038d3b 100644 --- a/grain/_src/python/grain_pool.py +++ b/grain/_src/python/grain_pool.py @@ -42,7 +42,7 @@ """ from __future__ import annotations - +import sys from collections.abc import Iterator import cProfile import dataclasses @@ -54,7 +54,7 @@ import queue import threading import traceback -from typing import Any, Protocol, TypeVar +from typing import Any, Protocol, TypeVar, Union from absl import logging import cloudpickle @@ -68,6 +68,7 @@ from grain._src.python.options import MultiprocessingOptions # pylint: disable=g-importing-member T = TypeVar("T") +PY310 = sys.version_info >= (3, 10) # Maximum number of threads for starting and stopping processes. _PROCESS_MANAGEMENT_MAX_THREADS = 64 @@ -86,7 +87,7 @@ class _ProcessingComplete: _PROCESSING_COMPLETE = _ProcessingComplete() -@dataclasses.dataclass(frozen=True, slots=True) +@dataclasses.dataclass(**({"slots": True, "frozen": True} if PY310 else {"frozen": True})) class GrainPoolElement: """Wrapper for output records emited by Grain Pool.""" @@ -412,7 +413,7 @@ def _shutdown(self) -> None: process.terminate() -@dataclasses.dataclass(frozen=True, slots=True) +@dataclasses.dataclass(**({"slots": True, "frozen": True} if PY310 else {"frozen": True})) class _ReaderQueueElement: """Element to be added to the reader queue.""" @@ -427,7 +428,7 @@ class _GrainPoolProcessingComplete: _GRAIN_POOL_PROCESSING_COMPLETE = _GrainPoolProcessingComplete() -_QueueElement = _ReaderQueueElement | _GrainPoolProcessingComplete | Exception +_QueueElement = Union[_ReaderQueueElement, _GrainPoolProcessingComplete, Exception] class GrainPoolProcessingError(Exception): diff --git a/grain/_src/python/lazy_dataset/data_sources.py b/grain/_src/python/lazy_dataset/data_sources.py index a68f2376..7956be11 100644 --- a/grain/_src/python/lazy_dataset/data_sources.py +++ b/grain/_src/python/lazy_dataset/data_sources.py @@ -13,9 +13,8 @@ # limitations under the License. """LazyDataset data sources.""" -from typing import Protocol +from typing import Protocol, Union -from absl import logging from grain._src.python.lazy_dataset import lazy_dataset @@ -51,7 +50,7 @@ def log_lineage(self): def log_lineage_for_sources( - root: lazy_dataset.LazyMapDataset | lazy_dataset.LazyIterDataset, + root: Union[lazy_dataset.LazyMapDataset, lazy_dataset.LazyIterDataset] ): """Traverses tree of transformations and logs lineage on source datasets.""" pass diff --git a/grain/_src/python/lazy_dataset/lazy_dataset.py b/grain/_src/python/lazy_dataset/lazy_dataset.py index f969aa47..fbc1eca9 100644 --- a/grain/_src/python/lazy_dataset/lazy_dataset.py +++ b/grain/_src/python/lazy_dataset/lazy_dataset.py @@ -45,13 +45,12 @@ import copy import functools import time -from typing import Any, Callable, Optional, TypeVar, overload +from typing import Any, Callable, Optional, TypeVar, overload, Union from concurrent import futures from grain._src.core import sharding from grain._src.core import tree from grain._src.core import usage_logging -import multiprocessing as mp from grain._src.python import grain_pool from grain._src.python import options as grain_options from grain._src.python import shared_memory_array @@ -66,7 +65,7 @@ class LazyMapDataset(Sequence[T], abc.ABC): _functions: dict[str, Callable[[LazyMapDataset], Any]] = {} - def __init__(self, parents: LazyMapDataset | Sequence[LazyMapDataset] = ()): + def __init__(self, parents: Union[LazyMapDataset, Sequence[LazyMapDataset]] = ()): if isinstance(parents, LazyMapDataset): self._parents = (parents,) else: @@ -91,7 +90,7 @@ def __getitem__(self, index: slice) -> LazyMapDataset: ... @overload - def __getitem__(self, index: int) -> T | None: + def __getitem__(self, index: int) -> Union[T, None]: ... @abc.abstractmethod @@ -121,7 +120,7 @@ def __iter__(self) -> LazyDatasetIterator[T]: return self.to_iter_dataset().__iter__() def to_iter_dataset( - self, read_options: grain_options.ReadOptions | None = None + self, read_options: Union[grain_options.ReadOptions, None] = None ) -> LazyIterDataset[T]: """Syntactic sugar to construct a LazyIterDataset.""" return PrefetchLazyIterDataset( @@ -137,9 +136,10 @@ class LazyIterDataset(Iterable[T], abc.ABC): def __init__( self, parents: ( - LazyMapDataset - | LazyIterDataset - | Sequence[LazyMapDataset | LazyIterDataset] + Union[ + LazyMapDataset, + LazyIterDataset, + Sequence[Union[LazyMapDataset, LazyIterDataset]]] ) = (), ): if isinstance(parents, (LazyMapDataset, LazyIterDataset)): @@ -149,11 +149,11 @@ def __init__( usage_logging.log_event("LazyIterDataset", tag_3="PyGrain") @property - def parents(self) -> Sequence[LazyMapDataset | LazyIterDataset]: + def parents(self) -> Sequence[Union[LazyMapDataset, LazyIterDataset]]: return self._parents @property - def _parent(self) -> LazyMapDataset | LazyIterDataset: + def _parent(self) -> Union[LazyMapDataset, LazyIterDataset]: assert len(self._parents) == 1, self._parents return self._parents[0] @@ -452,7 +452,7 @@ def __next__(self) -> T: def get_element_producer_fn( worker_index: int, worker_count: int - ) -> Iterator[tuple[T, dict[str, Any] | None]]: + ) -> Iterator[tuple[T, Union[dict[str, Any], None]]]: # Recover from the last recorded state for the given worker. worker_state = state[_WORKERS_STATE][str(worker_index)] parent.set_parent_maps_slice(slice(worker_index, None, worker_count)) @@ -502,7 +502,7 @@ def get_state(self) -> dict[str, Any]: class RangeLazyMapDataset(LazyMapDataset[int]): """Range data source, similar to python range() function.""" - def __init__(self, start: int, stop: int | None = None, step: int = 1): + def __init__(self, start: int, stop: Union[int, None] = None, step: int = 1): super().__init__() self.start = 0 if stop is None else start self.stop = start if stop is None else stop @@ -522,7 +522,7 @@ def __getitem__(self, index): def to_iter_dataset( self, - read_options: grain_options.ReadOptions | None = None, + read_options: Union[grain_options.ReadOptions, None] = None, ) -> LazyIterDataset[int]: """Syntactic sugar to construct a LazyIterDataset.""" return PrefetchLazyIterDataset( @@ -550,7 +550,7 @@ def __init__( def __len__(self) -> int: return self._end - self._start - def __getitem__(self, index: int | slice) -> Optional[T]: + def __getitem__(self, index: Union[int, slice]) -> Optional[T]: if isinstance(index, slice): return self.slice(index) epoch = index // len(self) diff --git a/grain/_src/python/lazy_dataset/transformations/batch.py b/grain/_src/python/lazy_dataset/transformations/batch.py index f40ccf5f..d84ac8e5 100644 --- a/grain/_src/python/lazy_dataset/transformations/batch.py +++ b/grain/_src/python/lazy_dataset/transformations/batch.py @@ -25,13 +25,12 @@ def _make_batch(values: Sequence[T]) -> T: - match len(values): - case 0: - return () - case 1: - tree.map_structure(lambda x: np.expand_dims(x, axis=0), values[0]) - return tree.map_structure(lambda *xs: np.stack(xs), values[0], *values[1:]) - + if len(values) == 0: + return () + elif len(values) == 1: + return tree.map_structure(lambda x: np.expand_dims(x, axis=0), values[0]) + else: + return tree.map_structure(lambda *xs: np.stack(xs), values[0], *values[1:]) class _BatchLazyDatasetIterator(lazy_dataset.LazyDatasetIterator[T]): """Iterator that batches elements.""" diff --git a/grain/_src/python/lazy_dataset/transformations/filter.py b/grain/_src/python/lazy_dataset/transformations/filter.py index a927528b..8db633c7 100644 --- a/grain/_src/python/lazy_dataset/transformations/filter.py +++ b/grain/_src/python/lazy_dataset/transformations/filter.py @@ -13,7 +13,7 @@ # limitations under the License. """Filter transformation for LazyDataset.""" -from typing import Any, Callable, TypeVar +from typing import Any, Callable, TypeVar, Union from grain._src.core import transforms from grain._src.python.lazy_dataset import lazy_dataset @@ -30,7 +30,7 @@ class FilterLazyMapDataset(lazy_dataset.LazyMapDataset[T]): def __init__( self, parent: lazy_dataset.LazyMapDataset[T], - transform: transforms.FilterTransform | Callable[[T], bool], + transform: Union[transforms.FilterTransform, Callable[[T], bool]], ): super().__init__(parent) if isinstance(transform, transforms.FilterTransform): @@ -97,7 +97,7 @@ class FilterLazyIterDataset(lazy_dataset.LazyIterDataset[T]): def __init__( self, parent: lazy_dataset.LazyIterDataset, - transform: transforms.FilterTransform | Callable[[T], bool], + transform: Union[transforms.FilterTransform, Callable[[T], bool]], ): super().__init__(parent) if isinstance(transform, transforms.FilterTransform): diff --git a/grain/_src/python/lazy_dataset/transformations/map.py b/grain/_src/python/lazy_dataset/transformations/map.py index d340dd50..2a4b5ab2 100644 --- a/grain/_src/python/lazy_dataset/transformations/map.py +++ b/grain/_src/python/lazy_dataset/transformations/map.py @@ -73,8 +73,8 @@ def release_rng(self, rng: np.random.Generator): def _get_map_fn_and_seed( - transform: _MapTransformType, seed: int | None = None -) -> tuple[Callable[..., T], int | None]: + transform: _MapTransformType, seed: Union[int, None] = None +) -> tuple[Callable[..., T], Union[int, None]]: """Extracts a map fn from `transform`. If a seed is returned map fn requires a seed. @@ -117,7 +117,7 @@ def __init__( self, parent: lazy_dataset.LazyMapDataset, transform: _MapTransformType, - seed: int | None = None, + seed: Union[int, None] = None, ): super().__init__(parent) self._map_fn, seed = _get_map_fn_and_seed(transform, seed) @@ -148,7 +148,7 @@ class MapWithIndexLazyMapDataset(lazy_dataset.LazyMapDataset[T]): def __init__( self, parent: lazy_dataset.LazyMapDataset, - transform: transforms.MapWithIndexTransform | Callable[[int, Any], T], + transform: Union[transforms.MapWithIndexTransform, Callable[[int, Any], T]], ): super().__init__(parent) if isinstance(transform, transforms.MapWithIndexTransform): @@ -176,7 +176,7 @@ def __init__( self, parent: lazy_dataset.LazyDatasetIterator, map_fn: Callable[..., T], - seed: int | None = None, + seed: Union[int, None] = None, ): super().__init__() self._parent = parent @@ -223,7 +223,7 @@ def __init__( self, parent: lazy_dataset.LazyIterDataset, transform: _MapTransformType, - seed: int | None = None, + seed: Union[int, None] = None, ): super().__init__(parent) self._map_fn, self._seed = _get_map_fn_and_seed(transform, seed) diff --git a/grain/_src/python/lazy_dataset/transformations/mix.py b/grain/_src/python/lazy_dataset/transformations/mix.py index dadaee6f..fa845e48 100644 --- a/grain/_src/python/lazy_dataset/transformations/mix.py +++ b/grain/_src/python/lazy_dataset/transformations/mix.py @@ -52,7 +52,7 @@ class SelectionWithProportionsMap(DatasetSelectionMap): def __init__( self, parents: Sequence[lazy_dataset.LazyMapDataset], - proportions: Sequence[float | int] | None = None, + proportions: Union[Sequence[Union[float, int]], None] = None, ): # Normalize proportions if proportions is None: @@ -90,8 +90,8 @@ class MixedLazyMapDataset(lazy_dataset.LazyMapDataset[T]): def __init__( self, parents: Sequence[lazy_dataset.LazyMapDataset[T]], - proportions: Sequence[float | int] | None = None, - selection_map: DatasetSelectionMap | None = None, + proportions: Union[Sequence[Union[float, int]], None] = None, + selection_map: Union[DatasetSelectionMap, None] = None, ): """Initializes the mixed dataset. @@ -139,7 +139,7 @@ class _MixedLazyDatasetIterator(lazy_dataset.LazyDatasetIterator[T]): def __init__( self, parents: Sequence[lazy_dataset.LazyDatasetIterator[T]], - proportions: Sequence[float | int] | None = None, + proportions: Union[Sequence[Union[float, int]], None] = None, ): super().__init__() self._parents = parents @@ -190,7 +190,7 @@ class MixedLazyIterDataset(lazy_dataset.LazyIterDataset[T]): def __init__( self, parents: Sequence[lazy_dataset.LazyIterDataset], - proportions: Sequence[float | int] | None = None, + proportions: Union[Sequence[Union[float, int]], None] = None, ): super().__init__(parents) # Normalize proportions diff --git a/grain/_src/python/lazy_dataset/transformations/packing.py b/grain/_src/python/lazy_dataset/transformations/packing.py index cd193ee3..c1bfa17b 100644 --- a/grain/_src/python/lazy_dataset/transformations/packing.py +++ b/grain/_src/python/lazy_dataset/transformations/packing.py @@ -14,7 +14,7 @@ """Implements packing transformations.""" import collections import copy -from typing import Any +from typing import Any, Union from grain._src.core import tree from grain._src.python.lazy_dataset import lazy_dataset @@ -81,7 +81,7 @@ class SingleBinPackLazyIterDataset(lazy_dataset.LazyIterDataset): def __init__( self, parent: lazy_dataset.LazyIterDataset, - length_struct: PyTree[int | None], + length_struct: PyTree[Union[int, None]], ): super().__init__(parent) self._length_struct = length_struct @@ -101,13 +101,13 @@ class SingleBinPackLazyDatasetIterator(lazy_dataset.LazyDatasetIterator): def __init__( self, parent: lazy_dataset.LazyDatasetIterator, - length_struct: PyTree[int | None], + length_struct: PyTree[Union[int, None]], ): self._parent = parent self._length_struct = length_struct # Same as above but flattened. Some operations are easier using the # flattened representation. - self._flat_lengths: list[int | None] = tree.flatten(length_struct) + self._flat_lengths: list[Union[int, None]] = tree.flatten(length_struct) # Buffer for fully packed elements (not flattened) self._packed_elements = collections.deque() # Variable length list of flat elements going into the next packed example. diff --git a/grain/_src/python/lazy_dataset/transformations/packing_test.py b/grain/_src/python/lazy_dataset/transformations/packing_test.py index c1c5e850..16e62b7a 100644 --- a/grain/_src/python/lazy_dataset/transformations/packing_test.py +++ b/grain/_src/python/lazy_dataset/transformations/packing_test.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for batch transformation.""" +import sys from absl.testing import absltest from absl.testing import parameterized from grain._src.python.lazy_dataset import data_sources @@ -25,6 +26,7 @@ # pylint: enable=unused-import import numpy as np +PY310 = sys.version_info >= (3, 10) class SingleBinPackLazyIterDatasetTest(parameterized.TestCase): @@ -47,7 +49,7 @@ def test_pack_single_feature(self): # Second, fourth and five element packed together. ([5, 6, 7, 8], [1, 1, 2, 3], [0, 1, 0, 0]), ] - for actual, expected in zip(ds_iter, expected_elements, strict=True): + for actual, expected in zip(ds_iter, expected_elements, **({"strict": True} if PY310 else {})): # Elements are tuples with (inputs, inputs_segment_ids, inputs_positions). self.assertLen(actual, 3) np.testing.assert_array_equal(actual, expected) @@ -72,7 +74,7 @@ def test_pack_single_feature_remainder_is_padded(self): ([5, 6, 7, 0], [1, 1, 2, 0], [0, 1, 0, 0]), ] - for actual, expected in zip(ds_iter, expected_elements, strict=True): + for actual, expected in zip(ds_iter, expected_elements, **({"strict": True} if PY310 else {})): # Elements are tuples with (inputs, inputs_segment_ids, inputs_positions). self.assertLen(actual, 3) np.testing.assert_array_equal(actual, expected) @@ -129,7 +131,7 @@ def test_pack_single_feature_in_dict(self, feature: str): }, ] - for actual, expected in zip(ds_iter, expected_elements, strict=True): + for actual, expected in zip(ds_iter, expected_elements, **({"strict": True} if PY310 else {})): # Compare keys. self.assertSequenceEqual(sorted(actual), sorted(expected)) np.testing.assert_array_equal(actual[feature], expected[feature]) @@ -199,7 +201,7 @@ def test_pack_multiple_features_same_sequences_length(self, feature: str): "targets_positions": [0, 1, 2, 0], }, ] - for actual, expected in zip(ds_iter, expected_elements, strict=True): + for actual, expected in zip(ds_iter, expected_elements, **({"strict": True} if PY310 else {})): # Compare keys. self.assertSequenceEqual(sorted(actual), sorted(expected)) np.testing.assert_array_equal(actual[feature], expected[feature]) @@ -259,7 +261,7 @@ def test_pack_multiple_features_different_sequences_length( "targets_positions": [0, 1, 2, 0], }, ] - for actual, expected in zip(ds_iter, expected_elements, strict=True): + for actual, expected in zip(ds_iter, expected_elements, **({"strict": True} if PY310 else {})): np.testing.assert_array_equal(actual[feature], expected[feature]) @parameterized.parameters( @@ -311,7 +313,7 @@ def test_pack_two_dimensional_features(self, feature: str): "input_vectors_positions": [0, 1, 0], }, ] - for actual, expected in zip(ds_iter, expected_elements, strict=True): + for actual, expected in zip(ds_iter, expected_elements, **({"strict": True} if PY310 else {})): np.testing.assert_array_equal(actual[feature], expected[feature]) def test_checkpointing(self): diff --git a/grain/_src/python/lazy_dataset/transformations/repeat.py b/grain/_src/python/lazy_dataset/transformations/repeat.py index 1194d1ca..70e0961a 100644 --- a/grain/_src/python/lazy_dataset/transformations/repeat.py +++ b/grain/_src/python/lazy_dataset/transformations/repeat.py @@ -13,7 +13,7 @@ # limitations under the License. """Implements repeat transformation.""" import sys -from typing import TypeVar +from typing import TypeVar, Union from grain._src.python.lazy_dataset import lazy_dataset @@ -32,7 +32,7 @@ class RepeatLazyMapDataset(lazy_dataset.LazyMapDataset[T]): def __init__( self, parent: lazy_dataset.LazyMapDataset[T], - num_epochs: int | None = None, + num_epochs: Union[int, None] = None, ): super().__init__(parent) if len(parent) >= sys.maxsize: diff --git a/grain/_src/python/load.py b/grain/_src/python/load.py index 0f00b760..0e727c3a 100644 --- a/grain/_src/python/load.py +++ b/grain/_src/python/load.py @@ -1,5 +1,7 @@ """High level APIs that serve as a single endpoint for very common use cases.""" +from typing import Union + from grain._src.core import sharding from grain._src.core import transforms from grain._src.core import usage_logging @@ -12,15 +14,15 @@ def load( source: data_sources.RandomAccessDataSource, *, - num_epochs: int | None = None, + num_epochs: Union[int, None] = None, shuffle: bool = False, - seed: int | None = None, + seed: Union[int, None] = None, shard_options: sharding.ShardOptions = sharding.NoSharding(), transformations: transforms.Transformations = (), - batch_size: int | None = None, + batch_size: Union[int, None] = None, drop_remainder: bool = False, - worker_count: int | None = 0, - read_options: options.ReadOptions | None = None, + worker_count: Union[int, None] = 0, + read_options: Union[options.ReadOptions, None] = None, ) -> data_loader.DataLoader: """Convenient method for simple pipelines on top of a data source. diff --git a/grain/_src/python/options.py b/grain/_src/python/options.py index 4e8b966f..9e29e195 100644 --- a/grain/_src/python/options.py +++ b/grain/_src/python/options.py @@ -13,9 +13,11 @@ # limitations under the License. """Dataclasses for holdings options.""" import dataclasses +import sys +PY310 = sys.version_info >= (3, 10) -@dataclasses.dataclass(slots=True) +@dataclasses.dataclass(**({"slots": True} if PY310 else {})) class ReadOptions: """Options for reading data from the DataSource. @@ -38,7 +40,7 @@ class ReadOptions: prefetch_buffer_size: int = 500 -@dataclasses.dataclass(slots=True) +@dataclasses.dataclass(**({"slots": True} if PY310 else {})) class MultiprocessingOptions: """Options for using Python multiprocessing. diff --git a/grain/_src/python/record.py b/grain/_src/python/record.py index 66ad33c5..510c3f54 100644 --- a/grain/_src/python/record.py +++ b/grain/_src/python/record.py @@ -14,13 +14,14 @@ """Define record class used by various modules in the Grain Python Backend.""" import dataclasses +import sys from typing import Optional, Generic, TypeVar import numpy as np T = TypeVar("T") +PY310 = sys.version_info >= (3, 10) - -@dataclasses.dataclass(slots=True) +@dataclasses.dataclass(**({"slots": True} if PY310 else {})) class RecordMetadata: """RecordMetadata contains metadata about indidivual records. @@ -47,7 +48,7 @@ def __str__(self): ) -@dataclasses.dataclass(slots=True) +@dataclasses.dataclass(**({"slots": True} if PY310 else {})) class Record(Generic[T]): metadata: RecordMetadata data: T diff --git a/grain/_src/python/shared_memory_array.py b/grain/_src/python/shared_memory_array.py index 53b4cd13..d3d35f94 100644 --- a/grain/_src/python/shared_memory_array.py +++ b/grain/_src/python/shared_memory_array.py @@ -14,6 +14,7 @@ """Shared memory array.""" from __future__ import annotations +import sys import dataclasses import math import mmap @@ -25,8 +26,9 @@ import numpy as np import numpy.typing as npt +PY310 = sys.version_info >= (3, 10) -@dataclasses.dataclass(frozen=True, slots=True) +@dataclasses.dataclass(**({"slots": True, "frozen": True} if PY310 else {"frozen": True})) class SharedMemoryArrayMetadata: name: str shape: Iterable[int] diff --git a/grain/oss/build_whl.sh b/grain/oss/build_whl.sh index 39ee3ae5..7488c0c2 100644 --- a/grain/oss/build_whl.sh +++ b/grain/oss/build_whl.sh @@ -47,4 +47,4 @@ function main() { echo $(date) : "=== Output wheel file is in: ${DEST}" } -main "$@" +main "$@" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index cb042b2e..344e5e7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ ] readme = "README.md" license = { file = "LICENSE" } -requires-python = ">=3.10" +requires-python = ">=3.9" classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License",