From cc36f55540ba82a2cdf4a874868beed28327b6e1 Mon Sep 17 00:00:00 2001 From: Jed Cunningham Date: Tue, 27 Sep 2022 16:27:12 -0600 Subject: [PATCH] Revert "Rename schema to database in `PostgresHook` (#26436)" This reverts commit 642375f97de133edba1a6c1fa9397d840e8b5936. --- airflow/providers/postgres/hooks/postgres.py | 19 ++------ .../providers/postgres/operators/postgres.py | 4 +- .../connections/postgres.rst | 9 +--- .../postgres_operator_howto_guide.rst | 2 - .../providers/postgres/hooks/test_postgres.py | 43 ++++++++----------- .../postgres/operators/test_postgres.py | 6 +-- 6 files changed, 28 insertions(+), 55 deletions(-) diff --git a/airflow/providers/postgres/hooks/postgres.py b/airflow/providers/postgres/hooks/postgres.py index 471b293d2bb01..25a5eed4284cd 100644 --- a/airflow/providers/postgres/hooks/postgres.py +++ b/airflow/providers/postgres/hooks/postgres.py @@ -18,7 +18,6 @@ from __future__ import annotations import os -import warnings from contextlib import closing from copy import deepcopy from typing import Iterable, Union @@ -68,18 +67,10 @@ class PostgresHook(DbApiHook): supports_autocommit = True def __init__(self, *args, **kwargs) -> None: - if 'schema' in kwargs: - warnings.warn( - 'The "schema" arg has been renamed to "database" as it contained the database name.' - 'Please use "database" to set the database name.', - DeprecationWarning, - stacklevel=2, - ) - kwargs['database'] = kwargs['schema'] super().__init__(*args, **kwargs) self.connection: Connection | None = kwargs.pop("connection", None) self.conn: connection = None - self.database: str | None = kwargs.pop("database", None) + self.schema: str | None = kwargs.pop("schema", None) def _get_cursor(self, raw_cursor: str) -> CursorType: _cursor = raw_cursor.lower() @@ -104,7 +95,7 @@ def get_conn(self) -> connection: host=conn.host, user=conn.login, password=conn.password, - dbname=self.database or conn.schema, + dbname=self.schema or conn.schema, port=conn.port, ) raw_cursor = conn.extra_dejson.get('cursor', False) @@ -152,9 +143,7 @@ def get_uri(self) -> str: Extract the URI from the connection. :return: the extracted uri. """ - conn = self.get_connection(getattr(self, self.conn_name_attr)) - conn.schema = self.database or conn.schema - uri = conn.get_uri().replace("postgres://", "postgresql://") + uri = super().get_uri().replace("postgres://", "postgresql://") return uri def bulk_load(self, table: str, tmp_file: str) -> None: @@ -208,7 +197,7 @@ def get_iam_token(self, conn: Connection) -> tuple[str, str, int]: # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift.html#Redshift.Client.get_cluster_credentials cluster_creds = redshift_client.get_cluster_credentials( DbUser=login, - DbName=self.database or conn.schema, + DbName=self.schema or conn.schema, ClusterIdentifier=cluster_identifier, AutoCreate=False, ) diff --git a/airflow/providers/postgres/operators/postgres.py b/airflow/providers/postgres/operators/postgres.py index f85e655e0f64a..e12ee58d8ce45 100644 --- a/airflow/providers/postgres/operators/postgres.py +++ b/airflow/providers/postgres/operators/postgres.py @@ -42,8 +42,6 @@ class PostgresOperator(BaseOperator): (default value: False) :param parameters: (optional) the parameters to render the SQL query with. :param database: name of database which overwrite defined one in connection - :param runtime_parameters: a mapping of runtime params added to the final sql being executed. - For example, you could set the schema via `{"search_path": "CUSTOM_SCHEMA"}`. """ template_fields: Sequence[str] = ('sql',) @@ -75,7 +73,7 @@ def __init__( self.hook: PostgresHook | None = None def execute(self, context: Context): - self.hook = PostgresHook(postgres_conn_id=self.postgres_conn_id, database=self.database) + self.hook = PostgresHook(postgres_conn_id=self.postgres_conn_id, schema=self.database) if self.runtime_parameters: final_sql = [] sql_param = {} diff --git a/docs/apache-airflow-providers-postgres/connections/postgres.rst b/docs/apache-airflow-providers-postgres/connections/postgres.rst index 68966dc9260e7..f97e99af84aa7 100644 --- a/docs/apache-airflow-providers-postgres/connections/postgres.rst +++ b/docs/apache-airflow-providers-postgres/connections/postgres.rst @@ -29,14 +29,7 @@ Host (required) The host to connect to. Schema (optional) - Specify the name of the database to connect to. - - .. note:: - - If you want to define a default database schema: - - * using ``PostgresOperator`` see :ref:`Passing Server Configuration Parameters into PostgresOperator ` - * using ``PostgresHook`` see `search_path _` + Specify the schema name to be used in the database. Login (required) Specify the user name to connect. diff --git a/docs/apache-airflow-providers-postgres/operators/postgres_operator_howto_guide.rst b/docs/apache-airflow-providers-postgres/operators/postgres_operator_howto_guide.rst index 648a6c75e00e7..790d02caecd87 100644 --- a/docs/apache-airflow-providers-postgres/operators/postgres_operator_howto_guide.rst +++ b/docs/apache-airflow-providers-postgres/operators/postgres_operator_howto_guide.rst @@ -15,8 +15,6 @@ specific language governing permissions and limitations under the License. -.. _howto/operators:postgres: - How-to Guide for PostgresOperator ================================= diff --git a/tests/providers/postgres/hooks/test_postgres.py b/tests/providers/postgres/hooks/test_postgres.py index 8e7ee68c1d3c4..f8853b14c7ffc 100644 --- a/tests/providers/postgres/hooks/test_postgres.py +++ b/tests/providers/postgres/hooks/test_postgres.py @@ -33,7 +33,7 @@ class TestPostgresHookConn: @pytest.fixture(autouse=True) def setup(self): - self.connection = Connection(login='login', password='password', host='host', schema='database') + self.connection = Connection(login='login', password='password', host='host', schema='schema') class UnitTestPostgresHook(PostgresHook): conn_name_attr = 'test_conn_id' @@ -47,7 +47,7 @@ def test_get_conn_non_default_id(self, mock_connect): self.db_hook.test_conn_id = 'non_default' self.db_hook.get_conn() mock_connect.assert_called_once_with( - user='login', password='password', host='host', dbname='database', port=None + user='login', password='password', host='host', dbname='schema', port=None ) self.db_hook.get_connection.assert_called_once_with('non_default') @@ -55,7 +55,7 @@ def test_get_conn_non_default_id(self, mock_connect): def test_get_conn(self, mock_connect): self.db_hook.get_conn() mock_connect.assert_called_once_with( - user='login', password='password', host='host', dbname='database', port=None + user='login', password='password', host='host', dbname='schema', port=None ) @mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect') @@ -64,7 +64,7 @@ def test_get_uri(self, mock_connect): self.connection.conn_type = 'postgres' self.db_hook.get_conn() assert mock_connect.call_count == 1 - assert self.db_hook.get_uri() == "postgresql://login:password@host/database?client_encoding=utf-8" + assert self.db_hook.get_uri() == "postgresql://login:password@host/schema?client_encoding=utf-8" @mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect') def test_get_conn_cursor(self, mock_connect): @@ -75,7 +75,7 @@ def test_get_conn_cursor(self, mock_connect): user='login', password='password', host='host', - dbname='database', + dbname='schema', port=None, ) @@ -87,20 +87,20 @@ def test_get_conn_with_invalid_cursor(self, mock_connect): @mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect') def test_get_conn_from_connection(self, mock_connect): - conn = Connection(login='login-conn', password='password-conn', host='host', schema='database') + conn = Connection(login='login-conn', password='password-conn', host='host', schema='schema') hook = PostgresHook(connection=conn) hook.get_conn() mock_connect.assert_called_once_with( - user='login-conn', password='password-conn', host='host', dbname='database', port=None + user='login-conn', password='password-conn', host='host', dbname='schema', port=None ) @mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect') - def test_get_conn_from_connection_with_database(self, mock_connect): - conn = Connection(login='login-conn', password='password-conn', host='host', schema="database") - hook = PostgresHook(connection=conn, database='database-override') + def test_get_conn_from_connection_with_schema(self, mock_connect): + conn = Connection(login='login-conn', password='password-conn', host='host', schema='schema') + hook = PostgresHook(connection=conn, schema='schema-override') hook.get_conn() mock_connect.assert_called_once_with( - user='login-conn', password='password-conn', host='host', dbname='database-override', port=None + user='login-conn', password='password-conn', host='host', dbname='schema-override', port=None ) @mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect') @@ -146,7 +146,7 @@ def test_get_conn_extra(self, mock_connect): self.connection.extra = '{"connect_timeout": 3}' self.db_hook.get_conn() mock_connect.assert_called_once_with( - user='login', password='password', host='host', dbname='database', port=None, connect_timeout=3 + user='login', password='password', host='host', dbname='schema', port=None, connect_timeout=3 ) @mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect') @@ -225,37 +225,32 @@ def test_get_conn_rds_iam_redshift( port=(port or 5439), ) - def test_get_uri_from_connection_without_database_override(self): + def test_get_uri_from_connection_without_schema_override(self): self.db_hook.get_connection = mock.MagicMock( return_value=Connection( conn_type="postgres", host="host", login="login", password="password", - schema="database", + schema="schema", port=1, ) ) - assert "postgresql://login:password@host:1/database" == self.db_hook.get_uri() + assert "postgresql://login:password@host:1/schema" == self.db_hook.get_uri() - def test_get_uri_from_connection_with_database_override(self): - hook = PostgresHook(database='database-override') + def test_get_uri_from_connection_with_schema_override(self): + hook = PostgresHook(schema='schema-override') hook.get_connection = mock.MagicMock( return_value=Connection( conn_type="postgres", host="host", login="login", password="password", - schema="database", + schema="schema", port=1, ) ) - assert "postgresql://login:password@host:1/database-override" == hook.get_uri() - - def test_schema_kwarg_database_kwarg_compatibility(self): - database = 'database-override' - hook = PostgresHook(schema=database) - assert hook.database == database + assert "postgresql://login:password@host:1/schema-override" == hook.get_uri() class TestPostgresHook(unittest.TestCase): diff --git a/tests/providers/postgres/operators/test_postgres.py b/tests/providers/postgres/operators/test_postgres.py index 48ba1b61161fd..57fa871e8f048 100644 --- a/tests/providers/postgres/operators/test_postgres.py +++ b/tests/providers/postgres/operators/test_postgres.py @@ -79,14 +79,14 @@ def test_vacuum(self): op = PostgresOperator(task_id='postgres_operator_test_vacuum', sql=sql, dag=self.dag, autocommit=True) op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - def test_overwrite_database(self): + def test_overwrite_schema(self): """ - Verifies option to overwrite connection database + Verifies option to overwrite connection schema """ sql = "SELECT 1;" op = PostgresOperator( - task_id='postgres_operator_test_database_overwrite', + task_id='postgres_operator_test_schema_overwrite', sql=sql, dag=self.dag, autocommit=True,