From 10fbf58e20eccbb899f0f4bfe28b9c089d5bd2e5 Mon Sep 17 00:00:00 2001 From: Li Wan <49334982+wanliAlex@users.noreply.github.com> Date: Thu, 16 Mar 2023 11:05:53 +1100 Subject: [PATCH] [features] bring your own vectors (#381) * updated * update mock * update mock * change sum to mean. * change sum to mean. * updated * updated * add query type check logic * add maximum number of vectors limit * add tests * updated * catch mainline * catch mainline * catch mainline * catch mainline * add test for vectors * add test for vectors * add test for vectors --- src/marqo/tensor_search/api.py | 3 +- src/marqo/tensor_search/models/api_models.py | 1 + .../tensor_search/models/context_object.py | 31 +++++ src/marqo/tensor_search/tensor_search.py | 52 ++++++-- src/marqo/tensor_search/validation.py | 18 +++ .../test_custom_vectors_search.py | 111 +++++++++++++++ tests/tensor_search/test_validation.py | 126 ++++++++++++++++++ 7 files changed, 327 insertions(+), 15 deletions(-) create mode 100644 src/marqo/tensor_search/models/context_object.py create mode 100644 tests/tensor_search/test_custom_vectors_search.py diff --git a/src/marqo/tensor_search/api.py b/src/marqo/tensor_search/api.py index f53dc4bde..76daabae4 100644 --- a/src/marqo/tensor_search/api.py +++ b/src/marqo/tensor_search/api.py @@ -129,7 +129,8 @@ def search(search_query: SearchQuery, index_name: str, device: str = Depends(api reranker=search_query.reRanker, filter=search_query.filter, device=device, attributes_to_retrieve=search_query.attributesToRetrieve, boost=search_query.boost, - image_download_headers=search_query.image_download_headers + image_download_headers=search_query.image_download_headers, + context=search_query.context ) diff --git a/src/marqo/tensor_search/models/api_models.py b/src/marqo/tensor_search/models/api_models.py index bd09d9d68..555af4110 100644 --- a/src/marqo/tensor_search/models/api_models.py +++ b/src/marqo/tensor_search/models/api_models.py @@ -22,6 +22,7 @@ class SearchQuery(BaseModel): attributesToRetrieve: Union[None, List[str]] = None boost: Optional[Dict] = None image_download_headers: Optional[Dict] = None + context: Optional[Dict] = None @pydantic.validator('searchMethod') def validate_search_method(cls, value): diff --git a/src/marqo/tensor_search/models/context_object.py b/src/marqo/tensor_search/models/context_object.py new file mode 100644 index 000000000..a15d40838 --- /dev/null +++ b/src/marqo/tensor_search/models/context_object.py @@ -0,0 +1,31 @@ +context_schema = { + "$schema": "http://json-schema.org/draft-04/schema#", + "type": "object", + "properties": { + "tensor": { + "type": "array", + "minItems":1, + "maxItems" : 64, + "items": + { + "type": "object", + "properties": { + "vector": { + "type": "array", + "items": {"type": "number"} + }, + "weight": { + "type": "number" + } + }, + "required": [ + "vector", + "weight" + ] + }, + } + }, + "required": [ + "tensor" + ] +} \ No newline at end of file diff --git a/src/marqo/tensor_search/tensor_search.py b/src/marqo/tensor_search/tensor_search.py index 8573eb3a3..55acd8bd4 100644 --- a/src/marqo/tensor_search/tensor_search.py +++ b/src/marqo/tensor_search/tensor_search.py @@ -994,13 +994,13 @@ def bulk_search(query: BulkSearchQuery, marqo_config: config.Config, verbose: bo Args: query: Set of search queries marqo_config: - verbose: - device: + verbose: + device: Notes: Current limitations: - Lexical and tensor search done in serial. - - A single error (e.g. validation errors) on any one of the search queries returns an error and does not + - A single error (e.g. validation errors) on any one of the search queries returns an error and does not process non-erroring queries. """ # TODO: Let non-errored docs to propagate. @@ -1041,7 +1041,7 @@ def bulk_search(query: BulkSearchQuery, marqo_config: config.Config, verbose: bo s["limit"] = q.limit s["offset"] = q.offset - ## TODO: filter out highlights within `_lexical_search` + ## TODO: filter out highlights within `_lexical_search` if not q.showHighlights: for hit in s["hits"]: del hit["_highlights"] @@ -1090,7 +1090,8 @@ def search(config: Config, index_name: str, text: Union[str, dict], reranker: Union[str, Dict] = None, simplified_format: bool = True, filter: str = None, attributes_to_retrieve: Optional[List[str]] = None, device=None, boost: Optional[Dict] = None, - image_download_headers: Optional[Dict] = None) -> Dict: + image_download_headers: Optional[Dict] = None, + context: Optional[Dict] = None) -> Dict: """The root search method. Calls the specific search method Validation should go here. Validations include: @@ -1113,6 +1114,7 @@ def search(config: Config, index_name: str, text: Union[str, dict], num_highlights: number of highlights to return for each doc boost: boosters to re-weight the scores of individual fields image_download_headers: headers for downloading images + context: a dictionary to allow custom vectors in search, for tensor search only Returns: """ @@ -1164,7 +1166,7 @@ def search(config: Config, index_name: str, text: Union[str, dict], return_doc_ids=return_doc_ids, searchable_attributes=searchable_attributes, verbose=verbose, number_of_highlights=num_highlights, simplified_format=simplified_format, filter_string=filter, device=device, attributes_to_retrieve=attributes_to_retrieve, boost=boost, - image_download_headers=image_download_headers + image_download_headers=image_download_headers, context=context ) elif search_method.upper() == SearchMethod.LEXICAL: search_result = _lexical_search( @@ -1582,7 +1584,7 @@ def get_query_vectors_from_jobs( Handles multi-modal queries, by weighting and combining queries into a single vector Args: - - queries: Original search queries. + - queries: Original search queries. - qidx_to_job: VectorisedJobPointer for each query - job_to_vectors: inference output from each VectorisedJob - config: standard Marqo config. @@ -1667,15 +1669,15 @@ def create_empty_query_response(queries: List[BulkSearchQueryEntity]) -> List[Di def _bulk_vector_text_search(config: Config, queries: List[BulkSearchQueryEntity], device=None) -> List[Dict]: - """Resolve a batch of search queries in parallel. + """Resolve a batch of search queries in parallel. Args: - - config: + - config: - queries: A list of independent search queries. Can be across multiple indexes, but are all expected to have `searchMethod = "TENSOR"` Returns: - A list of search query responses (see `_format_ordered_docs_simple` for structure of individual entities). + A list of search query responses (see `_format_ordered_docs_simple` for structure of individual entities). Note: - - Search results are in the same order as `queries`. + - Search results are in the same order as `queries`. """ if len(queries) == 0: return [] @@ -1765,7 +1767,8 @@ def _vector_text_search( verbose=0, raise_on_searchable_attribs=False, hide_vectors=True, k=500, simplified_format=True, filter_string: str = None, device=None, attributes_to_retrieve: Optional[List[str]] = None, boost: Optional[Dict] = None, - image_download_headers: Optional[Dict] = None): + image_download_headers: Optional[Dict] = None, + context: Optional[Dict] = None,): """ Args: config: @@ -1783,12 +1786,15 @@ def _vector_text_search( objects are printed out attributes_to_retrieve: if set, only returns these fields image_download_headers: headers for downloading images + context: a dictionary to allow custom vectors in search Returns: + Note: - uses multisearch, which returns k results in each attribute. Not that much of a concern unless you have a ridiculous number of attributes - Should not be directly called by client - the search() method should be called. The search() method adds syncing + Output format: [ { @@ -1802,6 +1808,17 @@ def _vector_text_search( - searching a non existent index should return a HTTP-type error """ # SEARCH TIMER-LOGGER (pre-processing) + custom_tensors = None + if context is not None: + if isinstance(query, dict): + validation.validate_context_object(context_object=context) + custom_tensors = context.get("tensor", None) + elif isinstance(query, str): + raise errors.InvalidArgError(f"Marqo received a query = `{query}` with type =`{type(query).__name__}` " + f"and a context = `{context}`.\n" + f"This is not supported as the context only works when the query is a dictionary." + f"If you aim to search with your custom vectors, reformat the query as a dictionary.\n" + f"Please check `https://docs.marqo.ai/0.0.16/` for more information.") start_preprocess_time = timer() try: index_info = get_index_info(config=config, index_name=index_name) @@ -1843,8 +1860,15 @@ def _vector_text_search( if q in batch_dict: vec = batch_dict[q] weighted_vectors.append(np.asarray(vec) * weight) - - vectorised_text = np.mean(weighted_vectors, axis=0) + if custom_tensors: + weighted_vectors += [np.asarray(v["vector"]) * v["weight"]for v in custom_tensors] + try: + vectorised_text = np.mean(weighted_vectors, axis=0) + except ValueError as e: + raise errors.InvalidArgError(f"The provided vectors are not in the same dimension of the index." + f"This causes the error when we do `numpy.mean()` over all the vectors.\n" + f"The original error is `{e}`.\n" + f"Please check `https://docs.marqo.ai/0.0.15/API-Reference/search/`.") if index_info.index_settings['index_defaults']['normalize_embeddings']: norm = np.linalg.norm(vectorised_text, axis=-1, keepdims=True) if norm > 0: diff --git a/src/marqo/tensor_search/validation.py b/src/marqo/tensor_search/validation.py index 906ab4a61..3b13e3050 100644 --- a/src/marqo/tensor_search/validation.py +++ b/src/marqo/tensor_search/validation.py @@ -16,6 +16,7 @@ import jsonschema from marqo.tensor_search.models.settings_object import settings_schema from marqo.tensor_search.models.mappings_object import mappings_schema, multimodal_combination_schema +from marqo.tensor_search.models.context_object import context_schema def validate_query(q: Union[dict, str], search_method: Union[str, SearchMethod]): @@ -457,6 +458,23 @@ def validate_multimodal_combination(field_content, is_non_tensor_field, field_ma return True +def validate_context_object(context_object: dict): + """validates the mappings object. + Returns + the given context_object if passed the validation + + Raises an InvalidArgError if the context object is badly formatted + """ + try: + jsonschema.validate(instance=context_object, schema=context_schema) + return context_object + except jsonschema.ValidationError as e: + raise InvalidArgError( + f"Error validating mappings object. Reason: \n{str(e)}" + f"\nRead about the mappings object here: https://docs.marqo.ai/0.0.16" + ) + + def validate_mappings_object(mappings_object: dict): """validates the mappings object. Returns diff --git a/tests/tensor_search/test_custom_vectors_search.py b/tests/tensor_search/test_custom_vectors_search.py new file mode 100644 index 000000000..be882f88a --- /dev/null +++ b/tests/tensor_search/test_custom_vectors_search.py @@ -0,0 +1,111 @@ +import unittest.mock +import pprint + +import torch + +import marqo.tensor_search.backend +from marqo.errors import IndexNotFoundError, InvalidArgError +from marqo.tensor_search import tensor_search +from marqo.tensor_search.enums import TensorField, IndexSettingsField, SearchMethod +from tests.marqo_test import MarqoTestCase +from unittest.mock import patch +import numpy as np + + +class TestMultimodalTensorCombination(MarqoTestCase): + + def setUp(self): + self.index_name_1 = "my-test-index-1" + self.endpoint = self.authorized_url + + try: + tensor_search.delete_index(config=self.config, index_name=self.index_name_1) + except IndexNotFoundError as e: + pass + + tensor_search.create_vector_index( + index_name=self.index_name_1, config=self.config, index_settings={ + IndexSettingsField.index_defaults: { + IndexSettingsField.model: "ViT-B/32", + IndexSettingsField.treat_urls_and_pointers_as_images: True, + IndexSettingsField.normalize_embeddings: True + } + }) + tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ + { + "Title": "Horse rider", + "text_field": "A rider is riding a horse jumping over the barrier.", + "_id": "1" + }], auto_refresh=True) + + def tearDown(self) -> None: + try: + tensor_search.delete_index(config=self.config, index_name=self.index_name_1) + except: + pass + + def test_search(self): + query = { + "A rider is riding a horse jumping over the barrier": 1, + } + res = tensor_search.search(config=self.config, index_name=self.index_name_1, text=query, context= + {"tensor": [{"vector": [1, ] * 512, "weight": 2}, {"vector": [2, ] * 512, "weight": -1}], }) + + def test_search_with_incorrect_tensor_dimension(self): + query = { + "A rider is riding a horse jumping over the barrier": 1, + } + try: + res = tensor_search.search(config=self.config, index_name=self.index_name_1, text=query, context= + {"tensor": [{"vector": [1, ] * 3, "weight": 0}, {"vector": [2, ] * 512, "weight": 0}], }) + raise AssertionError + except InvalidArgError as e: + assert "This causes the error when we do `numpy.mean()` over" in e.message + + def test_search_with_incorrect_query_format(self): + query = "A rider is riding a horse jumping over the barrier" + try: + res = tensor_search.search(config=self.config, index_name=self.index_name_1, text=query, context= + {"tensor": [{"vector": [1, ] * 512, "weight": 0}, {"vector": [2, ] * 512, "weight": 0}], }) + raise AssertionError + except InvalidArgError as e: + assert "This is not supported as the context only works when the query is a dictionary." in e.message + + def test_search_score(self): + query = { + "A rider is riding a horse jumping over the barrier": 1, + } + + res_1 = tensor_search.search(config=self.config, index_name=self.index_name_1, text=query) + res_2 = tensor_search.search(config=self.config, index_name=self.index_name_1, text=query, context= + {"tensor": [{"vector": [1, ] * 512, "weight": 0}, {"vector": [2, ] * 512, "weight": 0}], }) + res_3 = tensor_search.search(config=self.config, index_name=self.index_name_1, text=query, context= + {"tensor": [{"vector": [1, ] * 512, "weight": -1}, {"vector": [1, ] * 512, "weight": 1}], }) + + assert res_1["hits"][0]["_score"] == res_2["hits"][0]["_score"] + assert res_1["hits"][0]["_score"] == res_3["hits"][0]["_score"] + + def test_search_vectors(self): + with patch("numpy.mean", wraps = np.mean) as mock_mean: + query = { + "A rider is riding a horse jumping over the barrier": 1, + } + res_1 = tensor_search.search(config=self.config, index_name=self.index_name_1, text=query) + + weight_1, weight_2, weight_3 = 2.5, 3.4, -1.334 + vector_2 = [-1,] * 512 + vector_3 = [1.3,] * 512 + query = { + "A rider is riding a horse jumping over the barrier": weight_1, + } + + res_2 = tensor_search.search(config=self.config, index_name=self.index_name_1, text=query, context= + {"tensor": [{"vector": vector_2, "weight": weight_2}, {"vector": vector_3, "weight": weight_3}], }) + + args_list = [args[0] for args in mock_mean.call_args_list] + vectorised_string = args_list[0][0][0] + weighted_vectors = args_list[1][0] + + assert np.allclose(vectorised_string * weight_1, weighted_vectors[0], atol=1e-9) + assert np.allclose(np.array(vector_2) * weight_2, weighted_vectors[1], atol=1e-9) + assert np.allclose(np.array(vector_3) * weight_3, weighted_vectors[2], atol=1e-9) \ No newline at end of file diff --git a/tests/tensor_search/test_validation.py b/tests/tensor_search/test_validation.py index f900fd99c..9fa6b1fa4 100644 --- a/tests/tensor_search/test_validation.py +++ b/tests/tensor_search/test_validation.py @@ -9,6 +9,7 @@ InvalidDocumentIdError, InvalidArgError, DocTooLargeError, InvalidIndexNameError ) +import pprint class TestValidation(unittest.TestCase): @@ -831,3 +832,128 @@ def test_validate_multimodal_combination_object_invalid(self): except InvalidArgError as e: pass + def test_validate_valid_context_object(self): + valid_context_list = [ + { + "tensor":[ + {"vector" : [0.2132] * 512, "weight" : 0.32}, + {"vector": [0.2132] * 512, "weight": 0.32}, + {"vector": [0.2132] * 512, "weight": 0.32}, + ] + }, + { + "tensor": [ + {"vector": [0.2132] * 512, "weight": 1}, + {"vector": [0.2132] * 512, "weight": 1}, + {"vector": [0.2132] * 512, "weight": 1}, + ] + }, + + { + # Note we are not validating the vector size here + "tensor": [ + {"vector": [0.2132] * 53, "weight": 1}, + {"vector": [23,], "weight": 1}, + {"vector": [0.2132] * 512, "weight": 1}, + ], + "addition_field": None + }, + { + "tensor": [ + {"vector": [0.2132] * 53, "weight": 1}, + {"vector": [23, ], "weight": 1}, + {"vector": [0.2132] * 512, "weight": 1}, + ], + "addition_field_1": None, + "addition_field_2": "random" + }, + { + "tensor": [ + {"vector": [0.2132] * 512, "weight": 0.32}, + ] * 64 + }, + ] + + for valid_context in valid_context_list: + assert valid_context == validation.validate_context_object(valid_context) + + def test_validate_invalid_context_object(self): + valid_context_list = [ + { + # Typo in tensor + "tensors": [ + {"vector" : [0.2132] * 512, "weight" : 0.32}, + {"vector": [0.2132] * 512, "weight": 0.32}, + {"vector": [0.2132] * 512, "weight": 0.32}, + ] + }, + { + # Typo in vector + "tensor": [ + {"vectors": [0.2132] * 512, "weight": 1}, + {"vector": [0.2132] * 512, "weight": 1}, + {"vector": [0.2132] * 512, "weight": 1}, + ] + }, + { + # Typo in weight + "tensor": [ + {"vector": [0.2132] * 53, "weight": 1}, + {"vector": [23,], "weight": 1}, + {"vector": [0.2132] * 512, "weights": 1}, + ], + "addition_field": None + }, + { + # Int instead of list + "tensor": [ + {"vector": [0.2132] * 53, "weight": 1}, + {"vector": [23, ], "weight": 1}, + {"vector": 3, "weight": 1}, + ], + "addition_field_1": None, + "addition_field_2": "random" + }, + { + # Str instead of list + "tensor": [ + {"vector" : str([0.2132] * 512), "weight": 0.32}, + {"vector": [0.2132] * 512, "weight": 0.32}, + {"vector": [0.2132] * 512, "weight": 0.32}, + ], + "addition_field_1": None, + "addition_field_2": "random" + }, + { + # None instead of list + "tensor": [ + {"vector": [0.2132] * 53, "weight": 1}, + {"vector": [23, ], "weight": 1}, + {"vectors": None, "weight": 1}, + ], + "addition_field_1": None, + "addition_field_2": "random" + }, + { + # too many vectors, maximum 64 + "tensor": [ + {"vector": [0.2132] * 512, "weight": 0.32}, + ] * 65 + }, + { + # None + "tensor": None, + }, + { + # Empty tensor + "tensor": [], + }, + ] + + for invalid_context in valid_context_list: + try: + validation.validate_context_object(invalid_context) + pprint.pprint(invalid_context) + raise AssertionError + except InvalidArgError: + pass \ No newline at end of file