Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Speech LRO handling and add HTTP side for multiple results. #2965

Merged
merged 3 commits into from
Feb 6, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions docs/speech-usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ See: `Speech Asynchronous Recognize`_
>>> operation.complete
True
>>> for result in operation.results:
... print('=' * 20)
... print(result.transcript)
... print(result.confidence)
... for alternative in result.alternatives:
... print('=' * 20)
... print(alternative.transcript)
... print(alternative.confidence)
====================
'how old is the Brooklyn Bridge'
0.98267895
Expand All @@ -93,9 +94,10 @@ Great Britian.
... source_uri='gs://my-bucket/recording.flac', language_code='en-GB',
... max_alternatives=2)
>>> for result in results:
... print('=' * 20)
... print('transcript: ' + result.transcript)
... print('confidence: ' + result.confidence)
... for alternative in result.alternatives:
... print('=' * 20)
... print('transcript: ' + alternative.transcript)
... print('confidence: ' + alternative.confidence)
====================
transcript: Hello, this is a test
confidence: 0.81
Expand All @@ -115,9 +117,10 @@ Example of using the profanity filter.
>>> results = sample.sync_recognize(max_alternatives=1,
... profanity_filter=True)
>>> for result in results:
... print('=' * 20)
... print('transcript: ' + result.transcript)
... print('confidence: ' + result.confidence)
... for alternative in result.alternatives:
... print('=' * 20)
... print('transcript: ' + alternative.transcript)
... print('confidence: ' + alternative.confidence)
====================
transcript: Hello, this is a f****** test
confidence: 0.81
Expand All @@ -137,9 +140,10 @@ words to the vocabulary of the recognizer.
>>> results = sample.sync_recognize(max_alternatives=2,
... speech_context=hints)
>>> for result in results:
... print('=' * 20)
... print('transcript: ' + result.transcript)
... print('confidence: ' + result.confidence)
... for alternative in result.alternatives:
... print('=' * 20)
... print('transcript: ' + alternative.transcript)
... print('confidence: ' + alternative.confidence)
====================
transcript: Hello, this is a test
confidence: 0.81
Expand Down
4 changes: 2 additions & 2 deletions speech/google/cloud/speech/_gax.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ def async_recognize(self, sample, language_code=None,
audio = RecognitionAudio(content=sample.content,
uri=sample.source_uri)
api = self._gapic_api
response = api.async_recognize(config=config, audio=audio)
operation_future = api.async_recognize(config=config, audio=audio)

return Operation.from_pb(response, self)
return Operation.from_pb(operation_future.last_operation_data(), self)

This comment was marked as spam.

This comment was marked as spam.

This comment was marked as spam.

This comment was marked as spam.

This comment was marked as spam.


def streaming_recognize(self, sample, language_code=None,
max_alternatives=None, profanity_filter=None,
Expand Down
11 changes: 5 additions & 6 deletions speech/google/cloud/speech/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from google.cloud.environment_vars import DISABLE_GRPC

from google.cloud.speech._gax import GAPICSpeechAPI
from google.cloud.speech.alternative import Alternative
from google.cloud.speech.result import Result
from google.cloud.speech.connection import Connection
from google.cloud.speech.operation import Operation
from google.cloud.speech.sample import Sample
Expand Down Expand Up @@ -235,12 +235,11 @@ def sync_recognize(self, sample, language_code=None, max_alternatives=None,
api_response = self._connection.api_request(
method='POST', path='speech:syncrecognize', data=data)

if len(api_response['results']) == 1:
result = api_response['results'][0]
return [Alternative.from_api_repr(alternative)
for alternative in result['alternatives']]
if len(api_response['results']) > 0:
results = api_response['results']
return [Result.from_api_repr(result) for result in results]
else:
raise ValueError('More than one result or none returned from API.')
raise ValueError('No results were returned from the API')


def _build_request_data(sample, language_code=None, max_alternatives=None,
Expand Down
27 changes: 21 additions & 6 deletions speech/google/cloud/speech/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,34 @@ def __init__(self, alternatives):

@classmethod
def from_pb(cls, result):
"""Factory: construct instance of ``SpeechRecognitionResult``.
"""Factory: construct instance of ``Result``.

:type result: :class:`~google.cloud.grpc.speech.v1beta1\
.cloud_speech_pb2.StreamingRecognizeResult`
:param result: Instance of ``StreamingRecognizeResult`` protobuf.
.cloud_speech_pb2.SpeechRecognitionResult`
:param result: Instance of ``SpeechRecognitionResult`` protobuf.

:rtype: :class:`~google.cloud.speech.result.SpeechRecognitionResult`
:returns: Instance of ``SpeechRecognitionResult``.
:rtype: :class:`~google.cloud.speech.result.Result`
:returns: Instance of ``Result``.
"""
alternatives = [Alternative.from_pb(result) for result
alternatives = [Alternative.from_pb(alternative) for alternative
in result.alternatives]
return cls(alternatives=alternatives)

@classmethod
def from_api_repr(cls, result):
"""Factory: construct instance of ``Result``.

:type result: dict
:param result: Dictionary of a :class:`~google.cloud.grpc.speech.\
v1beta1.cloud_speech_pb2.SpeechRecognitionResult`

:rtype: :class:`~google.cloud.speech.result.Result`
:returns: Instance of ``Result``.
"""
alternatives = [Alternative.from_api_repr(alternative) for alternative
in result['alternatives']]
return cls(alternatives=alternatives)

@property
def confidence(self):
"""Return the confidence for the most probable alternative.
Expand Down
30 changes: 22 additions & 8 deletions speech/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def test_sync_recognize_content_with_optional_params_no_gax(self):

from google.cloud import speech
from google.cloud.speech.alternative import Alternative
from google.cloud.speech.result import Result
from unit_tests._fixtures import SYNC_RECOGNIZE_RESPONSE

_B64_AUDIO_CONTENT = _bytes_to_unicode(b64encode(self.AUDIO_CONTENT))
Expand Down Expand Up @@ -174,13 +175,16 @@ def test_sync_recognize_content_with_optional_params_no_gax(self):
alternative = SYNC_RECOGNIZE_RESPONSE['results'][0]['alternatives'][0]
expected = Alternative.from_api_repr(alternative)
self.assertEqual(len(response), 1)
self.assertIsInstance(response[0], Alternative)
self.assertEqual(response[0].transcript, expected.transcript)
self.assertEqual(response[0].confidence, expected.confidence)
self.assertIsInstance(response[0], Result)
self.assertEqual(len(response[0].alternatives), 1)
alternative = response[0].alternatives[0]
self.assertEqual(alternative.transcript, expected.transcript)
self.assertEqual(alternative.confidence, expected.confidence)

def test_sync_recognize_source_uri_without_optional_params_no_gax(self):
from google.cloud import speech
from google.cloud.speech.alternative import Alternative
from google.cloud.speech.result import Result
from unit_tests._fixtures import SYNC_RECOGNIZE_RESPONSE

RETURNED = SYNC_RECOGNIZE_RESPONSE
Expand Down Expand Up @@ -214,9 +218,12 @@ def test_sync_recognize_source_uri_without_optional_params_no_gax(self):
expected = Alternative.from_api_repr(
SYNC_RECOGNIZE_RESPONSE['results'][0]['alternatives'][0])
self.assertEqual(len(response), 1)
self.assertIsInstance(response[0], Alternative)
self.assertEqual(response[0].transcript, expected.transcript)
self.assertEqual(response[0].confidence, expected.confidence)
self.assertIsInstance(response[0], Result)
self.assertEqual(len(response[0].alternatives), 1)
alternative = response[0].alternatives[0]

self.assertEqual(alternative.transcript, expected.transcript)
self.assertEqual(alternative.confidence, expected.confidence)

def test_sync_recognize_with_empty_results_no_gax(self):
from google.cloud import speech
Expand Down Expand Up @@ -710,19 +717,26 @@ class _MockGAPICSpeechAPI(object):
_requests = None
_response = None
_results = None

SERVICE_ADDRESS = 'foo.apis.invalid'

def __init__(self, response=None, channel=None):
self._response = response
self._channel = channel

def async_recognize(self, config, audio):
from google.gapic.longrunning.operations_client import OperationsClient
from google.gax import _OperationFuture
from google.longrunning.operations_pb2 import Operation
from google.cloud.proto.speech.v1beta1.cloud_speech_pb2 import (
AsyncRecognizeResponse)

self.config = config
self.audio = audio
operation = Operation()
return operation
operations_client = mock.Mock(spec=OperationsClient)
operation_future = _OperationFuture(Operation(), operations_client,
AsyncRecognizeResponse, {})
return operation_future

def sync_recognize(self, config, audio):
self.config = config
Expand Down
22 changes: 16 additions & 6 deletions system_tests/speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,10 @@ def test_sync_recognize_local_file(self):

results = self._make_sync_request(content=content,
max_alternatives=2)
self._check_results(results, 2)
self.assertEqual(len(results), 1)
alternatives = results[0].alternatives
self.assertEqual(len(alternatives), 2)
self._check_results(alternatives, 2)

def test_sync_recognize_gcs_file(self):
bucket_name = Config.TEST_BUCKET.name
Expand All @@ -155,9 +158,10 @@ def test_sync_recognize_gcs_file(self):
blob.upload_from_file(file_obj)

source_uri = 'gs://%s/%s' % (bucket_name, blob_name)
result = self._make_sync_request(source_uri=source_uri,
max_alternatives=1)
self._check_results(result)
results = self._make_sync_request(source_uri=source_uri,
max_alternatives=1)
self.assertEqual(len(results), 1)
self._check_results(results[0].alternatives)

def test_async_recognize_local_file(self):
with open(AUDIO_FILE, 'rb') as file_obj:
Expand All @@ -167,7 +171,10 @@ def test_async_recognize_local_file(self):
max_alternatives=2)

_wait_until_complete(operation)
self._check_results(operation.results, 2)
self.assertEqual(len(operation.results), 1)
alternatives = operation.results[0].alternatives
self.assertEqual(len(alternatives), 2)
self._check_results(alternatives, 2)

def test_async_recognize_gcs_file(self):
bucket_name = Config.TEST_BUCKET.name
Expand All @@ -182,7 +189,10 @@ def test_async_recognize_gcs_file(self):
max_alternatives=2)

_wait_until_complete(operation)
self._check_results(operation.results, 2)
self.assertEqual(len(operation.results), 1)
alternatives = operation.results[0].alternatives
self.assertEqual(len(alternatives), 2)
self._check_results(alternatives, 2)

def test_stream_recognize(self):
if not Config.USE_GAX:
Expand Down