Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add numberOfVectors to get_stats api #553

Merged
merged 14 commits into from
Jul 28, 2023
48 changes: 42 additions & 6 deletions src/marqo/tensor_search/tensor_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@
from marqo.tensor_search.models.index_info import IndexInfo, get_model_properties_from_index_defaults
from marqo.tensor_search.models.external_apis.abstract_classes import ExternalAuth
from marqo.tensor_search.telemetry import RequestMetricsStore
from marqo.tensor_search.health import generate_heath_check_response
from marqo.tensor_search.utils import add_timing
from marqo.tensor_search import delete_docs
from marqo.s2_inference.processing import text as text_processor
from marqo.s2_inference.processing import image as image_processor
from marqo.s2_inference.clip_utils import _is_image
from marqo.s2_inference.reranking import rerank
from marqo.s2_inference import s2_inference
from marqo.tensor_search.health import generate_heath_check_response
import torch.cuda
import psutil
# We depend on _httprequests.py for now, but this may be replaced in the future, as
Expand Down Expand Up @@ -269,9 +269,43 @@ def _autofill_index_settings(index_settings: dict):


def get_stats(config: Config, index_name: str):
doc_count = HttpRequests(config).post(path=F"{index_name}/_count")["count"]
"""Returns the number of documents and vectors in the index."""

body = {
"size": 0,
"aggs": {
"nested_chunks": {
"nested": {
"path": "__chunks"
},
"aggs": {
"marqo_vector_count": {
"value_count": {
# This is a key_word field, so it is fast in value_count
"field": "__chunks.__field_name"
}
}
}
}
}
}

try:
doc_count = HttpRequests(config).post(path=F"{index_name}/_count")["count"]
vector_count = HttpRequests(config).get(path=f"{index_name}/_search", body=body) \
["aggregations"]["nested_chunks"]["marqo_vector_count"]["value"]
except (KeyError, TypeError) as e:
raise errors.InternalError(f"Marqo received an unexpected response from Marqo-OS. "
f"The expected fields do not exist in the response. Original error message = {e}")
except (errors.IndexNotFoundError, errors.InvalidIndexNameError):
raise
except errors.MarqoWebError as e:
raise errors.InternalError(f"Marqo encountered an error while communicating with Marqo-OS. "
f"Original error message: {e.message}")

return {
"numberOfDocuments": doc_count
"numberOfDocuments": doc_count,
"numberOfVectors": vector_count,
}


Expand All @@ -289,8 +323,9 @@ def _infer_opensearch_data_type(
to_check = sample_field_content

if isinstance(to_check, dict):
raise errors.MarqoError("Field content can't be an object. An object should not be passed into _infer_opensearch_data_type"
"to check.")
raise errors.MarqoError(
"Field content can't be an object. An object should not be passed into _infer_opensearch_data_type"
"to check.")
elif isinstance(to_check, str):
return OpenSearchDataType.text
else:
Expand Down Expand Up @@ -1861,6 +1896,7 @@ def _select_model_from_media_type(media_type: Union[MediaType, str]) -> Union[Ml
raise ValueError("_select_model_from_media_type(): "
"Received unknown media type: {}".format(media_type))


def get_loaded_models() -> dict:
available_models = s2_inference.get_available_models()
message = {"models": []}
Expand Down Expand Up @@ -2141,4 +2177,4 @@ def delete_documents(config: Config, index_name: str, doc_ids: List[str], auto_r
index_name=index_name,
document_ids=doc_ids,
auto_refresh=auto_refresh)
)
)
211 changes: 205 additions & 6 deletions tests/tensor_search/test_get_stats.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's get some edge case and HTTP error tests

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also get some cases where the HTTP response isn't expected (like an unknown dict type because of an error)

Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from unittest.mock import MagicMock, patch

from marqo.tensor_search.models.add_docs_objects import AddDocsParams
from marqo import errors
from marqo.errors import IndexNotFoundError, MarqoError
from marqo.tensor_search import tensor_search, constants, index_meta_cache
from tests.marqo_test import MarqoTestCase
Expand All @@ -14,25 +17,221 @@ def setUp(self) -> None:
except IndexNotFoundError as s:
pass

def test_get_stats_empty(self):
def tearDown(self) -> None:
try:
tensor_search.delete_index(config=self.config, index_name=self.index_name_1)
except IndexNotFoundError as s:
pass

def test_get_stats_empty(self):
tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1)
assert tensor_search.get_stats(config=self.config, index_name=self.index_name_1)["numberOfDocuments"] == 0

def test_get_stats_non_empty(self):
try:
tensor_search.delete_index(config=self.config, index_name=self.index_name_1)
except IndexNotFoundError as s:
pass
tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1)
tensor_search.add_documents(
config=self.config, add_docs_params=AddDocsParams(
docs=[{"1": "2"},{"134": "2"},{"14": "62"}],
docs=[{"1": "2"}, {"134": "2"}, {"14": "62"}],
index_name=self.index_name_1,
auto_refresh=True, device="cpu"
)
)
assert tensor_search.get_stats(config=self.config, index_name=self.index_name_1)["numberOfDocuments"] == 3

def test_get_stats_number_of_vectors_unified(self):
"""Tests the 'get_stats' function by checking if it correctly returns the
number of vectors and documents for all the documents indexed at once."""
tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1,
index_settings={'index_defaults': {"model": "random/small"}})
expected_number_of_vectors = 7
expected_number_of_documents = 5
tensor_search.add_documents(
config=self.config, add_docs_params=AddDocsParams(
docs=[
{"description_1": "test-2", "description_2": "test"}, # 2 vectors
{"description_1": "test-2", "description_2": "test", "description_3": "test"}, # 3 vectors
{"description_2": "test"}, # 1 vector
{"my_multi_modal_field": {
"text_1": "test", "text_2": "test"}}, # 1 vector
{"non_tensor_field": "test"} # 0 vectors
],
index_name=self.index_name_1,
auto_refresh=True, device="cpu",
non_tensor_fields=["non_tensor_field"],
mappings={"my_multi_modal_field": {"type": "multimodal_combination", "weights": {
"text_1": 0.5, "text_2": 0.8}}}
)
)

res = tensor_search.get_stats(config=self.config, index_name=self.index_name_1)

assert res["numberOfDocuments"] == expected_number_of_documents
assert res["numberOfVectors"] == expected_number_of_vectors

def test_get_stats_number_of_vectors_separated(self):
"""Tests the 'get_stats' function by checking if it correctly returns the
number of vectors and documents for documents indexed one by one ."""
testing_list = [
{
"expected_number_of_vectors": 2,
"expected_number_of_documents": 1,
"add_docs_kwargs": {
"docs": [{"description_1": "test-2", "description_2": "test"}]
}
},
{
"expected_number_of_vectors": 3,
"expected_number_of_documents": 1,
"add_docs_kwargs": {
"docs": [{"description_1": "test-2", "description_2": "test", "description_3": "test"}]
}
},
{
"expected_number_of_vectors": 1,
"expected_number_of_documents": 1,
"add_docs_kwargs": {
"docs": [{"description_2": "test"}]
}
},
{
"expected_number_of_vectors": 3,
"expected_number_of_documents": 1,
"add_docs_kwargs": {
"docs": [{"description_1": "test-2", "description_2": "test", "description_3": "test"}],
"mappings": {
"my_multi_modal_field": {
"type": "multimodal_combination",
"weights": {"text_1": 0.5, "text_2": 0.8}
}
}
}
},
{
"expected_number_of_vectors": 0,
"expected_number_of_documents": 1,
"add_docs_kwargs": {
"docs": [{"non_tensor_field": "test"}],
"non_tensor_fields": ["non_tensor_field"]
}
},
{
"expected_number_of_vectors": 0,
"expected_number_of_documents": 1,
"add_docs_kwargs": {
"docs": [{"list_field": ["this", "that"]}],
"non_tensor_fields": ["list_field"]
}
}
]
for test_case in testing_list:
try:
tensor_search.delete_index(config=self.config, index_name=self.index_name_1)
except IndexNotFoundError:
pass

tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1,
index_settings={'index_defaults': {"model": "random/small"}})
tensor_search.add_documents(
config=self.config, add_docs_params=AddDocsParams(
index_name=self.index_name_1,
auto_refresh=True, device="cpu",
**test_case["add_docs_kwargs"],
)
)

res = tensor_search.get_stats(config=self.config, index_name=self.index_name_1)

assert res["numberOfDocuments"] == test_case["expected_number_of_documents"]
assert res["numberOfVectors"] == test_case["expected_number_of_vectors"]

def test_long_text_splitting_vectors_count(self):

number_of_words = 55

test_case = {
"expected_number_of_vectors": 3,
"expected_number_of_documents": 1,
"add_docs_kwargs": {
"docs": [{"55_words_field": "test " * number_of_words}],
}
}

index_settings = {
"index_defaults": {
"normalize_embeddings": True,
"model": "random/small",
"text_preprocessing": {
"split_length": 20,
"split_overlap": 1,
"split_method": "word"
},
}
}

try:
tensor_search.delete_index(config=self.config, index_name=self.index_name_1)
except IndexNotFoundError:
pass

tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1,
index_settings=index_settings)

tensor_search.add_documents(
config=self.config, add_docs_params=AddDocsParams(
index_name=self.index_name_1,
auto_refresh=True, device="cpu",
**test_case["add_docs_kwargs"],
)
)

res = tensor_search.get_stats(config=self.config, index_name=self.index_name_1)

assert res["numberOfDocuments"] == test_case["expected_number_of_documents"]
assert res["numberOfVectors"] == test_case["expected_number_of_vectors"]

def test_key_error(self):
with patch("marqo.tensor_search.tensor_search.HttpRequests.get") as mock_get:
mock_get.return_value = {"aggregations": {"nested_chunks": {"chunk_count": {"not_value": 200}}}}

tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1,)
with self.assertRaises(errors.InternalError) as e:
res = tensor_search.get_stats(self.config, self.index_name_1)

self.assertIn("The expected fields do not exist in the response", str(e.exception))

def test_type_error(self):
with patch("marqo.tensor_search.tensor_search.HttpRequests.get") as mock_get:
mock_get.return_value = None

tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1,)
with self.assertRaises(errors.InternalError) as e:
res = tensor_search.get_stats(self.config, self.index_name_1)

self.assertIn("The expected fields do not exist in the response", str(e.exception))

def test_IndexNotFoundError_error(self):
with patch("marqo.tensor_search.tensor_search.HttpRequests.get") as mock_get:
mock_get.side_effect = errors.IndexNotFoundError("IndexNotFoundError")

tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1,)
with self.assertRaises(errors.IndexNotFoundError) as e:
res = tensor_search.get_stats(self.config, self.index_name_1)

def test_InvalidIndexNameError_error(self):
with patch("marqo.tensor_search.tensor_search.HttpRequests.get") as mock_get:
mock_get.side_effect = errors.InvalidIndexNameError("InvalidIndexNameError")

tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1,)
with self.assertRaises(errors.InvalidIndexNameError) as e:
res = tensor_search.get_stats(self.config, self.index_name_1)

def test_generic_MarqoWebError_error(self):
# Generic MarqoWebError should be caught and raised as InternalError
with patch("marqo.tensor_search.tensor_search.HttpRequests.get") as mock_get:
mock_get.side_effect = errors.MarqoWebError("test-test")

tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1,)
with self.assertRaises(errors.InternalError) as e:
res = tensor_search.get_stats(self.config, self.index_name_1)

self.assertIn("Marqo encountered an error while communicating with Marqo-OS", str(e.exception))