Skip to content

Commit

Permalink
Boost tensor fields (#300)
Browse files Browse the repository at this point in the history
* added field score boosters

* add tests

* added error handling

* added tests

* Added validation to boost

* added exclude vectors to search

* add a test to test boost equation.

* add a test to test boost equation.

* add a test to test boost equation.

* add test to different scores

* add test to different scores

* delete print

* added extra image search test

---------

Co-authored-by: aryanagarwal9 <59826369+aryanagarwal9@users.noreply.github.com>
Co-authored-by: Li Wan <lwan3@student.unimelb.edu.au>
  • Loading branch information
3 people authored Feb 10, 2023
1 parent 908db08 commit 0a0cea4
Show file tree
Hide file tree
Showing 8 changed files with 534 additions and 16 deletions.
11 changes: 5 additions & 6 deletions src/marqo/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,13 @@ class MarqoWebError(Exception):
message: str = None
code: str = None
link: str = ""

base_message = ("Please create an issue on Marqo's GitHub repo"
" (https://github.com/marqo-ai/marqo/issues) "
"if this problem persists.")
def __init__(self, message: str, status_code: int = None,
error_type: str = None, code: str = None,
link: str = None) -> None:
base_message = ("Please create an issue on Marqo's GitHub repo"
" (https://github.com/marqo-ai/marqo/issues) "
"if this problem persists.")
self.message = f"{message}\n{base_message}"
self.message = f"{message}\n{self.base_message}"

if self.status_code is None:
self.status_code = status_code
Expand All @@ -87,7 +86,7 @@ def __init__(self, message: str, status_code: int = None,
super().__init__(self.message)

def __str__(self) -> str:
return f'{self.__class__.__name__} Message: {self.message}'
return f'{self.__class__.__name__}: {self.message}\n{self.base_message}'

# ---MARQO USER ERRORS---

Expand Down
2 changes: 1 addition & 1 deletion src/marqo/tensor_search/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def search(search_query: SearchQuery, index_name: str, device: str = Depends(api
result_count=search_query.limit, offset=search_query.offset,
reranker=search_query.reRanker,
filter=search_query.filter, device=device,
attributes_to_retrieve=search_query.attributesToRetrieve
attributes_to_retrieve=search_query.attributesToRetrieve, boost=search_query.boost
)


Expand Down
3 changes: 2 additions & 1 deletion src/marqo/tensor_search/models/api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
import pydantic
from pydantic import BaseModel
from typing import Union, List, Dict
from typing import Union, List, Dict, Optional
from marqo.tensor_search.enums import SearchMethod, Device
from marqo.tensor_search import validation
from marqo.tensor_search import configs
Expand All @@ -21,6 +21,7 @@ class SearchQuery(BaseModel):
reRanker: str = None
filter: str = None
attributesToRetrieve: List[str] = None
boost: Optional[Dict] = None

@pydantic.validator('searchMethod')
def validate_search_method(cls, value):
Expand Down
48 changes: 41 additions & 7 deletions src/marqo/tensor_search/tensor_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ def search(config: Config, index_name: str, text: Union[str, dict],
searchable_attributes: Iterable[str] = None, verbose: int = 0, num_highlights: int = 3,
reranker: Union[str, Dict] = None, simplified_format: bool = True, filter: str = None,
attributes_to_retrieve: Optional[List[str]] = None,
device=None) -> Dict:
device=None, boost: Optional[Dict] = None) -> Dict:
"""The root search method. Calls the specific search method
Validation should go here. Validations include:
Expand All @@ -815,6 +815,7 @@ def search(config: Config, index_name: str, text: Union[str, dict],
searchable_attributes:
verbose:
num_highlights: number of highlights to return for each doc
boost: boosters to re-weight the scores of individual fields
Returns:
Expand All @@ -838,10 +839,9 @@ def search(config: Config, index_name: str, text: Union[str, dict],

raise errors.IllegalRequestedDocCount(f"{upper_bound_explanation} Marqo received search result limit of `{result_count}` "
f"and offset of `{offset}`.")


t0 = timer()

validation.validate_boost(boost=boost, search_method=search_method)
if searchable_attributes is not None:
[validation.validate_field_name(attribute) for attribute in searchable_attributes]
if attributes_to_retrieve is not None:
Expand All @@ -865,7 +865,7 @@ def search(config: Config, index_name: str, text: Union[str, dict],
config=config, index_name=index_name, query=text, result_count=result_count, offset=offset,
return_doc_ids=return_doc_ids, searchable_attributes=searchable_attributes, verbose=verbose,
number_of_highlights=num_highlights, simplified_format=simplified_format,
filter_string=filter, device=device, attributes_to_retrieve=attributes_to_retrieve
filter_string=filter, device=device, attributes_to_retrieve=attributes_to_retrieve, boost=boost
)
elif search_method.upper() == SearchMethod.LEXICAL:
search_result = _lexical_search(
Expand Down Expand Up @@ -1032,7 +1032,7 @@ def _vector_text_search(
return_doc_ids=False, searchable_attributes: Iterable[str] = None, number_of_highlights=3,
verbose=0, raise_on_searchable_attribs=False, hide_vectors=True, k=500,
simplified_format=True, filter_string: str = None, device=None,
attributes_to_retrieve: Optional[List[str]] = None):
attributes_to_retrieve: Optional[List[str]] = None, boost: Optional[Dict] = None):
"""
Args:
config:
Expand Down Expand Up @@ -1171,6 +1171,9 @@ def _vector_text_search(
},
"score_mode": "max"
}
},
"_source": {
"exclude": ["__chunks.__vector_*"]
}
}

Expand Down Expand Up @@ -1271,16 +1274,47 @@ def _vector_text_search(
if not gathered_docs[doc_id]["chunks"]:
del gathered_docs[doc_id]

# SORT THE DOCS HERE
def boost_score(docs: dict, boosters: dict) -> dict:
""" re-weighs the scores of individual fields
Args:
docs:
boosters: {'field_to_be_boosted': (int, int)}
"""
to_be_boosted = docs.copy()
boosted_fields = set()
if searchable_attributes and boosters:
if not set(boosters).issubset(set(searchable_attributes)):
raise errors.InvalidArgError(
"Boost fieldnames must be a subset of searchable attributes. "
f"\nSearchable attributes: {searchable_attributes}"
f"\nBoost: {boosters}"
)

for doc_id in list(to_be_boosted.keys()):
for chunk in to_be_boosted[doc_id]["chunks"]:
field_name = chunk['_source']['__field_name']
if field_name in boosters.keys():
booster = boosters[field_name]
if len(booster) == 2:
chunk['_score'] = chunk['_score'] * booster[0] + booster[1]
else:
chunk['_score'] = chunk['_score'] * booster[0]
boosted_fields.add(field_name)
return to_be_boosted

# SORT THE DOCS HERE
def sort_chunks(docs: dict) -> dict:
to_be_sorted = docs.copy()
for doc_id in list(to_be_sorted.keys()):
to_be_sorted[doc_id]["chunks"] = sorted(
to_be_sorted[doc_id]["chunks"], key=lambda x: x["_score"], reverse=True)
return to_be_sorted

docs_chunks_sorted = sort_chunks(gathered_docs)
if boost is not None:
docs_chunk_boosted = boost_score(gathered_docs, boost)
docs_chunks_sorted = sort_chunks(docs_chunk_boosted)
else:
docs_chunks_sorted = sort_chunks(gathered_docs)

def sort_docs(docs: dict) -> List[dict]:
as_list = list(docs.values())
Expand Down
52 changes: 51 additions & 1 deletion src/marqo/tensor_search/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
InvalidDocumentIdError, DocTooLargeError, InvalidIndexNameError)
from marqo.tensor_search.enums import TensorField, SearchMethod
from marqo.tensor_search import constants
from typing import Any, Type
from typing import Any, Type, Sequence
import inspect
from enum import Enum

Expand Down Expand Up @@ -87,6 +87,56 @@ def validate_field_content(field_content: typing.Any) -> typing.Any:
)


def validate_boost(boost: dict, search_method: typing.Union[str, SearchMethod]):
if boost is not None:
further_info_message = ("\nRead about boost usage here: "
"https://docs.marqo.ai/0.0.13/API-Reference/search/#boost")
for boost_attr in boost:
try:
validate_field_name(boost_attr)
except InvalidFieldNameError as e:
raise InvalidFieldNameError(f"Invalid boost dictionary. {e.message} {further_info_message}")
if search_method != SearchMethod.TENSOR:
# to be removed if boosting is implemented for lexical
raise InvalidArgError(
f'Boosting is only supported for search_method="TENSOR". '
f'Received search_method={search_method}'
f'{further_info_message}'
)
if not isinstance(boost, dict):
raise InvalidArgError(
f'Boost must be a dictionary. Instead received boost of value `{boost}`'
f'{further_info_message}'
)
for k, v in boost.items():
base_invalid_kv_message = (
"Boost dictionaries have structure <attribute (string)>: <[weight (float), bias (float)]>\n")
if not isinstance(k, str):
raise InvalidArgError(
f'{base_invalid_kv_message}Found key of type `{type(k)}` instead of string. Key=`{k}`'
f"{further_info_message}"
)
if not isinstance(v, Sequence):
raise InvalidArgError(
f'{base_invalid_kv_message}Found value of type `{type(v)}` instead of Array. Value=`{v}`'
f"{further_info_message}"
)
if len(v) not in [1, 2]:
raise InvalidArgError(
f'{base_invalid_kv_message}An attribute boost must have a weight float and optional bias float. '
f'Instead received invalid boost `{v}`'
f"{further_info_message}"
)
for wb in v:
if not isinstance(wb, (int, float)):
raise InvalidArgError(
f'{base_invalid_kv_message}An attribute boost must have a weight float and optional bias float. '
f'Instead received boost `{v}` with invalid member `{wb}` of type {type(wb)} '
f"{further_info_message}"
)
return boost


def validate_field_name(field_name) -> str:
"""TODO:
- length (remember the vector name will have the vector_prefix added to the front of field_name)
Expand Down
Loading

0 comments on commit 0a0cea4

Please sign in to comment.