From 79efd74392d9d9cff5a456ab8ffb2d74a5971dee Mon Sep 17 00:00:00 2001 From: Phani Kumar Date: Thu, 6 Jul 2023 16:26:17 +0530 Subject: [PATCH] Include changes to airflow cli --- airflow/cli/commands/connection_command.py | 14 ++++++++------ tests/cli/commands/test_connection_command.py | 15 +++++++-------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/airflow/cli/commands/connection_command.py b/airflow/cli/commands/connection_command.py index 63888c0ec96ed..e7b83e342e298 100644 --- a/airflow/cli/commands/connection_command.py +++ b/airflow/cli/commands/connection_command.py @@ -26,6 +26,7 @@ from typing import Any from urllib.parse import urlsplit, urlunsplit +from sqlalchemy import select from sqlalchemy.orm import exc from airflow.cli.simple_table import AirflowConsole @@ -77,9 +78,10 @@ def connections_get(args): def connections_list(args): """Lists all connections at the command line.""" with create_session() as session: - query = session.query(Connection) + query = select(Connection) if args.conn_id: - query = query.filter(Connection.conn_id == args.conn_id) + query = query.where(Connection.conn_id == args.conn_id) + query = session.scalars(query) conns = query.all() AirflowConsole().print_as( @@ -177,7 +179,7 @@ def connections_export(args): raise SystemExit("Option `--serialization-format` may only be used with file type `env`.") with create_session() as session: - connections = session.query(Connection).order_by(Connection.conn_id).all() + connections = session.scalars(select(Connection).order_by(Connection.conn_id)).all() msg = _format_connections( conns=connections, @@ -265,7 +267,7 @@ def connections_add(args): new_conn.set_extra(args.conn_extra) with create_session() as session: - if not session.query(Connection).filter(Connection.conn_id == new_conn.conn_id).first(): + if not session.scalar(select(Connection).where(Connection.conn_id == new_conn.conn_id).limit(1)): session.add(new_conn) msg = "Successfully added `conn_id`={conn_id} : {uri}" msg = msg.format( @@ -293,7 +295,7 @@ def connections_delete(args): """Deletes connection from DB.""" with create_session() as session: try: - to_delete = session.query(Connection).filter(Connection.conn_id == args.conn_id).one() + to_delete = session.scalars(select(Connection).where(Connection.conn_id == args.conn_id)).one() except exc.NoResultFound: raise SystemExit(f"Did not find a connection with `conn_id`={args.conn_id}") except exc.MultipleResultsFound: @@ -326,7 +328,7 @@ def _import_helper(file_path: str, overwrite: bool) -> None: print(f"Could not import connection. {e}") continue - existing_conn_id = session.query(Connection.id).filter(Connection.conn_id == conn_id).scalar() + existing_conn_id = session.scalar(select(Connection.id).where(Connection.conn_id == conn_id)) if existing_conn_id is not None: if not overwrite: print(f"Could not import connection {conn_id}: connection already exists.") diff --git a/tests/cli/commands/test_connection_command.py b/tests/cli/commands/test_connection_command.py index 12b3de6170c0b..7cc19af529d2a 100644 --- a/tests/cli/commands/test_connection_command.py +++ b/tests/cli/commands/test_connection_command.py @@ -26,6 +26,7 @@ from unittest import mock import pytest +from sqlalchemy import select from airflow.cli import cli_parser from airflow.cli.commands import connection_command @@ -167,9 +168,7 @@ def test_cli_connections_export_should_raise_error_if_fetching_connections_fails def my_side_effect(_): raise Exception("dummy exception") - mock_session.return_value.__enter__.return_value.query.return_value.order_by.side_effect = ( - my_side_effect - ) + mock_session.return_value.__enter__.return_value.scalars.side_effect = my_side_effect args = self.parser.parse_args(["connections", "export", output_filepath.as_posix()]) with pytest.raises(Exception, match=r"dummy exception"): connection_command.connections_export(args) @@ -566,7 +565,7 @@ def test_cli_connection_add(self, cmd, expected_output, expected_conn): "port", "schema", ] - current_conn = session.query(Connection).filter(Connection.conn_id == conn_id).first() + current_conn = session.scalars(select(Connection).where(Connection.conn_id == conn_id)).first() assert expected_conn == {attr: getattr(current_conn, attr) for attr in comparable_attrs} def test_cli_connections_add_duplicate(self): @@ -676,7 +675,7 @@ def test_cli_delete_connections(self, session=None): assert "Successfully deleted connection with `conn_id`=new1" in stdout # Check deletions - result = session.query(Connection).filter(Connection.conn_id == "new1").first() + result = session.scalars(select(Connection).filter(Connection.conn_id == "new1")).first() assert result is None @@ -752,7 +751,7 @@ def test_cli_connections_import_should_load_connections(self, mock_exists, mock_ # Verify that the imported connections match the expected, sample connections with create_session() as session: - current_conns = session.query(Connection).all() + current_conns = session.scalars(select(Connection)).all() comparable_attrs = [ "conn_id", @@ -829,7 +828,7 @@ def test_cli_connections_import_should_not_overwrite_existing_connections( assert "Could not import connection new3: connection already exists." in stdout.getvalue() # Verify that the imported connections match the expected, sample connections - current_conns = session.query(Connection).all() + current_conns = session.scalars(select(Connection)).all() comparable_attrs = [ "conn_id", @@ -909,7 +908,7 @@ def test_cli_connections_import_should_overwrite_existing_connections( assert "Could not import connection new3: connection already exists." not in stdout.getvalue() # Verify that the imported connections match the expected, sample connections - current_conns = session.query(Connection).all() + current_conns = session.scalars(select(Connection)).all() comparable_attrs = [ "conn_id",