Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature branch] Lower bounds for min-max normalization in hybrid query #1195

1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD)
### Features
- Lower bound for min-max normalization technique in hybrid query ([#1195](https://github.com/opensearch-project/neural-search/pull/1195))
### Enhancements
- Set neural-search plugin 3.0.0 baseline JDK version to JDK-21 ([#838](https://github.com/opensearch-project/neural-search/pull/838))
### Bug Fixes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;

Expand Down Expand Up @@ -150,4 +151,78 @@
FieldDoc fieldDoc = (FieldDoc) scoreDoc;
return new FieldDoc(fieldDoc.doc, fieldDoc.score, fieldDoc.fields, fieldDoc.shardIndex);
}

@Override
public boolean equals(Object other) {
if (this == other) return true;
if (other == null || getClass() != other.getClass()) return false;
CompoundTopDocs that = (CompoundTopDocs) other;

if (this.topDocs.size() != that.topDocs.size()) {
return false;

Check warning on line 162 in src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java#L162

Added line #L162 was not covered by tests
}
for (int i = 0; i < topDocs.size(); i++) {
TopDocs thisTopDoc = this.topDocs.get(i);
TopDocs thatTopDoc = that.topDocs.get(i);
if ((thisTopDoc == null) != (thatTopDoc == null)) {
return false;

Check warning on line 168 in src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java#L168

Added line #L168 was not covered by tests
}
if (thisTopDoc == null) {
continue;

Check warning on line 171 in src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java#L171

Added line #L171 was not covered by tests
}
if (Objects.equals(thisTopDoc.totalHits, thatTopDoc.totalHits) == false) {
return false;

Check warning on line 174 in src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java#L174

Added line #L174 was not covered by tests
}
if (compareScoreDocs(thisTopDoc.scoreDocs, thatTopDoc.scoreDocs) == false) {
return false;
}
}
return Objects.equals(totalHits, that.totalHits) && Objects.equals(searchShard, that.searchShard);
}

private boolean compareScoreDocs(ScoreDoc[] first, ScoreDoc[] second) {
if (first.length != second.length) {
return false;

Check warning on line 185 in src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java#L185

Added line #L185 was not covered by tests
}

for (int i = 0; i < first.length; i++) {
ScoreDoc firstDoc = first[i];
ScoreDoc secondDoc = second[i];
if ((firstDoc == null) != (secondDoc == null)) {
return false;

Check warning on line 192 in src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java#L192

Added line #L192 was not covered by tests
}
if (firstDoc == null) {
continue;

Check warning on line 195 in src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java#L195

Added line #L195 was not covered by tests
}
if (firstDoc.doc != secondDoc.doc || Float.compare(firstDoc.score, secondDoc.score) != 0) {
return false;
}
if (firstDoc instanceof FieldDoc != secondDoc instanceof FieldDoc) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we extract fielddocs validation seperately in method compareFieldDocs? and at root level itself check for instanceOf and call respective methid

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I not sure why that's needed

return false;
}
if (firstDoc instanceof FieldDoc firstFieldDoc) {
FieldDoc secondFieldDoc = (FieldDoc) secondDoc;
if (Arrays.equals(firstFieldDoc.fields, secondFieldDoc.fields) == false) {
return false;

Check warning on line 206 in src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java#L206

Added line #L206 was not covered by tests
}
}
}
return true;
}

@Override
public int hashCode() {
int result = Objects.hash(totalHits, searchShard);
for (TopDocs topDoc : topDocs) {
result = 31 * result + topDoc.totalHits.hashCode();
for (ScoreDoc scoreDoc : topDoc.scoreDocs) {
result = 31 * result + Float.floatToIntBits(scoreDoc.score);
result = 31 * result + scoreDoc.doc;
if (scoreDoc instanceof FieldDoc fieldDoc && fieldDoc.fields != null) {
result = 31 * result + Arrays.deepHashCode(fieldDoc.fields);
}
}
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ public SearchPhaseResultsProcessor create(
TECHNIQUE,
MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME
);
normalizationTechnique = scoreNormalizationFactory.createNormalization(normalizationTechniqueName);
Map<String, Object> normalizationParams = readOptionalMap(NormalizationProcessor.TYPE, tag, normalizationClause, PARAMETERS);
normalizationTechnique = scoreNormalizationFactory.createNormalization(normalizationTechniqueName, normalizationParams);
}

Map<String, Object> combinationClause = readOptionalMap(NormalizationProcessor.TYPE, tag, config, COMBINATION_CLAUSE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
Expand All @@ -32,6 +33,14 @@ public class L2ScoreNormalizationTechnique implements ScoreNormalizationTechniqu
public static final String TECHNIQUE_NAME = "l2";
private static final float MIN_SCORE = 0.0f;

public L2ScoreNormalizationTechnique() {
this(Map.of(), new ScoreNormalizationUtil());
}

public L2ScoreNormalizationTechnique(final Map<String, Object> params, final ScoreNormalizationUtil scoreNormalizationUtil) {
scoreNormalizationUtil.validateParameters(params, Set.of(), Map.of());
}

/**
* L2 normalization method.
* n_score_i = score_i/sqrt(score1^2 + score2^2 + ... + scoren^2)
Expand Down
Loading
Loading