Skip to content

Commit f6d8a12

Browse files
authored
Adding Reciprocal Rank Fusion (RRF) in hybrid query (opensearch-project#1086)
* Reciprocal Rank Fusion (RRF) normalization technique in hybrid query (opensearch-project#874) --------- Signed-off-by: Isaac Johnson <isaacnj@amazon.com> Signed-off-by: Ryan Bogan <rbogan@amazon.com> Signed-off-by: Martin Gaievski <gaievski@amazon.com>
1 parent b084838 commit f6d8a12

File tree

46 files changed

+2256
-265
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+2256
-265
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1717
## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x)
1818
### Features
1919
- Pagination in Hybrid query ([#1048](https://github.com/opensearch-project/neural-search/pull/1048))
20+
- Implement Reciprocal Rank Fusion score normalization/combination technique in hybrid query ([#874](https://github.com/opensearch-project/neural-search/pull/874))
2021
### Enhancements
2122
- Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970))
2223
- Support new knn query parameter expand_nested ([#1013](https://github.com/opensearch-project/neural-search/pull/1013))

qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java

-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ private void validateIndexQuery(final String modelId) {
6969
.modelId(modelId)
7070
.maxDistance(100000f)
7171
.build();
72-
7372
Map<String, Object> responseWithMaxDistanceQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1);
7473
assertNotNull(responseWithMaxDistanceQuery);
7574
}

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

+7-3
Original file line numberDiff line numberDiff line change
@@ -30,22 +30,24 @@
3030
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
3131
import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor;
3232
import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor;
33-
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
3433
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
3534
import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor;
3635
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
3736
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
3837
import org.opensearch.neuralsearch.processor.TextChunkingProcessor;
3938
import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor;
39+
import org.opensearch.neuralsearch.processor.RRFProcessor;
40+
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
4041
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
4142
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
4243
import org.opensearch.neuralsearch.processor.factory.ExplanationResponseProcessorFactory;
4344
import org.opensearch.neuralsearch.processor.factory.TextChunkingProcessorFactory;
44-
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
4545
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
4646
import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory;
4747
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
4848
import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory;
49+
import org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory;
50+
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
4951
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
5052
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
5153
import org.opensearch.neuralsearch.processor.rerank.RerankProcessor;
@@ -157,7 +159,9 @@ public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchPhaseR
157159
) {
158160
return Map.of(
159161
NormalizationProcessor.TYPE,
160-
new NormalizationProcessorFactory(normalizationProcessorWorkflow, scoreNormalizationFactory, scoreCombinationFactory)
162+
new NormalizationProcessorFactory(normalizationProcessorWorkflow, scoreNormalizationFactory, scoreCombinationFactory),
163+
RRFProcessor.TYPE,
164+
new RRFProcessorFactory(normalizationProcessorWorkflow, scoreNormalizationFactory, scoreCombinationFactory)
161165
);
162166
}
163167

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.processor;
6+
7+
import org.opensearch.action.search.SearchPhaseContext;
8+
import org.opensearch.action.search.SearchPhaseResults;
9+
import org.opensearch.search.SearchPhaseResult;
10+
import org.opensearch.search.internal.SearchContext;
11+
import org.opensearch.search.pipeline.PipelineProcessingContext;
12+
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
13+
14+
import java.util.Optional;
15+
16+
/**
17+
* Base class for all score hybridization processors. This class is responsible for executing the score hybridization process.
18+
* It is a pipeline processor that is executed after the query phase and before the fetch phase.
19+
*/
20+
public abstract class AbstractScoreHybridizationProcessor implements SearchPhaseResultsProcessor {
21+
/**
22+
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
23+
* are set as part of class constructor. This method is called when there is no pipeline context
24+
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
25+
* @param searchPhaseContext {@link SearchContext}
26+
*/
27+
@Override
28+
public <Result extends SearchPhaseResult> void process(
29+
final SearchPhaseResults<Result> searchPhaseResult,
30+
final SearchPhaseContext searchPhaseContext
31+
) {
32+
hybridizeScores(searchPhaseResult, searchPhaseContext, Optional.empty());
33+
}
34+
35+
/**
36+
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
37+
* are set as part of class constructor. This method is called when there is pipeline context
38+
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
39+
* @param searchPhaseContext {@link SearchContext}
40+
* @param requestContext {@link PipelineProcessingContext} processing context of search pipeline
41+
* @param <Result>
42+
*/
43+
@Override
44+
public <Result extends SearchPhaseResult> void process(
45+
final SearchPhaseResults<Result> searchPhaseResult,
46+
final SearchPhaseContext searchPhaseContext,
47+
final PipelineProcessingContext requestContext
48+
) {
49+
hybridizeScores(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext));
50+
}
51+
52+
/**
53+
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
54+
* are set as part of class constructor
55+
* @param searchPhaseResult
56+
* @param searchPhaseContext
57+
* @param requestContextOptional
58+
* @param <Result>
59+
*/
60+
abstract <Result extends SearchPhaseResult> void hybridizeScores(
61+
SearchPhaseResults<Result> searchPhaseResult,
62+
SearchPhaseContext searchPhaseContext,
63+
Optional<PipelineProcessingContext> requestContextOptional
64+
);
65+
}

src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,9 @@ public SearchResponse processResponse(
111111
);
112112
}
113113
// Create and set final explanation combining all components
114+
Float finalScore = Float.isNaN(searchHit.getScore()) ? 0.0f : searchHit.getScore();
114115
Explanation finalExplanation = Explanation.match(
115-
searchHit.getScore(),
116+
finalScore,
116117
// combination level explanation is always a single detail
117118
combinationExplanation.getScoreDetails().get(0).getValue(),
118119
normalizedExplanation

src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java

+3-4
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ Map<String, Object> buildMapWithTargetKeys(IngestDocument ingestDocument) {
319319
buildNestedMap(originalKey, targetKey, sourceAndMetadataMap, treeRes);
320320
mapWithProcessorKeys.put(originalKey, treeRes.get(originalKey));
321321
} else {
322-
mapWithProcessorKeys.put(String.valueOf(targetKey), normalizeSourceValue(sourceAndMetadataMap.get(originalKey)));
322+
mapWithProcessorKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey));
323323
}
324324
}
325325
return mapWithProcessorKeys;
@@ -357,9 +357,8 @@ void buildNestedMap(String parentKey, Object processorKey, Map<String, Object> s
357357
}
358358
treeRes.merge(parentKey, next, REMAPPING_FUNCTION);
359359
} else {
360-
Object parentValue = sourceAndMetadataMap.get(parentKey);
361360
String key = String.valueOf(processorKey);
362-
treeRes.put(key, normalizeSourceValue(parentValue));
361+
treeRes.put(key, sourceAndMetadataMap.get(parentKey));
363362
}
364363
}
365364

@@ -404,7 +403,7 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
404403
indexName,
405404
clusterService,
406405
environment,
407-
true
406+
false
408407
);
409408
}
410409

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.processor;
6+
7+
import lombok.AllArgsConstructor;
8+
import lombok.Builder;
9+
import lombok.Getter;
10+
import lombok.NonNull;
11+
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
12+
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
13+
import org.opensearch.search.fetch.FetchSearchResult;
14+
import org.opensearch.search.query.QuerySearchResult;
15+
16+
import java.util.List;
17+
import java.util.Optional;
18+
19+
/**
20+
* DTO object to hold data in NormalizationProcessorWorkflow class
21+
* in NormalizationProcessorWorkflow.
22+
*/
23+
@AllArgsConstructor
24+
@Builder
25+
@Getter
26+
public class NormalizationExecuteDTO {
27+
@NonNull
28+
private List<QuerySearchResult> querySearchResults;
29+
@NonNull
30+
private Optional<FetchSearchResult> fetchSearchResultOptional;
31+
@NonNull
32+
private ScoreNormalizationTechnique normalizationTechnique;
33+
@NonNull
34+
private ScoreCombinationTechnique combinationTechnique;
35+
}

src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java

+2-34
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@
1919
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
2020
import org.opensearch.search.SearchPhaseResult;
2121
import org.opensearch.search.fetch.FetchSearchResult;
22-
import org.opensearch.search.internal.SearchContext;
2322
import org.opensearch.search.pipeline.PipelineProcessingContext;
24-
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
2523
import org.opensearch.search.query.QuerySearchResult;
2624

2725
import lombok.AllArgsConstructor;
@@ -33,7 +31,7 @@
3331
*/
3432
@Log4j2
3533
@AllArgsConstructor
36-
public class NormalizationProcessor implements SearchPhaseResultsProcessor {
34+
public class NormalizationProcessor extends AbstractScoreHybridizationProcessor {
3735
public static final String TYPE = "normalization-processor";
3836

3937
private final String tag;
@@ -42,38 +40,8 @@ public class NormalizationProcessor implements SearchPhaseResultsProcessor {
4240
private final ScoreCombinationTechnique combinationTechnique;
4341
private final NormalizationProcessorWorkflow normalizationWorkflow;
4442

45-
/**
46-
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
47-
* are set as part of class constructor. This method is called when there is no pipeline context
48-
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
49-
* @param searchPhaseContext {@link SearchContext}
50-
*/
5143
@Override
52-
public <Result extends SearchPhaseResult> void process(
53-
final SearchPhaseResults<Result> searchPhaseResult,
54-
final SearchPhaseContext searchPhaseContext
55-
) {
56-
prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.empty());
57-
}
58-
59-
/**
60-
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
61-
* are set as part of class constructor
62-
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
63-
* @param searchPhaseContext {@link SearchContext}
64-
* @param requestContext {@link PipelineProcessingContext} processing context of search pipeline
65-
* @param <Result>
66-
*/
67-
@Override
68-
public <Result extends SearchPhaseResult> void process(
69-
final SearchPhaseResults<Result> searchPhaseResult,
70-
final SearchPhaseContext searchPhaseContext,
71-
final PipelineProcessingContext requestContext
72-
) {
73-
prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext));
74-
}
75-
76-
private <Result extends SearchPhaseResult> void prepareAndExecuteNormalizationWorkflow(
44+
<Result extends SearchPhaseResult> void hybridizeScores(
7745
SearchPhaseResults<Result> searchPhaseResult,
7846
SearchPhaseContext searchPhaseContext,
7947
Optional<PipelineProcessingContext> requestContextOptional

src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java

+11-27
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,12 @@
2222
import org.opensearch.action.search.SearchPhaseContext;
2323
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
2424
import org.opensearch.neuralsearch.processor.combination.CombineScoresDto;
25-
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
2625
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
2726
import org.opensearch.neuralsearch.processor.explain.CombinedExplanationDetails;
2827
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
2928
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
3029
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
3130
import org.opensearch.neuralsearch.processor.explain.ExplanationPayload;
32-
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
3331
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
3432
import org.opensearch.search.SearchHit;
3533
import org.opensearch.search.SearchHits;
@@ -57,44 +55,30 @@ public class NormalizationProcessorWorkflow {
5755

5856
/**
5957
* Start execution of this workflow
60-
* @param querySearchResults input data with QuerySearchResult from multiple shards
61-
* @param normalizationTechnique technique for score normalization
62-
* @param combinationTechnique technique for score combination
58+
* @param request contains querySearchResults input data with QuerySearchResult
59+
* from multiple shards, fetchSearchResultOptional, normalizationTechnique technique for score normalization
60+
* combinationTechnique technique for score combination, and nullable rankConstant only used in RRF technique
6361
*/
64-
public void execute(
65-
final List<QuerySearchResult> querySearchResults,
66-
final Optional<FetchSearchResult> fetchSearchResultOptional,
67-
final ScoreNormalizationTechnique normalizationTechnique,
68-
final ScoreCombinationTechnique combinationTechnique,
69-
final SearchPhaseContext searchPhaseContext
70-
) {
71-
NormalizationProcessorWorkflowExecuteRequest request = NormalizationProcessorWorkflowExecuteRequest.builder()
72-
.querySearchResults(querySearchResults)
73-
.fetchSearchResultOptional(fetchSearchResultOptional)
74-
.normalizationTechnique(normalizationTechnique)
75-
.combinationTechnique(combinationTechnique)
76-
.explain(false)
77-
.searchPhaseContext(searchPhaseContext)
78-
.build();
79-
execute(request);
80-
}
81-
8262
public void execute(final NormalizationProcessorWorkflowExecuteRequest request) {
8363
List<QuerySearchResult> querySearchResults = request.getQuerySearchResults();
8464
Optional<FetchSearchResult> fetchSearchResultOptional = request.getFetchSearchResultOptional();
85-
86-
// save original state
87-
List<Integer> unprocessedDocIds = unprocessedDocIds(querySearchResults);
65+
List<Integer> unprocessedDocIds = unprocessedDocIds(request.getQuerySearchResults());
8866

8967
// pre-process data
9068
log.debug("Pre-process query results");
9169
List<CompoundTopDocs> queryTopDocs = getQueryTopDocs(querySearchResults);
9270

9371
explain(request, queryTopDocs);
9472

73+
// Data transfer object for score normalization used to pass nullable rankConstant which is only used in RRF
74+
NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder()
75+
.queryTopDocs(queryTopDocs)
76+
.normalizationTechnique(request.getNormalizationTechnique())
77+
.build();
78+
9579
// normalize
9680
log.debug("Do score normalization");
97-
scoreNormalizer.normalizeScores(queryTopDocs, request.getNormalizationTechnique());
81+
scoreNormalizer.normalizeScores(normalizeScoresDTO);
9882

9983
CombineScoresDto combineScoresDTO = CombineScoresDto.builder()
10084
.queryTopDocs(queryTopDocs)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.neuralsearch.processor;
6+
7+
import lombok.AllArgsConstructor;
8+
import lombok.Builder;
9+
import lombok.Getter;
10+
import lombok.NonNull;
11+
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
12+
13+
import java.util.List;
14+
15+
/**
16+
* DTO object to hold data required for score normalization.
17+
*/
18+
@AllArgsConstructor
19+
@Builder
20+
@Getter
21+
public class NormalizeScoresDTO {
22+
@NonNull
23+
private List<CompoundTopDocs> queryTopDocs;
24+
@NonNull
25+
private ScoreNormalizationTechnique normalizationTechnique;
26+
}

0 commit comments

Comments
 (0)