Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Rename schema to database in PostgresHook (#26436)" #26734

Merged
merged 1 commit into from
Sep 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 4 additions & 15 deletions airflow/providers/postgres/hooks/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from __future__ import annotations

import os
import warnings
from contextlib import closing
from copy import deepcopy
from typing import Iterable, Union
Expand Down Expand Up @@ -68,18 +67,10 @@ class PostgresHook(DbApiHook):
supports_autocommit = True

def __init__(self, *args, **kwargs) -> None:
if 'schema' in kwargs:
warnings.warn(
'The "schema" arg has been renamed to "database" as it contained the database name.'
'Please use "database" to set the database name.',
DeprecationWarning,
stacklevel=2,
)
kwargs['database'] = kwargs['schema']
super().__init__(*args, **kwargs)
self.connection: Connection | None = kwargs.pop("connection", None)
self.conn: connection = None
self.database: str | None = kwargs.pop("database", None)
self.schema: str | None = kwargs.pop("schema", None)

def _get_cursor(self, raw_cursor: str) -> CursorType:
_cursor = raw_cursor.lower()
Expand All @@ -104,7 +95,7 @@ def get_conn(self) -> connection:
host=conn.host,
user=conn.login,
password=conn.password,
dbname=self.database or conn.schema,
dbname=self.schema or conn.schema,
port=conn.port,
)
raw_cursor = conn.extra_dejson.get('cursor', False)
Expand Down Expand Up @@ -152,9 +143,7 @@ def get_uri(self) -> str:
Extract the URI from the connection.
:return: the extracted uri.
"""
conn = self.get_connection(getattr(self, self.conn_name_attr))
conn.schema = self.database or conn.schema
uri = conn.get_uri().replace("postgres://", "postgresql://")
uri = super().get_uri().replace("postgres://", "postgresql://")
return uri

def bulk_load(self, table: str, tmp_file: str) -> None:
Expand Down Expand Up @@ -208,7 +197,7 @@ def get_iam_token(self, conn: Connection) -> tuple[str, str, int]:
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/redshift.html#Redshift.Client.get_cluster_credentials
cluster_creds = redshift_client.get_cluster_credentials(
DbUser=login,
DbName=self.database or conn.schema,
DbName=self.schema or conn.schema,
ClusterIdentifier=cluster_identifier,
AutoCreate=False,
)
Expand Down
4 changes: 1 addition & 3 deletions airflow/providers/postgres/operators/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ class PostgresOperator(BaseOperator):
(default value: False)
:param parameters: (optional) the parameters to render the SQL query with.
:param database: name of database which overwrite defined one in connection
:param runtime_parameters: a mapping of runtime params added to the final sql being executed.
For example, you could set the schema via `{"search_path": "CUSTOM_SCHEMA"}`.
"""

template_fields: Sequence[str] = ('sql',)
Expand Down Expand Up @@ -75,7 +73,7 @@ def __init__(
self.hook: PostgresHook | None = None

def execute(self, context: Context):
self.hook = PostgresHook(postgres_conn_id=self.postgres_conn_id, database=self.database)
self.hook = PostgresHook(postgres_conn_id=self.postgres_conn_id, schema=self.database)
if self.runtime_parameters:
final_sql = []
sql_param = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,7 @@ Host (required)
The host to connect to.

Schema (optional)
Specify the name of the database to connect to.

.. note::

If you want to define a default database schema:

* using ``PostgresOperator`` see :ref:`Passing Server Configuration Parameters into PostgresOperator <howto/operators:postgres>`
* using ``PostgresHook`` see `search_path <https://www.postgresql.org/docs/current/ddl-schemas.html#DDL-SCHEMAS-PATH>_`
Specify the schema name to be used in the database.

Login (required)
Specify the user name to connect.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
specific language governing permissions and limitations
under the License.

.. _howto/operators:postgres:

How-to Guide for PostgresOperator
=================================

Expand Down
43 changes: 19 additions & 24 deletions tests/providers/postgres/hooks/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
class TestPostgresHookConn:
@pytest.fixture(autouse=True)
def setup(self):
self.connection = Connection(login='login', password='password', host='host', schema='database')
self.connection = Connection(login='login', password='password', host='host', schema='schema')

class UnitTestPostgresHook(PostgresHook):
conn_name_attr = 'test_conn_id'
Expand All @@ -47,15 +47,15 @@ def test_get_conn_non_default_id(self, mock_connect):
self.db_hook.test_conn_id = 'non_default'
self.db_hook.get_conn()
mock_connect.assert_called_once_with(
user='login', password='password', host='host', dbname='database', port=None
user='login', password='password', host='host', dbname='schema', port=None
)
self.db_hook.get_connection.assert_called_once_with('non_default')

@mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect')
def test_get_conn(self, mock_connect):
self.db_hook.get_conn()
mock_connect.assert_called_once_with(
user='login', password='password', host='host', dbname='database', port=None
user='login', password='password', host='host', dbname='schema', port=None
)

@mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect')
Expand All @@ -64,7 +64,7 @@ def test_get_uri(self, mock_connect):
self.connection.conn_type = 'postgres'
self.db_hook.get_conn()
assert mock_connect.call_count == 1
assert self.db_hook.get_uri() == "postgresql://login:password@host/database?client_encoding=utf-8"
assert self.db_hook.get_uri() == "postgresql://login:password@host/schema?client_encoding=utf-8"

@mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect')
def test_get_conn_cursor(self, mock_connect):
Expand All @@ -75,7 +75,7 @@ def test_get_conn_cursor(self, mock_connect):
user='login',
password='password',
host='host',
dbname='database',
dbname='schema',
port=None,
)

Expand All @@ -87,20 +87,20 @@ def test_get_conn_with_invalid_cursor(self, mock_connect):

@mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect')
def test_get_conn_from_connection(self, mock_connect):
conn = Connection(login='login-conn', password='password-conn', host='host', schema='database')
conn = Connection(login='login-conn', password='password-conn', host='host', schema='schema')
hook = PostgresHook(connection=conn)
hook.get_conn()
mock_connect.assert_called_once_with(
user='login-conn', password='password-conn', host='host', dbname='database', port=None
user='login-conn', password='password-conn', host='host', dbname='schema', port=None
)

@mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect')
def test_get_conn_from_connection_with_database(self, mock_connect):
conn = Connection(login='login-conn', password='password-conn', host='host', schema="database")
hook = PostgresHook(connection=conn, database='database-override')
def test_get_conn_from_connection_with_schema(self, mock_connect):
conn = Connection(login='login-conn', password='password-conn', host='host', schema='schema')
hook = PostgresHook(connection=conn, schema='schema-override')
hook.get_conn()
mock_connect.assert_called_once_with(
user='login-conn', password='password-conn', host='host', dbname='database-override', port=None
user='login-conn', password='password-conn', host='host', dbname='schema-override', port=None
)

@mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect')
Expand Down Expand Up @@ -146,7 +146,7 @@ def test_get_conn_extra(self, mock_connect):
self.connection.extra = '{"connect_timeout": 3}'
self.db_hook.get_conn()
mock_connect.assert_called_once_with(
user='login', password='password', host='host', dbname='database', port=None, connect_timeout=3
user='login', password='password', host='host', dbname='schema', port=None, connect_timeout=3
)

@mock.patch('airflow.providers.postgres.hooks.postgres.psycopg2.connect')
Expand Down Expand Up @@ -225,37 +225,32 @@ def test_get_conn_rds_iam_redshift(
port=(port or 5439),
)

def test_get_uri_from_connection_without_database_override(self):
def test_get_uri_from_connection_without_schema_override(self):
self.db_hook.get_connection = mock.MagicMock(
return_value=Connection(
conn_type="postgres",
host="host",
login="login",
password="password",
schema="database",
schema="schema",
port=1,
)
)
assert "postgresql://login:password@host:1/database" == self.db_hook.get_uri()
assert "postgresql://login:password@host:1/schema" == self.db_hook.get_uri()

def test_get_uri_from_connection_with_database_override(self):
hook = PostgresHook(database='database-override')
def test_get_uri_from_connection_with_schema_override(self):
hook = PostgresHook(schema='schema-override')
hook.get_connection = mock.MagicMock(
return_value=Connection(
conn_type="postgres",
host="host",
login="login",
password="password",
schema="database",
schema="schema",
port=1,
)
)
assert "postgresql://login:password@host:1/database-override" == hook.get_uri()

def test_schema_kwarg_database_kwarg_compatibility(self):
database = 'database-override'
hook = PostgresHook(schema=database)
assert hook.database == database
assert "postgresql://login:password@host:1/schema-override" == hook.get_uri()


class TestPostgresHook(unittest.TestCase):
Expand Down
6 changes: 3 additions & 3 deletions tests/providers/postgres/operators/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ def test_vacuum(self):
op = PostgresOperator(task_id='postgres_operator_test_vacuum', sql=sql, dag=self.dag, autocommit=True)
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

def test_overwrite_database(self):
def test_overwrite_schema(self):
"""
Verifies option to overwrite connection database
Verifies option to overwrite connection schema
"""

sql = "SELECT 1;"
op = PostgresOperator(
task_id='postgres_operator_test_database_overwrite',
task_id='postgres_operator_test_schema_overwrite',
sql=sql,
dag=self.dag,
autocommit=True,
Expand Down