Skip to content

Commit

Permalink
Embed method for index (#227)
Browse files Browse the repository at this point in the history
  • Loading branch information
vicilliar authored Apr 23, 2024
1 parent 10aefd4 commit 5a30a06
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 0 deletions.
55 changes: 55 additions & 0 deletions src/marqo/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,61 @@ def search(self, q: Optional[Union[str, dict]] = None, searchable_attributes: Op
mq_logger.debug(search_time_log)
return res

def embed(self, content: Union[Union[str, Dict[str, float]], List[Union[str, Dict[str, float]]]],
device: Optional[str] = None, image_download_headers: Optional[Dict] = None,
model_auth: Optional[dict] = None):
"""Retrieve embeddings for content or list of content.
Args:
content: string, dictionary of weighted strings, or list of either. Strings
to search are text or a pointer/url to an image if the index
has treat_urls_and_pointers_as_images set to True.
If queries are weighted, each weight act as a (possibly negative)
multiplier for that query, relative to the other queries.
device: the device used to index the data. Examples include "cpu",
"cuda" and "cuda:2".
image_download_headers: a dictionary of headers to be passed while downloading images,
for URLs found in documents
model_auth: authorisation that lets Marqo download a private model, if required
Returns:
Dictionary of content, embeddings, and processingTimeMs.
"""

start_time_client_request = timer()

path_with_query_str = (
f"indexes/{self.index_name}/embed"
f"{f'?&device={utils.translate_device_string_for_url(device)}' if device is not None else ''}"
)
body = {
"content": content,
}

if image_download_headers is not None:
body["image_download_headers"] = image_download_headers
if model_auth is not None:
body["modelAuth"] = model_auth

res = self.http.post(
path=path_with_query_str,
body=body,
index_name=self.index_name,
)

num_results = len(res["embeddings"])
end_time_client_request = timer()
total_client_request_time = end_time_client_request - start_time_client_request

embed_time_log = (f"embed: took {(total_client_request_time):.3f}s to embed content"
f"and received {num_results} embeddings from Marqo (roundtrip).")
if 'processingTimeMs' in res:
embed_time_log += f" Marqo itself took {(res['processingTimeMs'] * 0.001):.3f}s to execute the embed request."

mq_logger.debug(embed_time_log)
return res

def get_document(self, document_id: str, expose_facets=None) -> Dict[str, Any]:
"""Get one document with given an ID.
Expand Down
162 changes: 162 additions & 0 deletions tests/v2_tests/test_embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import copy
import marqo
from marqo import enums
from unittest import mock
import requests
import random
import math
import time
from tests.marqo_test import MarqoTestCase, CloudTestIndex
from marqo.errors import MarqoWebError
from pytest import mark
import numpy as np


@mark.fixed
class TestEmbed(MarqoTestCase):
def test_embed_single_string(self):
"""Embeds a string. Use add docs and get docs with tensor facets to ensure the vector is correct.
Checks the basic functionality and response structure"""
for cloud_test_index_to_use, open_source_test_index_name in self.test_cases:
test_index_name = self.get_test_index_name(
cloud_test_index_to_use=cloud_test_index_to_use,
open_source_test_index_name=open_source_test_index_name
)
with (self.subTest(test_index_name)):
# Add document
tensor_fields = ["text_field_1"] if "unstr" in test_index_name else None
d1 = {
"_id": "doc1",
"text_field_1": "Jimmy Butler is the GOAT."
}
res = self.client.index(test_index_name).add_documents([d1], tensor_fields=tensor_fields)

# Get doc with tensor facets (for reference vector)
retrieved_d1 = self.client.index(test_index_name).get_document(
document_id="doc1", expose_facets=True)

# Call embed
embed_res = self.client.index(test_index_name).embed("Jimmy Butler is the GOAT.")

self.assertIn("processingTimeMs", embed_res)
self.assertEqual(embed_res["content"], "Jimmy Butler is the GOAT.")
self.assertTrue(np.allclose(embed_res["embeddings"][0], retrieved_d1["_tensor_facets"][0] ["_embedding"]))


def test_embed_with_device(self):
"""Embeds a string with device parameter. Use add docs and get docs with tensor facets to ensure the vector is correct.
Checks the basic functionality and response structure"""
for cloud_test_index_to_use, open_source_test_index_name in self.test_cases:
test_index_name = self.get_test_index_name(
cloud_test_index_to_use=cloud_test_index_to_use,
open_source_test_index_name=open_source_test_index_name
)
with (self.subTest(test_index_name)):
# Add document
tensor_fields = ["text_field_1"] if "unstr" in test_index_name else None
d1 = {
"_id": "doc1",
"text_field_1": "Jimmy Butler is the GOAT."
}
res = self.client.index(test_index_name).add_documents([d1], tensor_fields=tensor_fields)

# Get doc with tensor facets (for reference vector)
retrieved_d1 = self.client.index(test_index_name).get_document(
document_id="doc1", expose_facets=True)

# Call embed
embed_res = self.client.index(test_index_name).embed(content="Jimmy Butler is the GOAT.", device="cpu")
self.assertIn("processingTimeMs", embed_res)
self.assertEqual(embed_res["content"], "Jimmy Butler is the GOAT.")
self.assertTrue(np.allclose(embed_res["embeddings"][0], retrieved_d1["_tensor_facets"][0] ["_embedding"]))

def test_embed_single_dict(self):
"""Embeds a dict. Use add docs and get docs with tensor facets to ensure the vector is correct.
Checks the basic functionality and response structure"""
for cloud_test_index_to_use, open_source_test_index_name in self.test_cases:
test_index_name = self.get_test_index_name(
cloud_test_index_to_use=cloud_test_index_to_use,
open_source_test_index_name=open_source_test_index_name
)
with (self.subTest(test_index_name)):
# Add document
tensor_fields = ["text_field_1"] if "unstr" in test_index_name else None
d1 = {
"_id": "doc1",
"text_field_1": "Jimmy Butler is the GOAT."
}
res = self.client.index(test_index_name).add_documents([d1], tensor_fields=tensor_fields)

# Get doc with tensor facets (for reference vector)
retrieved_d1 = self.client.index(test_index_name).get_document(
document_id="doc1", expose_facets=True)

# Call embed
embed_res = self.client.index(test_index_name).embed(content={"Jimmy Butler is the GOAT.": 1})

self.assertIn("processingTimeMs", embed_res)
self.assertEqual(embed_res["content"], {"Jimmy Butler is the GOAT.": 1})
self.assertTrue(np.allclose(embed_res["embeddings"][0], retrieved_d1["_tensor_facets"][0] ["_embedding"]))

def test_embed_list_content(self):
"""Embeds a list with string and dict. Use add docs and get docs with tensor facets to ensure the vector is correct.
Checks the basic functionality and response structure"""
for cloud_test_index_to_use, open_source_test_index_name in self.test_cases:
test_index_name = self.get_test_index_name(
cloud_test_index_to_use=cloud_test_index_to_use,
open_source_test_index_name=open_source_test_index_name
)
with (self.subTest(test_index_name)):
# Add document
tensor_fields = ["text_field_1"] if "unstr" in test_index_name else None
d1 = {
"_id": "doc1",
"text_field_1": "Jimmy Butler is the GOAT."
}
d2 = {
"_id": "doc2",
"text_field_1": "Alex Caruso is the GOAT."
}
res = self.client.index(test_index_name).add_documents([d1, d2], tensor_fields=tensor_fields)

# Get doc with tensor facets (for reference vector)
retrieved_docs = self.client.index(test_index_name).get_documents(
document_ids=["doc1", "doc2"], expose_facets=True)

# Call embed
embed_res = self.client.index(test_index_name).embed(
content=[{"Jimmy Butler is the GOAT.": 1}, "Alex Caruso is the GOAT."]
)

self.assertIn("processingTimeMs", embed_res)
self.assertEqual(embed_res["content"], [{"Jimmy Butler is the GOAT.": 1}, "Alex Caruso is the GOAT."])
self.assertTrue(
np.allclose(embed_res["embeddings"][0], retrieved_docs["results"][0]["_tensor_facets"][0]["_embedding"]))
self.assertTrue(
np.allclose(embed_res["embeddings"][1], retrieved_docs["results"][1]["_tensor_facets"][0]["_embedding"]))


def test_embed_non_numeric_weight_fails(self):
for cloud_test_index_to_use, open_source_test_index_name in self.test_cases:
test_index_name = self.get_test_index_name(
cloud_test_index_to_use=cloud_test_index_to_use,
open_source_test_index_name=open_source_test_index_name
)
with (self.subTest(test_index_name)):
with self.assertRaises(MarqoWebError) as e:
self.client.index(test_index_name).embed(content={"text to embed": "not a number"})

self.assertIn("not a valid float", str(e.exception))


def test_embed_empty_content(self):
for cloud_test_index_to_use, open_source_test_index_name in self.test_cases:
test_index_name = self.get_test_index_name(
cloud_test_index_to_use=cloud_test_index_to_use,
open_source_test_index_name=open_source_test_index_name
)
with (self.subTest(test_index_name)):
with self.assertRaises(MarqoWebError) as e:
self.client.index(test_index_name).embed(content=[])

self.assertIn("bruh should not be empty", str(e.exception))

0 comments on commit 5a30a06

Please sign in to comment.