diff --git a/docs/speech-usage.rst b/docs/speech-usage.rst index 5ba566432fdd..f82236fb8092 100644 --- a/docs/speech-usage.rst +++ b/docs/speech-usage.rst @@ -64,9 +64,10 @@ See: `Speech Asynchronous Recognize`_ >>> operation.complete True >>> for result in operation.results: - ... print('=' * 20) - ... print(result.transcript) - ... print(result.confidence) + ... for alternative in result.alternatives: + ... print('=' * 20) + ... print(alternative.transcript) + ... print(alternative.confidence) ==================== 'how old is the Brooklyn Bridge' 0.98267895 @@ -93,9 +94,10 @@ Great Britian. ... source_uri='gs://my-bucket/recording.flac', language_code='en-GB', ... max_alternatives=2) >>> for result in results: - ... print('=' * 20) - ... print('transcript: ' + result.transcript) - ... print('confidence: ' + result.confidence) + ... for alternative in result.alternatives: + ... print('=' * 20) + ... print('transcript: ' + alternative.transcript) + ... print('confidence: ' + alternative.confidence) ==================== transcript: Hello, this is a test confidence: 0.81 @@ -115,9 +117,10 @@ Example of using the profanity filter. >>> results = sample.sync_recognize(max_alternatives=1, ... profanity_filter=True) >>> for result in results: - ... print('=' * 20) - ... print('transcript: ' + result.transcript) - ... print('confidence: ' + result.confidence) + ... for alternative in result.alternatives: + ... print('=' * 20) + ... print('transcript: ' + alternative.transcript) + ... print('confidence: ' + alternative.confidence) ==================== transcript: Hello, this is a f****** test confidence: 0.81 @@ -137,9 +140,10 @@ words to the vocabulary of the recognizer. >>> results = sample.sync_recognize(max_alternatives=2, ... speech_context=hints) >>> for result in results: - ... print('=' * 20) - ... print('transcript: ' + result.transcript) - ... print('confidence: ' + result.confidence) + ... for alternative in result.alternatives: + ... print('=' * 20) + ... print('transcript: ' + alternative.transcript) + ... print('confidence: ' + alternative.confidence) ==================== transcript: Hello, this is a test confidence: 0.81 diff --git a/speech/google/cloud/speech/_gax.py b/speech/google/cloud/speech/_gax.py index 2465ec0e50bd..af07cb6b45c2 100644 --- a/speech/google/cloud/speech/_gax.py +++ b/speech/google/cloud/speech/_gax.py @@ -104,9 +104,9 @@ def async_recognize(self, sample, language_code=None, audio = RecognitionAudio(content=sample.content, uri=sample.source_uri) api = self._gapic_api - response = api.async_recognize(config=config, audio=audio) + operation_future = api.async_recognize(config=config, audio=audio) - return Operation.from_pb(response, self) + return Operation.from_pb(operation_future.last_operation_data(), self) def streaming_recognize(self, sample, language_code=None, max_alternatives=None, profanity_filter=None, diff --git a/speech/google/cloud/speech/client.py b/speech/google/cloud/speech/client.py index 0bf96c68e100..f3a44659dab4 100644 --- a/speech/google/cloud/speech/client.py +++ b/speech/google/cloud/speech/client.py @@ -23,7 +23,7 @@ from google.cloud.environment_vars import DISABLE_GRPC from google.cloud.speech._gax import GAPICSpeechAPI -from google.cloud.speech.alternative import Alternative +from google.cloud.speech.result import Result from google.cloud.speech.connection import Connection from google.cloud.speech.operation import Operation from google.cloud.speech.sample import Sample @@ -235,12 +235,11 @@ def sync_recognize(self, sample, language_code=None, max_alternatives=None, api_response = self._connection.api_request( method='POST', path='speech:syncrecognize', data=data) - if len(api_response['results']) == 1: - result = api_response['results'][0] - return [Alternative.from_api_repr(alternative) - for alternative in result['alternatives']] + if len(api_response['results']) > 0: + results = api_response['results'] + return [Result.from_api_repr(result) for result in results] else: - raise ValueError('More than one result or none returned from API.') + raise ValueError('No results were returned from the API') def _build_request_data(sample, language_code=None, max_alternatives=None, diff --git a/speech/google/cloud/speech/result.py b/speech/google/cloud/speech/result.py index bba01a047c5d..82ae472a4f39 100644 --- a/speech/google/cloud/speech/result.py +++ b/speech/google/cloud/speech/result.py @@ -32,19 +32,34 @@ def __init__(self, alternatives): @classmethod def from_pb(cls, result): - """Factory: construct instance of ``SpeechRecognitionResult``. + """Factory: construct instance of ``Result``. :type result: :class:`~google.cloud.grpc.speech.v1beta1\ - .cloud_speech_pb2.StreamingRecognizeResult` - :param result: Instance of ``StreamingRecognizeResult`` protobuf. + .cloud_speech_pb2.SpeechRecognitionResult` + :param result: Instance of ``SpeechRecognitionResult`` protobuf. - :rtype: :class:`~google.cloud.speech.result.SpeechRecognitionResult` - :returns: Instance of ``SpeechRecognitionResult``. + :rtype: :class:`~google.cloud.speech.result.Result` + :returns: Instance of ``Result``. """ - alternatives = [Alternative.from_pb(result) for result + alternatives = [Alternative.from_pb(alternative) for alternative in result.alternatives] return cls(alternatives=alternatives) + @classmethod + def from_api_repr(cls, result): + """Factory: construct instance of ``Result``. + + :type result: dict + :param result: Dictionary of a :class:`~google.cloud.grpc.speech.\ + v1beta1.cloud_speech_pb2.SpeechRecognitionResult` + + :rtype: :class:`~google.cloud.speech.result.Result` + :returns: Instance of ``Result``. + """ + alternatives = [Alternative.from_api_repr(alternative) for alternative + in result['alternatives']] + return cls(alternatives=alternatives) + @property def confidence(self): """Return the confidence for the most probable alternative. diff --git a/speech/unit_tests/test_client.py b/speech/unit_tests/test_client.py index 43d3527ec339..9ec6a0148531 100644 --- a/speech/unit_tests/test_client.py +++ b/speech/unit_tests/test_client.py @@ -129,6 +129,7 @@ def test_sync_recognize_content_with_optional_params_no_gax(self): from google.cloud import speech from google.cloud.speech.alternative import Alternative + from google.cloud.speech.result import Result from unit_tests._fixtures import SYNC_RECOGNIZE_RESPONSE _B64_AUDIO_CONTENT = _bytes_to_unicode(b64encode(self.AUDIO_CONTENT)) @@ -174,13 +175,16 @@ def test_sync_recognize_content_with_optional_params_no_gax(self): alternative = SYNC_RECOGNIZE_RESPONSE['results'][0]['alternatives'][0] expected = Alternative.from_api_repr(alternative) self.assertEqual(len(response), 1) - self.assertIsInstance(response[0], Alternative) - self.assertEqual(response[0].transcript, expected.transcript) - self.assertEqual(response[0].confidence, expected.confidence) + self.assertIsInstance(response[0], Result) + self.assertEqual(len(response[0].alternatives), 1) + alternative = response[0].alternatives[0] + self.assertEqual(alternative.transcript, expected.transcript) + self.assertEqual(alternative.confidence, expected.confidence) def test_sync_recognize_source_uri_without_optional_params_no_gax(self): from google.cloud import speech from google.cloud.speech.alternative import Alternative + from google.cloud.speech.result import Result from unit_tests._fixtures import SYNC_RECOGNIZE_RESPONSE RETURNED = SYNC_RECOGNIZE_RESPONSE @@ -214,9 +218,12 @@ def test_sync_recognize_source_uri_without_optional_params_no_gax(self): expected = Alternative.from_api_repr( SYNC_RECOGNIZE_RESPONSE['results'][0]['alternatives'][0]) self.assertEqual(len(response), 1) - self.assertIsInstance(response[0], Alternative) - self.assertEqual(response[0].transcript, expected.transcript) - self.assertEqual(response[0].confidence, expected.confidence) + self.assertIsInstance(response[0], Result) + self.assertEqual(len(response[0].alternatives), 1) + alternative = response[0].alternatives[0] + + self.assertEqual(alternative.transcript, expected.transcript) + self.assertEqual(alternative.confidence, expected.confidence) def test_sync_recognize_with_empty_results_no_gax(self): from google.cloud import speech @@ -710,6 +717,7 @@ class _MockGAPICSpeechAPI(object): _requests = None _response = None _results = None + SERVICE_ADDRESS = 'foo.apis.invalid' def __init__(self, response=None, channel=None): @@ -717,12 +725,18 @@ def __init__(self, response=None, channel=None): self._channel = channel def async_recognize(self, config, audio): + from google.gapic.longrunning.operations_client import OperationsClient + from google.gax import _OperationFuture from google.longrunning.operations_pb2 import Operation + from google.cloud.proto.speech.v1beta1.cloud_speech_pb2 import ( + AsyncRecognizeResponse) self.config = config self.audio = audio - operation = Operation() - return operation + operations_client = mock.Mock(spec=OperationsClient) + operation_future = _OperationFuture(Operation(), operations_client, + AsyncRecognizeResponse, {}) + return operation_future def sync_recognize(self, config, audio): self.config = config diff --git a/system_tests/speech.py b/system_tests/speech.py index b970350d1fa8..14e80b7cfc1d 100644 --- a/system_tests/speech.py +++ b/system_tests/speech.py @@ -144,7 +144,10 @@ def test_sync_recognize_local_file(self): results = self._make_sync_request(content=content, max_alternatives=2) - self._check_results(results, 2) + self.assertEqual(len(results), 1) + alternatives = results[0].alternatives + self.assertEqual(len(alternatives), 2) + self._check_results(alternatives, 2) def test_sync_recognize_gcs_file(self): bucket_name = Config.TEST_BUCKET.name @@ -155,9 +158,10 @@ def test_sync_recognize_gcs_file(self): blob.upload_from_file(file_obj) source_uri = 'gs://%s/%s' % (bucket_name, blob_name) - result = self._make_sync_request(source_uri=source_uri, - max_alternatives=1) - self._check_results(result) + results = self._make_sync_request(source_uri=source_uri, + max_alternatives=1) + self.assertEqual(len(results), 1) + self._check_results(results[0].alternatives) def test_async_recognize_local_file(self): with open(AUDIO_FILE, 'rb') as file_obj: @@ -167,7 +171,10 @@ def test_async_recognize_local_file(self): max_alternatives=2) _wait_until_complete(operation) - self._check_results(operation.results, 2) + self.assertEqual(len(operation.results), 1) + alternatives = operation.results[0].alternatives + self.assertEqual(len(alternatives), 2) + self._check_results(alternatives, 2) def test_async_recognize_gcs_file(self): bucket_name = Config.TEST_BUCKET.name @@ -182,7 +189,10 @@ def test_async_recognize_gcs_file(self): max_alternatives=2) _wait_until_complete(operation) - self._check_results(operation.results, 2) + self.assertEqual(len(operation.results), 1) + alternatives = operation.results[0].alternatives + self.assertEqual(len(alternatives), 2) + self._check_results(alternatives, 2) def test_stream_recognize(self): if not Config.USE_GAX: