Skip to content

Commit

Permalink
Validate truncated images and return better error (#797)
Browse files Browse the repository at this point in the history
Return 4xx for truncated images (within an add docs response) rather than a 500 for the whole request
  • Loading branch information
wanliAlex authored Apr 11, 2024
1 parent 3f78b6d commit 921e853
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 4 deletions.
8 changes: 5 additions & 3 deletions src/marqo/s2_inference/s2_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ def vectorise(model_name: str, content: Union[str, List[str]], model_properties:
raise RuntimeError(f"Vectorise created an empty list of batches! Content: {content}")
else:
vectorised = np.concatenate(vector_batches, axis=0)
except UnidentifiedImageError as e:
raise VectoriseError(f"Could not process given image: {content}") from e

except (UnidentifiedImageError, OSError) as e:
if isinstance(e, UnidentifiedImageError) or "image file is truncated" in str(e):
raise VectoriseError(f"Could not process given image: {content}. Original Error message: {e}") from e
else:
raise e
return _convert_vectorized_output(vectorised)


Expand Down
2 changes: 1 addition & 1 deletion src/marqo/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.4.0"
__version__ = "2.4.1"


def get_version() -> str:
Expand Down
93 changes: 93 additions & 0 deletions tests/tensor_search/integ_tests/test_add_documents_combined.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import functools
import math
import os
import uuid
from unittest import mock
from unittest.mock import patch

import PIL
import pytest
import requests

from marqo.api.exceptions import IndexNotFoundError, BadRequestError
from marqo.core.models.marqo_index import *
from marqo.core.models.marqo_index_request import FieldRequest
from marqo.s2_inference import types
from marqo.tensor_search import add_docs
from marqo.tensor_search import enums
from marqo.tensor_search import tensor_search
from marqo.tensor_search.models.add_docs_objects import AddDocsParams
from tests.marqo_test import MarqoTestCase


class TestAddDocumentsStructured(MarqoTestCase):
@classmethod
def setUpClass(cls) -> None:
super().setUpClass()

structured_image_index_request = cls.structured_marqo_index_request(
name="structured_image_index" + str(uuid.uuid4()).replace('-', ''),
fields=[
FieldRequest(name="image_field_1", type=FieldType.ImagePointer),
FieldRequest(name="text_field_1", type=FieldType.Text,
features=[FieldFeature.Filter, FieldFeature.LexicalSearch])
],
model=Model(name="open_clip/ViT-B-32/laion2b_s34b_b79k"),
tensor_fields=["image_field_1", "text_field_1"]
)

unstructured_image_index_request = cls.unstructured_marqo_index_request(
name="unstructured_image_index" + str(uuid.uuid4()).replace('-', ''),
model=Model(name="open_clip/ViT-B-32/laion2b_s34b_b79k"),
treat_urls_and_pointers_as_images=True
)

cls.indexes = cls.create_indexes([
structured_image_index_request,
unstructured_image_index_request
])

cls.structured_marqo_index_name = structured_image_index_request.name
cls.unstructured_marqo_index_name = unstructured_image_index_request.name

def setUp(self) -> None:
super().setUp()

# Any tests that call add_documents, search, bulk_search need this env var
self.device_patcher = mock.patch.dict(os.environ, {"MARQO_BEST_AVAILABLE_DEVICE": "cpu"})
self.device_patcher.start()

def tearDown(self) -> None:
super().tearDown()
self.device_patcher.stop()

def test_add_documents_with_truncated_image(self):
"""Test to ensure that the add_documents API can properly return 400 for the document with a truncated image."""
truncated_image_url = "https://marqo-assets.s3.amazonaws.com/tests/images/truncated_image.jpg"

documents = [
{
"image_field_1": "https://marqo-assets.s3.amazonaws.com/tests/images/ai_hippo_statue.png",
"text_field_1": "This is a valid image",
"_id": "1"
},
{
"image_field_1": truncated_image_url,
"text_field_1": "This is a truncated image",
"_id": "2"
}
]

for index_name in [self.structured_marqo_index_name, self.unstructured_marqo_index_name]:
tensor_fields = ["image_field_1", "text_field_1"] if index_name == self.unstructured_marqo_index_name \
else None
with self.subTest(f"test add documents with truncated image for {index_name}"):
r = tensor_search.add_documents(config=self.config,
add_docs_params=AddDocsParams(index_name=index_name,
docs=documents,
tensor_fields=tensor_fields))
self.assertEqual(True, r["errors"])
self.assertEqual(2, len(r["items"]))
self.assertEqual(200, r["items"][0]["status"])
self.assertEqual(400, r["items"][1]["status"])
self.assertIn("image file is truncated", r["items"][1]["error"])

0 comments on commit 921e853

Please sign in to comment.