Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 1173939

Browse files
committedDec 20, 2024··
feat(rust): execute_uncommitted for merge_insert
expose in python refactor: make transaction marshalling easier cleanup fix tests fix path backward compatibility fix repr get changes back
1 parent 10e6454 commit 1173939

File tree

6 files changed

+243
-79
lines changed

6 files changed

+243
-79
lines changed
 

‎python/python/lance/dataset.py

+80-13
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Optional,
2727
Sequence,
2828
Set,
29+
Tuple,
2930
TypedDict,
3031
Union,
3132
)
@@ -102,6 +103,30 @@ def execute(self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None):
102103

103104
return super(MergeInsertBuilder, self).execute(reader)
104105

106+
def execute_uncommitted(
107+
self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None
108+
) -> Tuple[Transaction, Dict[str, Any]]:
109+
"""Executes the merge insert operation without committing
110+
111+
This function updates the original dataset and returns a dictionary with
112+
information about merge statistics - i.e. the number of inserted, updated,
113+
and deleted rows.
114+
115+
Parameters
116+
----------
117+
118+
data_obj: ReaderLike
119+
The new data to use as the source table for the operation. This parameter
120+
can be any source of data (e.g. table / dataset) that
121+
:func:`~lance.write_dataset` accepts.
122+
schema: Optional[pa.Schema]
123+
The schema of the data. This only needs to be supplied whenever the data
124+
source is some kind of generator.
125+
"""
126+
reader = _coerce_reader(data_obj, schema)
127+
128+
return super(MergeInsertBuilder, self).execute_uncommitted(reader)
129+
105130
# These next three overrides exist only to document the methods
106131

107132
def when_matched_update_all(
@@ -2200,7 +2225,7 @@ def _commit(
22002225
@staticmethod
22012226
def commit(
22022227
base_uri: Union[str, Path, LanceDataset],
2203-
operation: LanceOperation.BaseOperation,
2228+
operation: Union[LanceOperation.BaseOperation, Transaction],
22042229
read_version: Optional[int] = None,
22052230
commit_lock: Optional[CommitLock] = None,
22062231
storage_options: Optional[Dict[str, str]] = None,
@@ -2305,24 +2330,44 @@ def commit(
23052330
f"commit_lock must be a function, got {type(commit_lock)}"
23062331
)
23072332

2308-
if read_version is None and not isinstance(
2309-
operation, (LanceOperation.Overwrite, LanceOperation.Restore)
2333+
if (
2334+
isinstance(operation, LanceOperation.BaseOperation)
2335+
and read_version is None
2336+
and not isinstance(
2337+
operation, (LanceOperation.Overwrite, LanceOperation.Restore)
2338+
)
23102339
):
23112340
raise ValueError(
23122341
"read_version is required for all operations except "
23132342
"Overwrite and Restore"
23142343
)
23152344

2316-
new_ds = _Dataset.commit(
2317-
base_uri,
2318-
operation,
2319-
read_version,
2320-
commit_lock,
2321-
storage_options=storage_options,
2322-
enable_v2_manifest_paths=enable_v2_manifest_paths,
2323-
detached=detached,
2324-
max_retries=max_retries,
2325-
)
2345+
if isinstance(operation, Transaction):
2346+
new_ds = _Dataset.commit_transaction(
2347+
base_uri,
2348+
operation,
2349+
commit_lock,
2350+
storage_options=storage_options,
2351+
enable_v2_manifest_paths=enable_v2_manifest_paths,
2352+
detached=detached,
2353+
max_retries=max_retries,
2354+
)
2355+
elif isinstance(operation, LanceOperation.BaseOperation):
2356+
new_ds = _Dataset.commit(
2357+
base_uri,
2358+
operation,
2359+
read_version,
2360+
commit_lock,
2361+
storage_options=storage_options,
2362+
enable_v2_manifest_paths=enable_v2_manifest_paths,
2363+
detached=detached,
2364+
max_retries=max_retries,
2365+
)
2366+
else:
2367+
raise TypeError(
2368+
"operation must be a LanceOperation.BaseOperation or Transaction, "
2369+
f"got {type(operation)}"
2370+
)
23262371
ds = LanceDataset.__new__(LanceDataset)
23272372
ds._storage_options = storage_options
23282373
ds._ds = new_ds
@@ -2666,6 +2711,28 @@ class Delete(BaseOperation):
26662711
def __post_init__(self):
26672712
LanceOperation._validate_fragments(self.updated_fragments)
26682713

2714+
@dataclass
2715+
class Update(BaseOperation):
2716+
"""
2717+
Operation that updates rows in the dataset.
2718+
Attributes
2719+
----------
2720+
removed_fragment_ids: list[int]
2721+
The ids of the fragments that have been removed entirely.
2722+
updated_fragments: list[FragmentMetadata]
2723+
The fragments that have been updated with new deletion vectors.
2724+
new_fragments: list[FragmentMetadata]
2725+
The fragments that contain the new rows.
2726+
"""
2727+
2728+
removed_fragment_ids: List[int]
2729+
updated_fragments: List[FragmentMetadata]
2730+
new_fragments: List[FragmentMetadata]
2731+
2732+
def __post_init__(self):
2733+
LanceOperation._validate_fragments(self.updated_fragments)
2734+
LanceOperation._validate_fragments(self.new_fragments)
2735+
26692736
@dataclass
26702737
class Merge(BaseOperation):
26712738
"""

‎python/python/tests/test_dataset.py

+25
Original file line numberDiff line numberDiff line change
@@ -1015,6 +1015,31 @@ def test_restore_with_commit(tmp_path: Path):
10151015
assert tbl == table
10161016

10171017

1018+
def test_merge_insert_with_commit():
1019+
table = pa.table({"id": range(10), "updated": [False] * 10})
1020+
dataset = lance.write_dataset(table, "memory://test")
1021+
1022+
updates = pa.Table.from_pylist([{"id": 1, "updated": True}])
1023+
transaction, stats = (
1024+
dataset.merge_insert(on="id")
1025+
.when_matched_update_all()
1026+
.execute_uncommitted(updates)
1027+
)
1028+
1029+
assert isinstance(stats, dict)
1030+
assert stats["num_updated_rows"] == 1
1031+
assert stats["num_inserted_rows"] == 0
1032+
assert stats["num_deleted_rows"] == 0
1033+
1034+
assert isinstance(transaction, lance.Transaction)
1035+
assert isinstance(transaction.operation, lance.LanceOperation.Update)
1036+
1037+
dataset = lance.LanceDataset.commit(dataset, transaction)
1038+
assert dataset.to_table().sort_by("id") == pa.table(
1039+
{"id": range(10), "updated": [False] + [True] + [False] * 8}
1040+
)
1041+
1042+
10181043
def test_merge_with_commit(tmp_path: Path):
10191044
table = pa.Table.from_pydict({"a": range(100), "b": range(100)})
10201045
base_dir = tmp_path / "test"

‎python/src/dataset.rs

+66-13
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ use lance::dataset::{
4040
WriteParams,
4141
};
4242
use lance::dataset::{
43-
BatchInfo, BatchUDF, CommitBuilder, NewColumnTransform, UDFCheckpointStore, WriteDestination,
43+
BatchInfo, BatchUDF, CommitBuilder, MergeStats, NewColumnTransform, UDFCheckpointStore,
44+
WriteDestination,
4445
};
4546
use lance::dataset::{ColumnAlteration, ProjectionRequest};
4647
use lance::index::{vector::VectorIndexParams, DatasetIndexInternalExt};
@@ -194,20 +195,46 @@ impl MergeInsertBuilder {
194195
.try_build()
195196
.map_err(|err| PyValueError::new_err(err.to_string()))?;
196197

197-
let new_self = RT
198+
let (new_dataset, stats) = RT
198199
.spawn(Some(py), job.execute_reader(new_data))?
199200
.map_err(|err| PyIOError::new_err(err.to_string()))?;
200201

201202
let dataset = self.dataset.bind(py);
202203

203-
dataset.borrow_mut().ds = new_self.0;
204-
let merge_stats = new_self.1;
205-
let merge_dict = PyDict::new_bound(py);
206-
merge_dict.set_item("num_inserted_rows", merge_stats.num_inserted_rows)?;
207-
merge_dict.set_item("num_updated_rows", merge_stats.num_updated_rows)?;
208-
merge_dict.set_item("num_deleted_rows", merge_stats.num_deleted_rows)?;
204+
dataset.borrow_mut().ds = new_dataset;
205+
206+
Ok(Self::build_stats(&stats, py)?.into())
207+
}
208+
209+
pub fn execute_uncommitted<'a>(
210+
&mut self,
211+
new_data: &Bound<'a, PyAny>,
212+
) -> PyResult<(PyLance<Transaction>, Bound<'a, PyDict>)> {
213+
let py = new_data.py();
214+
let new_data = convert_reader(new_data)?;
209215

210-
Ok(merge_dict.into())
216+
let job = self
217+
.builder
218+
.try_build()
219+
.map_err(|err| PyValueError::new_err(err.to_string()))?;
220+
221+
let (transaction, stats) = RT
222+
.spawn(Some(py), job.execute_uncommitted(new_data))?
223+
.map_err(|err| PyIOError::new_err(err.to_string()))?;
224+
225+
let stats = Self::build_stats(&stats, py)?;
226+
227+
Ok((PyLance(transaction), stats))
228+
}
229+
}
230+
231+
impl MergeInsertBuilder {
232+
fn build_stats<'a>(stats: &MergeStats, py: Python<'a>) -> PyResult<Bound<'a, PyDict>> {
233+
let dict = PyDict::new_bound(py);
234+
dict.set_item("num_inserted_rows", stats.num_inserted_rows)?;
235+
dict.set_item("num_updated_rows", stats.num_updated_rows)?;
236+
dict.set_item("num_deleted_rows", stats.num_deleted_rows)?;
237+
Ok(dict)
211238
}
212239
}
213240

@@ -1284,6 +1311,32 @@ impl Dataset {
12841311
enable_v2_manifest_paths: Option<bool>,
12851312
detached: Option<bool>,
12861313
max_retries: Option<u32>,
1314+
) -> PyResult<Self> {
1315+
let transaction =
1316+
Transaction::new(read_version.unwrap_or_default(), operation.0, None, None);
1317+
1318+
Self::commit_transaction(
1319+
dest,
1320+
PyLance(transaction),
1321+
commit_lock,
1322+
storage_options,
1323+
enable_v2_manifest_paths,
1324+
detached,
1325+
max_retries,
1326+
)
1327+
}
1328+
1329+
#[allow(clippy::too_many_arguments)]
1330+
#[staticmethod]
1331+
#[pyo3(signature = (dest, transaction, commit_lock = None, storage_options = None, enable_v2_manifest_paths = None, detached = None, max_retries = None))]
1332+
fn commit_transaction(
1333+
dest: &Bound<PyAny>,
1334+
transaction: PyLance<Transaction>,
1335+
commit_lock: Option<&Bound<'_, PyAny>>,
1336+
storage_options: Option<HashMap<String, String>>,
1337+
enable_v2_manifest_paths: Option<bool>,
1338+
detached: Option<bool>,
1339+
max_retries: Option<u32>,
12871340
) -> PyResult<Self> {
12881341
let object_store_params =
12891342
storage_options
@@ -1305,9 +1358,6 @@ impl Dataset {
13051358
WriteDestination::Uri(dest.extract()?)
13061359
};
13071360

1308-
let transaction =
1309-
Transaction::new(read_version.unwrap_or_default(), operation.0, None, None);
1310-
13111361
let mut builder = CommitBuilder::new(dest)
13121362
.enable_v2_manifest_paths(enable_v2_manifest_paths.unwrap_or(false))
13131363
.with_detached(detached.unwrap_or(false))
@@ -1322,7 +1372,10 @@ impl Dataset {
13221372
}
13231373

13241374
let ds = RT
1325-
.block_on(commit_lock.map(|cl| cl.py()), builder.execute(transaction))?
1375+
.block_on(
1376+
commit_lock.map(|cl| cl.py()),
1377+
builder.execute(transaction.0),
1378+
)?
13261379
.map_err(|err| PyIOError::new_err(err.to_string()))?;
13271380

13281381
let uri = ds.uri().to_string();

‎python/src/transaction.rs

+29
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,20 @@ impl FromPyObject<'_> for PyLance<Operation> {
4747
};
4848
Ok(Self(op))
4949
}
50+
"Update" => {
51+
let removed_fragment_ids = ob.getattr("removed_fragment_ids")?.extract()?;
52+
53+
let updated_fragments = extract_vec(&ob.getattr("updated_fragments")?)?;
54+
55+
let new_fragments = extract_vec(&ob.getattr("new_fragments")?)?;
56+
57+
let op = Operation::Update {
58+
removed_fragment_ids,
59+
updated_fragments,
60+
new_fragments,
61+
};
62+
Ok(Self(op))
63+
}
5064
"Merge" => {
5165
let schema = extract_schema(&ob.getattr("schema")?)?;
5266

@@ -126,6 +140,21 @@ impl ToPyObject for PyLance<&Operation> {
126140
.expect("Failed to get Append class");
127141
cls.call1((fragments,)).unwrap().to_object(py)
128142
}
143+
Operation::Update {
144+
removed_fragment_ids,
145+
updated_fragments,
146+
new_fragments,
147+
} => {
148+
let removed_fragment_ids = removed_fragment_ids.to_object(py);
149+
let updated_fragments = export_vec(py, updated_fragments.as_slice());
150+
let new_fragments = export_vec(py, new_fragments.as_slice());
151+
let cls = namespace
152+
.getattr("Update")
153+
.expect("Failed to get Update class");
154+
cls.call1((removed_fragment_ids, updated_fragments, new_fragments))
155+
.unwrap()
156+
.to_object(py)
157+
}
129158
_ => todo!(),
130159
}
131160
}

‎rust/lance/src/dataset.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ pub use schema_evolution::{
8383
};
8484
pub use take::TakeBuilder;
8585
pub use write::merge_insert::{
86-
MergeInsertBuilder, MergeInsertJob, WhenMatched, WhenNotMatched, WhenNotMatchedBySource,
86+
MergeInsertBuilder, MergeInsertJob, MergeStats, WhenMatched, WhenNotMatched,
87+
WhenNotMatchedBySource,
8788
};
8889
pub use write::update::{UpdateBuilder, UpdateJob};
8990
#[allow(deprecated)]

‎rust/lance/src/dataset/write/merge_insert.rs

+41-52
Original file line numberDiff line numberDiff line change
@@ -82,17 +82,13 @@ use crate::{
8282
write::open_writer,
8383
},
8484
index::DatasetIndexInternalExt,
85-
io::{
86-
commit::commit_transaction,
87-
exec::{
88-
project, scalar_index::MapIndexExec, utils::ReplayExec, AddRowAddrExec, Planner,
89-
TakeExec,
90-
},
85+
io::exec::{
86+
project, scalar_index::MapIndexExec, utils::ReplayExec, AddRowAddrExec, Planner, TakeExec,
9187
},
9288
Dataset,
9389
};
9490

95-
use super::{write_fragments_internal, WriteParams};
91+
use super::{write_fragments_internal, CommitBuilder, WriteParams};
9692

9793
// "update if" expressions typically compare fields from the source table to the target table.
9894
// These tables have the same schema and so filter expressions need to differentiate. To do that
@@ -927,6 +923,27 @@ impl MergeInsertJob {
927923
self,
928924
source: SendableRecordBatchStream,
929925
) -> Result<(Arc<Dataset>, MergeStats)> {
926+
let ds = self.dataset.clone();
927+
let (transaction, stats) = self.execute_uncommitted_impl(source).await?;
928+
let dataset = CommitBuilder::new(ds).execute(transaction).await?;
929+
Ok((Arc::new(dataset), stats))
930+
}
931+
932+
/// Execute the merge insert job without committing the changes.
933+
///
934+
/// Use [`CommitBuilder`] to commit the returned transaction.
935+
pub async fn execute_uncommitted(
936+
self,
937+
source: impl StreamingWriteSource,
938+
) -> Result<(Transaction, MergeStats)> {
939+
let stream = source.into_stream();
940+
self.execute_uncommitted_impl(stream).await
941+
}
942+
943+
async fn execute_uncommitted_impl(
944+
self,
945+
source: SendableRecordBatchStream,
946+
) -> Result<(Transaction, MergeStats)> {
930947
let schema = source.schema();
931948

932949
let full_schema = Schema::from(self.dataset.local_schema());
@@ -942,7 +959,7 @@ impl MergeInsertJob {
942959
.try_flatten();
943960
let stream = RecordBatchStreamAdapter::new(merger_schema, stream);
944961

945-
let committed_ds = if !is_full_schema {
962+
let operation = if !is_full_schema {
946963
if !matches!(
947964
self.params.delete_not_matched_by_source,
948965
WhenNotMatchedBySource::Keep
@@ -956,7 +973,11 @@ impl MergeInsertJob {
956973
let (updated_fragments, new_fragments) =
957974
Self::update_fragments(self.dataset.clone(), Box::pin(stream)).await?;
958975

959-
Self::commit(self.dataset, Vec::new(), updated_fragments, new_fragments).await?
976+
Operation::Update {
977+
removed_fragment_ids: Vec::new(),
978+
updated_fragments,
979+
new_fragments,
980+
}
960981
} else {
961982
let written = write_fragments_internal(
962983
Some(&self.dataset),
@@ -978,21 +999,26 @@ impl MergeInsertJob {
978999
Self::apply_deletions(&self.dataset, &removed_row_ids).await?;
9791000

9801001
// Commit updated and new fragments
981-
Self::commit(
982-
self.dataset,
1002+
Operation::Update {
9831003
removed_fragment_ids,
984-
old_fragments,
1004+
updated_fragments: old_fragments,
9851005
new_fragments,
986-
)
987-
.await?
1006+
}
9881007
};
9891008

9901009
let stats = Arc::into_inner(merge_statistics)
9911010
.unwrap()
9921011
.into_inner()
9931012
.unwrap();
9941013

995-
Ok((committed_ds, stats))
1014+
let transaction = Transaction::new(
1015+
self.dataset.manifest.version,
1016+
operation,
1017+
/*blobs_op=*/ None,
1018+
None,
1019+
);
1020+
1021+
Ok((transaction, stats))
9961022
}
9971023

9981024
// Delete a batch of rows by id, returns the fragments modified and the fragments removed
@@ -1041,43 +1067,6 @@ impl MergeInsertJob {
10411067

10421068
Ok((updated_fragments, removed_fragments))
10431069
}
1044-
1045-
// Commit the operation
1046-
async fn commit(
1047-
dataset: Arc<Dataset>,
1048-
removed_fragment_ids: Vec<u64>,
1049-
updated_fragments: Vec<Fragment>,
1050-
new_fragments: Vec<Fragment>,
1051-
) -> Result<Arc<Dataset>> {
1052-
let operation = Operation::Update {
1053-
removed_fragment_ids,
1054-
updated_fragments,
1055-
new_fragments,
1056-
};
1057-
let transaction = Transaction::new(
1058-
dataset.manifest.version,
1059-
operation,
1060-
/*blobs_op=*/ None,
1061-
None,
1062-
);
1063-
1064-
let (manifest, manifest_path) = commit_transaction(
1065-
dataset.as_ref(),
1066-
dataset.object_store(),
1067-
dataset.commit_handler.as_ref(),
1068-
&transaction,
1069-
&Default::default(),
1070-
&Default::default(),
1071-
dataset.manifest_naming_scheme,
1072-
)
1073-
.await?;
1074-
1075-
let mut dataset = dataset.as_ref().clone();
1076-
dataset.manifest = Arc::new(manifest);
1077-
dataset.manifest_file = manifest_path;
1078-
1079-
Ok(Arc::new(dataset))
1080-
}
10811070
}
10821071

10831072
/// Merger will store these statistics as it runs (for each batch)

0 commit comments

Comments
 (0)
Please sign in to comment.