-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add query_namespaces
#409
Merged
Merged
Add query_namespaces
#409
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
ac06d9f
Add query_namespaces
jhamon 103b744
Add retries for threadpool requests
jhamon 8cdee1a
Add get() accessors
jhamon 4fb3047
Add grpc implementation
jhamon ea54772
Add integration tests for query_namespaces
jhamon 561ba06
Remove changes to index grpc
jhamon 5da2610
Run int tests for rest only
jhamon File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
from typing import List, Tuple, Optional, Any, Dict | ||
import json | ||
import heapq | ||
from pinecone.core.openapi.data.models import Usage | ||
from pinecone.core.openapi.data.models import QueryResponse as OpenAPIQueryResponse | ||
|
||
from dataclasses import dataclass, asdict | ||
|
||
|
||
@dataclass | ||
class ScoredVectorWithNamespace: | ||
namespace: str | ||
score: float | ||
id: str | ||
values: List[float] | ||
sparse_values: dict | ||
metadata: dict | ||
|
||
def __init__(self, aggregate_results_heap_tuple: Tuple[float, int, object, str]): | ||
json_vector = aggregate_results_heap_tuple[2] | ||
self.namespace = aggregate_results_heap_tuple[3] | ||
self.id = json_vector.get("id") # type: ignore | ||
self.score = json_vector.get("score") # type: ignore | ||
self.values = json_vector.get("values") # type: ignore | ||
self.sparse_values = json_vector.get("sparse_values", None) # type: ignore | ||
self.metadata = json_vector.get("metadata", None) # type: ignore | ||
|
||
def __getitem__(self, key): | ||
if hasattr(self, key): | ||
return getattr(self, key) | ||
else: | ||
raise KeyError(f"'{key}' not found in ScoredVectorWithNamespace") | ||
|
||
def get(self, key, default=None): | ||
return getattr(self, key, default) | ||
|
||
def __repr__(self): | ||
return json.dumps(self._truncate(asdict(self)), indent=4) | ||
|
||
def __json__(self): | ||
return self._truncate(asdict(self)) | ||
|
||
def _truncate(self, obj, max_items=2): | ||
""" | ||
Recursively traverse and truncate lists that exceed max_items length. | ||
Only display the "... X more" message if at least 2 elements are hidden. | ||
""" | ||
if obj is None: | ||
return None # Skip None values | ||
elif isinstance(obj, list): | ||
filtered_list = [self._truncate(i, max_items) for i in obj if i is not None] | ||
if len(filtered_list) > max_items: | ||
# Show the truncation message only if more than 1 item is hidden | ||
remaining_items = len(filtered_list) - max_items | ||
if remaining_items > 1: | ||
return filtered_list[:max_items] + [f"... {remaining_items} more"] | ||
else: | ||
# If only 1 item remains, show it | ||
return filtered_list | ||
return filtered_list | ||
elif isinstance(obj, dict): | ||
# Recursively process dictionaries, omitting None values | ||
return {k: self._truncate(v, max_items) for k, v in obj.items() if v is not None} | ||
return obj | ||
|
||
|
||
@dataclass | ||
class QueryNamespacesResults: | ||
usage: Usage | ||
matches: List[ScoredVectorWithNamespace] | ||
|
||
def __getitem__(self, key): | ||
if hasattr(self, key): | ||
return getattr(self, key) | ||
else: | ||
raise KeyError(f"'{key}' not found in QueryNamespacesResults") | ||
|
||
def get(self, key, default=None): | ||
return getattr(self, key, default) | ||
|
||
def __repr__(self): | ||
return json.dumps( | ||
{ | ||
"usage": self.usage.to_dict(), | ||
"matches": [match.__json__() for match in self.matches], | ||
}, | ||
indent=4, | ||
) | ||
|
||
|
||
class QueryResultsAggregregatorNotEnoughResultsError(Exception): | ||
def __init__(self): | ||
super().__init__( | ||
"Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores." | ||
) | ||
|
||
|
||
class QueryResultsAggregatorInvalidTopKError(Exception): | ||
def __init__(self, top_k: int): | ||
super().__init__( | ||
f"Invalid top_k value {top_k}. To aggregate results from multiple queries the top_k must be at least 2." | ||
) | ||
|
||
|
||
class QueryResultsAggregator: | ||
def __init__(self, top_k: int): | ||
if top_k < 2: | ||
raise QueryResultsAggregatorInvalidTopKError(top_k) | ||
self.top_k = top_k | ||
self.usage_read_units = 0 | ||
self.heap: List[Tuple[float, int, object, str]] = [] | ||
self.insertion_counter = 0 | ||
self.is_dotproduct = None | ||
self.read = False | ||
self.final_results: Optional[QueryNamespacesResults] = None | ||
|
||
def _is_dotproduct_index(self, matches): | ||
# The interpretation of the score depends on the similar metric used. | ||
# Unlike other index types, in indexes configured for dotproduct, | ||
# a higher score is better. We have to infer this is the case by inspecting | ||
# the order of the scores in the results. | ||
for i in range(1, len(matches)): | ||
if matches[i].get("score") > matches[i - 1].get("score"): # Found an increase | ||
return False | ||
return True | ||
|
||
def _dotproduct_heap_item(self, match, ns): | ||
return (match.get("score"), -self.insertion_counter, match, ns) | ||
|
||
def _non_dotproduct_heap_item(self, match, ns): | ||
return (-match.get("score"), -self.insertion_counter, match, ns) | ||
|
||
def _process_matches(self, matches, ns, heap_item_fn): | ||
for match in matches: | ||
self.insertion_counter += 1 | ||
if len(self.heap) < self.top_k: | ||
heapq.heappush(self.heap, heap_item_fn(match, ns)) | ||
else: | ||
# Assume we have dotproduct scores sorted in descending order | ||
if self.is_dotproduct and match["score"] < self.heap[0][0]: | ||
# No further matches can improve the top-K heap | ||
break | ||
elif not self.is_dotproduct and match["score"] > -self.heap[0][0]: | ||
# No further matches can improve the top-K heap | ||
break | ||
heapq.heappushpop(self.heap, heap_item_fn(match, ns)) | ||
|
||
def add_results(self, results: Dict[str, Any]): | ||
if self.read: | ||
# This is mainly just to sanity check in test cases which get quite confusing | ||
# if you read results twice due to the heap being emptied when constructing | ||
# the ordered results. | ||
raise ValueError("Results have already been read. Cannot add more results.") | ||
|
||
matches = results.get("matches", []) | ||
ns: str = results.get("namespace", "") | ||
if isinstance(results, OpenAPIQueryResponse): | ||
self.usage_read_units += results.usage.read_units | ||
else: | ||
self.usage_read_units += results.get("usage", {}).get("readUnits", 0) | ||
|
||
if len(matches) == 0: | ||
return | ||
|
||
if self.is_dotproduct is None: | ||
if len(matches) == 1: | ||
# This condition should match the second time we add results containing | ||
# only one match. We need at least two matches in a single response in order | ||
# to infer the similarity metric | ||
raise QueryResultsAggregregatorNotEnoughResultsError() | ||
self.is_dotproduct = self._is_dotproduct_index(matches) | ||
|
||
if self.is_dotproduct: | ||
self._process_matches(matches, ns, self._dotproduct_heap_item) | ||
else: | ||
self._process_matches(matches, ns, self._non_dotproduct_heap_item) | ||
|
||
def get_results(self) -> QueryNamespacesResults: | ||
if self.read: | ||
if self.final_results is not None: | ||
return self.final_results | ||
else: | ||
# I don't think this branch can ever actually be reached, but the type checker disagrees | ||
raise ValueError("Results have already been read. Cannot get results again.") | ||
self.read = True | ||
|
||
self.final_results = QueryNamespacesResults( | ||
usage=Usage(read_units=self.usage_read_units), | ||
matches=[ | ||
ScoredVectorWithNamespace(heapq.heappop(self.heap)) for _ in range(len(self.heap)) | ||
][::-1], | ||
) | ||
return self.final_results |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this just manually stubbing out retries in the generated code for now? Just curious, also regarding the print statement down there and whether it should be uncommented.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was just me throwing in something basic to get started. This is used to wrap __call_api, but I need to investigate the tuning of the constants and stuff to get the sleep intervals to sensible levels.
Also re: this being generated code, very soon I will be moving these elsewhere and not generating them, since the ApiClient and a couple other classes don't contain any generated content.