diff --git a/packages/google-cloud-vision/google/cloud/vision/annotations.py b/packages/google-cloud-vision/google/cloud/vision/annotations.py new file mode 100644 index 000000000000..43b828194b5b --- /dev/null +++ b/packages/google-cloud-vision/google/cloud/vision/annotations.py @@ -0,0 +1,119 @@ +# Copyright 2016 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Annotations management for Vision API responses.""" + + +from google.cloud.vision.color import ImagePropertiesAnnotation +from google.cloud.vision.entity import EntityAnnotation +from google.cloud.vision.face import Face +from google.cloud.vision.safe import SafeSearchAnnotation + + +FACE_ANNOTATIONS = 'faceAnnotations' +IMAGE_PROPERTIES_ANNOTATION = 'imagePropertiesAnnotation' +SAFE_SEARCH_ANNOTATION = 'safeSearchAnnotation' + +_KEY_MAP = { + FACE_ANNOTATIONS: 'faces', + IMAGE_PROPERTIES_ANNOTATION: 'properties', + 'labelAnnotations': 'labels', + 'landmarkAnnotations': 'landmarks', + 'logoAnnotations': 'logos', + SAFE_SEARCH_ANNOTATION: 'safe_searches', + 'textAnnotations': 'texts' +} + + +class Annotations(object): + """Helper class to bundle annotation responses. + + :type faces: list + :param faces: List of :class:`~google.cloud.vision.face.Face`. + + :type properties: list + :param properties: + List of :class:`~google.cloud.vision.color.ImagePropertiesAnnotation`. + + :type labels: list + :param labels: List of + :class:`~google.cloud.vision.entity.EntityAnnotation`. + + :type landmarks: list + :param landmarks: List of + :class:`~google.cloud.vision.entity.EntityAnnotation.` + + :type logos: list + :param logos: List of + :class:`~google.cloud.vision.entity.EntityAnnotation`. + + :type safe_searches: list + :param safe_searches: + List of :class:`~google.cloud.vision.safe.SafeSearchAnnotation` + + :type texts: list + :param texts: List of + :class:`~google.cloud.vision.entity.EntityAnnotation`. + """ + def __init__(self, faces=(), properties=(), labels=(), landmarks=(), + logos=(), safe_searches=(), texts=()): + self.faces = faces + self.properties = properties + self.labels = labels + self.landmarks = landmarks + self.logos = logos + self.safe_searches = safe_searches + self.texts = texts + + @classmethod + def from_api_repr(cls, response): + """Factory: construct an instance of ``Annotations`` from a response. + + :type response: dict + :param response: Vision API response object. + + :rtype: :class:`~google.cloud.vision.annotations.Annotations` + :returns: An instance of ``Annotations`` with detection types loaded. + """ + annotations = {} + for feature_type, annotation in response.items(): + curr_feature = annotations.setdefault(_KEY_MAP[feature_type], []) + curr_feature.extend( + _entity_from_response_type(feature_type, annotation)) + return cls(**annotations) + + +def _entity_from_response_type(feature_type, results): + """Convert a JSON result to an entity type based on the feature. + + :rtype: list + :returns: List containing any of + :class:`~google.cloud.vision.color.ImagePropertiesAnnotation`, + :class:`~google.cloud.vision.entity.EntityAnnotation`, + :class:`~google.cloud.vision.face.Face`, + :class:`~google.cloud.vision.safe.SafeSearchAnnotation`. + """ + detected_objects = [] + if feature_type == FACE_ANNOTATIONS: + detected_objects.extend( + Face.from_api_repr(face) for face in results) + elif feature_type == IMAGE_PROPERTIES_ANNOTATION: + detected_objects.append( + ImagePropertiesAnnotation.from_api_repr(results)) + elif feature_type == SAFE_SEARCH_ANNOTATION: + detected_objects.append(SafeSearchAnnotation.from_api_repr(results)) + else: + for result in results: + detected_objects.append(EntityAnnotation.from_api_repr(result)) + return detected_objects diff --git a/packages/google-cloud-vision/google/cloud/vision/image.py b/packages/google-cloud-vision/google/cloud/vision/image.py index 48d49bf65983..9f7a2afcdaee 100644 --- a/packages/google-cloud-vision/google/cloud/vision/image.py +++ b/packages/google-cloud-vision/google/cloud/vision/image.py @@ -19,31 +19,9 @@ from google.cloud._helpers import _to_bytes from google.cloud._helpers import _bytes_to_unicode -from google.cloud.vision.entity import EntityAnnotation -from google.cloud.vision.face import Face +from google.cloud.vision.annotations import Annotations from google.cloud.vision.feature import Feature from google.cloud.vision.feature import FeatureTypes -from google.cloud.vision.color import ImagePropertiesAnnotation -from google.cloud.vision.safe import SafeSearchAnnotation - - -_FACE_DETECTION = 'FACE_DETECTION' -_IMAGE_PROPERTIES = 'IMAGE_PROPERTIES' -_LABEL_DETECTION = 'LABEL_DETECTION' -_LANDMARK_DETECTION = 'LANDMARK_DETECTION' -_LOGO_DETECTION = 'LOGO_DETECTION' -_SAFE_SEARCH_DETECTION = 'SAFE_SEARCH_DETECTION' -_TEXT_DETECTION = 'TEXT_DETECTION' - -_REVERSE_TYPES = { - _FACE_DETECTION: 'faceAnnotations', - _IMAGE_PROPERTIES: 'imagePropertiesAnnotation', - _LABEL_DETECTION: 'labelAnnotations', - _LANDMARK_DETECTION: 'landmarkAnnotations', - _LOGO_DETECTION: 'logoAnnotations', - _SAFE_SEARCH_DETECTION: 'safeSearchAnnotation', - _TEXT_DETECTION: 'textAnnotations', -} class Image(object): @@ -105,7 +83,7 @@ def source(self): return self._source def _detect_annotation(self, features): - """Generic method for detecting a single annotation. + """Generic method for detecting annotations. :type features: list :param features: List of :class:`~google.cloud.vision.feature.Feature` @@ -118,12 +96,21 @@ def _detect_annotation(self, features): :class:`~google.cloud.vision.color.ImagePropertiesAnnotation`, :class:`~google.cloud.vision.sage.SafeSearchAnnotation`, """ - detected_objects = [] results = self.client.annotate(self, features) - for feature in features: - detected_objects.extend( - _entity_from_response_type(feature.feature_type, results)) - return detected_objects + return Annotations.from_api_repr(results) + + def detect(self, features): + """Detect multiple feature types. + + :type features: list of :class:`~google.cloud.vision.feature.Feature` + :param features: List of the ``Feature`` indication the type of + annotation to perform. + + :rtype: list + :returns: List of + :class:`~google.cloud.vision.entity.EntityAnnotation`. + """ + return self._detect_annotation(features) def detect_faces(self, limit=10): """Detect faces in image. @@ -135,7 +122,8 @@ def detect_faces(self, limit=10): :returns: List of :class:`~google.cloud.vision.face.Face`. """ features = [Feature(FeatureTypes.FACE_DETECTION, limit)] - return self._detect_annotation(features) + annotations = self._detect_annotation(features) + return annotations.faces def detect_labels(self, limit=10): """Detect labels that describe objects in an image. @@ -147,7 +135,8 @@ def detect_labels(self, limit=10): :returns: List of :class:`~google.cloud.vision.entity.EntityAnnotation` """ features = [Feature(FeatureTypes.LABEL_DETECTION, limit)] - return self._detect_annotation(features) + annotations = self._detect_annotation(features) + return annotations.labels def detect_landmarks(self, limit=10): """Detect landmarks in an image. @@ -160,7 +149,8 @@ def detect_landmarks(self, limit=10): :class:`~google.cloud.vision.entity.EntityAnnotation`. """ features = [Feature(FeatureTypes.LANDMARK_DETECTION, limit)] - return self._detect_annotation(features) + annotations = self._detect_annotation(features) + return annotations.landmarks def detect_logos(self, limit=10): """Detect logos in an image. @@ -173,7 +163,8 @@ def detect_logos(self, limit=10): :class:`~google.cloud.vision.entity.EntityAnnotation`. """ features = [Feature(FeatureTypes.LOGO_DETECTION, limit)] - return self._detect_annotation(features) + annotations = self._detect_annotation(features) + return annotations.logos def detect_properties(self, limit=10): """Detect the color properties of an image. @@ -186,7 +177,8 @@ def detect_properties(self, limit=10): :class:`~google.cloud.vision.color.ImagePropertiesAnnotation`. """ features = [Feature(FeatureTypes.IMAGE_PROPERTIES, limit)] - return self._detect_annotation(features) + annotations = self._detect_annotation(features) + return annotations.properties def detect_safe_search(self, limit=10): """Retreive safe search properties from an image. @@ -199,7 +191,8 @@ def detect_safe_search(self, limit=10): :class:`~google.cloud.vision.sage.SafeSearchAnnotation`. """ features = [Feature(FeatureTypes.SAFE_SEARCH_DETECTION, limit)] - return self._detect_annotation(features) + annotations = self._detect_annotation(features) + return annotations.safe_searches def detect_text(self, limit=10): """Detect text in an image. @@ -212,27 +205,5 @@ def detect_text(self, limit=10): :class:`~google.cloud.vision.entity.EntityAnnotation`. """ features = [Feature(FeatureTypes.TEXT_DETECTION, limit)] - return self._detect_annotation(features) - - -def _entity_from_response_type(feature_type, results): - """Convert a JSON result to an entity type based on the feature.""" - feature_key = _REVERSE_TYPES[feature_type] - annotations = results.get(feature_key, ()) - if not annotations: - return [] - - detected_objects = [] - if feature_type == _FACE_DETECTION: - detected_objects.extend( - Face.from_api_repr(face) for face in annotations) - elif feature_type == _IMAGE_PROPERTIES: - detected_objects.append( - ImagePropertiesAnnotation.from_api_repr(annotations)) - elif feature_type == _SAFE_SEARCH_DETECTION: - detected_objects.append( - SafeSearchAnnotation.from_api_repr(annotations)) - else: - for result in annotations: - detected_objects.append(EntityAnnotation.from_api_repr(result)) - return detected_objects + annotations = self._detect_annotation(features) + return annotations.texts diff --git a/packages/google-cloud-vision/unit_tests/test_client.py b/packages/google-cloud-vision/unit_tests/test_client.py index be3bda73f020..a7ff1cccd81f 100644 --- a/packages/google-cloud-vision/unit_tests/test_client.py +++ b/packages/google-cloud-vision/unit_tests/test_client.py @@ -81,6 +81,60 @@ def test_image_with_client(self): image = client.image(source_uri=IMAGE_SOURCE) self.assertIsInstance(image, Image) + def test_multiple_detection_from_content(self): + import copy + from google.cloud.vision.feature import Feature + from google.cloud.vision.feature import FeatureTypes + from unit_tests._fixtures import LABEL_DETECTION_RESPONSE + from unit_tests._fixtures import LOGO_DETECTION_RESPONSE + + returned = copy.deepcopy(LABEL_DETECTION_RESPONSE) + logos = copy.deepcopy(LOGO_DETECTION_RESPONSE['responses'][0]) + returned['responses'][0]['logoAnnotations'] = logos['logoAnnotations'] + + credentials = _Credentials() + client = self._make_one(project=PROJECT, credentials=credentials) + client._connection = _Connection(returned) + + limit = 2 + label_feature = Feature(FeatureTypes.LABEL_DETECTION, limit) + logo_feature = Feature(FeatureTypes.LOGO_DETECTION, limit) + features = [label_feature, logo_feature] + image = client.image(content=IMAGE_CONTENT) + items = image.detect(features) + + self.assertEqual(len(items.logos), 2) + self.assertEqual(len(items.labels), 3) + first_logo = items.logos[0] + second_logo = items.logos[1] + self.assertEqual(first_logo.description, 'Brand1') + self.assertEqual(first_logo.score, 0.63192177) + self.assertEqual(second_logo.description, 'Brand2') + self.assertEqual(second_logo.score, 0.5492993) + + first_label = items.labels[0] + second_label = items.labels[1] + third_label = items.labels[2] + self.assertEqual(first_label.description, 'automobile') + self.assertEqual(first_label.score, 0.9776855) + self.assertEqual(second_label.description, 'vehicle') + self.assertEqual(second_label.score, 0.947987) + self.assertEqual(third_label.description, 'truck') + self.assertEqual(third_label.score, 0.88429511) + + requested = client._connection._requested + requests = requested[0]['data']['requests'] + image_request = requests[0] + label_request = image_request['features'][0] + logo_request = image_request['features'][1] + + self.assertEqual(B64_IMAGE_CONTENT, + image_request['image']['content']) + self.assertEqual(label_request['maxResults'], 2) + self.assertEqual(label_request['type'], 'LABEL_DETECTION') + self.assertEqual(logo_request['maxResults'], 2) + self.assertEqual(logo_request['type'], 'LOGO_DETECTION') + def test_face_detection_from_source(self): from google.cloud.vision.face import Face from unit_tests._fixtures import FACE_DETECTION_RESPONSE @@ -126,7 +180,7 @@ def test_face_detection_from_content_no_results(self): image = client.image(content=IMAGE_CONTENT) faces = image.detect_faces(limit=5) - self.assertEqual(faces, []) + self.assertEqual(faces, ()) self.assertEqual(len(faces), 0) image_request = client._connection._requested[0]['data']['requests'][0] @@ -166,7 +220,7 @@ def test_label_detection_no_results(self): image = client.image(content=IMAGE_CONTENT) labels = image.detect_labels() - self.assertEqual(labels, []) + self.assertEqual(labels, ()) self.assertEqual(len(labels), 0) def test_landmark_detection_from_source(self): @@ -219,7 +273,7 @@ def test_landmark_detection_no_results(self): image = client.image(content=IMAGE_CONTENT) landmarks = image.detect_landmarks() - self.assertEqual(landmarks, []) + self.assertEqual(landmarks, ()) self.assertEqual(len(landmarks), 0) def test_logo_detection_from_source(self): @@ -308,7 +362,7 @@ def test_safe_search_no_results(self): image = client.image(content=IMAGE_CONTENT) safe_search = image.detect_safe_search() - self.assertEqual(safe_search, []) + self.assertEqual(safe_search, ()) self.assertEqual(len(safe_search), 0) def test_image_properties_detection_from_source(self): @@ -344,7 +398,7 @@ def test_image_properties_no_results(self): image = client.image(content=IMAGE_CONTENT) image_properties = image.detect_properties() - self.assertEqual(image_properties, []) + self.assertEqual(image_properties, ()) self.assertEqual(len(image_properties), 0)