Skip to content

Commit

Permalink
Random sample in local mode (#705)
Browse files Browse the repository at this point in the history
* pre-implement random sampling

* generate models

* add conversions and tests

* fix mypy lints

* tests: add test for sample random conversion

* use camelcase Sample.Random

* review fixes

* fix mypy

---------

Co-authored-by: George Panchuk <george.panchuk@qdrant.tech>
  • Loading branch information
coszio and joein committed Aug 8, 2024
1 parent 5f11d89 commit dc585b3
Show file tree
Hide file tree
Showing 7 changed files with 270 additions and 139 deletions.
20 changes: 20 additions & 0 deletions qdrant_client/conversions/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,13 @@ def convert_fusion(cls, model: grpc.Fusion) -> rest.Fusion:

raise ValueError(f"invalid Fusion model: {model}") # pragma: no cover

@classmethod
def convert_sample(cls, model: grpc.Sample) -> rest.Sample:
if model == grpc.Sample.Random:
return rest.Sample.RANDOM

raise ValueError(f"invalid Sample model: {model}") # pragma: no cover

@classmethod
def convert_query(cls, model: grpc.Query) -> rest.Query:
name = model.WhichOneof("variant")
Expand All @@ -967,6 +974,9 @@ def convert_query(cls, model: grpc.Query) -> rest.Query:
if name == "fusion":
return rest.FusionQuery(fusion=cls.convert_fusion(val))

if name == "sample":
return rest.SampleQuery(sample=cls.convert_sample(val))

raise ValueError(f"invalid Query model: {model}")

@classmethod
Expand Down Expand Up @@ -2579,6 +2589,13 @@ def convert_fusion(cls, model: rest.Fusion) -> grpc.Fusion:

raise ValueError(f"invalid Fusion model: {model}")

@classmethod
def convert_sample(cls, model: rest.Sample) -> grpc.Sample:
if model == rest.Sample.RANDOM:
return grpc.Sample.Random

raise ValueError(f"invalid Sample model: {model}")

@classmethod
def convert_query(cls, model: rest.Query) -> grpc.Query:
if isinstance(model, rest.NearestQuery):
Expand All @@ -2599,6 +2616,9 @@ def convert_query(cls, model: rest.Query) -> grpc.Query:
if isinstance(model, rest.FusionQuery):
return grpc.Query(fusion=cls.convert_fusion(model.fusion))

if isinstance(model, rest.SampleQuery):
return grpc.Query(sample=cls.convert_sample(model.sample))

raise ValueError(f"invalid Query model: {model}") # pragma: no cover

@classmethod
Expand Down
281 changes: 143 additions & 138 deletions qdrant_client/grpc/points_pb2.py

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions qdrant_client/http/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1843,6 +1843,14 @@ class RunningEnvironmentTelemetry(BaseModel):
cpu_flags: str = Field(..., description="")


class Sample(str, Enum):
RANDOM = "random"


class SampleQuery(BaseModel, extra="forbid"):
sample: "Sample" = Field(..., description="")


class ScalarQuantization(BaseModel, extra="forbid"):
scalar: "ScalarQuantizationConfig" = Field(..., description="")

Expand Down Expand Up @@ -2691,6 +2699,7 @@ def __str__(self) -> str:
ContextQuery,
OrderByQuery,
FusionQuery,
SampleQuery,
]
RangeInterface = Union[
Range,
Expand Down
53 changes: 53 additions & 0 deletions qdrant_client/local/local_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,16 @@ def _query_collection(
with_vectors=with_vectors,
)
return [record_to_scored_point(record) for record in records[offset:]]
elif isinstance(query, models.SampleQuery):
if query.sample == models.Sample.RANDOM:
return self._sample_randomly(
limit=limit + offset,
query_filter=query_filter,
with_payload=with_payload,
with_vectors=with_vectors,
)
else:
raise ValueError(f"Unknown Sample variant: {query.sample}")
elif isinstance(query, models.FusionQuery):
raise AssertionError("Cannot perform fusion without prefetches")
else:
Expand Down Expand Up @@ -1762,6 +1772,49 @@ def _scroll_by_value(

return result, None

def _sample_randomly(
self,
limit: int,
query_filter: Optional[types.Filter],
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
with_vectors: Union[bool, Sequence[str]] = False,
) -> List[types.ScoredPoint]:
payload_mask = calculate_payload_mask(
payloads=self.payload,
payload_filter=query_filter,
ids_inv=self.ids_inv,
)
# in deleted: 1 - deleted, 0 - not deleted
# in payload_mask: 1 - accepted, 0 - rejected
# in mask: 1 - ok, 0 - rejected
mask = payload_mask & ~self.deleted

random_scores = np.random.rand(len(self.ids))
random_order = np.argsort(random_scores)

result: list[types.ScoredPoint] = []
for idx in random_order:
if len(result) >= limit:
break

if not mask[idx]:
continue

point_id = self.ids_inv[idx]

scored_point = construct(
models.ScoredPoint,
id=point_id,
score=float(0),
version=0,
payload=self._get_payload(idx, with_payload),
vector=self._get_vectors(idx, with_vectors),
)

result.append(scored_point)

return result

def _update_point(self, point: models.PointStruct) -> None:
idx = self.ids[point.id]
self.payload[idx] = deepcopy(
Expand Down
10 changes: 10 additions & 0 deletions qdrant_client/proto/points.proto
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,15 @@ enum Fusion {
DBSF = 1; // Distribution-Based Score Fusion
}

/// Sample points from the collection
///
/// Available sampling methods:
///
/// * `random` - Random sampling
enum Sample {
Random = 0;
}

message Query {
oneof variant {
VectorInput nearest = 1; // Find the nearest neighbors to this vector.
Expand All @@ -516,6 +525,7 @@ message Query {
ContextInput context = 4; // Return points that live in positive areas.
OrderBy order_by = 5; // Order the points by a payload field.
Fusion fusion = 6; // Fuse the results of multiple prefetches.
Sample sample = 7; // Sample points from the collection.
}
}

Expand Down
28 changes: 28 additions & 0 deletions tests/congruence_tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,19 @@ def dense_query_lookup_from(
def no_query_no_prefetch(cls, client: QdrantBase) -> models.QueryResponse:
return client.query_points(collection_name=COLLECTION_NAME, limit=10)

@classmethod
def random_query(cls, client: QdrantBase) -> models.QueryResponse:
result = client.query_points(
collection_name=COLLECTION_NAME,
query=models.SampleQuery(sample=models.Sample.RANDOM),
limit=100,
)

# sort to be able to compare
result.points.sort(key=lambda point: point.id)

return result


def group_by_keys():
return ["maybe", "rand_digit", "two_words", "city.name", "maybe_null", "id"]
Expand Down Expand Up @@ -1411,3 +1424,18 @@ def test_query_group(prefer_grpc):
except AssertionError as e:
print(f"\nFailed with filter {query_filter}")
raise e


@pytest.mark.parametrize("prefer_grpc", [False, True])
def test_random_sampling(prefer_grpc):
fixture_points = generate_fixtures(100)

searcher = TestSimpleSearcher()

local_client = init_local()
init_client(local_client, fixture_points)

remote_client = init_remote(prefer_grpc=prefer_grpc)
init_client(remote_client, fixture_points)

compare_client_results(local_client, remote_client, searcher.random_query)
8 changes: 7 additions & 1 deletion tests/conversions/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,13 +1077,19 @@
query_order_by = grpc.Query(order_by=order_by)
query_fusion = grpc.Query(fusion=grpc.Fusion.RRF)
query_fusion_dbsf = grpc.Query(fusion=grpc.Fusion.DBSF)
query_sample = grpc.Query(sample=grpc.Sample.Random)

deep_prefetch_query = grpc.PrefetchQuery(query=query_recommend)
prefetch_query = grpc.PrefetchQuery(
prefetch=[deep_prefetch_query],
filter=filter_,
query=query_fusion_dbsf,
)
prefetch_random_sample = grpc.PrefetchQuery(
prefetch=[deep_prefetch_query],
filter=filter_,
query=query_sample
)
prefetch_full_query = grpc.PrefetchQuery(
prefetch=[prefetch_query],
query=query_fusion,
Expand All @@ -1094,7 +1100,7 @@
lookup_from=lookup_location_1,
)
prefetch_many = grpc.PrefetchQuery(
prefetch=[prefetch_query, prefetch_full_query],
prefetch=[prefetch_query, prefetch_full_query, prefetch_random_sample],
)

health_check_reply = grpc.HealthCheckReply(
Expand Down

0 comments on commit dc585b3

Please sign in to comment.