Skip to content

Commit

Permalink
add support for heterogeneous list of good and bad responses
Browse files Browse the repository at this point in the history
  • Loading branch information
iscai-msft committed Feb 4, 2021
1 parent c1f7bdc commit 20eee41
Show file tree
Hide file tree
Showing 7 changed files with 694 additions and 517 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
AnalyzeBatchActionsResult,
RequestStatistics,
AnalyzeBatchActionsType,
AnalyzeBatchActionsError,
)
from._paging import AnalyzeHealthcareResult

Expand Down Expand Up @@ -83,6 +84,7 @@
'AnalyzeBatchActionsResult',
'RequestStatistics',
'AnalyzeBatchActionsType',
"AnalyzeBatchActionsError",
]

__version__ = VERSION
Original file line number Diff line number Diff line change
Expand Up @@ -1171,6 +1171,7 @@ class AnalyzeBatchActionsType(str, Enum):
RECOGNIZE_PII_ENTITIES = "recognize_pii_entities" #: PII Entities Recognition action.
EXTRACT_KEY_PHRASES = "extract_key_phrases" #: Key Phrase Extraction action.


class AnalyzeBatchActionsResult(DictMixin):
"""AnalyzeBatchActionsResult contains the results of a recognize entities action
on a list of documents. Returned by `begin_analyze_batch_actions`
Expand Down Expand Up @@ -1200,6 +1201,31 @@ def __repr__(self):
self.completed_on
)[:1024]

class AnalyzeBatchActionsError(DictMixin):
"""AnalyzeBatchActionsError is an error object which represents an an
error response for an action.
:ivar error: The action result error.
:vartype error: ~azure.ai.textanalytics.TextAnalyticsError
:ivar bool is_error: Boolean check for error item when iterating over list of
results. Always True for an instance of a DocumentError.
"""

def __init__(self, **kwargs):
self.error = kwargs.get("error")
self.is_error = True

def __repr__(self):
return "AnalyzeBatchActionsError(error={}, is_error={}".format(
self.error, self.is_error
)

@classmethod
def _from_generated(cls, error):
return cls(
error=TextAnalyticsError(code=error.code, message=error.message, target=error.target)
)


class RecognizeEntitiesAction(DictMixin):
"""RecognizeEntitiesAction encapsulates the parameters for starting a long-running Entities Recognition operation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import json
import functools
from collections import defaultdict
from six.moves.urllib.parse import urlparse, parse_qsl
from azure.core.exceptions import (
HttpResponseError,
Expand Down Expand Up @@ -33,7 +34,9 @@
AnalyzeBatchActionsResult,
RequestStatistics,
AnalyzeBatchActionsType,
AnalyzeBatchActionsError,
TextDocumentBatchStatistics,
_get_indices,
)
from ._paging import AnalyzeHealthcareResult, AnalyzeResult

Expand Down Expand Up @@ -217,23 +220,61 @@ def _num_tasks_in_current_page(returned_tasks_object):
len(returned_tasks_object.key_phrase_extraction_tasks or [])
)

def _get_task_type_from_error(error):
if "pii" in error.target.lower():
return AnalyzeBatchActionsType.RECOGNIZE_PII_ENTITIES
if "entity" in error.target.lower():
return AnalyzeBatchActionsType.RECOGNIZE_ENTITIES
return AnalyzeBatchActionsType.EXTRACT_KEY_PHRASES

def _get_mapped_errors(analyze_job_state):
"""
"""
mapped_errors = defaultdict(list)
if not analyze_job_state.errors:
return mapped_errors
for error in analyze_job_state.errors:
mapped_errors[_get_task_type_from_error(error)].append((_get_error_index(error), error))
return mapped_errors

def _get_error_index(error):
return _get_indices(error.target)[-1]

def _get_good_result(current_task_type, index_of_task_result, doc_id_order, response_headers, returned_tasks_object):
deserialization_callback = _get_deserialization_callback_from_task_type(current_task_type)
property_name = _get_property_name_from_task_type(current_task_type)
response_task_to_deserialize = getattr(returned_tasks_object, property_name)[index_of_task_result]
document_results = deserialization_callback(
doc_id_order, response_task_to_deserialize.results, response_headers, lro=True
)
return AnalyzeBatchActionsResult(
document_results=document_results,
action_type=current_task_type,
completed_on=response_task_to_deserialize.last_update_date_time,
)

def get_iter_items(doc_id_order, task_order, response_headers, analyze_job_state):
iter_items = []
task_type_to_index = defaultdict(int) # need to keep track of how many of each type of tasks we've seen
returned_tasks_object = analyze_job_state.tasks
mapped_errors = _get_mapped_errors(analyze_job_state)
for current_task_type in task_order:
deserialization_callback = _get_deserialization_callback_from_task_type(current_task_type)
property_name = _get_property_name_from_task_type(current_task_type)
response_task_to_deserialize = getattr(returned_tasks_object, property_name).pop(0)
document_results = deserialization_callback(
doc_id_order, response_task_to_deserialize.results, response_headers, lro=True
)
iter_items.append(
AnalyzeBatchActionsResult(
document_results=document_results,
action_type=current_task_type,
completed_on=response_task_to_deserialize.last_update_date_time,
index_of_task_result = task_type_to_index[current_task_type]

try:
# try to deserailize as error. If fails, we know it's good
# kind of a weird way to order things, but we can fail when deserializing
# the curr response as an error, not when deserializing as a good response.

current_task_type_errors = mapped_errors[current_task_type]
error = next(err for err in current_task_type_errors if err[0] == index_of_task_result)
result = AnalyzeBatchActionsError._from_generated(error[1]) # pylint: disable=protected-access
except StopIteration:
result = _get_good_result(
current_task_type, index_of_task_result, doc_id_order, response_headers, returned_tasks_object
)
)
iter_items.append(result)
task_type_to_index[current_task_type] += 1
return iter_items

def analyze_extract_page_data(doc_id_order, task_order, response_headers, analyze_job_state):
Expand Down
Loading

0 comments on commit 20eee41

Please sign in to comment.