Skip to content

Commit

Permalink
[form recognizer] Remove unnecessary code (#14257)
Browse files Browse the repository at this point in the history
  • Loading branch information
iscai-msft authored Oct 5, 2020
1 parent 0c2e175 commit 4cabc99
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 140 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@
TYPE_CHECKING
)
from azure.core.tracing.decorator import distributed_trace
from azure.core.polling import LROPoller
from azure.core.polling.base_polling import LROBasePolling

from ._response_handlers import (
prepare_receipt,
prepare_content_result,
prepare_form_result
)
from ._helpers import get_content_type, error_map
from ._helpers import get_content_type
from ._form_base_client import FormRecognizerClientBase
from ._polling import AnalyzePolling
if TYPE_CHECKING:
from azure.core.polling import LROPoller
from ._models import FormPage, RecognizedForm


Expand Down Expand Up @@ -104,14 +105,11 @@ def begin_recognize_receipts(self, receipt, **kwargs):
:caption: Recognize US sales receipt fields.
"""
locale = kwargs.pop("locale", None)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
content_type = kwargs.pop("content_type", None)
include_field_elements = kwargs.pop("include_field_elements", False)
if content_type == "application/json":
raise TypeError("Call begin_recognize_receipts_from_url() to analyze a receipt from a URL.")
cls = kwargs.pop("cls", self._receipt_callback)
polling = LROBasePolling(timeout=polling_interval, **kwargs)
if content_type is None:
content_type = get_content_type(receipt)

Expand All @@ -123,9 +121,7 @@ def begin_recognize_receipts(self, receipt, **kwargs):
content_type=content_type,
include_text_details=include_field_elements,
cls=cls,
polling=polling,
error_map=error_map,
continuation_token=continuation_token,
polling=True,
**kwargs
)

Expand Down Expand Up @@ -161,20 +157,15 @@ def begin_recognize_receipts_from_url(self, receipt_url, **kwargs):
:caption: Recognize US sales receipt fields from a URL.
"""
locale = kwargs.pop("locale", None)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
include_field_elements = kwargs.pop("include_field_elements", False)
cls = kwargs.pop("cls", self._receipt_callback)
polling = LROBasePolling(timeout=polling_interval, **kwargs)
if self.api_version == "2.1-preview.1" and locale:
kwargs.update({"locale": locale})
return self._client.begin_analyze_receipt_async( # type: ignore
file_stream={"source": receipt_url},
include_text_details=include_field_elements,
cls=cls,
polling=polling,
error_map=error_map,
continuation_token=continuation_token,
polling=True,
**kwargs
)

Expand Down Expand Up @@ -212,9 +203,6 @@ def begin_recognize_content(self, form, **kwargs):
:dedent: 8
:caption: Recognize text and content/layout information from a form.
"""

polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)
content_type = kwargs.pop("content_type", None)
if content_type == "application/json":
raise TypeError("Call begin_recognize_content_from_url() to analyze a document from a URL.")
Expand All @@ -226,9 +214,7 @@ def begin_recognize_content(self, form, **kwargs):
file_stream=form,
content_type=content_type,
cls=kwargs.pop("cls", self._content_callback),
polling=LROBasePolling(timeout=polling_interval, **kwargs),
error_map=error_map,
continuation_token=continuation_token,
polling=True,
**kwargs
)

Expand All @@ -249,15 +235,10 @@ def begin_recognize_content_from_url(self, form_url, **kwargs):
:raises ~azure.core.exceptions.HttpResponseError:
"""

polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)

return self._client.begin_analyze_layout_async( # type: ignore
file_stream={"source": form_url},
cls=kwargs.pop("cls", self._content_callback),
polling=LROBasePolling(timeout=polling_interval, **kwargs),
error_map=error_map,
continuation_token=continuation_token,
polling=True,
**kwargs
)

Expand Down Expand Up @@ -299,9 +280,8 @@ def begin_recognize_custom_forms(self, model_id, form, **kwargs):
if not model_id:
raise ValueError("model_id cannot be None or empty.")

cls = kwargs.pop("cls", None)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)

content_type = kwargs.pop("content_type", None)
if content_type == "application/json":
raise TypeError("Call begin_recognize_custom_forms_from_url() to analyze a document from a URL.")
Expand All @@ -314,16 +294,13 @@ def analyze_callback(raw_response, _, headers): # pylint: disable=unused-argume
analyze_result = self._deserialize(self._generated_models.AnalyzeOperationResult, raw_response)
return prepare_form_result(analyze_result, model_id)

deserialization_callback = cls if cls else analyze_callback
return self._client.begin_analyze_with_custom_model( # type: ignore
file_stream=form,
model_id=model_id,
include_text_details=include_field_elements,
content_type=content_type,
cls=deserialization_callback,
cls=kwargs.pop("cls", analyze_callback),
polling=LROBasePolling(timeout=polling_interval, lro_algorithms=[AnalyzePolling()], **kwargs),
error_map=error_map,
continuation_token=continuation_token,
**kwargs
)

Expand Down Expand Up @@ -351,24 +328,20 @@ def begin_recognize_custom_forms_from_url(self, model_id, form_url, **kwargs):
if not model_id:
raise ValueError("model_id cannot be None or empty.")

cls = kwargs.pop("cls", None)
polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)

include_field_elements = kwargs.pop("include_field_elements", False)

def analyze_callback(raw_response, _, headers): # pylint: disable=unused-argument
analyze_result = self._deserialize(self._generated_models.AnalyzeOperationResult, raw_response)
return prepare_form_result(analyze_result, model_id)

deserialization_callback = cls if cls else analyze_callback
return self._client.begin_analyze_with_custom_model( # type: ignore
file_stream={"source": form_url},
model_id=model_id,
include_text_details=include_field_elements,
cls=deserialization_callback,
cls=kwargs.pop("cls", analyze_callback),
polling=LROBasePolling(timeout=polling_interval, lro_algorithms=[AnalyzePolling()], **kwargs),
error_map=error_map,
continuation_token=continuation_token,
**kwargs
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@
CopyRequest,
CopyAuthorizationResult
)
from ._helpers import (
error_map,
TransportWrapper
)
from ._helpers import TransportWrapper

from ._models import (
CustomFormModelInfo,
AccountProperties,
Expand Down Expand Up @@ -152,7 +150,6 @@ def callback_v2_1(raw_response, _, headers): # pylint: disable=unused-argument
)
),
cls=lambda pipeline_response, _, response_headers: pipeline_response,
error_map=error_map,
**kwargs
) # type: PipelineResponseType

Expand All @@ -176,7 +173,6 @@ def callback_v2_1(raw_response, _, headers): # pylint: disable=unused-argument
cls=deserialization_callback,
continuation_token=continuation_token,
polling=LROBasePolling(timeout=polling_interval, lro_algorithms=[TrainingPolling()], **kwargs),
error_map=error_map,
**kwargs
)

Expand Down Expand Up @@ -204,11 +200,7 @@ def delete_model(self, model_id, **kwargs):
if not model_id:
raise ValueError("model_id cannot be None or empty.")

self._client.delete_custom_model(
model_id=model_id,
error_map=error_map,
**kwargs
)
self._client.delete_custom_model(model_id=model_id, **kwargs)

@distributed_trace
def list_custom_models(self, **kwargs):
Expand All @@ -231,7 +223,6 @@ def list_custom_models(self, **kwargs):
"""
return self._client.list_custom_models( # type: ignore
cls=kwargs.pop("cls", lambda objs: [CustomFormModelInfo._from_generated(x) for x in objs]),
error_map=error_map,
**kwargs
)

Expand All @@ -254,7 +245,7 @@ def get_account_properties(self, **kwargs):
:dedent: 8
:caption: Get properties for the form recognizer account.
"""
response = self._client.get_custom_models(error_map=error_map, **kwargs)
response = self._client.get_custom_models(**kwargs)
return AccountProperties._from_generated(response.summary)

@distributed_trace
Expand All @@ -281,7 +272,7 @@ def get_custom_model(self, model_id, **kwargs):
if not model_id:
raise ValueError("model_id cannot be None or empty.")

response = self._client.get_custom_model(model_id=model_id, include_keys=True, error_map=error_map, **kwargs)
response = self._client.get_custom_model(model_id=model_id, include_keys=True, **kwargs)
return CustomFormModel._from_generated(response)

@distributed_trace
Expand Down Expand Up @@ -314,7 +305,6 @@ def get_copy_authorization(self, resource_id, resource_region, **kwargs):

response = self._client.generate_model_copy_authorization( # type: ignore
cls=lambda pipeline_response, deserialized, response_headers: pipeline_response,
error_map=error_map,
**kwargs
) # type: PipelineResponse
target = json.loads(response.http_response.text())
Expand Down Expand Up @@ -359,9 +349,7 @@ def begin_copy_model(

if not model_id:
raise ValueError("model_id cannot be None or empty.")

polling_interval = kwargs.pop("polling_interval", self._client._config.polling_interval)
continuation_token = kwargs.pop("continuation_token", None)

def _copy_callback(raw_response, _, headers): # pylint: disable=unused-argument
copy_result = self._deserialize(self._generated_models.CopyOperationResult, raw_response)
Expand All @@ -380,8 +368,6 @@ def _copy_callback(raw_response, _, headers): # pylint: disable=unused-argument
),
cls=kwargs.pop("cls", _copy_callback),
polling=LROBasePolling(timeout=polling_interval, lro_algorithms=[CopyPolling()], **kwargs),
error_map=error_map,
continuation_token=continuation_token,
**kwargs
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,10 @@
from azure.core.credentials import AzureKeyCredential
from azure.core.pipeline.policies import AzureKeyCredentialPolicy
from azure.core.pipeline.transport import HttpTransport
from azure.core.exceptions import (
ResourceNotFoundError,
ResourceExistsError,
ClientAuthenticationError
)

POLLING_INTERVAL = 5
COGNITIVE_KEY_HEADER = "Ocp-Apim-Subscription-Key"


error_map = {
404: ResourceNotFoundError,
409: ResourceExistsError,
401: ClientAuthenticationError
}

def _get_deserialize():
from ._generated.v2_1_preview_1 import FormRecognizerClient
return FormRecognizerClient("dummy", "dummy")._deserialize # pylint: disable=protected-access
Expand Down
Loading

0 comments on commit 4cabc99

Please sign in to comment.