From 66d9ca06927e67aa5947ef7822631303f2f453b7 Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Mon, 14 Aug 2023 20:28:02 +0200 Subject: [PATCH 1/5] Make notebook-native auth work with more configurations of the Databricks Runtime This PR adds additional logging and hardening to `auth_type='runtime'`, which performs credentials lookup in the following order: 1. `init_runtime_native_auth` for the newest DBR versions. 2. `init_runtime_repl_auth` via Databricks REPL context and `workspaceUrl`. 3. `init_runtime_legacy_auth` via IPython context for legacy runtimes and modes. Every detection step adds more logging on `DEBUG` level. Based on publicly-accessible code in https://github.com/mlflow/mlflow/blame/6bd97bde24d78bcfbf6d50c1dd0f4fac2ed6987b/mlflow/utils/databricks_utils.py --- databricks/sdk/core.py | 31 ++++++++---------- databricks/sdk/runtime/__init__.py | 50 +++++++++++++++++++++++++++++- 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/databricks/sdk/core.py b/databricks/sdk/core.py index 4127a9853..e992f3f4f 100644 --- a/databricks/sdk/core.py +++ b/databricks/sdk/core.py @@ -91,26 +91,21 @@ def inner() -> Dict[str, str]: @credentials_provider('runtime', []) def runtime_native_auth(cfg: 'Config') -> Optional[HeaderFactory]: - from databricks.sdk.runtime import init_runtime_native_auth - if init_runtime_native_auth is not None: - host, inner = init_runtime_native_auth() + from databricks.sdk.runtime import (init_runtime_native_auth, + init_runtime_repl_auth, + init_runtime_legacy_auth) + for init in [init_runtime_native_auth, init_runtime_repl_auth, init_runtime_legacy_auth]: + if init is None: + logger.debug(f'[{init.__name__}] not applicable') + continue + host, inner = init() + if host is None: + logger.debug(f'[{init.__name__}] no host detected') + continue cfg.host = host + logger.debug(f'[{init.__name__}] runtime native auth configured') return inner - try: - from dbruntime.databricks_repl_context import get_context - ctx = get_context() - if ctx is None: - logger.debug('Empty REPL context returned, skipping runtime auth') - return None - cfg.host = f'https://{ctx.workspaceUrl}' - - def inner() -> Dict[str, str]: - ctx = get_context() - return {'Authorization': f'Bearer {ctx.apiToken}'} - - return inner - except ImportError: - return None + return None @credentials_provider('oauth-m2m', ['is_aws', 'host', 'client_id', 'client_secret']) diff --git a/databricks/sdk/runtime/__init__.py b/databricks/sdk/runtime/__init__.py index 5120c235f..f4c852944 100644 --- a/databricks/sdk/runtime/__init__.py +++ b/databricks/sdk/runtime/__init__.py @@ -1,7 +1,9 @@ from __future__ import annotations -from typing import Union +import logging +from typing import Dict, Union +logger = logging.getLogger('databricks.sdk') is_local_implementation = True # All objects that are injected into the Notebook's user namespace should also be made @@ -16,12 +18,58 @@ # We don't want to expose additional entity to user namespace, so # a workaround here for exposing required information in notebook environment from dbruntime.sdk_credential_provider import init_runtime_native_auth + logger.debug('runtime SDK credential provider available') dbruntime_objects.append("init_runtime_native_auth") except ImportError: init_runtime_native_auth = None globals()["init_runtime_native_auth"] = init_runtime_native_auth + +def init_runtime_repl_auth(): + try: + from dbruntime.databricks_repl_context import get_context + ctx = get_context() + if ctx is None: + logger.debug('Empty REPL context returned, skipping runtime auth') + return None, None + host = f'https://{ctx.workspaceUrl}' + + def inner() -> Dict[str, str]: + ctx = get_context() + return {'Authorization': f'Bearer {ctx.apiToken}'} + + return host, inner + except ImportError: + return None, None + + +def init_runtime_legacy_auth(): + try: + import IPython + ip_shell = IPython.get_ipython() + if ip_shell is None: + return None, None + global_ns = ip_shell.ns_table["user_global"] + if 'dbutils' not in global_ns: + return None, None + dbutils = global_ns["dbutils"].notebook.entry_point.getDbutils() + if dbutils is None: + return None, None + ctx = dbutils.notebook().getContext() + if ctx is None: + return None, None + host = getattr(ctx, 'apiUrl')().get() + + def inner() -> Dict[str, str]: + ctx = dbutils.notebook().getContext() + return {'Authorization': f'Bearer {getattr(ctx, "apiToken")().get()}'} + + return host, inner + except ImportError: + return None, None + + try: # Internal implementation # Separated from above for backward compatibility From bc6cb45f62ae96cf2ff20c3ecae28c1bae5078e4 Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Mon, 14 Aug 2023 20:35:05 +0200 Subject: [PATCH 2/5] fix tests --- databricks/sdk/core.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/databricks/sdk/core.py b/databricks/sdk/core.py index e992f3f4f..79c89b78f 100644 --- a/databricks/sdk/core.py +++ b/databricks/sdk/core.py @@ -91,12 +91,11 @@ def inner() -> Dict[str, str]: @credentials_provider('runtime', []) def runtime_native_auth(cfg: 'Config') -> Optional[HeaderFactory]: - from databricks.sdk.runtime import (init_runtime_native_auth, - init_runtime_repl_auth, - init_runtime_legacy_auth) + from databricks.sdk.runtime import (init_runtime_legacy_auth, + init_runtime_native_auth, + init_runtime_repl_auth) for init in [init_runtime_native_auth, init_runtime_repl_auth, init_runtime_legacy_auth]: if init is None: - logger.debug(f'[{init.__name__}] not applicable') continue host, inner = init() if host is None: From 5f1b201536b9e4e338e6dfae94339c00ff6d55a9 Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Mon, 14 Aug 2023 20:54:08 +0200 Subject: [PATCH 3/5] quick path to exit --- databricks/sdk/core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/databricks/sdk/core.py b/databricks/sdk/core.py index 79c89b78f..508ada430 100644 --- a/databricks/sdk/core.py +++ b/databricks/sdk/core.py @@ -94,6 +94,8 @@ def runtime_native_auth(cfg: 'Config') -> Optional[HeaderFactory]: from databricks.sdk.runtime import (init_runtime_legacy_auth, init_runtime_native_auth, init_runtime_repl_auth) + if 'DATABRICKS_RUNTIME_VERSION' not in os.environ: + return None for init in [init_runtime_native_auth, init_runtime_repl_auth, init_runtime_legacy_auth]: if init is None: continue From f94ca5b4e2ffe20f3fd78da6ef91ec255d0fbc1c Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Thu, 17 Aug 2023 12:30:55 +0200 Subject: [PATCH 4/5] wip --- tests/integration/test_auth.py | 92 ++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 tests/integration/test_auth.py diff --git a/tests/integration/test_auth.py b/tests/integration/test_auth.py new file mode 100644 index 000000000..2928eb66b --- /dev/null +++ b/tests/integration/test_auth.py @@ -0,0 +1,92 @@ +import base64 +import json +import shutil +import subprocess +import sys +import urllib.parse +from pathlib import Path + +import io +import pytest + +from databricks.sdk.service.compute import ClusterSpec, Library +from databricks.sdk.service.workspace import Language +from databricks.sdk.service.jobs import Task, NotebookTask, ViewType + + +@pytest.fixture +def fresh_wheel_file(tmp_path) -> Path: + this_file = Path(__file__) + project_root = this_file.parent.parent.parent.absolute() + build_root = tmp_path / 'databricks-sdk-py' + shutil.copytree(project_root, build_root) + try: + completed_process = subprocess.run( + [sys.executable, 'setup.py', 'bdist_wheel'], + capture_output=True, + cwd=build_root) + if completed_process.returncode != 0: + raise RuntimeError(completed_process.stderr) + + from databricks.sdk.version import __version__ + filename = f'databricks_sdk-{__version__}-py3-none-any.whl' + wheel_file = build_root / 'dist' / filename + + return wheel_file + except subprocess.CalledProcessError as e: + raise RuntimeError(e.stderr) + + +def test_runtime_auth(w, fresh_wheel_file, env_or_skip, random): + instance_pool_id = env_or_skip('TEST_INSTANCE_POOL_ID') + + v = w.clusters.spark_versions() + lts_runtimes = [x for x in v.versions if 'LTS' in x.name + and '-ml' not in x.key + and '-photon' not in x.key] + + dbfs_wheel = f'/tmp/wheels/{random(10)}/{fresh_wheel_file.name}' + with fresh_wheel_file.open('rb') as f: + w.dbfs.upload(dbfs_wheel, f) + + notebook_path = f'/Users/{w.current_user.me().user_name}/notebook-native-auth' + notebook_content = io.BytesIO(b''' +from databricks.sdk import WorkspaceClient +w = WorkspaceClient() +me = w.current_user.me() +print(me.user_name)''') + w.workspace.upload(notebook_path, notebook_content, + language=Language.PYTHON, + overwrite=True) + + tasks = [] + for v in lts_runtimes: + t = Task(task_key=f'test_{v.key.replace(".", "_")}', + notebook_task=NotebookTask(notebook_path=notebook_path), + new_cluster=ClusterSpec(spark_version=v.key, + num_workers=1, + instance_pool_id=instance_pool_id), + libraries=[Library(whl=f'dbfs:{dbfs_wheel}')]) + tasks.append(t) + w.jobs.create(tasks=tasks, name=f'Runtime Native Auth {random(10)}') + + print(v) + +def test_job_output(w): + # workflow_runs = w.jobs.list_runs(job_id=133270013770420) + this_run = w.jobs.get_run(20504473) + + import re + notebook_model = re.compile(r"var __DATABRICKS_NOTEBOOK_MODEL = '(.*)';", re.MULTILINE) + + for task_run in this_run.tasks: + print(task_run.task_key) + run_output = w.jobs.export_run(task_run.run_id) + for view in run_output.views: + if view.type != ViewType.NOTEBOOK: + continue + for b64 in notebook_model.findall(view.content): + url_encoded: bytes = base64.b64decode(b64) + json_encoded = urllib.parse.unquote(str(url_encoded)) + x = json.loads(json_encoded) + print(x) From 7a122fcaf16d874b1202ea0a40197c3c815a4cf1 Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Thu, 17 Aug 2023 15:17:24 +0200 Subject: [PATCH 5/5] Add integration tests --- tests/integration/test_auth.py | 113 ++++++++++++++++++++++++--------- 1 file changed, 84 insertions(+), 29 deletions(-) diff --git a/tests/integration/test_auth.py b/tests/integration/test_auth.py index 2928eb66b..6458056b7 100644 --- a/tests/integration/test_auth.py +++ b/tests/integration/test_auth.py @@ -1,17 +1,20 @@ import base64 +import io import json +import re import shutil import subprocess import sys import urllib.parse +from functools import partial from pathlib import Path -import io import pytest -from databricks.sdk.service.compute import ClusterSpec, Library -from databricks.sdk.service.workspace import Language -from databricks.sdk.service.jobs import Task, NotebookTask, ViewType +from databricks.sdk.service.compute import (ClusterSpec, DataSecurityMode, + Library, ResultType) +from databricks.sdk.service.jobs import NotebookTask, Task, ViewType +from databricks.sdk.service.workspace import ImportFormat @pytest.fixture @@ -21,10 +24,9 @@ def fresh_wheel_file(tmp_path) -> Path: build_root = tmp_path / 'databricks-sdk-py' shutil.copytree(project_root, build_root) try: - completed_process = subprocess.run( - [sys.executable, 'setup.py', 'bdist_wheel'], - capture_output=True, - cwd=build_root) + completed_process = subprocess.run([sys.executable, 'setup.py', 'bdist_wheel'], + capture_output=True, + cwd=build_root) if completed_process.returncode != 0: raise RuntimeError(completed_process.stderr) @@ -37,27 +39,73 @@ def fresh_wheel_file(tmp_path) -> Path: raise RuntimeError(e.stderr) -def test_runtime_auth(w, fresh_wheel_file, env_or_skip, random): +@pytest.mark.parametrize("mode", [DataSecurityMode.SINGLE_USER, DataSecurityMode.USER_ISOLATION]) +def test_runtime_auth_from_interactive_on_uc(ucws, fresh_wheel_file, env_or_skip, random, mode): + instance_pool_id = env_or_skip('TEST_INSTANCE_POOL_ID') + latest = ucws.clusters.select_spark_version(latest=True, beta=True) + + my_user = ucws.current_user.me().user_name + + workspace_location = f'/Users/{my_user}/wheels/{random(10)}' + ucws.workspace.mkdirs(workspace_location) + + wsfs_wheel = f'{workspace_location}/{fresh_wheel_file.name}' + with fresh_wheel_file.open('rb') as f: + ucws.workspace.upload(wsfs_wheel, f, format=ImportFormat.AUTO) + + from databricks.sdk.service.compute import Language + interactive_cluster = ucws.clusters.create(cluster_name=f'native-auth-on-{mode.name}', + spark_version=latest, + instance_pool_id=instance_pool_id, + autotermination_minutes=10, + num_workers=1, + data_security_mode=mode).result() + ctx = ucws.command_execution.create(cluster_id=interactive_cluster.cluster_id, + language=Language.PYTHON).result() + run = partial(ucws.command_execution.execute, + cluster_id=interactive_cluster.cluster_id, + context_id=ctx.id, + language=Language.PYTHON) + try: + res = run(command=f"%pip install /Workspace{wsfs_wheel}\ndbutils.library.restartPython()").result() + results = res.results + if results.result_type != ResultType.TEXT: + msg = f'({mode}) unexpected result type: {results.result_type}: {results.summary}\n{results.cause}' + raise RuntimeError(msg) + + res = run(command="\n".join([ + 'from databricks.sdk import WorkspaceClient', 'w = WorkspaceClient()', 'me = w.current_user.me()', + 'print(me.user_name)' + ])).result() + assert res.results.result_type == ResultType.TEXT, f'unexpected result type: {res.results.result_type}' + + assert my_user == res.results.data, f'unexpected user: {res.results.data}' + finally: + ucws.clusters.permanent_delete(interactive_cluster.cluster_id) + + +def test_runtime_auth_from_jobs(w, fresh_wheel_file, env_or_skip, random): instance_pool_id = env_or_skip('TEST_INSTANCE_POOL_ID') v = w.clusters.spark_versions() - lts_runtimes = [x for x in v.versions if 'LTS' in x.name - and '-ml' not in x.key - and '-photon' not in x.key] + lts_runtimes = [ + x for x in v.versions if 'LTS' in x.name and '-ml' not in x.key and '-photon' not in x.key + ] dbfs_wheel = f'/tmp/wheels/{random(10)}/{fresh_wheel_file.name}' with fresh_wheel_file.open('rb') as f: w.dbfs.upload(dbfs_wheel, f) - notebook_path = f'/Users/{w.current_user.me().user_name}/notebook-native-auth' + my_name = w.current_user.me().user_name + notebook_path = f'/Users/{my_name}/notebook-native-auth' notebook_content = io.BytesIO(b''' from databricks.sdk import WorkspaceClient w = WorkspaceClient() me = w.current_user.me() print(me.user_name)''') - w.workspace.upload(notebook_path, notebook_content, - language=Language.PYTHON, - overwrite=True) + + from databricks.sdk.service.workspace import Language + w.workspace.upload(notebook_path, notebook_content, language=Language.PYTHON, overwrite=True) tasks = [] for v in lts_runtimes: @@ -68,25 +116,32 @@ def test_runtime_auth(w, fresh_wheel_file, env_or_skip, random): instance_pool_id=instance_pool_id), libraries=[Library(whl=f'dbfs:{dbfs_wheel}')]) tasks.append(t) - w.jobs.create(tasks=tasks, name=f'Runtime Native Auth {random(10)}') - print(v) + run = w.jobs.submit(run_name=f'Runtime Native Auth {random(10)}', tasks=tasks).result() + for task_key, output in _task_outputs(w, run).items(): + assert my_name in output, f'{task_key} does not work with notebook native auth' -def test_job_output(w): - # workflow_runs = w.jobs.list_runs(job_id=133270013770420) - this_run = w.jobs.get_run(20504473) - import re - notebook_model = re.compile(r"var __DATABRICKS_NOTEBOOK_MODEL = '(.*)';", re.MULTILINE) +def _task_outputs(w, run): + notebook_model_re = re.compile(r"var __DATABRICKS_NOTEBOOK_MODEL = '(.*)';", re.MULTILINE) - for task_run in this_run.tasks: - print(task_run.task_key) + task_outputs = {} + for task_run in run.tasks: + output = '' run_output = w.jobs.export_run(task_run.run_id) for view in run_output.views: if view.type != ViewType.NOTEBOOK: continue - for b64 in notebook_model.findall(view.content): + for b64 in notebook_model_re.findall(view.content): url_encoded: bytes = base64.b64decode(b64) - json_encoded = urllib.parse.unquote(str(url_encoded)) - x = json.loads(json_encoded) - print(x) + json_encoded = urllib.parse.unquote(url_encoded.decode('utf-8')) + notebook_model = json.loads(json_encoded) + for command in notebook_model['commands']: + results_data = command['results']['data'] + if isinstance(results_data, str): + output += results_data + else: + for data in results_data: + output += data['data'] + task_outputs[task_run.task_key] = output + return task_outputs