Skip to content

Commit 2cd296f

Browse files
feat: add "merge insert" operation based on merge operation in other databases (#1647)
The "merge insert" operation can insert new rows, delete old rows, and update old rows, all in a single transaction. It is a generic operation that is used to provide upsert, find-or-create, and "replace range". closes #1456 --------- Co-authored-by: Will Jones <willjones127@gmail.com>
1 parent b3db3cc commit 2cd296f

File tree

20 files changed

+1436
-113
lines changed

20 files changed

+1436
-113
lines changed

python/python/lance/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
LanceDataset,
2121
LanceOperation,
2222
LanceScanner,
23+
MergeInsertBuilder,
2324
__version__,
2425
write_dataset,
2526
)
@@ -41,6 +42,7 @@
4142
"LanceDataset",
4243
"LanceOperation",
4344
"LanceScanner",
45+
"MergeInsertBuilder",
4446
"__version__",
4547
"write_dataset",
4648
"schema_to_json",

python/python/lance/dataset.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,14 @@
5050
from .dependencies import numpy as np
5151
from .dependencies import pandas as pd
5252
from .fragment import FragmentMetadata, LanceFragment
53-
from .lance import CleanupStats, _Dataset, _Operation, _Scanner, _write_dataset
53+
from .lance import (
54+
CleanupStats,
55+
_Dataset,
56+
_MergeInsertBuilder,
57+
_Operation,
58+
_Scanner,
59+
_write_dataset,
60+
)
5461
from .lance import CompactionMetrics as CompactionMetrics
5562
from .lance import __version__ as __version__
5663
from .optimize import Compaction
@@ -80,6 +87,12 @@
8087
]
8188

8289

90+
class MergeInsertBuilder(_MergeInsertBuilder):
91+
def execute(self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None):
92+
reader = _coerce_reader(data_obj, schema)
93+
super(MergeInsertBuilder, self).execute(reader)
94+
95+
8396
class LanceDataset(pa.dataset.Dataset):
8497
"""A dataset in Lance format where the data is stored at the given uri."""
8598

@@ -630,6 +643,12 @@ def delete(self, predicate: Union[str, pa.compute.Expression]):
630643
predicate = str(predicate)
631644
self._ds.delete(predicate)
632645

646+
def merge_insert(
647+
self,
648+
on: Union[str, Iterable[str]],
649+
):
650+
return MergeInsertBuilder(self._ds, on)
651+
633652
def update(
634653
self,
635654
updates: Dict[str, str],

python/python/tests/test_dataset.py

+187
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,193 @@ def test_delete_data(tmp_path: Path):
822822
assert dataset.count_rows() == 0
823823

824824

825+
def test_merge_insert(tmp_path: Path):
826+
nrows = 1000
827+
table = pa.Table.from_pydict({"a": range(nrows), "b": [1 for _ in range(nrows)]})
828+
dataset = lance.write_dataset(
829+
table, tmp_path / "dataset", mode="create", max_rows_per_file=100
830+
)
831+
version = dataset.version
832+
833+
new_table = pa.Table.from_pydict(
834+
{"a": range(300, 300 + nrows), "b": [2 for _ in range(nrows)]}
835+
)
836+
837+
is_new = pc.field("b") == 2
838+
839+
dataset.merge_insert("a").when_not_matched_insert_all().execute(new_table)
840+
table = dataset.to_table()
841+
assert table.num_rows == 1300
842+
assert table.filter(is_new).num_rows == 300
843+
844+
dataset = lance.dataset(tmp_path / "dataset", version=version)
845+
dataset.restore()
846+
dataset.merge_insert("a").when_matched_update_all().execute(new_table)
847+
table = dataset.to_table()
848+
assert table.num_rows == 1000
849+
assert table.filter(is_new).num_rows == 700
850+
851+
dataset = lance.dataset(tmp_path / "dataset", version=version)
852+
dataset.restore()
853+
dataset.merge_insert(
854+
"a"
855+
).when_not_matched_insert_all().when_matched_update_all().execute(new_table)
856+
table = dataset.to_table()
857+
assert table.num_rows == 1300
858+
assert table.filter(is_new).num_rows == 1000
859+
860+
dataset = lance.dataset(tmp_path / "dataset", version=version)
861+
dataset.restore()
862+
dataset.merge_insert("a").when_not_matched_by_source_delete().execute(new_table)
863+
table = dataset.to_table()
864+
assert table.num_rows == 700
865+
assert table.filter(is_new).num_rows == 0
866+
867+
dataset = lance.dataset(tmp_path / "dataset", version=version)
868+
dataset.restore()
869+
dataset.merge_insert("a").when_not_matched_by_source_delete(
870+
"a < 100"
871+
).when_not_matched_insert_all().execute(new_table)
872+
873+
table = dataset.to_table()
874+
assert table.num_rows == 1200
875+
assert table.filter(is_new).num_rows == 300
876+
877+
# If the user doesn't specify anything then the merge_insert is
878+
# a no-op and the operation fails
879+
dataset = lance.dataset(tmp_path / "dataset", version=version)
880+
dataset.restore()
881+
with pytest.raises(ValueError):
882+
dataset.merge_insert("a").execute(new_table)
883+
884+
885+
def test_merge_insert_source_is_dataset(tmp_path: Path):
886+
nrows = 1000
887+
table = pa.Table.from_pydict({"a": range(nrows), "b": [1 for _ in range(nrows)]})
888+
dataset = lance.write_dataset(
889+
table, tmp_path / "dataset", mode="create", max_rows_per_file=100
890+
)
891+
version = dataset.version
892+
893+
new_table = pa.Table.from_pydict(
894+
{"a": range(300, 300 + nrows), "b": [2 for _ in range(nrows)]}
895+
)
896+
new_dataset = lance.write_dataset(
897+
new_table, tmp_path / "dataset2", mode="create", max_rows_per_file=80
898+
)
899+
900+
is_new = pc.field("b") == 2
901+
902+
dataset.merge_insert("a").when_not_matched_insert_all().execute(new_dataset)
903+
table = dataset.to_table()
904+
assert table.num_rows == 1300
905+
assert table.filter(is_new).num_rows == 300
906+
907+
dataset = lance.dataset(tmp_path / "dataset", version=version)
908+
dataset.restore()
909+
910+
reader = new_dataset.to_batches()
911+
912+
dataset.merge_insert("a").when_not_matched_insert_all().execute(
913+
reader, schema=new_dataset.schema
914+
)
915+
table = dataset.to_table()
916+
assert table.num_rows == 1300
917+
assert table.filter(is_new).num_rows == 300
918+
919+
920+
def test_merge_insert_multiple_keys(tmp_path: Path):
921+
nrows = 1000
922+
# a - [0, 1, 2, ..., 999]
923+
# b - [1, 1, 1, ..., 1]
924+
# c - [0, 1, 0, ..., 1]
925+
table = pa.Table.from_pydict(
926+
{
927+
"a": range(nrows),
928+
"b": [1 for _ in range(nrows)],
929+
"c": [i % 2 for i in range(nrows)],
930+
}
931+
)
932+
dataset = lance.write_dataset(
933+
table, tmp_path / "dataset", mode="create", max_rows_per_file=100
934+
)
935+
936+
# a - [300, 301, 302, ..., 1299]
937+
# b - [2, 2, 2, ..., 2]
938+
# c - [0, 0, 0, ..., 0]
939+
new_table = pa.Table.from_pydict(
940+
{
941+
"a": range(300, 300 + nrows),
942+
"b": [2 for _ in range(nrows)],
943+
"c": [0 for _ in range(nrows)],
944+
}
945+
)
946+
947+
is_new = pc.field("b") == 2
948+
949+
dataset.merge_insert(["a", "c"]).when_matched_update_all().execute(new_table)
950+
table = dataset.to_table()
951+
assert table.num_rows == 1000
952+
assert table.filter(is_new).num_rows == 350
953+
954+
955+
def test_merge_insert_incompatible_schema(tmp_path: Path):
956+
nrows = 1000
957+
table = pa.Table.from_pydict(
958+
{
959+
"a": range(nrows),
960+
"b": [1 for _ in range(nrows)],
961+
}
962+
)
963+
dataset = lance.write_dataset(
964+
table, tmp_path / "dataset", mode="create", max_rows_per_file=100
965+
)
966+
967+
new_table = pa.Table.from_pydict(
968+
{
969+
"a": range(300, 300 + nrows),
970+
}
971+
)
972+
973+
with pytest.raises(OSError):
974+
dataset.merge_insert("a").when_matched_update_all().execute(new_table)
975+
976+
977+
def test_merge_insert_vector_column(tmp_path: Path):
978+
table = pa.Table.from_pydict(
979+
{
980+
"vec": pa.array([[1, 2, 3], [4, 5, 6]], pa.list_(pa.float32(), 3)),
981+
"key": [1, 2],
982+
}
983+
)
984+
985+
new_table = pa.Table.from_pydict(
986+
{
987+
"vec": pa.array([[7, 8, 9], [10, 11, 12]], pa.list_(pa.float32(), 3)),
988+
"key": [2, 3],
989+
}
990+
)
991+
992+
dataset = lance.write_dataset(
993+
table, tmp_path / "dataset", mode="create", max_rows_per_file=100
994+
)
995+
996+
dataset.merge_insert(
997+
["key"]
998+
).when_not_matched_insert_all().when_matched_update_all().execute(new_table)
999+
1000+
expected = pa.Table.from_pydict(
1001+
{
1002+
"vec": pa.array(
1003+
[[1, 2, 3], [7, 8, 9], [10, 11, 12]], pa.list_(pa.float32(), 3)
1004+
),
1005+
"key": [1, 2, 3],
1006+
}
1007+
)
1008+
1009+
assert dataset.to_table().sort_by("key") == expected
1010+
1011+
8251012
def test_update_dataset(tmp_path: Path):
8261013
nrows = 100
8271014
vecs = pa.FixedSizeListArray.from_arrays(

python/python/tests/test_fragment.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ def test_dataset_progress(tmp_path: Path):
136136

137137
assert fragment == FragmentMetadata.from_json(json.dumps(metadata))
138138

139-
p = multiprocessing.Process(target=failing_write, args=(progress_uri, dataset_uri))
139+
ctx = multiprocessing.get_context("spawn")
140+
p = ctx.Process(target=failing_write, args=(progress_uri, dataset_uri))
140141
p.start()
141142
try:
142143
p.join()

python/src/dataset.rs

+100-4
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,12 @@ use chrono::Duration;
2525

2626
use futures::{StreamExt, TryFutureExt};
2727
use lance::dataset::builder::DatasetBuilder;
28-
use lance::dataset::UpdateBuilder;
2928
use lance::dataset::{
3029
fragment::FileFragment as LanceFileFragment, progress::WriteFragmentProgress,
3130
scanner::Scanner as LanceScanner, transaction::Operation as LanceOperation,
32-
Dataset as LanceDataset, ReadParams, Version, WriteMode, WriteParams,
31+
Dataset as LanceDataset, MergeInsertBuilder as LanceMergeInsertBuilder, ReadParams,
32+
UpdateBuilder, Version, WhenMatched, WhenNotMatched, WhenNotMatchedBySource, WriteMode,
33+
WriteParams,
3334
};
3435
use lance::index::{
3536
scalar::ScalarIndexParams,
@@ -47,9 +48,9 @@ use lance_linalg::distance::MetricType;
4748
use lance_table::format::Fragment;
4849
use lance_table::io::commit::CommitHandler;
4950
use object_store::path::Path;
50-
use pyo3::exceptions::PyStopIteration;
51+
use pyo3::exceptions::{PyStopIteration, PyTypeError};
5152
use pyo3::prelude::*;
52-
use pyo3::types::{PyList, PySet};
53+
use pyo3::types::{PyList, PySet, PyString};
5354
use pyo3::{
5455
exceptions::{PyIOError, PyKeyError, PyValueError},
5556
pyclass,
@@ -93,6 +94,101 @@ fn convert_schema(arrow_schema: &ArrowSchema) -> PyResult<Schema> {
9394
})
9495
}
9596

97+
#[pyclass(name = "_MergeInsertBuilder", module = "_lib", subclass)]
98+
pub struct MergeInsertBuilder {
99+
builder: LanceMergeInsertBuilder,
100+
dataset: Py<Dataset>,
101+
}
102+
103+
#[pymethods]
104+
impl MergeInsertBuilder {
105+
#[new]
106+
pub fn new(dataset: &PyAny, on: &PyAny) -> PyResult<Self> {
107+
let dataset: Py<Dataset> = dataset.extract()?;
108+
let ds = dataset.borrow(on.py()).ds.clone();
109+
// Either a single string, which we put in a vector or an iterator
110+
// of strings, which we collect into a vector
111+
let on = PyAny::downcast::<PyString>(on)
112+
.map(|val| vec![val.to_string()])
113+
.or_else(|_| {
114+
let iterator = on.iter().map_err(|_| {
115+
PyTypeError::new_err(
116+
"The `on` argument to merge_insert must be a str or iterable of str",
117+
)
118+
})?;
119+
let mut keys = Vec::new();
120+
for key in iterator {
121+
keys.push(PyAny::downcast::<PyString>(key?)?.to_string());
122+
}
123+
PyResult::Ok(keys)
124+
})?;
125+
126+
let mut builder = LanceMergeInsertBuilder::try_new(ds, on)
127+
.map_err(|err| PyValueError::new_err(err.to_string()))?;
128+
129+
// We don't have do_nothing methods in python so we start with a blank slate
130+
builder
131+
.when_matched(WhenMatched::DoNothing)
132+
.when_not_matched(WhenNotMatched::DoNothing);
133+
134+
Ok(Self { builder, dataset })
135+
}
136+
137+
pub fn when_matched_update_all(mut slf: PyRefMut<Self>) -> PyResult<PyRefMut<Self>> {
138+
slf.builder.when_matched(WhenMatched::UpdateAll);
139+
Ok(slf)
140+
}
141+
142+
pub fn when_not_matched_insert_all(mut slf: PyRefMut<Self>) -> PyResult<PyRefMut<Self>> {
143+
slf.builder.when_not_matched(WhenNotMatched::InsertAll);
144+
Ok(slf)
145+
}
146+
147+
pub fn when_not_matched_by_source_delete<'a>(
148+
mut slf: PyRefMut<'a, Self>,
149+
expr: Option<&str>,
150+
) -> PyResult<PyRefMut<'a, Self>> {
151+
let new_val = if let Some(expr) = expr {
152+
let dataset = slf.dataset.borrow(slf.py());
153+
WhenNotMatchedBySource::delete_if(&dataset.ds, expr)
154+
.map_err(|err| PyValueError::new_err(err.to_string()))?
155+
} else {
156+
WhenNotMatchedBySource::Delete
157+
};
158+
slf.builder.when_not_matched_by_source(new_val);
159+
Ok(slf)
160+
}
161+
162+
pub fn execute(&mut self, new_data: &PyAny) -> PyResult<()> {
163+
let py = new_data.py();
164+
165+
let new_data: Box<dyn RecordBatchReader + Send> = if new_data.is_instance_of::<Scanner>() {
166+
let scanner: Scanner = new_data.extract()?;
167+
Box::new(
168+
RT.spawn(Some(py), async move { scanner.to_reader().await })?
169+
.map_err(|err| PyValueError::new_err(err.to_string()))?,
170+
)
171+
} else {
172+
Box::new(ArrowArrayStreamReader::from_pyarrow(new_data)?)
173+
};
174+
175+
let job = self
176+
.builder
177+
.try_build()
178+
.map_err(|err| PyValueError::new_err(err.to_string()))?;
179+
180+
let new_self = RT
181+
.spawn(Some(py), job.execute_reader(new_data))?
182+
.map_err(|err| PyIOError::new_err(err.to_string()))?;
183+
184+
let dataset = self.dataset.as_ref(py);
185+
186+
dataset.borrow_mut().ds = new_self;
187+
188+
Ok(())
189+
}
190+
}
191+
96192
#[pymethods]
97193
impl Operation {
98194
fn __repr__(&self) -> String {

0 commit comments

Comments
 (0)