diff --git a/airflow/providers/microsoft/azure/sensors/data_factory.py b/airflow/providers/microsoft/azure/sensors/data_factory.py index e98bb9caeea9c..5a18bca6dbe10 100644 --- a/airflow/providers/microsoft/azure/sensors/data_factory.py +++ b/airflow/providers/microsoft/azure/sensors/data_factory.py @@ -94,17 +94,18 @@ def execute(self, context: Context) -> None: if not self.deferrable: super().execute(context=context) else: - self.defer( - timeout=timedelta(seconds=self.timeout), - trigger=ADFPipelineRunStatusSensorTrigger( - run_id=self.run_id, - azure_data_factory_conn_id=self.azure_data_factory_conn_id, - resource_group_name=self.resource_group_name, - factory_name=self.factory_name, - poke_interval=self.poke_interval, - ), - method_name="execute_complete", - ) + if not self.poke(context=context): + self.defer( + timeout=timedelta(seconds=self.timeout), + trigger=ADFPipelineRunStatusSensorTrigger( + run_id=self.run_id, + azure_data_factory_conn_id=self.azure_data_factory_conn_id, + resource_group_name=self.resource_group_name, + factory_name=self.factory_name, + poke_interval=self.poke_interval, + ), + method_name="execute_complete", + ) def execute_complete(self, context: Context, event: dict[str, str]) -> None: """ diff --git a/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py b/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py index 21451775bd91d..ecc3a7c7b90c4 100644 --- a/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py +++ b/tests/providers/microsoft/azure/sensors/test_azure_data_factory.py @@ -84,15 +84,28 @@ def test_poke(self, mock_pipeline_run, pipeline_run_status, expected_status): with pytest.raises(AzureDataFactoryPipelineRunException, match=error_message): self.sensor.poke({}) - def test_adf_pipeline_status_sensor_async(self): + @mock.patch("airflow.providers.microsoft.azure.sensors.data_factory.AzureDataFactoryHook") + def test_adf_pipeline_status_sensor_async(self, mock_hook): """Assert execute method defer for Azure Data factory pipeline run status sensor""" - + mock_hook.return_value.get_pipeline_run_status.return_value = AzureDataFactoryPipelineRunStatus.QUEUED with pytest.raises(TaskDeferred) as exc: - self.defered_sensor.execute({}) + self.defered_sensor.execute(mock.MagicMock()) assert isinstance( exc.value.trigger, ADFPipelineRunStatusSensorTrigger ), "Trigger is not a ADFPipelineRunStatusSensorTrigger" + @mock.patch("airflow.providers.microsoft.azure.sensors.data_factory.AzureDataFactoryHook") + @mock.patch( + "airflow.providers.microsoft.azure.sensors.data_factory" + ".AzureDataFactoryPipelineRunStatusSensor.defer" + ) + def test_adf_pipeline_status_sensor_finish_before_deferred(self, mock_defer, mock_hook): + mock_hook.return_value.get_pipeline_run_status.return_value = ( + AzureDataFactoryPipelineRunStatus.SUCCEEDED + ) + self.defered_sensor.execute(mock.MagicMock()) + assert not mock_defer.called + def test_adf_pipeline_status_sensor_execute_complete_success(self): """Assert execute_complete log success message when trigger fire with target status""" @@ -115,9 +128,10 @@ class TestAzureDataFactoryPipelineRunStatusAsyncSensor: run_id=RUN_ID, ) - def test_adf_pipeline_status_sensor_async(self): + @mock.patch("airflow.providers.microsoft.azure.sensors.data_factory.AzureDataFactoryHook") + def test_adf_pipeline_status_sensor_async(self, mock_hook): """Assert execute method defer for Azure Data factory pipeline run status sensor""" - + mock_hook.return_value.get_pipeline_run_status.return_value = AzureDataFactoryPipelineRunStatus.QUEUED with pytest.raises(TaskDeferred) as exc: self.SENSOR.execute({}) assert isinstance(