Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi queries #307

Merged
merged 10 commits into from
Feb 9, 2023
4 changes: 3 additions & 1 deletion src/marqo/s2_inference/processing/custom_clip_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
import os
import urllib
from tqdm import tqdm
from src.marqo.s2_inference.configs import ModelCache
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mentioned this bug to li, i think he might have made the same change. looks good.

from marqo.s2_inference.configs import ModelCache


def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
Expand Down
2 changes: 1 addition & 1 deletion src/marqo/s2_inference/s2_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def vectorise(model_name: str, content: Union[str, List[str]], model_properties:
try:
vectorised = available_models[model_cache_key].encode(content, normalize=normalize_embeddings, **kwargs)
except UnidentifiedImageError as e:
raise VectoriseError from e
raise VectoriseError(str(e)) from e
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This adds the UnidentifiedImageError message to e


return _convert_vectorized_output(vectorised)

Expand Down
2 changes: 1 addition & 1 deletion src/marqo/tensor_search/models/api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


class SearchQuery(BaseModel):
q: str
q: Union[str, dict]
searchableAttributes: Union[None, List[str]] = None
searchMethod: Union[None, str] = "TENSOR"
limit: int = 10
Expand Down
59 changes: 46 additions & 13 deletions src/marqo/tensor_search/tensor_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import typing
import uuid
from typing import List, Optional, Union, Iterable, Sequence, Dict, Any
import numpy as np
from PIL import Image
from marqo.tensor_search.enums import (
MediaType, MlModel, TensorField, SearchMethod, OpenSearchDataType,
Expand Down Expand Up @@ -765,7 +766,8 @@ def refresh_index(config: Config, index_name: str):
return HttpRequests(config).post(path=F"{index_name}/_refresh")


def search(config: Config, index_name: str, text: str, result_count: int = 3, offset: int = 0, highlights=True, return_doc_ids=True,
def search(config: Config, index_name: str, text: Union[str, dict],
result_count: int = 3, offset: int = 0, highlights=True, return_doc_ids=True,
search_method: Union[str, SearchMethod, None] = SearchMethod.TENSOR,
searchable_attributes: Iterable[str] = None, verbose: int = 0, num_highlights: int = 3,
reranker: Union[str, Dict] = None, simplified_format: bool = True, filter: str = None,
Expand Down Expand Up @@ -802,12 +804,15 @@ def search(config: Config, index_name: str, text: str, result_count: int = 3, of
if offset < 0:
raise errors.IllegalRequestedDocCount("search result offset cannot be less than 0!")

# validate query
validation.validate_query(q=text, search_method=search_method)

# Validate result_count + offset <= int(max_docs_limit)
max_docs_limit = utils.read_env_vars_and_defaults(EnvVars.MARQO_MAX_RETRIEVABLE_DOCS)
check_upper = True if max_docs_limit is None else result_count + offset <= int(max_docs_limit)
if not check_upper:
upper_bound_explanation = ("The search result limit + offset must be less than or equal to the "
f"MARQO_MAX_RETRIEVABLE_DOCS limit of [{max_docs_limit}]. ")
f"MARQO_MAX_RETRIEVABLE_DOCS limit of [{max_docs_limit}]. ")

raise errors.IllegalRequestedDocCount(f"{upper_bound_explanation} Marqo received search result limit of `{result_count}` "
f"and offset of `{offset}`.")
Expand Down Expand Up @@ -835,7 +840,7 @@ def search(config: Config, index_name: str, text: str, result_count: int = 3, of

if search_method.upper() == SearchMethod.TENSOR:
search_result = _vector_text_search(
config=config, index_name=index_name, text=text, result_count=result_count, offset=offset,
config=config, index_name=index_name, query=text, result_count=result_count, offset=offset,
return_doc_ids=return_doc_ids, searchable_attributes=searchable_attributes, verbose=verbose,
number_of_highlights=num_highlights, simplified_format=simplified_format,
filter_string=filter, device=device, attributes_to_retrieve=attributes_to_retrieve
Expand Down Expand Up @@ -1001,16 +1006,17 @@ def _lexical_search(


def _vector_text_search(
config: Config, index_name: str, text: str, result_count: int = 5, offset: int = 0, return_doc_ids=False,
searchable_attributes: Iterable[str] = None, number_of_highlights=3,
config: Config, index_name: str, query: Union[str, dict], result_count: int = 5, offset: int = 0,
return_doc_ids=False, searchable_attributes: Iterable[str] = None, number_of_highlights=3,
verbose=0, raise_on_searchable_attribs=False, hide_vectors=True, k=500,
simplified_format=True, filter_string: str = None, device=None,
attributes_to_retrieve: Optional[List[str]] = None):
"""
Args:
config:
index_name:
text:
query: either a string query (which can be a URL or natural language text), or a dict of
<query string>:<weight float> pairs.
result_count:
offset:
return_doc_ids: if True adds doc _id to the docs. Otherwise just returns the docs as-is
Expand Down Expand Up @@ -1049,12 +1055,38 @@ def _vector_text_search(
raise errors.IndexNotFoundError(message="Tried to search a non-existent index: {}".format(index_name))
selected_device = config.indexing_device if device is None else device

# TODO average over vectorized inputs with weights
# query, weight pairs, if query is a dict:
ordered_queries = None

if isinstance(query, str):
to_be_vectorised = [query, ]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this just be to_be_vectorised = [query]?

else: # is dict:
ordered_queries = list(query.items())
if index_info.index_settings[NsField.index_defaults][NsField.treat_urls_and_pointers_as_images]:
text_queries = [k for k, _ in ordered_queries if _is_image(k)]
image_queries = [k for k, _ in ordered_queries if not _is_image(k)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could these be extracted in 1 loop to make it faster? like

for k, _ in ordered_queries:
   if _is_image(k):
      text_queries.append(k)
   else:
      image_queries.append(k)

to_be_vectorised = [batch for batch in [text_queries, image_queries] if batch]
else:
to_be_vectorised = [[k for k, _ in ordered_queries], ]
try:
vectorised_text = s2_inference.vectorise(
model_name=index_info.model_name, model_properties=_get_model_properties(index_info), content=text,
device=selected_device,
normalize_embeddings=index_info.index_settings['index_defaults']['normalize_embeddings'])[0]
vectorised_text = functools.reduce(lambda x, y: x + y,
[ s2_inference.vectorise(
model_name=index_info.model_name, model_properties=_get_model_properties(index_info),
content=batch, device=selected_device,
normalize_embeddings=index_info.index_settings['index_defaults']['normalize_embeddings'])
for batch in to_be_vectorised]
)
if ordered_queries:
# multiple queries. We have to weight and combine them:
weighted_vectors = [np.asarray(vec) * weight for vec, weight in zip(vectorised_text, [w for _, w in ordered_queries])]
vectorised_text = np.mean(weighted_vectors, axis=0)
if index_info.index_settings['index_defaults']['normalize_embeddings']:
norm = np.linalg.norm(vectorised_text, axis=-1, keepdims=True)
if norm > 0:
vectorised_text /= np.linalg.norm(vectorised_text, axis=-1, keepdims=True)
vectorised_text = list(vectorised_text)
else:
vectorised_text = vectorised_text[0]
except (s2_inference_errors.UnknownModelError,
s2_inference_errors.InvalidModelPropertiesError,
s2_inference_errors.ModelLoadError) as model_error:
Expand Down Expand Up @@ -1182,9 +1214,10 @@ def _vector_text_search(
"Try reducing the query's limit parameter") from e
elif 'parse_exception' in response["responses"][0]["error"]["root_cause"][0]["reason"]:
raise errors.InvalidArgError("Syntax error, could not parse filter string") from e
elif contextualised_filter in response["responses"][0]["error"]["root_cause"][0]["reason"]:
elif (contextualised_filter
and contextualised_filter in response["responses"][0]["error"]["root_cause"][0]["reason"]):
raise errors.InvalidArgError("Syntax error, could not parse filter string") from e
raise e
raise errors.BackendCommunicationError(f"Error communicating with Marqo-OS backend:\n{response}")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We pass through the Marqo-os error as fallback, if not yet handled

except (KeyError, IndexError) as e2:
raise e

Expand Down
39 changes: 37 additions & 2 deletions src/marqo/tensor_search/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,52 @@
import typing
from marqo.tensor_search import constants
from marqo.tensor_search import enums, utils
from typing import Iterable, Container
from typing import Iterable, Container, Union
from marqo.errors import (
MarqoError, InvalidFieldNameError, InvalidArgError, InternalError,
InvalidDocumentIdError, DocTooLargeError, InvalidIndexNameError)
from marqo.tensor_search.enums import TensorField
from marqo.tensor_search.enums import TensorField, SearchMethod
from marqo.tensor_search import constants
from typing import Any, Type
import inspect
from enum import Enum


def validate_query(q: Union[dict, str], search_method: Union[str, SearchMethod]):
"""
Returns q if an error is not raised"""
usage_ref = "\nSee query reference here: https://docs.marqo.ai/0.0.13/API-Reference/search/#query-q"
if isinstance(q, dict):
if search_method.upper() != SearchMethod.TENSOR:
raise InvalidArgError(
'Multi-query search is currently only supported for search_method="TENSOR" '
f"\nReceived search_method `{search_method}`. {usage_ref}")
if not len(q):
raise InvalidArgError(
"Multi-query search requires at least one query! Received empty dictionary. "
f"{usage_ref}"
)
for k, v in q.items():
base_invalid_kv_message = "Multi queries dictionaries must be <string>:<float> pairs. "
if not isinstance(k, str):
raise InvalidArgError(
f"{base_invalid_kv_message}Found key of type `{type(k)}` instead of string. Key=`{k}`"
f"{usage_ref}"
)
if not isinstance(v, (int, float)):
raise InvalidArgError(
f"{base_invalid_kv_message}Found value of type `{type(v)}` instead of float. Value=`{v}`"
f" {usage_ref}"
)
elif not isinstance(q, str):
raise InvalidArgError(
f"q must be a string or dict! Received q of type `{type(q)}`. "
f"\nq=`{q}`"
f"{usage_ref}"
)
return q


def validate_str_against_enum(value: Any, enum_class: Type[Enum], case_sensitive: bool = True):
"""Checks whether a value is found as the value of a str attribute of the
given enum_class.
Expand Down
Loading