-
Notifications
You must be signed in to change notification settings - Fork 130
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Make a best effort attempt to initialise all Databricks globals (#562)
## Changes We only initialise dbutils locally when using `from databricks.sdk.runtime import *`. Users in the webui are guided to use this import in all library code. <img width="812" alt="image" src="https://github.com/databricks/databricks-sdk-py/assets/88345179/3f28bc3c-0ba9-41ac-b990-3c4f5bf138aa"> The local (for people outside DBR) solution so far was to initialise spark manually. But this can be tedious for deeply nested libraries (which is the reason this import was introduced in the first place). Now, we make a best effort attempt to initialise maximum number of globals locally, so that users can build and debug libraries using databricks connect. ## Tests * integration test - [ ] `make test` run locally - [x] `make fmt` applied - [ ] relevant integration tests applied --------- Signed-off-by: Kartik Gupta <88345179+kartikgupta-db@users.noreply.github.com> Co-authored-by: Miles Yucht <miles@databricks.com>
- Loading branch information
1 parent
5255760
commit f0fe023
Showing
9 changed files
with
176 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import pytest | ||
|
||
DBCONNECT_DBR_CLIENT = {"13.3": "13.3.3", "14.3": "14.3.1", } | ||
|
||
|
||
def reload_modules(name: str): | ||
""" | ||
Reloads the specified module. This is useful when testing Databricks Connect, since both | ||
the `databricks.connect` and `databricks.sdk.runtime` modules are stateful, and we need | ||
to reload these modules to reset the state cache between test runs. | ||
""" | ||
|
||
import importlib | ||
import sys | ||
|
||
v = sys.modules.get(name) | ||
if v is None: | ||
return | ||
try: | ||
print(f"Reloading {name}") | ||
importlib.reload(v) | ||
except Exception as e: | ||
print(f"Failed to reload {name}: {e}") | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def restorable_env(): | ||
import os | ||
current_env = os.environ.copy() | ||
yield | ||
for k, v in os.environ.items(): | ||
if k not in current_env: | ||
del os.environ[k] | ||
elif v != current_env[k]: | ||
os.environ[k] = current_env[k] | ||
|
||
|
||
@pytest.fixture(params=list(DBCONNECT_DBR_CLIENT.keys())) | ||
def setup_dbconnect_test(request, env_or_skip, restorable_env): | ||
dbr = request.param | ||
assert dbr in DBCONNECT_DBR_CLIENT, f"Unsupported Databricks Runtime version {dbr}. Please update DBCONNECT_DBR_CLIENT." | ||
|
||
import os | ||
os.environ["DATABRICKS_CLUSTER_ID"] = env_or_skip( | ||
f"TEST_DBR_{dbr.replace('.', '_')}_DBCONNECT_CLUSTER_ID") | ||
|
||
import subprocess | ||
import sys | ||
lib = f"databricks-connect=={DBCONNECT_DBR_CLIENT[dbr]}" | ||
subprocess.check_call([sys.executable, "-m", "pip", "install", lib]) | ||
|
||
yield | ||
|
||
subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", "databricks-connect"]) | ||
|
||
|
||
@pytest.mark.xdist_group(name="databricks-connect") | ||
def test_dbconnect_initialisation(w, setup_dbconnect_test): | ||
reload_modules("databricks.connect") | ||
from databricks.connect import DatabricksSession | ||
|
||
spark = DatabricksSession.builder.getOrCreate() | ||
assert spark.sql("SELECT 1").collect()[0][0] == 1 | ||
|
||
|
||
@pytest.mark.xdist_group(name="databricks-connect") | ||
def test_dbconnect_runtime_import(w, setup_dbconnect_test): | ||
reload_modules("databricks.sdk.runtime") | ||
from databricks.sdk.runtime import spark | ||
|
||
assert spark.sql("SELECT 1").collect()[0][0] == 1 | ||
|
||
|
||
@pytest.mark.xdist_group(name="databricks-connect") | ||
def test_dbconnect_runtime_import_no_error_if_doesnt_exist(w): | ||
reload_modules("databricks.sdk.runtime") | ||
from databricks.sdk.runtime import spark | ||
|
||
assert spark is None |