Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Boost tensor fields #300

Merged
merged 15 commits into from
Feb 10, 2023
11 changes: 5 additions & 6 deletions src/marqo/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,13 @@ class MarqoWebError(Exception):
message: str = None
code: str = None
link: str = ""

base_message = ("Please create an issue on Marqo's GitHub repo"
" (https://github.com/marqo-ai/marqo/issues) "
"if this problem persists.")
def __init__(self, message: str, status_code: int = None,
error_type: str = None, code: str = None,
link: str = None) -> None:
base_message = ("Please create an issue on Marqo's GitHub repo"
" (https://github.com/marqo-ai/marqo/issues) "
"if this problem persists.")
self.message = f"{message}\n{base_message}"
self.message = f"{message}\n{self.base_message}"

if self.status_code is None:
self.status_code = status_code
Expand All @@ -87,7 +86,7 @@ def __init__(self, message: str, status_code: int = None,
super().__init__(self.message)

def __str__(self) -> str:
return f'{self.__class__.__name__} Message: {self.message}'
return f'{self.__class__.__name__}: {self.message}\n{self.base_message}'

# ---MARQO USER ERRORS---

Expand Down
2 changes: 1 addition & 1 deletion src/marqo/tensor_search/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def search(search_query: SearchQuery, index_name: str, device: str = Depends(api
result_count=search_query.limit, offset=search_query.offset,
reranker=search_query.reRanker,
filter=search_query.filter, device=device,
attributes_to_retrieve=search_query.attributesToRetrieve
attributes_to_retrieve=search_query.attributesToRetrieve, boost=search_query.boost
)


Expand Down
3 changes: 2 additions & 1 deletion src/marqo/tensor_search/models/api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
import pydantic
from pydantic import BaseModel
from typing import Union, List, Dict
from typing import Union, List, Dict, Optional
from marqo.tensor_search.enums import SearchMethod, Device
from marqo.tensor_search import validation
from marqo.tensor_search import configs
Expand All @@ -21,6 +21,7 @@ class SearchQuery(BaseModel):
reRanker: str = None
filter: str = None
attributesToRetrieve: List[str] = None
boost: Optional[Dict] = None

@pydantic.validator('searchMethod')
def validate_search_method(cls, value):
Expand Down
48 changes: 41 additions & 7 deletions src/marqo/tensor_search/tensor_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ def search(config: Config, index_name: str, text: Union[str, dict],
searchable_attributes: Iterable[str] = None, verbose: int = 0, num_highlights: int = 3,
reranker: Union[str, Dict] = None, simplified_format: bool = True, filter: str = None,
attributes_to_retrieve: Optional[List[str]] = None,
device=None) -> Dict:
device=None, boost: Optional[Dict] = None) -> Dict:
"""The root search method. Calls the specific search method

Validation should go here. Validations include:
Expand All @@ -793,6 +793,7 @@ def search(config: Config, index_name: str, text: Union[str, dict],
searchable_attributes:
verbose:
num_highlights: number of highlights to return for each doc
boost: boosters to re-weight the scores of individual fields

Returns:

Expand All @@ -816,10 +817,9 @@ def search(config: Config, index_name: str, text: Union[str, dict],

raise errors.IllegalRequestedDocCount(f"{upper_bound_explanation} Marqo received search result limit of `{result_count}` "
f"and offset of `{offset}`.")


t0 = timer()

validation.validate_boost(boost=boost, search_method=search_method)
if searchable_attributes is not None:
[validation.validate_field_name(attribute) for attribute in searchable_attributes]
if attributes_to_retrieve is not None:
Expand All @@ -843,7 +843,7 @@ def search(config: Config, index_name: str, text: Union[str, dict],
config=config, index_name=index_name, query=text, result_count=result_count, offset=offset,
return_doc_ids=return_doc_ids, searchable_attributes=searchable_attributes, verbose=verbose,
number_of_highlights=num_highlights, simplified_format=simplified_format,
filter_string=filter, device=device, attributes_to_retrieve=attributes_to_retrieve
filter_string=filter, device=device, attributes_to_retrieve=attributes_to_retrieve, boost=boost
)
elif search_method.upper() == SearchMethod.LEXICAL:
search_result = _lexical_search(
Expand Down Expand Up @@ -1010,7 +1010,7 @@ def _vector_text_search(
return_doc_ids=False, searchable_attributes: Iterable[str] = None, number_of_highlights=3,
verbose=0, raise_on_searchable_attribs=False, hide_vectors=True, k=500,
simplified_format=True, filter_string: str = None, device=None,
attributes_to_retrieve: Optional[List[str]] = None):
attributes_to_retrieve: Optional[List[str]] = None, boost: Optional[Dict] = None):
"""
Args:
config:
Expand Down Expand Up @@ -1149,6 +1149,9 @@ def _vector_text_search(
},
"score_mode": "max"
}
},
"_source": {
"exclude": ["__chunks.__vector_*"]
}
}

Expand Down Expand Up @@ -1249,16 +1252,47 @@ def _vector_text_search(
if not gathered_docs[doc_id]["chunks"]:
del gathered_docs[doc_id]

# SORT THE DOCS HERE
def boost_score(docs: dict, boosters: dict) -> dict:
""" re-weighs the scores of individual fields
Args:
docs:
boosters: {'field_to_be_boosted': (int, int)}
"""
to_be_boosted = docs.copy()
boosted_fields = set()
if searchable_attributes and boosters:
if not set(boosters).issubset(set(searchable_attributes)):
raise errors.InvalidArgError(
"Boost fieldnames must be a subset of searchable attributes. "
f"\nSearchable attributes: {searchable_attributes}"
f"\nBoost: {boosters}"
)

for doc_id in list(to_be_boosted.keys()):
for chunk in to_be_boosted[doc_id]["chunks"]:
field_name = chunk['_source']['__field_name']
if field_name in boosters.keys():
booster = boosters[field_name]
if len(booster) == 2:
chunk['_score'] = chunk['_score'] * booster[0] + booster[1]
else:
chunk['_score'] = chunk['_score'] * booster[0]
boosted_fields.add(field_name)
return to_be_boosted

# SORT THE DOCS HERE
def sort_chunks(docs: dict) -> dict:
to_be_sorted = docs.copy()
for doc_id in list(to_be_sorted.keys()):
to_be_sorted[doc_id]["chunks"] = sorted(
to_be_sorted[doc_id]["chunks"], key=lambda x: x["_score"], reverse=True)
return to_be_sorted

docs_chunks_sorted = sort_chunks(gathered_docs)
if boost is not None:
docs_chunk_boosted = boost_score(gathered_docs, boost)
docs_chunks_sorted = sort_chunks(docs_chunk_boosted)
else:
docs_chunks_sorted = sort_chunks(gathered_docs)

def sort_docs(docs: dict) -> List[dict]:
as_list = list(docs.values())
Expand Down
52 changes: 51 additions & 1 deletion src/marqo/tensor_search/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
InvalidDocumentIdError, DocTooLargeError, InvalidIndexNameError)
from marqo.tensor_search.enums import TensorField, SearchMethod
from marqo.tensor_search import constants
from typing import Any, Type
from typing import Any, Type, Sequence
import inspect
from enum import Enum

Expand Down Expand Up @@ -87,6 +87,56 @@ def validate_field_content(field_content: typing.Any) -> typing.Any:
)


def validate_boost(boost: dict, search_method: typing.Union[str, SearchMethod]):
if boost is not None:
further_info_message = ("\nRead about boost usage here: "
"https://docs.marqo.ai/0.0.13/API-Reference/search/#boost")
for boost_attr in boost:
try:
validate_field_name(boost_attr)
except InvalidFieldNameError as e:
raise InvalidFieldNameError(f"Invalid boost dictionary. {e.message} {further_info_message}")
if search_method != SearchMethod.TENSOR:
# to be removed if boosting is implemented for lexical
raise InvalidArgError(
f'Boosting is only supported for search_method="TENSOR". '
f'Received search_method={search_method}'
f'{further_info_message}'
)
if not isinstance(boost, dict):
raise InvalidArgError(
f'Boost must be a dictionary. Instead received boost of value `{boost}`'
f'{further_info_message}'
)
for k, v in boost.items():
base_invalid_kv_message = (
"Boost dictionaries have structure <attribute (string)>: <[weight (float), bias (float)]>\n")
if not isinstance(k, str):
raise InvalidArgError(
f'{base_invalid_kv_message}Found key of type `{type(k)}` instead of string. Key=`{k}`'
f"{further_info_message}"
)
if not isinstance(v, Sequence):
raise InvalidArgError(
f'{base_invalid_kv_message}Found value of type `{type(v)}` instead of Array. Value=`{v}`'
f"{further_info_message}"
)
if len(v) not in [1, 2]:
raise InvalidArgError(
f'{base_invalid_kv_message}An attribute boost must have a weight float and optional bias float. '
f'Instead received invalid boost `{v}`'
f"{further_info_message}"
)
for wb in v:
if not isinstance(wb, (int, float)):
raise InvalidArgError(
f'{base_invalid_kv_message}An attribute boost must have a weight float and optional bias float. '
f'Instead received boost `{v}` with invalid member `{wb}` of type {type(wb)} '
f"{further_info_message}"
)
return boost


def validate_field_name(field_name) -> str:
"""TODO:
- length (remember the vector name will have the vector_prefix added to the front of field_name)
Expand Down
Loading