From 2438abdbf9fdffcea7224e20dffb9fb1d5178cc9 Mon Sep 17 00:00:00 2001
From: Miles Yucht <miles@databricks.com>
Date: Mon, 16 Sep 2024 08:52:51 -0400
Subject: [PATCH] [Fix] Do not specify --tenant flag when fetching managed
 identity access token from the CLI (#748)

## Changes
Ports https://github.com/databricks/databricks-sdk-go/pull/1021 to the
Python SDK.

The Azure CLI's az account get-access-token command does not allow
specifying --tenant flag if it is authenticated via the CLI.

Fixes #742.

## Tests
Unit tests ensure that all expected cases are treated as managed
identities.

- [ ] `make test` run locally
- [ ] `make fmt` applied
- [ ] relevant integration tests applied
---
 databricks/sdk/credentials_provider.py | 44 +++++++++++++++++++++++---
 tests/test_auth_manual_tests.py        | 12 +++++++
 tests/testdata/az                      | 32 +++++++++++++++++--
 tests/testdata/windows/az.ps1          | 28 ++++++++++++++++
 4 files changed, 109 insertions(+), 7 deletions(-)

diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py
index 860a06ce..8c1655af 100644
--- a/databricks/sdk/credentials_provider.py
+++ b/databricks/sdk/credentials_provider.py
@@ -411,10 +411,7 @@ def _parse_expiry(expiry: str) -> datetime:
 
     def refresh(self) -> Token:
         try:
-            is_windows = sys.platform.startswith('win')
-            # windows requires shell=True to be able to execute 'az login' or other commands
-            # cannot use shell=True all the time, as it breaks macOS
-            out = subprocess.run(self._cmd, capture_output=True, check=True, shell=is_windows)
+            out = _run_subprocess(self._cmd, capture_output=True, check=True)
             it = json.loads(out.stdout.decode())
             expires_on = self._parse_expiry(it[self._expiry_field])
             return Token(access_token=it[self._access_token_field],
@@ -429,6 +426,26 @@ def refresh(self) -> Token:
             raise IOError(f'cannot get access token: {message}') from e
 
 
+def _run_subprocess(popenargs,
+                    input=None,
+                    capture_output=True,
+                    timeout=None,
+                    check=False,
+                    **kwargs) -> subprocess.CompletedProcess:
+    """Runs subprocess with given arguments.
+    This handles OS-specific modifications that need to be made to the invocation of subprocess.run."""
+    kwargs['shell'] = sys.platform.startswith('win')
+    # windows requires shell=True to be able to execute 'az login' or other commands
+    # cannot use shell=True all the time, as it breaks macOS
+    logging.debug(f'Running command: {" ".join(popenargs)}')
+    return subprocess.run(popenargs,
+                          input=input,
+                          capture_output=capture_output,
+                          timeout=timeout,
+                          check=check,
+                          **kwargs)
+
+
 class AzureCliTokenSource(CliTokenSource):
     """ Obtain the token granted by `az login` CLI command """
 
@@ -437,13 +454,30 @@ def __init__(self, resource: str, subscription: Optional[str] = None, tenant: Op
         if subscription is not None:
             cmd.append("--subscription")
             cmd.append(subscription)
-        if tenant:
+        if tenant and not self.__is_cli_using_managed_identity():
             cmd.extend(["--tenant", tenant])
         super().__init__(cmd=cmd,
                          token_type_field='tokenType',
                          access_token_field='accessToken',
                          expiry_field='expiresOn')
 
+    @staticmethod
+    def __is_cli_using_managed_identity() -> bool:
+        """Checks whether the current CLI session is authenticated using managed identity."""
+        try:
+            cmd = ["az", "account", "show", "--output", "json"]
+            out = _run_subprocess(cmd, capture_output=True, check=True)
+            account = json.loads(out.stdout.decode())
+            user = account.get("user")
+            if user is None:
+                return False
+            return user.get("type") == "servicePrincipal" and user.get("name") in [
+                'systemAssignedIdentity', 'userAssignedIdentity'
+            ]
+        except subprocess.CalledProcessError as e:
+            logger.debug("Failed to get account information from Azure CLI", exc_info=e)
+            return False
+
     def is_human_user(self) -> bool:
         """The UPN claim is the username of the user, but not the Service Principal.
 
diff --git a/tests/test_auth_manual_tests.py b/tests/test_auth_manual_tests.py
index 34aa3a9c..8c58dd6b 100644
--- a/tests/test_auth_manual_tests.py
+++ b/tests/test_auth_manual_tests.py
@@ -1,3 +1,5 @@
+import pytest
+
 from databricks.sdk.core import Config
 
 from .conftest import set_az_path, set_home
@@ -60,3 +62,13 @@ def test_azure_cli_with_warning_on_stderr(monkeypatch, mock_tenant):
                  host='https://adb-123.4.azuredatabricks.net',
                  azure_workspace_resource_id=resource_id)
     assert 'X-Databricks-Azure-SP-Management-Token' in cfg.authenticate()
+
+
+@pytest.mark.parametrize('username', ['systemAssignedIdentity', 'userAssignedIdentity'])
+def test_azure_cli_does_not_specify_tenant_id_with_msi(monkeypatch, username):
+    set_home(monkeypatch, '/testdata/azure')
+    set_az_path(monkeypatch)
+    monkeypatch.setenv('FAIL_IF_TENANT_ID_SET', 'true')
+    monkeypatch.setenv('AZ_USER_NAME', username)
+    monkeypatch.setenv('AZ_USER_TYPE', 'servicePrincipal')
+    cfg = Config(auth_type='azure-cli', host='https://adb-123.4.azuredatabricks.net', azure_tenant_id='abc')
diff --git a/tests/testdata/az b/tests/testdata/az
index 5bf43a66..7437babc 100755
--- a/tests/testdata/az
+++ b/tests/testdata/az
@@ -1,7 +1,20 @@
 #!/bin/bash
 
-if [ -n "$WARN" ]; then
-    >&2 /bin/echo "WARNING: ${WARN}"
+# If the arguments are "account show", return the account details.
+if [ "$1" == "account" ] && [ "$2" == "show" ]; then
+    /bin/echo "{
+    \"environmentName\": \"AzureCloud\",
+    \"id\": \"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee\",
+    \"isDefault\": true,
+    \"name\": \"Pay-As-You-Go\",
+    \"state\": \"Enabled\",
+    \"tenantId\": \"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee\",
+    \"user\": {
+        \"name\": \"${AZ_USER_NAME:-testuser@databricks.com}\",
+        \"type\": \"${AZ_USER_TYPE:-user}\"
+    }
+}"
+    exit 0
 fi
 
 if [ "yes" == "$FAIL" ]; then
@@ -26,6 +39,21 @@ for arg in "$@"; do
     fi
 done
 
+# Add character to file at $COUNT if it is defined.
+if [ -n "$COUNT" ]; then
+    echo -n x >> "$COUNT"
+fi
+
+# If FAIL_IF_TENANT_ID_SET is set & --tenant-id is passed, fail.
+if [ -n "$FAIL_IF_TENANT_ID_SET" ]; then
+    for arg in "$@"; do
+        if [[ "$arg" == "--tenant" ]]; then
+            echo 1>&2 "ERROR: Tenant shouldn't be specified for managed identity account"
+            exit 1
+        fi
+    done
+fi
+
 # Macos
 EXP="$(/bin/date -v+${EXPIRE:=10S} +'%F %T' 2>/dev/null)"
 if [ -z "${EXP}" ]; then
diff --git a/tests/testdata/windows/az.ps1 b/tests/testdata/windows/az.ps1
index 4aa96adf..97ecbca7 100644
--- a/tests/testdata/windows/az.ps1
+++ b/tests/testdata/windows/az.ps1
@@ -1,5 +1,23 @@
 #!/usr/bin/env pwsh
 
+# If the arguments are "account show", return the account details.
+if ($args[0] -eq "account" -and $args[1] -eq "show") {
+    $output = @{
+        environmentName = "AzureCloud"
+        id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
+        isDefault = $true
+        name = "Pay-As-You-Go"
+        state = "Enabled"
+        tenantId = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
+        user = @{
+            name = if ($env:AZ_USER_NAME) { $env:AZ_USER_NAME } else { "testuser@databricks.com" }
+            type = if ($env:AZ_USER_TYPE) { $env:AZ_USER_TYPE } else { "user" }
+        }
+    }
+    $output | ConvertTo-Json
+    exit 0
+}
+
 if ($env:WARN) {
     Write-Error "WARNING: $env:WARN"
 }
@@ -30,6 +48,16 @@ foreach ($arg in $Args) {
     }
 }
 
+# If FAIL_IF_TENANT_ID_SET is set & --tenant-id is passed, fail.
+if ($env:FAIL_IF_TENANT_ID_SET) {
+    foreach ($arg in $args) {
+        if ($arg -eq "--tenant-id" -or $arg -like "--tenant*") {
+            Write-Error "ERROR: Tenant shouldn't be specified for managed identity account"
+            exit 1
+        }
+    }
+}
+
 try {
     $EXP = (Get-Date).AddSeconds($env:EXPIRE -as [int])
 } catch {