Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Boost tensor fields #300

Merged
merged 15 commits into from
Feb 10, 2023
2 changes: 1 addition & 1 deletion src/marqo/tensor_search/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def search(search_query: SearchQuery, index_name: str, device: str = Depends(api
result_count=search_query.limit, offset=search_query.offset,
reranker=search_query.reRanker,
filter=search_query.filter, device=device,
attributes_to_retrieve=search_query.attributesToRetrieve
attributes_to_retrieve=search_query.attributesToRetrieve, boost=search_query.boost
)


Expand Down
3 changes: 2 additions & 1 deletion src/marqo/tensor_search/models/api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"""
import pydantic
from pydantic import BaseModel
from typing import Union, List, Dict
from typing import Union, List, Dict, Optional
from marqo.tensor_search.enums import SearchMethod, Device
from marqo.tensor_search import validation
from marqo.tensor_search import configs
Expand All @@ -21,6 +21,7 @@ class SearchQuery(BaseModel):
reRanker: str = None
filter: str = None
attributesToRetrieve: List[str] = None
boost: Optional[Dict] = None

@pydantic.validator('searchMethod')
def validate_search_method(cls, value):
Expand Down
39 changes: 34 additions & 5 deletions src/marqo/tensor_search/tensor_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ def search(config: Config, index_name: str, text: str, result_count: int = 3, of
searchable_attributes: Iterable[str] = None, verbose: int = 0, num_highlights: int = 3,
reranker: Union[str, Dict] = None, simplified_format: bool = True, filter: str = None,
attributes_to_retrieve: Optional[List[str]] = None,
device=None) -> Dict:
device=None, boost: Optional[Dict] = None) -> Dict:
"""The root search method. Calls the specific search method

Validation should go here. Validations include:
Expand All @@ -790,6 +790,7 @@ def search(config: Config, index_name: str, text: str, result_count: int = 3, of
searchable_attributes:
verbose:
num_highlights: number of highlights to return for each doc
boost: boosters to re-weight the scores of individual fields

Returns:

Expand Down Expand Up @@ -837,7 +838,7 @@ def search(config: Config, index_name: str, text: str, result_count: int = 3, of
config=config, index_name=index_name, text=text, result_count=result_count, offset=offset,
return_doc_ids=return_doc_ids, searchable_attributes=searchable_attributes, verbose=verbose,
number_of_highlights=num_highlights, simplified_format=simplified_format,
filter_string=filter, device=device, attributes_to_retrieve=attributes_to_retrieve
filter_string=filter, device=device, attributes_to_retrieve=attributes_to_retrieve, boost=boost
)
elif search_method.upper() == SearchMethod.LEXICAL:
search_result = _lexical_search(
Expand Down Expand Up @@ -987,7 +988,7 @@ def _vector_text_search(
searchable_attributes: Iterable[str] = None, number_of_highlights=3,
verbose=0, raise_on_searchable_attribs=False, hide_vectors=True, k=500,
simplified_format=True, filter_string: str = None, device=None,
attributes_to_retrieve: Optional[List[str]] = None):
attributes_to_retrieve: Optional[List[str]] = None, boost: Optional[Dict] = None):
"""
Args:
config:
Expand Down Expand Up @@ -1204,16 +1205,44 @@ def _vector_text_search(
if not gathered_docs[doc_id]["chunks"]:
del gathered_docs[doc_id]

# SORT THE DOCS HERE

def boost_score(docs: dict, boosters: dict) -> dict:
""" re-weighs the scores of individual fields
Args:
docs:
boosters: {'field_to_be_boosted': (int, int)}
"""
to_be_boosted = docs.copy()
boosted_fields = set()

for doc_id in list(to_be_boosted.keys()):
for chunk in to_be_boosted[doc_id]["chunks"]:
field_name = chunk['_source']['__field_name']
if field_name in boosters.keys():
booster = boosters[field_name]
chunk['_score'] = chunk['_score'] * booster[0] + booster[1]

boosted_fields.add(field_name)

if set(boosters.keys()) != boosted_fields:
raise errors.InvalidArgError(f"Could not boost field(s): {set(boosters.keys()) - boosted_fields}. "
f"Please check if the indexed documents contain the specified field(s).")

return to_be_boosted

# SORT THE DOCS HERE
def sort_chunks(docs: dict) -> dict:
to_be_sorted = docs.copy()
for doc_id in list(to_be_sorted.keys()):
to_be_sorted[doc_id]["chunks"] = sorted(
to_be_sorted[doc_id]["chunks"], key=lambda x: x["_score"], reverse=True)
return to_be_sorted

docs_chunks_sorted = sort_chunks(gathered_docs)
if boost is not None:
docs_chunk_boosted = boost_score(gathered_docs, boost)
docs_chunks_sorted = sort_chunks(docs_chunk_boosted)
else:
docs_chunks_sorted = sort_chunks(gathered_docs)

def sort_docs(docs: dict) -> List[dict]:
as_list = list(docs.values())
Expand Down
95 changes: 95 additions & 0 deletions tests/tensor_search/test_boost_field_scores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from marqo.errors import IndexNotFoundError, InvalidArgError
from marqo.tensor_search import tensor_search

from tests.marqo_test import MarqoTestCase


class TestBoostFieldScores(MarqoTestCase):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we test with more documents (10 or so) with varying fields, with varying other params (like pagination)?


def setUp(self):
self.index_name_1 = "my-test-index-1"
try:
tensor_search.delete_index(config=self.config, index_name=self.index_name_1)
except IndexNotFoundError as e:
pass
finally:
tensor_search.create_vector_index(
index_name=self.index_name_1, config=self.config)

tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[
{
"Title": "The Travels of Marco Polo",
"Description": "A 13th-century travelogue describing Polo's travels",
"_id": "article_590"
}
,
{
"Title": "Extravehicular Mobility Unit (EMU)",
"Description": "The EMU is a spacesuit that provides environmental protection, "
"mobility, life support, and communications for astronauts",
"_id": "article_591"
}
], auto_refresh=True)

def tearDown(self) -> None:
pass

def test_score_is_boosted(self):
q = "What is the best outfit to wear on the moon?"

res = tensor_search.search(
config=self.config, index_name=self.index_name_1, text=q,
)
res_boosted = tensor_search.search(
config=self.config, index_name=self.index_name_1, text=q, boost={'Title': (5, 1)}
)

score = res['hits'][0]['_score']
score_boosted = res_boosted['hits'][0]['_score']

self.assertGreater(score_boosted, score)

def test_boost_empty_dict(self):
"""Passing an empty dict in the boost argument should not affect the score.
"""
q = "What is the best outfit to wear on the moon?"

res = tensor_search.search(
config=self.config, index_name=self.index_name_1, text=q
)
res_boosted = tensor_search.search(
config=self.config, index_name=self.index_name_1, text=q, boost={}
)

score = res['hits'][0]['_score']
score_boosted = res_boosted['hits'][0]['_score']

self.assertEqual(score_boosted, score)

def test_different_attributes_searched_and_boosted(self):
"""An error should be raised if the user tries to
boost a field which is not being searched.
"""
q = "What is the best outfit to wear on the moon?"

with self.assertRaises(InvalidArgError) as ctx:
res_boosted = tensor_search.search(
config=self.config, index_name=self.index_name_1, text=q,
searchable_attributes=['Description'], boost={'Title': (0.5, 1)}
)

self.assertTrue('Title' in str(ctx.exception))

def test_boost_invalid_fields(self):
"""An error should be raised if the user tries to boost a non-existent field.
The error message should tell the user which field(s) were unable to be boosted.
"""
q = "What is the best outfit to wear on the moon?"

with self.assertRaises(InvalidArgError) as ctx:
res_boosted = tensor_search.search(
config=self.config, index_name=self.index_name_1, text=q,
boost={'Title': (0.2, 1), 'invalid_field_name': (0.5, 1)}
)

self.assertTrue('invalid_field_name' in str(ctx.exception))