Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GITHUB-12252: Add function queries for computing similarity scores between knn vectors #12253

Merged
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
7431f86
Implementation of function values for dense vector
eliaporciani Apr 28, 2023
ffe7bab
tidy
eliaporciani Apr 28, 2023
92b9651
Add toString for DenseVectorSimilarityFunction
eliaporciani Apr 28, 2023
36d6824
minor fix
eliaporciani Apr 28, 2023
533c6fc
Use HashMap (was TreeMap) for OnHeapHnswGraph neighbors
jbellis Apr 27, 2023
9cb17b0
Fix SynonymQuery equals implementation (#12260)
javanna May 3, 2023
ca5b831
Fix MMapDirectory documentation for Java 20 (#12265)
uschindler May 5, 2023
6f8edc2
GITHUB-12224: remove KnnGraphTester (moved to luceneutil) (#12238)
msokolov May 8, 2023
ab643b7
allocate one NeighborQueue per search for results (#12255)
jbellis May 8, 2023
815e840
Don't generate stacktrace in CollectionTerminatedException (#12270)
original-brownbear May 9, 2023
2937c8d
Move changes entry for #12270 to 9.7.0 section
javanna May 9, 2023
8befea5
add missing changelog entry for #12260
javanna May 9, 2023
3f93aa6
add missing changelog entry for #12220
javanna May 9, 2023
3a6fa03
Make query timeout members final in ExitableDirectoryReader (#12274)
javanna May 9, 2023
2951138
Update javadocs for QueryTimeout (#12272)
javanna May 9, 2023
5f9dfd4
Make TimeExceededException members final (#12271)
javanna May 9, 2023
88856f7
DOAP changes for release 9.6.0
romseygeek May 10, 2023
ec5f3d4
reasoning about thread safety
alessandrobenedetti May 12, 2023
b916d84
minor rename
alessandrobenedetti May 16, 2023
a034425
Merge remote-tracking branch 'upstream/main' into dense_vector_simila…
alessandrobenedetti May 16, 2023
a48dc2e
minor rename
alessandrobenedetti May 16, 2023
01497a6
moved to assertion
alessandrobenedetti May 16, 2023
5a2eabf
tydy and checks
alessandrobenedetti May 16, 2023
7a63933
Managed case of field not indexed for vector valuesource
eliaporciani May 25, 2023
755df0e
Merge branch 'main' of github.com:apache/lucene into dense_vector_sim…
eliaporciani May 25, 2023
363aea0
spotless apply
eliaporciani May 25, 2023
108c4a8
fix typo
eliaporciani May 25, 2023
0e6aaa7
Merge branch 'main' of github.com:apache/lucene into dense_vector_sim…
eliaporciani May 25, 2023
af748e3
Merge remote-tracking branch 'upstream/main' into dense_vector_simila…
alessandrobenedetti May 30, 2023
9e835df
Merge branch 'main' of github.com:apache/lucene into dense_vector_sim…
eliaporciani Jun 13, 2023
d95c6c1
Renaming classes
eliaporciani Jun 13, 2023
69ce94d
code refactoring
eliaporciani Jun 13, 2023
72dd3fc
rename variable
eliaporciani Jun 13, 2023
481f2a0
Addressing review
eliaporciani Jun 13, 2023
901febe
tidy
eliaporciani Jun 13, 2023
7219987
Addressing review
eliaporciani Jun 13, 2023
bc50429
Addressing review: fixed hash computations
eliaporciani Jun 13, 2023
1b69551
Changed default similarity from NaN to 0.f
eliaporciani Jun 13, 2023
8ece485
Changed constructors using arrays instead of lists
eliaporciani Jun 14, 2023
66e1987
Merge branch 'main' of github.com:apache/lucene into dense_vector_sim…
eliaporciani Jun 14, 2023
cf1e31c
Updated CHANGES.txt
eliaporciani Jun 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ public boolean boolVal(int doc) throws IOException {
return intVal(doc) != 0;
}

public float[] floatVectorVal(int doc) throws IOException {
throw new UnsupportedOperationException();
}

public byte[] byteVectorVal(int doc) throws IOException {
throw new UnsupportedOperationException();
}

/**
* returns the bytes representation of the string val - TODO: should this return the indexed raw
* bytes not?
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.queries.function.valuesource;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.search.DocIdSetIterator;

/**
* An implementation for retrieving {@link FunctionValues} instances for byte knn vectors fields.
*/
public class ByteKnnVectorFieldSource extends ValueSource {
private final String fieldName;

public ByteKnnVectorFieldSource(String fieldName) {
this.fieldName = fieldName;
}

@Override
public FunctionValues getValues(Map<Object, Object> context, LeafReaderContext readerContext)
throws IOException {

final ByteVectorValues vectorValues = readerContext.reader().getByteVectorValues(fieldName);

if (vectorValues == null) {
throw new IllegalArgumentException(
"no byte vector value is indexed for field '" + fieldName + "'");
}

return new VectorFieldFunction(this) {

@Override
public byte[] byteVectorVal(int doc) throws IOException {
if (exists(doc)) {
return vectorValues.vectorValue();
} else {
return null;
}
}

@Override
protected DocIdSetIterator getVectorIterator() {
return vectorValues;
}
};
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ByteKnnVectorFieldSource other = (ByteKnnVectorFieldSource) o;
return Objects.equals(fieldName, other.fieldName);
}

@Override
public int hashCode() {
return Objects.hash(getClass().hashCode(), fieldName);
}

@Override
public String description() {
return "ByteKnnVectorFieldSource(" + fieldName + ")";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.queries.function.valuesource;

import java.io.IOException;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.ValueSource;

/**
* <code>ByteVectorSimilarityFunction</code> returns a similarity function between two knn vectors
* with byte elements.
*/
public class ByteVectorSimilarityFunction extends VectorSimilarityFunction {
public ByteVectorSimilarityFunction(
org.apache.lucene.index.VectorSimilarityFunction similarityFunction,
ValueSource vector1,
ValueSource vector2) {
super(similarityFunction, vector1, vector2);
}

@Override
protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException {

var v1 = f1.byteVectorVal(doc);
var v2 = f2.byteVectorVal(doc);

if (v1 == null || v2 == null) {
uschindler marked this conversation as resolved.
Show resolved Hide resolved
return 0.f;
}

assert v1.length == v2.length : "Vectors must have the same length";

return similarityFunction.compare(v1, v2);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.queries.function.valuesource;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.ValueSource;

/** Function that returns a constant byte vector value for every document. */
public class ConstKnnByteVectorValueSource extends ValueSource {
private final byte[] vector;

public ConstKnnByteVectorValueSource(List<Number> constVector) {
uschindler marked this conversation as resolved.
Show resolved Hide resolved
this.vector = new byte[constVector.size()];
for (int i = 0; i < constVector.size(); i++) {
vector[i] = constVector.get(i).byteValue();
}
}

@Override
public FunctionValues getValues(Map<Object, Object> context, LeafReaderContext readerContext)
throws IOException {
return new FunctionValues() {
@Override
public byte[] byteVectorVal(int doc) {
return vector;
}

@Override
public String strVal(int doc) {
return Arrays.toString(vector);
}

@Override
public String toString(int doc) throws IOException {
return description() + '=' + strVal(doc);
}
};
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ConstKnnByteVectorValueSource other = (ConstKnnByteVectorValueSource) o;
return Arrays.equals(vector, other.vector);
}

@Override
public int hashCode() {
return Objects.hash(getClass().hashCode(), Arrays.hashCode(vector));
}

@Override
public String description() {
return "ConstKnnByteVectorValueSource(" + Arrays.toString(vector) + ')';
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.queries.function.valuesource;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.ValueSource;

/** Function that returns a constant float vector value for every document. */
public class ConstKnnFloatValueSource extends ValueSource {
private final float[] vector;

public ConstKnnFloatValueSource(List<Number> constVector) {
uschindler marked this conversation as resolved.
Show resolved Hide resolved
this.vector = new float[constVector.size()];
for (int i = 0; i < constVector.size(); i++) {
vector[i] = constVector.get(i).floatValue();
}
}

@Override
public FunctionValues getValues(Map<Object, Object> context, LeafReaderContext readerContext)
throws IOException {
return new FunctionValues() {
@Override
public float[] floatVectorVal(int doc) {
return vector;
}

@Override
public String strVal(int doc) {
return Arrays.toString(vector);
}

@Override
public String toString(int doc) throws IOException {
return description() + '=' + strVal(doc);
}
};
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ConstKnnFloatValueSource other = (ConstKnnFloatValueSource) o;
return Arrays.equals(vector, other.vector);
}

@Override
public int hashCode() {
return Objects.hash(getClass().hashCode(), Arrays.hashCode(vector));
}

@Override
public String description() {
return "ConstKnnFloatValueSource(" + Arrays.toString(vector) + ')';
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.queries.function.valuesource;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.search.DocIdSetIterator;

/**
* An implementation for retrieving {@link FunctionValues} instances for float knn vectors fields.
*/
public class FloatKnnVectorFieldSource extends ValueSource {
private final String fieldName;

public FloatKnnVectorFieldSource(String fieldName) {
this.fieldName = fieldName;
}

@Override
public FunctionValues getValues(Map<Object, Object> context, LeafReaderContext readerContext)
throws IOException {

final FloatVectorValues vectorValues = readerContext.reader().getFloatVectorValues(fieldName);

if (vectorValues == null) {
throw new IllegalArgumentException(
"no float vector value is indexed for field '" + fieldName + "'");
}
return new VectorFieldFunction(this) {

@Override
public float[] floatVectorVal(int doc) throws IOException {
if (exists(doc)) {
return vectorValues.vectorValue();
} else {
return null;
}
}

@Override
protected DocIdSetIterator getVectorIterator() {
return vectorValues;
}
};
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
FloatKnnVectorFieldSource other = (FloatKnnVectorFieldSource) o;
return Objects.equals(fieldName, other.fieldName);
}

@Override
public int hashCode() {
return Objects.hash(getClass().hashCode(), fieldName);
}

@Override
public String description() {
return "FloatKnnVectorFieldSource(" + fieldName + ")";
}
}
Loading