Skip to content

Commit ae70478

Browse files
authored
feat: support merge fragment with dataset (#3256)
this PR allows merge dataset concurrently.
1 parent c40164b commit ae70478

File tree

6 files changed

+236
-3
lines changed

6 files changed

+236
-3
lines changed

python/python/lance/dataset.py

+7
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,13 @@ def data_storage_version(self) -> str:
490490
"""
491491
return self._ds.data_storage_version
492492

493+
@property
494+
def max_field_id(self) -> int:
495+
"""
496+
The max_field_id in manifest
497+
"""
498+
return self._ds.max_field_id
499+
493500
def to_table(
494501
self,
495502
columns: Optional[Union[List[str], Dict[str, str]]] = None,

python/python/lance/fragment.py

+75
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
_write_fragments,
3636
)
3737
from .progress import FragmentWriteProgress, NoopFragmentWriteProgress
38+
from .types import _coerce_reader
3839
from .udf import BatchUDF, normalize_transform
3940

4041
if TYPE_CHECKING:
@@ -406,6 +407,7 @@ def scanner(
406407
limit: Optional[int] = None,
407408
offset: Optional[int] = None,
408409
with_row_id: bool = False,
410+
with_row_address: bool = False,
409411
batch_readahead: int = 16,
410412
) -> "LanceScanner":
411413
"""See Dataset::scanner for details"""
@@ -424,6 +426,7 @@ def scanner(
424426
limit=limit,
425427
offset=offset,
426428
with_row_id=with_row_id,
429+
with_row_address=with_row_address,
427430
batch_readahead=batch_readahead,
428431
**columns_arg,
429432
)
@@ -475,6 +478,78 @@ def to_table(
475478
with_row_id=with_row_id,
476479
).to_table()
477480

481+
def merge(
482+
self,
483+
data_obj: ReaderLike,
484+
left_on: str,
485+
right_on: Optional[str] = None,
486+
schema=None,
487+
) -> Tuple[FragmentMetadata, LanceSchema]:
488+
"""
489+
Merge another dataset into this fragment.
490+
491+
Performs a left join, where the fragment is the left side and data_obj
492+
is the right side. Rows existing in the dataset but not on the left will
493+
be filled with null values, unless Lance doesn't support null values for
494+
some types, in which case an error will be raised.
495+
496+
Parameters
497+
----------
498+
data_obj: Reader-like
499+
The data to be merged. Acceptable types are:
500+
- Pandas DataFrame, Pyarrow Table, Dataset, Scanner,
501+
Iterator[RecordBatch], or RecordBatchReader
502+
left_on: str
503+
The name of the column in the dataset to join on.
504+
right_on: str or None
505+
The name of the column in data_obj to join on. If None, defaults to
506+
left_on.
507+
508+
Examples
509+
--------
510+
511+
>>> import lance
512+
>>> import pyarrow as pa
513+
>>> df = pa.table({'x': [1, 2, 3], 'y': ['a', 'b', 'c']})
514+
>>> dataset = lance.write_dataset(df, "dataset")
515+
>>> dataset.to_table().to_pandas()
516+
x y
517+
0 1 a
518+
1 2 b
519+
2 3 c
520+
>>> fragments = dataset.get_fragments()
521+
>>> new_df = pa.table({'x': [1, 2, 3], 'z': ['d', 'e', 'f']})
522+
>>> merged = []
523+
>>> schema = None
524+
>>> for f in fragments:
525+
... f, schema = f.merge(new_df, 'x')
526+
... merged.append(f)
527+
>>> merge = lance.LanceOperation.Merge(merged, schema)
528+
>>> dataset = lance.LanceDataset.commit("dataset", merge, read_version=1)
529+
>>> dataset.to_table().to_pandas()
530+
x y z
531+
0 1 a d
532+
1 2 b e
533+
2 3 c f
534+
535+
See Also
536+
--------
537+
LanceDataset.merge_columns :
538+
Add columns to this Fragment.
539+
540+
Returns
541+
-------
542+
Tuple[FragmentMetadata, LanceSchema]
543+
A new fragment with the merged column(s) and the final schema.
544+
"""
545+
if right_on is None:
546+
right_on = left_on
547+
548+
reader = _coerce_reader(data_obj, schema)
549+
max_field_id = self._ds.max_field_id
550+
metadata, schema = self._fragment.merge(reader, left_on, right_on, max_field_id)
551+
return metadata, schema
552+
478553
def merge_columns(
479554
self,
480555
value_func: Dict[str, str]

python/python/tests/test_fragment.py

+61
Original file line numberDiff line numberDiff line change
@@ -361,3 +361,64 @@ def test_create_from_file(tmp_path):
361361
assert dataset.count_rows() == 1600
362362
assert len(dataset.get_fragments()) == 1
363363
assert dataset.get_fragments()[0].fragment_id == 2
364+
365+
366+
def test_fragment_merge(tmp_path):
367+
schema = pa.schema([pa.field("a", pa.string())])
368+
batches = pa.RecordBatchReader.from_batches(
369+
schema,
370+
[
371+
pa.record_batch([pa.array(["0" * 1024] * 1024 * 8)], names=["a"]),
372+
pa.record_batch([pa.array(["0" * 1024] * 1024 * 8)], names=["a"]),
373+
],
374+
)
375+
376+
progress = ProgressForTest()
377+
fragments = write_fragments(
378+
batches,
379+
tmp_path,
380+
max_rows_per_group=512,
381+
max_bytes_per_file=1024,
382+
progress=progress,
383+
)
384+
385+
operation = lance.LanceOperation.Overwrite(schema, fragments)
386+
dataset = lance.LanceDataset.commit(tmp_path, operation)
387+
merged = []
388+
schema = None
389+
for fragment in dataset.get_fragments():
390+
table = fragment.scanner(with_row_id=True, columns=[]).to_table()
391+
table = table.add_column(0, "b", [[i for i in range(len(table))]])
392+
fragment, schema = fragment.merge(table, "_rowid")
393+
merged.append(fragment)
394+
395+
merge = lance.LanceOperation.Merge(merged, schema)
396+
dataset = lance.LanceDataset.commit(
397+
tmp_path, merge, read_version=dataset.latest_version
398+
)
399+
400+
merged = []
401+
schema = None
402+
for fragment in dataset.get_fragments():
403+
table = fragment.scanner(with_row_address=True, columns=[]).to_table()
404+
table = table.add_column(0, "c", [[i + 1 for i in range(len(table))]])
405+
fragment, schema = fragment.merge(table, "_rowaddr")
406+
merged.append(fragment)
407+
408+
merge = lance.LanceOperation.Merge(merged, schema)
409+
dataset = lance.LanceDataset.commit(
410+
tmp_path, merge, read_version=dataset.latest_version
411+
)
412+
413+
merged = []
414+
for fragment in dataset.get_fragments():
415+
table = fragment.scanner(columns=["b"]).to_table()
416+
table = table.add_column(0, "d", [[i + 2 for i in range(len(table))]])
417+
fragment, schema = fragment.merge(table, "b")
418+
merged.append(fragment)
419+
420+
merge = lance.LanceOperation.Merge(merged, schema)
421+
dataset = lance.LanceDataset.commit(
422+
tmp_path, merge, read_version=dataset.latest_version
423+
)
424+
assert [f.name for f in dataset.schema] == ["a", "b", "c", "d"]

python/src/dataset.rs

+5
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,11 @@ impl Dataset {
332332
self.clone()
333333
}
334334

335+
#[getter(max_field_id)]
336+
fn max_field_id(self_: PyRef<'_, Self>) -> PyResult<i32> {
337+
Ok(self_.ds.manifest().max_field_id())
338+
}
339+
335340
#[getter(schema)]
336341
fn schema(self_: PyRef<'_, Self>) -> PyResult<PyObject> {
337342
let arrow_schema = ArrowSchema::from(self_.ds.schema());

python/src/fragment.rs

+25-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use std::fmt::Write as _;
1616
use std::sync::Arc;
1717

1818
use arrow::ffi_stream::ArrowArrayStreamReader;
19-
use arrow::pyarrow::{FromPyArrow, ToPyArrow};
19+
use arrow::pyarrow::{FromPyArrow, PyArrowType, ToPyArrow};
2020
use arrow_array::RecordBatchReader;
2121
use arrow_schema::Schema as ArrowSchema;
2222
use futures::TryFutureExt;
@@ -163,7 +163,7 @@ impl FileFragment {
163163
}
164164

165165
#[allow(clippy::too_many_arguments)]
166-
#[pyo3(signature=(columns=None, columns_with_transform=None, batch_size=None, filter=None, limit=None, offset=None, with_row_id=None, batch_readahead=None))]
166+
#[pyo3(signature=(columns=None, columns_with_transform=None, batch_size=None, filter=None, limit=None, offset=None, with_row_id=None, with_row_address=None, batch_readahead=None))]
167167
fn scanner(
168168
self_: PyRef<'_, Self>,
169169
columns: Option<Vec<String>>,
@@ -173,6 +173,7 @@ impl FileFragment {
173173
limit: Option<i64>,
174174
offset: Option<i64>,
175175
with_row_id: Option<bool>,
176+
with_row_address: Option<bool>,
176177
batch_readahead: Option<usize>,
177178
) -> PyResult<Scanner> {
178179
let mut scanner = self_.fragment.scan();
@@ -212,6 +213,9 @@ impl FileFragment {
212213
if with_row_id.unwrap_or(false) {
213214
scanner.with_row_id();
214215
}
216+
if with_row_address.unwrap_or(false) {
217+
scanner.with_row_address();
218+
}
215219
if let Some(batch_readahead) = batch_readahead {
216220
scanner.batch_readahead(batch_readahead);
217221
}
@@ -261,6 +265,25 @@ impl FileFragment {
261265
Ok((PyLance(fragment), LanceSchema(schema)))
262266
}
263267

268+
fn merge(
269+
&mut self,
270+
reader: PyArrowType<ArrowArrayStreamReader>,
271+
left_on: String,
272+
right_on: String,
273+
max_field_id: i32,
274+
) -> PyResult<(PyLance<Fragment>, LanceSchema)> {
275+
let mut fragment = self.fragment.clone();
276+
let (fragment, schema) = RT
277+
.spawn(None, async move {
278+
fragment
279+
.merge_columns(reader.0, &left_on, &right_on, max_field_id)
280+
.await
281+
})?
282+
.infer_error()?;
283+
284+
Ok((PyLance(fragment), LanceSchema(schema)))
285+
}
286+
264287
fn delete(&self, predicate: &str) -> PyResult<Option<Self>> {
265288
let old_fragment = self.fragment.clone();
266289
let updated_fragment = RT

rust/lance/src/dataset/fragment.rs

+63-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ use std::sync::Arc;
1212

1313
use arrow::compute::concat_batches;
1414
use arrow_array::cast::as_primitive_array;
15-
use arrow_array::{new_null_array, RecordBatch, StructArray, UInt32Array, UInt64Array};
15+
use arrow_array::{
16+
new_null_array, RecordBatch, RecordBatchReader, StructArray, UInt32Array, UInt64Array,
17+
};
1618
use arrow_schema::Schema as ArrowSchema;
1719
use datafusion::logical_expr::Expr;
1820
use datafusion::scalar::ScalarValue;
@@ -1331,6 +1333,66 @@ impl FileFragment {
13311333
Updater::try_new(self.clone(), reader, deletion_vector, schemas, batch_size)
13321334
}
13331335

1336+
pub async fn merge_columns(
1337+
&mut self,
1338+
stream: impl RecordBatchReader + Send + 'static,
1339+
left_on: &str,
1340+
right_on: &str,
1341+
max_field_id: i32,
1342+
) -> Result<(Fragment, Schema)> {
1343+
let stream = Box::new(stream);
1344+
if self.schema().field(left_on).is_none() && left_on != ROW_ID && left_on != ROW_ADDR {
1345+
return Err(Error::invalid_input(
1346+
format!(
1347+
"Column {} does not exist in the left side fragment",
1348+
left_on
1349+
),
1350+
location!(),
1351+
));
1352+
};
1353+
let right_schema = stream.schema();
1354+
if right_schema.field_with_name(right_on).is_err() {
1355+
return Err(Error::invalid_input(
1356+
format!(
1357+
"Column {} does not exist in the right side fragment",
1358+
right_on
1359+
),
1360+
location!(),
1361+
));
1362+
};
1363+
1364+
for field in right_schema.fields() {
1365+
if field.name() == right_on {
1366+
// right_on is allowed to exist in the dataset, since it may be
1367+
// the same as left_on.
1368+
continue;
1369+
}
1370+
if self.schema().field(field.name()).is_some() {
1371+
return Err(Error::invalid_input(
1372+
format!(
1373+
"Column {} exists in left side fragment and right side dataset",
1374+
field.name()
1375+
),
1376+
location!(),
1377+
));
1378+
}
1379+
}
1380+
// Hash join
1381+
let joiner = Arc::new(HashJoiner::try_new(stream, right_on).await?);
1382+
// Final schema is union of current schema, plus the RHS schema without
1383+
// the right_on key.
1384+
let mut new_schema: Schema = self.schema().merge(joiner.out_schema().as_ref())?;
1385+
new_schema.set_field_id(Some(max_field_id));
1386+
1387+
let new_fragment = self
1388+
.clone()
1389+
.merge(left_on, &joiner)
1390+
.await
1391+
.map(|f| f.metadata)?;
1392+
1393+
Ok((new_fragment, new_schema))
1394+
}
1395+
13341396
pub(crate) async fn merge(mut self, join_column: &str, joiner: &HashJoiner) -> Result<Self> {
13351397
let mut updater = self.updater(Some(&[join_column]), None, None).await?;
13361398

0 commit comments

Comments
 (0)