Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random sample in local mode #705

Merged
merged 8 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions qdrant_client/conversions/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,6 +944,13 @@ def convert_fusion(cls, model: grpc.Fusion) -> rest.Fusion:

raise ValueError(f"invalid Fusion model: {model}") # pragma: no cover

@classmethod
def convert_sample(cls, model: grpc.Sample) -> rest.Sample:
if model == grpc.Sample.Random:
return rest.Sample.RANDOM

raise ValueError(f"invalid Sample model: {model}") # pragma: no cover

@classmethod
def convert_query(cls, model: grpc.Query) -> rest.Query:
name = model.WhichOneof("variant")
Expand All @@ -967,6 +974,9 @@ def convert_query(cls, model: grpc.Query) -> rest.Query:
if name == "fusion":
return rest.FusionQuery(fusion=cls.convert_fusion(val))

if name == "sample":
return rest.SampleQuery(sample=cls.convert_sample(val))

raise ValueError(f"invalid Query model: {model}")

@classmethod
Expand Down Expand Up @@ -2571,6 +2581,13 @@ def convert_fusion(cls, model: rest.Fusion) -> grpc.Fusion:

raise ValueError(f"invalid Fusion model: {model}")

@classmethod
def convert_sample(cls, model: rest.Sample) -> grpc.Sample:
if model == rest.Sample.RANDOM:
return grpc.Sample.Random

raise ValueError(f"invalid Sample model: {model}")

@classmethod
def convert_query(cls, model: rest.Query) -> grpc.Query:
if isinstance(model, rest.NearestQuery):
Expand All @@ -2591,6 +2608,9 @@ def convert_query(cls, model: rest.Query) -> grpc.Query:
if isinstance(model, rest.FusionQuery):
return grpc.Query(fusion=cls.convert_fusion(model.fusion))

if isinstance(model, rest.SampleQuery):
return grpc.Query(sample=cls.convert_sample(model.sample))

raise ValueError(f"invalid Query model: {model}") # pragma: no cover

@classmethod
Expand Down
281 changes: 143 additions & 138 deletions qdrant_client/grpc/points_pb2.py

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions qdrant_client/http/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1838,6 +1838,14 @@ class RunningEnvironmentTelemetry(BaseModel):
cpu_flags: str = Field(..., description="")


class Sample(str, Enum):
RANDOM = "random"


class SampleQuery(BaseModel, extra="forbid"):
sample: "Sample" = Field(..., description="")


class ScalarQuantization(BaseModel, extra="forbid"):
scalar: "ScalarQuantizationConfig" = Field(..., description="")

Expand Down Expand Up @@ -2686,6 +2694,7 @@ def __str__(self) -> str:
ContextQuery,
OrderByQuery,
FusionQuery,
SampleQuery,
]
RangeInterface = Union[
Range,
Expand Down
56 changes: 54 additions & 2 deletions qdrant_client/local/local_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,6 @@ def query_points(

Assumes all vectors have been homogenized so that there are no ids in the inputs
"""
scored_points = []

prefetches = []
if prefetch is not None:
Expand Down Expand Up @@ -784,7 +783,7 @@ def _merge_sources(
sources_ids.add(point.id)

if len(sources_ids) == 0:
# no need to perform a query if there no matches for the sources
# no need to perform a query if there are no matches for the sources
return []
else:
filter_with_sources = _include_ids_in_filter(query_filter, list(sources_ids))
Expand Down Expand Up @@ -881,6 +880,16 @@ def _query_collection(
with_vectors=with_vectors,
)
return [record_to_scored_point(record) for record in records[offset:]]
elif isinstance(query, models.SampleQuery):
if query.sample == models.Sample.RANDOM:
return self._sample_randomly(
limit=limit + offset,
query_filter=query_filter,
with_payload=with_payload,
with_vectors=with_vectors,
)
else:
raise ValueError(f"Unknown Sample variant: {query.sample}")
elif isinstance(query, models.FusionQuery):
raise AssertionError("Cannot perform fusion without prefetches")
else:
Expand Down Expand Up @@ -1769,6 +1778,49 @@ def _scroll_by_value(

return result, None

def _sample_randomly(
self,
limit: int,
query_filter: Optional[types.Filter],
with_payload: Union[bool, Sequence[str], types.PayloadSelector] = True,
with_vectors: Union[bool, Sequence[str]] = False,
) -> List[types.ScoredPoint]:
payload_mask = calculate_payload_mask(
payloads=self.payload,
payload_filter=query_filter,
ids_inv=self.ids_inv,
)
# in deleted: 1 - deleted, 0 - not deleted
# in payload_mask: 1 - accepted, 0 - rejected
# in mask: 1 - ok, 0 - rejected
mask = payload_mask & ~self.deleted

random_scores = np.random.rand(len(self.ids))
random_order = np.argsort(random_scores)
Comment on lines +1798 to +1799
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: np.random.permutation(self.inv_ids)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but we need the internal ids to filter against the mask, not the external ones


result: list[types.ScoredPoint] = []
for idx in random_order:
if len(result) >= limit:
break

if not mask[idx]:
continue

point_id = self.ids_inv[idx]

scored_point = construct(
models.ScoredPoint,
id=point_id,
score=float(0),
version=0,
payload=self._get_payload(idx, with_payload),
vector=self._get_vectors(idx, with_vectors),
)

result.append(scored_point)

return result

def _update_point(self, point: models.PointStruct) -> None:
idx = self.ids[point.id]
self.payload[idx] = deepcopy(
Expand Down
28 changes: 19 additions & 9 deletions qdrant_client/proto/points.proto
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ message SearchParams {
optional bool exact = 2;

/*
If set to true, search will ignore quantized vector data
If set to true, search will ignore quantized vector data
*/
optional QuantizationSearchParams quantization = 3;
/*
Expand Down Expand Up @@ -363,12 +363,12 @@ message ScrollPoints {

// How to use positive and negative vectors to find the results, default is `AverageVector`.
enum RecommendStrategy {
// Average positive and negative vectors and create a single query with the formula
// Average positive and negative vectors and create a single query with the formula
// `query = avg_pos + avg_pos - avg_neg`. Then performs normal search.
AverageVector = 0;

// Uses custom search objective. Each candidate is compared against all
// examples, its score is then chosen from the `max(max_pos_score, max_neg_score)`.
// Uses custom search objective. Each candidate is compared against all
// examples, its score is then chosen from the `max(max_pos_score, max_neg_score)`.
// If the `max_neg_score` is chosen then it is squared and negated.
BestScore = 1;
}
Expand Down Expand Up @@ -434,7 +434,7 @@ message RecommendPointGroups {
message TargetVector {
oneof target {
VectorExample single = 1;

// leaving extensibility for possibly adding multi-target
}
}
Expand Down Expand Up @@ -508,6 +508,15 @@ enum Fusion {
DBSF = 1; // Distribution-Based Score Fusion
}

/// Sample points from the collection
///
/// Available sampling methods:
///
/// * `random` - Random sampling
enum Sample {
Random = 0;
}

message Query {
oneof variant {
VectorInput nearest = 1; // Find the nearest neighbors to this vector.
Expand All @@ -516,6 +525,7 @@ message Query {
ContextInput context = 4; // Return points that live in positive areas.
OrderBy order_by = 5; // Order the points by a payload field.
Fusion fusion = 6; // Fuse the results of multiple prefetches.
Sample sample = 7; // Sample points from the collection.
}
}

Expand Down Expand Up @@ -688,7 +698,7 @@ message GroupId {

message PointGroup {
GroupId id = 1; // Group id
repeated ScoredPoint hits = 2; // Points in the group
repeated ScoredPoint hits = 2; // Points in the group
RetrievedPoint lookup = 3; // Point(s) from the lookup collection that matches the group id
}

Expand Down Expand Up @@ -797,12 +807,12 @@ message Filter {
repeated Condition should = 1; // At least one of those conditions should match
repeated Condition must = 2; // All conditions must match
repeated Condition must_not = 3; // All conditions must NOT match
optional MinShould min_should = 4; // At least minimum amount of given conditions should match
optional MinShould min_should = 4; // At least minimum amount of given conditions should match
}

message MinShould {
repeated Condition conditions = 1;
uint64 min_count = 2;
repeated Condition conditions = 1;
uint64 min_count = 2;
}

message Condition {
Expand Down
28 changes: 28 additions & 0 deletions tests/congruence_tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,19 @@ def dense_query_lookup_from(
def no_query_no_prefetch(cls, client: QdrantBase) -> models.QueryResponse:
return client.query_points(collection_name=COLLECTION_NAME, limit=10)

@classmethod
def random_query(cls, client: QdrantBase) -> models.QueryResponse:
result = client.query_points(
collection_name=COLLECTION_NAME,
query=models.SampleQuery(sample=models.Sample.RANDOM),
limit=100,
)

# sort to be able to compare
result.points.sort(key=lambda point: point.id)

return result


def group_by_keys():
return ["maybe", "rand_digit", "two_words", "city.name", "maybe_null", "id"]
Expand Down Expand Up @@ -1341,3 +1354,18 @@ def test_query_group(prefer_grpc):
except AssertionError as e:
print(f"\nFailed with filter {query_filter}")
raise e


@pytest.mark.parametrize("prefer_grpc", [False, True])
def test_random_sampling(prefer_grpc):
fixture_points = generate_fixtures(100)

searcher = TestSimpleSearcher()

local_client = init_local()
init_client(local_client, fixture_points)

remote_client = init_remote(prefer_grpc=prefer_grpc)
init_client(remote_client, fixture_points)

compare_client_results(local_client, remote_client, searcher.random_query)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are the random implementations between local and server equivalent?
I would expect those to return different values.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this test we basically sample all the points available in the collection
In random_query we use limit=100 and we generate 100 points in fixture_points = generate_fixtures(100)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We return all points postprocessed to be sorted by ID. Just to make sure we can return all of them

8 changes: 7 additions & 1 deletion tests/conversions/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,13 +1077,19 @@
query_order_by = grpc.Query(order_by=order_by)
query_fusion = grpc.Query(fusion=grpc.Fusion.RRF)
query_fusion_dbsf = grpc.Query(fusion=grpc.Fusion.DBSF)
query_sample = grpc.Query(sample=grpc.Sample.Random)

deep_prefetch_query = grpc.PrefetchQuery(query=query_recommend)
prefetch_query = grpc.PrefetchQuery(
prefetch=[deep_prefetch_query],
filter=filter_,
query=query_fusion_dbsf,
)
prefetch_random_sample = grpc.PrefetchQuery(
prefetch=[deep_prefetch_query],
filter=filter_,
query=query_sample
)
prefetch_full_query = grpc.PrefetchQuery(
prefetch=[prefetch_query],
query=query_fusion,
Expand All @@ -1094,7 +1100,7 @@
lookup_from=lookup_location_1,
)
prefetch_many = grpc.PrefetchQuery(
prefetch=[prefetch_query, prefetch_full_query],
prefetch=[prefetch_query, prefetch_full_query, prefetch_random_sample],
)

fixtures = {
Expand Down
Loading