@@ -822,6 +822,193 @@ def test_delete_data(tmp_path: Path):
822
822
assert dataset .count_rows () == 0
823
823
824
824
825
+ def test_merge_insert (tmp_path : Path ):
826
+ nrows = 1000
827
+ table = pa .Table .from_pydict ({"a" : range (nrows ), "b" : [1 for _ in range (nrows )]})
828
+ dataset = lance .write_dataset (
829
+ table , tmp_path / "dataset" , mode = "create" , max_rows_per_file = 100
830
+ )
831
+ version = dataset .version
832
+
833
+ new_table = pa .Table .from_pydict (
834
+ {"a" : range (300 , 300 + nrows ), "b" : [2 for _ in range (nrows )]}
835
+ )
836
+
837
+ is_new = pc .field ("b" ) == 2
838
+
839
+ dataset .merge_insert ("a" ).when_not_matched_insert_all ().execute (new_table )
840
+ table = dataset .to_table ()
841
+ assert table .num_rows == 1300
842
+ assert table .filter (is_new ).num_rows == 300
843
+
844
+ dataset = lance .dataset (tmp_path / "dataset" , version = version )
845
+ dataset .restore ()
846
+ dataset .merge_insert ("a" ).when_matched_update_all ().execute (new_table )
847
+ table = dataset .to_table ()
848
+ assert table .num_rows == 1000
849
+ assert table .filter (is_new ).num_rows == 700
850
+
851
+ dataset = lance .dataset (tmp_path / "dataset" , version = version )
852
+ dataset .restore ()
853
+ dataset .merge_insert (
854
+ "a"
855
+ ).when_not_matched_insert_all ().when_matched_update_all ().execute (new_table )
856
+ table = dataset .to_table ()
857
+ assert table .num_rows == 1300
858
+ assert table .filter (is_new ).num_rows == 1000
859
+
860
+ dataset = lance .dataset (tmp_path / "dataset" , version = version )
861
+ dataset .restore ()
862
+ dataset .merge_insert ("a" ).when_not_matched_by_source_delete ().execute (new_table )
863
+ table = dataset .to_table ()
864
+ assert table .num_rows == 700
865
+ assert table .filter (is_new ).num_rows == 0
866
+
867
+ dataset = lance .dataset (tmp_path / "dataset" , version = version )
868
+ dataset .restore ()
869
+ dataset .merge_insert ("a" ).when_not_matched_by_source_delete (
870
+ "a < 100"
871
+ ).when_not_matched_insert_all ().execute (new_table )
872
+
873
+ table = dataset .to_table ()
874
+ assert table .num_rows == 1200
875
+ assert table .filter (is_new ).num_rows == 300
876
+
877
+ # If the user doesn't specify anything then the merge_insert is
878
+ # a no-op and the operation fails
879
+ dataset = lance .dataset (tmp_path / "dataset" , version = version )
880
+ dataset .restore ()
881
+ with pytest .raises (ValueError ):
882
+ dataset .merge_insert ("a" ).execute (new_table )
883
+
884
+
885
+ def test_merge_insert_source_is_dataset (tmp_path : Path ):
886
+ nrows = 1000
887
+ table = pa .Table .from_pydict ({"a" : range (nrows ), "b" : [1 for _ in range (nrows )]})
888
+ dataset = lance .write_dataset (
889
+ table , tmp_path / "dataset" , mode = "create" , max_rows_per_file = 100
890
+ )
891
+ version = dataset .version
892
+
893
+ new_table = pa .Table .from_pydict (
894
+ {"a" : range (300 , 300 + nrows ), "b" : [2 for _ in range (nrows )]}
895
+ )
896
+ new_dataset = lance .write_dataset (
897
+ new_table , tmp_path / "dataset2" , mode = "create" , max_rows_per_file = 80
898
+ )
899
+
900
+ is_new = pc .field ("b" ) == 2
901
+
902
+ dataset .merge_insert ("a" ).when_not_matched_insert_all ().execute (new_dataset )
903
+ table = dataset .to_table ()
904
+ assert table .num_rows == 1300
905
+ assert table .filter (is_new ).num_rows == 300
906
+
907
+ dataset = lance .dataset (tmp_path / "dataset" , version = version )
908
+ dataset .restore ()
909
+
910
+ reader = new_dataset .to_batches ()
911
+
912
+ dataset .merge_insert ("a" ).when_not_matched_insert_all ().execute (
913
+ reader , schema = new_dataset .schema
914
+ )
915
+ table = dataset .to_table ()
916
+ assert table .num_rows == 1300
917
+ assert table .filter (is_new ).num_rows == 300
918
+
919
+
920
+ def test_merge_insert_multiple_keys (tmp_path : Path ):
921
+ nrows = 1000
922
+ # a - [0, 1, 2, ..., 999]
923
+ # b - [1, 1, 1, ..., 1]
924
+ # c - [0, 1, 0, ..., 1]
925
+ table = pa .Table .from_pydict (
926
+ {
927
+ "a" : range (nrows ),
928
+ "b" : [1 for _ in range (nrows )],
929
+ "c" : [i % 2 for i in range (nrows )],
930
+ }
931
+ )
932
+ dataset = lance .write_dataset (
933
+ table , tmp_path / "dataset" , mode = "create" , max_rows_per_file = 100
934
+ )
935
+
936
+ # a - [300, 301, 302, ..., 1299]
937
+ # b - [2, 2, 2, ..., 2]
938
+ # c - [0, 0, 0, ..., 0]
939
+ new_table = pa .Table .from_pydict (
940
+ {
941
+ "a" : range (300 , 300 + nrows ),
942
+ "b" : [2 for _ in range (nrows )],
943
+ "c" : [0 for _ in range (nrows )],
944
+ }
945
+ )
946
+
947
+ is_new = pc .field ("b" ) == 2
948
+
949
+ dataset .merge_insert (["a" , "c" ]).when_matched_update_all ().execute (new_table )
950
+ table = dataset .to_table ()
951
+ assert table .num_rows == 1000
952
+ assert table .filter (is_new ).num_rows == 350
953
+
954
+
955
+ def test_merge_insert_incompatible_schema (tmp_path : Path ):
956
+ nrows = 1000
957
+ table = pa .Table .from_pydict (
958
+ {
959
+ "a" : range (nrows ),
960
+ "b" : [1 for _ in range (nrows )],
961
+ }
962
+ )
963
+ dataset = lance .write_dataset (
964
+ table , tmp_path / "dataset" , mode = "create" , max_rows_per_file = 100
965
+ )
966
+
967
+ new_table = pa .Table .from_pydict (
968
+ {
969
+ "a" : range (300 , 300 + nrows ),
970
+ }
971
+ )
972
+
973
+ with pytest .raises (OSError ):
974
+ dataset .merge_insert ("a" ).when_matched_update_all ().execute (new_table )
975
+
976
+
977
+ def test_merge_insert_vector_column (tmp_path : Path ):
978
+ table = pa .Table .from_pydict (
979
+ {
980
+ "vec" : pa .array ([[1 , 2 , 3 ], [4 , 5 , 6 ]], pa .list_ (pa .float32 (), 3 )),
981
+ "key" : [1 , 2 ],
982
+ }
983
+ )
984
+
985
+ new_table = pa .Table .from_pydict (
986
+ {
987
+ "vec" : pa .array ([[7 , 8 , 9 ], [10 , 11 , 12 ]], pa .list_ (pa .float32 (), 3 )),
988
+ "key" : [2 , 3 ],
989
+ }
990
+ )
991
+
992
+ dataset = lance .write_dataset (
993
+ table , tmp_path / "dataset" , mode = "create" , max_rows_per_file = 100
994
+ )
995
+
996
+ dataset .merge_insert (
997
+ ["key" ]
998
+ ).when_not_matched_insert_all ().when_matched_update_all ().execute (new_table )
999
+
1000
+ expected = pa .Table .from_pydict (
1001
+ {
1002
+ "vec" : pa .array (
1003
+ [[1 , 2 , 3 ], [7 , 8 , 9 ], [10 , 11 , 12 ]], pa .list_ (pa .float32 (), 3 )
1004
+ ),
1005
+ "key" : [1 , 2 , 3 ],
1006
+ }
1007
+ )
1008
+
1009
+ assert dataset .to_table ().sort_by ("key" ) == expected
1010
+
1011
+
825
1012
def test_update_dataset (tmp_path : Path ):
826
1013
nrows = 100
827
1014
vecs = pa .FixedSizeListArray .from_arrays (
0 commit comments