Skip to content

Commit 920a070

Browse files
authored
feat: improvement of Ray sink API (#2237)
* Expose `max_bytes_per_file` via Ray sink * Add a hook to provide `ray.data.Dataset.write_lance()` interface.
1 parent 8ac02bc commit 920a070

File tree

3 files changed

+120
-6
lines changed

3 files changed

+120
-6
lines changed

python/python/lance/fragment.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@
3333
from .schema import LanceSchema
3434

3535

36+
DEFAULT_MAX_BYTES_PER_FILE = 90 * 1024 * 1024 * 1024
37+
38+
3639
class FragmentMetadata:
3740
"""Metadata of a Fragment in the dataset."""
3841

@@ -496,7 +499,7 @@ def write_fragments(
496499
mode: str = "append",
497500
max_rows_per_file: int = 1024 * 1024,
498501
max_rows_per_group: int = 1024,
499-
max_bytes_per_file: int = 90 * 1024 * 1024 * 1024,
502+
max_bytes_per_file: int = DEFAULT_MAX_BYTES_PER_FILE,
500503
progress: Optional[FragmentWriteProgress] = None,
501504
use_experimental_writer: bool = False,
502505
) -> List[FragmentMetadata]:

python/python/lance/ray/sink.py

+94-5
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020
import pyarrow as pa
2121

2222
import lance
23-
from lance.fragment import FragmentMetadata, write_fragments
23+
from lance.fragment import DEFAULT_MAX_BYTES_PER_FILE, FragmentMetadata, write_fragments
2424

2525
from ..dependencies import ray
2626

2727
if TYPE_CHECKING:
2828
import pandas as pd
2929

30-
__all__ = ["LanceDatasink", "LanceFragmentWriter", "LanceCommitter"]
30+
__all__ = ["LanceDatasink", "LanceFragmentWriter", "LanceCommitter", "write_lance"]
3131

3232

3333
def _pd_to_arrow(
@@ -52,6 +52,7 @@ def _write_fragment(
5252
*,
5353
schema: Optional[pa.Schema] = None,
5454
max_rows_per_file: int = 1024 * 1024,
55+
max_bytes_per_file: Optional[int] = None,
5556
max_rows_per_group: int = 1024, # Only useful for v1 writer.
5657
use_experimental_writer: bool = False,
5758
) -> Tuple[FragmentMetadata, pa.Schema]:
@@ -74,13 +75,18 @@ def record_batch_converter():
7475
tbl = _pd_to_arrow(block, schema)
7576
yield from tbl.to_batches()
7677

78+
max_bytes_per_file = (
79+
DEFAULT_MAX_BYTES_PER_FILE if max_bytes_per_file is None else max_bytes_per_file
80+
)
81+
7782
reader = pa.RecordBatchReader.from_batches(schema, record_batch_converter())
7883
fragments = write_fragments(
7984
reader,
8085
uri,
8186
schema=schema,
8287
max_rows_per_file=max_rows_per_file,
8388
max_rows_per_group=max_rows_per_group,
89+
max_bytes_per_file=max_bytes_per_file,
8490
use_experimental_writer=use_experimental_writer,
8591
)
8692
return [(fragment, schema) for fragment in fragments]
@@ -213,24 +219,43 @@ class LanceFragmentWriter:
213219
in distributed fashion.
214220
215221
Parameters
222+
----------
223+
uri : str
224+
The base URI of the dataset.
225+
transform : Callable[[pa.Table], Union[pa.Table, Generator]], optional
226+
A callable to transform the input batch. Default is None.
227+
schema : pyarrow.Schema, optional
228+
The schema of the dataset.
229+
max_rows_per_file : int, optional
230+
The maximum number of rows per file. Default is 1024 * 1024.
231+
max_bytes_per_file : int, optional
232+
The maximum number of bytes per file. Default is 90GB.
233+
max_rows_per_group : int, optional
234+
The maximum number of rows per group. Default is 1024.
235+
Only useful for v1 writer.
236+
use_experimental_writer : bool, optional
237+
Set true to use v2 writer. Default is True.
238+
216239
"""
217240

218241
def __init__(
219242
self,
220243
uri: str,
221244
*,
222-
transform: Callable[[pa.Table], Union[pa.Table, Generator]] = lambda x: x,
245+
transform: Optional[Callable[[pa.Table], Union[pa.Table, Generator]]] = None,
223246
schema: Optional[pa.Schema] = None,
224-
max_rows_per_group: int = 1024, # Only useful for v1 writer.
225247
max_rows_per_file: int = 1024 * 1024,
248+
max_bytes_per_file: Optional[int] = None,
249+
max_rows_per_group: Optional[int] = None, # Only useful for v1 writer.
226250
use_experimental_writer: bool = True,
227251
):
228252
self.uri = uri
229253
self.schema = schema
230-
self.transform = transform
254+
self.transform = transform if transform is not None else lambda x: x
231255

232256
self.max_rows_per_group = max_rows_per_group
233257
self.max_rows_per_file = max_rows_per_file
258+
self.max_bytes_per_file = max_bytes_per_file
234259
self.use_experimental_writer = use_experimental_writer
235260

236261
def __call__(self, batch: Union[pa.Table, "pd.DataFrame"]) -> Dict[str, Any]:
@@ -282,3 +307,67 @@ def write(
282307
):
283308
v.append((fragment, schema))
284309
return v
310+
311+
312+
def write_lance(
313+
data: ray.data.Dataset,
314+
output_uri: str,
315+
*,
316+
schema: Optional[pa.Schema] = None,
317+
transform: Optional[
318+
Callable[[pa.Table], Union[pa.Table, Generator[None, pa.Table, None]]]
319+
] = None,
320+
max_rows_per_file: int = 1024 * 1024,
321+
max_bytes_per_file: Optional[int] = None,
322+
) -> None:
323+
"""Write Ray dataset at scale.
324+
325+
This method wraps the `LanceFragmentWriter` and `LanceCommitter` to write
326+
large-than-memory ray data to lance files.
327+
328+
Parameters
329+
----------
330+
data : ray.data.Dataset
331+
The dataset to write.
332+
output_uri : str
333+
The output dataset URI.
334+
transform : Callable[[pa.Table], Union[pa.Table, Generator]], optional
335+
A callable to transform the input batch. Default is identity function.
336+
schema : pyarrow.Schema, optional
337+
If provided, the schema of the dataset. Otherwise, it will be inferred.
338+
max_rows_per_file: int, optional
339+
The maximum number of rows per file. Default is 1024 * 1024.
340+
max_bytes_per_file: int, optional
341+
The maximum number of bytes per file. Default is 90GB.
342+
"""
343+
data.map_batches(
344+
LanceFragmentWriter(
345+
output_uri,
346+
schema=schema,
347+
transform=transform,
348+
max_rows_per_file=max_rows_per_file,
349+
max_bytes_per_file=max_bytes_per_file,
350+
),
351+
batch_size=max_rows_per_file,
352+
).write_datasink(LanceCommitter(output_uri, schema=schema))
353+
354+
355+
def _register_hooks():
356+
"""Register lance hook to Ray for better integration.
357+
358+
You can use `ray.data.Dataset.write_lance` to write Ray dataset to lance.
359+
Example:
360+
361+
```python
362+
import ray
363+
import lance
364+
from lance.ray.sink import _register_hooks
365+
366+
_register_hooks()
367+
368+
ray.data.range(10)
369+
.map(lambda x: {"id": x["id"], "str": f"str-{x['id']}"})
370+
.write_lance("~/data.lance")
371+
```
372+
"""
373+
ray.data.Dataset.write_lance = write_lance

python/python/tests/test_ray.py

+22
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@
1414
LanceCommitter,
1515
LanceDatasink,
1616
LanceFragmentWriter,
17+
_register_hooks,
1718
)
1819

20+
# Use this hook until we have offical DataSink in Ray.
21+
_register_hooks()
22+
1923
ray.init()
2024

2125

@@ -77,3 +81,21 @@ def test_ray_committer(tmp_path: Path):
7781
assert sorted(tbl["id"].to_pylist()) == list(range(10))
7882
assert set(tbl["str"].to_pylist()) == set([f"str-{i}" for i in range(10)])
7983
assert len(ds.get_fragments()) == 2
84+
85+
86+
def test_ray_write_lance(tmp_path: Path):
87+
schema = pa.schema([pa.field("id", pa.int64()), pa.field("str", pa.string())])
88+
89+
(
90+
ray.data.range(10)
91+
.map(lambda x: {"id": x["id"], "str": f"str-{x['id']}"})
92+
.write_lance(tmp_path, schema=schema)
93+
)
94+
95+
ds = lance.dataset(tmp_path)
96+
ds.count_rows() == 10
97+
assert ds.schema == schema
98+
99+
tbl = ds.to_table()
100+
assert sorted(tbl["id"].to_pylist()) == list(range(10))
101+
assert set(tbl["str"].to_pylist()) == set([f"str-{i}" for i in range(10)])

0 commit comments

Comments
 (0)