diff --git a/airflow/providers/oracle/hooks/oracle.py b/airflow/providers/oracle/hooks/oracle.py index 95b5fdca9faee..9ebed94587826 100644 --- a/airflow/providers/oracle/hooks/oracle.py +++ b/airflow/providers/oracle/hooks/oracle.py @@ -16,6 +16,7 @@ # specific language governing permissions and limitations # under the License. +import warnings from datetime import datetime from typing import Dict, List, Optional, Union @@ -87,6 +88,7 @@ def get_conn(self) -> 'OracleHook': conn_config = {'user': conn.login, 'password': conn.password} sid = conn.extra_dejson.get('sid') mod = conn.extra_dejson.get('module') + schema = conn.schema service_name = conn.extra_dejson.get('service_name') port = conn.port if conn.port else 1521 @@ -100,8 +102,16 @@ def get_conn(self) -> 'OracleHook': dsn = conn.host if conn.port is not None: dsn += ":" + str(conn.port) - if service_name or conn.schema: - dsn += "/" + (service_name or conn.schema) + if service_name: + dsn += "/" + service_name + elif conn.schema: + warnings.warn( + """Using conn.schema to pass the Oracle Service Name is deprecated. + Please use conn.extra.service_name instead.""", + DeprecationWarning, + stacklevel=2, + ) + dsn += "/" + conn.schema conn_config['dsn'] = dsn if 'encoding' in conn.extra_dejson: @@ -146,6 +156,12 @@ def get_conn(self) -> 'OracleHook': if mod is not None: conn.module = mod + # if Connection.schema is defined, set schema after connecting successfully + # cannot be part of conn_config + # https://cx-oracle.readthedocs.io/en/latest/api_manual/connection.html?highlight=schema#Connection.current_schema + if schema is not None: + conn.current_schema = schema + return conn def insert_rows( diff --git a/tests/providers/oracle/hooks/test_oracle.py b/tests/providers/oracle/hooks/test_oracle.py index 3eae248107fd1..9217fec881c16 100644 --- a/tests/providers/oracle/hooks/test_oracle.py +++ b/tests/providers/oracle/hooks/test_oracle.py @@ -176,6 +176,10 @@ def test_get_conn_purity(self, mock_connect): assert args == () assert kwargs['purity'] == purity.get(pur) + @mock.patch('airflow.providers.oracle.hooks.oracle.cx_Oracle.connect') + def test_set_current_schema(self, mock_connect): + assert self.db_hook.get_conn().current_schema == self.connection.schema + @unittest.skipIf(cx_Oracle is None, 'cx_Oracle package not present') class TestOracleHook(unittest.TestCase):