Skip to content

Commit

Permalink
insulate env vars
Browse files Browse the repository at this point in the history
  • Loading branch information
benc-db committed Dec 19, 2024
1 parent f2445f2 commit f3a83e5
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 14 deletions.
10 changes: 5 additions & 5 deletions dbt/adapters/databricks/global_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class GlobalState:
def get_use_long_sessions(cls) -> bool:
if cls.__use_long_sessions is None:
cls.__use_long_sessions = (
os.environ.get("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE"
os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE"
)
return cls.__use_long_sessions

Expand All @@ -23,7 +23,7 @@ def get_use_long_sessions(cls) -> bool:
@classmethod
def get_invocation_env(cls) -> Optional[str]:
if not cls.__invocation_env_set:
cls.__invocation_env = os.environ.get("DBT_DATABRICKS_INVOCATION_ENV")
cls.__invocation_env = os.getenv("DBT_DATABRICKS_INVOCATION_ENV")
cls.__invocation_env_set = True
return cls.__invocation_env

Expand All @@ -33,7 +33,7 @@ def get_invocation_env(cls) -> Optional[str]:
@classmethod
def get_http_session_headers(cls) -> Optional[str]:
if not cls.__session_headers_set:
cls.__session_headers = os.environ.get("DBT_DATABRICKS_HTTP_SESSION_HEADERS")
cls.__session_headers = os.getenv("DBT_DATABRICKS_HTTP_SESSION_HEADERS")
cls.__session_headers_set = True
return cls.__session_headers

Expand All @@ -43,7 +43,7 @@ def get_http_session_headers(cls) -> Optional[str]:
def get_char_limit_bypass(cls) -> bool:
if cls.__describe_char_bypass is None:
cls.__describe_char_bypass = (
os.environ.get("DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS", "False").upper() == "TRUE"
os.getenv("DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS", "False").upper() == "TRUE"
)
return cls.__describe_char_bypass

Expand All @@ -52,7 +52,7 @@ def get_char_limit_bypass(cls) -> bool:
@classmethod
def get_connector_log_level(cls) -> str:
if cls.__connector_log_level is None:
cls.__connector_log_level = os.environ.get(
cls.__connector_log_level = os.getenv(
"DBT_DATABRICKS_CONNECTOR_LOG_LEVEL", "WARN"
).upper()
return cls.__connector_log_level
31 changes: 22 additions & 9 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,10 @@ def test_invalid_custom_user_agent(self):
with pytest.raises(DbtValidationError) as excinfo:
config = self._get_config()
adapter = DatabricksAdapter(config, get_context("spawn"))
with patch.dict("os.environ", **{"DBT_DATABRICKS_INVOCATION_ENV": "(Some-thing)"}):
with patch(
"dbt.adapters.databricks.global_state.GlobalState.get_invocation_env",
return_value="(Some-thing)",
):
connection = adapter.acquire_connection("dummy")
connection.handle # trigger lazy-load

Expand All @@ -126,8 +129,9 @@ def test_custom_user_agent(self):
"dbt.adapters.databricks.connections.dbsql.connect",
new=self._connect_func(expected_invocation_env="databricks-workflows"),
):
with patch.dict(
"os.environ", **{"DBT_DATABRICKS_INVOCATION_ENV": "databricks-workflows"}
with patch(
"dbt.adapters.databricks.global_state.GlobalState.get_invocation_env",
return_value="databricks-workflows",
):
connection = adapter.acquire_connection("dummy")
connection.handle # trigger lazy-load
Expand Down Expand Up @@ -188,9 +192,9 @@ def _test_environment_http_headers(
"dbt.adapters.databricks.connections.dbsql.connect",
new=self._connect_func(expected_http_headers=expected_http_headers),
):
with patch.dict(
"os.environ",
**{"DBT_DATABRICKS_HTTP_SESSION_HEADERS": http_headers_str},
with patch(
"dbt.adapters.databricks.global_state.GlobalState.get_http_session_headers",
return_value=http_headers_str,
):
connection = adapter.acquire_connection("dummy")
connection.handle # trigger lazy-load
Expand Down Expand Up @@ -910,7 +914,10 @@ def test_describe_table_extended_2048_char_limit(self):
assert get_identifier_list_string(table_names) == "|".join(table_names)

# If environment variable is set, then limit the number of characters
with patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}):
with patch(
"dbt.adapters.databricks.global_state.GlobalState.get_char_limit_bypass",
return_value="true",
):
# Long list of table names is capped
assert get_identifier_list_string(table_names) == "*"

Expand Down Expand Up @@ -939,7 +946,10 @@ def test_describe_table_extended_should_limit(self):
table_names = set([f"customers_{i}" for i in range(200)])

# If environment variable is set, then limit the number of characters
with patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}):
with patch(
"dbt.adapters.databricks.global_state.GlobalState.get_char_limit_bypass",
return_value="true",
):
# Long list of table names is capped
assert get_identifier_list_string(table_names) == "*"

Expand All @@ -952,7 +962,10 @@ def test_describe_table_extended_may_limit(self):
table_names = set([f"customers_{i}" for i in range(200)])

# If environment variable is set, then we may limit the number of characters
with patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}):
with patch(
"dbt.adapters.databricks.global_state.GlobalState.get_char_limit_bypass",
return_value="true",
):
# But a short list of table names is not capped
assert get_identifier_list_string(list(table_names)[:5]) == "|".join(
list(table_names)[:5]
Expand Down

0 comments on commit f3a83e5

Please sign in to comment.