15
15
16
16
import org .apache .lucene .search .ScoreDoc ;
17
17
import org .apache .lucene .search .TopDocs ;
18
+ import org .apache .lucene .search .Sort ;
19
+ import org .apache .lucene .search .TopFieldDocs ;
20
+ import org .apache .lucene .search .FieldDoc ;
18
21
import org .opensearch .common .lucene .search .TopDocsAndMaxScore ;
22
+ import org .opensearch .neuralsearch .processor .combination .CombineScoresDto ;
19
23
import org .opensearch .neuralsearch .processor .combination .ScoreCombinationTechnique ;
20
24
import org .opensearch .neuralsearch .processor .combination .ScoreCombiner ;
21
25
import org .opensearch .neuralsearch .processor .normalization .ScoreNormalizationTechnique ;
27
31
28
32
import lombok .AllArgsConstructor ;
29
33
import lombok .extern .log4j .Log4j2 ;
34
+ import static org .opensearch .neuralsearch .processor .combination .ScoreCombiner .MAX_SCORE_WHEN_NO_HITS_FOUND ;
35
+ import static org .opensearch .neuralsearch .search .util .HybridSearchSortUtil .evaluateSortCriteria ;
30
36
31
37
/**
32
38
* Class abstracts steps required for score normalization and combination, this includes pre-processing of incoming data
@@ -62,13 +68,20 @@ public void execute(
62
68
log .debug ("Do score normalization" );
63
69
scoreNormalizer .normalizeScores (queryTopDocs , normalizationTechnique );
64
70
71
+ CombineScoresDto combineScoresDTO = CombineScoresDto .builder ()
72
+ .queryTopDocs (queryTopDocs )
73
+ .scoreCombinationTechnique (combinationTechnique )
74
+ .querySearchResults (querySearchResults )
75
+ .sort (evaluateSortCriteria (querySearchResults , queryTopDocs ))
76
+ .build ();
77
+
65
78
// combine
66
79
log .debug ("Do score combination" );
67
- scoreCombiner .combineScores (queryTopDocs , combinationTechnique );
80
+ scoreCombiner .combineScores (combineScoresDTO );
68
81
69
82
// post-process data
70
83
log .debug ("Post-process query results after score normalization and combination" );
71
- updateOriginalQueryResults (querySearchResults , queryTopDocs );
84
+ updateOriginalQueryResults (combineScoresDTO );
72
85
updateOriginalFetchResults (querySearchResults , fetchSearchResultOptional , unprocessedDocIds );
73
86
}
74
87
@@ -96,7 +109,23 @@ private List<CompoundTopDocs> getQueryTopDocs(final List<QuerySearchResult> quer
96
109
return queryTopDocs ;
97
110
}
98
111
99
- private void updateOriginalQueryResults (final List <QuerySearchResult > querySearchResults , final List <CompoundTopDocs > queryTopDocs ) {
112
+ private void updateOriginalQueryResults (final CombineScoresDto combineScoresDTO ) {
113
+ final List <QuerySearchResult > querySearchResults = combineScoresDTO .getQuerySearchResults ();
114
+ final List <CompoundTopDocs > queryTopDocs = getCompoundTopDocs (combineScoresDTO , querySearchResults );
115
+ final Sort sort = combineScoresDTO .getSort ();
116
+ for (int index = 0 ; index < querySearchResults .size (); index ++) {
117
+ QuerySearchResult querySearchResult = querySearchResults .get (index );
118
+ CompoundTopDocs updatedTopDocs = queryTopDocs .get (index );
119
+ TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore (
120
+ buildTopDocs (updatedTopDocs , sort ),
121
+ maxScoreForShard (updatedTopDocs , sort != null )
122
+ );
123
+ querySearchResult .topDocs (updatedTopDocsAndMaxScore , querySearchResult .sortValueFormats ());
124
+ }
125
+ }
126
+
127
+ private List <CompoundTopDocs > getCompoundTopDocs (CombineScoresDto combineScoresDTO , List <QuerySearchResult > querySearchResults ) {
128
+ final List <CompoundTopDocs > queryTopDocs = combineScoresDTO .getQueryTopDocs ();
100
129
if (querySearchResults .size () != queryTopDocs .size ()) {
101
130
throw new IllegalStateException (
102
131
String .format (
@@ -107,17 +136,42 @@ private void updateOriginalQueryResults(final List<QuerySearchResult> querySearc
107
136
)
108
137
);
109
138
}
110
- for (int index = 0 ; index < querySearchResults .size (); index ++) {
111
- QuerySearchResult querySearchResult = querySearchResults .get (index );
112
- CompoundTopDocs updatedTopDocs = queryTopDocs .get (index );
113
- float maxScore = updatedTopDocs .getTotalHits ().value > 0 ? updatedTopDocs .getScoreDocs ().get (0 ).score : 0.0f ;
139
+ return queryTopDocs ;
140
+ }
114
141
115
- // create final version of top docs with all updated values
116
- TopDocs topDocs = new TopDocs (updatedTopDocs .getTotalHits (), updatedTopDocs .getScoreDocs ().toArray (new ScoreDoc [0 ]));
142
+ /**
143
+ * Get Max score on Shard
144
+ * @param updatedTopDocs updatedTopDocs compound top docs on a shard
145
+ * @param isSortEnabled if sort is enabled or disabled
146
+ * @return max score
147
+ */
148
+ private float maxScoreForShard (CompoundTopDocs updatedTopDocs , boolean isSortEnabled ) {
149
+ if (updatedTopDocs .getTotalHits ().value == 0 || updatedTopDocs .getScoreDocs ().isEmpty ()) {
150
+ return MAX_SCORE_WHEN_NO_HITS_FOUND ;
151
+ }
152
+ if (isSortEnabled ) {
153
+ float maxScore = MAX_SCORE_WHEN_NO_HITS_FOUND ;
154
+ // In case of sorting iterate over score docs and deduce the max score
155
+ for (ScoreDoc scoreDoc : updatedTopDocs .getScoreDocs ()) {
156
+ maxScore = Math .max (maxScore , scoreDoc .score );
157
+ }
158
+ return maxScore ;
159
+ }
160
+ // If it is a normal hybrid query then first entry of score doc will have max score
161
+ return updatedTopDocs .getScoreDocs ().get (0 ).score ;
162
+ }
117
163
118
- TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore (topDocs , maxScore );
119
- querySearchResult .topDocs (updatedTopDocsAndMaxScore , null );
164
+ /**
165
+ * Get Top Docs on Shard
166
+ * @param updatedTopDocs compound top docs on a shard
167
+ * @param sort sort criteria
168
+ * @return TopDocs which will be instance of TopFieldDocs if sort is enabled.
169
+ */
170
+ private TopDocs buildTopDocs (CompoundTopDocs updatedTopDocs , Sort sort ) {
171
+ if (sort != null ) {
172
+ return new TopFieldDocs (updatedTopDocs .getTotalHits (), updatedTopDocs .getScoreDocs ().toArray (new FieldDoc [0 ]), sort .getSort ());
120
173
}
174
+ return new TopDocs (updatedTopDocs .getTotalHits (), updatedTopDocs .getScoreDocs ().toArray (new ScoreDoc [0 ]));
121
175
}
122
176
123
177
/**
0 commit comments