Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Add support for L1 distance in AKNN, custom scoring and painless scripting #310

Merged
merged 4 commits into from
Feb 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
*/
public enum SpaceTypes {
l2("l2"),
cosinesimil("cosinesimil");
cosinesimil("cosinesimil"),
l1("l1");

private String value;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public class KNNConstants {
public static final String HNSW_ALGO_EF_SEARCH = "efSearch";
public static final String HNSW_ALGO_INDEX_THREAD_QTY = "indexThreadQty";
public static final String L2 = "l2";
public static final String L1 = "l1";
public static final String COSINESIMIL = "cosinesimil";
public static final String HAMMING_BIT = "hammingbit";
public static final String DIMENSION = "dimension";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,33 @@ public ScoreScript getScoreScript(Map<String, Object> params, String field, Sear
(BiFunction<BigInteger, BigInteger, Float>) this.scoringMethod, lookup, ctx);
}
}

class L1 implements KNNScoringSpace {

float[] processedQuery;
BiFunction<float[], float[], Float> scoringMethod;

/**
* Constructor for L1 scoring space. L1 scoring space expects values to be of type float[].
*
* @param query Query object that, along with the doc values, will be used to compute L1 score
* @param fieldType FieldType for the doc values that will be used
*/
public L1(Object query, MappedFieldType fieldType) {
if (!isKNNVectorFieldType(fieldType)) {
throw new IllegalArgumentException("Incompatible field_type for l1 space. The field type must " +
"be knn_vector.");
}

this.processedQuery = parseToFloatArray(query,
((KNNVectorFieldMapper.KNNVectorFieldType) fieldType).getDimension());
this.scoringMethod = (float[] q, float[] v) -> 1 / (1 + KNNScoringUtil.l1distance(q, v));
}

public ScoreScript getScoreScript(Map<String, Object> params, String field, SearchLookup lookup,
LeafReaderContext ctx) throws IOException {
return new KNNScoreScript.KNNVectorType(params, this.processedQuery, field, this.scoringMethod, lookup,
ctx);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ public static KNNScoringSpace create(String spaceType, Object query, MappedField
if (KNNConstants.L2.equalsIgnoreCase(spaceType)) {
return new KNNScoringSpace.L2(query, mappedFieldType);
}
if (KNNConstants.L1.equalsIgnoreCase(spaceType)) {
return new KNNScoringSpace.L1(query, mappedFieldType);
}

if (KNNConstants.COSINESIMIL.equalsIgnoreCase(spaceType)) {
return new KNNScoringSpace.CosineSimilarity(query, mappedFieldType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.logging.log4j.Logger;

import java.math.BigInteger;
import java.lang.Math;
import java.util.List;
import java.util.Objects;

Expand Down Expand Up @@ -207,4 +208,42 @@ public static float calculateHammingBit(BigInteger queryBigInteger, BigInteger i
public static float calculateHammingBit(Long queryLong, Long inputLong) {
return Long.bitCount(queryLong ^ inputLong);
}

/**
* This method calculates L1 distance between query vector
* and input vector
*
* @param queryVector query vector
* @param inputVector input vector
* @return L1 score
*/
public static float l1distance(float[] queryVector, float[] inputVector) {
elb3k marked this conversation as resolved.
Show resolved Hide resolved
requireEqualDimension(queryVector, inputVector);
float distance = 0;
for (int i = 0; i < inputVector.length; i++) {
float diff = queryVector[i] - inputVector[i];
distance += Math.abs(diff);
}
return distance;
}

/**
* Whitelisted l1distance method for users to calculate L1 distance between query vector
* and document vectors
* Example
* "script": {
* "source": "1/(1 + l1distance(params.query_vector, doc[params.field]))",
* "params": {
* "query_vector": [1, 2, 3.4],
* "field": "my_dense_vector"
* }
* }
*
* @param queryVector query vector
* @param docValues script doc values
* @return L1 score
*/
public static float l1distance(List<Number> queryVector, KNNVectorScriptDocValues docValues) {
return l1distance(toFloat(queryVector), docValues.getValue());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class com.amazon.opendistroforelasticsearch.knn.index.KNNVectorScriptDocValues {
}
static_import {
float l2Squared(List, com.amazon.opendistroforelasticsearch.knn.index.KNNVectorScriptDocValues) from_class com.amazon.opendistroforelasticsearch.knn.plugin.script.KNNScoringUtil
float l1distance(List, com.amazon.opendistroforelasticsearch.knn.index.KNNVectorScriptDocValues) from_class com.amazon.opendistroforelasticsearch.knn.plugin.script.KNNScoringUtil
float cosineSimilarity(List, com.amazon.opendistroforelasticsearch.knn.index.KNNVectorScriptDocValues) from_class com.amazon.opendistroforelasticsearch.knn.plugin.script.KNNScoringUtil
float cosineSimilarity(List, com.amazon.opendistroforelasticsearch.knn.index.KNNVectorScriptDocValues, Number) from_class com.amazon.opendistroforelasticsearch.knn.plugin.script.KNNScoringUtil
}
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,55 @@ public Void run() {
dir.close();
}

public void testQueryHnswIndexl1() throws Exception {
int[] docs = {0, 1, 2};

float[][] vectors = {
{5.0f, 6.0f, 7.0f, 8.0f},
{1.0f, 2.0f, 3.0f, 4.0f},
{9.0f, 10.0f, 11.0f, 12.0f}
};

Directory dir = newFSDirectory(createTempDir());
String segmentName = "_dummy1";
String indexPath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(dir))).getDirectory().toString(),
String.format("%s.hnsw", segmentName)).toString();

String[] algoParams = {};
AccessController.doPrivileged(
new PrivilegedAction<Void>() {
public Void run() {
KNNIndex.saveIndex(docs, vectors, indexPath, algoParams, "l1");
return null;
}
}
);

assertTrue(Arrays.asList(dir.listAll()).contains("_dummy1.hnsw"));

float[] queryVector = {1.0f, 1.0f, 1.0f, 1.0f};
String[] algoQueryParams = {"efSearch=20"};

final KNNIndex knnIndex = KNNIndex.loadIndex(indexPath, algoQueryParams, "l1");
final KNNQueryResult[] results = knnIndex.queryIndex(queryVector, 30);

Map<Integer, Float> scores = Arrays.stream(results).collect(
Collectors.toMap(result -> result.getId(), result -> result.getScore()));
logger.info(scores);

assertEquals(results.length, 3);
/*
* scores are evaluated using Manhattan distance. Distance of the documents with
* respect to query vector are as follows
* doc0 = 22, doc1 = 6, doc2 = 38
* Nearest neighbor is doc1 then doc0 then doc2
*/
assertEquals(22.0, scores.get(0), 0.001);
assertEquals(6.0, scores.get(1), 0.001);
assertEquals(38.0, scores.get(2), 0.001);
dir.close();
}

public void testQueryHnswIndexWithValidAlgoParams() throws Exception {
int[] docs = {0, 1, 2};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,61 @@ public void testKNNL2ScriptScore() throws Exception {
assertEquals("1", results.get(3).getDocId());
}

public void testKNNL1ScriptScore() throws Exception {
/*
* Create knn index and populate data
*/
createKnnIndex(INDEX_NAME, createKnnIndexMapping(FIELD_NAME, 2));
Float[] f1 = {6.0f, 6.0f};
addKnnDoc(INDEX_NAME, "1", FIELD_NAME, f1);

Float[] f2 = {4.0f, 1.0f};
addKnnDoc(INDEX_NAME, "2", FIELD_NAME, f2);

Float[] f3 = {3.0f, 3.0f};
addKnnDoc(INDEX_NAME, "3", FIELD_NAME, f3);

Float[] f4 = {5.0f, 5.0f};
addKnnDoc(INDEX_NAME, "4", FIELD_NAME, f4);


/**
* Construct Search Request
*/
QueryBuilder qb = new MatchAllQueryBuilder();
Map<String, Object> params = new HashMap<>();
/*
* params": {
* "field": "my_dense_vector",
* "vector": [1.0, 1.0]
* }
*/
float[] queryVector = {1.0f, 1.0f};
params.put("field", FIELD_NAME);
params.put("query_value", queryVector);
params.put("space_type", KNNConstants.L1);
Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params);
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK,
RestStatus.fromCode(response.getStatusLine().getStatusCode()));

List<KNNResult> results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME);
List<String> expectedDocids = Arrays.asList("2", "4", "3", "1");

List<String> actualDocids = new ArrayList<>();
for(KNNResult result : results) {
actualDocids.add(result.getDocId());
}

assertEquals(4, results.size());

// assert document order
assertEquals("2", results.get(0).getDocId());
assertEquals("3", results.get(1).getDocId());
assertEquals("4", results.get(2).getDocId());
assertEquals("1", results.get(3).getDocId());
}

public void testKNNCosineScriptScore() throws Exception {
/*
* Create knn index and populate data
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ private Map<String, Float[]> getL2TestData() {
data.put("4", new Float[]{3.0f, 3.0f});
return data;
}
private Map<String, Float[]> getL1TestData() {
Map<String, Float[]> data = new HashMap<>();
data.put("1", new Float[]{6.0f, 6.0f});
data.put("2", new Float[]{4.0f, 1.0f});
data.put("3", new Float[]{3.0f, 3.0f});
data.put("4", new Float[]{5.0f, 5.0f});
return data;
}

private Map<String, Float[]> getCosineTestData() {
Map<String, Float[]> data = new HashMap<>();
Expand Down Expand Up @@ -246,6 +254,56 @@ public void testCosineSimilarityNormalizedScriptScoreWithNumericField() throws E
deleteKNNIndex(INDEX_NAME);
}

// L1 tests
public void testL1ScriptScoreFails() throws Exception {
String source = String.format("1/(1 + l1distance([1.0f, 1.0f], doc['%s']))", FIELD_NAME);
Request request = buildPainlessScriptRequest(source, 3, getL1TestData());
addDocWithNumericField(INDEX_NAME, "100", NUMERIC_INDEX_FIELD_NAME, 1000);
expectThrows(ResponseException.class, () -> client().performRequest(request));
deleteKNNIndex(INDEX_NAME);
}
public void testL1ScriptScore() throws Exception {

String source = String.format("1/(1 + l1distance([1.0f, 1.0f], doc['%s']))", FIELD_NAME);
Request request = buildPainlessScriptRequest(source, 3, getL1TestData());

Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK,
RestStatus.fromCode(response.getStatusLine().getStatusCode()));

List<KNNResult> results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME);
assertEquals(3, results.size());


String[] expectedDocIDs = {"2", "3", "4", "1"};
for (int i = 0; i < results.size(); i++) {
assertEquals(expectedDocIDs[i], results.get(i).getDocId());
}
deleteKNNIndex(INDEX_NAME);
}

public void testL1ScriptScoreWithNumericField() throws Exception {

String source = String.format(
"doc['%s'].size() == 0 ? 0 : 1/(1 + l1distance([1.0f, 1.0f], doc['%s']))", FIELD_NAME, FIELD_NAME);
Request request = buildPainlessScriptRequest(source, 3, getL1TestData());
addDocWithNumericField(INDEX_NAME, "100", NUMERIC_INDEX_FIELD_NAME, 1000);
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK,
RestStatus.fromCode(response.getStatusLine().getStatusCode()));

List<KNNResult> results = parseSearchResponse(EntityUtils.toString(response.getEntity()), FIELD_NAME);
assertEquals(3, results.size());


String[] expectedDocIDs = {"2", "3", "4", "1"};
for (int i = 0; i < results.size(); i++) {
assertEquals(expectedDocIDs[i], results.get(i).getDocId());
}
deleteKNNIndex(INDEX_NAME);
}


class MappingProperty {

private String name;
Expand Down