Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add query_namespaces #409

Merged
merged 7 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 44 additions & 19 deletions pinecone/core/openapi/shared/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,24 @@
import typing
from urllib.parse import quote
from urllib3.fields import RequestField
import time
import random

def retry_api_call(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just manually stubbing out retries in the generated code for now? Just curious, also regarding the print statement down there and whether it should be uncommented.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was just me throwing in something basic to get started. This is used to wrap __call_api, but I need to investigate the tuning of the constants and stuff to get the sleep intervals to sensible levels.

Also re: this being generated code, very soon I will be moving these elsewhere and not generating them, since the ApiClient and a couple other classes don't contain any generated content.

func, args=(), kwargs={}, retries=3, backoff=1, jitter=0.5
):
attempts = 0
while attempts < retries:
try:
return func(*args, **kwargs) # Attempt to call __call_api
except Exception as e:
attempts += 1
if attempts >= retries:
print(f"API call failed after {attempts} attempts: {e}")
raise # Re-raise exception if retries are exhausted
sleep_time = backoff * (2 ** (attempts - 1)) + random.uniform(0, jitter)
# print(f"Retrying ({attempts}/{retries}) in {sleep_time:.2f} seconds after error: {e}")
time.sleep(sleep_time)


from pinecone.core.openapi.shared import rest
Expand Down Expand Up @@ -397,25 +415,32 @@ def call_api(
)

return self.pool.apply_async(
self.__call_api,
(
resource_path,
method,
path_params,
query_params,
header_params,
body,
post_params,
files,
response_type,
auth_settings,
_return_http_data_only,
collection_formats,
_preload_content,
_request_timeout,
_host,
_check_type,
),
retry_api_call,
args=(
self.__call_api, # Pass the API call function as the first argument
(
resource_path,
method,
path_params,
query_params,
header_params,
body,
post_params,
files,
response_type,
auth_settings,
_return_http_data_only,
collection_formats,
_preload_content,
_request_timeout,
_host,
_check_type,
),
{}, # empty kwargs dictionary
3, # retries
1, # backoff time
0.5 # jitter
)
)

def request(
Expand Down
84 changes: 82 additions & 2 deletions pinecone/data/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
)
from .features.bulk_import import ImportFeatureMixin
from .vector_factory import VectorFactory
from .query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults
from multiprocessing.pool import ApplyResult

from pinecone_plugin_interface import load_and_install as install_plugins

Expand Down Expand Up @@ -387,7 +389,7 @@ def query(
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
] = None,
**kwargs,
) -> QueryResponse:
) -> Union[QueryResponse, ApplyResult]:
"""
The Query operation searches a namespace, using a query vector.
It retrieves the ids of the most similar items in a namespace, along with their similarity scores.
Expand Down Expand Up @@ -429,6 +431,39 @@ def query(
and namespace name.
"""

response = self._query(
*args,
top_k=top_k,
vector=vector,
id=id,
namespace=namespace,
filter=filter,
include_values=include_values,
include_metadata=include_metadata,
sparse_vector=sparse_vector,
**kwargs,
)

if kwargs.get("async_req", False):
return response
else:
return parse_query_response(response)

def _query(
self,
*args,
top_k: int,
vector: Optional[List[float]] = None,
id: Optional[str] = None,
namespace: Optional[str] = None,
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
include_values: Optional[bool] = None,
include_metadata: Optional[bool] = None,
sparse_vector: Optional[
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
] = None,
**kwargs,
) -> QueryResponse:
if len(args) > 0:
raise ValueError(
"The argument order for `query()` has changed; please use keyword arguments instead of positional arguments. Example: index.query(vector=[0.1, 0.2, 0.3], top_k=10, namespace='my_namespace')"
Expand Down Expand Up @@ -461,7 +496,52 @@ def query(
),
**{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS},
)
return parse_query_response(response)
return response

@validate_and_convert_errors
def query_namespaces(
self,
vector: List[float],
namespaces: List[str],
top_k: Optional[int] = None,
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
include_values: Optional[bool] = None,
include_metadata: Optional[bool] = None,
sparse_vector: Optional[
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
] = None,
**kwargs,
) -> QueryNamespacesResults:
if namespaces is None or len(namespaces) == 0:
raise ValueError("At least one namespace must be specified")
if len(vector) == 0:
raise ValueError("Query vector must not be empty")

overall_topk = top_k if top_k is not None else 10
aggregator = QueryResultsAggregator(top_k=overall_topk)

target_namespaces = set(namespaces) # dedup namespaces
async_results = [
self.query(
vector=vector,
namespace=ns,
top_k=overall_topk,
filter=filter,
include_values=include_values,
include_metadata=include_metadata,
sparse_vector=sparse_vector,
async_req=True,
**kwargs,
)
for ns in target_namespaces
]

for result in async_results:
response = result.get()
aggregator.add_results(response)

final_results = aggregator.get_results()
return final_results

@validate_and_convert_errors
def update(
Expand Down
193 changes: 193 additions & 0 deletions pinecone/data/query_results_aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
from typing import List, Tuple, Optional, Any, Dict
import json
import heapq
from pinecone.core.openapi.data.models import Usage
from pinecone.core.openapi.data.models import QueryResponse as OpenAPIQueryResponse

from dataclasses import dataclass, asdict


@dataclass
class ScoredVectorWithNamespace:
namespace: str
score: float
id: str
values: List[float]
sparse_values: dict
metadata: dict

def __init__(self, aggregate_results_heap_tuple: Tuple[float, int, object, str]):
json_vector = aggregate_results_heap_tuple[2]
self.namespace = aggregate_results_heap_tuple[3]
self.id = json_vector.get("id") # type: ignore
self.score = json_vector.get("score") # type: ignore
self.values = json_vector.get("values") # type: ignore
self.sparse_values = json_vector.get("sparse_values", None) # type: ignore
self.metadata = json_vector.get("metadata", None) # type: ignore

def __getitem__(self, key):
if hasattr(self, key):
return getattr(self, key)
else:
raise KeyError(f"'{key}' not found in ScoredVectorWithNamespace")

def get(self, key, default=None):
return getattr(self, key, default)

def __repr__(self):
return json.dumps(self._truncate(asdict(self)), indent=4)

def __json__(self):
return self._truncate(asdict(self))

def _truncate(self, obj, max_items=2):
"""
Recursively traverse and truncate lists that exceed max_items length.
Only display the "... X more" message if at least 2 elements are hidden.
"""
if obj is None:
return None # Skip None values
elif isinstance(obj, list):
filtered_list = [self._truncate(i, max_items) for i in obj if i is not None]
if len(filtered_list) > max_items:
# Show the truncation message only if more than 1 item is hidden
remaining_items = len(filtered_list) - max_items
if remaining_items > 1:
return filtered_list[:max_items] + [f"... {remaining_items} more"]
else:
# If only 1 item remains, show it
return filtered_list
return filtered_list
elif isinstance(obj, dict):
# Recursively process dictionaries, omitting None values
return {k: self._truncate(v, max_items) for k, v in obj.items() if v is not None}
return obj


@dataclass
class QueryNamespacesResults:
usage: Usage
matches: List[ScoredVectorWithNamespace]

def __getitem__(self, key):
if hasattr(self, key):
return getattr(self, key)
else:
raise KeyError(f"'{key}' not found in QueryNamespacesResults")

def get(self, key, default=None):
return getattr(self, key, default)

def __repr__(self):
return json.dumps(
{
"usage": self.usage.to_dict(),
"matches": [match.__json__() for match in self.matches],
},
indent=4,
)


class QueryResultsAggregregatorNotEnoughResultsError(Exception):
def __init__(self):
super().__init__(
"Cannot interpret results without at least two matches. In order to aggregate results from multiple queries, top_k must be greater than 1 in order to correctly infer the similarity metric from scores."
)


class QueryResultsAggregatorInvalidTopKError(Exception):
def __init__(self, top_k: int):
super().__init__(
f"Invalid top_k value {top_k}. To aggregate results from multiple queries the top_k must be at least 2."
)


class QueryResultsAggregator:
def __init__(self, top_k: int):
if top_k < 2:
raise QueryResultsAggregatorInvalidTopKError(top_k)
self.top_k = top_k
self.usage_read_units = 0
self.heap: List[Tuple[float, int, object, str]] = []
self.insertion_counter = 0
self.is_dotproduct = None
self.read = False
self.final_results: Optional[QueryNamespacesResults] = None

def _is_dotproduct_index(self, matches):
# The interpretation of the score depends on the similar metric used.
# Unlike other index types, in indexes configured for dotproduct,
# a higher score is better. We have to infer this is the case by inspecting
# the order of the scores in the results.
for i in range(1, len(matches)):
if matches[i].get("score") > matches[i - 1].get("score"): # Found an increase
return False
return True

def _dotproduct_heap_item(self, match, ns):
return (match.get("score"), -self.insertion_counter, match, ns)

def _non_dotproduct_heap_item(self, match, ns):
return (-match.get("score"), -self.insertion_counter, match, ns)

def _process_matches(self, matches, ns, heap_item_fn):
for match in matches:
self.insertion_counter += 1
if len(self.heap) < self.top_k:
heapq.heappush(self.heap, heap_item_fn(match, ns))
else:
# Assume we have dotproduct scores sorted in descending order
if self.is_dotproduct and match["score"] < self.heap[0][0]:
# No further matches can improve the top-K heap
break
elif not self.is_dotproduct and match["score"] > -self.heap[0][0]:
# No further matches can improve the top-K heap
break
heapq.heappushpop(self.heap, heap_item_fn(match, ns))

def add_results(self, results: Dict[str, Any]):
if self.read:
# This is mainly just to sanity check in test cases which get quite confusing
# if you read results twice due to the heap being emptied when constructing
# the ordered results.
raise ValueError("Results have already been read. Cannot add more results.")

matches = results.get("matches", [])
ns: str = results.get("namespace", "")
if isinstance(results, OpenAPIQueryResponse):
self.usage_read_units += results.usage.read_units
else:
self.usage_read_units += results.get("usage", {}).get("readUnits", 0)

if len(matches) == 0:
return

if self.is_dotproduct is None:
if len(matches) == 1:
# This condition should match the second time we add results containing
# only one match. We need at least two matches in a single response in order
# to infer the similarity metric
raise QueryResultsAggregregatorNotEnoughResultsError()
self.is_dotproduct = self._is_dotproduct_index(matches)

if self.is_dotproduct:
self._process_matches(matches, ns, self._dotproduct_heap_item)
else:
self._process_matches(matches, ns, self._non_dotproduct_heap_item)

def get_results(self) -> QueryNamespacesResults:
if self.read:
if self.final_results is not None:
return self.final_results
else:
# I don't think this branch can ever actually be reached, but the type checker disagrees
raise ValueError("Results have already been read. Cannot get results again.")
self.read = True

self.final_results = QueryNamespacesResults(
usage=Usage(read_units=self.usage_read_units),
matches=[
ScoredVectorWithNamespace(heapq.heappop(self.heap)) for _ in range(len(self.heap))
][::-1],
)
return self.final_results
Loading
Loading