Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improvement of Ray sink API #2237

Merged
merged 17 commits into from
Apr 23, 2024
5 changes: 4 additions & 1 deletion python/python/lance/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
from .schema import LanceSchema


DEFAULT_MAX_BYTES_PER_FILE = 90 * 1024 * 1024 * 1024


class FragmentMetadata:
"""Metadata of a Fragment in the dataset."""

Expand Down Expand Up @@ -496,7 +499,7 @@ def write_fragments(
mode: str = "append",
max_rows_per_file: int = 1024 * 1024,
max_rows_per_group: int = 1024,
max_bytes_per_file: int = 90 * 1024 * 1024 * 1024,
max_bytes_per_file: int = DEFAULT_MAX_BYTES_PER_FILE,
progress: Optional[FragmentWriteProgress] = None,
use_experimental_writer: bool = False,
) -> List[FragmentMetadata]:
Expand Down
99 changes: 94 additions & 5 deletions python/python/lance/ray/sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
import pyarrow as pa

import lance
from lance.fragment import FragmentMetadata, write_fragments
from lance.fragment import DEFAULT_MAX_BYTES_PER_FILE, FragmentMetadata, write_fragments

from ..dependencies import ray

if TYPE_CHECKING:
import pandas as pd

__all__ = ["LanceDatasink", "LanceFragmentWriter", "LanceCommitter"]
__all__ = ["LanceDatasink", "LanceFragmentWriter", "LanceCommitter", "write_lance"]


def _pd_to_arrow(
Expand All @@ -52,6 +52,7 @@ def _write_fragment(
*,
schema: Optional[pa.Schema] = None,
max_rows_per_file: int = 1024 * 1024,
max_bytes_per_file: Optional[int] = None,
max_rows_per_group: int = 1024, # Only useful for v1 writer.
use_experimental_writer: bool = False,
) -> Tuple[FragmentMetadata, pa.Schema]:
Expand All @@ -74,13 +75,18 @@ def record_batch_converter():
tbl = _pd_to_arrow(block, schema)
yield from tbl.to_batches()

max_bytes_per_file = (
DEFAULT_MAX_BYTES_PER_FILE if max_bytes_per_file is None else max_bytes_per_file
)
Comment on lines +78 to +80
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just put DEFAULT_MAX_BYTES_PER_FILE in the signature as the default value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So using None here we can later detect whether users specify this value or not. During a benchmark, the default 90GB causes OOM (now we know it was a bug in arrow. This allows us to provide a better value later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Seems like a good thing to write a TODO comment for.


reader = pa.RecordBatchReader.from_batches(schema, record_batch_converter())
fragments = write_fragments(
reader,
uri,
schema=schema,
max_rows_per_file=max_rows_per_file,
max_rows_per_group=max_rows_per_group,
max_bytes_per_file=max_bytes_per_file,
use_experimental_writer=use_experimental_writer,
)
return [(fragment, schema) for fragment in fragments]
Expand Down Expand Up @@ -213,24 +219,43 @@ class LanceFragmentWriter:
in distributed fashion.

Parameters
----------
uri : str
The base URI of the dataset.
transform : Callable[[pa.Table], Union[pa.Table, Generator]], optional
A callable to transform the input batch. Default is None.
schema : pyarrow.Schema, optional
The schema of the dataset.
max_rows_per_file : int, optional
The maximum number of rows per file. Default is 1024 * 1024.
max_bytes_per_file : int, optional
The maximum number of bytes per file. Default is 90GB.
max_rows_per_group : int, optional
The maximum number of rows per group. Default is 1024.
Only useful for v1 writer.
use_experimental_writer : bool, optional
Set true to use v2 writer. Default is True.

"""

def __init__(
self,
uri: str,
*,
transform: Callable[[pa.Table], Union[pa.Table, Generator]] = lambda x: x,
transform: Optional[Callable[[pa.Table], Union[pa.Table, Generator]]] = None,
schema: Optional[pa.Schema] = None,
max_rows_per_group: int = 1024, # Only useful for v1 writer.
max_rows_per_file: int = 1024 * 1024,
max_bytes_per_file: Optional[int] = None,
max_rows_per_group: Optional[int] = None, # Only useful for v1 writer.
use_experimental_writer: bool = True,
):
self.uri = uri
self.schema = schema
self.transform = transform
self.transform = transform if transform is not None else lambda x: x

self.max_rows_per_group = max_rows_per_group
self.max_rows_per_file = max_rows_per_file
self.max_bytes_per_file = max_bytes_per_file
self.use_experimental_writer = use_experimental_writer

def __call__(self, batch: Union[pa.Table, "pd.DataFrame"]) -> Dict[str, Any]:
Expand Down Expand Up @@ -282,3 +307,67 @@ def write(
):
v.append((fragment, schema))
return v


def write_lance(
data: ray.data.Dataset,
output_uri: str,
*,
schema: Optional[pa.Schema] = None,
transform: Optional[
Callable[[pa.Table], Union[pa.Table, Generator[None, pa.Table, None]]]
] = None,
max_rows_per_file: int = 1024 * 1024,
max_bytes_per_file: Optional[int] = None,
) -> None:
"""Write Ray dataset at scale.

This method wraps the `LanceFragmentWriter` and `LanceCommitter` to write
large-than-memory ray data to lance files.

Parameters
----------
data : ray.data.Dataset
The dataset to write.
output_uri : str
The output dataset URI.
transform : Callable[[pa.Table], Union[pa.Table, Generator]], optional
A callable to transform the input batch. Default is identity function.
schema : pyarrow.Schema, optional
If provided, the schema of the dataset. Otherwise, it will be inferred.
max_rows_per_file: int, optional
The maximum number of rows per file. Default is 1024 * 1024.
max_bytes_per_file: int, optional
The maximum number of bytes per file. Default is 90GB.
"""
data.map_batches(
LanceFragmentWriter(
output_uri,
schema=schema,
transform=transform,
max_rows_per_file=max_rows_per_file,
max_bytes_per_file=max_bytes_per_file,
),
batch_size=max_rows_per_file,
).write_datasink(LanceCommitter(output_uri, schema=schema))


def _register_hooks():
"""Register lance hook to Ray for better integration.

You can use `ray.data.Dataset.write_lance` to write Ray dataset to lance.
Example:

```python
import ray
import lance
from lance.ray.sink import _register_hooks

_register_hooks()

ray.data.range(10)
.map(lambda x: {"id": x["id"], "str": f"str-{x['id']}"})
.write_lance("~/data.lance")
```
"""
ray.data.Dataset.write_lance = write_lance
22 changes: 22 additions & 0 deletions python/python/tests/test_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
LanceCommitter,
LanceDatasink,
LanceFragmentWriter,
_register_hooks,
)

# Use this hook until we have offical DataSink in Ray.
_register_hooks()

ray.init()


Expand Down Expand Up @@ -77,3 +81,21 @@ def test_ray_committer(tmp_path: Path):
assert sorted(tbl["id"].to_pylist()) == list(range(10))
assert set(tbl["str"].to_pylist()) == set([f"str-{i}" for i in range(10)])
assert len(ds.get_fragments()) == 2


def test_ray_write_lance(tmp_path: Path):
schema = pa.schema([pa.field("id", pa.int64()), pa.field("str", pa.string())])

(
ray.data.range(10)
.map(lambda x: {"id": x["id"], "str": f"str-{x['id']}"})
.write_lance(tmp_path, schema=schema)
)

ds = lance.dataset(tmp_path)
ds.count_rows() == 10
assert ds.schema == schema

tbl = ds.to_table()
assert sorted(tbl["id"].to_pylist()) == list(range(10))
assert set(tbl["str"].to_pylist()) == set([f"str-{i}" for i in range(10)])
Loading