diff --git a/airflow/providers/microsoft/azure/hooks/data_lake.py b/airflow/providers/microsoft/azure/hooks/data_lake.py index 054eda087e434..b2d9c5aafa941 100644 --- a/airflow/providers/microsoft/azure/hooks/data_lake.py +++ b/airflow/providers/microsoft/azure/hooks/data_lake.py @@ -22,7 +22,7 @@ from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError from azure.datalake.store import core, lib, multithread -from azure.identity import ClientSecretCredential +from azure.identity import ClientSecretCredential, DefaultAzureCredential from azure.storage.filedatalake import ( DataLakeDirectoryClient, DataLakeFileClient, @@ -38,9 +38,10 @@ AzureIdentityCredentialAdapter, add_managed_identity_connection_widgets, get_field, + get_sync_default_azure_credential, ) -Credentials = Union[ClientSecretCredential, AzureIdentityCredentialAdapter] +Credentials = Union[ClientSecretCredential, AzureIdentityCredentialAdapter, DefaultAzureCredential] class AzureDataLakeHook(BaseHook): @@ -358,7 +359,7 @@ def get_conn(self) -> DataLakeServiceClient: # type: ignore[override] else: managed_identity_client_id = self._get_field(extra, "managed_identity_client_id") workload_identity_tenant_id = self._get_field(extra, "workload_identity_tenant_id") - credential = AzureIdentityCredentialAdapter( + credential = get_sync_default_azure_credential( managed_identity_client_id=managed_identity_client_id, workload_identity_tenant_id=workload_identity_tenant_id, ) diff --git a/tests/providers/microsoft/azure/hooks/test_data_factory.py b/tests/providers/microsoft/azure/hooks/test_data_factory.py index 1ee77ad3aff10..a7d8786fd88c3 100644 --- a/tests/providers/microsoft/azure/hooks/test_data_factory.py +++ b/tests/providers/microsoft/azure/hooks/test_data_factory.py @@ -86,8 +86,8 @@ def setup_connections(create_mock_connections): "factory_name": DEFAULT_FACTORY, }, ), + # connection_missing_subscription_id Connection( - # connection_missing_subscription_id conn_id="azure_data_factory_missing_subscription_id", conn_type="azure_data_factory", login="clientId", @@ -110,6 +110,18 @@ def setup_connections(create_mock_connections): "factory_name": DEFAULT_FACTORY, }, ), + # connection_workload_identity + Connection( + conn_id="azure_data_factory_workload_identity", + conn_type="azure_data_factory", + extra={ + "subscriptionId": "subscriptionId", + "resource_group_name": DEFAULT_RESOURCE_GROUP, + "factory_name": DEFAULT_FACTORY, + "workload_identity_tenant_id": "workload_tenant_id", + "managed_identity_client_id": "workload_client_id", + }, + ), ) @@ -198,6 +210,21 @@ def test_get_conn_by_default_azure_credential(mock_credential): mock_create_client.assert_called_with(mock_credential(), "subscriptionId") +@mock.patch(f"{MODULE}.get_sync_default_azure_credential") +def test_get_conn_with_workload_identity(mock_credential): + hook = AzureDataFactoryHook("azure_data_factory_workload_identity") + with patch.object(hook, "_create_client") as mock_create_client: + mock_create_client.return_value = MagicMock() + + connection = hook.get_conn() + assert connection is not None + mock_credential.assert_called_once_with( + managed_identity_client_id="workload_client_id", + workload_identity_tenant_id="workload_tenant_id", + ) + mock_create_client.assert_called_with(mock_credential(), "subscriptionId") + + def test_get_factory(hook: AzureDataFactoryHook): hook.get_factory(RESOURCE_GROUP, FACTORY)