Skip to content

Commit

Permalink
Streamline annotations and filtering responses.
Browse files Browse the repository at this point in the history
  • Loading branch information
daspecster committed Nov 29, 2016
1 parent 74eef27 commit 2da9e85
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 74 deletions.
3 changes: 0 additions & 3 deletions docs/vision-annotations.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
Vision Annotations
==================

Image Annotations
~~~~~~~~~~~~~~~~~

.. automodule:: google.cloud.vision.annotations
:members:
:undoc-members:
Expand Down
82 changes: 33 additions & 49 deletions vision/google/cloud/vision/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,14 @@
"""Annotations management for Vision API responses."""


from google.cloud.vision.color import ImagePropertiesAnnotation
from google.cloud.vision.entity import EntityAnnotation
from google.cloud.vision.face import Face
from google.cloud.vision.feature import FeatureTypes
from google.cloud.vision.color import ImagePropertiesAnnotation
from google.cloud.vision.safe import SafeSearchAnnotation


_REVERSE_TYPES = {
FeatureTypes.FACE_DETECTION: 'faceAnnotations',
FeatureTypes.IMAGE_PROPERTIES: 'imagePropertiesAnnotation',
FeatureTypes.LABEL_DETECTION: 'labelAnnotations',
FeatureTypes.LANDMARK_DETECTION: 'landmarkAnnotations',
FeatureTypes.LOGO_DETECTION: 'logoAnnotations',
FeatureTypes.SAFE_SEARCH_DETECTION: 'safeSearchAnnotation',
FeatureTypes.TEXT_DETECTION: 'textAnnotations',
}


class Annotations(object):
"""Annotation class for managing responses.
"""Helper class to bundle annotation responses.
:type faces: list
:param faces: List of :class:`~google.cloud.vision.face.Face`.
Expand Down Expand Up @@ -65,13 +53,13 @@ class Annotations(object):
"""
def __init__(self, faces=None, properties=None, labels=None,
landmarks=None, logos=None, safe_searches=None, texts=None):
self.faces = faces or []
self.properties = properties or []
self.labels = labels or []
self.landmarks = landmarks or []
self.logos = logos or []
self.safe_searches = safe_searches or []
self.texts = texts or []
self.faces = faces or ()
self.properties = properties or ()
self.labels = labels or ()
self.landmarks = landmarks or ()
self.logos = logos or ()
self.safe_searches = safe_searches or ()
self.texts = texts or ()

@classmethod
def from_api_repr(cls, response):
Expand All @@ -84,45 +72,41 @@ def from_api_repr(cls, response):
:returns: An instance of ``Annotations`` with detection types loaded.
"""
annotations = {}

for feature_type in response.keys():
annotations[feature_type] = []
key_map = {
'faceAnnotations': 'faces',
'imagePropertiesAnnotation': 'properties',
'labelAnnotations': 'labels',
'landmarkAnnotations': 'landmarks',
'logoAnnotations': 'logos',
'safeSearchAnnotation': 'safe_searches',
'textAnnotations': 'texts'
}

for feature_type, annotation in response.items():
annotations[feature_type].extend(
curr_feature = annotations.setdefault(key_map[feature_type], [])
curr_feature.extend(
_entity_from_response_type(feature_type, annotation))

faces = annotations.get(
_REVERSE_TYPES[FeatureTypes.FACE_DETECTION], [])
properties = annotations.get(
_REVERSE_TYPES[FeatureTypes.IMAGE_PROPERTIES], [])
labels = annotations.get(
_REVERSE_TYPES[FeatureTypes.LABEL_DETECTION], [])
landmarks = annotations.get(
_REVERSE_TYPES[FeatureTypes.LANDMARK_DETECTION], [])
logos = annotations.get(
_REVERSE_TYPES[FeatureTypes.LOGO_DETECTION], [])
safe_searches = annotations.get(
_REVERSE_TYPES[FeatureTypes.SAFE_SEARCH_DETECTION], [])
texts = annotations.get(
_REVERSE_TYPES[FeatureTypes.TEXT_DETECTION], [])

return cls(faces=faces, properties=properties, labels=labels,
landmarks=landmarks, logos=logos,
safe_searches=safe_searches, texts=texts)
return cls(**annotations)


def _entity_from_response_type(feature_type, results):
"""Convert a JSON result to an entity type based on the feature."""

"""Convert a JSON result to an entity type based on the feature.
:rtype: list
:returns: List containing any of
:class:`~google.cloud.vision.color.ImagePropertiesAnnotation`,
:class:`~google.cloud.vision.entity.EntityAnnotation`,
:class:`~google.cloud.vision.face.Face`,
:class:`~google.cloud.vision.safe.SafeSearchAnnotation`.
"""
detected_objects = []
if feature_type == _REVERSE_TYPES[FeatureTypes.FACE_DETECTION]:
if feature_type == 'faceAnnotations':
detected_objects.extend(
Face.from_api_repr(face) for face in results)
elif feature_type == _REVERSE_TYPES[FeatureTypes.IMAGE_PROPERTIES]:
elif feature_type == 'imagePropertiesAnnotation':
detected_objects.append(
ImagePropertiesAnnotation.from_api_repr(results))
elif feature_type == _REVERSE_TYPES[FeatureTypes.SAFE_SEARCH_DETECTION]:
elif feature_type == 'safeSearchAnnotation':
detected_objects.append(SafeSearchAnnotation.from_api_repr(results))
else:
for result in results:
Expand Down
53 changes: 31 additions & 22 deletions vision/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,19 @@ def test_image_with_client(self):
self.assertIsInstance(image, Image)

def test_multiple_detection_from_content(self):
import copy
from google.cloud.vision.feature import Feature
from google.cloud.vision.feature import FeatureTypes
from unit_tests._fixtures import LABEL_DETECTION_RESPONSE
from unit_tests._fixtures import LOGO_DETECTION_RESPONSE
RETURNED = LABEL_DETECTION_RESPONSE
LOGOS = LOGO_DETECTION_RESPONSE['responses'][0]['logoAnnotations']
RETURNED['responses'][0]['logoAnnotations'] = LOGOS

returned = copy.deepcopy(LABEL_DETECTION_RESPONSE)
logos = copy.deepcopy(LOGO_DETECTION_RESPONSE['responses'][0])
returned['responses'][0]['logoAnnotations'] = logos['logoAnnotations']

credentials = _Credentials()
client = self._make_one(project=PROJECT, credentials=credentials)
client._connection = _Connection(RETURNED)
client._connection = _Connection(returned)

limit = 2
label_feature = Feature(FeatureTypes.LABEL_DETECTION, limit)
Expand All @@ -103,19 +105,26 @@ def test_multiple_detection_from_content(self):

self.assertEqual(len(items.logos), 2)
self.assertEqual(len(items.labels), 3)
self.assertEqual(items.logos[0].description, 'Brand1')
self.assertEqual(items.logos[0].score, 0.63192177)
self.assertEqual(items.logos[1].description, 'Brand2')
self.assertEqual(items.logos[1].score, 0.5492993)

self.assertEqual(items.labels[0].description, 'automobile')
self.assertEqual(items.labels[0].score, 0.9776855)
self.assertEqual(items.labels[1].description, 'vehicle')
self.assertEqual(items.labels[1].score, 0.947987)
self.assertEqual(items.labels[2].description, 'truck')
self.assertEqual(items.labels[2].score, 0.88429511)

image_request = client._connection._requested[0]['data']['requests'][0]
first_logo = items.logos[0]
second_logo = items.logos[1]
self.assertEqual(first_logo.description, 'Brand1')
self.assertEqual(first_logo.score, 0.63192177)
self.assertEqual(second_logo.description, 'Brand2')
self.assertEqual(second_logo.score, 0.5492993)

first_label = items.labels[0]
second_label = items.labels[1]
third_label = items.labels[2]
self.assertEqual(first_label.description, 'automobile')
self.assertEqual(first_label.score, 0.9776855)
self.assertEqual(second_label.description, 'vehicle')
self.assertEqual(second_label.score, 0.947987)
self.assertEqual(third_label.description, 'truck')
self.assertEqual(third_label.score, 0.88429511)

requested = client._connection._requested
requests = requested[0]['data']['requests']
image_request = requests[0]
label_request = image_request['features'][0]
logo_request = image_request['features'][1]

Expand Down Expand Up @@ -171,7 +180,7 @@ def test_face_detection_from_content_no_results(self):

image = client.image(content=IMAGE_CONTENT)
faces = image.detect_faces(limit=5)
self.assertEqual(faces, [])
self.assertEqual(faces, ())
self.assertEqual(len(faces), 0)
image_request = client._connection._requested[0]['data']['requests'][0]

Expand Down Expand Up @@ -211,7 +220,7 @@ def test_label_detection_no_results(self):

image = client.image(content=IMAGE_CONTENT)
labels = image.detect_labels()
self.assertEqual(labels, [])
self.assertEqual(labels, ())
self.assertEqual(len(labels), 0)

def test_landmark_detection_from_source(self):
Expand Down Expand Up @@ -264,7 +273,7 @@ def test_landmark_detection_no_results(self):

image = client.image(content=IMAGE_CONTENT)
landmarks = image.detect_landmarks()
self.assertEqual(landmarks, [])
self.assertEqual(landmarks, ())
self.assertEqual(len(landmarks), 0)

def test_logo_detection_from_source(self):
Expand Down Expand Up @@ -353,7 +362,7 @@ def test_safe_search_no_results(self):

image = client.image(content=IMAGE_CONTENT)
safe_search = image.detect_safe_search()
self.assertEqual(safe_search, [])
self.assertEqual(safe_search, ())
self.assertEqual(len(safe_search), 0)

def test_image_properties_detection_from_source(self):
Expand Down Expand Up @@ -389,7 +398,7 @@ def test_image_properties_no_results(self):

image = client.image(content=IMAGE_CONTENT)
image_properties = image.detect_properties()
self.assertEqual(image_properties, [])
self.assertEqual(image_properties, ())
self.assertEqual(len(image_properties), 0)


Expand Down

0 comments on commit 2da9e85

Please sign in to comment.