From d296cad70c7dbffc8a4e5e073188eb72034af1de Mon Sep 17 00:00:00 2001 From: Phong Bui Date: Thu, 20 Jun 2024 22:07:55 +0700 Subject: [PATCH] ISSUE-38834: Add tests and modify logic --- airflow/cli/commands/dag_command.py | 32 +++++++++++++------------- tests/cli/commands/test_dag_command.py | 13 +++++++++-- 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/airflow/cli/commands/dag_command.py b/airflow/cli/commands/dag_command.py index 66444a231d52b..c574d725e21b2 100644 --- a/airflow/cli/commands/dag_command.py +++ b/airflow/cli/commands/dag_command.py @@ -225,18 +225,24 @@ def dag_unpause(args) -> None: def set_is_paused(is_paused: bool, args) -> None: """Set is_paused for DAG by a given dag_id.""" should_apply = True - dags = [ - dag - for dag in get_dags(args.subdir, dag_id=args.dag_id, use_regex=args.treat_dag_id_as_regex) - if is_paused != dag.get_is_paused() - ] + with create_session() as session: + query = select(DagModel) + + if args.treat_dag_id_as_regex: + query = query.where(DagModel.dag_id.regexp_match(args.dag_id)) + else: + query = query.where(DagModel.dag_id == args.dag_id) + + query = query.where(DagModel.is_paused != is_paused) + + matched_dags = session.scalars(query).all() - if not dags: + if not matched_dags: print(f"No {'un' if is_paused else ''}paused DAGs were found") return if not args.yes and args.treat_dag_id_as_regex: - dags_ids = [dag.dag_id for dag in dags] + dags_ids = [dag.dag_id for dag in matched_dags] question = ( f"You are about to {'un' if not is_paused else ''}pause {len(dags_ids)} DAGs:\n" f"{','.join(dags_ids)}" @@ -245,17 +251,11 @@ def set_is_paused(is_paused: bool, args) -> None: should_apply = ask_yesno(question) if should_apply: - dags_models = [DagModel.get_dagmodel(dag.dag_id) for dag in dags] - for dag_model in dags_models: - if dag_model is not None: - dag_model.set_is_paused(is_paused=is_paused) + for dag_model in matched_dags: + dag_model.set_is_paused(is_paused=is_paused) AirflowConsole().print_as( - data=[ - {"dag_id": dag.dag_id, "is_paused": dag.get_is_paused()} - for dag in dags_models - if dag is not None - ], + data=[{"dag_id": dag.dag_id, "is_paused": not dag.get_is_paused()} for dag in matched_dags], output=args.output, ) else: diff --git a/tests/cli/commands/test_dag_command.py b/tests/cli/commands/test_dag_command.py index de8788ada8515..7ddda0fb862c8 100644 --- a/tests/cli/commands/test_dag_command.py +++ b/tests/cli/commands/test_dag_command.py @@ -717,10 +717,19 @@ def test_pause_regex_yes(self, mock_yesno): mock_yesno.assert_not_called() dag_command.dag_unpause(args) - def test_pause_non_existing_dag_error(self): + def test_pause_non_existing_dag_do_not_error(self): args = self.parser.parse_args(["dags", "pause", "non_existing_dag"]) - with pytest.raises(AirflowException): + with contextlib.redirect_stdout(StringIO()) as temp_stdout: dag_command.dag_pause(args) + out = temp_stdout.getvalue().strip().splitlines()[-1] + assert out == "No unpaused DAGs were found" + + def test_unpause_non_existing_dag_do_not_error(self): + args = self.parser.parse_args(["dags", "unpause", "non_existing_dag"]) + with contextlib.redirect_stdout(StringIO()) as temp_stdout: + dag_command.dag_unpause(args) + out = temp_stdout.getvalue().strip().splitlines()[-1] + assert out == "No paused DAGs were found" def test_unpause_already_unpaused_dag_do_not_error(self): args = self.parser.parse_args(["dags", "unpause", "example_bash_operator", "--yes"])