From 88d6ec45f7d7183558a8fd561af943a5938338a3 Mon Sep 17 00:00:00 2001 From: Elbek1997 Date: Tue, 26 Jan 2021 13:42:56 +0900 Subject: [PATCH 1/4] Added L1 index and scoring --- .../knn/index/SpaceTypes.java | 3 +- .../knn/plugin/script/KNNScoringSpace.java | 29 +++++++++++ .../knn/plugin/script/KNNScoringUtil.java | 39 +++++++++++++++ .../knn/index/KNNJNITests.java | 49 +++++++++++++++++++ 4 files changed, 119 insertions(+), 1 deletion(-) diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/SpaceTypes.java b/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/SpaceTypes.java index 73044d1b..c61f2f52 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/SpaceTypes.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/SpaceTypes.java @@ -23,7 +23,8 @@ */ public enum SpaceTypes { l2("l2"), - cosinesimil("cosinesimil"); + cosinesimil("cosinesimil"), + l1("l1"); private String value; diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/knn/plugin/script/KNNScoringSpace.java b/src/main/java/com/amazon/opendistroforelasticsearch/knn/plugin/script/KNNScoringSpace.java index 47006677..47ce435f 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/knn/plugin/script/KNNScoringSpace.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/knn/plugin/script/KNNScoringSpace.java @@ -148,4 +148,33 @@ public ScoreScript getScoreScript(Map params, String field, Sear (BiFunction) this.scoringMethod, lookup, ctx); } } + + class L1 implements KNNScoringSpace { + + float[] processedQuery; + BiFunction scoringMethod; + + /** + * Constructor for L1 scoring space. L1 scoring space expects values to be of type float[]. + * + * @param query Query object that, along with the doc values, will be used to compute L1 score + * @param fieldType FieldType for the doc values that will be used + */ + public L1(Object query, MappedFieldType fieldType) { + if (!isKNNVectorFieldType(fieldType)) { + throw new IllegalArgumentException("Incompatible field_type for l1 space. The field type must " + + "be knn_vector."); + } + + this.processedQuery = parseToFloatArray(query, + ((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension()); + this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l1distance(q, v)); + } + + public ScoreScript getScoreScript(Map params, String field, SearchLookup lookup, + LeafReaderContext ctx) throws IOException { + return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup, + ctx); + } + } } diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/knn/plugin/script/KNNScoringUtil.java b/src/main/java/com/amazon/opendistroforelasticsearch/knn/plugin/script/KNNScoringUtil.java index a16ff7ef..ff058c9c 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/knn/plugin/script/KNNScoringUtil.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/knn/plugin/script/KNNScoringUtil.java @@ -20,6 +20,7 @@ import org.apache.logging.log4j.Logger; import java.math.BigInteger; +import java.lang.Math; import java.util.List; import java.util.Objects; @@ -207,4 +208,42 @@ public static float calculateHammingBit(BigInteger queryBigInteger, BigInteger i public static float calculateHammingBit(Long queryLong, Long inputLong) { return Long.bitCount(queryLong ^ inputLong); } + + /** + * This method calculates L1 squared distance between query vector + * and input vector + * + * @param queryVector query vector + * @param inputVector input vector + * @return L1 score + */ + public static float l1distance(float[] queryVector, float[] inputVector) { + requireEqualDimension(queryVector, inputVector); + float distance = 0; + for (int i = 0; i < inputVector.length; i++) { + float diff = queryVector[i] - inputVector[i]; + distance += Math.abs(diff); + } + return distance; + } + + /** + * Whitelisted l1distance method for users to calculate L1 distance between query vector + * and document vectors + * Example + * "script": { + * "source": "1/(1 + l1distance(params.query_vector, doc[params.field]))", + * "params": { + * "query_vector": [1, 2, 3.4], + * "field": "my_dense_vector" + * } + * } + * + * @param queryVector query vector + * @param docValues script doc values + * @return L1 score + */ + public static float l1distance(List queryVector, KNNVectorScriptDocValues docValues) { + return l1distance(toFloat(queryVector), docValues.getValue()); + } } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNJNITests.java b/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNJNITests.java index c8fa0235..d72e831e 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNJNITests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/knn/index/KNNJNITests.java @@ -183,6 +183,55 @@ public Void run() { dir.close(); } + public void testQueryHnswIndexl1() throws Exception { + int[] docs = {0, 1, 2}; + + float[][] vectors = { + {5.0f, 6.0f, 7.0f, 8.0f}, + {1.0f, 2.0f, 3.0f, 4.0f}, + {9.0f, 10.0f, 11.0f, 12.0f} + }; + + Directory dir = newFSDirectory(createTempDir()); + String segmentName = "_dummy1"; + String indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(dir))).getDirectory().toString(), + String.format("%s.hnsw", segmentName)).toString(); + + String[] algoParams = {}; + AccessController.doPrivileged( + new PrivilegedAction() { + public Void run() { + KNNIndex.saveIndex(docs, vectors, indexPath, algoParams, "l1"); + return null; + } + } + ); + + assertTrue(Arrays.asList(dir.listAll()).contains("_dummy1.hnsw")); + + float[] queryVector = {1.0f, 1.0f, 1.0f, 1.0f}; + String[] algoQueryParams = {"efSearch=20"}; + + final KNNIndex knnIndex = KNNIndex.loadIndex(indexPath, algoQueryParams, "l1"); + final KNNQueryResult[] results = knnIndex.queryIndex(queryVector, 30); + + Map scores = Arrays.stream(results).collect( + Collectors.toMap(result -> result.getId(), result -> result.getScore())); + logger.info(scores); + + assertEquals(results.length, 3); + /* + * scores are evaluated using Manhattan distance. Distance of the documents with + * respect to query vector are as follows + * doc0 = 22, doc1 = 6, doc2 = 38 + * Nearest neighbor is doc1 then doc0 then doc2 + */ + assertEquals(22.0, scores.get(0), 0.001); + assertEquals(6.0, scores.get(1), 0.001); + assertEquals(38.0, scores.get(2), 0.001); + dir.close(); + } + public void testQueryHnswIndexWithValidAlgoParams() throws Exception { int[] docs = {0, 1, 2}; From ff224ac7476867d3639a2dffae438251c81bc821 Mon Sep 17 00:00:00 2001 From: Elbek1997 Date: Sat, 30 Jan 2021 17:02:19 +0900 Subject: [PATCH 2/4] L1 utils and scoring test --- .../knn/index/util/KNNConstants.java | 1 + .../plugin/script/KNNScoringSpaceFactory.java | 3 + .../knn/plugin/script/KNNScriptScoringIT.java | 55 +++++++++++++++++++ 3 files changed, 59 insertions(+) diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/util/KNNConstants.java b/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/util/KNNConstants.java index 0ff7c741..70e3ae99 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/util/KNNConstants.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/knn/index/util/KNNConstants.java @@ -22,6 +22,7 @@ public class KNNConstants { public static final String HNSW_ALGO_EF_SEARCH = "efSearch"; public static final String HNSW_ALGO_INDEX_THREAD_QTY = "indexThreadQty"; public static final String L2 = "l2"; + public static final String L1 = "l1"; public static final String COSINESIMIL = "cosinesimil"; public static final String HAMMING_BIT = "hammingbit"; public static final String DIMENSION = "dimension"; diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/knn/plugin/script/KNNScoringSpaceFactory.java b/src/main/java/com/amazon/opendistroforelasticsearch/knn/plugin/script/KNNScoringSpaceFactory.java index 5e9f5faa..b16b3a64 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/knn/plugin/script/KNNScoringSpaceFactory.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/knn/plugin/script/KNNScoringSpaceFactory.java @@ -32,6 +32,9 @@ public static KNNScoringSpace create(String spaceType, Object query, MappedField if (KNNConstants.L2.equalsIgnoreCase(spaceType)) { return new KNNScoringSpace.L2(query, mappedFieldType); } + if (KNNConstants.L1.equalsIgnoreCase(spaceType)) { + return new KNNScoringSpace.L1(query, mappedFieldType); + } if (KNNConstants.COSINESIMIL.equalsIgnoreCase(spaceType)) { return new KNNScoringSpace.CosineSimilarity(query, mappedFieldType); diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/knn/plugin/script/KNNScriptScoringIT.java b/src/test/java/com/amazon/opendistroforelasticsearch/knn/plugin/script/KNNScriptScoringIT.java index 8a386a4a..04f85b9c 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/knn/plugin/script/KNNScriptScoringIT.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/knn/plugin/script/KNNScriptScoringIT.java @@ -100,6 +100,61 @@ public void testKNNL2ScriptScore() throws Exception { assertEquals("1", results.get(3).getDocId()); } + public void testKNNL1ScriptScore() throws Exception { + /* + * Create knn index and populate data + */ + createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2)); + Float[] f1 = {6.0f, 6.0f}; + addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1); + + Float[] f2 = {4.0f, 1.0f}; + addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2); + + Float[] f3 = {3.0f, 3.0f}; + addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3); + + Float[] f4 = {5.0f, 5.0f}; + addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4); + + + /** + * Construct Search Request + */ + QueryBuilder qb = new MatchAllQueryBuilder(); + Map params = new HashMap<>(); + /* + * params": { + * "field": "my_dense_vector", + * "vector": [1.0, 1.0] + * } + */ + float[] queryVector = {1.0f, 1.0f}; + params.put("field", FIELD_NAME); + params.put("query_value", queryVector); + params.put("space_type", KNNConstants.L1); + Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, + RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); + List expectedDocids = Arrays.asList("2", "4", "3", "1"); + + List actualDocids = new ArrayList<>(); + for(KNNResult result : results) { + actualDocids.add(result.getDocId()); + } + + assertEquals(4, results.size()); + + // assert document order + assertEquals("2", results.get(0).getDocId()); + assertEquals("3", results.get(1).getDocId()); + assertEquals("4", results.get(2).getDocId()); + assertEquals("1", results.get(3).getDocId()); + } + public void testKNNCosineScriptScore() throws Exception { /* * Create knn index and populate data From 14512ff2fa8b9bdb719f3bf6753aa2f02326e687 Mon Sep 17 00:00:00 2001 From: Elbek1997 Date: Mon, 1 Feb 2021 13:46:07 +0900 Subject: [PATCH 3/4] Comment fix --- .../knn/plugin/script/KNNScoringUtil.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/knn/plugin/script/KNNScoringUtil.java b/src/main/java/com/amazon/opendistroforelasticsearch/knn/plugin/script/KNNScoringUtil.java index ff058c9c..07bd5fbd 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/knn/plugin/script/KNNScoringUtil.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/knn/plugin/script/KNNScoringUtil.java @@ -210,7 +210,7 @@ public static float calculateHammingBit(Long queryLong, Long inputLong) { } /** - * This method calculates L1 squared distance between query vector + * This method calculates L1 distance between query vector * and input vector * * @param queryVector query vector From 22d466235685ff56ed5e7b8596f5886db041e120 Mon Sep 17 00:00:00 2001 From: Elbek1997 Date: Tue, 2 Feb 2021 21:21:28 +0900 Subject: [PATCH 4/4] L1 painless script --- .../knn/plugin/script/knn_whitelist.txt | 1 + .../script/PainlessScriptScoringIT.java | 58 +++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/src/main/resources/com/amazon/opendistroforelasticsearch/knn/plugin/script/knn_whitelist.txt b/src/main/resources/com/amazon/opendistroforelasticsearch/knn/plugin/script/knn_whitelist.txt index 54dedd22..0ef0b012 100644 --- a/src/main/resources/com/amazon/opendistroforelasticsearch/knn/plugin/script/knn_whitelist.txt +++ b/src/main/resources/com/amazon/opendistroforelasticsearch/knn/plugin/script/knn_whitelist.txt @@ -17,6 +17,7 @@ class com.amazon.opendistroforelasticsearch.knn.index.KNNVectorScriptDocValues { } static_import { float l2Squared(List, com.amazon.opendistroforelasticsearch.knn.index.KNNVectorScriptDocValues) from_class com.amazon.opendistroforelasticsearch.knn.plugin.script.KNNScoringUtil + float l1distance(List, com.amazon.opendistroforelasticsearch.knn.index.KNNVectorScriptDocValues) from_class com.amazon.opendistroforelasticsearch.knn.plugin.script.KNNScoringUtil float cosineSimilarity(List, com.amazon.opendistroforelasticsearch.knn.index.KNNVectorScriptDocValues) from_class com.amazon.opendistroforelasticsearch.knn.plugin.script.KNNScoringUtil float cosineSimilarity(List, com.amazon.opendistroforelasticsearch.knn.index.KNNVectorScriptDocValues, Number) from_class com.amazon.opendistroforelasticsearch.knn.plugin.script.KNNScoringUtil } diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/knn/plugin/script/PainlessScriptScoringIT.java b/src/test/java/com/amazon/opendistroforelasticsearch/knn/plugin/script/PainlessScriptScoringIT.java index b0a2db0a..c700d4d3 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/knn/plugin/script/PainlessScriptScoringIT.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/knn/plugin/script/PainlessScriptScoringIT.java @@ -80,6 +80,14 @@ private Map getL2TestData() { data.put("4", new Float[]{3.0f, 3.0f}); return data; } + private Map getL1TestData() { + Map data = new HashMap<>(); + data.put("1", new Float[]{6.0f, 6.0f}); + data.put("2", new Float[]{4.0f, 1.0f}); + data.put("3", new Float[]{3.0f, 3.0f}); + data.put("4", new Float[]{5.0f, 5.0f}); + return data; + } private Map getCosineTestData() { Map data = new HashMap<>(); @@ -246,6 +254,56 @@ public void testCosineSimilarityNormalizedScriptScoreWithNumericField() throws E deleteKNNIndex(INDEX_NAME); } + // L1 tests + public void testL1ScriptScoreFails() throws Exception { + String source = String.format("1/(1 + l1distance([1.0f, 1.0f], doc['%s']))", FIELD_NAME); + Request request = buildPainlessScriptRequest(source, 3, getL1TestData()); + addDocWithNumericField(INDEX_NAME, "100", NUMERIC_INDEX_FIELD_NAME, 1000); + expectThrows(ResponseException.class, () -> client().performRequest(request)); + deleteKNNIndex(INDEX_NAME); + } + public void testL1ScriptScore() throws Exception { + + String source = String.format("1/(1 + l1distance([1.0f, 1.0f], doc['%s']))", FIELD_NAME); + Request request = buildPainlessScriptRequest(source, 3, getL1TestData()); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, + RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); + assertEquals(3, results.size()); + + + String[] expectedDocIDs = {"2", "3", "4", "1"}; + for (int i = 0; i < results.size(); i++) { + assertEquals(expectedDocIDs[i], results.get(i).getDocId()); + } + deleteKNNIndex(INDEX_NAME); + } + + public void testL1ScriptScoreWithNumericField() throws Exception { + + String source = String.format( + "doc['%s'].size() == 0 ? 0 : 1/(1 + l1distance([1.0f, 1.0f], doc['%s']))", FIELD_NAME, FIELD_NAME); + Request request = buildPainlessScriptRequest(source, 3, getL1TestData()); + addDocWithNumericField(INDEX_NAME, "100", NUMERIC_INDEX_FIELD_NAME, 1000); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, + RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + List results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME); + assertEquals(3, results.size()); + + + String[] expectedDocIDs = {"2", "3", "4", "1"}; + for (int i = 0; i < results.size(); i++) { + assertEquals(expectedDocIDs[i], results.get(i).getDocId()); + } + deleteKNNIndex(INDEX_NAME); + } + + class MappingProperty { private String name;