Skip to content

Commit fd4508d

Browse files
committedDec 20, 2024
feat(rust): execute_uncommitted for merge_insert
expose in python refactor: make transaction marshalling easier cleanup fix tests fix path backward compatibility fix repr get changes back
1 parent 10e6454 commit fd4508d

File tree

6 files changed

+250
-79
lines changed

6 files changed

+250
-79
lines changed
 

‎python/python/lance/dataset.py

+80-13
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Optional,
2727
Sequence,
2828
Set,
29+
Tuple,
2930
TypedDict,
3031
Union,
3132
)
@@ -102,6 +103,30 @@ def execute(self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None):
102103

103104
return super(MergeInsertBuilder, self).execute(reader)
104105

106+
def execute_uncommitted(
107+
self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None
108+
) -> Tuple[Transaction, Dict[str, Any]]:
109+
"""Executes the merge insert operation without committing
110+
111+
This function updates the original dataset and returns a dictionary with
112+
information about merge statistics - i.e. the number of inserted, updated,
113+
and deleted rows.
114+
115+
Parameters
116+
----------
117+
118+
data_obj: ReaderLike
119+
The new data to use as the source table for the operation. This parameter
120+
can be any source of data (e.g. table / dataset) that
121+
:func:`~lance.write_dataset` accepts.
122+
schema: Optional[pa.Schema]
123+
The schema of the data. This only needs to be supplied whenever the data
124+
source is some kind of generator.
125+
"""
126+
reader = _coerce_reader(data_obj, schema)
127+
128+
return super(MergeInsertBuilder, self).execute_uncommitted(reader)
129+
105130
# These next three overrides exist only to document the methods
106131

107132
def when_matched_update_all(
@@ -2200,7 +2225,7 @@ def _commit(
22002225
@staticmethod
22012226
def commit(
22022227
base_uri: Union[str, Path, LanceDataset],
2203-
operation: LanceOperation.BaseOperation,
2228+
operation: Union[LanceOperation.BaseOperation, Transaction],
22042229
read_version: Optional[int] = None,
22052230
commit_lock: Optional[CommitLock] = None,
22062231
storage_options: Optional[Dict[str, str]] = None,
@@ -2305,24 +2330,44 @@ def commit(
23052330
f"commit_lock must be a function, got {type(commit_lock)}"
23062331
)
23072332

2308-
if read_version is None and not isinstance(
2309-
operation, (LanceOperation.Overwrite, LanceOperation.Restore)
2333+
if (
2334+
isinstance(operation, LanceOperation.BaseOperation)
2335+
and read_version is None
2336+
and not isinstance(
2337+
operation, (LanceOperation.Overwrite, LanceOperation.Restore)
2338+
)
23102339
):
23112340
raise ValueError(
23122341
"read_version is required for all operations except "
23132342
"Overwrite and Restore"
23142343
)
23152344

2316-
new_ds = _Dataset.commit(
2317-
base_uri,
2318-
operation,
2319-
read_version,
2320-
commit_lock,
2321-
storage_options=storage_options,
2322-
enable_v2_manifest_paths=enable_v2_manifest_paths,
2323-
detached=detached,
2324-
max_retries=max_retries,
2325-
)
2345+
if isinstance(operation, Transaction):
2346+
new_ds = _Dataset.commit_transaction(
2347+
base_uri,
2348+
operation,
2349+
commit_lock,
2350+
storage_options=storage_options,
2351+
enable_v2_manifest_paths=enable_v2_manifest_paths,
2352+
detached=detached,
2353+
max_retries=max_retries,
2354+
)
2355+
elif isinstance(operation, LanceOperation.BaseOperation):
2356+
new_ds = _Dataset.commit(
2357+
base_uri,
2358+
operation,
2359+
read_version,
2360+
commit_lock,
2361+
storage_options=storage_options,
2362+
enable_v2_manifest_paths=enable_v2_manifest_paths,
2363+
detached=detached,
2364+
max_retries=max_retries,
2365+
)
2366+
else:
2367+
raise TypeError(
2368+
"operation must be a LanceOperation.BaseOperation or Transaction, "
2369+
f"got {type(operation)}"
2370+
)
23262371
ds = LanceDataset.__new__(LanceDataset)
23272372
ds._storage_options = storage_options
23282373
ds._ds = new_ds
@@ -2666,6 +2711,28 @@ class Delete(BaseOperation):
26662711
def __post_init__(self):
26672712
LanceOperation._validate_fragments(self.updated_fragments)
26682713

2714+
@dataclass
2715+
class Update(BaseOperation):
2716+
"""
2717+
Operation that updates rows in the dataset.
2718+
Attributes
2719+
----------
2720+
removed_fragment_ids: list[int]
2721+
The ids of the fragments that have been removed entirely.
2722+
updated_fragments: list[FragmentMetadata]
2723+
The fragments that have been updated with new deletion vectors.
2724+
new_fragments: list[FragmentMetadata]
2725+
The fragments that contain the new rows.
2726+
"""
2727+
2728+
removed_fragment_ids: List[int]
2729+
updated_fragments: List[FragmentMetadata]
2730+
new_fragments: List[FragmentMetadata]
2731+
2732+
def __post_init__(self):
2733+
LanceOperation._validate_fragments(self.updated_fragments)
2734+
LanceOperation._validate_fragments(self.new_fragments)
2735+
26692736
@dataclass
26702737
class Merge(BaseOperation):
26712738
"""

‎python/python/tests/test_dataset.py

+25
Original file line numberDiff line numberDiff line change
@@ -1015,6 +1015,31 @@ def test_restore_with_commit(tmp_path: Path):
10151015
assert tbl == table
10161016

10171017

1018+
def test_merge_insert_with_commit():
1019+
table = pa.table({"id": range(10), "updated": [False] * 10})
1020+
dataset = lance.write_dataset(table, "memory://test")
1021+
1022+
updates = pa.Table.from_pylist([{"id": 1, "updated": True}])
1023+
transaction, stats = (
1024+
dataset.merge_insert(on="id")
1025+
.when_matched_update_all()
1026+
.execute_uncommitted(updates)
1027+
)
1028+
1029+
assert isinstance(stats, dict)
1030+
assert stats["num_updated_rows"] == 1
1031+
assert stats["num_inserted_rows"] == 0
1032+
assert stats["num_deleted_rows"] == 0
1033+
1034+
assert isinstance(transaction, lance.Transaction)
1035+
assert isinstance(transaction.operation, lance.LanceOperation.Update)
1036+
1037+
dataset = lance.LanceDataset.commit(dataset, transaction)
1038+
assert dataset.to_table().sort_by("id") == pa.table(
1039+
{"id": range(10), "updated": [False] + [True] + [False] * 8}
1040+
)
1041+
1042+
10181043
def test_merge_with_commit(tmp_path: Path):
10191044
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
10201045
base_dir = tmp_path / "test"

‎python/src/dataset.rs

+66-13
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ use lance::dataset::{
4040
WriteParams,
4141
};
4242
use lance::dataset::{
43-
BatchInfo, BatchUDF, CommitBuilder, NewColumnTransform, UDFCheckpointStore, WriteDestination,
43+
BatchInfo, BatchUDF, CommitBuilder, MergeStats, NewColumnTransform, UDFCheckpointStore,
44+
WriteDestination,
4445
};
4546
use lance::dataset::{ColumnAlteration, ProjectionRequest};
4647
use lance::index::{vector::VectorIndexParams, DatasetIndexInternalExt};
@@ -194,20 +195,46 @@ impl MergeInsertBuilder {
194195
.try_build()
195196
.map_err(|err| PyValueError::new_err(err.to_string()))?;
196197

197-
let new_self = RT
198+
let (new_dataset, stats) = RT
198199
.spawn(Some(py), job.execute_reader(new_data))?
199200
.map_err(|err| PyIOError::new_err(err.to_string()))?;
200201

201202
let dataset = self.dataset.bind(py);
202203

203-
dataset.borrow_mut().ds = new_self.0;
204-
let merge_stats = new_self.1;
205-
let merge_dict = PyDict::new_bound(py);
206-
merge_dict.set_item("num_inserted_rows", merge_stats.num_inserted_rows)?;
207-
merge_dict.set_item("num_updated_rows", merge_stats.num_updated_rows)?;
208-
merge_dict.set_item("num_deleted_rows", merge_stats.num_deleted_rows)?;
204+
dataset.borrow_mut().ds = new_dataset;
205+
206+
Ok(Self::build_stats(&stats, py)?.into())
207+
}
208+
209+
pub fn execute_uncommitted<'a>(
210+
&mut self,
211+
new_data: &Bound<'a, PyAny>,
212+
) -> PyResult<(PyLance<Transaction>, Bound<'a, PyDict>)> {
213+
let py = new_data.py();
214+
let new_data = convert_reader(new_data)?;
209215

210-
Ok(merge_dict.into())
216+
let job = self
217+
.builder
218+
.try_build()
219+
.map_err(|err| PyValueError::new_err(err.to_string()))?;
220+
221+
let (transaction, stats) = RT
222+
.spawn(Some(py), job.execute_uncommitted(new_data))?
223+
.map_err(|err| PyIOError::new_err(err.to_string()))?;
224+
225+
let stats = Self::build_stats(&stats, py)?;
226+
227+
Ok((PyLance(transaction), stats))
228+
}
229+
}
230+
231+
impl MergeInsertBuilder {
232+
fn build_stats<'a>(stats: &MergeStats, py: Python<'a>) -> PyResult<Bound<'a, PyDict>> {
233+
let dict = PyDict::new_bound(py);
234+
dict.set_item("num_inserted_rows", stats.num_inserted_rows)?;
235+
dict.set_item("num_updated_rows", stats.num_updated_rows)?;
236+
dict.set_item("num_deleted_rows", stats.num_deleted_rows)?;
237+
Ok(dict)
211238
}
212239
}
213240

@@ -1284,6 +1311,32 @@ impl Dataset {
12841311
enable_v2_manifest_paths: Option<bool>,
12851312
detached: Option<bool>,
12861313
max_retries: Option<u32>,
1314+
) -> PyResult<Self> {
1315+
let transaction =
1316+
Transaction::new(read_version.unwrap_or_default(), operation.0, None, None);
1317+
1318+
Self::commit_transaction(
1319+
dest,
1320+
PyLance(transaction),
1321+
commit_lock,
1322+
storage_options,
1323+
enable_v2_manifest_paths,
1324+
detached,
1325+
max_retries,
1326+
)
1327+
}
1328+
1329+
#[allow(clippy::too_many_arguments)]
1330+
#[staticmethod]
1331+
#[pyo3(signature = (dest, transaction, commit_lock = None, storage_options = None, enable_v2_manifest_paths = None, detached = None, max_retries = None))]
1332+
fn commit_transaction(
1333+
dest: &Bound<PyAny>,
1334+
transaction: PyLance<Transaction>,
1335+
commit_lock: Option<&Bound<'_, PyAny>>,
1336+
storage_options: Option<HashMap<String, String>>,
1337+
enable_v2_manifest_paths: Option<bool>,
1338+
detached: Option<bool>,
1339+
max_retries: Option<u32>,
12871340
) -> PyResult<Self> {
12881341
let object_store_params =
12891342
storage_options
@@ -1305,9 +1358,6 @@ impl Dataset {
13051358
WriteDestination::Uri(dest.extract()?)
13061359
};
13071360

1308-
let transaction =
1309-
Transaction::new(read_version.unwrap_or_default(), operation.0, None, None);
1310-
13111361
let mut builder = CommitBuilder::new(dest)
13121362
.enable_v2_manifest_paths(enable_v2_manifest_paths.unwrap_or(false))
13131363
.with_detached(detached.unwrap_or(false))
@@ -1322,7 +1372,10 @@ impl Dataset {
13221372
}
13231373

13241374
let ds = RT
1325-
.block_on(commit_lock.map(|cl| cl.py()), builder.execute(transaction))?
1375+
.block_on(
1376+
commit_lock.map(|cl| cl.py()),
1377+
builder.execute(transaction.0),
1378+
)?
13261379
.map_err(|err| PyIOError::new_err(err.to_string()))?;
13271380

13281381
let uri = ds.uri().to_string();

‎python/src/transaction.rs

+29
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,20 @@ impl FromPyObject<'_> for PyLance<Operation> {
4747
};
4848
Ok(Self(op))
4949
}
50+
"Update" => {
51+
let removed_fragment_ids = ob.getattr("removed_fragment_ids")?.extract()?;
52+
53+
let updated_fragments = extract_vec(&ob.getattr("updated_fragments")?)?;
54+
55+
let new_fragments = extract_vec(&ob.getattr("new_fragments")?)?;
56+
57+
let op = Operation::Update {
58+
removed_fragment_ids,
59+
updated_fragments,
60+
new_fragments,
61+
};
62+
Ok(Self(op))
63+
}
5064
"Merge" => {
5165
let schema = extract_schema(&ob.getattr("schema")?)?;
5266

@@ -126,6 +140,21 @@ impl ToPyObject for PyLance<&Operation> {
126140
.expect("Failed to get Append class");
127141
cls.call1((fragments,)).unwrap().to_object(py)
128142
}
143+
Operation::Update {
144+
removed_fragment_ids,
145+
updated_fragments,
146+
new_fragments,
147+
} => {
148+
let removed_fragment_ids = removed_fragment_ids.to_object(py);
149+
let updated_fragments = export_vec(py, updated_fragments.as_slice());
150+
let new_fragments = export_vec(py, new_fragments.as_slice());
151+
let cls = namespace
152+
.getattr("Update")
153+
.expect("Failed to get Update class");
154+
cls.call1((removed_fragment_ids, updated_fragments, new_fragments))
155+
.unwrap()
156+
.to_object(py)
157+
}
129158
_ => todo!(),
130159
}
131160
}

‎rust/lance/src/dataset.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ pub use schema_evolution::{
8383
};
8484
pub use take::TakeBuilder;
8585
pub use write::merge_insert::{
86-
MergeInsertBuilder, MergeInsertJob, WhenMatched, WhenNotMatched, WhenNotMatchedBySource,
86+
MergeInsertBuilder, MergeInsertJob, MergeStats, WhenMatched, WhenNotMatched,
87+
WhenNotMatchedBySource,
8788
};
8889
pub use write::update::{UpdateBuilder, UpdateJob};
8990
#[allow(deprecated)]

0 commit comments

Comments
 (0)