Skip to content

Commit

Permalink
Add GAPIC support for face detection.
Browse files Browse the repository at this point in the history
  • Loading branch information
daspecster committed Jan 6, 2017
1 parent 83207a1 commit b5bf40f
Show file tree
Hide file tree
Showing 11 changed files with 331 additions and 70 deletions.
19 changes: 5 additions & 14 deletions system_tests/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def _assert_coordinate(self, coordinate):
if coordinate is None:
return
self.assertIsInstance(coordinate, (int, float))
self.assertNotEqual(coordinate, 0.0)

def _assert_likelihood(self, likelihood):
from google.cloud.vision.likelihood import Likelihood
Expand Down Expand Up @@ -133,6 +132,7 @@ def test_detect_logos_gcs(self):

class TestVisionClientFace(BaseVisionTestCase):
def setUp(self):
Config.CLIENT = vision.Client(use_gax=True)
self.to_delete_by_case = []

def tearDown(self):
Expand All @@ -146,7 +146,7 @@ def _assert_landmarks(self, landmarks):

for landmark in LandmarkTypes:
if landmark is not LandmarkTypes.UNKNOWN_LANDMARK:
feature = getattr(landmarks, landmark.value.lower())
feature = getattr(landmarks, landmark.name.lower())
self.assertIsInstance(feature, Landmark)
self.assertIsInstance(feature.position, Position)
self._assert_coordinate(feature.position.x_coordinate)
Expand Down Expand Up @@ -190,7 +190,6 @@ def _assert_face(self, face):

def test_detect_faces_content(self):
client = Config.CLIENT
client._use_gax = False
with open(FACE_FILE, 'rb') as image_file:
image = client.image(content=image_file.read())
faces = image.detect_faces()
Expand All @@ -209,7 +208,6 @@ def test_detect_faces_gcs(self):
source_uri = 'gs://%s/%s' % (bucket_name, blob_name)

client = Config.CLIENT
client._use_gax = False
image = client.image(source_uri=source_uri)
faces = image.detect_faces()
self.assertEqual(len(faces), 5)
Expand All @@ -218,7 +216,6 @@ def test_detect_faces_gcs(self):

def test_detect_faces_filename(self):
client = Config.CLIENT
client._use_gax = False
image = client.image(filename=FACE_FILE)
faces = image.detect_faces()
self.assertEqual(len(faces), 5)
Expand Down Expand Up @@ -292,6 +289,7 @@ class TestVisionClientLandmark(BaseVisionTestCase):
DESCRIPTIONS = ('Mount Rushmore',)

def setUp(self):
Config.CLIENT = vision.Client(use_gax=True)
self.to_delete_by_case = []

def tearDown(self):
Expand All @@ -313,7 +311,6 @@ def _assert_landmark(self, landmark):

def test_detect_landmark_content(self):
client = Config.CLIENT
client._use_gax = True
with open(LANDMARK_FILE, 'rb') as image_file:
image = client.image(content=image_file.read())
landmarks = image.detect_landmarks()
Expand All @@ -332,7 +329,6 @@ def test_detect_landmark_gcs(self):
source_uri = 'gs://%s/%s' % (bucket_name, blob_name)

client = Config.CLIENT
client._use_gax = True
image = client.image(source_uri=source_uri)
landmarks = image.detect_landmarks()
self.assertEqual(len(landmarks), 1)
Expand All @@ -341,7 +337,6 @@ def test_detect_landmark_gcs(self):

def test_detect_landmark_filename(self):
client = Config.CLIENT
client._use_gax = True
image = client.image(filename=LANDMARK_FILE)
landmarks = image.detect_landmarks()
self.assertEqual(len(landmarks), 1)
Expand All @@ -351,6 +346,7 @@ def test_detect_landmark_filename(self):

class TestVisionClientSafeSearch(BaseVisionTestCase):
def setUp(self):
Config.CLIENT = vision.Client(use_gax=False)
self.to_delete_by_case = []

def tearDown(self):
Expand All @@ -368,7 +364,6 @@ def _assert_safe_search(self, safe_search):

def test_detect_safe_search_content(self):
client = Config.CLIENT
client._use_gax = False
with open(FACE_FILE, 'rb') as image_file:
image = client.image(content=image_file.read())
safe_searches = image.detect_safe_search()
Expand All @@ -387,7 +382,6 @@ def test_detect_safe_search_gcs(self):
source_uri = 'gs://%s/%s' % (bucket_name, blob_name)

client = Config.CLIENT
client._use_gax = False
image = client.image(source_uri=source_uri)
safe_searches = image.detect_safe_search()
self.assertEqual(len(safe_searches), 1)
Expand All @@ -396,7 +390,6 @@ def test_detect_safe_search_gcs(self):

def test_detect_safe_search_filename(self):
client = Config.CLIENT
client._use_gax = False
image = client.image(filename=FACE_FILE)
safe_searches = image.detect_safe_search()
self.assertEqual(len(safe_searches), 1)
Expand Down Expand Up @@ -470,6 +463,7 @@ def test_detect_text_filename(self):

class TestVisionClientImageProperties(BaseVisionTestCase):
def setUp(self):
Config.CLIENT = vision.Client(use_gax=False)
self.to_delete_by_case = []

def tearDown(self):
Expand Down Expand Up @@ -497,7 +491,6 @@ def _assert_properties(self, image_property):

def test_detect_properties_content(self):
client = Config.CLIENT
client._use_gax = False
with open(FACE_FILE, 'rb') as image_file:
image = client.image(content=image_file.read())
properties = image.detect_properties()
Expand All @@ -516,7 +509,6 @@ def test_detect_properties_gcs(self):
source_uri = 'gs://%s/%s' % (bucket_name, blob_name)

client = Config.CLIENT
client._use_gax = False
image = client.image(source_uri=source_uri)
properties = image.detect_properties()
self.assertEqual(len(properties), 1)
Expand All @@ -525,7 +517,6 @@ def test_detect_properties_gcs(self):

def test_detect_properties_filename(self):
client = Config.CLIENT
client._use_gax = False
image = client.image(filename=FACE_FILE)
properties = image.detect_properties()
self.assertEqual(len(properties), 1)
Expand Down
4 changes: 1 addition & 3 deletions vision/google/cloud/vision/_gax.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
from google.cloud.gapic.vision.v1 import image_annotator_client
from google.cloud.grpc.vision.v1 import image_annotator_pb2

from google.cloud._helpers import _to_bytes

from google.cloud.vision.annotations import Annotations


Expand Down Expand Up @@ -81,7 +79,7 @@ def _to_gapic_image(image):
:class:`~google.cloud.vision.image.Image`.
"""
if image.content is not None:
return image_annotator_pb2.Image(content=_to_bytes(image.content))
return image_annotator_pb2.Image(content=image.content)
if image.source is not None:
return image_annotator_pb2.Image(
source=image_annotator_pb2.ImageSource(
Expand Down
17 changes: 17 additions & 0 deletions vision/google/cloud/vision/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def _process_image_annotations(image):
:returns: Dictionary populated with entities from response.
"""
annotations = {}
annotations['faces'] = _make_faces_from_pb(image.face_annotations)
annotations['labels'] = _make_entity_from_pb(image.label_annotations)
annotations['landmarks'] = _make_entity_from_pb(image.landmark_annotations)
annotations['logos'] = _make_entity_from_pb(image.logo_annotations)
Expand All @@ -143,6 +144,22 @@ def _make_entity_from_pb(annotations):
return entities


def _make_faces_from_pb(annotations):
"""Create face objects from a gRPC response.
:type annotations:
:class:`~google.cloud.grpc.vision.v1.image_annotator_pb2.FaceAnnotation`
:param annotations: gRPC instance of ``FaceAnnotation``.
:rtype: list
:returns: List of ``Face``.
"""
faces = []
for annotation in annotations:
faces.append(Face.from_pb(annotation))
return faces


def _entity_from_response_type(feature_type, results):
"""Convert a JSON result to an entity type based on the feature.
Expand Down
Loading

0 comments on commit b5bf40f

Please sign in to comment.