Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add numberOfVectors to get_stats api #553

Merged
merged 14 commits into from
Jul 28, 2023
27 changes: 9 additions & 18 deletions src/marqo/tensor_search/tensor_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,7 @@ def _autofill_index_settings(index_settings: dict):


def get_stats(config: Config, index_name: str):
"""Returns the number of documents and vectors in the index.

The _count API counts top-level documents.
For numberOfVectors, we count the number of `__chunks.__field_name` fields because it is a key_word field, which
is known to be fast.

Difference between the two gives the numberOfVectors."""
"""Returns the number of documents and vectors in the index."""

body = {
"size": 0,
Expand All @@ -285,6 +279,7 @@ def get_stats(config: Config, index_name: str):
"aggs": {
"marqo_vector_count": {
"value_count": {
# This is a key_word field, so it is fast in value_count
"field": "__chunks.__field_name"
}
}
Expand All @@ -298,18 +293,14 @@ def get_stats(config: Config, index_name: str):
vector_count = HttpRequests(config).get(path=f"{index_name}/_search", body=body) \
["aggregations"]["nested_chunks"]["marqo_vector_count"]["value"]
except (KeyError, TypeError) as e:
error_message = (f"Marqo received an unexpected response from Marqo-os during execution of `get_stats()` APIs. "
f"The expected fields do not exist in the response. Original error message = {e}")
logger.error(error_message)
raise errors.InternalError(error_message)
except (errors.IndexNotFoundError, errors.InvalidIndexNameError):
raise
raise errors.InternalError(f"Marqo received an unexpected response from Marqo-OS. "
f"The expected fields do not exist in the response. Original error message = {e}")
except errors.MarqoWebError as e:
error_message = (
f"Marqo encountered an error while communicating with Marqo-os during execution of `get_stats()` APIs. "
f"Original error message: {e.message}")
logger.error(error_message)
raise errors.InternalError(error_message)
if isinstance(e, (errors.IndexNotFoundError, errors.InvalidIndexNameError)):
raise e
else:
raise errors.InternalError(f"Marqo encountered an error while communicating with Marqo-OS. "
f"Original error message: {e.message}")

return {
"numberOfDocuments": doc_count,
Expand Down
36 changes: 18 additions & 18 deletions tests/tensor_search/test_get_stats.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's get some edge case and HTTP error tests

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also get some cases where the HTTP response isn't expected (like an unknown dict type because of an error)

Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def test_get_stats_non_empty(self):
assert tensor_search.get_stats(config=self.config, index_name=self.index_name_1)["numberOfDocuments"] == 3

def test_get_stats_number_of_vectors_unified(self):

"""Tests the 'get_stats' function by checking if it correctly returns the
number of vectors and documents for all the documents indexed at once."""
tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1,
index_settings={'index_defaults': {"model": "random/small"}})
expected_number_of_vectors = 7
Expand All @@ -62,12 +63,14 @@ def test_get_stats_number_of_vectors_unified(self):
)
)

assert tensor_search.get_stats(config=self.config, index_name=self.index_name_1)["numberOfDocuments"] \
== expected_number_of_documents
assert tensor_search.get_stats(config=self.config, index_name=self.index_name_1)["numberOfVectors"] \
== expected_number_of_vectors
res = tensor_search.get_stats(config=self.config, index_name=self.index_name_1)

assert res["numberOfDocuments"] == expected_number_of_documents
assert res["numberOfVectors"] == expected_number_of_vectors

def test_get_stats_number_of_vectors_separated(self):
"""Tests the 'get_stats' function by checking if it correctly returns the
number of vectors and documents for documents indexed one by one ."""
testing_list = [
{
"expected_number_of_vectors": 2,
Expand Down Expand Up @@ -135,10 +138,11 @@ def test_get_stats_number_of_vectors_separated(self):
**test_case["add_docs_kwargs"],
)
)
assert tensor_search.get_stats(config=self.config, index_name=self.index_name_1)["numberOfDocuments"] \
== test_case["expected_number_of_documents"]
assert tensor_search.get_stats(config=self.config, index_name=self.index_name_1)["numberOfVectors"] \
== test_case["expected_number_of_vectors"]

res = tensor_search.get_stats(config=self.config, index_name=self.index_name_1)

assert res["numberOfDocuments"] == test_case["expected_number_of_documents"]
assert res["numberOfVectors"] == test_case["expected_number_of_vectors"]

def test_long_text_splitting_vectors_count(self):

Expand Down Expand Up @@ -180,10 +184,10 @@ def test_long_text_splitting_vectors_count(self):
)
)

assert tensor_search.get_stats(config=self.config, index_name=self.index_name_1)["numberOfDocuments"] \
== test_case["expected_number_of_documents"]
assert tensor_search.get_stats(config=self.config, index_name=self.index_name_1)["numberOfVectors"] \
== test_case["expected_number_of_vectors"]
res = tensor_search.get_stats(config=self.config, index_name=self.index_name_1)

assert res["numberOfDocuments"] == test_case["expected_number_of_documents"]
assert res["numberOfVectors"] == test_case["expected_number_of_vectors"]

def test_key_error(self):
with patch("marqo.tensor_search.tensor_search.HttpRequests.get") as mock_get:
Expand Down Expand Up @@ -213,8 +217,6 @@ def test_IndexNotFoundError_error(self):
with self.assertRaises(errors.IndexNotFoundError) as e:
res = tensor_search.get_stats(self.config, self.index_name_1)

self.assertIn("IndexNotFoundError", str(e.exception))

def test_InvalidIndexNameError_error(self):
with patch("marqo.tensor_search.tensor_search.HttpRequests.get") as mock_get:
mock_get.side_effect = errors.InvalidIndexNameError("InvalidIndexNameError")
Expand All @@ -223,8 +225,6 @@ def test_InvalidIndexNameError_error(self):
with self.assertRaises(errors.InvalidIndexNameError) as e:
res = tensor_search.get_stats(self.config, self.index_name_1)

self.assertIn("InvalidIndexNameError", str(e.exception))

def test_generic_MarqoWebError_error(self):
# Generic MarqoWebError should be caught and raised as InternalError
with patch("marqo.tensor_search.tensor_search.HttpRequests.get") as mock_get:
Expand All @@ -234,4 +234,4 @@ def test_generic_MarqoWebError_error(self):
with self.assertRaises(errors.InternalError) as e:
res = tensor_search.get_stats(self.config, self.index_name_1)

self.assertIn("Marqo encountered an error while communicating with Marqo-os during execution of", str(e.exception))
self.assertIn("Marqo encountered an error while communicating with Marqo-OS", str(e.exception))