diff --git a/vision/google/cloud/vision/image.py b/vision/google/cloud/vision/image.py index 2f89d317ed54..48d49bf65983 100644 --- a/vision/google/cloud/vision/image.py +++ b/vision/google/cloud/vision/image.py @@ -217,20 +217,22 @@ def detect_text(self, limit=10): def _entity_from_response_type(feature_type, results): """Convert a JSON result to an entity type based on the feature.""" - - detected_objects = [] feature_key = _REVERSE_TYPES[feature_type] + annotations = results.get(feature_key, ()) + if not annotations: + return [] + detected_objects = [] if feature_type == _FACE_DETECTION: detected_objects.extend( - Face.from_api_repr(face) for face in results[feature_key]) + Face.from_api_repr(face) for face in annotations) elif feature_type == _IMAGE_PROPERTIES: detected_objects.append( - ImagePropertiesAnnotation.from_api_repr(results[feature_key])) + ImagePropertiesAnnotation.from_api_repr(annotations)) elif feature_type == _SAFE_SEARCH_DETECTION: - result = results[feature_key] - detected_objects.append(SafeSearchAnnotation.from_api_repr(result)) + detected_objects.append( + SafeSearchAnnotation.from_api_repr(annotations)) else: - for result in results[feature_key]: + for result in annotations: detected_objects.append(EntityAnnotation.from_api_repr(result)) return detected_objects diff --git a/vision/unit_tests/test_client.py b/vision/unit_tests/test_client.py index 374fc352a1d0..be3bda73f020 100644 --- a/vision/unit_tests/test_client.py +++ b/vision/unit_tests/test_client.py @@ -116,6 +116,24 @@ def test_face_detection_from_content(self): image_request['image']['content']) self.assertEqual(5, image_request['features'][0]['maxResults']) + def test_face_detection_from_content_no_results(self): + RETURNED = { + 'responses': [{}] + } + credentials = _Credentials() + client = self._make_one(project=PROJECT, credentials=credentials) + client._connection = _Connection(RETURNED) + + image = client.image(content=IMAGE_CONTENT) + faces = image.detect_faces(limit=5) + self.assertEqual(faces, []) + self.assertEqual(len(faces), 0) + image_request = client._connection._requested[0]['data']['requests'][0] + + self.assertEqual(B64_IMAGE_CONTENT, + image_request['image']['content']) + self.assertEqual(5, image_request['features'][0]['maxResults']) + def test_label_detection_from_source(self): from google.cloud.vision.entity import EntityAnnotation from unit_tests._fixtures import ( @@ -138,6 +156,19 @@ def test_label_detection_from_source(self): self.assertEqual('/m/0k4j', labels[0].mid) self.assertEqual('/m/07yv9', labels[1].mid) + def test_label_detection_no_results(self): + RETURNED = { + 'responses': [{}] + } + credentials = _Credentials() + client = self._make_one(project=PROJECT, credentials=credentials) + client._connection = _Connection(RETURNED) + + image = client.image(content=IMAGE_CONTENT) + labels = image.detect_labels() + self.assertEqual(labels, []) + self.assertEqual(len(labels), 0) + def test_landmark_detection_from_source(self): from google.cloud.vision.entity import EntityAnnotation from unit_tests._fixtures import ( @@ -178,6 +209,19 @@ def test_landmark_detection_from_content(self): image_request['image']['content']) self.assertEqual(5, image_request['features'][0]['maxResults']) + def test_landmark_detection_no_results(self): + RETURNED = { + 'responses': [{}] + } + credentials = _Credentials() + client = self._make_one(project=PROJECT, credentials=credentials) + client._connection = _Connection(RETURNED) + + image = client.image(content=IMAGE_CONTENT) + landmarks = image.detect_landmarks() + self.assertEqual(landmarks, []) + self.assertEqual(len(landmarks), 0) + def test_logo_detection_from_source(self): from google.cloud.vision.entity import EntityAnnotation from unit_tests._fixtures import LOGO_DETECTION_RESPONSE @@ -254,6 +298,19 @@ def test_safe_search_detection_from_source(self): self.assertEqual('POSSIBLE', safe_search.medical) self.assertEqual('VERY_UNLIKELY', safe_search.violence) + def test_safe_search_no_results(self): + RETURNED = { + 'responses': [{}] + } + credentials = _Credentials() + client = self._make_one(project=PROJECT, credentials=credentials) + client._connection = _Connection(RETURNED) + + image = client.image(content=IMAGE_CONTENT) + safe_search = image.detect_safe_search() + self.assertEqual(safe_search, []) + self.assertEqual(len(safe_search), 0) + def test_image_properties_detection_from_source(self): from google.cloud.vision.color import ImagePropertiesAnnotation from unit_tests._fixtures import IMAGE_PROPERTIES_RESPONSE @@ -277,6 +334,19 @@ def test_image_properties_detection_from_source(self): self.assertEqual(65, image_properties.colors[0].color.blue) self.assertEqual(0.0, image_properties.colors[0].color.alpha) + def test_image_properties_no_results(self): + RETURNED = { + 'responses': [{}] + } + credentials = _Credentials() + client = self._make_one(project=PROJECT, credentials=credentials) + client._connection = _Connection(RETURNED) + + image = client.image(content=IMAGE_CONTENT) + image_properties = image.detect_properties() + self.assertEqual(image_properties, []) + self.assertEqual(len(image_properties), 0) + class TestVisionRequest(unittest.TestCase): @staticmethod