Skip to content

Commit

Permalink
Recommend endpoint (#816)
Browse files Browse the repository at this point in the history
  • Loading branch information
farshidz authored Apr 23, 2024
1 parent a059583 commit d2b96af
Show file tree
Hide file tree
Showing 20 changed files with 1,066 additions and 48 deletions.
2 changes: 1 addition & 1 deletion requirements.dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# tensor search, idential to requirements.txt:
# tensor search, identical to requirements.txt:
requests==2.28.1
anyio==3.7.1
fastapi==0.86.0
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ urllib3==1.26.0
pydantic==1.10.11
httpx==0.25.0
semver==3.0.2
scipy==1.10.1
memory-profiler==0.61.0
cachetools==5.3.1
pynvml==11.5.0 # For cuda utilization
readerwriterlock==1.0.9
readerwriterlock==1.0.9
22 changes: 22 additions & 0 deletions src/marqo/api/models/recommend_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Dict, List, Union, Optional

from marqo.core.models.interpolation_method import InterpolationMethod
from marqo.tensor_search.models.api_models import BaseMarqoModel
from marqo.tensor_search.models.score_modifiers_object import ScoreModifier


class RecommendQuery(BaseMarqoModel):
documents: Union[List[str], Dict[str, float]]
tensorFields: Optional[List[str]] = None
interpolationMethod: Optional[InterpolationMethod] = None
excludeInputDocuments: bool = True
limit: int = 10
offset: int = 0
efSearch: Optional[int] = None
approximate: Optional[bool] = None
searchableAttributes: Optional[List[str]] = None
showHighlights: bool = True
reRanker: str = None
filter: str = None
attributesToRetrieve: Union[None, List[str]] = None
scoreModifiers: Optional[ScoreModifier] = None
12 changes: 12 additions & 0 deletions src/marqo/case_insensitive_enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from enum import Enum
from typing import Optional


class CaseInsensitiveEnum(Enum):
@classmethod
def _missing_(cls, value: str) -> Optional['CaseInsensitiveEnum']:
value = value.lower()
for member in cls:
if member.value.lower() == value:
return member
return None
18 changes: 9 additions & 9 deletions src/marqo/config.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,39 @@
from typing import Optional, Union

from marqo.core.document.document import Document
from marqo.core.embed.embed import Embed
from marqo.core.index_management.index_management import IndexManagement
from marqo.core.monitoring.monitoring import Monitoring
from marqo.core.search.recommender import Recommender
from marqo.tensor_search import enums
from marqo.vespa.vespa_client import VespaClient
from marqo.core.embed.embed import Embed


class Config:
def __init__(
self,
vespa_client: VespaClient,
index_management: IndexManagement,
default_device: str,
timeout: Optional[int] = None,
backend: Optional[Union[enums.SearchDb, str]] = None
backend: Optional[Union[enums.SearchDb, str]] = None,
) -> None:
"""
Parameters
----------
url:
The url to the S2Search API (ex: http://localhost:9200)
"""
self.default_device = default_device
self.vespa_client = vespa_client
self.set_is_remote(vespa_client)
self.timeout = timeout
self.backend = backend if backend is not None else enums.SearchDb.vespa

# Initialize Core layer dependencies
self.index_management = index_management
self.monitoring = Monitoring(vespa_client, index_management)
self.document = Document(vespa_client, index_management)
self.embed = Embed(vespa_client, index_management, default_device)
self.index_management = IndexManagement(vespa_client)
self.monitoring = Monitoring(vespa_client, self.index_management)
self.document = Document(vespa_client, self.index_management)
self.recommender = Recommender(vespa_client, self.index_management)
self.embed = Embed(vespa_client, self.index_management, default_device)

def set_is_remote(self, vespa_client: VespaClient):
local_host_markers = ["localhost", "0.0.0.0", "127.0.0.1"]
Expand All @@ -45,4 +45,4 @@ def set_is_remote(self, vespa_client: VespaClient):
for url in [vespa_client.config_url, vespa_client.query_url, vespa_client.document_url]
]
):
self.cluster_is_remote = False
self.cluster_is_remote = False
7 changes: 7 additions & 0 deletions src/marqo/core/models/interpolation_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from marqo.case_insensitive_enum import CaseInsensitiveEnum


class InterpolationMethod(str, CaseInsensitiveEnum):
LERP = "lerp"
NLERP = "nlerp"
SLERP = "slerp"
178 changes: 178 additions & 0 deletions src/marqo/core/search/recommender.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from timeit import default_timer as timer
from typing import Dict, List, Union, Optional

from marqo.core.exceptions import InvalidFieldNameError
from marqo.core.index_management.index_management import IndexManagement
from marqo.core.models import MarqoIndex
from marqo.core.models.interpolation_method import InterpolationMethod
from marqo.core.models.marqo_index import IndexType
from marqo.core.utils.vector_interpolation import from_interpolation_method
from marqo.exceptions import InvalidArgumentError
from marqo.tensor_search.models.score_modifiers_object import ScoreModifier
from marqo.tensor_search.models.search import SearchContext, SearchContextTensor
from marqo.vespa.vespa_client import VespaClient


class Recommender:
def __init__(self, vespa_client: VespaClient, index_management: IndexManagement):
self.vespa_client = vespa_client
self.index_management = index_management

def recommend(self,
index_name: str,
documents: Union[List[str], Dict[str, float]],
tensor_fields: Optional[List[str]] = None,
interpolation_method: Optional[InterpolationMethod] = None,
exclude_input_documents: bool = True,
result_count: int = 3,
offset: int = 0,
highlights: bool = True,
ef_search: Optional[int] = None,
approximate: Optional[bool] = None,
searchable_attributes: Optional[List[str]] = None,
verbose: int = 0,
reranker: Union[str, Dict] = None,
filter: str = None,
attributes_to_retrieve: Optional[List[str]] = None,
score_modifiers: Optional[ScoreModifier] = None
):
"""
Recommend documents similar to the provided documents.
Args:
index_name: Name of the index to search
documents: A list of document IDs or a dictionary where the keys are document IDs and the values are weights
tensor_fields: List of tensor fields to use for recommendation
interpolation_method: Interpolation method to use for combining vectors
exclude_input_documents: Whether to exclude the input documents from the search results
result_count: Number of results to return
offset: Offset of the first result
highlights: Whether to include highlights in the results
ef_search: ef_search parameter for HNSW search
approximate: Whether to use approximate search
searchable_attributes: List of attributes to search in
verbose: Verbosity level
reranker: Reranker to use
filter: Filter string
attributes_to_retrieve: List of attributes to retrieve
score_modifiers: Score modifiers to apply
"""
# TODO - Extract search and get_docs from tensor_search and refactor this
# TODO - The dependence on Config in tensor_search is bad design. Refactor to require specific dependencies
from marqo import config
from marqo.tensor_search import tensor_search
from marqo.tensor_search import index_meta_cache

if documents is None or len(documents) == 0:
raise InvalidArgumentError('No document IDs provided')

marqo_index = index_meta_cache.get_index(config.Config(self.vespa_client), index_name=index_name)

if interpolation_method is None:
interpolation_method = self._get_default_interpolation_method(marqo_index)

vector_interpolation = from_interpolation_method(interpolation_method)

if marqo_index.type == IndexType.Structured:
# Validate tensor field names
if tensor_fields is not None:
valid_tensor_fields = marqo_index.tensor_field_map.keys()
for tensor_field in tensor_fields:
if tensor_field not in valid_tensor_fields:
raise InvalidFieldNameError(f'Tensor field "{tensor_field}" not found in index "{index_name}". '
f'Available tensor fields: {", ".join(valid_tensor_fields)}')

if isinstance(documents, dict):
document_ids = list(documents.keys())
else:
document_ids = documents

t0 = timer()

marqo_documents = tensor_search.get_documents_by_ids(
config.Config(self.vespa_client),
index_name, document_ids, show_vectors=True
)

# Make sure all documents were found
not_found = []
for document in marqo_documents['results']:
if not document['_found']:
not_found.append(document['_id'])

if len(not_found) > 0:
raise InvalidArgumentError(f'The following document IDs were not found: {", ".join(not_found)}')

doc_vectors: Dict[str, List[List[float]]] = {}
docs_without_vectors = []
for document in marqo_documents['results']:
vectors: List[List[float]] = []
for tensor_facet in document['_tensor_facets']:
field = list(tensor_facet.keys())[0]
if tensor_fields is None or field in tensor_fields:
vectors.append(tensor_facet['_embedding'])

doc_vectors[document['_id']] = vectors

if len(vectors) == 0:
docs_without_vectors.append(document['_id'])

if len(docs_without_vectors) > 0:
raise InvalidArgumentError(
f'The following documents do not have embeddings: {", ".join(docs_without_vectors)}'
)

vectors: List[List[float]] = []
weights: List[float] = []

for document_id, vector_list in doc_vectors.items():
if isinstance(documents, dict):
weight = documents[document_id]
else:
weight = 1
vectors.extend(vector_list)
weights.extend([weight] * len(vector_list))

interpolated_vector = vector_interpolation.interpolate(
vectors, weights
)

if exclude_input_documents:
recommend_filter = self._get_exclusion_filter(document_ids, filter)
else:
recommend_filter = filter

results = tensor_search.search(
config.Config(self.vespa_client),
index_name,
text=None,
context=SearchContext(tensor=[SearchContextTensor(vector=interpolated_vector, weight=1)]),
result_count=result_count,
offset=offset,
highlights=highlights,
ef_search=ef_search,
approximate=approximate,
searchable_attributes=searchable_attributes,
verbose=verbose,
reranker=reranker,
filter=recommend_filter,
attributes_to_retrieve=attributes_to_retrieve,
score_modifiers=score_modifiers,
processing_start=t0
)

return results

def _get_default_interpolation_method(self, marqo_index: MarqoIndex) -> InterpolationMethod:
if marqo_index.normalize_embeddings:
return InterpolationMethod.SLERP
else:
return InterpolationMethod.LERP

def _get_exclusion_filter(self, documents: List[str], user_filter: Optional[str]) -> str:
not_in = 'NOT (' + ' OR '.join([f'_id:({doc})' for doc in documents]) + ')'

if user_filter is not None and user_filter.strip() != '':
return f'({user_filter}) AND {not_in}'
else:
return not_in
Empty file.
Loading

0 comments on commit d2b96af

Please sign in to comment.