Skip to content

Commit

Permalink
Add deferrable param in SageMakerTransformOperator (#31063)
Browse files Browse the repository at this point in the history
This will allow running SageMakerTransformOperator in an async
fashion meaning that we only submit a job from the worker to
run a job and then defer to the trigger for polling to wait for
the job status reaching a terminal state. This way, the worker
slot won't be occupied for the whole period of task execution.
  • Loading branch information
pankajkoti authored May 9, 2023
1 parent b8f7376 commit 4c9b5fe
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 8 deletions.
53 changes: 46 additions & 7 deletions airflow/providers/amazon/aws/operators/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,8 @@ class SageMakerTransformOperator(SageMakerBaseOperator):
:param wait_for_completion: Set to True to wait until the transform job finishes.
:param check_interval: If wait is set to True, the time interval, in seconds,
that this operation waits to check the status of the transform job.
:param max_attempts: Number of times to poll for query state before returning the current state,
defaults to None.
:param max_ingestion_time: If wait is set to True, the operation fails
if the transform job doesn't finish within max_ingestion_time seconds. If you
set this parameter to None, the operation does not timeout.
Expand All @@ -511,14 +513,17 @@ def __init__(
aws_conn_id: str = DEFAULT_CONN_ID,
wait_for_completion: bool = True,
check_interval: int = CHECK_INTERVAL_SECOND,
max_attempts: int | None = None,
max_ingestion_time: int | None = None,
check_if_job_exists: bool = True,
action_if_job_exists: str = "timestamp",
deferrable: bool = False,
**kwargs,
):
super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.max_attempts = max_attempts or 60
self.max_ingestion_time = max_ingestion_time
self.check_if_job_exists = check_if_job_exists
if action_if_job_exists in ("increment", "fail", "timestamp"):
Expand All @@ -535,6 +540,7 @@ def __init__(
f"Argument action_if_job_exists accepts only 'timestamp', 'increment' and 'fail'. \
Provided value: '{action_if_job_exists}'."
)
self.deferrable = deferrable

def _create_integer_fields(self) -> None:
"""Set fields which should be cast to integers."""
Expand Down Expand Up @@ -573,21 +579,54 @@ def execute(self, context: Context) -> dict:
self.hook.create_model(model_config)

self.log.info("Creating SageMaker transform Job %s.", transform_config["TransformJobName"])

if self.deferrable and not self.wait_for_completion:
self.log.warning(
"Setting deferrable to True does not have effect when wait_for_completion is set to False."
)

wait_for_completion = self.wait_for_completion
if self.deferrable and self.wait_for_completion:
# Set wait_for_completion to False so that it waits for the status in the deferred task.
wait_for_completion = False

response = self.hook.create_transform_job(
transform_config,
wait_for_completion=self.wait_for_completion,
wait_for_completion=wait_for_completion,
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time,
)
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"Sagemaker transform Job creation failed: {response}")
else:
return {
"Model": serialize(self.hook.describe_model(transform_config["ModelName"])),
"Transform": serialize(
self.hook.describe_transform_job(transform_config["TransformJobName"])

if self.deferrable and self.wait_for_completion:
self.defer(
timeout=self.execution_timeout,
trigger=SageMakerTrigger(
job_name=transform_config["TransformJobName"],
job_type="Transform",
poke_interval=self.check_interval,
max_attempts=self.max_attempts,
aws_conn_id=self.aws_conn_id,
),
}
method_name="execute_complete",
)

return {
"Model": serialize(self.hook.describe_model(transform_config["ModelName"])),
"Transform": serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])),
}

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")
else:
self.log.info(event["message"])
transform_config = self.config.get("Transform", self.config)
return {
"Model": serialize(self.hook.describe_model(transform_config["ModelName"])),
"Transform": serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])),
}


class SageMakerTuningOperator(SageMakerBaseOperator):
Expand Down
17 changes: 16 additions & 1 deletion tests/providers/amazon/aws/operators/test_sagemaker_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
import pytest
from botocore.exceptions import ClientError

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators import sagemaker
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTransformOperator
from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger

EXPECTED_INTEGER_FIELDS: list[list[str]] = [
["Transform", "TransformResources", "InstanceCount"],
Expand Down Expand Up @@ -163,3 +164,17 @@ def test_execute_without_check_if_job_exists(self, _, __, ___, mock_transform, _
check_interval=5,
max_ingestion_time=None,
)

@mock.patch.object(SageMakerHook, "create_transform_job")
@mock.patch.object(SageMakerHook, "create_model")
def test_operator_defer(self, _, mock_transform):
mock_transform.return_value = {
"TransformJobArn": "test_arn",
"ResponseMetadata": {"HTTPStatusCode": 200},
}
self.sagemaker.deferrable = True
self.sagemaker.wait_for_completion = True
self.sagemaker.check_if_job_exists = False
with pytest.raises(TaskDeferred) as exc:
self.sagemaker.execute(context=None)
assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is not a SagemakerTrigger"

0 comments on commit 4c9b5fe

Please sign in to comment.