diff --git a/airflow/providers/amazon/aws/operators/sagemaker.py b/airflow/providers/amazon/aws/operators/sagemaker.py index 587bf5b0a16ac..f4041b465a1ed 100644 --- a/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/airflow/providers/amazon/aws/operators/sagemaker.py @@ -493,6 +493,8 @@ class SageMakerTransformOperator(SageMakerBaseOperator): :param wait_for_completion: Set to True to wait until the transform job finishes. :param check_interval: If wait is set to True, the time interval, in seconds, that this operation waits to check the status of the transform job. + :param max_attempts: Number of times to poll for query state before returning the current state, + defaults to None. :param max_ingestion_time: If wait is set to True, the operation fails if the transform job doesn't finish within max_ingestion_time seconds. If you set this parameter to None, the operation does not timeout. @@ -511,14 +513,17 @@ def __init__( aws_conn_id: str = DEFAULT_CONN_ID, wait_for_completion: bool = True, check_interval: int = CHECK_INTERVAL_SECOND, + max_attempts: int | None = None, max_ingestion_time: int | None = None, check_if_job_exists: bool = True, action_if_job_exists: str = "timestamp", + deferrable: bool = False, **kwargs, ): super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs) self.wait_for_completion = wait_for_completion self.check_interval = check_interval + self.max_attempts = max_attempts or 60 self.max_ingestion_time = max_ingestion_time self.check_if_job_exists = check_if_job_exists if action_if_job_exists in ("increment", "fail", "timestamp"): @@ -535,6 +540,7 @@ def __init__( f"Argument action_if_job_exists accepts only 'timestamp', 'increment' and 'fail'. \ Provided value: '{action_if_job_exists}'." ) + self.deferrable = deferrable def _create_integer_fields(self) -> None: """Set fields which should be cast to integers.""" @@ -573,21 +579,54 @@ def execute(self, context: Context) -> dict: self.hook.create_model(model_config) self.log.info("Creating SageMaker transform Job %s.", transform_config["TransformJobName"]) + + if self.deferrable and not self.wait_for_completion: + self.log.warning( + "Setting deferrable to True does not have effect when wait_for_completion is set to False." + ) + + wait_for_completion = self.wait_for_completion + if self.deferrable and self.wait_for_completion: + # Set wait_for_completion to False so that it waits for the status in the deferred task. + wait_for_completion = False + response = self.hook.create_transform_job( transform_config, - wait_for_completion=self.wait_for_completion, + wait_for_completion=wait_for_completion, check_interval=self.check_interval, max_ingestion_time=self.max_ingestion_time, ) if response["ResponseMetadata"]["HTTPStatusCode"] != 200: raise AirflowException(f"Sagemaker transform Job creation failed: {response}") - else: - return { - "Model": serialize(self.hook.describe_model(transform_config["ModelName"])), - "Transform": serialize( - self.hook.describe_transform_job(transform_config["TransformJobName"]) + + if self.deferrable and self.wait_for_completion: + self.defer( + timeout=self.execution_timeout, + trigger=SageMakerTrigger( + job_name=transform_config["TransformJobName"], + job_type="Transform", + poke_interval=self.check_interval, + max_attempts=self.max_attempts, + aws_conn_id=self.aws_conn_id, ), - } + method_name="execute_complete", + ) + + return { + "Model": serialize(self.hook.describe_model(transform_config["ModelName"])), + "Transform": serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])), + } + + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error while running job: {event}") + else: + self.log.info(event["message"]) + transform_config = self.config.get("Transform", self.config) + return { + "Model": serialize(self.hook.describe_model(transform_config["ModelName"])), + "Transform": serialize(self.hook.describe_transform_job(transform_config["TransformJobName"])), + } class SageMakerTuningOperator(SageMakerBaseOperator): diff --git a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py index 482c7201ccc41..76a4d877b6545 100644 --- a/tests/providers/amazon/aws/operators/test_sagemaker_transform.py +++ b/tests/providers/amazon/aws/operators/test_sagemaker_transform.py @@ -23,10 +23,11 @@ import pytest from botocore.exceptions import ClientError -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook from airflow.providers.amazon.aws.operators import sagemaker from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTransformOperator +from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger EXPECTED_INTEGER_FIELDS: list[list[str]] = [ ["Transform", "TransformResources", "InstanceCount"], @@ -163,3 +164,17 @@ def test_execute_without_check_if_job_exists(self, _, __, ___, mock_transform, _ check_interval=5, max_ingestion_time=None, ) + + @mock.patch.object(SageMakerHook, "create_transform_job") + @mock.patch.object(SageMakerHook, "create_model") + def test_operator_defer(self, _, mock_transform): + mock_transform.return_value = { + "TransformJobArn": "test_arn", + "ResponseMetadata": {"HTTPStatusCode": 200}, + } + self.sagemaker.deferrable = True + self.sagemaker.wait_for_completion = True + self.sagemaker.check_if_job_exists = False + with pytest.raises(TaskDeferred) as exc: + self.sagemaker.execute(context=None) + assert isinstance(exc.value.trigger, SageMakerTrigger), "Trigger is not a SagemakerTrigger"