Skip to content

Commit

Permalink
Implement Github actions for CI and fix Python 3.9 errors
Browse files Browse the repository at this point in the history
Testing on Python 3.9, 3.10, and 3.11.

Unfortunately not 3.12 as TF doesn't have wheels for it yet.

FIX Python 3.9 errors:
* dataclass don't have a slot argument
* replace | with Union[,] in type hints
* remove match statements
* removed unused imports
  • Loading branch information
fabianp committed Feb 6, 2024
1 parent db639ce commit 0b59607
Show file tree
Hide file tree
Showing 22 changed files with 179 additions and 123 deletions.
54 changes: 54 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
name: tests

on:
push:
branches: ["main"]
pull_request:
branches: ["main"]

jobs:
build-and-test:
name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}"
runs-on: "${{ matrix.os }}"

strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"]
os: [ubuntu-latest]

steps:
- uses: "actions/checkout@v2"
# - uses: "actions/setup-python@v4"
# with:
# python-version: "${{ matrix.python-version }}"
- name: Create directory
run: |
mkdir -p /tmp/grain
cp -r . /tmp/grain
- name: Build package
run: |
set -xe
export PYTHON_VERSION=${{ matrix.python-version }}
export PYTHON_MAJOR_VERSION=$(echo $PYTHON_VERSION | cut -d. -f1)
export PYTHON_MINOR_VERSION=$(echo $PYTHON_VERSION | cut -d. -f2)
export CP_VERSION="cp${PYTHON_MAJOR_VERSION}${PYTHON_MINOR_VERSION}"
export BAZEL_VERSION="5.4.0"
export AUDITWHEEL_PLATFORM="manylinux2014_x86_64"
cd /tmp/grain
DOCKER_BUILDKIT=1 docker build --progress=plain --no-cache \
--build-arg PYTHON_VERSION=${PYTHON_VERSION} \
--build-arg PYTHON_MAJOR_VERSION=${PYTHON_MAJOR_VERSION} \
--build-arg PYTHON_MINOR_VERSION=${PYTHON_MINOR_VERSION} \
--build-arg BAZEL_VERSION=${BAZEL_VERSION} \
-t grain:${PYTHON_VERSION} - < grain/oss/build.Dockerfile
docker run --rm -a stdin -a stdout -a stderr \
--env PYTHON_VERSION=${PYTHON_VERSION} \
--env PYTHON_MAJOR_VERSION=${PYTHON_MAJOR_VERSION} \
--env PYTHON_MINOR_VERSION=${PYTHON_MINOR_VERSION} \
--env PYTHON_BIN_PATH="/opt/python/${CP_VERSION}-${CP_VERSION}/bin/python" \
--env BAZEL_VERSION=${BAZEL_VERSION} \
-v /tmp/grain:/tmp/grain \
--env AUDITWHEEL_PLATFORM=${AUDITWHEEL_PLATFORM} \
--name grain grain:${PYTHON_VERSION} \
bash grain/oss/build_whl.sh
83 changes: 40 additions & 43 deletions grain/_src/python/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@
import json
from multiprocessing import pool
import os
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar
import sys
from typing import Any, Callable, Optional, Sequence, Tuple, TypeVar, Union

from absl import logging
from concurrent import futures
from grain._src.core import sharding
from grain._src.core import transforms
from grain._src.core import tree
from grain._src.core import usage_logging
import multiprocessing as mp
from grain._src.python import grain_pool
from grain._src.python import options
from grain._src.python import record
Expand All @@ -45,6 +45,7 @@

_T = TypeVar("_T")
_IteratorState = dict[str, Any]
PY310 = sys.version_info >= (3, 10)

# Dictionary keys used in checkpoints.
_VERSION = "version"
Expand Down Expand Up @@ -82,7 +83,7 @@ def _determine_worker_count(input_worker_count: int | None) -> int:
raise ValueError("Can't determine worker count. Please set worker count.")


@dataclasses.dataclass(frozen=True, slots=True)
@dataclasses.dataclass(**({"slots": True, "frozen": True} if PY310 else {"frozen": True}))
class _ReaderQueueElement:
"""Element to be added to the reader queue."""

Expand All @@ -99,7 +100,7 @@ class _GrainPoolProcessingComplete:


_GRAIN_POOL_PROCESSING_COMPLETE = _GrainPoolProcessingComplete()
_QueueElement = _ReaderQueueElement | _GrainPoolProcessingComplete | Exception
_QueueElement = Union[_ReaderQueueElement, _GrainPoolProcessingComplete, Exception]


@contextlib.contextmanager
Expand Down Expand Up @@ -440,54 +441,50 @@ def __str__(self):


def _apply_transform(
transform: transforms.Transformation | Operation,
input_iterator: Iterator[record.Record],
transform: transforms.Transformation | Operation,
input_iterator: Iterator[record.Record],
) -> Iterator[record.Record]:
"""Applies the `transform` to records in the iterator."""
fn: Callable[[record.Record], Tuple[record.Record, bool]] = None
# pylint: disable=g-long-lambda
# pytype: disable=attribute-error
match transform:
case transforms.MapTransform():
fn = lambda r: (record.Record(r.metadata, transform.map(r.data)), True)
case transforms.RandomMapTransform():
fn = lambda r: (
record.Record(
r.metadata, transform.random_map(r.data, r.metadata.rng)
),
True,
)
case transforms.TfRandomMapTransform():
fn = lambda r: (
record.Record(
r.metadata, transform.np_random_map(r.data, r.metadata.rng)
),
True,
)
case transforms.FilterTransform():
fn = lambda r: (r, bool(transform.filter(r.data)))
case transforms.BatchTransform():
batch_op = BatchOperation(
batch_size=transform.batch_size,
drop_remainder=transform.drop_remainder,
)
batch_op.disable_deprecation_message()
for r in batch_op(input_iterator):
yield r
case _:
# Transform is a legacy style operation and __call__() yield output
# records.
for r in transform(input_iterator):
yield r
# pytype: enable=attribute-error
# pylint: enable=g-long-lambda

if isinstance(transform, transforms.MapTransform):
fn = lambda r: (record.Record(r.metadata, transform.map(r.data)), True)
elif isinstance(transform, transforms.RandomMapTransform):
fn = lambda r: (
record.Record(
r.metadata, transform.random_map(r.data, r.metadata.rng)
),
True,
)
elif isinstance(transform, transforms.TfRandomMapTransform):
fn = lambda r: (
record.Record(
r.metadata, transform.np_random_map(r.data, r.metadata.rng)
),
True,
)
elif isinstance(transform, transforms.FilterTransform):
fn = lambda r: (r, bool(transform.filter(r.data)))
elif isinstance(transform, transforms.BatchTransform):
batch_op = BatchOperation(
batch_size=transform.batch_size,
drop_remainder=transform.drop_remainder,
)
batch_op.disable_deprecation_message()
for r in batch_op(input_iterator):
yield r
else:
# Transform is a legacy style operation and __call__() yield output
# records.
for r in transform(input_iterator):
yield r

for input_record in input_iterator:
try:
output_record, filter_result = fn(input_record)
except Exception as e:
raise ValueError(
f"PyGrain encountered an error when applying {transform}."
f"PyGrain encountered an error when applying {transform}."
) from e
if filter_result:
yield output_record
3 changes: 2 additions & 1 deletion grain/_src/python/data_loader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from collections.abc import Sequence
import pathlib
from unittest import mock
from typing import Union

from absl import flags
from absl.testing import absltest
Expand Down Expand Up @@ -144,7 +145,7 @@ def setUp(self):
self.testdata_dir = pathlib.Path(FLAGS.test_srcdir) / "testdata"

def _create_data_loader_for_short_sequence(
self, transformations, *, worker_count: int = 0, seed: int | None = None
self, transformations, *, worker_count: int = 0, seed: Union[int, None] = None
) -> data_loader_lib.DataLoader:
# Generates elements [0, 1, 2, 3, 4, 5, 6, 7].
range_data_source = RangeDataSource(start=0, stop=8, step=1)
Expand Down
10 changes: 3 additions & 7 deletions grain/_src/python/data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,11 @@
from collections.abc import Sequence
import math
from multiprocessing import shared_memory
import os
import threading
import typing
from typing import Any, Generic, Protocol, SupportsIndex, TypeVar
from typing import Any, Generic, Protocol, SupportsIndex, TypeVar, Union

from absl import logging
import array_record.python.array_record_data_source as array_record
from etils import epath
from grain._src.core import usage_logging

T = TypeVar("T")

Expand Down Expand Up @@ -113,9 +109,9 @@ class InMemoryDataSource(shared_memory.ShareableList):

def __init__(
self,
elements: Sequence[Any] | None = None,
elements: Union[Sequence[Any], None] = None,
*,
name: str | None = None,
name: Union[str, None] = None,
):
"""Creates a new InMemoryDataSource object.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from collections.abc import Sequence
import dataclasses
import heapq
from typing import Optional, Tuple
from typing import Optional, Tuple, Union

from grain._src.core import sharding
from grain._src.python import record
Expand Down Expand Up @@ -362,7 +362,7 @@ class SamplerWrapper:

def __init__(
self,
sampler: ContinualSequenceSampler | BatchedContinualSequenceSampler,
sampler: Union[ContinualSequenceSampler, BatchedContinualSequenceSampler],
start_index_ordered: np.ndarray,
seed: int,
):
Expand Down
4 changes: 2 additions & 2 deletions grain/_src/python/experimental/example_packing/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"""

import dataclasses
from typing import Generic, Iterator, TypeVar, cast
from typing import Generic, Iterator, TypeVar, cast, Union

from grain._src.core import tree
from grain._src.python import record
Expand Down Expand Up @@ -180,7 +180,7 @@ class PackAndBatchOperation(Generic[_T]):
length_struct: jt.PyTree[int]
batch_size: int
# We don't know input shapes and corresponding buffer shapes until __call__.
_cur_batch: _PackedBatch | None = None
_cur_batch: Union[_PackedBatch, None] = None

def __call__(
self, input_iterator: Iterator[record.Record[_T]]
Expand Down
11 changes: 6 additions & 5 deletions grain/_src/python/grain_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
"""

from __future__ import annotations

import sys
from collections.abc import Iterator
import cProfile
import dataclasses
Expand All @@ -54,7 +54,7 @@
import queue
import threading
import traceback
from typing import Any, Protocol, TypeVar
from typing import Any, Protocol, TypeVar, Union

from absl import logging
import cloudpickle
Expand All @@ -68,6 +68,7 @@
from grain._src.python.options import MultiprocessingOptions # pylint: disable=g-importing-member

T = TypeVar("T")
PY310 = sys.version_info >= (3, 10)

# Maximum number of threads for starting and stopping processes.
_PROCESS_MANAGEMENT_MAX_THREADS = 64
Expand All @@ -86,7 +87,7 @@ class _ProcessingComplete:
_PROCESSING_COMPLETE = _ProcessingComplete()


@dataclasses.dataclass(frozen=True, slots=True)
@dataclasses.dataclass(**({"slots": True, "frozen": True} if PY310 else {"frozen": True}))
class GrainPoolElement:
"""Wrapper for output records emited by Grain Pool."""

Expand Down Expand Up @@ -412,7 +413,7 @@ def _shutdown(self) -> None:
process.terminate()


@dataclasses.dataclass(frozen=True, slots=True)
@dataclasses.dataclass(**({"slots": True, "frozen": True} if PY310 else {"frozen": True}))
class _ReaderQueueElement:
"""Element to be added to the reader queue."""

Expand All @@ -427,7 +428,7 @@ class _GrainPoolProcessingComplete:


_GRAIN_POOL_PROCESSING_COMPLETE = _GrainPoolProcessingComplete()
_QueueElement = _ReaderQueueElement | _GrainPoolProcessingComplete | Exception
_QueueElement = Union[_ReaderQueueElement, _GrainPoolProcessingComplete, Exception]


class GrainPoolProcessingError(Exception):
Expand Down
5 changes: 2 additions & 3 deletions grain/_src/python/lazy_dataset/data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
# limitations under the License.
"""LazyDataset data sources."""

from typing import Protocol
from typing import Protocol, Union

from absl import logging
from grain._src.python.lazy_dataset import lazy_dataset


Expand Down Expand Up @@ -51,7 +50,7 @@ def log_lineage(self):


def log_lineage_for_sources(
root: lazy_dataset.LazyMapDataset | lazy_dataset.LazyIterDataset,
root: Union[lazy_dataset.LazyMapDataset, lazy_dataset.LazyIterDataset]
):
"""Traverses tree of transformations and logs lineage on source datasets."""
pass
Loading

0 comments on commit 0b59607

Please sign in to comment.