20
20
import pyarrow as pa
21
21
22
22
import lance
23
- from lance .fragment import FragmentMetadata , write_fragments
23
+ from lance .fragment import DEFAULT_MAX_BYTES_PER_FILE , FragmentMetadata , write_fragments
24
24
25
25
from ..dependencies import ray
26
26
27
27
if TYPE_CHECKING :
28
28
import pandas as pd
29
29
30
- __all__ = ["LanceDatasink" , "LanceFragmentWriter" , "LanceCommitter" ]
30
+ __all__ = ["LanceDatasink" , "LanceFragmentWriter" , "LanceCommitter" , "write_lance" ]
31
31
32
32
33
33
def _pd_to_arrow (
@@ -52,6 +52,7 @@ def _write_fragment(
52
52
* ,
53
53
schema : Optional [pa .Schema ] = None ,
54
54
max_rows_per_file : int = 1024 * 1024 ,
55
+ max_bytes_per_file : Optional [int ] = None ,
55
56
max_rows_per_group : int = 1024 , # Only useful for v1 writer.
56
57
use_experimental_writer : bool = False ,
57
58
) -> Tuple [FragmentMetadata , pa .Schema ]:
@@ -74,13 +75,18 @@ def record_batch_converter():
74
75
tbl = _pd_to_arrow (block , schema )
75
76
yield from tbl .to_batches ()
76
77
78
+ max_bytes_per_file = (
79
+ DEFAULT_MAX_BYTES_PER_FILE if max_bytes_per_file is None else max_bytes_per_file
80
+ )
81
+
77
82
reader = pa .RecordBatchReader .from_batches (schema , record_batch_converter ())
78
83
fragments = write_fragments (
79
84
reader ,
80
85
uri ,
81
86
schema = schema ,
82
87
max_rows_per_file = max_rows_per_file ,
83
88
max_rows_per_group = max_rows_per_group ,
89
+ max_bytes_per_file = max_bytes_per_file ,
84
90
use_experimental_writer = use_experimental_writer ,
85
91
)
86
92
return [(fragment , schema ) for fragment in fragments ]
@@ -213,24 +219,43 @@ class LanceFragmentWriter:
213
219
in distributed fashion.
214
220
215
221
Parameters
222
+ ----------
223
+ uri : str
224
+ The base URI of the dataset.
225
+ transform : Callable[[pa.Table], Union[pa.Table, Generator]], optional
226
+ A callable to transform the input batch. Default is None.
227
+ schema : pyarrow.Schema, optional
228
+ The schema of the dataset.
229
+ max_rows_per_file : int, optional
230
+ The maximum number of rows per file. Default is 1024 * 1024.
231
+ max_bytes_per_file : int, optional
232
+ The maximum number of bytes per file. Default is 90GB.
233
+ max_rows_per_group : int, optional
234
+ The maximum number of rows per group. Default is 1024.
235
+ Only useful for v1 writer.
236
+ use_experimental_writer : bool, optional
237
+ Set true to use v2 writer. Default is True.
238
+
216
239
"""
217
240
218
241
def __init__ (
219
242
self ,
220
243
uri : str ,
221
244
* ,
222
- transform : Callable [[pa .Table ], Union [pa .Table , Generator ]] = lambda x : x ,
245
+ transform : Optional [ Callable [[pa .Table ], Union [pa .Table , Generator ]]] = None ,
223
246
schema : Optional [pa .Schema ] = None ,
224
- max_rows_per_group : int = 1024 , # Only useful for v1 writer.
225
247
max_rows_per_file : int = 1024 * 1024 ,
248
+ max_bytes_per_file : Optional [int ] = None ,
249
+ max_rows_per_group : Optional [int ] = None , # Only useful for v1 writer.
226
250
use_experimental_writer : bool = True ,
227
251
):
228
252
self .uri = uri
229
253
self .schema = schema
230
- self .transform = transform
254
+ self .transform = transform if transform is not None else lambda x : x
231
255
232
256
self .max_rows_per_group = max_rows_per_group
233
257
self .max_rows_per_file = max_rows_per_file
258
+ self .max_bytes_per_file = max_bytes_per_file
234
259
self .use_experimental_writer = use_experimental_writer
235
260
236
261
def __call__ (self , batch : Union [pa .Table , "pd.DataFrame" ]) -> Dict [str , Any ]:
@@ -282,3 +307,67 @@ def write(
282
307
):
283
308
v .append ((fragment , schema ))
284
309
return v
310
+
311
+
312
+ def write_lance (
313
+ data : ray .data .Dataset ,
314
+ output_uri : str ,
315
+ * ,
316
+ schema : Optional [pa .Schema ] = None ,
317
+ transform : Optional [
318
+ Callable [[pa .Table ], Union [pa .Table , Generator [None , pa .Table , None ]]]
319
+ ] = None ,
320
+ max_rows_per_file : int = 1024 * 1024 ,
321
+ max_bytes_per_file : Optional [int ] = None ,
322
+ ) -> None :
323
+ """Write Ray dataset at scale.
324
+
325
+ This method wraps the `LanceFragmentWriter` and `LanceCommitter` to write
326
+ large-than-memory ray data to lance files.
327
+
328
+ Parameters
329
+ ----------
330
+ data : ray.data.Dataset
331
+ The dataset to write.
332
+ output_uri : str
333
+ The output dataset URI.
334
+ transform : Callable[[pa.Table], Union[pa.Table, Generator]], optional
335
+ A callable to transform the input batch. Default is identity function.
336
+ schema : pyarrow.Schema, optional
337
+ If provided, the schema of the dataset. Otherwise, it will be inferred.
338
+ max_rows_per_file: int, optional
339
+ The maximum number of rows per file. Default is 1024 * 1024.
340
+ max_bytes_per_file: int, optional
341
+ The maximum number of bytes per file. Default is 90GB.
342
+ """
343
+ data .map_batches (
344
+ LanceFragmentWriter (
345
+ output_uri ,
346
+ schema = schema ,
347
+ transform = transform ,
348
+ max_rows_per_file = max_rows_per_file ,
349
+ max_bytes_per_file = max_bytes_per_file ,
350
+ ),
351
+ batch_size = max_rows_per_file ,
352
+ ).write_datasink (LanceCommitter (output_uri , schema = schema ))
353
+
354
+
355
+ def _register_hooks ():
356
+ """Register lance hook to Ray for better integration.
357
+
358
+ You can use `ray.data.Dataset.write_lance` to write Ray dataset to lance.
359
+ Example:
360
+
361
+ ```python
362
+ import ray
363
+ import lance
364
+ from lance.ray.sink import _register_hooks
365
+
366
+ _register_hooks()
367
+
368
+ ray.data.range(10)
369
+ .map(lambda x: {"id": x["id"], "str": f"str-{x['id']}"})
370
+ .write_lance("~/data.lance")
371
+ ```
372
+ """
373
+ ray .data .Dataset .write_lance = write_lance
0 commit comments