Skip to content

Commit 6e76529

Browse files
authored
feat: execute_uncommitted for merge insert (#3233)
Allows separating write and commit step of merge-insert.
1 parent cf49205 commit 6e76529

File tree

6 files changed

+250
-84
lines changed

6 files changed

+250
-84
lines changed

python/python/lance/dataset.py

+83-14
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Optional,
2727
Sequence,
2828
Set,
29+
Tuple,
2930
TypedDict,
3031
Union,
3132
)
@@ -102,6 +103,30 @@ def execute(self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None):
102103

103104
return super(MergeInsertBuilder, self).execute(reader)
104105

106+
def execute_uncommitted(
107+
self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None
108+
) -> Tuple[Transaction, Dict[str, Any]]:
109+
"""Executes the merge insert operation without committing
110+
111+
This function updates the original dataset and returns a dictionary with
112+
information about merge statistics - i.e. the number of inserted, updated,
113+
and deleted rows.
114+
115+
Parameters
116+
----------
117+
118+
data_obj: ReaderLike
119+
The new data to use as the source table for the operation. This parameter
120+
can be any source of data (e.g. table / dataset) that
121+
:func:`~lance.write_dataset` accepts.
122+
schema: Optional[pa.Schema]
123+
The schema of the data. This only needs to be supplied whenever the data
124+
source is some kind of generator.
125+
"""
126+
reader = _coerce_reader(data_obj, schema)
127+
128+
return super(MergeInsertBuilder, self).execute_uncommitted(reader)
129+
105130
# These next three overrides exist only to document the methods
106131

107132
def when_matched_update_all(
@@ -2220,7 +2245,7 @@ def _commit(
22202245
@staticmethod
22212246
def commit(
22222247
base_uri: Union[str, Path, LanceDataset],
2223-
operation: LanceOperation.BaseOperation,
2248+
operation: Union[LanceOperation.BaseOperation, Transaction],
22242249
blobs_op: Optional[LanceOperation.BaseOperation] = None,
22252250
read_version: Optional[int] = None,
22262251
commit_lock: Optional[CommitLock] = None,
@@ -2326,24 +2351,45 @@ def commit(
23262351
f"commit_lock must be a function, got {type(commit_lock)}"
23272352
)
23282353

2329-
if read_version is None and not isinstance(
2330-
operation, (LanceOperation.Overwrite, LanceOperation.Restore)
2354+
if (
2355+
isinstance(operation, LanceOperation.BaseOperation)
2356+
and read_version is None
2357+
and not isinstance(
2358+
operation, (LanceOperation.Overwrite, LanceOperation.Restore)
2359+
)
23312360
):
23322361
raise ValueError(
23332362
"read_version is required for all operations except "
23342363
"Overwrite and Restore"
23352364
)
2336-
new_ds = _Dataset.commit(
2337-
base_uri,
2338-
operation,
2339-
blobs_op,
2340-
read_version,
2341-
commit_lock,
2342-
storage_options=storage_options,
2343-
enable_v2_manifest_paths=enable_v2_manifest_paths,
2344-
detached=detached,
2345-
max_retries=max_retries,
2346-
)
2365+
if isinstance(operation, Transaction):
2366+
new_ds = _Dataset.commit_transaction(
2367+
base_uri,
2368+
operation,
2369+
commit_lock,
2370+
storage_options=storage_options,
2371+
enable_v2_manifest_paths=enable_v2_manifest_paths,
2372+
detached=detached,
2373+
max_retries=max_retries,
2374+
)
2375+
elif isinstance(operation, LanceOperation.BaseOperation):
2376+
new_ds = _Dataset.commit(
2377+
base_uri,
2378+
operation,
2379+
blobs_op,
2380+
read_version,
2381+
commit_lock,
2382+
storage_options=storage_options,
2383+
enable_v2_manifest_paths=enable_v2_manifest_paths,
2384+
detached=detached,
2385+
max_retries=max_retries,
2386+
)
2387+
else:
2388+
raise TypeError(
2389+
"operation must be a LanceOperation.BaseOperation or Transaction, "
2390+
f"got {type(operation)}"
2391+
)
2392+
23472393
ds = LanceDataset.__new__(LanceDataset)
23482394
ds._storage_options = storage_options
23492395
ds._ds = new_ds
@@ -2722,6 +2768,29 @@ class Delete(BaseOperation):
27222768
def __post_init__(self):
27232769
LanceOperation._validate_fragments(self.updated_fragments)
27242770

2771+
@dataclass
2772+
class Update(BaseOperation):
2773+
"""
2774+
Operation that updates rows in the dataset.
2775+
2776+
Attributes
2777+
----------
2778+
removed_fragment_ids: list[int]
2779+
The ids of the fragments that have been removed entirely.
2780+
updated_fragments: list[FragmentMetadata]
2781+
The fragments that have been updated with new deletion vectors.
2782+
new_fragments: list[FragmentMetadata]
2783+
The fragments that contain the new rows.
2784+
"""
2785+
2786+
removed_fragment_ids: List[int]
2787+
updated_fragments: List[FragmentMetadata]
2788+
new_fragments: List[FragmentMetadata]
2789+
2790+
def __post_init__(self):
2791+
LanceOperation._validate_fragments(self.updated_fragments)
2792+
LanceOperation._validate_fragments(self.new_fragments)
2793+
27252794
@dataclass
27262795
class Merge(BaseOperation):
27272796
"""

python/python/tests/test_dataset.py

+25
Original file line numberDiff line numberDiff line change
@@ -1015,6 +1015,31 @@ def test_restore_with_commit(tmp_path: Path):
10151015
assert tbl == table
10161016

10171017

1018+
def test_merge_insert_with_commit():
1019+
table = pa.table({"id": range(10), "updated": [False] * 10})
1020+
dataset = lance.write_dataset(table, "memory://test")
1021+
1022+
updates = pa.Table.from_pylist([{"id": 1, "updated": True}])
1023+
transaction, stats = (
1024+
dataset.merge_insert(on="id")
1025+
.when_matched_update_all()
1026+
.execute_uncommitted(updates)
1027+
)
1028+
1029+
assert isinstance(stats, dict)
1030+
assert stats["num_updated_rows"] == 1
1031+
assert stats["num_inserted_rows"] == 0
1032+
assert stats["num_deleted_rows"] == 0
1033+
1034+
assert isinstance(transaction, lance.Transaction)
1035+
assert isinstance(transaction.operation, lance.LanceOperation.Update)
1036+
1037+
dataset = lance.LanceDataset.commit(dataset, transaction)
1038+
assert dataset.to_table().sort_by("id") == pa.table(
1039+
{"id": range(10), "updated": [False] + [True] + [False] * 8}
1040+
)
1041+
1042+
10181043
def test_merge_with_commit(tmp_path: Path):
10191044
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
10201045
base_dir = tmp_path / "test"

python/src/dataset.rs

+70-17
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ use lance::dataset::{
4343
WriteParams,
4444
};
4545
use lance::dataset::{
46-
BatchInfo, BatchUDF, CommitBuilder, NewColumnTransform, UDFCheckpointStore, WriteDestination,
46+
BatchInfo, BatchUDF, CommitBuilder, MergeStats, NewColumnTransform, UDFCheckpointStore,
47+
WriteDestination,
4748
};
4849
use lance::dataset::{ColumnAlteration, ProjectionRequest};
4950
use lance::index::vector::utils::get_vector_type;
@@ -199,20 +200,46 @@ impl MergeInsertBuilder {
199200
.try_build()
200201
.map_err(|err| PyValueError::new_err(err.to_string()))?;
201202

202-
let new_self = RT
203+
let (new_dataset, stats) = RT
203204
.spawn(Some(py), job.execute_reader(new_data))?
204205
.map_err(|err| PyIOError::new_err(err.to_string()))?;
205206

206207
let dataset = self.dataset.bind(py);
207208

208-
dataset.borrow_mut().ds = new_self.0;
209-
let merge_stats = new_self.1;
210-
let merge_dict = PyDict::new_bound(py);
211-
merge_dict.set_item("num_inserted_rows", merge_stats.num_inserted_rows)?;
212-
merge_dict.set_item("num_updated_rows", merge_stats.num_updated_rows)?;
213-
merge_dict.set_item("num_deleted_rows", merge_stats.num_deleted_rows)?;
209+
dataset.borrow_mut().ds = new_dataset;
214210

215-
Ok(merge_dict.into())
211+
Ok(Self::build_stats(&stats, py)?.into())
212+
}
213+
214+
pub fn execute_uncommitted<'a>(
215+
&mut self,
216+
new_data: &Bound<'a, PyAny>,
217+
) -> PyResult<(PyLance<Transaction>, Bound<'a, PyDict>)> {
218+
let py = new_data.py();
219+
let new_data = convert_reader(new_data)?;
220+
221+
let job = self
222+
.builder
223+
.try_build()
224+
.map_err(|err| PyValueError::new_err(err.to_string()))?;
225+
226+
let (transaction, stats) = RT
227+
.spawn(Some(py), job.execute_uncommitted(new_data))?
228+
.map_err(|err| PyIOError::new_err(err.to_string()))?;
229+
230+
let stats = Self::build_stats(&stats, py)?;
231+
232+
Ok((PyLance(transaction), stats))
233+
}
234+
}
235+
236+
impl MergeInsertBuilder {
237+
fn build_stats<'a>(stats: &MergeStats, py: Python<'a>) -> PyResult<Bound<'a, PyDict>> {
238+
let dict = PyDict::new_bound(py);
239+
dict.set_item("num_inserted_rows", stats.num_inserted_rows)?;
240+
dict.set_item("num_updated_rows", stats.num_updated_rows)?;
241+
dict.set_item("num_deleted_rows", stats.num_deleted_rows)?;
242+
Ok(dict)
216243
}
217244
}
218245

@@ -1312,6 +1339,36 @@ impl Dataset {
13121339
enable_v2_manifest_paths: Option<bool>,
13131340
detached: Option<bool>,
13141341
max_retries: Option<u32>,
1342+
) -> PyResult<Self> {
1343+
let transaction = Transaction::new(
1344+
read_version.unwrap_or_default(),
1345+
operation.0,
1346+
blobs_op.map(|op| op.0),
1347+
None,
1348+
);
1349+
1350+
Self::commit_transaction(
1351+
dest,
1352+
PyLance(transaction),
1353+
commit_lock,
1354+
storage_options,
1355+
enable_v2_manifest_paths,
1356+
detached,
1357+
max_retries,
1358+
)
1359+
}
1360+
1361+
#[allow(clippy::too_many_arguments)]
1362+
#[staticmethod]
1363+
#[pyo3(signature = (dest, transaction, commit_lock = None, storage_options = None, enable_v2_manifest_paths = None, detached = None, max_retries = None))]
1364+
fn commit_transaction(
1365+
dest: &Bound<PyAny>,
1366+
transaction: PyLance<Transaction>,
1367+
commit_lock: Option<&Bound<'_, PyAny>>,
1368+
storage_options: Option<HashMap<String, String>>,
1369+
enable_v2_manifest_paths: Option<bool>,
1370+
detached: Option<bool>,
1371+
max_retries: Option<u32>,
13151372
) -> PyResult<Self> {
13161373
let object_store_params =
13171374
storage_options
@@ -1333,13 +1390,6 @@ impl Dataset {
13331390
WriteDestination::Uri(dest.extract()?)
13341391
};
13351392

1336-
let transaction = Transaction::new(
1337-
read_version.unwrap_or_default(),
1338-
operation.0,
1339-
blobs_op.map(|op| op.0),
1340-
None,
1341-
);
1342-
13431393
let mut builder = CommitBuilder::new(dest)
13441394
.enable_v2_manifest_paths(enable_v2_manifest_paths.unwrap_or(false))
13451395
.with_detached(detached.unwrap_or(false))
@@ -1354,7 +1404,10 @@ impl Dataset {
13541404
}
13551405

13561406
let ds = RT
1357-
.block_on(commit_lock.map(|cl| cl.py()), builder.execute(transaction))?
1407+
.block_on(
1408+
commit_lock.map(|cl| cl.py()),
1409+
builder.execute(transaction.0),
1410+
)?
13581411
.map_err(|err| PyIOError::new_err(err.to_string()))?;
13591412

13601413
let uri = ds.uri().to_string();

python/src/transaction.rs

+29
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,20 @@ impl FromPyObject<'_> for PyLance<Operation> {
4747
};
4848
Ok(Self(op))
4949
}
50+
"Update" => {
51+
let removed_fragment_ids = ob.getattr("removed_fragment_ids")?.extract()?;
52+
53+
let updated_fragments = extract_vec(&ob.getattr("updated_fragments")?)?;
54+
55+
let new_fragments = extract_vec(&ob.getattr("new_fragments")?)?;
56+
57+
let op = Operation::Update {
58+
removed_fragment_ids,
59+
updated_fragments,
60+
new_fragments,
61+
};
62+
Ok(Self(op))
63+
}
5064
"Merge" => {
5165
let schema = extract_schema(&ob.getattr("schema")?)?;
5266

@@ -143,6 +157,21 @@ impl ToPyObject for PyLance<&Operation> {
143157
.expect("Failed to create Overwrite instance")
144158
.to_object(py)
145159
}
160+
Operation::Update {
161+
removed_fragment_ids,
162+
updated_fragments,
163+
new_fragments,
164+
} => {
165+
let removed_fragment_ids = removed_fragment_ids.to_object(py);
166+
let updated_fragments = export_vec(py, updated_fragments.as_slice());
167+
let new_fragments = export_vec(py, new_fragments.as_slice());
168+
let cls = namespace
169+
.getattr("Update")
170+
.expect("Failed to get Update class");
171+
cls.call1((removed_fragment_ids, updated_fragments, new_fragments))
172+
.unwrap()
173+
.to_object(py)
174+
}
146175
_ => todo!(),
147176
}
148177
}

rust/lance/src/dataset.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ pub use schema_evolution::{
8585
};
8686
pub use take::TakeBuilder;
8787
pub use write::merge_insert::{
88-
MergeInsertBuilder, MergeInsertJob, WhenMatched, WhenNotMatched, WhenNotMatchedBySource,
88+
MergeInsertBuilder, MergeInsertJob, MergeStats, WhenMatched, WhenNotMatched,
89+
WhenNotMatchedBySource,
8990
};
9091
pub use write::update::{UpdateBuilder, UpdateJob};
9192
#[allow(deprecated)]

0 commit comments

Comments
 (0)