Skip to content

Commit

Permalink
Multi queries (#307)
Browse files Browse the repository at this point in the history
* Added multiple search queries

* bug fix

* Made unit test for ordering

* Fixed bug for norm of 0 vector

* Added validation test

* Separating images from text on search. added tests

* added extra tests, removed print statements

* Added documentation link to validation
  • Loading branch information
pandu-k authored Feb 9, 2023
1 parent ddec690 commit d6584bf
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 31 deletions.
2 changes: 2 additions & 0 deletions src/marqo/s2_inference/processing/custom_clip_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import urllib
from tqdm import tqdm
from marqo.s2_inference.configs import ModelCache


def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
Expand Down
2 changes: 1 addition & 1 deletion src/marqo/s2_inference/s2_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def vectorise(model_name: str, content: Union[str, List[str]], model_properties:
try:
vectorised = available_models[model_cache_key].encode(content, normalize=normalize_embeddings, **kwargs)
except UnidentifiedImageError as e:
raise VectoriseError from e
raise VectoriseError(str(e)) from e

return _convert_vectorized_output(vectorised)

Expand Down
2 changes: 1 addition & 1 deletion src/marqo/tensor_search/models/api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


class SearchQuery(BaseModel):
q: str
q: Union[str, dict]
searchableAttributes: Union[None, List[str]] = None
searchMethod: Union[None, str] = "TENSOR"
limit: int = 10
Expand Down
59 changes: 46 additions & 13 deletions src/marqo/tensor_search/tensor_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import typing
import uuid
from typing import List, Optional, Union, Iterable, Sequence, Dict, Any
import numpy as np
from PIL import Image
from marqo.tensor_search.enums import (
MediaType, MlModel, TensorField, SearchMethod, OpenSearchDataType,
Expand Down Expand Up @@ -765,7 +766,8 @@ def refresh_index(config: Config, index_name: str):
return HttpRequests(config).post(path=F"{index_name}/_refresh")


def search(config: Config, index_name: str, text: str, result_count: int = 3, offset: int = 0, highlights=True, return_doc_ids=True,
def search(config: Config, index_name: str, text: Union[str, dict],
result_count: int = 3, offset: int = 0, highlights=True, return_doc_ids=True,
search_method: Union[str, SearchMethod, None] = SearchMethod.TENSOR,
searchable_attributes: Iterable[str] = None, verbose: int = 0, num_highlights: int = 3,
reranker: Union[str, Dict] = None, simplified_format: bool = True, filter: str = None,
Expand Down Expand Up @@ -802,12 +804,15 @@ def search(config: Config, index_name: str, text: str, result_count: int = 3, of
if offset < 0:
raise errors.IllegalRequestedDocCount("search result offset cannot be less than 0!")

# validate query
validation.validate_query(q=text, search_method=search_method)

# Validate result_count + offset <= int(max_docs_limit)
max_docs_limit = utils.read_env_vars_and_defaults(EnvVars.MARQO_MAX_RETRIEVABLE_DOCS)
check_upper = True if max_docs_limit is None else result_count + offset <= int(max_docs_limit)
if not check_upper:
upper_bound_explanation = ("The search result limit + offset must be less than or equal to the "
f"MARQO_MAX_RETRIEVABLE_DOCS limit of [{max_docs_limit}]. ")
f"MARQO_MAX_RETRIEVABLE_DOCS limit of [{max_docs_limit}]. ")

raise errors.IllegalRequestedDocCount(f"{upper_bound_explanation} Marqo received search result limit of `{result_count}` "
f"and offset of `{offset}`.")
Expand Down Expand Up @@ -835,7 +840,7 @@ def search(config: Config, index_name: str, text: str, result_count: int = 3, of

if search_method.upper() == SearchMethod.TENSOR:
search_result = _vector_text_search(
config=config, index_name=index_name, text=text, result_count=result_count, offset=offset,
config=config, index_name=index_name, query=text, result_count=result_count, offset=offset,
return_doc_ids=return_doc_ids, searchable_attributes=searchable_attributes, verbose=verbose,
number_of_highlights=num_highlights, simplified_format=simplified_format,
filter_string=filter, device=device, attributes_to_retrieve=attributes_to_retrieve
Expand Down Expand Up @@ -1001,16 +1006,17 @@ def _lexical_search(


def _vector_text_search(
config: Config, index_name: str, text: str, result_count: int = 5, offset: int = 0, return_doc_ids=False,
searchable_attributes: Iterable[str] = None, number_of_highlights=3,
config: Config, index_name: str, query: Union[str, dict], result_count: int = 5, offset: int = 0,
return_doc_ids=False, searchable_attributes: Iterable[str] = None, number_of_highlights=3,
verbose=0, raise_on_searchable_attribs=False, hide_vectors=True, k=500,
simplified_format=True, filter_string: str = None, device=None,
attributes_to_retrieve: Optional[List[str]] = None):
"""
Args:
config:
index_name:
text:
query: either a string query (which can be a URL or natural language text), or a dict of
<query string>:<weight float> pairs.
result_count:
offset:
return_doc_ids: if True adds doc _id to the docs. Otherwise just returns the docs as-is
Expand Down Expand Up @@ -1049,12 +1055,38 @@ def _vector_text_search(
raise errors.IndexNotFoundError(message="Tried to search a non-existent index: {}".format(index_name))
selected_device = config.indexing_device if device is None else device

# TODO average over vectorized inputs with weights
# query, weight pairs, if query is a dict:
ordered_queries = None

if isinstance(query, str):
to_be_vectorised = [query, ]
else: # is dict:
ordered_queries = list(query.items())
if index_info.index_settings[NsField.index_defaults][NsField.treat_urls_and_pointers_as_images]:
text_queries = [k for k, _ in ordered_queries if _is_image(k)]
image_queries = [k for k, _ in ordered_queries if not _is_image(k)]
to_be_vectorised = [batch for batch in [text_queries, image_queries] if batch]
else:
to_be_vectorised = [[k for k, _ in ordered_queries], ]
try:
vectorised_text = s2_inference.vectorise(
model_name=index_info.model_name, model_properties=_get_model_properties(index_info), content=text,
device=selected_device,
normalize_embeddings=index_info.index_settings['index_defaults']['normalize_embeddings'])[0]
vectorised_text = functools.reduce(lambda x, y: x + y,
[ s2_inference.vectorise(
model_name=index_info.model_name, model_properties=_get_model_properties(index_info),
content=batch, device=selected_device,
normalize_embeddings=index_info.index_settings['index_defaults']['normalize_embeddings'])
for batch in to_be_vectorised]
)
if ordered_queries:
# multiple queries. We have to weight and combine them:
weighted_vectors = [np.asarray(vec) * weight for vec, weight in zip(vectorised_text, [w for _, w in ordered_queries])]
vectorised_text = np.mean(weighted_vectors, axis=0)
if index_info.index_settings['index_defaults']['normalize_embeddings']:
norm = np.linalg.norm(vectorised_text, axis=-1, keepdims=True)
if norm > 0:
vectorised_text /= np.linalg.norm(vectorised_text, axis=-1, keepdims=True)
vectorised_text = list(vectorised_text)
else:
vectorised_text = vectorised_text[0]
except (s2_inference_errors.UnknownModelError,
s2_inference_errors.InvalidModelPropertiesError,
s2_inference_errors.ModelLoadError) as model_error:
Expand Down Expand Up @@ -1182,9 +1214,10 @@ def _vector_text_search(
"Try reducing the query's limit parameter") from e
elif 'parse_exception' in response["responses"][0]["error"]["root_cause"][0]["reason"]:
raise errors.InvalidArgError("Syntax error, could not parse filter string") from e
elif contextualised_filter in response["responses"][0]["error"]["root_cause"][0]["reason"]:
elif (contextualised_filter
and contextualised_filter in response["responses"][0]["error"]["root_cause"][0]["reason"]):
raise errors.InvalidArgError("Syntax error, could not parse filter string") from e
raise e
raise errors.BackendCommunicationError(f"Error communicating with Marqo-OS backend:\n{response}")
except (KeyError, IndexError) as e2:
raise e

Expand Down
39 changes: 37 additions & 2 deletions src/marqo/tensor_search/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,52 @@
import typing
from marqo.tensor_search import constants
from marqo.tensor_search import enums, utils
from typing import Iterable, Container
from typing import Iterable, Container, Union
from marqo.errors import (
MarqoError, InvalidFieldNameError, InvalidArgError, InternalError,
InvalidDocumentIdError, DocTooLargeError, InvalidIndexNameError)
from marqo.tensor_search.enums import TensorField
from marqo.tensor_search.enums import TensorField, SearchMethod
from marqo.tensor_search import constants
from typing import Any, Type
import inspect
from enum import Enum


def validate_query(q: Union[dict, str], search_method: Union[str, SearchMethod]):
"""
Returns q if an error is not raised"""
usage_ref = "\nSee query reference here: https://docs.marqo.ai/0.0.13/API-Reference/search/#query-q"
if isinstance(q, dict):
if search_method.upper() != SearchMethod.TENSOR:
raise InvalidArgError(
'Multi-query search is currently only supported for search_method="TENSOR" '
f"\nReceived search_method `{search_method}`. {usage_ref}")
if not len(q):
raise InvalidArgError(
"Multi-query search requires at least one query! Received empty dictionary. "
f"{usage_ref}"
)
for k, v in q.items():
base_invalid_kv_message = "Multi queries dictionaries must be <string>:<float> pairs. "
if not isinstance(k, str):
raise InvalidArgError(
f"{base_invalid_kv_message}Found key of type `{type(k)}` instead of string. Key=`{k}`"
f"{usage_ref}"
)
if not isinstance(v, (int, float)):
raise InvalidArgError(
f"{base_invalid_kv_message}Found value of type `{type(v)}` instead of float. Value=`{v}`"
f" {usage_ref}"
)
elif not isinstance(q, str):
raise InvalidArgError(
f"q must be a string or dict! Received q of type `{type(q)}`. "
f"\nq=`{q}`"
f"{usage_ref}"
)
return q


def validate_str_against_enum(value: Any, enum_class: Type[Enum], case_sensitive: bool = True):
"""Checks whether a value is found as the value of a str attribute of the
given enum_class.
Expand Down
Loading

0 comments on commit d6584bf

Please sign in to comment.