Skip to content

Commit

Permalink
Distribution-based score fusion in local mode (#703)
Browse files Browse the repository at this point in the history
* pre-implement dbsf

* add dbsf congruence tests

* mypy lints

* add conversions

* tests: add test for dbsf conversion

---------

Co-authored-by: George Panchuk <george.panchuk@qdrant.tech>
  • Loading branch information
coszio and joein authored Aug 8, 2024
1 parent 6f68b49 commit 76a30d2
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 24 deletions.
16 changes: 14 additions & 2 deletions qdrant_client/conversions/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,9 @@ def convert_fusion(cls, model: grpc.Fusion) -> rest.Fusion:
if model == grpc.Fusion.RRF:
return rest.Fusion.RRF

if model == grpc.Fusion.DBSF:
return rest.Fusion.DBSF

raise ValueError(f"invalid Fusion model: {model}") # pragma: no cover

@classmethod
Expand Down Expand Up @@ -2522,8 +2525,12 @@ def convert_vector_input(cls, model: rest.VectorInput) -> grpc.VectorInput:
@classmethod
def convert_recommend_input(cls, model: rest.RecommendInput) -> grpc.RecommendInput:
return grpc.RecommendInput(
positive=[cls.convert_vector_input(vector) for vector in model.positive] if model.positive is not None else None,
negative=[cls.convert_vector_input(vector) for vector in model.negative] if model.negative is not None else None,
positive=[cls.convert_vector_input(vector) for vector in model.positive]
if model.positive is not None
else None,
negative=[cls.convert_vector_input(vector) for vector in model.negative]
if model.negative is not None
else None,
strategy=cls.convert_recommend_strategy(model.strategy)
if model.strategy is not None
else None,
Expand Down Expand Up @@ -2559,6 +2566,11 @@ def convert_fusion(cls, model: rest.Fusion) -> grpc.Fusion:
if model == rest.Fusion.RRF:
return grpc.Fusion.RRF

if model == rest.Fusion.DBSF:
return grpc.Fusion.DBSF

raise ValueError(f"invalid Fusion model: {model}")

@classmethod
def convert_query(cls, model: rest.Query) -> grpc.Query:
if isinstance(model, rest.NearestQuery):
Expand Down
32 changes: 32 additions & 0 deletions qdrant_client/hybrid/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,35 @@ def compute_score(pos: int) -> float:
point.score = score
sorted_points.append(point)
return sorted_points


def distribution_based_score_fusion(
responses: List[List[models.ScoredPoint]], limit: int
) -> List[models.ScoredPoint]:
def normalize(response: List[models.ScoredPoint]) -> List[models.ScoredPoint]:
total = sum([point.score for point in response])
mean = total / len(response)
variance = sum([(point.score - mean) ** 2 for point in response]) / (len(response) - 1)
std_dev = variance**0.5

low = mean - 3 * std_dev
high = mean + 3 * std_dev

for point in response:
point.score = (point.score - low) / (high - low)

return response

points_map: dict[models.ExtendedPointId, models.ScoredPoint] = {}
for response in responses:
normalized = normalize(response)
for point in normalized:
entry = points_map.get(point.id)
if entry is None:
points_map[point.id] = point
else:
entry.score += point.score

sorted_points = sorted(points_map.values(), key=lambda item: item.score, reverse=True)

return sorted_points[:limit]
30 changes: 16 additions & 14 deletions qdrant_client/local/local_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from qdrant_client.conversions.conversion import GrpcToRest
from qdrant_client.http import models
from qdrant_client.http.models.models import Distance, ExtendedPointId, SparseVector, OrderValue
from qdrant_client.hybrid.fusion import reciprocal_rank_fusion
from qdrant_client.hybrid.fusion import reciprocal_rank_fusion, distribution_based_score_fusion
from qdrant_client.local.distances import (
ContextPair,
ContextQuery,
Expand Down Expand Up @@ -759,21 +759,23 @@ def _merge_sources(
# Fuse results
if query.fusion == models.Fusion.RRF:
# RRF: Reciprocal Rank Fusion
rrf_results = reciprocal_rank_fusion(responses=sources, limit=limit + offset)

# Fetch payload and vectors
ids = [point.id for point in rrf_results]
fetched_points = self.retrieve(
ids, with_payload=with_payload, with_vectors=with_vectors
)
for fetched, scored in zip(fetched_points, rrf_results):
scored.payload = fetched.payload
scored.vector = fetched.vector

return rrf_results[offset:]

fused = reciprocal_rank_fusion(responses=sources, limit=limit + offset)
elif query.fusion == models.Fusion.DBSF:
# DBSF: Distribution-Based Score Fusion
fused = distribution_based_score_fusion(responses=sources, limit=limit + offset)
else:
raise ValueError(f"Fusion method {query.fusion} does not exist")

# Fetch payload and vectors
ids = [point.id for point in fused]
fetched_points = self.retrieve(
ids, with_payload=with_payload, with_vectors=with_vectors
)
for fetched, scored in zip(fetched_points, fused):
scored.payload = fetched.payload
scored.vector = fetched.vector

return fused[offset:]
else:
# Re-score
sources_ids = set()
Expand Down
63 changes: 55 additions & 8 deletions tests/congruence_tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,7 @@ def dense_query_group_with_lookup(self, client: QdrantBase) -> GroupsResult:
)

def filter_dense_query_group(
self,
client: QdrantBase,
query_filter: models.Filter
self, client: QdrantBase, query_filter: models.Filter
) -> GroupsResult:
return client.query_points_groups(
collection_name=COLLECTION_NAME,
Expand Down Expand Up @@ -264,7 +262,9 @@ def dense_queries_rescore_group(self, client: QdrantBase) -> GroupsResult:
limit=self.limit,
)

def dense_query_lookup_from_group(self, client: QdrantBase, lookup_from: models.LookupLocation) -> GroupsResult:
def dense_query_lookup_from_group(
self, client: QdrantBase, lookup_from: models.LookupLocation
) -> GroupsResult:
return client.query_points_groups(
collection_name=COLLECTION_NAME,
query=models.RecommendQuery(
Expand Down Expand Up @@ -314,7 +314,7 @@ def filter_query_scroll(
limit=10,
)

def dense_query_fusion(self, client: QdrantBase) -> models.QueryResponse:
def dense_query_rrf(self, client: QdrantBase) -> models.QueryResponse:
return client.query_points(
collection_name=COLLECTION_NAME,
prefetch=[
Expand All @@ -328,7 +328,22 @@ def dense_query_fusion(self, client: QdrantBase) -> models.QueryResponse:
limit=10,
)

def deep_dense_queries_fusion(self, client: QdrantBase) -> models.QueryResponse:
def dense_query_dbsf(self, client: QdrantBase) -> models.QueryResponse:
return client.query_points(
collection_name=COLLECTION_NAME,
prefetch=[
models.Prefetch(
query=self.dense_vector_query_text,
using="text",
),
models.Prefetch(query=self.dense_vector_query_code, using="code"),
],
query=models.FusionQuery(fusion=models.Fusion.DBSF),
with_payload=True,
limit=10,
)

def deep_dense_queries_rrf(self, client: QdrantBase) -> models.QueryResponse:
return client.query_points(
collection_name=COLLECTION_NAME,
prefetch=[
Expand Down Expand Up @@ -357,6 +372,35 @@ def deep_dense_queries_fusion(self, client: QdrantBase) -> models.QueryResponse:
limit=10,
)

def deep_dense_queries_dbsf(self, client: QdrantBase) -> models.QueryResponse:
return client.query_points(
collection_name=COLLECTION_NAME,
prefetch=[
models.Prefetch(
query=self.dense_vector_query_code,
using="code",
limit=30,
prefetch=[
models.Prefetch(
query=self.dense_vector_query_image,
using="image",
limit=40,
prefetch=[
models.Prefetch(
query=self.dense_vector_query_text,
using="text",
limit=50,
)
],
)
],
)
],
query=models.FusionQuery(fusion=models.Fusion.DBSF),
with_payload=True,
limit=10,
)

def dense_queries_rescore(self, client: QdrantBase) -> models.QueryResponse:
return client.query_points(
collection_name=COLLECTION_NAME,
Expand Down Expand Up @@ -631,6 +675,7 @@ def group_by_keys():

# ---- TESTS ---- #


def test_dense_query_lookup_from_another_collection():
fixture_points = generate_fixtures(10)

Expand Down Expand Up @@ -972,8 +1017,10 @@ def test_dense_query_fusion():
remote_client = init_remote()
init_client(remote_client, fixture_points)

compare_client_results(local_client, remote_client, searcher.dense_query_fusion)
compare_client_results(local_client, remote_client, searcher.deep_dense_queries_fusion)
compare_client_results(local_client, remote_client, searcher.dense_query_rrf)
compare_client_results(local_client, remote_client, searcher.dense_query_dbsf)
compare_client_results(local_client, remote_client, searcher.deep_dense_queries_rrf)
compare_client_results(local_client, remote_client, searcher.deep_dense_queries_dbsf)


def test_dense_query_discovery_context():
Expand Down
2 changes: 2 additions & 0 deletions tests/conversions/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,11 +1076,13 @@
query_context = grpc.Query(context=context_input)
query_order_by = grpc.Query(order_by=order_by)
query_fusion = grpc.Query(fusion=grpc.Fusion.RRF)
query_fusion_dbsf = grpc.Query(fusion=grpc.Fusion.DBSF)

deep_prefetch_query = grpc.PrefetchQuery(query=query_recommend)
prefetch_query = grpc.PrefetchQuery(
prefetch=[deep_prefetch_query],
filter=filter_,
query=query_fusion_dbsf,
)
prefetch_full_query = grpc.PrefetchQuery(
prefetch=[prefetch_query],
Expand Down

0 comments on commit 76a30d2

Please sign in to comment.