From 921e853149b9b03f0d731b721f8e94c926adb953 Mon Sep 17 00:00:00 2001 From: Li Wan <49334982+wanliAlex@users.noreply.github.com> Date: Thu, 11 Apr 2024 14:27:09 +1000 Subject: [PATCH] Validate truncated images and return better error (#797) Return 4xx for truncated images (within an add docs response) rather than a 500 for the whole request --- src/marqo/s2_inference/s2_inference.py | 8 +- src/marqo/version.py | 2 +- .../test_add_documents_combined.py | 93 +++++++++++++++++++ 3 files changed, 99 insertions(+), 4 deletions(-) create mode 100644 tests/tensor_search/integ_tests/test_add_documents_combined.py diff --git a/src/marqo/s2_inference/s2_inference.py b/src/marqo/s2_inference/s2_inference.py index 1e95ecfd8..6db6d872d 100644 --- a/src/marqo/s2_inference/s2_inference.py +++ b/src/marqo/s2_inference/s2_inference.py @@ -80,9 +80,11 @@ def vectorise(model_name: str, content: Union[str, List[str]], model_properties: raise RuntimeError(f"Vectorise created an empty list of batches! Content: {content}") else: vectorised = np.concatenate(vector_batches, axis=0) - except UnidentifiedImageError as e: - raise VectoriseError(f"Could not process given image: {content}") from e - + except (UnidentifiedImageError, OSError) as e: + if isinstance(e, UnidentifiedImageError) or "image file is truncated" in str(e): + raise VectoriseError(f"Could not process given image: {content}. Original Error message: {e}") from e + else: + raise e return _convert_vectorized_output(vectorised) diff --git a/src/marqo/version.py b/src/marqo/version.py index 92c922f6d..e4eb9bb9a 100644 --- a/src/marqo/version.py +++ b/src/marqo/version.py @@ -1,4 +1,4 @@ -__version__ = "2.4.0" +__version__ = "2.4.1" def get_version() -> str: diff --git a/tests/tensor_search/integ_tests/test_add_documents_combined.py b/tests/tensor_search/integ_tests/test_add_documents_combined.py new file mode 100644 index 000000000..8dd2d23af --- /dev/null +++ b/tests/tensor_search/integ_tests/test_add_documents_combined.py @@ -0,0 +1,93 @@ +import functools +import math +import os +import uuid +from unittest import mock +from unittest.mock import patch + +import PIL +import pytest +import requests + +from marqo.api.exceptions import IndexNotFoundError, BadRequestError +from marqo.core.models.marqo_index import * +from marqo.core.models.marqo_index_request import FieldRequest +from marqo.s2_inference import types +from marqo.tensor_search import add_docs +from marqo.tensor_search import enums +from marqo.tensor_search import tensor_search +from marqo.tensor_search.models.add_docs_objects import AddDocsParams +from tests.marqo_test import MarqoTestCase + + +class TestAddDocumentsStructured(MarqoTestCase): + @classmethod + def setUpClass(cls) -> None: + super().setUpClass() + + structured_image_index_request = cls.structured_marqo_index_request( + name="structured_image_index" + str(uuid.uuid4()).replace('-', ''), + fields=[ + FieldRequest(name="image_field_1", type=FieldType.ImagePointer), + FieldRequest(name="text_field_1", type=FieldType.Text, + features=[FieldFeature.Filter, FieldFeature.LexicalSearch]) + ], + model=Model(name="open_clip/ViT-B-32/laion2b_s34b_b79k"), + tensor_fields=["image_field_1", "text_field_1"] + ) + + unstructured_image_index_request = cls.unstructured_marqo_index_request( + name="unstructured_image_index" + str(uuid.uuid4()).replace('-', ''), + model=Model(name="open_clip/ViT-B-32/laion2b_s34b_b79k"), + treat_urls_and_pointers_as_images=True + ) + + cls.indexes = cls.create_indexes([ + structured_image_index_request, + unstructured_image_index_request + ]) + + cls.structured_marqo_index_name = structured_image_index_request.name + cls.unstructured_marqo_index_name = unstructured_image_index_request.name + + def setUp(self) -> None: + super().setUp() + + # Any tests that call add_documents, search, bulk_search need this env var + self.device_patcher = mock.patch.dict(os.environ, {"MARQO_BEST_AVAILABLE_DEVICE": "cpu"}) + self.device_patcher.start() + + def tearDown(self) -> None: + super().tearDown() + self.device_patcher.stop() + + def test_add_documents_with_truncated_image(self): + """Test to ensure that the add_documents API can properly return 400 for the document with a truncated image.""" + truncated_image_url = "https://marqo-assets.s3.amazonaws.com/tests/images/truncated_image.jpg" + + documents = [ + { + "image_field_1": "https://marqo-assets.s3.amazonaws.com/tests/images/ai_hippo_statue.png", + "text_field_1": "This is a valid image", + "_id": "1" + }, + { + "image_field_1": truncated_image_url, + "text_field_1": "This is a truncated image", + "_id": "2" + } + ] + + for index_name in [self.structured_marqo_index_name, self.unstructured_marqo_index_name]: + tensor_fields = ["image_field_1", "text_field_1"] if index_name == self.unstructured_marqo_index_name \ + else None + with self.subTest(f"test add documents with truncated image for {index_name}"): + r = tensor_search.add_documents(config=self.config, + add_docs_params=AddDocsParams(index_name=index_name, + docs=documents, + tensor_fields=tensor_fields)) + self.assertEqual(True, r["errors"]) + self.assertEqual(2, len(r["items"])) + self.assertEqual(200, r["items"][0]["status"]) + self.assertEqual(400, r["items"][1]["status"]) + self.assertIn("image file is truncated", r["items"][1]["error"]) \ No newline at end of file