From 95407992feb109b676a9feef77d78f483b51d78c Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Mon, 17 Feb 2025 14:30:08 -0800 Subject: [PATCH 1/8] Working draft with unit tests Signed-off-by: Martin Gaievski --- .../NormalizationProcessorFactory.java | 3 +- .../MinMaxScoreNormalizationTechnique.java | 193 +++++++++++++++++- .../ScoreNormalizationFactory.java | 7 +- ...inMaxScoreNormalizationTechniqueTests.java | 80 ++++++++ .../neuralsearch/BaseNeuralSearchIT.java | 5 +- 5 files changed, 280 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java index 0af46a4a4..d799651b0 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactory.java @@ -58,7 +58,8 @@ public SearchPhaseResultsProcessor create( TECHNIQUE, MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME ); - normalizationTechnique = scoreNormalizationFactory.createNormalization(normalizationTechniqueName); + Map normalizationParams = readOptionalMap(NormalizationProcessor.TYPE, tag, normalizationClause, PARAMETERS); + normalizationTechnique = scoreNormalizationFactory.createNormalization(normalizationTechniqueName, normalizationParams); } Map combinationClause = readOptionalMap(NormalizationProcessor.TYPE, tag, config, COMBINATION_CLAUSE); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index 8da996b41..424b87a0b 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -4,15 +4,20 @@ */ package org.opensearch.neuralsearch.processor.normalization; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; import lombok.AllArgsConstructor; import lombok.Getter; +import org.apache.commons.lang3.Validate; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.opensearch.neuralsearch.processor.CompoundTopDocs; @@ -34,8 +39,17 @@ public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTechnique, ExplainableTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "min_max"; - private static final float MIN_SCORE = 0.001f; + protected static final float MIN_SCORE = 0.001f; private static final float SINGLE_RESULT_SCORE = 1.0f; + private final List> lowerBounds; + + public MinMaxScoreNormalizationTechnique() { + this(Map.of(), new ScoreNormalizationUtil()); + } + + public MinMaxScoreNormalizationTechnique(final Map params, final ScoreNormalizationUtil scoreNormalizationUtil) { + lowerBounds = getLowerBounds(params); + } /** * Min-max normalization method. @@ -54,19 +68,34 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { continue; } List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); + if (Objects.nonNull(lowerBounds) && !lowerBounds.isEmpty() && lowerBounds.size() != topDocsPerSubQuery.size()) { + throw new IllegalArgumentException("lower bounds size should be same as number of sub queries"); + } for (int j = 0; j < topDocsPerSubQuery.size(); j++) { TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); + LowerBound lowerBound = getLowerBound(j); for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { scoreDoc.score = normalizeSingleScore( scoreDoc.score, minMaxScores.getMinScoresPerSubquery()[j], - minMaxScores.getMaxScoresPerSubquery()[j] + minMaxScores.getMaxScoresPerSubquery()[j], + lowerBound ); } } } } + private LowerBound getLowerBound(int j) { + LowerBound lowerBound; + if (Objects.isNull(lowerBounds) || lowerBounds.isEmpty()) { + lowerBound = new LowerBound(); + } else { + lowerBound = new LowerBound(true, lowerBounds.get(j).getLeft(), lowerBounds.get(j).getRight()); + } + return lowerBound; + } + private MinMaxScores getMinMaxScoresResult(final List queryTopDocs) { int numOfSubqueries = getNumOfSubqueries(queryTopDocs); // get min scores for each sub query @@ -96,10 +125,12 @@ public Map explain(final List queryTopDocs, final int } List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + // LowerBound lowerBound = getLowerBound(j); + // we need to compute actual min score for everything except clipping. For clipping we have to use + // lower bound min_score, it's passed as parameter. If we skip for clipping we can save some CPU cycles. + // if (!lowerBound.isEnabled() || lowerBound.getMode() != Mode.CLIP) { minScores[j] = Math.min( minScores[j], Arrays.stream(topDocsPerSubQuery.get(j).scoreDocs) @@ -162,20 +197,56 @@ private float[] getMinScores(final List queryTopDocs, final int .min(Float::compare) .orElse(Float.MAX_VALUE) ); + // } } } return minScores; } - private float normalizeSingleScore(final float score, final float minScore, final float maxScore) { + private float normalizeSingleScore(final float score, final float minScore, final float maxScore, LowerBound lowerBound) { // edge case when there is only one score and min and max scores are same if (Floats.compare(maxScore, minScore) == 0 && Floats.compare(maxScore, score) == 0) { return SINGLE_RESULT_SCORE; } + if (!lowerBound.isEnabled()) { + return Mode.IGNORE.normalize(score, minScore, maxScore, lowerBound.getMinScore()); + } + + return lowerBound.getMode().normalize(score, minScore, maxScore, lowerBound.getMinScore()); + } + + private boolean shouldIgnoreLowerBound(LowerBound lowerBound) { + return !lowerBound.isEnabled() || lowerBound.getMode() == Mode.IGNORE; + } + + private float normalizeWithoutLowerBound(float score, float minScore, float maxScore) { float normalizedScore = (score - minScore) / (maxScore - minScore); return normalizedScore == 0.0f ? MIN_SCORE : normalizedScore; } + private float normalizeWithLowerBound(float score, float minScore, float maxScore, LowerBound lowerBound) { + if (lowerBound.getMode() == Mode.APPLY) { + return normalizeWithApplyMode(score, maxScore, lowerBound); + } else if (lowerBound.getMode() == Mode.CLIP) { + return normalizeWithClipMode(score, minScore, maxScore, lowerBound); + } + return (score - minScore) / (maxScore - minScore); + } + + private float normalizeWithApplyMode(float score, float maxScore, LowerBound lowerBound) { + if (score < lowerBound.getMinScore()) { + return score / (maxScore - score); + } + return (score - lowerBound.getMinScore()) / (maxScore - lowerBound.getMinScore()); + } + + private float normalizeWithClipMode(float score, float minScore, float maxScore, LowerBound lowerBound) { + if (score < minScore) { + return lowerBound.getMinScore() / (maxScore - lowerBound.getMinScore()); + } + return (score - lowerBound.getMinScore()) / (maxScore - lowerBound.getMinScore()); + } + /** * Result class to hold min and max scores for each sub query */ @@ -185,4 +256,118 @@ private class MinMaxScores { float[] minScoresPerSubquery; float[] maxScoresPerSubquery; } + + private List> getLowerBounds(final Map params) { + List> lowerBounds = new ArrayList<>(); + + // Early return if params is null or doesn't contain lower_bounds + if (Objects.isNull(params) || !params.containsKey("lower_bounds")) { + return lowerBounds; + } + + Object lowerBoundsObj = params.get("lower_bounds"); + if (!(lowerBoundsObj instanceof List lowerBoundsParams)) { + throw new IllegalArgumentException("lower_bounds must be a List"); + } + + for (Object boundObj : lowerBoundsParams) { + if (!(boundObj instanceof Map)) { + throw new IllegalArgumentException("each lower bound must be a map"); + } + + @SuppressWarnings("unchecked") + Map lowerBound = (Map) boundObj; + + try { + Mode mode = Mode.fromString(lowerBound.get("mode").toString()); + float minScore = Float.parseFloat(String.valueOf(lowerBound.get("min_score"))); + + Validate.isTrue( + minScore >= LowerBound.MIN_LOWER_BOUND_SCORE && minScore <= LowerBound.MAX_LOWER_BOUND_SCORE, + "min_score must be a valid finite number between %f and %f", + LowerBound.MIN_LOWER_BOUND_SCORE, + LowerBound.MAX_LOWER_BOUND_SCORE + ); + + lowerBounds.add(ImmutablePair.of(mode, minScore)); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid format for min_score: must be a valid float value", e); + } + } + + return lowerBounds; + } + + /** + * Result class to hold lower bound for each sub query + */ + @Getter + private class LowerBound { + static final float MIN_LOWER_BOUND_SCORE = -10_000f; + static final float MAX_LOWER_BOUND_SCORE = 10_000f; + static final float DEFAULT_LOWER_BOUND_SCORE = 0.0f; + + boolean enabled; + Mode mode; + float minScore; + + LowerBound() { + this(false, Mode.DEFAULT, DEFAULT_LOWER_BOUND_SCORE); + } + + LowerBound(boolean enabled, Mode mode, float minScore) { + this.enabled = enabled; + this.mode = mode; + this.minScore = minScore; + } + } + + protected enum Mode { + APPLY { + @Override + public float normalize(float score, float minScore, float maxScore, float lowerBoundScore) { + if (score < lowerBoundScore) { + return score / (maxScore - score); + } + return (score - lowerBoundScore) / (maxScore - lowerBoundScore); + } + }, + CLIP { + @Override + public float normalize(float score, float minScore, float maxScore, float lowerBoundScore) { + if (score < minScore) { + return lowerBoundScore / (maxScore - lowerBoundScore); + } + return (score - lowerBoundScore) / (maxScore - lowerBoundScore); + } + }, + IGNORE { + @Override + public float normalize(float score, float minScore, float maxScore, float lowerBoundScore) { + float normalizedScore = (score - minScore) / (maxScore - minScore); + return normalizedScore == 0.0f ? MIN_SCORE : normalizedScore; + } + }; + + public static final Mode DEFAULT = APPLY; + public static final String VALID_VALUES = Arrays.stream(values()) + .map(mode -> mode.name().toLowerCase(Locale.ROOT)) + .collect(Collectors.joining(", ")); + + public static Mode fromString(String value) { + if (value == null || value.trim().isEmpty()) { + throw new IllegalArgumentException("mode value cannot be null or empty"); + } + + try { + return valueOf(value.toUpperCase(Locale.ROOT)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "invalid mode: %s, valid values are: %s", value, VALID_VALUES) + ); + } + } + + public abstract float normalize(float score, float minScore, float maxScore, float lowerBoundScore); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java index 7c62893a5..190c8f8c4 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java @@ -15,11 +15,14 @@ public class ScoreNormalizationFactory { private static final ScoreNormalizationUtil scoreNormalizationUtil = new ScoreNormalizationUtil(); - public static final ScoreNormalizationTechnique DEFAULT_METHOD = new MinMaxScoreNormalizationTechnique(); + public static final ScoreNormalizationTechnique DEFAULT_METHOD = new MinMaxScoreNormalizationTechnique( + Map.of(), + scoreNormalizationUtil + ); private final Map, ScoreNormalizationTechnique>> scoreNormalizationMethodsMap = Map.of( MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME, - params -> new MinMaxScoreNormalizationTechnique(), + params -> new MinMaxScoreNormalizationTechnique(params, scoreNormalizationUtil), L2ScoreNormalizationTechnique.TECHNIQUE_NAME, params -> new L2ScoreNormalizationTechnique(), RRFNormalizationTechnique.TECHNIQUE_NAME, diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java index ea2f80842..682a47050 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java @@ -22,6 +22,7 @@ import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import org.opensearch.search.SearchShardTarget; +import static org.opensearch.neuralsearch.processor.normalization.MinMaxScoreNormalizationTechnique.MIN_SCORE; import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; /** @@ -269,6 +270,85 @@ public void testNormalizedScoresAreSetAtCorrectIndices() { assertEquals(1.0f, topDocs3.scoreDocs[0].score, DELTA_FOR_SCORE_ASSERTION); // doc1 in third subquery } + public void testMode_fromString_validValues() { + assertEquals(MinMaxScoreNormalizationTechnique.Mode.APPLY, MinMaxScoreNormalizationTechnique.Mode.fromString("apply")); + assertEquals(MinMaxScoreNormalizationTechnique.Mode.CLIP, MinMaxScoreNormalizationTechnique.Mode.fromString("clip")); + assertEquals(MinMaxScoreNormalizationTechnique.Mode.IGNORE, MinMaxScoreNormalizationTechnique.Mode.fromString("ignore")); + // Case insensitive check + assertEquals(MinMaxScoreNormalizationTechnique.Mode.APPLY, MinMaxScoreNormalizationTechnique.Mode.fromString("APPLY")); + } + + public void testMode_fromString_invalidValues() { + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> MinMaxScoreNormalizationTechnique.Mode.fromString("invalid") + ); + assertEquals("invalid mode: invalid, valid values are: apply, clip, ignore", exception.getMessage()); + } + + public void testMode_fromString_nullOrEmpty() { + IllegalArgumentException nullException = expectThrows( + IllegalArgumentException.class, + () -> MinMaxScoreNormalizationTechnique.Mode.fromString(null) + ); + assertEquals("mode value cannot be null or empty", nullException.getMessage()); + + IllegalArgumentException emptyException = expectThrows( + IllegalArgumentException.class, + () -> MinMaxScoreNormalizationTechnique.Mode.fromString("") + ); + assertEquals("mode value cannot be null or empty", emptyException.getMessage()); + } + + public void testMode_normalize_apply() { + float score = 0.5f; + float minScore = 0.2f; + float maxScore = 0.8f; + float lowerBoundScore = 0.3f; + + float normalizedScore = MinMaxScoreNormalizationTechnique.Mode.APPLY.normalize(score, minScore, maxScore, lowerBoundScore); + assertEquals(0.4f, normalizedScore, DELTA_FOR_SCORE_ASSERTION); + + // Test when score is below lower bound + float lowScore = 0.1f; + float normalizedLowScore = MinMaxScoreNormalizationTechnique.Mode.APPLY.normalize(lowScore, minScore, maxScore, lowerBoundScore); + assertEquals(0.143f, normalizedLowScore, DELTA_FOR_SCORE_ASSERTION); + } + + public void testMode_normalize_clip() { + float score = 0.5f; + float minScore = 0.2f; + float maxScore = 0.8f; + float lowerBoundScore = 0.3f; + + float normalizedScore = MinMaxScoreNormalizationTechnique.Mode.CLIP.normalize(score, minScore, maxScore, lowerBoundScore); + assertEquals(0.4f, normalizedScore, DELTA_FOR_SCORE_ASSERTION); + + // Test when score is below min score + float lowScore = 0.1f; + float normalizedLowScore = MinMaxScoreNormalizationTechnique.Mode.CLIP.normalize(lowScore, minScore, maxScore, lowerBoundScore); + assertEquals(0.6f, normalizedLowScore, DELTA_FOR_SCORE_ASSERTION); + } + + public void testMode_normalize_ignore() { + float score = 0.5f; + float minScore = 0.2f; + float maxScore = 0.8f; + float lowerBoundScore = 0.3f; + + float normalizedScore = MinMaxScoreNormalizationTechnique.Mode.IGNORE.normalize(score, minScore, maxScore, lowerBoundScore); + assertEquals(0.5f, normalizedScore, DELTA_FOR_SCORE_ASSERTION); + + // Test when normalized score would be 0 + float lowScore = 0.2f; + float normalizedLowScore = MinMaxScoreNormalizationTechnique.Mode.IGNORE.normalize(lowScore, minScore, maxScore, lowerBoundScore); + assertEquals(MIN_SCORE, normalizedLowScore, DELTA_FOR_SCORE_ASSERTION); + } + + public void testMode_defaultValue() { + assertEquals(MinMaxScoreNormalizationTechnique.Mode.APPLY, MinMaxScoreNormalizationTechnique.Mode.DEFAULT); + } + private void assertCompoundTopDocs(TopDocs expected, TopDocs actual) { assertEquals(expected.totalHits.value(), actual.totalHits.value()); assertEquals(expected.totalHits.relation(), actual.totalHits.relation()); diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 72619aba5..79c92cca0 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -248,7 +248,10 @@ protected void loadModel(final String modelId) throws Exception { isComplete = checkComplete(taskQueryResult); Thread.sleep(DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND); } - assertTrue(String.format(Locale.ROOT, "failed to load the model, last task finished with status %s", taskQueryResult.get("state")), isComplete); + assertTrue( + String.format(Locale.ROOT, "failed to load the model, last task finished with status %s", taskQueryResult.get("state")), + isComplete + ); } /** From b63f34fd94b45593d3cd79c24b7dcc9c30624ff1 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Mon, 17 Feb 2025 17:16:16 -0800 Subject: [PATCH 2/8] Added integ test, adjust some calculations Signed-off-by: Martin Gaievski --- .../MinMaxScoreNormalizationTechnique.java | 38 +----- .../processor/NormalizationProcessorIT.java | 121 ++++++++++++++++++ .../query/HybridQueryExplainIT.java | 37 +++++- .../neuralsearch/query/HybridQuerySortIT.java | 2 +- .../neuralsearch/BaseNeuralSearchIT.java | 32 ++++- .../neuralsearch/util/TestUtils.java | 1 + 6 files changed, 188 insertions(+), 43 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index 424b87a0b..72871ab5b 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -215,38 +215,6 @@ private float normalizeSingleScore(final float score, final float minScore, fina return lowerBound.getMode().normalize(score, minScore, maxScore, lowerBound.getMinScore()); } - private boolean shouldIgnoreLowerBound(LowerBound lowerBound) { - return !lowerBound.isEnabled() || lowerBound.getMode() == Mode.IGNORE; - } - - private float normalizeWithoutLowerBound(float score, float minScore, float maxScore) { - float normalizedScore = (score - minScore) / (maxScore - minScore); - return normalizedScore == 0.0f ? MIN_SCORE : normalizedScore; - } - - private float normalizeWithLowerBound(float score, float minScore, float maxScore, LowerBound lowerBound) { - if (lowerBound.getMode() == Mode.APPLY) { - return normalizeWithApplyMode(score, maxScore, lowerBound); - } else if (lowerBound.getMode() == Mode.CLIP) { - return normalizeWithClipMode(score, minScore, maxScore, lowerBound); - } - return (score - minScore) / (maxScore - minScore); - } - - private float normalizeWithApplyMode(float score, float maxScore, LowerBound lowerBound) { - if (score < lowerBound.getMinScore()) { - return score / (maxScore - score); - } - return (score - lowerBound.getMinScore()) / (maxScore - lowerBound.getMinScore()); - } - - private float normalizeWithClipMode(float score, float minScore, float maxScore, LowerBound lowerBound) { - if (score < minScore) { - return lowerBound.getMinScore() / (maxScore - lowerBound.getMinScore()); - } - return (score - lowerBound.getMinScore()) / (maxScore - lowerBound.getMinScore()); - } - /** * Result class to hold min and max scores for each sub query */ @@ -302,7 +270,7 @@ private List> getLowerBounds(final Map params) * Result class to hold lower bound for each sub query */ @Getter - private class LowerBound { + private static class LowerBound { static final float MIN_LOWER_BOUND_SCORE = -10_000f; static final float MAX_LOWER_BOUND_SCORE = 10_000f; static final float DEFAULT_LOWER_BOUND_SCORE = 0.0f; @@ -326,7 +294,9 @@ protected enum Mode { APPLY { @Override public float normalize(float score, float minScore, float maxScore, float lowerBoundScore) { - if (score < lowerBoundScore) { + if (maxScore < lowerBoundScore) { + return (score - minScore) / (maxScore - minScore); + } else if (score < lowerBoundScore) { return score / (maxScore - score); } return (score - lowerBoundScore) / (maxScore - lowerBoundScore); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java index e083475dc..64bf4573c 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java @@ -4,6 +4,8 @@ */ package org.opensearch.neuralsearch.processor; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD; import static org.opensearch.neuralsearch.util.TestUtils.RELATION_EQUAL_TO; import static org.opensearch.neuralsearch.util.TestUtils.TEST_DIMENSION; import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE; @@ -48,6 +50,8 @@ public class NormalizationProcessorIT extends BaseNeuralSearchIT { private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; private static final String TEST_TEXT_FIELD_NAME_2 = "test-text-field-2"; private static final String SEARCH_PIPELINE = "phase-results-normalization-processor-pipeline"; + private static final String SEARCH_PIPELINE_LOWER_BOUNDS_2_QUERIES = "normalization-processor-with-lower-bounds-two-queries"; + private static final String SEARCH_PIPELINE_LOWER_BOUNDS_3_QUERIES = "normalization-processor-with-lower-bounds-three-queries"; private final float[] testVector1 = createRandomVector(TEST_DIMENSION); private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); @@ -239,6 +243,123 @@ public void testQueryMatches_whenMultipleShards_thenSuccessful() { assertQueryResults(searchResponseAsMapNoMatches, 0, true); } + @SneakyThrows + public void testMinMaxLowerBounds_whenMultipleShards_thenSuccessful() { + String modelId = null; + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME); + modelId = prepareModel(); + createSearchPipeline( + SEARCH_PIPELINE_LOWER_BOUNDS_2_QUERIES, + DEFAULT_NORMALIZATION_METHOD, + Map.of( + "lower_bounds", + List.of( + Map.of("mode", "apply", "min_score", Float.toString(0.01f)), + Map.of("mode", "clip", "min_score", Float.toString(0.0f)) + ) + ), + DEFAULT_COMBINATION_METHOD, + Map.of(), + false + ); + int totalExpectedDocQty = 6; + + NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder() + .fieldName(TEST_KNN_VECTOR_FIELD_NAME_1) + .queryText(TEST_DOC_TEXT1) + .modelId(modelId) + .k(6) + .build(); + + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(neuralQueryBuilder); + hybridQueryBuilder.add(termQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 6, + Map.of("search_pipeline", SEARCH_PIPELINE_LOWER_BOUNDS_2_QUERIES) + ); + + assertNotNull(searchResponseAsMap); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(totalExpectedDocQty, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + assertTrue(Range.between(.5f, 1.0f).contains(getMaxScore(searchResponseAsMap).get())); + List> hitsNestedList = getNestedHits(searchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + // verify scores order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + + // verify the scores are normalized. we need special assert logic because combined score may vary as neural search query + // based on random vectors and return results for every doc. In some cases that may affect 1.0 score from term query and make it + // lower. + float highestScore = scores.stream().max(Double::compare).get().floatValue(); + assertTrue(Range.between(.5f, 1.0f).contains(highestScore)); + float lowestScore = scores.stream().min(Double::compare).get().floatValue(); + assertTrue(Range.between(.0f, .5f).contains(lowestScore)); + + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + createSearchPipeline( + SEARCH_PIPELINE_LOWER_BOUNDS_3_QUERIES, + DEFAULT_NORMALIZATION_METHOD, + Map.of( + "lower_bounds", + List.of( + Map.of("mode", "apply", "min_score", Float.toString(0.01f)), + Map.of("mode", "clip", "min_score", Float.toString(0.0f)), + Map.of("mode", "ignore", "min_score", Float.toString(0.0f)) + ) + ), + DEFAULT_COMBINATION_METHOD, + Map.of(), + false + ); + + // verify case when there are partial match + HybridQueryBuilder hybridQueryBuilderPartialMatch = new HybridQueryBuilder(); + hybridQueryBuilderPartialMatch.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + hybridQueryBuilderPartialMatch.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4)); + hybridQueryBuilderPartialMatch.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT7)); + + Map searchResponseAsMapPartialMatch = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilderPartialMatch, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE_LOWER_BOUNDS_3_QUERIES) + ); + assertQueryResults(searchResponseAsMapPartialMatch, 4, false, Range.between(0.33f, 1.0f)); + + // verify case when query doesn't have a match + HybridQueryBuilder hybridQueryBuilderNoMatches = new HybridQueryBuilder(); + hybridQueryBuilderNoMatches.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT6)); + hybridQueryBuilderNoMatches.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT7)); + + Map searchResponseAsMapNoMatches = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilderNoMatches, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE_LOWER_BOUNDS_2_QUERIES) + ); + assertQueryResults(searchResponseAsMapNoMatches, 0, true); + } + private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)) { prepareKnnIndex( diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java index 5a71cac22..3fe39554e 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java @@ -79,7 +79,14 @@ protected boolean preserveClusterUponCompletion() { public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); // create search pipeline with both normalization processor and explain response processor - createSearchPipeline(NORMALIZATION_SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + createSearchPipeline( + NORMALIZATION_SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + Map.of(), + DEFAULT_COMBINATION_METHOD, + Map.of(), + true + ); TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); @@ -195,6 +202,7 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() createSearchPipeline( NORMALIZATION_SEARCH_PIPELINE, NORMALIZATION_TECHNIQUE_L2, + Map.of(), DEFAULT_COMBINATION_METHOD, Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.3f, 0.7f })), true @@ -324,7 +332,14 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() public void testExplanationResponseProcessor_whenProcessorIsNotConfigured_thenResponseHasQueryExplanations() { initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); // create search pipeline with normalization processor, no explanation response processor - createSearchPipeline(NORMALIZATION_SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), false); + createSearchPipeline( + NORMALIZATION_SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + Map.of(), + DEFAULT_COMBINATION_METHOD, + Map.of(), + false + ); TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); @@ -472,7 +487,14 @@ public void testExplanationResponseProcessor_whenProcessorIsNotConfigured_thenRe public void testExplain_whenLargeNumberOfDocuments_thenSuccessful() { initializeIndexIfNotExist(TEST_LARGE_DOCS_INDEX_NAME); // create search pipeline with both normalization processor and explain response processor - createSearchPipeline(NORMALIZATION_SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + createSearchPipeline( + NORMALIZATION_SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + Map.of(), + DEFAULT_COMBINATION_METHOD, + Map.of(), + true + ); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); @@ -526,7 +548,14 @@ public void testExplain_whenLargeNumberOfDocuments_thenSuccessful() { public void testSpecificQueryTypes_whenMultiMatchAndKnn_thenSuccessful() { initializeIndexIfNotExist(TEST_LARGE_DOCS_INDEX_NAME); // create search pipeline with both normalization processor and explain response processor - createSearchPipeline(NORMALIZATION_SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + createSearchPipeline( + NORMALIZATION_SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + Map.of(), + DEFAULT_COMBINATION_METHOD, + Map.of(), + true + ); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); hybridQueryBuilder.add(QueryBuilders.multiMatchQuery(TEST_QUERY_TEXT3, TEST_TEXT_FIELD_NAME_1, TEST_TEXT_FIELD_NAME_2)); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java index 10b09c78a..40b3d1bba 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java @@ -490,7 +490,7 @@ public void testExplainAndSort_whenIndexWithMultipleShards_thenSuccessful() { updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); initializeIndexIfNotExists(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SHARDS_COUNT_IN_MULTI_NODE_CLUSTER); - createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, Map.of(), DEFAULT_COMBINATION_METHOD, Map.of(), true); // Assert // scores for search hits HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 79c92cca0..428893311 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -82,6 +82,7 @@ import static org.opensearch.neuralsearch.util.TestUtils.ML_PLUGIN_SYSTEM_INDEX_PREFIX; import static org.opensearch.neuralsearch.util.TestUtils.OPENDISTRO_SECURITY; import static org.opensearch.neuralsearch.util.TestUtils.OPENSEARCH_SYSTEM_INDEX_PREFIX; +import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_LOWER_BOUNDS; import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS; import static org.opensearch.neuralsearch.util.TestUtils.MAX_RETRY; import static org.opensearch.neuralsearch.util.TestUtils.MAX_TIME_OUT_INTERVAL; @@ -1217,13 +1218,14 @@ protected void createSearchPipeline( String combinationMethod, final Map combinationParams ) { - createSearchPipeline(pipelineId, normalizationMethod, combinationMethod, combinationParams, false); + createSearchPipeline(pipelineId, normalizationMethod, Map.of(), combinationMethod, combinationParams, false); } @SneakyThrows protected void createSearchPipeline( final String pipelineId, final String normalizationMethod, + final Map normalizationParams, final String combinationMethod, final Map combinationParams, boolean addExplainResponseProcessor @@ -1235,10 +1237,32 @@ protected void createSearchPipeline( .append(NormalizationProcessor.TYPE) .append("\": {") .append("\"normalization\": {") - .append("\"technique\": \"%s\"") - .append("},") - .append("\"combination\": {") .append("\"technique\": \"%s\""); + if (Objects.nonNull(normalizationParams) && !normalizationParams.isEmpty()) { + stringBuilderForContentBody.append(", \"parameters\": {"); + if (normalizationParams.containsKey(PARAM_NAME_LOWER_BOUNDS)) { + stringBuilderForContentBody.append("\"lower_bounds\": ["); + List lowerBounds = (List) normalizationParams.get(PARAM_NAME_LOWER_BOUNDS); + for (int i = 0; i < lowerBounds.size(); i++) { + Map lowerBound = lowerBounds.get(i); + stringBuilderForContentBody.append("{ ") + .append("\"mode\"") + .append(": \"") + .append(lowerBound.get("mode")) + .append("\",") + .append("\"min_score\"") + .append(": ") + .append(lowerBound.get("min_score")) + .append(" }"); + if (i < lowerBounds.size() - 1) { + stringBuilderForContentBody.append(", "); + } + } + stringBuilderForContentBody.append("]"); + } + stringBuilderForContentBody.append(" }"); + } + stringBuilderForContentBody.append("},").append("\"combination\": {").append("\"technique\": \"%s\""); if (Objects.nonNull(combinationParams) && !combinationParams.isEmpty()) { stringBuilderForContentBody.append(", \"parameters\": {"); if (combinationParams.containsKey(PARAM_NAME_WEIGHTS)) { diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java b/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java index 7fe7d5825..0a5a9bf33 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java @@ -65,6 +65,7 @@ public class TestUtils { public static final String DEFAULT_NORMALIZATION_METHOD = "min_max"; public static final String DEFAULT_COMBINATION_METHOD = "arithmetic_mean"; public static final String PARAM_NAME_WEIGHTS = "weights"; + public static final String PARAM_NAME_LOWER_BOUNDS = "lower_bounds"; public static final String SPARSE_ENCODING_PROCESSOR = "sparse_encoding"; public static final int MAX_TIME_OUT_INTERVAL = 3000; public static final int MAX_RETRY = 5; From a983cfd893df5bf8f8129a1ec144c97201c8dc95 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Wed, 19 Feb 2025 17:06:58 -0800 Subject: [PATCH 3/8] Added check for number of elements in lower_bounds array Signed-off-by: Martin Gaievski --- .../MinMaxScoreNormalizationTechnique.java | 16 +++++++- .../ScoreNormalizationFactory.java | 7 +--- .../query/HybridQueryBuilder.java | 2 +- ...inMaxScoreNormalizationTechniqueTests.java | 40 +++++++++++++++++++ 4 files changed, 57 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index 72871ab5b..d9a009a87 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -31,6 +31,7 @@ import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.getDocIdAtQueryForNormalization; +import static org.opensearch.neuralsearch.query.HybridQueryBuilder.MAX_NUMBER_OF_SUB_QUERIES; /** * Abstracts normalization of scores based on min-max method @@ -44,10 +45,10 @@ public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTech private final List> lowerBounds; public MinMaxScoreNormalizationTechnique() { - this(Map.of(), new ScoreNormalizationUtil()); + this(Map.of()); } - public MinMaxScoreNormalizationTechnique(final Map params, final ScoreNormalizationUtil scoreNormalizationUtil) { + public MinMaxScoreNormalizationTechnique(final Map params) { lowerBounds = getLowerBounds(params); } @@ -238,6 +239,17 @@ private List> getLowerBounds(final Map params) throw new IllegalArgumentException("lower_bounds must be a List"); } + if (lowerBoundsParams.size() > MAX_NUMBER_OF_SUB_QUERIES) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "lower_bounds size %d should be less than or equal to %d", + lowerBoundsParams.size(), + MAX_NUMBER_OF_SUB_QUERIES + ) + ); + } + for (Object boundObj : lowerBoundsParams) { if (!(boundObj instanceof Map)) { throw new IllegalArgumentException("each lower bound must be a map"); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java index 190c8f8c4..9ad64da15 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java @@ -15,14 +15,11 @@ public class ScoreNormalizationFactory { private static final ScoreNormalizationUtil scoreNormalizationUtil = new ScoreNormalizationUtil(); - public static final ScoreNormalizationTechnique DEFAULT_METHOD = new MinMaxScoreNormalizationTechnique( - Map.of(), - scoreNormalizationUtil - ); + public static final ScoreNormalizationTechnique DEFAULT_METHOD = new MinMaxScoreNormalizationTechnique(Map.of()); private final Map, ScoreNormalizationTechnique>> scoreNormalizationMethodsMap = Map.of( MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME, - params -> new MinMaxScoreNormalizationTechnique(params, scoreNormalizationUtil), + MinMaxScoreNormalizationTechnique::new, L2ScoreNormalizationTechnique.TECHNIQUE_NAME, params -> new L2ScoreNormalizationTechnique(), RRFNormalizationTechnique.TECHNIQUE_NAME, diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java index c8737b94c..c04fbda09 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryBuilder.java @@ -57,7 +57,7 @@ public final class HybridQueryBuilder extends AbstractQueryBuilder> lowerBounds = new ArrayList<>(); + + for (int i = 0; i <= 100; i++) { + Map bound = new HashMap<>(); + if (i % 3 == 0) { + bound.put("mode", "apply"); + bound.put("min_score", 0.1f); + } else if (i % 3 == 1) { + bound.put("mode", "clip"); + bound.put("min_score", 0.1f); + } else { + bound.put("mode", "ignore"); + } + lowerBounds.add(bound); + } + + Map parameters = new HashMap<>(); + parameters.put("lower_bounds", lowerBounds); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> new MinMaxScoreNormalizationTechnique(parameters) + ); + + assertEquals( + String.format( + Locale.ROOT, + "lower_bounds size %d should be less than or equal to %d", + lowerBounds.size(), + MAX_NUMBER_OF_SUB_QUERIES + ), + exception.getMessage() + ); + } + private void assertCompoundTopDocs(TopDocs expected, TopDocs actual) { assertEquals(expected.totalHits.value(), actual.totalHits.value()); assertEquals(expected.totalHits.relation(), actual.totalHits.relation()); From 70a59c0ea75dd3ab9cd459d36c9a16d149cfff01 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 21 Feb 2025 17:42:59 -0800 Subject: [PATCH 4/8] Added more validations and unit tests Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + .../processor/CompoundTopDocs.java | 75 +++++ .../L2ScoreNormalizationTechnique.java | 9 + .../MinMaxScoreNormalizationTechnique.java | 233 ++++++++------ .../RRFNormalizationTechnique.java | 2 +- .../ScoreNormalizationFactory.java | 6 +- .../normalization/ScoreNormalizationUtil.java | 74 +++++ .../processor/CompoundTopDocsTests.java | 150 +++++++++ .../L2ScoreNormalizationTechniqueTests.java | 29 ++ ...inMaxScoreNormalizationTechniqueTests.java | 296 ++++++++++++++++-- .../ScoreNormalizationFactoryTests.java | 34 ++ .../ScoreNormalizationUtilTests.java | 99 ++++++ .../query/HybridQueryExplainIT.java | 214 +++++++++---- 13 files changed, 1032 insertions(+), 190 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5fc16369c..180989f05 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 3.0](https://github.com/opensearch-project/neural-search/compare/2.x...HEAD) ### Features +- Lower bound for min-max normalization technique in hybrid query ([#1195](https://github.com/opensearch-project/neural-search/pull/1195)) ### Enhancements - Set neural-search plugin 3.0.0 baseline JDK version to JDK-21 ([#838](https://github.com/opensearch-project/neural-search/pull/838)) ### Bug Fixes diff --git a/src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java b/src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java index 11a0c7ee0..986a8f261 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java @@ -13,6 +13,7 @@ import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Objects; @@ -150,4 +151,78 @@ private ScoreDoc deepCopyScoreDoc(final ScoreDoc scoreDoc, final boolean isSortE FieldDoc fieldDoc = (FieldDoc) scoreDoc; return new FieldDoc(fieldDoc.doc, fieldDoc.score, fieldDoc.fields, fieldDoc.shardIndex); } + + @Override + public boolean equals(Object other) { + if (this == other) return true; + if (other == null || getClass() != other.getClass()) return false; + CompoundTopDocs that = (CompoundTopDocs) other; + + if (this.topDocs.size() != that.topDocs.size()) { + return false; + } + for (int i = 0; i < topDocs.size(); i++) { + TopDocs thisTopDoc = this.topDocs.get(i); + TopDocs thatTopDoc = that.topDocs.get(i); + if ((thisTopDoc == null) != (thatTopDoc == null)) { + return false; + } + if (thisTopDoc == null) { + continue; + } + if (!Objects.equals(thisTopDoc.totalHits, thatTopDoc.totalHits)) { + return false; + } + if (!compareScoreDocs(thisTopDoc.scoreDocs, thatTopDoc.scoreDocs)) { + return false; + } + } + return Objects.equals(totalHits, that.totalHits) && Objects.equals(searchShard, that.searchShard); + } + + private boolean compareScoreDocs(ScoreDoc[] first, ScoreDoc[] second) { + if (first.length != second.length) { + return false; + } + + for (int i = 0; i < first.length; i++) { + ScoreDoc firstDoc = first[i]; + ScoreDoc secondDoc = second[i]; + if ((firstDoc == null) != (secondDoc == null)) { + return false; + } + if (firstDoc == null) { + continue; + } + if (firstDoc.doc != secondDoc.doc || Float.compare(firstDoc.score, secondDoc.score) != 0) { + return false; + } + if (firstDoc instanceof FieldDoc != secondDoc instanceof FieldDoc) { + return false; + } + if (firstDoc instanceof FieldDoc firstFieldDoc) { + FieldDoc secondFieldDoc = (FieldDoc) secondDoc; + if (!Arrays.equals(firstFieldDoc.fields, secondFieldDoc.fields)) { + return false; + } + } + } + return true; + } + + @Override + public int hashCode() { + int result = Objects.hash(totalHits, searchShard); + for (TopDocs topDoc : topDocs) { + result = 31 * result + topDoc.totalHits.hashCode(); + for (ScoreDoc scoreDoc : topDoc.scoreDocs) { + result = 31 * result + Float.floatToIntBits(scoreDoc.score); + result = 31 * result + scoreDoc.doc; + if (scoreDoc instanceof FieldDoc fieldDoc && fieldDoc.fields != null) { + result = 31 * result + Arrays.deepHashCode(fieldDoc.fields); + } + } + } + return result; + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java index 1208ffe77..88c4ce816 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java @@ -10,6 +10,7 @@ import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.Set; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; @@ -32,6 +33,14 @@ public class L2ScoreNormalizationTechnique implements ScoreNormalizationTechniqu public static final String TECHNIQUE_NAME = "l2"; private static final float MIN_SCORE = 0.0f; + public L2ScoreNormalizationTechnique() { + this(Map.of(), new ScoreNormalizationUtil()); + } + + public L2ScoreNormalizationTechnique(final Map params, final ScoreNormalizationUtil scoreNormalizationUtil) { + scoreNormalizationUtil.validateParameters(params, Set.of(), Map.of()); + } + /** * L2 normalization method. * n_score_i = score_i/sqrt(score1^2 + score2^2 + ... + scoren^2) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index d9a009a87..3bc138040 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -11,6 +11,8 @@ import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; import lombok.AllArgsConstructor; @@ -42,14 +44,25 @@ public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTech public static final String TECHNIQUE_NAME = "min_max"; protected static final float MIN_SCORE = 0.001f; private static final float SINGLE_RESULT_SCORE = 1.0f; - private final List> lowerBounds; + private static final String PARAM_NAME_LOWER_BOUNDS = "lower_bounds"; + private static final String PARAM_NAME_LOWER_BOUND_MODE = "mode"; + private static final String PARAM_NAME_LOWER_BOUND_MIN_SCORE = "min_score"; + + private static final Set SUPPORTED_PARAMETERS = Set.of(PARAM_NAME_LOWER_BOUNDS); + private static final Map> NESTED_PARAMETERS = Map.of( + PARAM_NAME_LOWER_BOUNDS, + Set.of(PARAM_NAME_LOWER_BOUND_MODE, PARAM_NAME_LOWER_BOUND_MIN_SCORE) + ); + + private final Optional>> lowerBoundsOptional; public MinMaxScoreNormalizationTechnique() { - this(Map.of()); + this(Map.of(), new ScoreNormalizationUtil()); } - public MinMaxScoreNormalizationTechnique(final Map params) { - lowerBounds = getLowerBounds(params); + public MinMaxScoreNormalizationTechnique(final Map params, final ScoreNormalizationUtil scoreNormalizationUtil) { + scoreNormalizationUtil.validateParameters(params, SUPPORTED_PARAMETERS, NESTED_PARAMETERS); + lowerBoundsOptional = getLowerBounds(params); } /** @@ -69,8 +82,14 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { continue; } List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); - if (Objects.nonNull(lowerBounds) && !lowerBounds.isEmpty() && lowerBounds.size() != topDocsPerSubQuery.size()) { - throw new IllegalArgumentException("lower bounds size should be same as number of sub queries"); + if (isLowerBoundsAndSubQueriesCountMismatched(topDocsPerSubQuery)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "expected lower bounds array to contain %d elements matching the number of sub-queries, but found a mismatch", + topDocsPerSubQuery.size() + ) + ); } for (int j = 0; j < topDocsPerSubQuery.size(); j++) { TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); @@ -87,14 +106,16 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { } } - private LowerBound getLowerBound(int j) { - LowerBound lowerBound; - if (Objects.isNull(lowerBounds) || lowerBounds.isEmpty()) { - lowerBound = new LowerBound(); - } else { - lowerBound = new LowerBound(true, lowerBounds.get(j).getLeft(), lowerBounds.get(j).getRight()); - } - return lowerBound; + private boolean isLowerBoundsAndSubQueriesCountMismatched(List topDocsPerSubQuery) { + return lowerBoundsOptional.isPresent() + && !topDocsPerSubQuery.isEmpty() + && lowerBoundsOptional.get().size() != topDocsPerSubQuery.size(); + } + + private LowerBound getLowerBound(int subQueryIndex) { + return lowerBoundsOptional.map( + pairs -> new LowerBound(true, pairs.get(subQueryIndex).getLeft(), pairs.get(subQueryIndex).getRight()) + ).orElseGet(LowerBound::new); } private MinMaxScores getMinMaxScoresResult(final List queryTopDocs) { @@ -108,7 +129,12 @@ private MinMaxScores getMinMaxScoresResult(final List queryTopD @Override public String describe() { - return String.format(Locale.ROOT, "%s", TECHNIQUE_NAME); + return lowerBoundsOptional.map(lb -> { + String lowerBounds = lb.stream() + .map(pair -> String.format(Locale.ROOT, "(%s, %s)", pair.getLeft(), pair.getRight())) + .collect(Collectors.joining(", ", "[", "]")); + return String.format(Locale.ROOT, "%s, lower bounds %s", TECHNIQUE_NAME, lowerBounds); + }).orElse(String.format(Locale.ROOT, "%s", TECHNIQUE_NAME)); } @Override @@ -187,10 +213,6 @@ private float[] getMinScores(final List queryTopDocs, final int } List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); for (int j = 0; j < topDocsPerSubQuery.size(); j++) { - // LowerBound lowerBound = getLowerBound(j); - // we need to compute actual min score for everything except clipping. For clipping we have to use - // lower bound min_score, it's passed as parameter. If we skip for clipping we can save some CPU cycles. - // if (!lowerBound.isEnabled() || lowerBound.getMode() != Mode.CLIP) { minScores[j] = Math.min( minScores[j], Arrays.stream(topDocsPerSubQuery.get(j).scoreDocs) @@ -198,43 +220,30 @@ private float[] getMinScores(final List queryTopDocs, final int .min(Float::compare) .orElse(Float.MAX_VALUE) ); - // } } } return minScores; } - private float normalizeSingleScore(final float score, final float minScore, final float maxScore, LowerBound lowerBound) { + private float normalizeSingleScore(final float score, final float minScore, final float maxScore, final LowerBound lowerBound) { // edge case when there is only one score and min and max scores are same if (Floats.compare(maxScore, minScore) == 0 && Floats.compare(maxScore, score) == 0) { return SINGLE_RESULT_SCORE; } if (!lowerBound.isEnabled()) { - return Mode.IGNORE.normalize(score, minScore, maxScore, lowerBound.getMinScore()); + return LowerBound.Mode.IGNORE.normalize(score, minScore, maxScore, lowerBound.getMinScore()); } - return lowerBound.getMode().normalize(score, minScore, maxScore, lowerBound.getMinScore()); } - /** - * Result class to hold min and max scores for each sub query - */ - @AllArgsConstructor - @Getter - private class MinMaxScores { - float[] minScoresPerSubquery; - float[] maxScoresPerSubquery; - } - - private List> getLowerBounds(final Map params) { - List> lowerBounds = new ArrayList<>(); - - // Early return if params is null or doesn't contain lower_bounds - if (Objects.isNull(params) || !params.containsKey("lower_bounds")) { - return lowerBounds; + private Optional>> getLowerBounds(final Map params) { + if (Objects.isNull(params) || !params.containsKey(PARAM_NAME_LOWER_BOUNDS)) { + return Optional.empty(); } - Object lowerBoundsObj = params.get("lower_bounds"); + List> lowerBounds = new ArrayList<>(); + + Object lowerBoundsObj = params.get(PARAM_NAME_LOWER_BOUNDS); if (!(lowerBoundsObj instanceof List lowerBoundsParams)) { throw new IllegalArgumentException("lower_bounds must be a List"); } @@ -259,8 +268,17 @@ private List> getLowerBounds(final Map params) Map lowerBound = (Map) boundObj; try { - Mode mode = Mode.fromString(lowerBound.get("mode").toString()); - float minScore = Float.parseFloat(String.valueOf(lowerBound.get("min_score"))); + LowerBound.Mode mode = LowerBound.Mode.fromString( + Objects.isNull(lowerBound.get(PARAM_NAME_LOWER_BOUND_MODE)) + ? "" + : lowerBound.get(PARAM_NAME_LOWER_BOUND_MODE).toString() + ); + float minScore; + if (Objects.isNull(lowerBound.get(PARAM_NAME_LOWER_BOUND_MIN_SCORE))) { + minScore = LowerBound.DEFAULT_LOWER_BOUND_SCORE; + } else { + minScore = Float.parseFloat(String.valueOf(lowerBound.get(PARAM_NAME_LOWER_BOUND_MIN_SCORE))); + } Validate.isTrue( minScore >= LowerBound.MIN_LOWER_BOUND_SCORE && minScore <= LowerBound.MAX_LOWER_BOUND_SCORE, @@ -271,25 +289,35 @@ private List> getLowerBounds(final Map params) lowerBounds.add(ImmutablePair.of(mode, minScore)); } catch (NumberFormatException e) { - throw new IllegalArgumentException("Invalid format for min_score: must be a valid float value", e); + throw new IllegalArgumentException("invalid format for min_score: must be a valid float value", e); } } - return lowerBounds; + return Optional.of(lowerBounds); + } + + /** + * Result class to hold min and max scores for each sub query + */ + @AllArgsConstructor + @Getter + private static class MinMaxScores { + float[] minScoresPerSubquery; + float[] maxScoresPerSubquery; } /** * Result class to hold lower bound for each sub query */ @Getter - private static class LowerBound { + public static class LowerBound { static final float MIN_LOWER_BOUND_SCORE = -10_000f; static final float MAX_LOWER_BOUND_SCORE = 10_000f; static final float DEFAULT_LOWER_BOUND_SCORE = 0.0f; - boolean enabled; - Mode mode; - float minScore; + private final boolean enabled; + private final Mode mode; + private final float minScore; LowerBound() { this(false, Mode.DEFAULT, DEFAULT_LOWER_BOUND_SCORE); @@ -300,56 +328,79 @@ private static class LowerBound { this.mode = mode; this.minScore = minScore; } - } - protected enum Mode { - APPLY { - @Override - public float normalize(float score, float minScore, float maxScore, float lowerBoundScore) { - if (maxScore < lowerBoundScore) { - return (score - minScore) / (maxScore - minScore); - } else if (score < lowerBoundScore) { - return score / (maxScore - score); + /** + * Enum for normalization mode + */ + protected enum Mode { + APPLY { + @Override + public float normalize(float score, float minScore, float maxScore, float lowerBoundScore) { + // if we apply the lower bound this mean we use actual score in case it's less then the lower bound min score + // same applied to case when actual max_score is less than lower bound min score + if (maxScore < lowerBoundScore || score < lowerBoundScore) { + return (score - minScore) / (maxScore - minScore); + } + return (score - lowerBoundScore) / (maxScore - lowerBoundScore); } - return (score - lowerBoundScore) / (maxScore - lowerBoundScore); - } - }, - CLIP { - @Override - public float normalize(float score, float minScore, float maxScore, float lowerBoundScore) { - if (score < minScore) { - return lowerBoundScore / (maxScore - lowerBoundScore); + }, + CLIP { + @Override + public float normalize(float score, float minScore, float maxScore, float lowerBoundScore) { + // apply clipping, return lower bound min score if score is less than min score. This effectively means 0 after + // normalization + if (score < minScore) { + return 0.0f; + } + if (maxScore < lowerBoundScore) { + return (score - minScore) / (maxScore - minScore); + } + return (score - lowerBoundScore) / (maxScore - lowerBoundScore); + } + }, + IGNORE { + @Override + public float normalize(float score, float minScore, float maxScore, float lowerBoundScore) { + // ignore lower bound logic and do raw min-max normalization using actual scores + float normalizedScore = (score - minScore) / (maxScore - minScore); + return normalizedScore == 0.0f ? MIN_SCORE : normalizedScore; + } + }; + + public static final Mode DEFAULT = APPLY; + // set of all valid values for mode + public static final String VALID_VALUES = Arrays.stream(values()) + .map(mode -> mode.name().toLowerCase(Locale.ROOT)) + .collect(Collectors.joining(", ")); + + /** + * Get mode from string value + * @param value string value of mode + * @return mode + * @throws IllegalArgumentException if mode is not valid + */ + public static Mode fromString(String value) { + if (Objects.isNull(value)) { + throw new IllegalArgumentException("mode value cannot be null or empty"); + } + if (value.trim().isEmpty()) { + return DEFAULT; + } + try { + return valueOf(value.toUpperCase(Locale.ROOT)); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "invalid mode: %s, valid values are: %s", value, VALID_VALUES) + ); } - return (score - lowerBoundScore) / (maxScore - lowerBoundScore); - } - }, - IGNORE { - @Override - public float normalize(float score, float minScore, float maxScore, float lowerBoundScore) { - float normalizedScore = (score - minScore) / (maxScore - minScore); - return normalizedScore == 0.0f ? MIN_SCORE : normalizedScore; } - }; - - public static final Mode DEFAULT = APPLY; - public static final String VALID_VALUES = Arrays.stream(values()) - .map(mode -> mode.name().toLowerCase(Locale.ROOT)) - .collect(Collectors.joining(", ")); - public static Mode fromString(String value) { - if (value == null || value.trim().isEmpty()) { - throw new IllegalArgumentException("mode value cannot be null or empty"); - } + public abstract float normalize(float score, float minScore, float maxScore, float lowerBoundScore); - try { - return valueOf(value.toUpperCase(Locale.ROOT)); - } catch (IllegalArgumentException e) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "invalid mode: %s, valid values are: %s", value, VALID_VALUES) - ); + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); } } - - public abstract float normalize(float score, float minScore, float maxScore, float lowerBoundScore); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java index d920e26c8..1ff4ccdd5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java @@ -48,7 +48,7 @@ public class RRFNormalizationTechnique implements ScoreNormalizationTechnique, E private final int rankConstant; public RRFNormalizationTechnique(final Map params, final ScoreNormalizationUtil scoreNormalizationUtil) { - scoreNormalizationUtil.validateParams(params, SUPPORTED_PARAMS); + scoreNormalizationUtil.validateParameters(params, SUPPORTED_PARAMS, Map.of()); rankConstant = getRankConstant(params); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java index 9ad64da15..797336789 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java @@ -15,13 +15,13 @@ public class ScoreNormalizationFactory { private static final ScoreNormalizationUtil scoreNormalizationUtil = new ScoreNormalizationUtil(); - public static final ScoreNormalizationTechnique DEFAULT_METHOD = new MinMaxScoreNormalizationTechnique(Map.of()); + public static final ScoreNormalizationTechnique DEFAULT_METHOD = new MinMaxScoreNormalizationTechnique(); private final Map, ScoreNormalizationTechnique>> scoreNormalizationMethodsMap = Map.of( MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME, - MinMaxScoreNormalizationTechnique::new, + params -> new MinMaxScoreNormalizationTechnique(params, scoreNormalizationUtil), L2ScoreNormalizationTechnique.TECHNIQUE_NAME, - params -> new L2ScoreNormalizationTechnique(), + params -> new L2ScoreNormalizationTechnique(params, scoreNormalizationUtil), RRFNormalizationTechnique.TECHNIQUE_NAME, params -> new RRFNormalizationTechnique(params, scoreNormalizationUtil) ); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java index 3b625dfd8..7e086abd1 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java @@ -82,4 +82,78 @@ public static void setNormalizedScore( } scores.set(subQueryIndex, normalizedScore); } + + /** + * Validate parameters for this technique. Following is example of structured parameters that we will validate + * { + * "technique": "arithmetic_mean", // top-level parameter 1 + * "details": { // top-level parameter 2 + * "weights": [1, 2, 3], // nested parameter 1 + * "color": "green". // nested parameter 2 + * } + * } + * for this example client should pass: + * top-level parameters: ["technique", "details"] + * nested parameters: ["details" -> ["weights", "color"]] + * @param actualParameters map of actual parameters in form of name-value + * @param supportedParametersTopLevel collection of top-level parameters that we should validate against + * @param supportedParametersNested map of nested parameters that we should validate against, key is one of top level parameters and value is set of allowed nested params + */ + public void validateParameters( + final Map actualParameters, + final Set supportedParametersTopLevel, + final Map> supportedParametersNested + ) { + if (Objects.isNull(actualParameters) || actualParameters.isEmpty()) { + return; + } + boolean hasUnknownParameters = false; + for (Map.Entry entry : actualParameters.entrySet()) { + String paramName = entry.getKey(); + Object paramValue = entry.getValue(); + + if (!supportedParametersTopLevel.contains(paramName)) { + hasUnknownParameters = true; + continue; + } + if (paramValue instanceof Map) { + Map nestedParams = (Map) paramValue; + validateNestedParameters(nestedParams, supportedParametersNested.get(paramName)); + } else if (paramValue instanceof List) { + for (Object item : (List) paramValue) { + if (item instanceof Map) { + validateNestedParameters((Map) item, supportedParametersNested.get(paramName)); + } else { + hasUnknownParameters = true; + } + } + } else { + if (supportedParametersNested.isEmpty()) { + continue; + } + hasUnknownParameters = true; + } + } + if (hasUnknownParameters) { + throw new IllegalArgumentException("unrecognized parameters in normalization technique"); + } + } + + private void validateNestedParameters(Map parameters, Set supportedNestedParams) { + if (Objects.isNull(parameters) || parameters.isEmpty()) { + return; + } + boolean hasUnknownParameters = false; + for (Map.Entry entry : parameters.entrySet()) { + String paramName = entry.getKey(); + + if (Objects.nonNull(supportedNestedParams) && !supportedNestedParams.contains(paramName)) { + hasUnknownParameters = true; + break; + } + } + if (hasUnknownParameters) { + throw new IllegalArgumentException("unrecognized parameters in normalization technique"); + } + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/CompoundTopDocsTests.java b/src/test/java/org/opensearch/neuralsearch/processor/CompoundTopDocsTests.java index eabc69894..f9e93415e 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/CompoundTopDocsTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/CompoundTopDocsTests.java @@ -5,9 +5,11 @@ package org.opensearch.neuralsearch.processor; import java.util.Arrays; +import java.util.Collections; import java.util.List; import org.apache.commons.lang3.RandomUtils; +import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; @@ -87,4 +89,152 @@ public void testBasics_whenMultipleTopDocsIsNull_thenScoreDocsIsNull() { assertNotNull(compoundTopDocsWithNullArray.getScoreDocs()); assertEquals(0, compoundTopDocsWithNullArray.getScoreDocs().size()); } + + public void testEqualsWithIdenticalCompoundTopDocs() { + TopDocs topDocs1 = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 1.0f) }); + TopDocs topDocs2 = new TopDocs(new TotalHits(2, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(2, 2.0f) }); + List topDocsList = Arrays.asList(topDocs1, topDocs2); + + CompoundTopDocs first = new CompoundTopDocs(new TotalHits(3, TotalHits.Relation.EQUAL_TO), topDocsList, false, SEARCH_SHARD); + CompoundTopDocs second = new CompoundTopDocs(new TotalHits(3, TotalHits.Relation.EQUAL_TO), topDocsList, false, SEARCH_SHARD); + + assertTrue(first.equals(second)); + assertTrue(second.equals(first)); + assertEquals(first.hashCode(), second.hashCode()); + } + + public void testEqualsWithDifferentScoreDocs() { + TopDocs topDocs1 = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 1.0f) }); + TopDocs topDocs2 = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 2.0f) }); + + CompoundTopDocs first = new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + Collections.singletonList(topDocs1), + false, + SEARCH_SHARD + ); + CompoundTopDocs second = new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + Collections.singletonList(topDocs2), + false, + SEARCH_SHARD + ); + + assertFalse(first.equals(second)); + assertFalse(second.equals(first)); + } + + public void testEqualsWithDifferentTotalHits() { + TopDocs topDocs = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 1.0f) }); + + CompoundTopDocs first = new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + Collections.singletonList(topDocs), + false, + SEARCH_SHARD + ); + CompoundTopDocs second = new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + Collections.singletonList(topDocs), + false, + SEARCH_SHARD + ); + + assertFalse(first.equals(second)); + assertFalse(second.equals(first)); + } + + public void testEqualsWithDifferentSortEnabled() { + Object[] fields = new Object[] { "value1" }; + ScoreDoc scoreDoc = new FieldDoc(1, 1.0f, fields); // use FieldDoc when sort is enabled + TopDocs topDocs = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { scoreDoc }); + + CompoundTopDocs first = new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + Collections.singletonList(topDocs), + true, + SEARCH_SHARD + ); + + // non-sorted case, use regular ScoreDoc + TopDocs topDocs2 = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 1.0f) }); + + CompoundTopDocs second = new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + Collections.singletonList(topDocs2), + false, + SEARCH_SHARD + ); + + assertNotEquals(first, second); + assertNotEquals(second, first); + } + + public void testEqualsWithDifferentSearchShards() { + TopDocs topDocs = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 1.0f) }); + + CompoundTopDocs first = new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + Collections.singletonList(topDocs), + false, + SEARCH_SHARD + ); + CompoundTopDocs second = new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + Collections.singletonList(topDocs), + false, + new SearchShard("my_index", 1, "23456789") + ); + + assertNotEquals(first, second); + assertNotEquals(second, first); + } + + public void testEqualsWithFieldDocs() { + Object[] fields1 = new Object[] { "value1" }; + Object[] fields2 = new Object[] { "value1" }; + TopDocs topDocs1 = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new FieldDoc[] { new FieldDoc(1, 1.0f, fields1) }); + TopDocs topDocs2 = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new FieldDoc[] { new FieldDoc(1, 1.0f, fields2) }); + + CompoundTopDocs first = new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + Collections.singletonList(topDocs1), + false, + SEARCH_SHARD + ); + CompoundTopDocs second = new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + Collections.singletonList(topDocs2), + false, + SEARCH_SHARD + ); + + assertEquals(first, second); + assertEquals(second, first); + assertEquals(first.hashCode(), second.hashCode()); + } + + public void testEqualsWithNull() { + TopDocs topDocs = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 1.0f) }); + CompoundTopDocs compoundTopDocs = new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + Collections.singletonList(topDocs), + false, + SEARCH_SHARD + ); + + assertNotEquals(null, compoundTopDocs); + } + + public void testEqualsWithDifferentClass() { + TopDocs topDocs = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 1.0f) }); + CompoundTopDocs compoundTopDocs = new CompoundTopDocs( + new TotalHits(1, TotalHits.Relation.EQUAL_TO), + Collections.singletonList(topDocs), + false, + SEARCH_SHARD + ); + + assertNotEquals("not a CompoundTopDocs", compoundTopDocs); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java index 30a9d1b89..c8bf370c2 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java @@ -6,6 +6,7 @@ import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -320,6 +321,34 @@ public void testNormalizedScoresAreSetAtCorrectIndices() { assertTrue(doc1Scores.get(2).getValue().contains("l2 normalization")); } + public void testInvalidParameters() { + Map parameters = new HashMap<>(); + List> lowerBoundsList = List.of( + Map.of("min_score", 0.1, "mode", "clip"), + Map.of("mode", "ignore", "invalid_param", "value") + ); + parameters.put("lower_bounds", lowerBoundsList); + + try { + new L2ScoreNormalizationTechnique(parameters, new ScoreNormalizationUtil()); + fail("expected IllegalArgumentException was not thrown"); + } catch (IllegalArgumentException e) { + assertEquals("unrecognized parameters in normalization technique", e.getMessage()); + } + } + + public void testUnsupportedTopLevelParameter() { + Map parameters = new HashMap<>(); + parameters.put("invalid_top_level_param", "value"); // Adding an invalid top-level parameter + + try { + new L2ScoreNormalizationTechnique(parameters, new ScoreNormalizationUtil()); + fail("expected IllegalArgumentException was not thrown"); + } catch (IllegalArgumentException e) { + assertEquals("unrecognized parameters in normalization technique", e.getMessage()); + } + } + private float l2Norm(float score, List scores) { return score / (float) Math.sqrt(scores.stream().map(Float::doubleValue).map(s -> s * s).mapToDouble(Double::doubleValue).sum()); } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java index 840c19394..d9856f8f8 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java @@ -28,6 +28,7 @@ import static org.opensearch.neuralsearch.processor.normalization.MinMaxScoreNormalizationTechnique.MIN_SCORE; import static org.opensearch.neuralsearch.query.HybridQueryBuilder.MAX_NUMBER_OF_SUB_QUERIES; import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_LOWER_BOUNDS; /** * Abstracts normalization of scores based on min-max method @@ -274,86 +275,124 @@ public void testNormalizedScoresAreSetAtCorrectIndices() { assertEquals(1.0f, topDocs3.scoreDocs[0].score, DELTA_FOR_SCORE_ASSERTION); // doc1 in third subquery } - public void testMode_fromString_validValues() { - assertEquals(MinMaxScoreNormalizationTechnique.Mode.APPLY, MinMaxScoreNormalizationTechnique.Mode.fromString("apply")); - assertEquals(MinMaxScoreNormalizationTechnique.Mode.CLIP, MinMaxScoreNormalizationTechnique.Mode.fromString("clip")); - assertEquals(MinMaxScoreNormalizationTechnique.Mode.IGNORE, MinMaxScoreNormalizationTechnique.Mode.fromString("ignore")); + public void testLowerBoundsModeFromString_whenValidValues_thenSuccessful() { + assertEquals( + MinMaxScoreNormalizationTechnique.LowerBound.Mode.APPLY, + MinMaxScoreNormalizationTechnique.LowerBound.Mode.fromString("apply") + ); + assertEquals( + MinMaxScoreNormalizationTechnique.LowerBound.Mode.CLIP, + MinMaxScoreNormalizationTechnique.LowerBound.Mode.fromString("clip") + ); + assertEquals( + MinMaxScoreNormalizationTechnique.LowerBound.Mode.IGNORE, + MinMaxScoreNormalizationTechnique.LowerBound.Mode.fromString("ignore") + ); // Case insensitive check - assertEquals(MinMaxScoreNormalizationTechnique.Mode.APPLY, MinMaxScoreNormalizationTechnique.Mode.fromString("APPLY")); + assertEquals( + MinMaxScoreNormalizationTechnique.LowerBound.Mode.APPLY, + MinMaxScoreNormalizationTechnique.LowerBound.Mode.fromString("APPLY") + ); } public void testMode_fromString_invalidValues() { IllegalArgumentException exception = expectThrows( IllegalArgumentException.class, - () -> MinMaxScoreNormalizationTechnique.Mode.fromString("invalid") + () -> MinMaxScoreNormalizationTechnique.LowerBound.Mode.fromString("invalid") ); assertEquals("invalid mode: invalid, valid values are: apply, clip, ignore", exception.getMessage()); } - public void testMode_fromString_nullOrEmpty() { + public void testLowerBoundsModeFromString_whenNullOrEmpty_thenFail() { IllegalArgumentException nullException = expectThrows( IllegalArgumentException.class, - () -> MinMaxScoreNormalizationTechnique.Mode.fromString(null) + () -> MinMaxScoreNormalizationTechnique.LowerBound.Mode.fromString(null) ); assertEquals("mode value cannot be null or empty", nullException.getMessage()); - - IllegalArgumentException emptyException = expectThrows( - IllegalArgumentException.class, - () -> MinMaxScoreNormalizationTechnique.Mode.fromString("") - ); - assertEquals("mode value cannot be null or empty", emptyException.getMessage()); } - public void testMode_normalize_apply() { + public void testLowerBounds_whenModeIsApply_thenSuccessful() { float score = 0.5f; - float minScore = 0.2f; + float minScore = 0.1f; float maxScore = 0.8f; float lowerBoundScore = 0.3f; - float normalizedScore = MinMaxScoreNormalizationTechnique.Mode.APPLY.normalize(score, minScore, maxScore, lowerBoundScore); + float normalizedScore = MinMaxScoreNormalizationTechnique.LowerBound.Mode.APPLY.normalize( + score, + minScore, + maxScore, + lowerBoundScore + ); + // we expect score as 0.5 - 0.3 / 0.8 - 0.3 = 0.2 / 0.5 = 0.4 assertEquals(0.4f, normalizedScore, DELTA_FOR_SCORE_ASSERTION); // Test when score is below lower bound - float lowScore = 0.1f; - float normalizedLowScore = MinMaxScoreNormalizationTechnique.Mode.APPLY.normalize(lowScore, minScore, maxScore, lowerBoundScore); + float lowScore = 0.2f; + float normalizedLowScore = MinMaxScoreNormalizationTechnique.LowerBound.Mode.APPLY.normalize( + lowScore, + minScore, + maxScore, + lowerBoundScore + ); + // we expect score as 0.2 - 0.1 / 0.8 - 0.1 = 0.1 / 0.7 = 0.1 assertEquals(0.143f, normalizedLowScore, DELTA_FOR_SCORE_ASSERTION); } - public void testMode_normalize_clip() { + public void testLowerBounds_whenModeIsClip_thenSuccessful() { float score = 0.5f; float minScore = 0.2f; float maxScore = 0.8f; float lowerBoundScore = 0.3f; - float normalizedScore = MinMaxScoreNormalizationTechnique.Mode.CLIP.normalize(score, minScore, maxScore, lowerBoundScore); + float normalizedScore = MinMaxScoreNormalizationTechnique.LowerBound.Mode.CLIP.normalize( + score, + minScore, + maxScore, + lowerBoundScore + ); assertEquals(0.4f, normalizedScore, DELTA_FOR_SCORE_ASSERTION); // Test when score is below min score float lowScore = 0.1f; - float normalizedLowScore = MinMaxScoreNormalizationTechnique.Mode.CLIP.normalize(lowScore, minScore, maxScore, lowerBoundScore); - assertEquals(0.6f, normalizedLowScore, DELTA_FOR_SCORE_ASSERTION); + float normalizedLowScore = MinMaxScoreNormalizationTechnique.LowerBound.Mode.CLIP.normalize( + lowScore, + minScore, + maxScore, + lowerBoundScore + ); + assertEquals(0.0f, normalizedLowScore, DELTA_FOR_SCORE_ASSERTION); } - public void testMode_normalize_ignore() { + public void testLowerBounds_whenModeIsIgnore_thenSuccessful() { float score = 0.5f; float minScore = 0.2f; float maxScore = 0.8f; float lowerBoundScore = 0.3f; - float normalizedScore = MinMaxScoreNormalizationTechnique.Mode.IGNORE.normalize(score, minScore, maxScore, lowerBoundScore); + float normalizedScore = MinMaxScoreNormalizationTechnique.LowerBound.Mode.IGNORE.normalize( + score, + minScore, + maxScore, + lowerBoundScore + ); assertEquals(0.5f, normalizedScore, DELTA_FOR_SCORE_ASSERTION); // Test when normalized score would be 0 float lowScore = 0.2f; - float normalizedLowScore = MinMaxScoreNormalizationTechnique.Mode.IGNORE.normalize(lowScore, minScore, maxScore, lowerBoundScore); + float normalizedLowScore = MinMaxScoreNormalizationTechnique.LowerBound.Mode.IGNORE.normalize( + lowScore, + minScore, + maxScore, + lowerBoundScore + ); assertEquals(MIN_SCORE, normalizedLowScore, DELTA_FOR_SCORE_ASSERTION); } - public void testMode_defaultValue() { - assertEquals(MinMaxScoreNormalizationTechnique.Mode.APPLY, MinMaxScoreNormalizationTechnique.Mode.DEFAULT); + public void testLowerBoundsMode_whenDefaultValue_thenSuccessful() { + assertEquals(MinMaxScoreNormalizationTechnique.LowerBound.Mode.APPLY, MinMaxScoreNormalizationTechnique.LowerBound.Mode.DEFAULT); } - public void testLowerBoundsExceedsMaxSubQueries() { + public void testLowerBounds_whenExceedsMaxSubQueries_thenFail() { List> lowerBounds = new ArrayList<>(); for (int i = 0; i <= 100; i++) { @@ -375,7 +414,7 @@ public void testLowerBoundsExceedsMaxSubQueries() { IllegalArgumentException exception = expectThrows( IllegalArgumentException.class, - () -> new MinMaxScoreNormalizationTechnique(parameters) + () -> new MinMaxScoreNormalizationTechnique(parameters, new ScoreNormalizationUtil()) ); assertEquals( @@ -389,6 +428,203 @@ public void testLowerBoundsExceedsMaxSubQueries() { ); } + public void testDescribe_whenLowerBoundsArePresent_thenSuccessful() { + Map parameters = new HashMap<>(); + List> lowerBounds = Arrays.asList( + Map.of("mode", "apply", "min_score", 0.2), + + Map.of("mode", "clip", "min_score", 0.1) + ); + parameters.put("lower_bounds", lowerBounds); + MinMaxScoreNormalizationTechnique techniqueWithBounds = new MinMaxScoreNormalizationTechnique( + parameters, + new ScoreNormalizationUtil() + ); + assertEquals("min_max, lower bounds [(apply, 0.2), (clip, 0.1)]", techniqueWithBounds.describe()); + + // Test case 2: without lower bounds + Map emptyParameters = new HashMap<>(); + MinMaxScoreNormalizationTechnique techniqueWithoutBounds = new MinMaxScoreNormalizationTechnique( + emptyParameters, + new ScoreNormalizationUtil() + ); + assertEquals("min_max", techniqueWithoutBounds.describe()); + + Map parametersMissingMode = new HashMap<>(); + List> lowerBoundsMissingMode = Arrays.asList( + Map.of("min_score", 0.2), + Map.of("mode", "clip", "min_score", 0.1) + ); + parametersMissingMode.put("lower_bounds", lowerBoundsMissingMode); + MinMaxScoreNormalizationTechnique techniqueMissingMode = new MinMaxScoreNormalizationTechnique( + parametersMissingMode, + new ScoreNormalizationUtil() + ); + assertEquals("min_max, lower bounds [(apply, 0.2), (clip, 0.1)]", techniqueMissingMode.describe()); + + Map parametersMissingScore = new HashMap<>(); + List> lowerBoundsMissingScore = Arrays.asList( + Map.of("mode", "apply"), + Map.of("mode", "clip", "min_score", 0.1) + ); + parametersMissingScore.put("lower_bounds", lowerBoundsMissingScore); + MinMaxScoreNormalizationTechnique techniqueMissingScore = new MinMaxScoreNormalizationTechnique( + parametersMissingScore, + new ScoreNormalizationUtil() + ); + assertEquals("min_max, lower bounds [(apply, 0.0), (clip, 0.1)]", techniqueMissingScore.describe()); + } + + public void testLowerBounds_whenInvalidInput_thenFail() { + // Test case 1: Invalid mode value + Map parametersInvalidMode = new HashMap<>(); + List> lowerBoundsInvalidMode = Arrays.asList( + Map.of("mode", "invalid_mode", "min_score", 0.2), + Map.of("mode", "clip", "min_score", 0.1) + ); + parametersInvalidMode.put("lower_bounds", lowerBoundsInvalidMode); + IllegalArgumentException invalidModeException = expectThrows( + IllegalArgumentException.class, + () -> new MinMaxScoreNormalizationTechnique(parametersInvalidMode, new ScoreNormalizationUtil()) + ); + assertEquals("invalid mode: invalid_mode, valid values are: apply, clip, ignore", invalidModeException.getMessage()); + + // Test case 4: Invalid min_score type + Map parametersInvalidScore = new HashMap<>(); + List> lowerBoundsInvalidScore = Arrays.asList( + Map.of("mode", "apply", "min_score", "not_a_number"), + Map.of("mode", "clip", "min_score", 0.1) + ); + parametersInvalidScore.put("lower_bounds", lowerBoundsInvalidScore); + IllegalArgumentException invalidScoreException = expectThrows( + IllegalArgumentException.class, + () -> new MinMaxScoreNormalizationTechnique(parametersInvalidScore, new ScoreNormalizationUtil()) + ); + assertEquals("invalid format for min_score: must be a valid float value", invalidScoreException.getMessage()); + } + + public void testLowerBoundsValidation_whenLowerBoundsAndSubQueriesCountMismatch_thenFail() { + Map parameters = new HashMap<>(); + List> lowerBounds = Arrays.asList(Map.of("mode", "clip", "min_score", 0.1)); + parameters.put(PARAM_NAME_LOWER_BOUNDS, lowerBounds); + + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } + ), + new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(3, 0.1f) }) + ), + false, + SEARCH_SHARD + ) + ); + ScoreNormalizationTechnique minMaxTechnique = new MinMaxScoreNormalizationTechnique(parameters, new ScoreNormalizationUtil()); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(minMaxTechnique) + .build(); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> minMaxTechnique.normalize(normalizeScoresDTO) + ); + + assertEquals( + "expected lower bounds array to contain 2 elements matching the number of sub-queries, but found a mismatch", + exception.getMessage() + ); + } + + public void testLowerBoundsValidation_whenTopDocsIsEmpty_thenSuccessful() { + Map parameters = new HashMap<>(); + List> lowerBounds = Arrays.asList( + Map.of("mode", "clip", "min_score", 0.1), + Map.of("mode", "apply", "min_score", 0.0) + ); + parameters.put(PARAM_NAME_LOWER_BOUNDS, lowerBounds); + + List compoundTopDocs = List.of( + new CompoundTopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), List.of(), false, SEARCH_SHARD), + new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } + ), + new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(3, 0.1f) }) + ), + false, + SEARCH_SHARD + ) + ); + ScoreNormalizationTechnique minMaxTechnique = new MinMaxScoreNormalizationTechnique(parameters, new ScoreNormalizationUtil()); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(minMaxTechnique) + .build(); + + minMaxTechnique.normalize(normalizeScoresDTO); + + CompoundTopDocs expectedCompoundDocsZero = new CompoundTopDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + List.of(), + false, + SEARCH_SHARD + ); + CompoundTopDocs expectedCompoundDocsOne = new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, 0.25f) } + ), + new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(3, 1.0f) }) + ), + false, + SEARCH_SHARD + ); + expectedCompoundDocsOne.setScoreDocs(List.of(new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f))); + assertNotNull(compoundTopDocs); + assertEquals(2, compoundTopDocs.size()); + CompoundTopDocs compoundTopDocsZero = compoundTopDocs.get(0); + assertEquals(expectedCompoundDocsZero, compoundTopDocsZero); + CompoundTopDocs compoundTopDocsOne = compoundTopDocs.get(1); + assertEquals(expectedCompoundDocsOne, compoundTopDocsOne); + } + + public void testInvalidParameters() { + Map parameters = new HashMap<>(); + List> lowerBoundsList = List.of( + Map.of("min_score", 0.1, "mode", "clip"), + Map.of("mode", "ignore", "invalid_param", "value") // Adding an invalid nested parameter + ); + parameters.put("lower_bounds", lowerBoundsList); + + try { + new MinMaxScoreNormalizationTechnique(parameters, new ScoreNormalizationUtil()); + fail("Expected IllegalArgumentException was not thrown"); + } catch (IllegalArgumentException e) { + assertEquals("unrecognized parameters in normalization technique", e.getMessage()); + } + } + + public void testUnsupportedTopLevelParameter() { + Map parameters = new HashMap<>(); + parameters.put("invalid_top_level_param", "value"); // Adding an invalid top-level parameter + + try { + new MinMaxScoreNormalizationTechnique(parameters, new ScoreNormalizationUtil()); + fail("Expected IllegalArgumentException was not thrown"); + } catch (IllegalArgumentException e) { + assertEquals("unrecognized parameters in normalization technique", e.getMessage()); + } + } + private void assertCompoundTopDocs(TopDocs expected, TopDocs actual) { assertEquals(expected.totalHits.value(), actual.totalHits.value()); assertEquals(expected.totalHits.relation(), actual.totalHits.relation()); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java index cecdf8779..1adaa89d5 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java @@ -5,9 +5,15 @@ package org.opensearch.neuralsearch.processor.normalization; import static org.hamcrest.Matchers.containsString; +import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_LOWER_BOUNDS; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + public class ScoreNormalizationFactoryTests extends OpenSearchQueryTestCase { public void testMinMaxNorm_whenCreatingByName_thenReturnCorrectInstance() { @@ -42,4 +48,32 @@ public void testUnsupportedTechnique_whenPassingInvalidName_thenFail() { ); assertThat(illegalArgumentException.getMessage(), containsString("provided normalization technique is not supported")); } + + public void testCreateMinMaxNormalizationWithParameters() { + Map parameters = new HashMap<>(); + + List> lowerBounds = Arrays.asList(Map.of("mode", "clip", "min_score", 0.1)); + parameters.put(PARAM_NAME_LOWER_BOUNDS, lowerBounds); + + ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); + ScoreNormalizationTechnique normalizationTechnique = scoreNormalizationFactory.createNormalization("min_max", parameters); + + assertNotNull(normalizationTechnique); + assertTrue(normalizationTechnique instanceof MinMaxScoreNormalizationTechnique); + } + + public void testThrowsExceptionForInvalidTechniqueWithParameters() { + Map parameters = new HashMap<>(); + + List> lowerBounds = Arrays.asList(Map.of("mode", "clip", "min_score", 0.1)); + parameters.put(PARAM_NAME_LOWER_BOUNDS, lowerBounds); + + ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> scoreNormalizationFactory.createNormalization(L2ScoreNormalizationTechnique.TECHNIQUE_NAME, parameters) + ); + assertEquals("unrecognized parameters in normalization technique", exception.getMessage()); + } + } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtilTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtilTests.java index e61ba0b10..1beea1d8b 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtilTests.java @@ -8,6 +8,7 @@ import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; import org.opensearch.test.OpenSearchTestCase; +import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -19,11 +20,16 @@ public class ScoreNormalizationUtilTests extends OpenSearchTestCase { private ScoreNormalizationUtil scoreNormalizationUtil; + private Set supportedTopLevelParams; + private Map> supportedNestedParams; @Override public void setUp() throws Exception { super.setUp(); scoreNormalizationUtil = new ScoreNormalizationUtil(); + supportedTopLevelParams = new HashSet<>(Arrays.asList("method", "parameters")); + supportedNestedParams = new HashMap<>(); + supportedNestedParams.put("parameters", new HashSet<>(Arrays.asList("factor", "offset"))); } public void testValidateParamsWithUnsupportedParameter() { @@ -67,4 +73,97 @@ public void testSetNormalizedScore() { assertEquals(0.0f, scores.get(0), DELTA_FOR_FLOATS_ASSERTION); assertEquals(0.0f, scores.get(2), DELTA_FOR_FLOATS_ASSERTION); } + + public void testValidateParametersWithNullParameters() { + scoreNormalizationUtil.validateParameters(null, supportedTopLevelParams, supportedNestedParams); + } + + public void testValidateParametersWithEmptyParameters() { + scoreNormalizationUtil.validateParameters(new HashMap<>(), supportedTopLevelParams, supportedNestedParams); + } + + public void testValidateParametersWithValidTopLevelParameters() { + Map params = new HashMap<>(); + Map nestedParams = new HashMap<>(); + nestedParams.put("factor", 1.0); + params.put("parameters", nestedParams); + + scoreNormalizationUtil.validateParameters(params, supportedTopLevelParams, supportedNestedParams); + } + + public void testValidateParametersWithValidNestedParameters() { + Map nestedParams = new HashMap<>(); + nestedParams.put("factor", 1.0); + nestedParams.put("offset", 0.0); + + Map params = new HashMap<>(); + params.put("parameters", nestedParams); + + scoreNormalizationUtil.validateParameters(params, supportedTopLevelParams, supportedNestedParams); + } + + public void testValidateParametersWithInvalidTopLevelParameter() { + Map params = new HashMap<>(); + params.put("invalid_param", "value"); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> scoreNormalizationUtil.validateParameters(params, supportedTopLevelParams, supportedNestedParams) + ); + assertEquals("unrecognized parameters in normalization technique", exception.getMessage()); + } + + public void testValidateParametersWithInvalidNestedParameter() { + Map nestedParams = new HashMap<>(); + nestedParams.put("invalid_nested", "value"); + + Map params = new HashMap<>(); + params.put("parameters", nestedParams); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> scoreNormalizationUtil.validateParameters(params, supportedTopLevelParams, supportedNestedParams) + ); + assertEquals("unrecognized parameters in normalization technique", exception.getMessage()); + } + + public void testValidateParametersWithListOfMaps() { + Map validNestedParams1 = new HashMap<>(); + validNestedParams1.put("factor", 1.0); + Map validNestedParams2 = new HashMap<>(); + validNestedParams2.put("offset", 0.0); + + Map params = new HashMap<>(); + params.put("parameters", Arrays.asList(validNestedParams1, validNestedParams2)); + + scoreNormalizationUtil.validateParameters(params, supportedTopLevelParams, supportedNestedParams); + } + + public void testValidateParametersWithInvalidListContent() { + Map params = new HashMap<>(); + params.put("parameters", Arrays.asList("invalid", "content")); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> scoreNormalizationUtil.validateParameters(params, supportedTopLevelParams, supportedNestedParams) + ); + assertEquals("unrecognized parameters in normalization technique", exception.getMessage()); + } + + public void testValidateParametersWithMixedValidAndInvalidParameters() { + Map nestedParams = new HashMap<>(); + nestedParams.put("factor", 1.0); + nestedParams.put("invalid_nested", "value"); + + Map params = new HashMap<>(); + params.put("method", "min_max"); + params.put("parameters", nestedParams); + params.put("invalid_top", "value"); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> scoreNormalizationUtil.validateParameters(params, supportedTopLevelParams, supportedNestedParams) + ); + assertEquals("unrecognized parameters in normalization technique", exception.getMessage()); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java index 3fe39554e..25b96838b 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java @@ -129,71 +129,7 @@ public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { // explain Map searchHit1 = hitsNestedList.get(0); Map topLevelExplanationsHit1 = getValueByKey(searchHit1, "_explanation"); - assertNotNull(topLevelExplanationsHit1); - assertEquals((double) searchHit1.get("_score"), (double) topLevelExplanationsHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); - String expectedTopLevelDescription = "arithmetic_mean combination of:"; - assertEquals(expectedTopLevelDescription, topLevelExplanationsHit1.get("description")); - List> normalizationExplanationHit1 = getListOfValues(topLevelExplanationsHit1, "details"); - assertEquals(1, normalizationExplanationHit1.size()); - Map hit1DetailsForHit1 = normalizationExplanationHit1.get(0); - assertEquals(1.0, hit1DetailsForHit1.get("value")); - assertEquals("min_max normalization of:", hit1DetailsForHit1.get("description")); - assertEquals(1, ((List) hit1DetailsForHit1.get("details")).size()); - - Map explanationsHit1 = getListOfValues(hit1DetailsForHit1, "details").get(0); - assertEquals("sum of:", explanationsHit1.get("description")); - assertEquals(0.343f, (double) explanationsHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertEquals(1, ((List) explanationsHit1.get("details")).size()); - - // search hit 2 - Map searchHit2 = hitsNestedList.get(1); - Map topLevelExplanationsHit2 = getValueByKey(searchHit2, "_explanation"); - assertNotNull(topLevelExplanationsHit2); - assertEquals((double) searchHit2.get("_score"), (double) topLevelExplanationsHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); - - assertEquals(expectedTopLevelDescription, topLevelExplanationsHit2.get("description")); - List> normalizationExplanationHit2 = getListOfValues(topLevelExplanationsHit2, "details"); - assertEquals(1, normalizationExplanationHit2.size()); - - Map hit1DetailsForHit2 = normalizationExplanationHit2.get(0); - assertEquals(1.0, hit1DetailsForHit2.get("value")); - assertEquals("min_max normalization of:", hit1DetailsForHit2.get("description")); - assertEquals(1, getListOfValues(hit1DetailsForHit2, "details").size()); - - Map explanationsHit2 = getListOfValues(hit1DetailsForHit2, "details").get(0); - assertEquals(0.13f, (double) explanationsHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertEquals("weight(test-text-field-1:hello in 0) [PerFieldSimilarity], result of:", explanationsHit2.get("description")); - assertEquals(1, getListOfValues(explanationsHit2, "details").size()); - - Map explanationsHit2Details = getListOfValues(explanationsHit2, "details").get(0); - assertEquals(0.13f, (double) explanationsHit2Details.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertEquals("score(freq=1.0), computed as boost * idf * tf from:", explanationsHit2Details.get("description")); - assertEquals(2, getListOfValues(explanationsHit2Details, "details").size()); - - // search hit 3 - Map searchHit3 = hitsNestedList.get(1); - Map topLevelExplanationsHit3 = getValueByKey(searchHit3, "_explanation"); - assertNotNull(topLevelExplanationsHit3); - assertEquals((double) searchHit2.get("_score"), (double) topLevelExplanationsHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); - - assertEquals(expectedTopLevelDescription, topLevelExplanationsHit3.get("description")); - List> normalizationExplanationHit3 = getListOfValues(topLevelExplanationsHit3, "details"); - assertEquals(1, normalizationExplanationHit3.size()); - - Map hit1DetailsForHit3 = normalizationExplanationHit3.get(0); - assertEquals(1.0, hit1DetailsForHit3.get("value")); - assertEquals("min_max normalization of:", hit1DetailsForHit3.get("description")); - assertEquals(1, getListOfValues(hit1DetailsForHit3, "details").size()); - - Map explanationsHit3 = getListOfValues(hit1DetailsForHit3, "details").get(0); - assertEquals(0.13f, (double) explanationsHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertEquals("weight(test-text-field-1:hello in 0) [PerFieldSimilarity], result of:", explanationsHit3.get("description")); - assertEquals(1, getListOfValues(explanationsHit3, "details").size()); - - Map explanationsHit3Details = getListOfValues(explanationsHit3, "details").get(0); - assertEquals(0.13f, (double) explanationsHit3Details.get("value"), DELTA_FOR_SCORE_ASSERTION); - assertEquals("score(freq=1.0), computed as boost * idf * tf from:", explanationsHit3Details.get("description")); - assertEquals(2, getListOfValues(explanationsHit3Details, "details").size()); + assertExplanation(topLevelExplanationsHit1, searchHit1, hitsNestedList, false); } @SneakyThrows @@ -732,6 +668,154 @@ public void testExplain_whenRRFProcessor_thenSuccessful() { assertTrue((double) explanationsHit4.get("value") > 0.0f); } + @SneakyThrows + public void testExplain_whenMinMaxNormalizationWithLowerBounds_thenSuccessful() { + initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); + // create search pipeline with both normalization processor and explain response processor + createSearchPipeline( + NORMALIZATION_SEARCH_PIPELINE, + DEFAULT_NORMALIZATION_METHOD, + Map.of( + "lower_bounds", + List.of( + Map.of("mode", "apply", "min_score", Float.toString(0.01f)), + Map.of("mode", "clip", "min_score", Float.toString(0.0f)) + ) + ), + DEFAULT_COMBINATION_METHOD, + Map.of(), + true + ); + + TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); + TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.should(termQueryBuilder2).should(termQueryBuilder3); + + HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1); + hybridQueryBuilderNeuralThenTerm.add(boolQueryBuilder); + + Map searchResponseAsMap1 = search( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + hybridQueryBuilderNeuralThenTerm, + null, + 10, + Map.of("search_pipeline", NORMALIZATION_SEARCH_PIPELINE, "explain", "true") + ); + // Assert + // search hits + assertEquals(3, getHitCount(searchResponseAsMap1)); + + List> hitsNestedList = getNestedHits(searchResponseAsMap1); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(searchResponseAsMap1); + assertNotNull(total.get("value")); + assertEquals(3, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + // explain + Map searchHit1 = hitsNestedList.get(0); + Map topLevelExplanationsHit1 = getValueByKey(searchHit1, "_explanation"); + assertExplanation(topLevelExplanationsHit1, searchHit1, hitsNestedList, true); + } + + private void assertExplanation( + Map topLevelExplanationsHit1, + Map searchHit1, + List> hitsNestedList, + boolean withLowerBounds + ) { + assertNotNull(topLevelExplanationsHit1); + assertEquals((double) searchHit1.get("_score"), (double) topLevelExplanationsHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); + String expectedTopLevelDescription = "arithmetic_mean combination of:"; + assertEquals(expectedTopLevelDescription, topLevelExplanationsHit1.get("description")); + List> normalizationExplanationHit1 = getListOfValues(topLevelExplanationsHit1, "details"); + assertEquals(1, normalizationExplanationHit1.size()); + Map hit1DetailsForHit1 = normalizationExplanationHit1.get(0); + assertEquals(1.0, hit1DetailsForHit1.get("value")); + if (withLowerBounds) { + assertEquals("min_max, lower bounds [(apply, 0.01), (clip, 0.0)] normalization of:", hit1DetailsForHit1.get("description")); + } else { + assertEquals("min_max normalization of:", hit1DetailsForHit1.get("description")); + } + assertEquals(1, ((List) hit1DetailsForHit1.get("details")).size()); + + Map explanationsHit1 = getListOfValues(hit1DetailsForHit1, "details").get(0); + assertEquals("sum of:", explanationsHit1.get("description")); + assertEquals(0.343f, (double) explanationsHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(1, ((List) explanationsHit1.get("details")).size()); + + // search hit 2 + Map searchHit2 = hitsNestedList.get(1); + Map topLevelExplanationsHit2 = getValueByKey(searchHit2, "_explanation"); + assertNotNull(topLevelExplanationsHit2); + assertEquals((double) searchHit2.get("_score"), (double) topLevelExplanationsHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, topLevelExplanationsHit2.get("description")); + List> normalizationExplanationHit2 = getListOfValues(topLevelExplanationsHit2, "details"); + assertEquals(1, normalizationExplanationHit2.size()); + + Map hit1DetailsForHit2 = normalizationExplanationHit2.get(0); + assertEquals(1.0, hit1DetailsForHit2.get("value")); + if (withLowerBounds) { + assertEquals("min_max, lower bounds [(apply, 0.01), (clip, 0.0)] normalization of:", hit1DetailsForHit2.get("description")); + } else { + assertEquals("min_max normalization of:", hit1DetailsForHit2.get("description")); + } + assertEquals(1, getListOfValues(hit1DetailsForHit2, "details").size()); + + Map explanationsHit2 = getListOfValues(hit1DetailsForHit2, "details").get(0); + assertEquals(0.13f, (double) explanationsHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("weight(test-text-field-1:hello in 0) [PerFieldSimilarity], result of:", explanationsHit2.get("description")); + assertEquals(1, getListOfValues(explanationsHit2, "details").size()); + + Map explanationsHit2Details = getListOfValues(explanationsHit2, "details").get(0); + assertEquals(0.13f, (double) explanationsHit2Details.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("score(freq=1.0), computed as boost * idf * tf from:", explanationsHit2Details.get("description")); + assertEquals(2, getListOfValues(explanationsHit2Details, "details").size()); + + // search hit 3 + Map searchHit3 = hitsNestedList.get(1); + Map topLevelExplanationsHit3 = getValueByKey(searchHit3, "_explanation"); + assertNotNull(topLevelExplanationsHit3); + assertEquals((double) searchHit2.get("_score"), (double) topLevelExplanationsHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, topLevelExplanationsHit3.get("description")); + List> normalizationExplanationHit3 = getListOfValues(topLevelExplanationsHit3, "details"); + assertEquals(1, normalizationExplanationHit3.size()); + + Map hit1DetailsForHit3 = normalizationExplanationHit3.get(0); + assertEquals(1.0, hit1DetailsForHit3.get("value")); + if (withLowerBounds) { + assertEquals("min_max, lower bounds [(apply, 0.01), (clip, 0.0)] normalization of:", hit1DetailsForHit3.get("description")); + } else { + assertEquals("min_max normalization of:", hit1DetailsForHit3.get("description")); + } + assertEquals(1, getListOfValues(hit1DetailsForHit3, "details").size()); + + Map explanationsHit3 = getListOfValues(hit1DetailsForHit3, "details").get(0); + assertEquals(0.13f, (double) explanationsHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("weight(test-text-field-1:hello in 0) [PerFieldSimilarity], result of:", explanationsHit3.get("description")); + assertEquals(1, getListOfValues(explanationsHit3, "details").size()); + + Map explanationsHit3Details = getListOfValues(explanationsHit3, "details").get(0); + assertEquals(0.13f, (double) explanationsHit3Details.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("score(freq=1.0), computed as boost * idf * tf from:", explanationsHit3Details.get("description")); + assertEquals(2, getListOfValues(explanationsHit3Details, "details").size()); + } + @SneakyThrows private void initializeIndexIfNotExist(String indexName) { if (TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME)) { From de3f9d48bcbf3a84255ddd6d4eda21645a47dacb Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Wed, 26 Feb 2025 18:29:04 -0800 Subject: [PATCH 5/8] Refactor getter for lowerBounds after code review comments Signed-off-by: Martin Gaievski --- .../MinMaxScoreNormalizationTechnique.java | 79 +++++++++++-------- 1 file changed, 46 insertions(+), 33 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index 3bc138040..892236278 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -4,7 +4,6 @@ */ package org.opensearch.neuralsearch.processor.normalization; -import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -236,18 +235,22 @@ private float normalizeSingleScore(final float score, final float minScore, fina return lowerBound.getMode().normalize(score, minScore, maxScore, lowerBound.getMinScore()); } + /** + * Get lower bounds from input parameters + * @param params user provided input parameters for this technique + * @return optional list of lower bounds. Can be empty in case lower bounds are not provided + */ private Optional>> getLowerBounds(final Map params) { + // validate that the input parameters are in correct format if (Objects.isNull(params) || !params.containsKey(PARAM_NAME_LOWER_BOUNDS)) { return Optional.empty(); } - List> lowerBounds = new ArrayList<>(); - Object lowerBoundsObj = params.get(PARAM_NAME_LOWER_BOUNDS); if (!(lowerBoundsObj instanceof List lowerBoundsParams)) { throw new IllegalArgumentException("lower_bounds must be a List"); } - + // number of lower bounds must match the number of sub-queries in a hybrid query if (lowerBoundsParams.size() > MAX_NUMBER_OF_SUB_QUERIES) { throw new IllegalArgumentException( String.format( @@ -258,42 +261,52 @@ private Optional>> getLowerBounds(final Map> lowerBounds = lowerBoundsParams.stream().map(this::parseLowerBound).collect(Collectors.toList()); - for (Object boundObj : lowerBoundsParams) { - if (!(boundObj instanceof Map)) { - throw new IllegalArgumentException("each lower bound must be a map"); - } + return Optional.of(lowerBounds); + } - @SuppressWarnings("unchecked") - Map lowerBound = (Map) boundObj; + @SuppressWarnings("unchecked") + /** + * Parse each lower bound item and return a pair of mode and min score + * @param boundObj lower bound item provided by the client + * @return a single pair of mode and min score + */ + private Pair parseLowerBound(Object boundObj) { + if (!(boundObj instanceof Map)) { + throw new IllegalArgumentException("each lower bound must be a map"); + } - try { - LowerBound.Mode mode = LowerBound.Mode.fromString( - Objects.isNull(lowerBound.get(PARAM_NAME_LOWER_BOUND_MODE)) - ? "" - : lowerBound.get(PARAM_NAME_LOWER_BOUND_MODE).toString() - ); - float minScore; - if (Objects.isNull(lowerBound.get(PARAM_NAME_LOWER_BOUND_MIN_SCORE))) { - minScore = LowerBound.DEFAULT_LOWER_BOUND_SCORE; - } else { - minScore = Float.parseFloat(String.valueOf(lowerBound.get(PARAM_NAME_LOWER_BOUND_MIN_SCORE))); - } + Map lowerBound = (Map) boundObj; - Validate.isTrue( - minScore >= LowerBound.MIN_LOWER_BOUND_SCORE && minScore <= LowerBound.MAX_LOWER_BOUND_SCORE, - "min_score must be a valid finite number between %f and %f", - LowerBound.MIN_LOWER_BOUND_SCORE, - LowerBound.MAX_LOWER_BOUND_SCORE - ); + String lowerBoundModeValue = Objects.toString(lowerBound.get(PARAM_NAME_LOWER_BOUND_MODE), ""); + LowerBound.Mode mode = LowerBound.Mode.fromString(lowerBoundModeValue); + float minScore = extractAndValidateMinScore(lowerBound); + + return ImmutablePair.of(mode, minScore); + } - lowerBounds.add(ImmutablePair.of(mode, minScore)); - } catch (NumberFormatException e) { - throw new IllegalArgumentException("invalid format for min_score: must be a valid float value", e); + private float extractAndValidateMinScore(Map lowerBound) { + Object minScoreObj = lowerBound.get(PARAM_NAME_LOWER_BOUND_MIN_SCORE); + if (minScoreObj == null) { + return LowerBound.DEFAULT_LOWER_BOUND_SCORE; + } + try { + float minScore = LowerBound.DEFAULT_LOWER_BOUND_SCORE; + if (Objects.nonNull(lowerBound.get(PARAM_NAME_LOWER_BOUND_MIN_SCORE))) { + minScore = Float.parseFloat(String.valueOf(lowerBound.get(PARAM_NAME_LOWER_BOUND_MIN_SCORE))); } + Validate.isTrue( + minScore >= LowerBound.MIN_LOWER_BOUND_SCORE && minScore <= LowerBound.MAX_LOWER_BOUND_SCORE, + "min_score must be a valid finite number between %f and %f", + LowerBound.MIN_LOWER_BOUND_SCORE, + LowerBound.MAX_LOWER_BOUND_SCORE + ); + return minScore; + } catch (NumberFormatException e) { + throw new IllegalArgumentException("invalid format for min_score: must be a valid float value", e); } - - return Optional.of(lowerBounds); } /** From adaf0eb059d99a23843c6a04168b4b5c06f18a75 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 27 Feb 2025 09:16:52 -0800 Subject: [PATCH 6/8] Changed syntax from negation to explicit comparision with false in if conditions Signed-off-by: Martin Gaievski --- .../neuralsearch/processor/CompoundTopDocs.java | 6 +++--- .../MinMaxScoreNormalizationTechnique.java | 13 +++++++------ .../normalization/ScoreNormalizationUtil.java | 4 ++-- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java b/src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java index 986a8f261..787f9d05a 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java @@ -170,10 +170,10 @@ public boolean equals(Object other) { if (thisTopDoc == null) { continue; } - if (!Objects.equals(thisTopDoc.totalHits, thatTopDoc.totalHits)) { + if (Objects.equals(thisTopDoc.totalHits, thatTopDoc.totalHits) == false) { return false; } - if (!compareScoreDocs(thisTopDoc.scoreDocs, thatTopDoc.scoreDocs)) { + if (compareScoreDocs(thisTopDoc.scoreDocs, thatTopDoc.scoreDocs) == false) { return false; } } @@ -202,7 +202,7 @@ private boolean compareScoreDocs(ScoreDoc[] first, ScoreDoc[] second) { } if (firstDoc instanceof FieldDoc firstFieldDoc) { FieldDoc secondFieldDoc = (FieldDoc) secondDoc; - if (!Arrays.equals(firstFieldDoc.fields, secondFieldDoc.fields)) { + if (Arrays.equals(firstFieldDoc.fields, secondFieldDoc.fields) == false) { return false; } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index 892236278..de9916d8a 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -107,7 +107,7 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { private boolean isLowerBoundsAndSubQueriesCountMismatched(List topDocsPerSubQuery) { return lowerBoundsOptional.isPresent() - && !topDocsPerSubQuery.isEmpty() + && topDocsPerSubQuery.isEmpty() == false && lowerBoundsOptional.get().size() != topDocsPerSubQuery.size(); } @@ -175,7 +175,7 @@ public Map explain(final List queryTopDocs) { return queryTopDocs.stream() .filter(Objects::nonNull) - .filter(topDocs -> !topDocs.getTopDocs().isEmpty()) + .filter(topDocs -> topDocs.getTopDocs().isEmpty() == false) .findAny() .get() .getTopDocs() @@ -229,7 +229,7 @@ private float normalizeSingleScore(final float score, final float minScore, fina if (Floats.compare(maxScore, minScore) == 0 && Floats.compare(maxScore, score) == 0) { return SINGLE_RESULT_SCORE; } - if (!lowerBound.isEnabled()) { + if (lowerBound.isEnabled() == false) { return LowerBound.Mode.IGNORE.normalize(score, minScore, maxScore, lowerBound.getMinScore()); } return lowerBound.getMode().normalize(score, minScore, maxScore, lowerBound.getMinScore()); @@ -242,14 +242,15 @@ private float normalizeSingleScore(final float score, final float minScore, fina */ private Optional>> getLowerBounds(final Map params) { // validate that the input parameters are in correct format - if (Objects.isNull(params) || !params.containsKey(PARAM_NAME_LOWER_BOUNDS)) { + if (Objects.isNull(params) || params.containsKey(PARAM_NAME_LOWER_BOUNDS) == false) { return Optional.empty(); } Object lowerBoundsObj = params.get(PARAM_NAME_LOWER_BOUNDS); - if (!(lowerBoundsObj instanceof List lowerBoundsParams)) { + if (lowerBoundsObj instanceof List == false) { throw new IllegalArgumentException("lower_bounds must be a List"); } + List lowerBoundsParams = (List) lowerBoundsObj; // number of lower bounds must match the number of sub-queries in a hybrid query if (lowerBoundsParams.size() > MAX_NUMBER_OF_SUB_QUERIES) { throw new IllegalArgumentException( @@ -274,7 +275,7 @@ private Optional>> getLowerBounds(final Map parseLowerBound(Object boundObj) { - if (!(boundObj instanceof Map)) { + if ((boundObj instanceof Map) == false) { throw new IllegalArgumentException("each lower bound must be a map"); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java index 7e086abd1..a27a9a718 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java @@ -49,7 +49,7 @@ public void validateParams(final Map actualParams, final Set Date: Thu, 27 Feb 2025 12:28:32 -0800 Subject: [PATCH 7/8] Addressing review comments, run 2 Signed-off-by: Martin Gaievski --- .../MinMaxScoreNormalizationTechnique.java | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index de9916d8a..802ba8c24 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -246,11 +246,10 @@ private Optional>> getLowerBounds(final Map == false) { - throw new IllegalArgumentException("lower_bounds must be a List"); - } - List lowerBoundsParams = (List) lowerBoundsObj; + List lowerBoundsParams = Optional.ofNullable(params.get(PARAM_NAME_LOWER_BOUNDS)) + .filter(List.class::isInstance) + .map(List.class::cast) + .orElseThrow(() -> new IllegalArgumentException("lower_bounds must be a List")); // number of lower bounds must match the number of sub-queries in a hybrid query if (lowerBoundsParams.size() > MAX_NUMBER_OF_SUB_QUERIES) { throw new IllegalArgumentException( @@ -324,7 +323,7 @@ private static class MinMaxScores { * Result class to hold lower bound for each sub query */ @Getter - public static class LowerBound { + static class LowerBound { static final float MIN_LOWER_BOUND_SCORE = -10_000f; static final float MAX_LOWER_BOUND_SCORE = 10_000f; static final float DEFAULT_LOWER_BOUND_SCORE = 0.0f; From c6d179e2211eedfbdd08b1faf9d0fa9b9a09654e Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 27 Feb 2025 16:15:52 -0800 Subject: [PATCH 8/8] Adding integ test for case when lower_bound score is greater then actual max score Signed-off-by: Martin Gaievski --- .../processor/NormalizationProcessorIT.java | 103 ++++++++++++++++++ 1 file changed, 103 insertions(+) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java index 64bf4573c..3f722c263 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java @@ -360,6 +360,109 @@ public void testMinMaxLowerBounds_whenMultipleShards_thenSuccessful() { assertQueryResults(searchResponseAsMapNoMatches, 0, true); } + @SneakyThrows + public void testMinMaxLowerBounds_whenLowerBoundsIsGreaterThenActualMinScore_thenSuccessful() { + String modelId = null; + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME); + modelId = prepareModel(); + createSearchPipeline( + SEARCH_PIPELINE_LOWER_BOUNDS_2_QUERIES, + DEFAULT_NORMALIZATION_METHOD, + Map.of( + "lower_bounds", + List.of( + Map.of("mode", "apply", "min_score", Float.toString(100.0f)), + Map.of("mode", "clip", "min_score", Float.toString(100.0f)) + ) + ), + DEFAULT_COMBINATION_METHOD, + Map.of(), + false + ); + int totalExpectedDocQty = 6; + + NeuralQueryBuilder neuralQueryBuilder = NeuralQueryBuilder.builder() + .fieldName(TEST_KNN_VECTOR_FIELD_NAME_1) + .queryText(TEST_DOC_TEXT1) + .modelId(modelId) + .k(6) + .build(); + + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(neuralQueryBuilder); + hybridQueryBuilder.add(termQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilder, + null, + 6, + Map.of("search_pipeline", SEARCH_PIPELINE_LOWER_BOUNDS_2_QUERIES) + ); + + assertNotNull(searchResponseAsMap); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(totalExpectedDocQty, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + assertTrue(Range.between(.5f, 1.0f).contains(getMaxScore(searchResponseAsMap).get())); + List> hitsNestedList = getNestedHits(searchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + // verify scores order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + + // verify the scores are normalized. we need special assert logic because combined score may vary as neural search query + // based on random vectors and return results for every doc. In some cases that may affect 1.0 score from term query and make it + // lower. + float highestScore = scores.stream().max(Double::compare).get().floatValue(); + assertTrue(Range.between(.5f, 1.0f).contains(highestScore)); + float lowestScore = scores.stream().min(Double::compare).get().floatValue(); + assertTrue(Range.between(.0f, .5f).contains(lowestScore)); + + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + createSearchPipeline( + SEARCH_PIPELINE_LOWER_BOUNDS_3_QUERIES, + DEFAULT_NORMALIZATION_METHOD, + Map.of( + "lower_bounds", + List.of( + Map.of("mode", "apply", "min_score", Float.toString(100.01f)), + Map.of("mode", "clip", "min_score", Float.toString(1000.0f)), + Map.of("mode", "ignore") + ) + ), + DEFAULT_COMBINATION_METHOD, + Map.of(), + false + ); + + // verify case when there are partial match + HybridQueryBuilder hybridQueryBuilderPartialMatch = new HybridQueryBuilder(); + hybridQueryBuilderPartialMatch.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); + hybridQueryBuilderPartialMatch.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4)); + hybridQueryBuilderPartialMatch.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT7)); + + Map searchResponseAsMapPartialMatch = search( + TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, + hybridQueryBuilderPartialMatch, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE_LOWER_BOUNDS_3_QUERIES) + ); + assertQueryResults(searchResponseAsMapPartialMatch, 4, false, Range.between(0.33f, 1.0f)); + } + private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME.equalsIgnoreCase(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_ONE_SHARD_NAME)) { prepareKnnIndex(