Skip to content

Commit

Permalink
Make a best effort attempt to initialise all Databricks globals (#562)
Browse files Browse the repository at this point in the history
## Changes
We only initialise dbutils locally when using `from
databricks.sdk.runtime import *`. Users in the webui are guided to use
this import in all library code.
<img width="812" alt="image"
src="https://github.com/databricks/databricks-sdk-py/assets/88345179/3f28bc3c-0ba9-41ac-b990-3c4f5bf138aa">

The local (for people outside DBR) solution so far was to initialise
spark manually. But this can be tedious for deeply nested libraries
(which is the reason this import was introduced in the first place).

Now, we make a best effort attempt to initialise maximum number of
globals locally, so that users can build and debug libraries using
databricks connect.

## Tests
* integration test

- [ ] `make test` run locally
- [x] `make fmt` applied
- [ ] relevant integration tests applied

---------

Signed-off-by: Kartik Gupta <88345179+kartikgupta-db@users.noreply.github.com>
Co-authored-by: Miles Yucht <miles@databricks.com>
  • Loading branch information
kartikgupta-db and mgyucht authored Feb 28, 2024
1 parent 5255760 commit f0fe023
Show file tree
Hide file tree
Showing 9 changed files with 176 additions and 66 deletions.
6 changes: 5 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,9 @@
"tests"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
"python.testing.pytestEnabled": true,
"python.envFile": "${workspaceFolder}/.databricks/.databricks.env",
"databricks.python.envFile": "${workspaceFolder}/.env",
"jupyter.interactiveWindow.cellMarker.codeRegex": "^# COMMAND ----------|^# Databricks notebook source|^(#\\s*%%|#\\s*\\<codecell\\>|#\\s*In\\[\\d*?\\]|#\\s*In\\[ \\])",
"jupyter.interactiveWindow.cellMarker.default": "# COMMAND ----------"
}
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ test:
pytest -m 'not integration and not benchmark' --cov=databricks --cov-report html tests

integration:
pytest -n auto -m 'integration and not benchmark' --cov=databricks --cov-report html tests
pytest -n auto -m 'integration and not benchmark' --dist loadgroup --cov=databricks --cov-report html tests

benchmark:
pytest -m 'benchmark' tests
Expand Down
4 changes: 2 additions & 2 deletions databricks/sdk/_widgets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ def get(self, name: str):
def _get(self, name: str) -> str:
pass

def getArgument(self, name: str, default_value: typing.Optional[str] = None):
def getArgument(self, name: str, defaultValue: typing.Optional[str] = None):
try:
return self.get(name)
except Exception:
return default_value
return defaultValue

def remove(self, name: str):
self._remove(name)
Expand Down
96 changes: 85 additions & 11 deletions databricks/sdk/runtime/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from typing import Dict, Union
from typing import Dict, Optional, Union, cast

logger = logging.getLogger('databricks.sdk')
is_local_implementation = True
Expand Down Expand Up @@ -86,23 +86,97 @@ def inner() -> Dict[str, str]:
_globals[var] = userNamespaceGlobals[var]
is_local_implementation = False
except ImportError:
from typing import cast

# OSS implementation
is_local_implementation = True

from databricks.sdk.dbutils import RemoteDbUtils
for var in dbruntime_objects:
globals()[var] = None

from . import dbutils_stub
# The next few try-except blocks are for initialising globals in a best effort
# mannaer. We separate them to try to get as many of them working as possible
try:
# We expect this to fail and only do this for providing types
from pyspark.sql.context import SQLContext
sqlContext: SQLContext = None # type: ignore
table = sqlContext.table
except Exception as e:
logging.debug(f"Failed to initialize globals 'sqlContext' and 'table', continuing. Cause: {e}")

dbutils_type = Union[dbutils_stub.dbutils, RemoteDbUtils]
try:
from pyspark.sql.functions import udf # type: ignore
except ImportError as e:
logging.debug(f"Failed to initialise udf global: {e}")

try:
from .stub import *
except (ImportError, NameError):
# this assumes that all environment variables are set
dbutils = RemoteDbUtils()
from databricks.connect import DatabricksSession # type: ignore
spark = DatabricksSession.builder.getOrCreate()
sql = spark.sql # type: ignore
except Exception as e:
# We are ignoring all failures here because user might want to initialize
# spark session themselves and we don't want to interfere with that
logging.debug(f"Failed to initialize globals 'spark' and 'sql', continuing. Cause: {e}")

try:
# We expect this to fail locally since dbconnect does not support sparkcontext. This is just for typing
sc = spark.sparkContext
except Exception as e:
logging.debug(f"Failed to initialize global 'sc', continuing. Cause: {e}")

def display(input=None, *args, **kwargs) -> None: # type: ignore
"""
Display plots or data.
Display plot:
- display() # no-op
- display(matplotlib.figure.Figure)
Display dataset:
- display(spark.DataFrame)
- display(list) # if list can be converted to DataFrame, e.g., list of named tuples
- display(pandas.DataFrame)
- display(koalas.DataFrame)
- display(pyspark.pandas.DataFrame)
Display any other value that has a _repr_html_() method
For Spark 2.0 and 2.1:
- display(DataFrame, streamName='optional', trigger=optional pyspark.sql.streaming.Trigger,
checkpointLocation='optional')
For Spark 2.2+:
- display(DataFrame, streamName='optional', trigger=optional interval like '1 second',
checkpointLocation='optional')
"""
# Import inside the function so that imports are only triggered on usage.
from IPython import display as IPDisplay
return IPDisplay.display(input, *args, **kwargs) # type: ignore

def displayHTML(html) -> None: # type: ignore
"""
Display HTML data.
Parameters
----------
data : URL or HTML string
If data is a URL, display the resource at that URL, the resource is loaded dynamically by the browser.
Otherwise data should be the HTML to be displayed.
See also:
IPython.display.HTML
IPython.display.display_html
"""
# Import inside the function so that imports are only triggered on usage.
from IPython import display as IPDisplay
return IPDisplay.display_html(html, raw=True) # type: ignore

# We want to propagate the error in initialising dbutils because this is a core
# functionality of the sdk
from databricks.sdk.dbutils import RemoteDbUtils

from . import dbutils_stub
dbutils_type = Union[dbutils_stub.dbutils, RemoteDbUtils]

dbutils = RemoteDbUtils()
dbutils = cast(dbutils_type, dbutils)

__all__ = ['dbutils'] if is_local_implementation else dbruntime_objects
# We do this to prevent importing widgets implementation prematurely
# The widget import should prompt users to use the implementation
# which has ipywidget support.
def getArgument(name: str, defaultValue: Optional[str] = None):
return dbutils.widgets.getArgument(name, defaultValue)


__all__ = dbruntime_objects
2 changes: 1 addition & 1 deletion databricks/sdk/runtime/dbutils_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def get(name: str) -> str:
...

@staticmethod
def getArgument(name: str, defaultValue: typing.Optional[str] = None) -> str:
def getArgument(name: str, defaultValue: typing.Optional[str] = None) -> str | None:
"""Returns the current value of a widget with give name.
:param name: Name of the argument to be accessed
:param defaultValue: (Deprecated) default value
Expand Down
48 changes: 0 additions & 48 deletions databricks/sdk/runtime/stub.py

This file was deleted.

3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
install_requires=["requests>=2.28.1,<3", "google-auth~=2.0"],
extras_require={"dev": ["pytest", "pytest-cov", "pytest-xdist", "pytest-mock",
"yapf", "pycodestyle", "autoflake", "isort", "wheel",
"ipython", "ipywidgets", "requests-mock", "pyfakefs"],
"ipython", "ipywidgets", "requests-mock", "pyfakefs",
"databricks-connect"],
"notebook": ["ipython>=8,<9", "ipywidgets>=8,<9"]},
author="Serge Smertin",
author_email="serge.smertin@databricks.com",
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def ucws(env_or_skip) -> WorkspaceClient:
@pytest.fixture(scope='session')
def env_or_skip():

def inner(var) -> str:
def inner(var: str) -> str:
if var not in os.environ:
pytest.skip(f'Environment variable {var} is missing')
return os.environ[var]
Expand Down
79 changes: 79 additions & 0 deletions tests/integration/test_dbconnect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import pytest

DBCONNECT_DBR_CLIENT = {"13.3": "13.3.3", "14.3": "14.3.1", }


def reload_modules(name: str):
"""
Reloads the specified module. This is useful when testing Databricks Connect, since both
the `databricks.connect` and `databricks.sdk.runtime` modules are stateful, and we need
to reload these modules to reset the state cache between test runs.
"""

import importlib
import sys

v = sys.modules.get(name)
if v is None:
return
try:
print(f"Reloading {name}")
importlib.reload(v)
except Exception as e:
print(f"Failed to reload {name}: {e}")


@pytest.fixture(scope="function")
def restorable_env():
import os
current_env = os.environ.copy()
yield
for k, v in os.environ.items():
if k not in current_env:
del os.environ[k]
elif v != current_env[k]:
os.environ[k] = current_env[k]


@pytest.fixture(params=list(DBCONNECT_DBR_CLIENT.keys()))
def setup_dbconnect_test(request, env_or_skip, restorable_env):
dbr = request.param
assert dbr in DBCONNECT_DBR_CLIENT, f"Unsupported Databricks Runtime version {dbr}. Please update DBCONNECT_DBR_CLIENT."

import os
os.environ["DATABRICKS_CLUSTER_ID"] = env_or_skip(
f"TEST_DBR_{dbr.replace('.', '_')}_DBCONNECT_CLUSTER_ID")

import subprocess
import sys
lib = f"databricks-connect=={DBCONNECT_DBR_CLIENT[dbr]}"
subprocess.check_call([sys.executable, "-m", "pip", "install", lib])

yield

subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", "databricks-connect"])


@pytest.mark.xdist_group(name="databricks-connect")
def test_dbconnect_initialisation(w, setup_dbconnect_test):
reload_modules("databricks.connect")
from databricks.connect import DatabricksSession

spark = DatabricksSession.builder.getOrCreate()
assert spark.sql("SELECT 1").collect()[0][0] == 1


@pytest.mark.xdist_group(name="databricks-connect")
def test_dbconnect_runtime_import(w, setup_dbconnect_test):
reload_modules("databricks.sdk.runtime")
from databricks.sdk.runtime import spark

assert spark.sql("SELECT 1").collect()[0][0] == 1


@pytest.mark.xdist_group(name="databricks-connect")
def test_dbconnect_runtime_import_no_error_if_doesnt_exist(w):
reload_modules("databricks.sdk.runtime")
from databricks.sdk.runtime import spark

assert spark is None

0 comments on commit f0fe023

Please sign in to comment.