17
17
import static org .opensearch .neuralsearch .query .NeuralQueryBuilder .MODEL_ID_FIELD ;
18
18
import static org .opensearch .neuralsearch .query .NeuralQueryBuilder .QUERY_TEXT_FIELD ;
19
19
20
+ import java .util .HashSet ;
20
21
import java .util .Iterator ;
21
22
import java .util .List ;
22
23
import java .util .Map ;
23
24
import java .util .ArrayList ;
25
+ import java .util .Optional ;
26
+ import java .util .Set ;
24
27
import java .util .function .Supplier ;
28
+ import java .util .stream .Collectors ;
25
29
26
30
import org .apache .lucene .search .MatchNoDocsQuery ;
27
31
import org .apache .lucene .search .Query ;
28
32
import org .apache .lucene .search .TermQuery ;
33
+ import org .mockito .Mock ;
34
+ import org .mockito .MockitoAnnotations ;
29
35
import org .opensearch .Version ;
30
36
import org .opensearch .cluster .service .ClusterService ;
31
37
import org .opensearch .common .io .stream .BytesStreamOutput ;
38
+ import org .opensearch .common .settings .ClusterSettings ;
39
+ import org .opensearch .common .settings .Setting ;
40
+ import org .opensearch .common .settings .Settings ;
32
41
import org .opensearch .common .xcontent .XContentFactory ;
33
42
import org .opensearch .core .ParseField ;
34
43
import org .opensearch .core .common .ParsingException ;
47
56
import org .opensearch .index .query .QueryBuilders ;
48
57
import org .opensearch .index .query .QueryShardContext ;
49
58
import org .opensearch .index .query .TermQueryBuilder ;
59
+ import org .opensearch .knn .index .KNNSettings ;
50
60
import org .opensearch .knn .index .SpaceType ;
51
61
import org .opensearch .knn .index .VectorDataType ;
62
+ import org .opensearch .knn .index .engine .KNNEngine ;
63
+ import org .opensearch .knn .index .engine .KNNMethodContext ;
64
+ import org .opensearch .knn .index .engine .MethodComponentContext ;
65
+ import org .opensearch .knn .index .mapper .KNNMappingConfig ;
52
66
import org .opensearch .knn .index .mapper .KNNVectorFieldType ;
53
67
import org .opensearch .knn .index .query .KNNQuery ;
54
68
import org .opensearch .knn .index .query .KNNQueryBuilder ;
@@ -69,6 +83,26 @@ public class HybridQueryBuilderTests extends OpenSearchQueryTestCase {
69
83
static final float BOOST = 1.8f ;
70
84
static final Supplier <float []> TEST_VECTOR_SUPPLIER = () -> new float [4 ];
71
85
static final QueryBuilder TEST_FILTER = new MatchAllQueryBuilder ();
86
+ @ Mock
87
+ private ClusterService clusterService ;
88
+ private AutoCloseable openMocks ;
89
+
90
+ @ Override
91
+ public void setUp () throws Exception {
92
+ super .setUp ();
93
+ openMocks = MockitoAnnotations .openMocks (this );
94
+ // This is required to make sure that before every test we are initializing the KNNSettings. Not doing this
95
+ // leads to failures of unit tests cases when a unit test is run separately. Try running this test:
96
+ // ./gradlew ':test' --tests "org.opensearch.knn.training.TrainingJobTests.testRun_success" and see it fails
97
+ // but if run along with other tests this test passes.
98
+ initKNNSettings ();
99
+ }
100
+
101
+ @ Override
102
+ public void tearDown () throws Exception {
103
+ super .tearDown ();
104
+ openMocks .close ();
105
+ }
72
106
73
107
@ SneakyThrows
74
108
public void testDoToQuery_whenNoSubqueries_thenBuildSuccessfully () {
@@ -86,11 +120,14 @@ public void testDoToQuery_whenOneSubquery_thenBuildSuccessfully() {
86
120
Index dummyIndex = new Index ("dummy" , "dummy" );
87
121
QueryShardContext mockQueryShardContext = mock (QueryShardContext .class );
88
122
KNNVectorFieldType mockKNNVectorField = mock (KNNVectorFieldType .class );
123
+ KNNMappingConfig mockKNNMappingConfig = mock (KNNMappingConfig .class );
124
+ KNNMethodContext knnMethodContext = new KNNMethodContext (KNNEngine .FAISS , SpaceType .L2 , MethodComponentContext .EMPTY );
125
+ when (mockKNNVectorField .getKnnMappingConfig ()).thenReturn (mockKNNMappingConfig );
126
+ when (mockKNNMappingConfig .getKnnMethodContext ()).thenReturn (Optional .of (knnMethodContext ));
89
127
when (mockQueryShardContext .index ()).thenReturn (dummyIndex );
90
- when (mockKNNVectorField .getDimension ()).thenReturn (4 );
128
+ when (mockKNNVectorField .getKnnMappingConfig (). getDimension ()).thenReturn (4 );
91
129
when (mockKNNVectorField .getVectorDataType ()).thenReturn (VectorDataType .FLOAT );
92
130
when (mockQueryShardContext .fieldMapper (eq (VECTOR_FIELD_NAME ))).thenReturn (mockKNNVectorField );
93
- when (mockKNNVectorField .getSpaceType ()).thenReturn (SpaceType .L2 );
94
131
95
132
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder ().fieldName (VECTOR_FIELD_NAME )
96
133
.queryText (QUERY_TEXT )
@@ -116,10 +153,13 @@ public void testDoToQuery_whenMultipleSubqueries_thenBuildSuccessfully() {
116
153
Index dummyIndex = new Index ("dummy" , "dummy" );
117
154
QueryShardContext mockQueryShardContext = mock (QueryShardContext .class );
118
155
KNNVectorFieldType mockKNNVectorField = mock (KNNVectorFieldType .class );
156
+ KNNMappingConfig mockKNNMappingConfig = mock (KNNMappingConfig .class );
157
+ KNNMethodContext knnMethodContext = new KNNMethodContext (KNNEngine .FAISS , SpaceType .L2 , MethodComponentContext .EMPTY );
158
+ when (mockKNNVectorField .getKnnMappingConfig ()).thenReturn (mockKNNMappingConfig );
159
+ when (mockKNNMappingConfig .getKnnMethodContext ()).thenReturn (Optional .of (knnMethodContext ));
119
160
when (mockQueryShardContext .index ()).thenReturn (dummyIndex );
120
- when (mockKNNVectorField .getDimension ()).thenReturn (4 );
161
+ when (mockKNNVectorField .getKnnMappingConfig (). getDimension ()).thenReturn (4 );
121
162
when (mockKNNVectorField .getVectorDataType ()).thenReturn (VectorDataType .FLOAT );
122
- when (mockKNNVectorField .getSpaceType ()).thenReturn (SpaceType .L2 );
123
163
when (mockQueryShardContext .fieldMapper (eq (VECTOR_FIELD_NAME ))).thenReturn (mockKNNVectorField );
124
164
125
165
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder ().fieldName (VECTOR_FIELD_NAME )
@@ -367,8 +407,10 @@ public void testToXContent_whenIncomingJsonIsCorrect_thenSuccessful() {
367
407
Index dummyIndex = new Index ("dummy" , "dummy" );
368
408
QueryShardContext mockQueryShardContext = mock (QueryShardContext .class );
369
409
KNNVectorFieldType mockKNNVectorField = mock (KNNVectorFieldType .class );
410
+ KNNMappingConfig mockKNNMappingConfig = mock (KNNMappingConfig .class );
411
+ when (mockKNNVectorField .getKnnMappingConfig ()).thenReturn (mockKNNMappingConfig );
370
412
when (mockQueryShardContext .index ()).thenReturn (dummyIndex );
371
- when (mockKNNVectorField .getDimension ()).thenReturn (4 );
413
+ when (mockKNNVectorField .getKnnMappingConfig (). getDimension ()).thenReturn (4 );
372
414
when (mockQueryShardContext .fieldMapper (eq (VECTOR_FIELD_NAME ))).thenReturn (mockKNNVectorField );
373
415
374
416
NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder ().fieldName (VECTOR_FIELD_NAME )
@@ -584,9 +626,11 @@ public void testRewrite_whenMultipleSubQueries_thenReturnBuilderForEachSubQuery(
584
626
585
627
QueryShardContext mockQueryShardContext = mock (QueryShardContext .class );
586
628
KNNVectorFieldType mockKNNVectorField = mock (KNNVectorFieldType .class );
629
+ KNNMappingConfig mockKNNMappingConfig = mock (KNNMappingConfig .class );
630
+ when (mockKNNVectorField .getKnnMappingConfig ()).thenReturn (mockKNNMappingConfig );
587
631
Index dummyIndex = new Index ("dummy" , "dummy" );
588
632
when (mockQueryShardContext .index ()).thenReturn (dummyIndex );
589
- when (mockKNNVectorField .getDimension ()).thenReturn (4 );
633
+ when (mockKNNVectorField .getKnnMappingConfig (). getDimension ()).thenReturn (4 );
590
634
when (mockQueryShardContext .fieldMapper (eq (VECTOR_FIELD_NAME ))).thenReturn (mockKNNVectorField );
591
635
592
636
TextFieldMapper .TextFieldType fieldType = (TextFieldMapper .TextFieldType ) createMapperService ().fieldType (TEXT_FIELD_NAME );
@@ -737,4 +781,17 @@ private void setUpClusterService() {
737
781
ClusterService clusterService = NeuralSearchClusterTestUtils .mockClusterService (Version .CURRENT );
738
782
NeuralSearchClusterUtil .instance ().initialize (clusterService );
739
783
}
784
+
785
+ private void initKNNSettings () {
786
+ Set <Setting <?>> defaultClusterSettings = new HashSet <>(ClusterSettings .BUILT_IN_CLUSTER_SETTINGS );
787
+ defaultClusterSettings .addAll (
788
+ KNNSettings .state ()
789
+ .getSettings ()
790
+ .stream ()
791
+ .filter (s -> s .getProperties ().contains (Setting .Property .NodeScope ))
792
+ .collect (Collectors .toList ())
793
+ );
794
+ when (clusterService .getClusterSettings ()).thenReturn (new ClusterSettings (Settings .EMPTY , defaultClusterSettings ));
795
+ KNNSettings .state ().setClusterService (clusterService );
796
+ }
740
797
}
0 commit comments