Skip to content

Commit

Permalink
Include changes to airflow cli
Browse files Browse the repository at this point in the history
  • Loading branch information
phanikumv committed Jul 7, 2023
1 parent 846f2c1 commit 79efd74
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
14 changes: 8 additions & 6 deletions airflow/cli/commands/connection_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from typing import Any
from urllib.parse import urlsplit, urlunsplit

from sqlalchemy import select
from sqlalchemy.orm import exc

from airflow.cli.simple_table import AirflowConsole
Expand Down Expand Up @@ -77,9 +78,10 @@ def connections_get(args):
def connections_list(args):
"""Lists all connections at the command line."""
with create_session() as session:
query = session.query(Connection)
query = select(Connection)
if args.conn_id:
query = query.filter(Connection.conn_id == args.conn_id)
query = query.where(Connection.conn_id == args.conn_id)
query = session.scalars(query)
conns = query.all()

AirflowConsole().print_as(
Expand Down Expand Up @@ -177,7 +179,7 @@ def connections_export(args):
raise SystemExit("Option `--serialization-format` may only be used with file type `env`.")

with create_session() as session:
connections = session.query(Connection).order_by(Connection.conn_id).all()
connections = session.scalars(select(Connection).order_by(Connection.conn_id)).all()

msg = _format_connections(
conns=connections,
Expand Down Expand Up @@ -265,7 +267,7 @@ def connections_add(args):
new_conn.set_extra(args.conn_extra)

with create_session() as session:
if not session.query(Connection).filter(Connection.conn_id == new_conn.conn_id).first():
if not session.scalar(select(Connection).where(Connection.conn_id == new_conn.conn_id).limit(1)):
session.add(new_conn)
msg = "Successfully added `conn_id`={conn_id} : {uri}"
msg = msg.format(
Expand Down Expand Up @@ -293,7 +295,7 @@ def connections_delete(args):
"""Deletes connection from DB."""
with create_session() as session:
try:
to_delete = session.query(Connection).filter(Connection.conn_id == args.conn_id).one()
to_delete = session.scalars(select(Connection).where(Connection.conn_id == args.conn_id)).one()
except exc.NoResultFound:
raise SystemExit(f"Did not find a connection with `conn_id`={args.conn_id}")
except exc.MultipleResultsFound:
Expand Down Expand Up @@ -326,7 +328,7 @@ def _import_helper(file_path: str, overwrite: bool) -> None:
print(f"Could not import connection. {e}")
continue

existing_conn_id = session.query(Connection.id).filter(Connection.conn_id == conn_id).scalar()
existing_conn_id = session.scalar(select(Connection.id).where(Connection.conn_id == conn_id))
if existing_conn_id is not None:
if not overwrite:
print(f"Could not import connection {conn_id}: connection already exists.")
Expand Down
15 changes: 7 additions & 8 deletions tests/cli/commands/test_connection_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from unittest import mock

import pytest
from sqlalchemy import select

from airflow.cli import cli_parser
from airflow.cli.commands import connection_command
Expand Down Expand Up @@ -167,9 +168,7 @@ def test_cli_connections_export_should_raise_error_if_fetching_connections_fails
def my_side_effect(_):
raise Exception("dummy exception")

mock_session.return_value.__enter__.return_value.query.return_value.order_by.side_effect = (
my_side_effect
)
mock_session.return_value.__enter__.return_value.scalars.side_effect = my_side_effect
args = self.parser.parse_args(["connections", "export", output_filepath.as_posix()])
with pytest.raises(Exception, match=r"dummy exception"):
connection_command.connections_export(args)
Expand Down Expand Up @@ -566,7 +565,7 @@ def test_cli_connection_add(self, cmd, expected_output, expected_conn):
"port",
"schema",
]
current_conn = session.query(Connection).filter(Connection.conn_id == conn_id).first()
current_conn = session.scalars(select(Connection).where(Connection.conn_id == conn_id)).first()
assert expected_conn == {attr: getattr(current_conn, attr) for attr in comparable_attrs}

def test_cli_connections_add_duplicate(self):
Expand Down Expand Up @@ -676,7 +675,7 @@ def test_cli_delete_connections(self, session=None):
assert "Successfully deleted connection with `conn_id`=new1" in stdout

# Check deletions
result = session.query(Connection).filter(Connection.conn_id == "new1").first()
result = session.scalars(select(Connection).filter(Connection.conn_id == "new1")).first()

assert result is None

Expand Down Expand Up @@ -752,7 +751,7 @@ def test_cli_connections_import_should_load_connections(self, mock_exists, mock_

# Verify that the imported connections match the expected, sample connections
with create_session() as session:
current_conns = session.query(Connection).all()
current_conns = session.scalars(select(Connection)).all()

comparable_attrs = [
"conn_id",
Expand Down Expand Up @@ -829,7 +828,7 @@ def test_cli_connections_import_should_not_overwrite_existing_connections(
assert "Could not import connection new3: connection already exists." in stdout.getvalue()

# Verify that the imported connections match the expected, sample connections
current_conns = session.query(Connection).all()
current_conns = session.scalars(select(Connection)).all()

comparable_attrs = [
"conn_id",
Expand Down Expand Up @@ -909,7 +908,7 @@ def test_cli_connections_import_should_overwrite_existing_connections(
assert "Could not import connection new3: connection already exists." not in stdout.getvalue()

# Verify that the imported connections match the expected, sample connections
current_conns = session.query(Connection).all()
current_conns = session.scalars(select(Connection)).all()

comparable_attrs = [
"conn_id",
Expand Down

0 comments on commit 79efd74

Please sign in to comment.