-
Notifications
You must be signed in to change notification settings - Fork 198
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Validate truncated images and return better error (#797)
Return 4xx for truncated images (within an add docs response) rather than a 500 for the whole request
- Loading branch information
Showing
3 changed files
with
99 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
__version__ = "2.4.0" | ||
__version__ = "2.4.1" | ||
|
||
|
||
def get_version() -> str: | ||
|
93 changes: 93 additions & 0 deletions
93
tests/tensor_search/integ_tests/test_add_documents_combined.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import functools | ||
import math | ||
import os | ||
import uuid | ||
from unittest import mock | ||
from unittest.mock import patch | ||
|
||
import PIL | ||
import pytest | ||
import requests | ||
|
||
from marqo.api.exceptions import IndexNotFoundError, BadRequestError | ||
from marqo.core.models.marqo_index import * | ||
from marqo.core.models.marqo_index_request import FieldRequest | ||
from marqo.s2_inference import types | ||
from marqo.tensor_search import add_docs | ||
from marqo.tensor_search import enums | ||
from marqo.tensor_search import tensor_search | ||
from marqo.tensor_search.models.add_docs_objects import AddDocsParams | ||
from tests.marqo_test import MarqoTestCase | ||
|
||
|
||
class TestAddDocumentsStructured(MarqoTestCase): | ||
@classmethod | ||
def setUpClass(cls) -> None: | ||
super().setUpClass() | ||
|
||
structured_image_index_request = cls.structured_marqo_index_request( | ||
name="structured_image_index" + str(uuid.uuid4()).replace('-', ''), | ||
fields=[ | ||
FieldRequest(name="image_field_1", type=FieldType.ImagePointer), | ||
FieldRequest(name="text_field_1", type=FieldType.Text, | ||
features=[FieldFeature.Filter, FieldFeature.LexicalSearch]) | ||
], | ||
model=Model(name="open_clip/ViT-B-32/laion2b_s34b_b79k"), | ||
tensor_fields=["image_field_1", "text_field_1"] | ||
) | ||
|
||
unstructured_image_index_request = cls.unstructured_marqo_index_request( | ||
name="unstructured_image_index" + str(uuid.uuid4()).replace('-', ''), | ||
model=Model(name="open_clip/ViT-B-32/laion2b_s34b_b79k"), | ||
treat_urls_and_pointers_as_images=True | ||
) | ||
|
||
cls.indexes = cls.create_indexes([ | ||
structured_image_index_request, | ||
unstructured_image_index_request | ||
]) | ||
|
||
cls.structured_marqo_index_name = structured_image_index_request.name | ||
cls.unstructured_marqo_index_name = unstructured_image_index_request.name | ||
|
||
def setUp(self) -> None: | ||
super().setUp() | ||
|
||
# Any tests that call add_documents, search, bulk_search need this env var | ||
self.device_patcher = mock.patch.dict(os.environ, {"MARQO_BEST_AVAILABLE_DEVICE": "cpu"}) | ||
self.device_patcher.start() | ||
|
||
def tearDown(self) -> None: | ||
super().tearDown() | ||
self.device_patcher.stop() | ||
|
||
def test_add_documents_with_truncated_image(self): | ||
"""Test to ensure that the add_documents API can properly return 400 for the document with a truncated image.""" | ||
truncated_image_url = "https://marqo-assets.s3.amazonaws.com/tests/images/truncated_image.jpg" | ||
|
||
documents = [ | ||
{ | ||
"image_field_1": "https://marqo-assets.s3.amazonaws.com/tests/images/ai_hippo_statue.png", | ||
"text_field_1": "This is a valid image", | ||
"_id": "1" | ||
}, | ||
{ | ||
"image_field_1": truncated_image_url, | ||
"text_field_1": "This is a truncated image", | ||
"_id": "2" | ||
} | ||
] | ||
|
||
for index_name in [self.structured_marqo_index_name, self.unstructured_marqo_index_name]: | ||
tensor_fields = ["image_field_1", "text_field_1"] if index_name == self.unstructured_marqo_index_name \ | ||
else None | ||
with self.subTest(f"test add documents with truncated image for {index_name}"): | ||
r = tensor_search.add_documents(config=self.config, | ||
add_docs_params=AddDocsParams(index_name=index_name, | ||
docs=documents, | ||
tensor_fields=tensor_fields)) | ||
self.assertEqual(True, r["errors"]) | ||
self.assertEqual(2, len(r["items"])) | ||
self.assertEqual(200, r["items"][0]["status"]) | ||
self.assertEqual(400, r["items"][1]["status"]) | ||
self.assertIn("image file is truncated", r["items"][1]["error"]) |