diff --git a/dbt/adapters/databricks/global_state.py b/dbt/adapters/databricks/global_state.py index 52cd8a15..de240d39 100644 --- a/dbt/adapters/databricks/global_state.py +++ b/dbt/adapters/databricks/global_state.py @@ -13,7 +13,7 @@ class GlobalState: def get_use_long_sessions(cls) -> bool: if cls.__use_long_sessions is None: cls.__use_long_sessions = ( - os.environ.get("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE" + os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE" ) return cls.__use_long_sessions @@ -23,7 +23,7 @@ def get_use_long_sessions(cls) -> bool: @classmethod def get_invocation_env(cls) -> Optional[str]: if not cls.__invocation_env_set: - cls.__invocation_env = os.environ.get("DBT_DATABRICKS_INVOCATION_ENV") + cls.__invocation_env = os.getenv("DBT_DATABRICKS_INVOCATION_ENV") cls.__invocation_env_set = True return cls.__invocation_env @@ -33,7 +33,7 @@ def get_invocation_env(cls) -> Optional[str]: @classmethod def get_http_session_headers(cls) -> Optional[str]: if not cls.__session_headers_set: - cls.__session_headers = os.environ.get("DBT_DATABRICKS_HTTP_SESSION_HEADERS") + cls.__session_headers = os.getenv("DBT_DATABRICKS_HTTP_SESSION_HEADERS") cls.__session_headers_set = True return cls.__session_headers @@ -43,7 +43,7 @@ def get_http_session_headers(cls) -> Optional[str]: def get_char_limit_bypass(cls) -> bool: if cls.__describe_char_bypass is None: cls.__describe_char_bypass = ( - os.environ.get("DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS", "False").upper() == "TRUE" + os.getenv("DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS", "False").upper() == "TRUE" ) return cls.__describe_char_bypass @@ -52,7 +52,7 @@ def get_char_limit_bypass(cls) -> bool: @classmethod def get_connector_log_level(cls) -> str: if cls.__connector_log_level is None: - cls.__connector_log_level = os.environ.get( + cls.__connector_log_level = os.getenv( "DBT_DATABRICKS_CONNECTOR_LOG_LEVEL", "WARN" ).upper() return cls.__connector_log_level diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 4ac564be..d42fa5e1 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -112,7 +112,10 @@ def test_invalid_custom_user_agent(self): with pytest.raises(DbtValidationError) as excinfo: config = self._get_config() adapter = DatabricksAdapter(config, get_context("spawn")) - with patch.dict("os.environ", **{"DBT_DATABRICKS_INVOCATION_ENV": "(Some-thing)"}): + with patch( + "dbt.adapters.databricks.global_state.GlobalState.get_invocation_env", + return_value="(Some-thing)", + ): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load @@ -126,8 +129,9 @@ def test_custom_user_agent(self): "dbt.adapters.databricks.connections.dbsql.connect", new=self._connect_func(expected_invocation_env="databricks-workflows"), ): - with patch.dict( - "os.environ", **{"DBT_DATABRICKS_INVOCATION_ENV": "databricks-workflows"} + with patch( + "dbt.adapters.databricks.global_state.GlobalState.get_invocation_env", + return_value="databricks-workflows", ): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load @@ -188,9 +192,9 @@ def _test_environment_http_headers( "dbt.adapters.databricks.connections.dbsql.connect", new=self._connect_func(expected_http_headers=expected_http_headers), ): - with patch.dict( - "os.environ", - **{"DBT_DATABRICKS_HTTP_SESSION_HEADERS": http_headers_str}, + with patch( + "dbt.adapters.databricks.global_state.GlobalState.get_http_session_headers", + return_value=http_headers_str, ): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load @@ -910,7 +914,10 @@ def test_describe_table_extended_2048_char_limit(self): assert get_identifier_list_string(table_names) == "|".join(table_names) # If environment variable is set, then limit the number of characters - with patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}): + with patch( + "dbt.adapters.databricks.global_state.GlobalState.get_char_limit_bypass", + return_value="true", + ): # Long list of table names is capped assert get_identifier_list_string(table_names) == "*" @@ -939,7 +946,10 @@ def test_describe_table_extended_should_limit(self): table_names = set([f"customers_{i}" for i in range(200)]) # If environment variable is set, then limit the number of characters - with patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}): + with patch( + "dbt.adapters.databricks.global_state.GlobalState.get_char_limit_bypass", + return_value="true", + ): # Long list of table names is capped assert get_identifier_list_string(table_names) == "*" @@ -952,7 +962,10 @@ def test_describe_table_extended_may_limit(self): table_names = set([f"customers_{i}" for i in range(200)]) # If environment variable is set, then we may limit the number of characters - with patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}): + with patch( + "dbt.adapters.databricks.global_state.GlobalState.get_char_limit_bypass", + return_value="true", + ): # But a short list of table names is not capped assert get_identifier_list_string(list(table_names)[:5]) == "|".join( list(table_names)[:5]