Skip to content

Commit

Permalink
added content_type fields to embed function and test requests (#231)
Browse files Browse the repository at this point in the history
  • Loading branch information
RaynorChavez authored May 9, 2024
1 parent b0cbce2 commit 66bc686
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 17 deletions.
4 changes: 4 additions & 0 deletions src/marqo/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ def create_index(
number_of_shards: Optional[int] = None,
number_of_replicas: Optional[int] = None,
number_of_inferences: Optional[int] = None,
text_query_prefix: Optional[str] = None,
text_chunk_prefix: Optional[str] = None,
) -> Dict[str, Any]:
"""Create the index. Please refer to the marqo cloud to see options for inference and storage node types.
Calls Index.create() with the same parameters.
Expand Down Expand Up @@ -144,6 +146,8 @@ def create_index(
number_of_shards=number_of_shards,
number_of_replicas=number_of_replicas,
number_of_inferences=number_of_inferences,
text_query_prefix=text_query_prefix,
text_chunk_prefix=text_chunk_prefix,
)

def delete_index(self, index_name: str, wait_for_readiness=True) -> Dict[str, Any]:
Expand Down
5 changes: 5 additions & 0 deletions src/marqo/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,8 @@ class InterpolationMethod(str, Enum):
LERP = "lerp"
NLERP = "nlerp"
SLERP = "slerp"


class EmbedContentType(str, Enum):
Query = "query"
Document = "document"
32 changes: 25 additions & 7 deletions src/marqo/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from marqo._httprequests import HttpRequests
from marqo.cloud_helpers import cloud_wait_for_index_status
from marqo.config import Config
from marqo.enums import IndexStatus, InterpolationMethod
from marqo.enums import IndexStatus, InterpolationMethod, EmbedContentType
from marqo.enums import SearchMethods
from marqo.errors import MarqoWebError, UnsupportedOperationError, MarqoCloudIndexNotFoundError
from marqo.marqo_logging import mq_logger
Expand Down Expand Up @@ -96,6 +96,8 @@ def create(config: Config,
number_of_replicas: Optional[int] = None,
number_of_inferences: Optional[int] = None,
wait_for_readiness: bool = True,
text_chunk_prefix: Optional[str] = None,
text_query_prefix: Optional[str] = None,
) -> Dict[str, Any]:
"""Create the index. Please refer to the marqo cloud to see options for inference and storage node types.
Creates CreateIndexSettings object and then uses it to create the index.
Expand Down Expand Up @@ -154,7 +156,9 @@ def create(config: Config,
textPreprocessing=text_preprocessing,
imagePreprocessing=image_preprocessing,
vectorNumericType=vector_numeric_type,
annParameters=ann_parameters
annParameters=ann_parameters,
textChunkPrefix=text_chunk_prefix,
textQueryPrefix=text_query_prefix,
)

return req.post(f"indexes/{index_name}", body=local_create_index_settings.generate_request_body())
Expand All @@ -180,6 +184,8 @@ def create(config: Config,
numberOfShards=number_of_shards,
numberOfReplicas=number_of_replicas,
storageClass=storage_class,
textChunkPrefix=text_chunk_prefix,
textQueryPrefix=text_query_prefix,
)

response = req.post(f"indexes/{index_name}", body=cloud_index_settings.generate_request_body())
Expand All @@ -202,7 +208,8 @@ def search(self, q: Optional[Union[str, dict]] = None, searchable_attributes: Op
boost: Optional[Dict[str, List[Union[float, int]]]] = None,
context: Optional[dict] = None, score_modifiers: Optional[dict] = None,
model_auth: Optional[dict] = None,
ef_search: Optional[int] = None, approximate: Optional[bool] = None
ef_search: Optional[int] = None, approximate: Optional[bool] = None,
text_query_prefix: Optional[str] = None,
) -> Dict[str, Any]:
"""Search the index.
Expand Down Expand Up @@ -265,6 +272,7 @@ def search(self, q: Optional[Union[str, dict]] = None, searchable_attributes: Op
"showHighlights": show_highlights,
"reRanker": reranker,
"boost": boost,
"textQueryPrefix": text_query_prefix,
}

body = {k: v for k, v in body.items() if v is not None}
Expand Down Expand Up @@ -376,7 +384,7 @@ def recommend(self, documents: Union[List[str], Dict[str, float]],

def embed(self, content: Union[Union[str, Dict[str, float]], List[Union[str, Dict[str, float]]]],
device: Optional[str] = None, image_download_headers: Optional[Dict] = None,
model_auth: Optional[dict] = None):
model_auth: Optional[dict] = None, content_type: Optional[EmbedContentType] = EmbedContentType.Query):
"""Retrieve embeddings for content or list of content.
Args:
content: string, dictionary of weighted strings, or list of either. Strings
Expand All @@ -392,6 +400,7 @@ def embed(self, content: Union[Union[str, Dict[str, float]], List[Union[str, Dic
image_download_headers: a dictionary of headers to be passed while downloading images,
for URLs found in documents
model_auth: authorisation that lets Marqo download a private model, if required
content_type: the type of prefix the user wants. "query", "document", or None.
Returns:
Dictionary of content, embeddings, and processingTimeMs.
"""
Expand All @@ -404,12 +413,14 @@ def embed(self, content: Union[Union[str, Dict[str, float]], List[Union[str, Dic
)
body = {
"content": content,
"content_type": content_type,
}

if image_download_headers is not None:
body["image_download_headers"] = image_download_headers
if model_auth is not None:
body["modelAuth"] = model_auth


res = self.http.post(
path=path_with_query_str,
Expand Down Expand Up @@ -476,7 +487,8 @@ def add_documents(
use_existing_tensors: bool = False,
image_download_headers: dict = None,
mappings: dict = None,
model_auth: dict = None
model_auth: dict = None,
text_chunk_prefix: str = None,
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
"""Add documents to this index. Does a partial update on existing documents,
based on their ID. Adds unseen documents to the index.
Expand All @@ -494,6 +506,7 @@ def add_documents(
for URLs found in documents
mappings: a dictionary to help handle the object fields. e.g., multimodal_combination field
model_auth: used to authorise a private model
text_chunk_prefix: the request level prefix for adding docs
Returns:
Response body outlining indexing result
"""
Expand All @@ -504,7 +517,8 @@ def add_documents(
documents=documents,
client_batch_size=client_batch_size, device=device, tensor_fields=tensor_fields,
use_existing_tensors=use_existing_tensors,
image_download_headers=image_download_headers, mappings=mappings, model_auth=model_auth
image_download_headers=image_download_headers, mappings=mappings, model_auth=model_auth,
text_chunk_prefix=text_chunk_prefix
)

def _add_docs_organiser(
Expand All @@ -516,7 +530,8 @@ def _add_docs_organiser(
use_existing_tensors: bool = False,
image_download_headers: dict = None,
mappings: dict = None,
model_auth: dict = None
model_auth: dict = None,
text_chunk_prefix: str = None,
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:
error_detected_message = ('Errors detected in add documents call. '
'Please examine the returned result object for more information.')
Expand All @@ -543,6 +558,9 @@ def _add_docs_organiser(
if tensor_fields is not None:
base_body['tensorFields'] = tensor_fields

if text_chunk_prefix is not None:
base_body['textChunkPrefix'] = text_chunk_prefix

end_time_client_process = timer()
total_client_process_time = end_time_client_process - start_time_client_process
mq_logger.debug(f"add_documents pre-processing: took {(total_client_process_time):.3f}s for {num_docs} docs.")
Expand Down
2 changes: 2 additions & 0 deletions src/marqo/models/create_index_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class IndexSettings(MarqoBaseModel):
imagePreprocessing: Optional[marqo_index.ImagePreProcessing] = None
vectorNumericType: Optional[marqo_index.VectorNumericType] = None
annParameters: Optional[marqo_index.AnnParameters] = None
textQueryPrefix: Optional[str] = None
textChunkPrefix: Optional[str] = None

def generate_request_body(self) -> dict:
"""A json encoded string of the request body"""
Expand Down
2 changes: 1 addition & 1 deletion src/marqo/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__marqo_version__ = "2.5.0"
__marqo_version__ = "2.6.0"
__marqo_release_page__ = f"https://github.com/marqo-ai/marqo/releases/tag/{__marqo_version__}"

__minimum_supported_marqo_version__ = "2.0"
Expand Down
51 changes: 51 additions & 0 deletions tests/v2_tests/test_create_index.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import uuid

from pytest import mark
import numpy as np

from marqo.errors import MarqoWebError
from tests.marqo_test import MarqoTestCase
Expand All @@ -10,11 +11,15 @@
@mark.ignore_during_cloud_tests
class TestCreateIndex(MarqoTestCase):
index_name = "test_create_index" + str(uuid.uuid4()).replace('-', '')
override_index_name = "override_prefix" + str(uuid.uuid4()).replace('-', '')
default_index_name = "default_prefix" + str(uuid.uuid4()).replace('-', '')

def tearDown(self):
super().tearDown()
try:
self.client.delete_index(index_name=self.index_name)
self.client.delete_index(index_name=self.override_index_name)
self.client.delete_index(index_name=self.default_index_name)
except MarqoWebError:
pass

Expand Down Expand Up @@ -45,6 +50,52 @@ def test_simple_index_creation(self):
}
self.assertEqual(expected_settings, index_settings)

def test_create_simple_index_creation_with_prefix(self):
# Create the indexes
self.client.create_index(
index_name=self.override_index_name,
model="test_prefix",
text_query_prefix="test: ",
text_chunk_prefix="test: ",
)
self.client.create_index(
index_name=self.default_index_name,
model="test_prefix",
)

d1 = {
"_id": "doc1",
"text_field_1": "hello document"
}
# Add documents to both
self.client.index("override_prefix").add_documents([d1], tensor_fields=["text_field_1"])
self.client.index("default_prefix").add_documents([d1], tensor_fields=["text_field_1"])

# Get override doc with tensor facets (for reference vector)
retrieved_override_doc = self.client.index("override_prefix").get_document(
document_id="doc1", expose_facets=True)

# Get default doc with tensor facets (for reference vector)
retrieved_default_doc = self.client.index("default_prefix").get_document(
document_id="doc1", expose_facets=True)

# Embed override
embed_res_override = self.client.index("override_prefix").embed("test: hello document", content_type=None)

# Embed default
embed_res_default = self.client.index("default_prefix").embed("test passage: hello document", content_type=None)

# Assert that the embeddings from override add docs and the embeddings from the embed call are the same
self.assertTrue(np.allclose(embed_res_override["embeddings"][0], retrieved_override_doc["_tensor_facets"][0]["_embedding"]))

# Assert that the embeddings from override add docs and the embeddings from the embed call are the same
self.assertTrue(np.allclose(embed_res_default["embeddings"][0], retrieved_default_doc["_tensor_facets"][0]["_embedding"]))






def test_create_unstructured_image_index(self):
self.client.create_index(index_name=self.index_name, type="unstructured",
treat_urls_and_pointers_as_images=True, model="open_clip/ViT-B-32/laion400m_e32")
Expand Down
2 changes: 1 addition & 1 deletion tests/v2_tests/test_demos.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def test_readme_example_weighted_query(self):
q=query, searchable_attributes=["text_field_1", "text_field_2"]
)

self.assertEqual("Smartphone", r2["hits"][0]["text_field_1"])
self.assertEqual("Telephone", r2["hits"][0]["text_field_1"])

print("Query 1:")
pprint.pprint(r2)
Expand Down
46 changes: 38 additions & 8 deletions tests/v2_tests/test_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class TestEmbed(MarqoTestCase):
def test_embed_single_string(self):
"""Embeds a string. Use add docs and get docs with tensor facets to ensure the vector is correct.
Checks the basic functionality and response structure"""
Checks the basic functionality and response structure. Also checks that the request level prefix override works."""
for cloud_test_index_to_use, open_source_test_index_name in self.test_cases:
test_index_name = self.get_test_index_name(
cloud_test_index_to_use=cloud_test_index_to_use,
Expand All @@ -29,18 +29,48 @@ def test_embed_single_string(self):
"_id": "doc1",
"text_field_1": "Jimmy Butler is the GOAT."
}
res = self.client.index(test_index_name).add_documents([d1], tensor_fields=tensor_fields)

res_1 = self.client.index(test_index_name).add_documents([d1], tensor_fields=tensor_fields)

# Get doc with tensor facets (for reference vector)
retrieved_d1 = self.client.index(test_index_name).get_document(
document_id="doc1", expose_facets=True)

# Call embed
embed_res = self.client.index(test_index_name).embed("Jimmy Butler is the GOAT.")
embed_res_1 = self.client.index(test_index_name).embed("Jimmy Butler is the GOAT.", content_type="document")

# Assert that the
self.assertIn("processingTimeMs", embed_res_1)
self.assertEqual(embed_res_1["content"], "Jimmy Butler is the GOAT.")
self.assertTrue(np.allclose(embed_res_1["embeddings"][0], retrieved_d1["_tensor_facets"][0]["_embedding"]))


def test_request_level_prefix_override_embed_add_docs(self):
"""Checks that the request level prefix override works."""
for cloud_test_index_to_use, open_source_test_index_name in self.test_cases:
test_index_name = self.get_test_index_name(
cloud_test_index_to_use=cloud_test_index_to_use,
open_source_test_index_name=open_source_test_index_name
)
with (self.subTest(test_index_name)):
# Add document
tensor_fields = ["text_field_1"] if "unstr" in test_index_name else None
d1 = {
"_id": "doc1",
"text_field_1": "Jimmy Butler is the GOAT."
}
res = self.client.index(test_index_name).add_documents([d1], tensor_fields=tensor_fields, text_chunk_prefix="test query: ")

# Get doc with tensor facets (for reference vector)
retrieved_d1 = self.client.index(test_index_name).get_document(
document_id="doc1", expose_facets=True)

embed_res = self.client.index(test_index_name).embed("test query: Jimmy Butler is the GOAT.", content_type=None)

# Assert request level prefix override
self.assertIn("processingTimeMs", embed_res)
self.assertEqual(embed_res["content"], "Jimmy Butler is the GOAT.")
self.assertTrue(np.allclose(embed_res["embeddings"][0], retrieved_d1["_tensor_facets"][0] ["_embedding"]))
self.assertEqual(embed_res["content"], "test query: Jimmy Butler is the GOAT.")
self.assertTrue(np.allclose(embed_res["embeddings"][0], retrieved_d1["_tensor_facets"][0]["_embedding"]))


def test_embed_with_device(self):
Expand All @@ -65,7 +95,7 @@ def test_embed_with_device(self):
document_id="doc1", expose_facets=True)

# Call embed
embed_res = self.client.index(test_index_name).embed(content="Jimmy Butler is the GOAT.", device="cpu")
embed_res = self.client.index(test_index_name).embed(content="Jimmy Butler is the GOAT.", device="cpu", content_type="document")
self.assertIn("processingTimeMs", embed_res)
self.assertEqual(embed_res["content"], "Jimmy Butler is the GOAT.")
self.assertTrue(np.allclose(embed_res["embeddings"][0], retrieved_d1["_tensor_facets"][0] ["_embedding"]))
Expand All @@ -92,7 +122,7 @@ def test_embed_single_dict(self):
document_id="doc1", expose_facets=True)

# Call embed
embed_res = self.client.index(test_index_name).embed(content={"Jimmy Butler is the GOAT.": 1})
embed_res = self.client.index(test_index_name).embed(content={"Jimmy Butler is the GOAT.": 1}, content_type="document")

self.assertIn("processingTimeMs", embed_res)
self.assertEqual(embed_res["content"], {"Jimmy Butler is the GOAT.": 1})
Expand Down Expand Up @@ -125,7 +155,7 @@ def test_embed_list_content(self):

# Call embed
embed_res = self.client.index(test_index_name).embed(
content=[{"Jimmy Butler is the GOAT.": 1}, "Alex Caruso is the GOAT."]
content=[{"Jimmy Butler is the GOAT.": 1}, "Alex Caruso is the GOAT."], content_type="document"
)

self.assertIn("processingTimeMs", embed_res)
Expand Down

0 comments on commit 66bc686

Please sign in to comment.