Skip to content

Commit

Permalink
GITHUB-12252: Add function queries for computing similarity scores be…
Browse files Browse the repository at this point in the history
…tween knn vectors (#12253)

Co-authored-by: Alessandro Benedetti <a.benedetti@sease.io>
  • Loading branch information
eliaporciani and alessandrobenedetti committed Jun 14, 2023
1 parent 7f46d58 commit 433aa49
Show file tree
Hide file tree
Showing 11 changed files with 833 additions and 0 deletions.
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ New Features
pass the following sysprop on Java command line:
"-Dorg.apache.lucene.store.MMapDirectory.enableMemorySegments=false" (Uwe Schindler)

* GITHUB#12252 Add function queries for computing similarity scores between knn vectors. (Elia Porciani, Alessandro Benedetti)

Improvements
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ public boolean boolVal(int doc) throws IOException {
return intVal(doc) != 0;
}

public float[] floatVectorVal(int doc) throws IOException {
throw new UnsupportedOperationException();
}

public byte[] byteVectorVal(int doc) throws IOException {
throw new UnsupportedOperationException();
}

/**
* returns the bytes representation of the string val - TODO: should this return the indexed raw
* bytes not?
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.queries.function.valuesource;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.search.DocIdSetIterator;

/**
* An implementation for retrieving {@link FunctionValues} instances for byte knn vectors fields.
*/
public class ByteKnnVectorFieldSource extends ValueSource {
private final String fieldName;

public ByteKnnVectorFieldSource(String fieldName) {
this.fieldName = fieldName;
}

@Override
public FunctionValues getValues(Map<Object, Object> context, LeafReaderContext readerContext)
throws IOException {

final ByteVectorValues vectorValues = readerContext.reader().getByteVectorValues(fieldName);

if (vectorValues == null) {
throw new IllegalArgumentException(
"no byte vector value is indexed for field '" + fieldName + "'");
}

return new VectorFieldFunction(this) {

@Override
public byte[] byteVectorVal(int doc) throws IOException {
if (exists(doc)) {
return vectorValues.vectorValue();
} else {
return null;
}
}

@Override
protected DocIdSetIterator getVectorIterator() {
return vectorValues;
}
};
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ByteKnnVectorFieldSource other = (ByteKnnVectorFieldSource) o;
return Objects.equals(fieldName, other.fieldName);
}

@Override
public int hashCode() {
return Objects.hash(getClass().hashCode(), fieldName);
}

@Override
public String description() {
return "ByteKnnVectorFieldSource(" + fieldName + ")";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.queries.function.valuesource;

import java.io.IOException;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.ValueSource;

/**
* <code>ByteVectorSimilarityFunction</code> returns a similarity function between two knn vectors
* with byte elements.
*/
public class ByteVectorSimilarityFunction extends VectorSimilarityFunction {
public ByteVectorSimilarityFunction(
org.apache.lucene.index.VectorSimilarityFunction similarityFunction,
ValueSource vector1,
ValueSource vector2) {
super(similarityFunction, vector1, vector2);
}

@Override
protected float func(int doc, FunctionValues f1, FunctionValues f2) throws IOException {

var v1 = f1.byteVectorVal(doc);
var v2 = f2.byteVectorVal(doc);

if (v1 == null || v2 == null) {
return 0.f;
}

assert v1.length == v2.length : "Vectors must have the same length";

return similarityFunction.compare(v1, v2);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.queries.function.valuesource;

import java.io.IOException;
import java.util.Arrays;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.ValueSource;

/** Function that returns a constant byte vector value for every document. */
public class ConstKnnByteVectorValueSource extends ValueSource {
private final byte[] vector;

public ConstKnnByteVectorValueSource(byte[] constVector) {
this.vector = Objects.requireNonNull(constVector, "constVector");
}

@Override
public FunctionValues getValues(Map<Object, Object> context, LeafReaderContext readerContext)
throws IOException {
return new FunctionValues() {
@Override
public byte[] byteVectorVal(int doc) {
return vector;
}

@Override
public String strVal(int doc) {
return Arrays.toString(vector);
}

@Override
public String toString(int doc) throws IOException {
return description() + '=' + strVal(doc);
}
};
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ConstKnnByteVectorValueSource other = (ConstKnnByteVectorValueSource) o;
return Arrays.equals(vector, other.vector);
}

@Override
public int hashCode() {
return Objects.hash(getClass().hashCode(), Arrays.hashCode(vector));
}

@Override
public String description() {
return "ConstKnnByteVectorValueSource(" + Arrays.toString(vector) + ')';
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.queries.function.valuesource;

import java.io.IOException;
import java.util.Arrays;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.util.VectorUtil;

/** Function that returns a constant float vector value for every document. */
public class ConstKnnFloatValueSource extends ValueSource {
private final float[] vector;

public ConstKnnFloatValueSource(float[] constVector) {
this.vector = VectorUtil.checkFinite(Objects.requireNonNull(constVector, "constVector"));
}

@Override
public FunctionValues getValues(Map<Object, Object> context, LeafReaderContext readerContext)
throws IOException {
return new FunctionValues() {
@Override
public float[] floatVectorVal(int doc) {
return vector;
}

@Override
public String strVal(int doc) {
return Arrays.toString(vector);
}

@Override
public String toString(int doc) throws IOException {
return description() + '=' + strVal(doc);
}
};
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ConstKnnFloatValueSource other = (ConstKnnFloatValueSource) o;
return Arrays.equals(vector, other.vector);
}

@Override
public int hashCode() {
return Objects.hash(getClass().hashCode(), Arrays.hashCode(vector));
}

@Override
public String description() {
return "ConstKnnFloatValueSource(" + Arrays.toString(vector) + ')';
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.queries.function.valuesource;

import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.queries.function.FunctionValues;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.search.DocIdSetIterator;

/**
* An implementation for retrieving {@link FunctionValues} instances for float knn vectors fields.
*/
public class FloatKnnVectorFieldSource extends ValueSource {
private final String fieldName;

public FloatKnnVectorFieldSource(String fieldName) {
this.fieldName = fieldName;
}

@Override
public FunctionValues getValues(Map<Object, Object> context, LeafReaderContext readerContext)
throws IOException {

final FloatVectorValues vectorValues = readerContext.reader().getFloatVectorValues(fieldName);

if (vectorValues == null) {
throw new IllegalArgumentException(
"no float vector value is indexed for field '" + fieldName + "'");
}
return new VectorFieldFunction(this) {

@Override
public float[] floatVectorVal(int doc) throws IOException {
if (exists(doc)) {
return vectorValues.vectorValue();
} else {
return null;
}
}

@Override
protected DocIdSetIterator getVectorIterator() {
return vectorValues;
}
};
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
FloatKnnVectorFieldSource other = (FloatKnnVectorFieldSource) o;
return Objects.equals(fieldName, other.fieldName);
}

@Override
public int hashCode() {
return Objects.hash(getClass().hashCode(), fieldName);
}

@Override
public String description() {
return "FloatKnnVectorFieldSource(" + fieldName + ")";
}
}
Loading

0 comments on commit 433aa49

Please sign in to comment.