@@ -10,9 +10,9 @@ use std::sync::Arc;
10
10
use crate :: scalar:: lance_format:: LanceIndexStore ;
11
11
use crate :: scalar:: { IndexReader , IndexStore , IndexWriter , InvertedIndexParams } ;
12
12
use crate :: vector:: graph:: OrderedFloat ;
13
- use arrow:: array:: AsArray ;
13
+ use arrow:: array:: { ArrayBuilder , AsArray , Int32Builder , StringBuilder } ;
14
14
use arrow:: datatypes;
15
- use arrow_array:: RecordBatch ;
15
+ use arrow_array:: { Int32Array , RecordBatch , StringArray } ;
16
16
use arrow_schema:: SchemaRef ;
17
17
use crossbeam_queue:: ArrayQueue ;
18
18
use datafusion:: execution:: SendableRecordBatchStream ;
@@ -131,8 +131,8 @@ impl InvertedIndexBuilder {
131
131
senders. push ( sender) ;
132
132
result_futs. push ( tokio:: spawn ( {
133
133
async move {
134
- while let Some ( ( row_id, tokens) ) = receiver. recv ( ) . await {
135
- worker. add ( row_id, tokens) . await ?;
134
+ while let Some ( ( row_id, tokens, positions ) ) = receiver. recv ( ) . await {
135
+ worker. add ( row_id, tokens, positions ) . await ?;
136
136
}
137
137
let reader = worker. into_reader ( inverted_list) . await ?;
138
138
Result :: Ok ( reader)
@@ -143,18 +143,15 @@ impl InvertedIndexBuilder {
143
143
let start = std:: time:: Instant :: now ( ) ;
144
144
let senders = Arc :: new ( senders) ;
145
145
let tokenizer_pool = Arc :: new ( ArrayQueue :: new ( num_shards) ) ;
146
- let token_buffers_pool = Arc :: new ( ArrayQueue :: new ( num_shards ) ) ;
146
+ let tokenizer = self . params . tokenizer_config . build ( ) ? ;
147
147
for _ in 0 ..num_shards {
148
- let _ = tokenizer_pool. push ( TOKENIZER . clone ( ) ) ;
149
- token_buffers_pool
150
- . push ( vec ! [ Vec :: new( ) ; num_shards] )
151
- . unwrap ( ) ;
148
+ let _ = tokenizer_pool. push ( tokenizer. clone ( ) ) ;
152
149
}
153
150
let mut stream = stream
154
151
. map ( move |batch| {
155
152
let senders = senders. clone ( ) ;
156
153
let tokenizer_pool = tokenizer_pool. clone ( ) ;
157
- let token_buffers_pool = token_buffers_pool. clone ( ) ;
154
+ // let token_buffers_pool = token_buffers_pool.clone();
158
155
CPU_RUNTIME . spawn_blocking ( move || {
159
156
let batch = batch?;
160
157
let doc_iter = iter_str_array ( batch. column ( 0 ) ) ;
@@ -164,37 +161,55 @@ impl InvertedIndexBuilder {
164
161
. filter_map ( |( doc, row_id) | doc. map ( |doc| ( doc, * row_id) ) ) ;
165
162
166
163
let mut tokenizer = tokenizer_pool. pop ( ) . unwrap ( ) ;
167
- let mut token_buffers = token_buffers_pool. pop ( ) . unwrap ( ) ;
168
164
169
165
let num_tokens = docs
170
166
. map ( |( doc, row_id) | {
171
167
// tokenize the document
168
+ let predicted_num_tokens = doc. len ( ) / 5 / num_shards;
169
+ let mut token_buffers = std:: iter:: repeat_with ( || {
170
+ (
171
+ StringBuilder :: with_capacity (
172
+ predicted_num_tokens,
173
+ doc. len ( ) / num_shards,
174
+ ) ,
175
+ Int32Builder :: with_capacity ( predicted_num_tokens) ,
176
+ )
177
+ } )
178
+ . take ( num_shards)
179
+ . collect_vec ( ) ;
172
180
let mut num_tokens = 0 ;
173
181
let mut token_stream = tokenizer. token_stream ( doc) ;
174
182
while token_stream. advance ( ) {
175
183
let token = token_stream. token_mut ( ) ;
176
184
let mut hasher = DefaultHasher :: new ( ) ;
177
185
hasher. write ( token. text . as_bytes ( ) ) ;
178
186
let shard = hasher. finish ( ) as usize % num_shards;
179
- token_buffers[ shard]
180
- . push ( ( std:: mem:: take ( & mut token. text ) , token. position as i32 ) ) ;
187
+ let ( ref mut token_builder, ref mut position_builder) =
188
+ & mut token_buffers[ shard] ;
189
+ token_builder. append_value ( & token. text ) ;
190
+ position_builder. append_value ( token. position as i32 ) ;
181
191
num_tokens += 1 ;
182
192
}
183
193
184
- for ( shard, buffer) in token_buffers. iter_mut ( ) . enumerate ( ) {
185
- if buffer. is_empty ( ) {
194
+ for ( shard, ( token_builder, position_builder) ) in
195
+ token_buffers. iter_mut ( ) . enumerate ( )
196
+ {
197
+ if token_builder. is_empty ( ) {
186
198
continue ;
187
199
}
188
- let buffer = std:: mem:: take ( buffer) ;
189
- senders[ shard] . blocking_send ( ( row_id, buffer) ) . unwrap ( ) ;
200
+
201
+ let tokens = token_builder. finish ( ) ;
202
+ let positions = position_builder. finish ( ) ;
203
+ senders[ shard]
204
+ . blocking_send ( ( row_id, tokens, positions) )
205
+ . unwrap ( ) ;
190
206
}
191
207
192
208
( row_id, num_tokens)
193
209
} )
194
210
. collect_vec ( ) ;
195
211
196
212
let _ = tokenizer_pool. push ( tokenizer) ;
197
- token_buffers_pool. push ( token_buffers) . unwrap ( ) ;
198
213
Result :: Ok ( num_tokens)
199
214
} )
200
215
} )
@@ -355,7 +370,10 @@ impl InvertedIndexBuilder {
355
370
let batch = tokens. to_batch ( ) ?;
356
371
let mut writer = store. new_index_file ( TOKENS_FILE , batch. schema ( ) ) . await ?;
357
372
writer. write_record_batch ( batch) . await ?;
358
- writer. finish ( ) . await ?;
373
+
374
+ let tokenizer = serde_json:: to_string ( & self . params . tokenizer_config ) ?;
375
+ let metadata = HashMap :: from_iter ( vec ! [ ( "tokenizer" . to_owned( ) , tokenizer) ] ) ;
376
+ writer. finish_with_metadata ( metadata) . await ?;
359
377
360
378
log:: info!( "finished writing tokens" ) ;
361
379
Ok ( ( ) )
@@ -421,21 +439,26 @@ impl IndexWorker {
421
439
self . schema . column_with_name ( POSITION_COL ) . is_some ( )
422
440
}
423
441
424
- async fn add ( & mut self , row_id : u64 , tokens : Vec < ( String , i32 ) > ) -> Result < ( ) > {
442
+ async fn add ( & mut self , row_id : u64 , tokens : StringArray , positions : Int32Array ) -> Result < ( ) > {
425
443
let mut token_occurrences = HashMap :: new ( ) ;
426
- for ( token, position) in tokens {
444
+ for ( token, position) in tokens. iter ( ) . zip ( positions. values ( ) . into_iter ( ) ) {
445
+ let token = if let Some ( token) = token {
446
+ token
447
+ } else {
448
+ continue ;
449
+ } ;
427
450
token_occurrences
428
451
. entry ( token)
429
452
. or_insert_with ( Vec :: new)
430
- . push ( position) ;
453
+ . push ( * position) ;
431
454
}
432
455
let with_position = self . has_position ( ) ;
433
456
token_occurrences
434
457
. into_iter ( )
435
458
. for_each ( |( token, term_positions) | {
436
459
let posting_list = self
437
460
. posting_lists
438
- . entry ( token. clone ( ) )
461
+ . entry ( token. to_owned ( ) )
439
462
. or_insert_with ( || PostingListBuilder :: empty ( with_position) ) ;
440
463
441
464
let old_size = if posting_list. is_empty ( ) {
@@ -702,20 +725,23 @@ mod tests {
702
725
use lance_io:: object_store:: ObjectStore ;
703
726
use object_store:: path:: Path ;
704
727
728
+ use crate :: scalar:: inverted:: TokenizerConfig ;
705
729
use crate :: scalar:: lance_format:: LanceIndexStore ;
706
730
use crate :: scalar:: { FullTextSearchQuery , SargableQuery , ScalarIndex } ;
707
731
708
732
use super :: InvertedIndex ;
709
733
710
734
async fn create_index < Offset : arrow:: array:: OffsetSizeTrait > (
711
735
with_position : bool ,
736
+ tokenizer : TokenizerConfig ,
712
737
) -> Arc < InvertedIndex > {
713
738
let tempdir = tempfile:: tempdir ( ) . unwrap ( ) ;
714
739
let index_dir = Path :: from_filesystem_path ( tempdir. path ( ) ) . unwrap ( ) ;
715
740
let cache = FileMetadataCache :: with_capacity ( 128 * 1024 * 1024 , CapacityMode :: Bytes ) ;
716
741
let store = LanceIndexStore :: new ( ObjectStore :: local ( ) , index_dir, cache) ;
717
742
718
- let params = super :: InvertedIndexParams :: default ( ) . with_position ( with_position) ;
743
+ let mut params = super :: InvertedIndexParams :: default ( ) . with_position ( with_position) ;
744
+ params. tokenizer_config = tokenizer;
719
745
let mut invert_index = super :: InvertedIndexBuilder :: new ( params) ;
720
746
let doc_col = GenericStringArray :: < Offset > :: from ( vec ! [
721
747
"lance database the search" ,
@@ -724,6 +750,7 @@ mod tests {
724
750
"database search" ,
725
751
"unrelated doc" ,
726
752
"unrelated" ,
753
+ "mots accentués" ,
727
754
] ) ;
728
755
let row_id_col = UInt64Array :: from ( Vec :: from_iter ( 0 ..doc_col. len ( ) as u64 ) ) ;
729
756
let batch = RecordBatch :: try_new (
@@ -750,7 +777,7 @@ mod tests {
750
777
}
751
778
752
779
async fn test_inverted_index < Offset : arrow:: array:: OffsetSizeTrait > ( ) {
753
- let invert_index = create_index :: < Offset > ( false ) . await ;
780
+ let invert_index = create_index :: < Offset > ( false , TokenizerConfig :: default ( ) ) . await ;
754
781
let row_ids = invert_index
755
782
. search ( & SargableQuery :: FullTextSearch (
756
783
FullTextSearchQuery :: new ( "lance" . to_owned ( ) ) . limit ( Some ( 3 ) ) ,
@@ -800,7 +827,7 @@ mod tests {
800
827
assert ! ( results. unwrap_err( ) . to_string( ) . contains( "position is not found but required for phrase queries, try recreating the index with position" ) ) ;
801
828
802
829
// recreate the index with position
803
- let invert_index = create_index :: < Offset > ( true ) . await ;
830
+ let invert_index = create_index :: < Offset > ( true , TokenizerConfig :: default ( ) ) . await ;
804
831
let row_ids = invert_index
805
832
. search ( & SargableQuery :: FullTextSearch (
806
833
FullTextSearchQuery :: new ( "lance database" . to_owned ( ) ) . limit ( Some ( 10 ) ) ,
@@ -857,4 +884,43 @@ mod tests {
857
884
async fn test_inverted_index_with_large_string ( ) {
858
885
test_inverted_index :: < i64 > ( ) . await ;
859
886
}
887
+
888
+ #[ tokio:: test]
889
+ async fn test_accented_chars ( ) {
890
+ let invert_index = create_index :: < i32 > ( false , TokenizerConfig :: default ( ) ) . await ;
891
+ let row_ids = invert_index
892
+ . search ( & SargableQuery :: FullTextSearch (
893
+ FullTextSearchQuery :: new ( "accentués" . to_owned ( ) ) . limit ( Some ( 3 ) ) ,
894
+ ) )
895
+ . await
896
+ . unwrap ( ) ;
897
+ assert_eq ! ( row_ids. len( ) , Some ( 1 ) ) ;
898
+
899
+ let row_ids = invert_index
900
+ . search ( & SargableQuery :: FullTextSearch (
901
+ FullTextSearchQuery :: new ( "accentues" . to_owned ( ) ) . limit ( Some ( 3 ) ) ,
902
+ ) )
903
+ . await
904
+ . unwrap ( ) ;
905
+ assert_eq ! ( row_ids. len( ) , Some ( 0 ) ) ;
906
+
907
+ // with ascii folding enabled, the search should be accent-insensitive
908
+ let invert_index =
909
+ create_index :: < i32 > ( true , TokenizerConfig :: default ( ) . ascii_folding ( true ) ) . await ;
910
+ let row_ids = invert_index
911
+ . search ( & SargableQuery :: FullTextSearch (
912
+ FullTextSearchQuery :: new ( "accentués" . to_owned ( ) ) . limit ( Some ( 3 ) ) ,
913
+ ) )
914
+ . await
915
+ . unwrap ( ) ;
916
+ assert_eq ! ( row_ids. len( ) , Some ( 1 ) ) ;
917
+
918
+ let row_ids = invert_index
919
+ . search ( & SargableQuery :: FullTextSearch (
920
+ FullTextSearchQuery :: new ( "accentues" . to_owned ( ) ) . limit ( Some ( 3 ) ) ,
921
+ ) )
922
+ . await
923
+ . unwrap ( ) ;
924
+ assert_eq ! ( row_ids. len( ) , Some ( 1 ) ) ;
925
+ }
860
926
}
0 commit comments