Skip to content

Commit

Permalink
SOLR-16675: dense vector function queries (#1750)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Alessandro Benedetti <a.benedetti@sease.io>
  • Loading branch information
eliaporciani and alessandrobenedetti committed Jul 9, 2023
1 parent 8b66cbc commit 8b320c4
Showing 10 changed files with 458 additions and 18 deletions.
2 changes: 2 additions & 0 deletions solr/CHANGES.txt
Original file line number Diff line number Diff line change
@@ -46,6 +46,8 @@ New Features

* SOLR-16717: {!join} can join collections with multiple shards on both sides. (Mikhail Khludnev)

* SOLR-16675: Added function queries for dense vector similarity. (Elia Porciani, Alessandro Benedetti)

Improvements
---------------------

13 changes: 11 additions & 2 deletions solr/core/src/java/org/apache/solr/schema/DenseVectorField.java
Original file line number Diff line number Diff line change
@@ -35,6 +35,8 @@
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.queries.function.valuesource.ByteKnnVectorFieldSource;
import org.apache.lucene.queries.function.valuesource.FloatKnnVectorFieldSource;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
@@ -343,9 +345,16 @@ public UninvertingReader.Type getUninversionType(SchemaField sf) {

@Override
public ValueSource getValueSource(SchemaField field, QParser parser) {

switch (vectorEncoding) {
case FLOAT32:
return new FloatKnnVectorFieldSource(field.getName());
case BYTE:
return new ByteKnnVectorFieldSource(field.getName());
}

throw new SolrException(
SolrException.ErrorCode.BAD_REQUEST,
"Function queries are not supported for Dense Vector fields.");
SolrException.ErrorCode.BAD_REQUEST, "Vector encoding not supported for function queries.");
}

public Query getKnnVectorQuery(
77 changes: 77 additions & 0 deletions solr/core/src/java/org/apache/solr/search/FunctionQParser.java
Original file line number Diff line number Diff line change
@@ -18,8 +18,11 @@

import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.queries.function.FunctionQuery;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.queries.function.valuesource.ConstKnnByteVectorValueSource;
import org.apache.lucene.queries.function.valuesource.ConstKnnFloatValueSource;
import org.apache.lucene.queries.function.valuesource.ConstValueSource;
import org.apache.lucene.queries.function.valuesource.DoubleConstValueSource;
import org.apache.lucene.queries.function.valuesource.LiteralValueSource;
@@ -40,6 +43,9 @@ public class FunctionQParser extends QParser {
// When a field name is encountered, use the placeholder FieldNameValueSource instead of resolving
// to a real ValueSource
public static final int FLAG_USE_FIELDNAME_SOURCE = 0x04;

// When the flag is set, vector parsing use byte encoding, otherwise float encoding is used
public static final int FLAG_PARSE_VECTOR_BYTE_ENCODING = 0x08;
public static final int FLAG_DEFAULT = FLAG_CONSUME_DELIMITER;

/**
@@ -243,6 +249,49 @@ public String parseArg() throws SyntaxError {
return val;
}

public List<Number> parseVector(VectorEncoding encoding) throws SyntaxError {
ArrayList<Number> values = new ArrayList<>();
char initChar = sp.val.charAt(sp.pos);
if (initChar != '[') {
throw new SyntaxError("Missing parenthesis at the beginning of vector ");
}
sp.pos += 1;
boolean valueExpected = true;
while (sp.pos < sp.end) {
char ch = sp.val.charAt(sp.pos);
if (Character.isWhitespace(ch)) {
sp.pos++;
} else if ((ch >= '0' && ch <= '9') || ch == '.' || ch == '+' || ch == '-') {
switch (encoding) {
case BYTE:
values.add(sp.getByte());
break;
case FLOAT32:
values.add(sp.getFloat());
break;
default:
throw new SyntaxError("Unexpected vector encoding: " + encoding);
}
valueExpected = false;
} else if (ch == ',') {
if (valueExpected) {
throw new SyntaxError("Unexpected vector encoding: " + encoding);
}
sp.pos++;
valueExpected = true;
} else if (ch == ']' && !valueExpected) {
break;
} else {
throw new SyntaxError("Unexpected " + ch + " at position " + sp.pos);
}
}
if (sp.pos >= sp.end) {
throw new SyntaxError("Missing parenthesis at the end of vector");
}
sp.pos++;
return values;
}

/**
* Parse a list of ValueSource. Must be the final set of arguments to a ValueSource.
*
@@ -363,6 +412,8 @@ protected ValueSource parseValueSource(int flags) throws SyntaxError {
}
} else if (ch == '"' || ch == '\'') {
valueSource = new LiteralValueSource(sp.getQuotedString());
} else if (ch == '[') {
valueSource = parseConstVector(flags);
} else if (ch == '$') {
sp.pos++;
String param = sp.getId();
@@ -457,6 +508,32 @@ protected ValueSource parseValueSource(int flags) throws SyntaxError {
return valueSource;
}

public ValueSource parseConstVector(int flags) throws SyntaxError {

VectorEncoding encoding =
(flags & FLAG_PARSE_VECTOR_BYTE_ENCODING) != 0
? VectorEncoding.BYTE
: VectorEncoding.FLOAT32;
var vector = parseVector(encoding);

switch (encoding) {
case BYTE:
byte[] byteVector = new byte[vector.size()];
for (int i = 0; i < vector.size(); ++i) {
byteVector[i] = vector.get(i).byteValue();
}
return new ConstKnnByteVectorValueSource(byteVector);
case FLOAT32:
float[] floatVector = new float[vector.size()];
for (int i = 0; i < vector.size(); ++i) {
floatVector[i] = vector.get(i).floatValue();
}
return new ConstKnnFloatValueSource(floatVector);
}

throw new SyntaxError("wrong vector encoding:" + encoding);
}

/**
* @lucene.experimental
*/
17 changes: 17 additions & 0 deletions solr/core/src/java/org/apache/solr/search/StrParser.java
Original file line number Diff line number Diff line change
@@ -164,6 +164,23 @@ public int getInt() {
return Integer.parseInt(new String(arr, 0, i));
}

public byte getByte() {
eatws();
char[] arr = new char[end - pos];
int i;
for (i = 0; i < arr.length; i++) {
char ch = val.charAt(pos);
if ((ch >= '0' && ch <= '9') || ch == '+' || ch == '-') {
pos++;
arr[i] = ch;
} else {
break;
}
}

return Byte.parseByte(new String(arr, 0, i));
}

public String getId() throws SyntaxError {
return getId("Expected identifier");
}
38 changes: 38 additions & 0 deletions solr/core/src/java/org/apache/solr/search/ValueSourceParser.java
Original file line number Diff line number Diff line change
@@ -26,19 +26,23 @@
import java.util.Map;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.queries.function.FunctionScoreQuery;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.queries.function.docvalues.BoolDocValues;
import org.apache.lucene.queries.function.docvalues.DoubleDocValues;
import org.apache.lucene.queries.function.docvalues.LongDocValues;
import org.apache.lucene.queries.function.valuesource.ByteVectorSimilarityFunction;
import org.apache.lucene.queries.function.valuesource.ConstNumberSource;
import org.apache.lucene.queries.function.valuesource.ConstValueSource;
import org.apache.lucene.queries.function.valuesource.DefFunction;
import org.apache.lucene.queries.function.valuesource.DivFloatFunction;
import org.apache.lucene.queries.function.valuesource.DocFreqValueSource;
import org.apache.lucene.queries.function.valuesource.DoubleConstValueSource;
import org.apache.lucene.queries.function.valuesource.DualFloatFunction;
import org.apache.lucene.queries.function.valuesource.FloatVectorSimilarityFunction;
import org.apache.lucene.queries.function.valuesource.IDFValueSource;
import org.apache.lucene.queries.function.valuesource.IfFunction;
import org.apache.lucene.queries.function.valuesource.JoinDocFreqValueSource;
@@ -340,6 +344,40 @@ public ValueSource parse(FunctionQParser fp) throws SyntaxError {
}
});
alias("sum", "add");
addParser(
"vectorSimilarity",
new ValueSourceParser() {
@Override
public ValueSource parse(FunctionQParser fp) throws SyntaxError {

VectorEncoding vectorEncoding = VectorEncoding.valueOf(fp.parseArg());
VectorSimilarityFunction functionName = VectorSimilarityFunction.valueOf(fp.parseArg());

int vectorEncodingFlag =
vectorEncoding.equals(VectorEncoding.BYTE)
? FunctionQParser.FLAG_PARSE_VECTOR_BYTE_ENCODING
: 0;
ValueSource v1 =
fp.parseValueSource(
FunctionQParser.FLAG_DEFAULT
| FunctionQParser.FLAG_CONSUME_DELIMITER
| vectorEncodingFlag);
ValueSource v2 =
fp.parseValueSource(
FunctionQParser.FLAG_DEFAULT
| FunctionQParser.FLAG_CONSUME_DELIMITER
| vectorEncodingFlag);

switch (vectorEncoding) {
case FLOAT32:
return new FloatVectorSimilarityFunction(functionName, v1, v2);
case BYTE:
return new ByteVectorSimilarityFunction(functionName, v1, v2);
default:
throw new SyntaxError("Invalid vector encoding: " + vectorEncoding);
}
}
});

addParser(
"product",
Original file line number Diff line number Diff line change
@@ -603,22 +603,6 @@ public void query_sortByVectorField_shouldThrowException() throws Exception {
}
}

/** Not Supported */
@Test
public void query_functionQueryUsage_shouldThrowException() throws Exception {
try {
initCore("solrconfig-basic.xml", "schema-densevector.xml");

assertQEx(
"Running Function queries on a dense vector field should raise an Exception",
"Function queries are not supported for Dense Vector fields.",
req("q", "*:*", "fl", "id,field(vector)"),
SolrException.ErrorCode.BAD_REQUEST);
} finally {
deleteCore();
}
}

@Test
public void denseVectorField_shouldBePresentAfterAtomicUpdate() throws Exception {
try {
10 changes: 10 additions & 0 deletions solr/core/src/test/org/apache/solr/search/QueryEqualityTest.java
Original file line number Diff line number Diff line change
@@ -908,6 +908,16 @@ public void testFuncVector() throws Exception {
assertFuncEquals("vector(foo_i,sum(4,bar_i))", "vector(foo_i, sum(4,bar_i))");
}

public void testFuncKnnVector() throws Exception {
assertFuncEquals(
"vectorSimilarity(FLOAT32,COSINE,[1,2,3],[4,5,6])",
"vectorSimilarity(FLOAT32, COSINE, [1, 2, 3], [4, 5, 6])");

assertFuncEquals(
"vectorSimilarity(BYTE, EUCLIDEAN, bar_i, [4,5,6])",
"vectorSimilarity(BYTE, EUCLIDEAN, field(bar_i), [4, 5, 6])");
}

public void testFuncQuery() throws Exception {
SolrQueryRequest req = req("myQ", "asdf");
try {
Loading

0 comments on commit 8b320c4

Please sign in to comment.