-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
GITHUB-12252: Add function queries for computing similarity scores be…
…tween knn vectors (#12253) Co-authored-by: Alessandro Benedetti <a.benedetti@sease.io>
- Loading branch information
1 parent
7f46d58
commit 433aa49
Showing
11 changed files
with
833 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
84 changes: 84 additions & 0 deletions
84
...ies/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.lucene.queries.function.valuesource; | ||
|
||
import java.io.IOException; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
import org.apache.lucene.index.ByteVectorValues; | ||
import org.apache.lucene.index.LeafReaderContext; | ||
import org.apache.lucene.queries.function.FunctionValues; | ||
import org.apache.lucene.queries.function.ValueSource; | ||
import org.apache.lucene.search.DocIdSetIterator; | ||
|
||
/** | ||
* An implementation for retrieving {@link FunctionValues} instances for byte knn vectors fields. | ||
*/ | ||
public class ByteKnnVectorFieldSource extends ValueSource { | ||
private final String fieldName; | ||
|
||
public ByteKnnVectorFieldSource(String fieldName) { | ||
this.fieldName = fieldName; | ||
} | ||
|
||
@Override | ||
public FunctionValues getValues(Map<Object, Object> context, LeafReaderContext readerContext) | ||
throws IOException { | ||
|
||
final ByteVectorValues vectorValues = readerContext.reader().getByteVectorValues(fieldName); | ||
|
||
if (vectorValues == null) { | ||
throw new IllegalArgumentException( | ||
"no byte vector value is indexed for field '" + fieldName + "'"); | ||
} | ||
|
||
return new VectorFieldFunction(this) { | ||
|
||
@Override | ||
public byte[] byteVectorVal(int doc) throws IOException { | ||
if (exists(doc)) { | ||
return vectorValues.vectorValue(); | ||
} else { | ||
return null; | ||
} | ||
} | ||
|
||
@Override | ||
protected DocIdSetIterator getVectorIterator() { | ||
return vectorValues; | ||
} | ||
}; | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
ByteKnnVectorFieldSource other = (ByteKnnVectorFieldSource) o; | ||
return Objects.equals(fieldName, other.fieldName); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(getClass().hashCode(), fieldName); | ||
} | ||
|
||
@Override | ||
public String description() { | ||
return "ByteKnnVectorFieldSource(" + fieldName + ")"; | ||
} | ||
} |
49 changes: 49 additions & 0 deletions
49
...src/java/org/apache/lucene/queries/function/valuesource/ByteVectorSimilarityFunction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.lucene.queries.function.valuesource; | ||
|
||
import java.io.IOException; | ||
import org.apache.lucene.queries.function.FunctionValues; | ||
import org.apache.lucene.queries.function.ValueSource; | ||
|
||
/** | ||
* <code>ByteVectorSimilarityFunction</code> returns a similarity function between two knn vectors | ||
* with byte elements. | ||
*/ | ||
public class ByteVectorSimilarityFunction extends VectorSimilarityFunction { | ||
public ByteVectorSimilarityFunction( | ||
org.apache.lucene.index.VectorSimilarityFunction similarityFunction, | ||
ValueSource vector1, | ||
ValueSource vector2) { | ||
super(similarityFunction, vector1, vector2); | ||
} | ||
|
||
@Override | ||
protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException { | ||
|
||
var v1 = f1.byteVectorVal(doc); | ||
var v2 = f2.byteVectorVal(doc); | ||
|
||
if (v1 == null || v2 == null) { | ||
return 0.f; | ||
} | ||
|
||
assert v1.length == v2.length : "Vectors must have the same length"; | ||
|
||
return similarityFunction.compare(v1, v2); | ||
} | ||
} |
73 changes: 73 additions & 0 deletions
73
...rc/java/org/apache/lucene/queries/function/valuesource/ConstKnnByteVectorValueSource.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.lucene.queries.function.valuesource; | ||
|
||
import java.io.IOException; | ||
import java.util.Arrays; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
import org.apache.lucene.index.LeafReaderContext; | ||
import org.apache.lucene.queries.function.FunctionValues; | ||
import org.apache.lucene.queries.function.ValueSource; | ||
|
||
/** Function that returns a constant byte vector value for every document. */ | ||
public class ConstKnnByteVectorValueSource extends ValueSource { | ||
private final byte[] vector; | ||
|
||
public ConstKnnByteVectorValueSource(byte[] constVector) { | ||
this.vector = Objects.requireNonNull(constVector, "constVector"); | ||
} | ||
|
||
@Override | ||
public FunctionValues getValues(Map<Object, Object> context, LeafReaderContext readerContext) | ||
throws IOException { | ||
return new FunctionValues() { | ||
@Override | ||
public byte[] byteVectorVal(int doc) { | ||
return vector; | ||
} | ||
|
||
@Override | ||
public String strVal(int doc) { | ||
return Arrays.toString(vector); | ||
} | ||
|
||
@Override | ||
public String toString(int doc) throws IOException { | ||
return description() + '=' + strVal(doc); | ||
} | ||
}; | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
ConstKnnByteVectorValueSource other = (ConstKnnByteVectorValueSource) o; | ||
return Arrays.equals(vector, other.vector); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(getClass().hashCode(), Arrays.hashCode(vector)); | ||
} | ||
|
||
@Override | ||
public String description() { | ||
return "ConstKnnByteVectorValueSource(" + Arrays.toString(vector) + ')'; | ||
} | ||
} |
74 changes: 74 additions & 0 deletions
74
...ies/src/java/org/apache/lucene/queries/function/valuesource/ConstKnnFloatValueSource.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.lucene.queries.function.valuesource; | ||
|
||
import java.io.IOException; | ||
import java.util.Arrays; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
import org.apache.lucene.index.LeafReaderContext; | ||
import org.apache.lucene.queries.function.FunctionValues; | ||
import org.apache.lucene.queries.function.ValueSource; | ||
import org.apache.lucene.util.VectorUtil; | ||
|
||
/** Function that returns a constant float vector value for every document. */ | ||
public class ConstKnnFloatValueSource extends ValueSource { | ||
private final float[] vector; | ||
|
||
public ConstKnnFloatValueSource(float[] constVector) { | ||
this.vector = VectorUtil.checkFinite(Objects.requireNonNull(constVector, "constVector")); | ||
} | ||
|
||
@Override | ||
public FunctionValues getValues(Map<Object, Object> context, LeafReaderContext readerContext) | ||
throws IOException { | ||
return new FunctionValues() { | ||
@Override | ||
public float[] floatVectorVal(int doc) { | ||
return vector; | ||
} | ||
|
||
@Override | ||
public String strVal(int doc) { | ||
return Arrays.toString(vector); | ||
} | ||
|
||
@Override | ||
public String toString(int doc) throws IOException { | ||
return description() + '=' + strVal(doc); | ||
} | ||
}; | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
ConstKnnFloatValueSource other = (ConstKnnFloatValueSource) o; | ||
return Arrays.equals(vector, other.vector); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(getClass().hashCode(), Arrays.hashCode(vector)); | ||
} | ||
|
||
@Override | ||
public String description() { | ||
return "ConstKnnFloatValueSource(" + Arrays.toString(vector) + ')'; | ||
} | ||
} |
83 changes: 83 additions & 0 deletions
83
...es/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package org.apache.lucene.queries.function.valuesource; | ||
|
||
import java.io.IOException; | ||
import java.util.Map; | ||
import java.util.Objects; | ||
import org.apache.lucene.index.FloatVectorValues; | ||
import org.apache.lucene.index.LeafReaderContext; | ||
import org.apache.lucene.queries.function.FunctionValues; | ||
import org.apache.lucene.queries.function.ValueSource; | ||
import org.apache.lucene.search.DocIdSetIterator; | ||
|
||
/** | ||
* An implementation for retrieving {@link FunctionValues} instances for float knn vectors fields. | ||
*/ | ||
public class FloatKnnVectorFieldSource extends ValueSource { | ||
private final String fieldName; | ||
|
||
public FloatKnnVectorFieldSource(String fieldName) { | ||
this.fieldName = fieldName; | ||
} | ||
|
||
@Override | ||
public FunctionValues getValues(Map<Object, Object> context, LeafReaderContext readerContext) | ||
throws IOException { | ||
|
||
final FloatVectorValues vectorValues = readerContext.reader().getFloatVectorValues(fieldName); | ||
|
||
if (vectorValues == null) { | ||
throw new IllegalArgumentException( | ||
"no float vector value is indexed for field '" + fieldName + "'"); | ||
} | ||
return new VectorFieldFunction(this) { | ||
|
||
@Override | ||
public float[] floatVectorVal(int doc) throws IOException { | ||
if (exists(doc)) { | ||
return vectorValues.vectorValue(); | ||
} else { | ||
return null; | ||
} | ||
} | ||
|
||
@Override | ||
protected DocIdSetIterator getVectorIterator() { | ||
return vectorValues; | ||
} | ||
}; | ||
} | ||
|
||
@Override | ||
public boolean equals(Object o) { | ||
if (this == o) return true; | ||
if (o == null || getClass() != o.getClass()) return false; | ||
FloatKnnVectorFieldSource other = (FloatKnnVectorFieldSource) o; | ||
return Objects.equals(fieldName, other.fieldName); | ||
} | ||
|
||
@Override | ||
public int hashCode() { | ||
return Objects.hash(getClass().hashCode(), fieldName); | ||
} | ||
|
||
@Override | ||
public String description() { | ||
return "FloatKnnVectorFieldSource(" + fieldName + ")"; | ||
} | ||
} |
Oops, something went wrong.