-
Notifications
You must be signed in to change notification settings - Fork 198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi queries #307
Multi queries #307
Changes from 9 commits
5e5007b
8f29129
2c11ab2
959e6b8
18202a6
b06d830
fced1f5
283b175
96a23f6
78468f6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,7 +45,7 @@ def vectorise(model_name: str, content: Union[str, List[str]], model_properties: | |
try: | ||
vectorised = available_models[model_cache_key].encode(content, normalize=normalize_embeddings, **kwargs) | ||
except UnidentifiedImageError as e: | ||
raise VectoriseError from e | ||
raise VectoriseError(str(e)) from e | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This adds the UnidentifiedImageError message to e |
||
|
||
return _convert_vectorized_output(vectorised) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,6 +38,7 @@ | |
import typing | ||
import uuid | ||
from typing import List, Optional, Union, Iterable, Sequence, Dict, Any | ||
import numpy as np | ||
from PIL import Image | ||
from marqo.tensor_search.enums import ( | ||
MediaType, MlModel, TensorField, SearchMethod, OpenSearchDataType, | ||
|
@@ -765,7 +766,8 @@ def refresh_index(config: Config, index_name: str): | |
return HttpRequests(config).post(path=F"{index_name}/_refresh") | ||
|
||
|
||
def search(config: Config, index_name: str, text: str, result_count: int = 3, offset: int = 0, highlights=True, return_doc_ids=True, | ||
def search(config: Config, index_name: str, text: Union[str, dict], | ||
result_count: int = 3, offset: int = 0, highlights=True, return_doc_ids=True, | ||
search_method: Union[str, SearchMethod, None] = SearchMethod.TENSOR, | ||
searchable_attributes: Iterable[str] = None, verbose: int = 0, num_highlights: int = 3, | ||
reranker: Union[str, Dict] = None, simplified_format: bool = True, filter: str = None, | ||
|
@@ -802,12 +804,15 @@ def search(config: Config, index_name: str, text: str, result_count: int = 3, of | |
if offset < 0: | ||
raise errors.IllegalRequestedDocCount("search result offset cannot be less than 0!") | ||
|
||
# validate query | ||
validation.validate_query(q=text, search_method=search_method) | ||
|
||
# Validate result_count + offset <= int(max_docs_limit) | ||
max_docs_limit = utils.read_env_vars_and_defaults(EnvVars.MARQO_MAX_RETRIEVABLE_DOCS) | ||
check_upper = True if max_docs_limit is None else result_count + offset <= int(max_docs_limit) | ||
if not check_upper: | ||
upper_bound_explanation = ("The search result limit + offset must be less than or equal to the " | ||
f"MARQO_MAX_RETRIEVABLE_DOCS limit of [{max_docs_limit}]. ") | ||
f"MARQO_MAX_RETRIEVABLE_DOCS limit of [{max_docs_limit}]. ") | ||
|
||
raise errors.IllegalRequestedDocCount(f"{upper_bound_explanation} Marqo received search result limit of `{result_count}` " | ||
f"and offset of `{offset}`.") | ||
|
@@ -835,7 +840,7 @@ def search(config: Config, index_name: str, text: str, result_count: int = 3, of | |
|
||
if search_method.upper() == SearchMethod.TENSOR: | ||
search_result = _vector_text_search( | ||
config=config, index_name=index_name, text=text, result_count=result_count, offset=offset, | ||
config=config, index_name=index_name, query=text, result_count=result_count, offset=offset, | ||
return_doc_ids=return_doc_ids, searchable_attributes=searchable_attributes, verbose=verbose, | ||
number_of_highlights=num_highlights, simplified_format=simplified_format, | ||
filter_string=filter, device=device, attributes_to_retrieve=attributes_to_retrieve | ||
|
@@ -1001,16 +1006,17 @@ def _lexical_search( | |
|
||
|
||
def _vector_text_search( | ||
config: Config, index_name: str, text: str, result_count: int = 5, offset: int = 0, return_doc_ids=False, | ||
searchable_attributes: Iterable[str] = None, number_of_highlights=3, | ||
config: Config, index_name: str, query: Union[str, dict], result_count: int = 5, offset: int = 0, | ||
return_doc_ids=False, searchable_attributes: Iterable[str] = None, number_of_highlights=3, | ||
verbose=0, raise_on_searchable_attribs=False, hide_vectors=True, k=500, | ||
simplified_format=True, filter_string: str = None, device=None, | ||
attributes_to_retrieve: Optional[List[str]] = None): | ||
""" | ||
Args: | ||
config: | ||
index_name: | ||
text: | ||
query: either a string query (which can be a URL or natural language text), or a dict of | ||
<query string>:<weight float> pairs. | ||
result_count: | ||
offset: | ||
return_doc_ids: if True adds doc _id to the docs. Otherwise just returns the docs as-is | ||
|
@@ -1049,12 +1055,38 @@ def _vector_text_search( | |
raise errors.IndexNotFoundError(message="Tried to search a non-existent index: {}".format(index_name)) | ||
selected_device = config.indexing_device if device is None else device | ||
|
||
# TODO average over vectorized inputs with weights | ||
# query, weight pairs, if query is a dict: | ||
ordered_queries = None | ||
|
||
if isinstance(query, str): | ||
to_be_vectorised = [query, ] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can this just be |
||
else: # is dict: | ||
ordered_queries = list(query.items()) | ||
if index_info.index_settings[NsField.index_defaults][NsField.treat_urls_and_pointers_as_images]: | ||
text_queries = [k for k, _ in ordered_queries if _is_image(k)] | ||
image_queries = [k for k, _ in ordered_queries if not _is_image(k)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could these be extracted in 1 loop to make it faster? like
|
||
to_be_vectorised = [batch for batch in [text_queries, image_queries] if batch] | ||
else: | ||
to_be_vectorised = [[k for k, _ in ordered_queries], ] | ||
try: | ||
vectorised_text = s2_inference.vectorise( | ||
model_name=index_info.model_name, model_properties=_get_model_properties(index_info), content=text, | ||
device=selected_device, | ||
normalize_embeddings=index_info.index_settings['index_defaults']['normalize_embeddings'])[0] | ||
vectorised_text = functools.reduce(lambda x, y: x + y, | ||
[ s2_inference.vectorise( | ||
model_name=index_info.model_name, model_properties=_get_model_properties(index_info), | ||
content=batch, device=selected_device, | ||
normalize_embeddings=index_info.index_settings['index_defaults']['normalize_embeddings']) | ||
for batch in to_be_vectorised] | ||
) | ||
if ordered_queries: | ||
# multiple queries. We have to weight and combine them: | ||
weighted_vectors = [np.asarray(vec) * weight for vec, weight in zip(vectorised_text, [w for _, w in ordered_queries])] | ||
vectorised_text = np.mean(weighted_vectors, axis=0) | ||
if index_info.index_settings['index_defaults']['normalize_embeddings']: | ||
norm = np.linalg.norm(vectorised_text, axis=-1, keepdims=True) | ||
if norm > 0: | ||
vectorised_text /= np.linalg.norm(vectorised_text, axis=-1, keepdims=True) | ||
vectorised_text = list(vectorised_text) | ||
else: | ||
vectorised_text = vectorised_text[0] | ||
except (s2_inference_errors.UnknownModelError, | ||
s2_inference_errors.InvalidModelPropertiesError, | ||
s2_inference_errors.ModelLoadError) as model_error: | ||
|
@@ -1182,9 +1214,10 @@ def _vector_text_search( | |
"Try reducing the query's limit parameter") from e | ||
elif 'parse_exception' in response["responses"][0]["error"]["root_cause"][0]["reason"]: | ||
raise errors.InvalidArgError("Syntax error, could not parse filter string") from e | ||
elif contextualised_filter in response["responses"][0]["error"]["root_cause"][0]["reason"]: | ||
elif (contextualised_filter | ||
and contextualised_filter in response["responses"][0]["error"]["root_cause"][0]["reason"]): | ||
raise errors.InvalidArgError("Syntax error, could not parse filter string") from e | ||
raise e | ||
raise errors.BackendCommunicationError(f"Error communicating with Marqo-OS backend:\n{response}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We pass through the Marqo-os error as fallback, if not yet handled |
||
except (KeyError, IndexError) as e2: | ||
raise e | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mentioned this bug to li, i think he might have made the same change. looks good.