Skip to content

Commit

Permalink
Merge pull request #2770 from daspecster/add-manual-detect-to-vision-…
Browse files Browse the repository at this point in the history
…2697

Add image.detect() for detecting multiple types.
  • Loading branch information
daspecster authored Nov 30, 2016
2 parents d11aa28 + fc55c18 commit e222714
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 64 deletions.
119 changes: 119 additions & 0 deletions packages/google-cloud-vision/google/cloud/vision/annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright 2016 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Annotations management for Vision API responses."""


from google.cloud.vision.color import ImagePropertiesAnnotation
from google.cloud.vision.entity import EntityAnnotation
from google.cloud.vision.face import Face
from google.cloud.vision.safe import SafeSearchAnnotation


FACE_ANNOTATIONS = 'faceAnnotations'
IMAGE_PROPERTIES_ANNOTATION = 'imagePropertiesAnnotation'
SAFE_SEARCH_ANNOTATION = 'safeSearchAnnotation'

_KEY_MAP = {
FACE_ANNOTATIONS: 'faces',
IMAGE_PROPERTIES_ANNOTATION: 'properties',
'labelAnnotations': 'labels',
'landmarkAnnotations': 'landmarks',
'logoAnnotations': 'logos',
SAFE_SEARCH_ANNOTATION: 'safe_searches',
'textAnnotations': 'texts'
}


class Annotations(object):
"""Helper class to bundle annotation responses.
:type faces: list
:param faces: List of :class:`~google.cloud.vision.face.Face`.
:type properties: list
:param properties:
List of :class:`~google.cloud.vision.color.ImagePropertiesAnnotation`.
:type labels: list
:param labels: List of
:class:`~google.cloud.vision.entity.EntityAnnotation`.
:type landmarks: list
:param landmarks: List of
:class:`~google.cloud.vision.entity.EntityAnnotation.`
:type logos: list
:param logos: List of
:class:`~google.cloud.vision.entity.EntityAnnotation`.
:type safe_searches: list
:param safe_searches:
List of :class:`~google.cloud.vision.safe.SafeSearchAnnotation`
:type texts: list
:param texts: List of
:class:`~google.cloud.vision.entity.EntityAnnotation`.
"""
def __init__(self, faces=(), properties=(), labels=(), landmarks=(),
logos=(), safe_searches=(), texts=()):
self.faces = faces
self.properties = properties
self.labels = labels
self.landmarks = landmarks
self.logos = logos
self.safe_searches = safe_searches
self.texts = texts

@classmethod
def from_api_repr(cls, response):
"""Factory: construct an instance of ``Annotations`` from a response.
:type response: dict
:param response: Vision API response object.
:rtype: :class:`~google.cloud.vision.annotations.Annotations`
:returns: An instance of ``Annotations`` with detection types loaded.
"""
annotations = {}
for feature_type, annotation in response.items():
curr_feature = annotations.setdefault(_KEY_MAP[feature_type], [])
curr_feature.extend(
_entity_from_response_type(feature_type, annotation))
return cls(**annotations)


def _entity_from_response_type(feature_type, results):
"""Convert a JSON result to an entity type based on the feature.
:rtype: list
:returns: List containing any of
:class:`~google.cloud.vision.color.ImagePropertiesAnnotation`,
:class:`~google.cloud.vision.entity.EntityAnnotation`,
:class:`~google.cloud.vision.face.Face`,
:class:`~google.cloud.vision.safe.SafeSearchAnnotation`.
"""
detected_objects = []
if feature_type == FACE_ANNOTATIONS:
detected_objects.extend(
Face.from_api_repr(face) for face in results)
elif feature_type == IMAGE_PROPERTIES_ANNOTATION:
detected_objects.append(
ImagePropertiesAnnotation.from_api_repr(results))
elif feature_type == SAFE_SEARCH_ANNOTATION:
detected_objects.append(SafeSearchAnnotation.from_api_repr(results))
else:
for result in results:
detected_objects.append(EntityAnnotation.from_api_repr(result))
return detected_objects
89 changes: 30 additions & 59 deletions packages/google-cloud-vision/google/cloud/vision/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,31 +19,9 @@

from google.cloud._helpers import _to_bytes
from google.cloud._helpers import _bytes_to_unicode
from google.cloud.vision.entity import EntityAnnotation
from google.cloud.vision.face import Face
from google.cloud.vision.annotations import Annotations
from google.cloud.vision.feature import Feature
from google.cloud.vision.feature import FeatureTypes
from google.cloud.vision.color import ImagePropertiesAnnotation
from google.cloud.vision.safe import SafeSearchAnnotation


_FACE_DETECTION = 'FACE_DETECTION'
_IMAGE_PROPERTIES = 'IMAGE_PROPERTIES'
_LABEL_DETECTION = 'LABEL_DETECTION'
_LANDMARK_DETECTION = 'LANDMARK_DETECTION'
_LOGO_DETECTION = 'LOGO_DETECTION'
_SAFE_SEARCH_DETECTION = 'SAFE_SEARCH_DETECTION'
_TEXT_DETECTION = 'TEXT_DETECTION'

_REVERSE_TYPES = {
_FACE_DETECTION: 'faceAnnotations',
_IMAGE_PROPERTIES: 'imagePropertiesAnnotation',
_LABEL_DETECTION: 'labelAnnotations',
_LANDMARK_DETECTION: 'landmarkAnnotations',
_LOGO_DETECTION: 'logoAnnotations',
_SAFE_SEARCH_DETECTION: 'safeSearchAnnotation',
_TEXT_DETECTION: 'textAnnotations',
}


class Image(object):
Expand Down Expand Up @@ -105,7 +83,7 @@ def source(self):
return self._source

def _detect_annotation(self, features):
"""Generic method for detecting a single annotation.
"""Generic method for detecting annotations.
:type features: list
:param features: List of :class:`~google.cloud.vision.feature.Feature`
Expand All @@ -118,12 +96,21 @@ def _detect_annotation(self, features):
:class:`~google.cloud.vision.color.ImagePropertiesAnnotation`,
:class:`~google.cloud.vision.sage.SafeSearchAnnotation`,
"""
detected_objects = []
results = self.client.annotate(self, features)
for feature in features:
detected_objects.extend(
_entity_from_response_type(feature.feature_type, results))
return detected_objects
return Annotations.from_api_repr(results)

def detect(self, features):
"""Detect multiple feature types.
:type features: list of :class:`~google.cloud.vision.feature.Feature`
:param features: List of the ``Feature`` indication the type of
annotation to perform.
:rtype: list
:returns: List of
:class:`~google.cloud.vision.entity.EntityAnnotation`.
"""
return self._detect_annotation(features)

def detect_faces(self, limit=10):
"""Detect faces in image.
Expand All @@ -135,7 +122,8 @@ def detect_faces(self, limit=10):
:returns: List of :class:`~google.cloud.vision.face.Face`.
"""
features = [Feature(FeatureTypes.FACE_DETECTION, limit)]
return self._detect_annotation(features)
annotations = self._detect_annotation(features)
return annotations.faces

def detect_labels(self, limit=10):
"""Detect labels that describe objects in an image.
Expand All @@ -147,7 +135,8 @@ def detect_labels(self, limit=10):
:returns: List of :class:`~google.cloud.vision.entity.EntityAnnotation`
"""
features = [Feature(FeatureTypes.LABEL_DETECTION, limit)]
return self._detect_annotation(features)
annotations = self._detect_annotation(features)
return annotations.labels

def detect_landmarks(self, limit=10):
"""Detect landmarks in an image.
Expand All @@ -160,7 +149,8 @@ def detect_landmarks(self, limit=10):
:class:`~google.cloud.vision.entity.EntityAnnotation`.
"""
features = [Feature(FeatureTypes.LANDMARK_DETECTION, limit)]
return self._detect_annotation(features)
annotations = self._detect_annotation(features)
return annotations.landmarks

def detect_logos(self, limit=10):
"""Detect logos in an image.
Expand All @@ -173,7 +163,8 @@ def detect_logos(self, limit=10):
:class:`~google.cloud.vision.entity.EntityAnnotation`.
"""
features = [Feature(FeatureTypes.LOGO_DETECTION, limit)]
return self._detect_annotation(features)
annotations = self._detect_annotation(features)
return annotations.logos

def detect_properties(self, limit=10):
"""Detect the color properties of an image.
Expand All @@ -186,7 +177,8 @@ def detect_properties(self, limit=10):
:class:`~google.cloud.vision.color.ImagePropertiesAnnotation`.
"""
features = [Feature(FeatureTypes.IMAGE_PROPERTIES, limit)]
return self._detect_annotation(features)
annotations = self._detect_annotation(features)
return annotations.properties

def detect_safe_search(self, limit=10):
"""Retreive safe search properties from an image.
Expand All @@ -199,7 +191,8 @@ def detect_safe_search(self, limit=10):
:class:`~google.cloud.vision.sage.SafeSearchAnnotation`.
"""
features = [Feature(FeatureTypes.SAFE_SEARCH_DETECTION, limit)]
return self._detect_annotation(features)
annotations = self._detect_annotation(features)
return annotations.safe_searches

def detect_text(self, limit=10):
"""Detect text in an image.
Expand All @@ -212,27 +205,5 @@ def detect_text(self, limit=10):
:class:`~google.cloud.vision.entity.EntityAnnotation`.
"""
features = [Feature(FeatureTypes.TEXT_DETECTION, limit)]
return self._detect_annotation(features)


def _entity_from_response_type(feature_type, results):
"""Convert a JSON result to an entity type based on the feature."""
feature_key = _REVERSE_TYPES[feature_type]
annotations = results.get(feature_key, ())
if not annotations:
return []

detected_objects = []
if feature_type == _FACE_DETECTION:
detected_objects.extend(
Face.from_api_repr(face) for face in annotations)
elif feature_type == _IMAGE_PROPERTIES:
detected_objects.append(
ImagePropertiesAnnotation.from_api_repr(annotations))
elif feature_type == _SAFE_SEARCH_DETECTION:
detected_objects.append(
SafeSearchAnnotation.from_api_repr(annotations))
else:
for result in annotations:
detected_objects.append(EntityAnnotation.from_api_repr(result))
return detected_objects
annotations = self._detect_annotation(features)
return annotations.texts
64 changes: 59 additions & 5 deletions packages/google-cloud-vision/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,60 @@ def test_image_with_client(self):
image = client.image(source_uri=IMAGE_SOURCE)
self.assertIsInstance(image, Image)

def test_multiple_detection_from_content(self):
import copy
from google.cloud.vision.feature import Feature
from google.cloud.vision.feature import FeatureTypes
from unit_tests._fixtures import LABEL_DETECTION_RESPONSE
from unit_tests._fixtures import LOGO_DETECTION_RESPONSE

returned = copy.deepcopy(LABEL_DETECTION_RESPONSE)
logos = copy.deepcopy(LOGO_DETECTION_RESPONSE['responses'][0])
returned['responses'][0]['logoAnnotations'] = logos['logoAnnotations']

credentials = _Credentials()
client = self._make_one(project=PROJECT, credentials=credentials)
client._connection = _Connection(returned)

limit = 2
label_feature = Feature(FeatureTypes.LABEL_DETECTION, limit)
logo_feature = Feature(FeatureTypes.LOGO_DETECTION, limit)
features = [label_feature, logo_feature]
image = client.image(content=IMAGE_CONTENT)
items = image.detect(features)

self.assertEqual(len(items.logos), 2)
self.assertEqual(len(items.labels), 3)
first_logo = items.logos[0]
second_logo = items.logos[1]
self.assertEqual(first_logo.description, 'Brand1')
self.assertEqual(first_logo.score, 0.63192177)
self.assertEqual(second_logo.description, 'Brand2')
self.assertEqual(second_logo.score, 0.5492993)

first_label = items.labels[0]
second_label = items.labels[1]
third_label = items.labels[2]
self.assertEqual(first_label.description, 'automobile')
self.assertEqual(first_label.score, 0.9776855)
self.assertEqual(second_label.description, 'vehicle')
self.assertEqual(second_label.score, 0.947987)
self.assertEqual(third_label.description, 'truck')
self.assertEqual(third_label.score, 0.88429511)

requested = client._connection._requested
requests = requested[0]['data']['requests']
image_request = requests[0]
label_request = image_request['features'][0]
logo_request = image_request['features'][1]

self.assertEqual(B64_IMAGE_CONTENT,
image_request['image']['content'])
self.assertEqual(label_request['maxResults'], 2)
self.assertEqual(label_request['type'], 'LABEL_DETECTION')
self.assertEqual(logo_request['maxResults'], 2)
self.assertEqual(logo_request['type'], 'LOGO_DETECTION')

def test_face_detection_from_source(self):
from google.cloud.vision.face import Face
from unit_tests._fixtures import FACE_DETECTION_RESPONSE
Expand Down Expand Up @@ -126,7 +180,7 @@ def test_face_detection_from_content_no_results(self):

image = client.image(content=IMAGE_CONTENT)
faces = image.detect_faces(limit=5)
self.assertEqual(faces, [])
self.assertEqual(faces, ())
self.assertEqual(len(faces), 0)
image_request = client._connection._requested[0]['data']['requests'][0]

Expand Down Expand Up @@ -166,7 +220,7 @@ def test_label_detection_no_results(self):

image = client.image(content=IMAGE_CONTENT)
labels = image.detect_labels()
self.assertEqual(labels, [])
self.assertEqual(labels, ())
self.assertEqual(len(labels), 0)

def test_landmark_detection_from_source(self):
Expand Down Expand Up @@ -219,7 +273,7 @@ def test_landmark_detection_no_results(self):

image = client.image(content=IMAGE_CONTENT)
landmarks = image.detect_landmarks()
self.assertEqual(landmarks, [])
self.assertEqual(landmarks, ())
self.assertEqual(len(landmarks), 0)

def test_logo_detection_from_source(self):
Expand Down Expand Up @@ -308,7 +362,7 @@ def test_safe_search_no_results(self):

image = client.image(content=IMAGE_CONTENT)
safe_search = image.detect_safe_search()
self.assertEqual(safe_search, [])
self.assertEqual(safe_search, ())
self.assertEqual(len(safe_search), 0)

def test_image_properties_detection_from_source(self):
Expand Down Expand Up @@ -344,7 +398,7 @@ def test_image_properties_no_results(self):

image = client.image(content=IMAGE_CONTENT)
image_properties = image.detect_properties()
self.assertEqual(image_properties, [])
self.assertEqual(image_properties, ())
self.assertEqual(len(image_properties), 0)


Expand Down

0 comments on commit e222714

Please sign in to comment.