Skip to content

Commit

Permalink
Add hybrid search (#235)
Browse files Browse the repository at this point in the history
  • Loading branch information
vicilliar authored Jul 10, 2024
1 parent 5ee1ec3 commit d45985c
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/marqo/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def search(self, q: Optional[Union[str, dict]] = None, searchable_attributes: Op
context: Optional[dict] = None, score_modifiers: Optional[dict] = None,
model_auth: Optional[dict] = None,
ef_search: Optional[int] = None, approximate: Optional[bool] = None,
text_query_prefix: Optional[str] = None,
text_query_prefix: Optional[str] = None, hybrid_parameters: Optional[dict] = None
) -> Dict[str, Any]:
"""Search the index.
Expand Down Expand Up @@ -273,6 +273,7 @@ def search(self, q: Optional[Union[str, dict]] = None, searchable_attributes: Op
"reRanker": reranker,
"boost": boost,
"textQueryPrefix": text_query_prefix,
"hybridParameters": hybrid_parameters
}

body = {k: v for k, v in body.items() if v is not None}
Expand Down
16 changes: 16 additions & 0 deletions src/marqo/models/search_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import Dict, List, Optional, Union
from marqo.models.marqo_models import StrictBaseModel
from abc import ABC
from enum import Enum

from pydantic import validator, BaseModel, root_validator


class SearchBody(StrictBaseModel):
Expand All @@ -26,3 +30,15 @@ class BulkSearchBody(SearchBody):
class BulkSearchQuery(StrictBaseModel):
queries: List[BulkSearchBody]


class RetrievalMethod(str, Enum):
Disjunction = 'disjunction'
Tensor = 'tensor'
Lexical = 'lexical'


class RankingMethod(str, Enum):
RRF = 'rrf'
NormalizeLinear = 'normalize_linear'
Tensor = 'tensor'
Lexical = 'lexical'
199 changes: 199 additions & 0 deletions tests/v2_tests/test_hybrid_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import copy
import marqo
from marqo import enums
from unittest import mock
import requests
import random
import math
import time
from tests.marqo_test import MarqoTestCase, CloudTestIndex
from marqo.errors import MarqoWebError
from pytest import mark


@mark.fixed
class TestHybridSearch(MarqoTestCase):
@staticmethod
def strip_marqo_fields(doc, strip_id=True):
"""Strips Marqo fields from a returned doc to get the original doc"""
copied = copy.deepcopy(doc)

strip_fields = ["_highlights", "_score"]
if strip_id:
strip_fields += ["_id"]

for to_strip in strip_fields:
del copied[to_strip]

return copied

def setUp(self):
super().setUp()
self.docs_list = [
# TODO: add score modifiers
# similar semantics to dogs
{"_id": "doc1", "text_field_1": "dogs"},
{"_id": "doc2", "text_field_1": "puppies"},
{"_id": "doc3", "text_field_1": "canines"},
{"_id": "doc4", "text_field_1": "huskies"},
{"_id": "doc5", "text_field_1": "four-legged animals"},

# shares lexical token with dogs
{"_id": "doc6", "text_field_1": "hot dogs"},
{"_id": "doc7", "text_field_1": "dogs is a word"},
{"_id": "doc8", "text_field_1": "something something dogs"},
{"_id": "doc9", "text_field_1": "dogs random words"},
{"_id": "doc10", "text_field_1": "dogs dogs dogs"},

{"_id": "doc11", "text_field_2": "dogs but wrong field"},
{"_id": "doc12", "text_field_2": "puppies puppies"},
{"_id": "doc13", "text_field_2": "canines canines"},
]

def test_hybrid_search_searchable_attributes(self):
"""
Tests that searchable attributes work as expected for all methods
"""

index_test_cases = [
(CloudTestIndex.structured_text, self.structured_index_name) # TODO: add unstructured when supported
]
for cloud_test_index_to_use, open_source_test_index_name in index_test_cases:
test_index_name = self.get_test_index_name(
cloud_test_index_to_use=cloud_test_index_to_use,
open_source_test_index_name=open_source_test_index_name
)
self.client.index(test_index_name).add_documents(self.docs_list)

with self.subTest("retrieval: disjunction, ranking: rrf"):
hybrid_res = self.client.index(test_index_name).search(
"puppies",
search_method="HYBRID",
hybrid_parameters={
"retrievalMethod": "disjunction",
"rankingMethod": "rrf",
"alpha": 0.5,
"searchableAttributesLexical": ["text_field_2"],
"searchableAttributesTensor": ["text_field_2"]
},
limit=10
)
self.assertEqual(len(hybrid_res["hits"]), 3) # Only 3 documents have text_field_2 at all
self.assertEqual(hybrid_res["hits"][0]["_id"], "doc12") # puppies puppies in text field 2
self.assertEqual(hybrid_res["hits"][1]["_id"], "doc13")
self.assertEqual(hybrid_res["hits"][2]["_id"], "doc11")

with self.subTest("retrieval: lexical, ranking: tensor"):
hybrid_res = self.client.index(test_index_name).search(
"puppies",
search_method="HYBRID",
hybrid_parameters={
"retrievalMethod": "lexical",
"rankingMethod": "tensor",
"searchableAttributesLexical": ["text_field_2"]
},
limit=10
)
self.assertEqual(len(hybrid_res["hits"]),
1) # Only 1 document has puppies in text_field_2. Lexical retrieval will only get this one.
self.assertEqual(hybrid_res["hits"][0]["_id"], "doc12")

with self.subTest("retrieval: tensor, ranking: lexical"):
hybrid_res = self.client.index(test_index_name).search(
"puppies",
search_method="HYBRID",
hybrid_parameters={
"retrievalMethod": "tensor",
"rankingMethod": "lexical",
"searchableAttributesTensor": ["text_field_2"]
},
limit=10
)
self.assertEqual(len(hybrid_res["hits"]),
3) # Only 3 documents have text field 2. Tensor retrieval will get them all.
self.assertEqual(hybrid_res["hits"][0]["_id"], "doc12")
self.assertEqual(hybrid_res["hits"][1]["_id"], "doc11")
self.assertEqual(hybrid_res["hits"][2]["_id"], "doc13")

def test_hybrid_search_same_retrieval_and_ranking_matches_original_method(self):
"""
Tests that hybrid search with:
retrievalMethod = "lexical", rankingMethod = "lexical" and
retrievalMethod = "tensor", rankingMethod = "tensor"
Results must be the same as lexical search and tensor search respectively.
"""

index_test_cases = [
(CloudTestIndex.structured_text, self.structured_index_name) # TODO: add unstructured when supported
]
for cloud_test_index_to_use, open_source_test_index_name in index_test_cases:
test_index_name = self.get_test_index_name(
cloud_test_index_to_use=cloud_test_index_to_use,
open_source_test_index_name=open_source_test_index_name
)
self.client.index(test_index_name).add_documents(self.docs_list)

test_cases = [
("lexical", "lexical"),
("tensor", "tensor")
]

for retrievalMethod, rankingMethod in test_cases:
with self.subTest(retrieval=retrievalMethod, ranking=rankingMethod):
hybrid_res = self.client.index(test_index_name).search(
"dogs",
search_method="HYBRID",
hybrid_parameters={
"retrievalMethod": retrievalMethod,
"rankingMethod": rankingMethod
},
limit=10
)

base_res = self.client.index(test_index_name).search(
"dogs",
search_method=retrievalMethod, # will be either lexical or tensor
limit=10
)

self.assertEqual(len(hybrid_res["hits"]), len(base_res["hits"]))
for i in range(len(hybrid_res["hits"])):
self.assertEqual(hybrid_res["hits"][i]["_id"], base_res["hits"][i]["_id"])

def test_hybrid_search_with_filter(self):
"""
Tests that filter is applied correctly in hybrid search.
"""

index_test_cases = [
(CloudTestIndex.structured_text, self.structured_index_name) # TODO: add unstructured when supported
]
for cloud_test_index_to_use, open_source_test_index_name in index_test_cases:
test_index_name = self.get_test_index_name(
cloud_test_index_to_use=cloud_test_index_to_use,
open_source_test_index_name=open_source_test_index_name
)
self.client.index(test_index_name).add_documents(self.docs_list)

test_cases = [
("disjunction", "rrf"),
("lexical", "lexical"),
("tensor", "tensor")
]

for retrievalMethod, rankingMethod in test_cases:
with self.subTest(retrieval=retrievalMethod, ranking=rankingMethod):
hybrid_res = self.client.index(test_index_name).search(
"dogs",
search_method="HYBRID",
filter_string="text_field_1:(something something dogs)",
hybrid_parameters={
"retrievalMethod": retrievalMethod,
"rankingMethod": rankingMethod
},
limit=10
)

self.assertEqual(len(hybrid_res["hits"]), 1)
self.assertEqual(hybrid_res["hits"][0]["_id"], "doc8")

0 comments on commit d45985c

Please sign in to comment.