Skip to content

Commit

Permalink
Switch AzureDataLakeStorageV2Hook to use DefaultAzureCredential for m…
Browse files Browse the repository at this point in the history
…anaged identity/workload auth (apache#38497)
  • Loading branch information
TJaniF authored and fdemiane committed Jun 6, 2024
1 parent c70d842 commit 6ad9255
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
7 changes: 4 additions & 3 deletions airflow/providers/microsoft/azure/hooks/data_lake.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
from azure.datalake.store import core, lib, multithread
from azure.identity import ClientSecretCredential
from azure.identity import ClientSecretCredential, DefaultAzureCredential
from azure.storage.filedatalake import (
DataLakeDirectoryClient,
DataLakeFileClient,
Expand All @@ -38,9 +38,10 @@
AzureIdentityCredentialAdapter,
add_managed_identity_connection_widgets,
get_field,
get_sync_default_azure_credential,
)

Credentials = Union[ClientSecretCredential, AzureIdentityCredentialAdapter]
Credentials = Union[ClientSecretCredential, AzureIdentityCredentialAdapter, DefaultAzureCredential]


class AzureDataLakeHook(BaseHook):
Expand Down Expand Up @@ -358,7 +359,7 @@ def get_conn(self) -> DataLakeServiceClient: # type: ignore[override]
else:
managed_identity_client_id = self._get_field(extra, "managed_identity_client_id")
workload_identity_tenant_id = self._get_field(extra, "workload_identity_tenant_id")
credential = AzureIdentityCredentialAdapter(
credential = get_sync_default_azure_credential(
managed_identity_client_id=managed_identity_client_id,
workload_identity_tenant_id=workload_identity_tenant_id,
)
Expand Down
29 changes: 28 additions & 1 deletion tests/providers/microsoft/azure/hooks/test_data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ def setup_connections(create_mock_connections):
"factory_name": DEFAULT_FACTORY,
},
),
# connection_missing_subscription_id
Connection(
# connection_missing_subscription_id
conn_id="azure_data_factory_missing_subscription_id",
conn_type="azure_data_factory",
login="clientId",
Expand All @@ -110,6 +110,18 @@ def setup_connections(create_mock_connections):
"factory_name": DEFAULT_FACTORY,
},
),
# connection_workload_identity
Connection(
conn_id="azure_data_factory_workload_identity",
conn_type="azure_data_factory",
extra={
"subscriptionId": "subscriptionId",
"resource_group_name": DEFAULT_RESOURCE_GROUP,
"factory_name": DEFAULT_FACTORY,
"workload_identity_tenant_id": "workload_tenant_id",
"managed_identity_client_id": "workload_client_id",
},
),
)


Expand Down Expand Up @@ -198,6 +210,21 @@ def test_get_conn_by_default_azure_credential(mock_credential):
mock_create_client.assert_called_with(mock_credential(), "subscriptionId")


@mock.patch(f"{MODULE}.get_sync_default_azure_credential")
def test_get_conn_with_workload_identity(mock_credential):
hook = AzureDataFactoryHook("azure_data_factory_workload_identity")
with patch.object(hook, "_create_client") as mock_create_client:
mock_create_client.return_value = MagicMock()

connection = hook.get_conn()
assert connection is not None
mock_credential.assert_called_once_with(
managed_identity_client_id="workload_client_id",
workload_identity_tenant_id="workload_tenant_id",
)
mock_create_client.assert_called_with(mock_credential(), "subscriptionId")


def test_get_factory(hook: AzureDataFactoryHook):
hook.get_factory(RESOURCE_GROUP, FACTORY)

Expand Down

0 comments on commit 6ad9255

Please sign in to comment.