From ed09464667f7d83e25b4eb377839b873dbc5e925 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Mon, 15 Nov 2021 17:09:22 +0000 Subject: [PATCH] Fix `airflow db check-migrations` This command broke after "Define datetime and StringID column types centrally in migrations" was merged, but we didn't notice as Github styles timedout checks badly. --- airflow/utils/db.py | 14 +++++++++----- tests/utils/test_db.py | 6 +++++- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index b09429d0dab42..86687494128a4 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -627,13 +627,17 @@ def check_migrations(timeout): :param timeout: Timeout for the migration in seconds :return: None """ - from alembic.runtime.migration import MigrationContext + from alembic.runtime.environment import EnvironmentContext from alembic.script import ScriptDirectory config = _get_alembic_config() script_ = ScriptDirectory.from_config(config) - with settings.engine.connect() as connection: - context = MigrationContext.configure(connection) + with EnvironmentContext( + config, + script_, + ) as env, settings.engine.connect() as connection: + env.configure(connection) + context = env.get_context() ticker = 0 while True: source_heads = set(script_.get_heads()) @@ -642,8 +646,8 @@ def check_migrations(timeout): break if ticker >= timeout: raise TimeoutError( - f"There are still unapplied migrations after {ticker} seconds. " - f"Migration Head(s) in DB: {db_heads} | Migration Head(s) in Source Code: {source_heads}" + f"There are still unapplied migrations after {ticker} seconds. Migration" + f"Head(s) in DB: {db_heads} | Migration Head(s) in Source Code: {source_heads}" ) ticker += 1 time.sleep(1) diff --git a/tests/utils/test_db.py b/tests/utils/test_db.py index d0c82a10a9011..39ca43b43a9c0 100644 --- a/tests/utils/test_db.py +++ b/tests/utils/test_db.py @@ -29,7 +29,7 @@ from airflow.models import Base as airflow_base from airflow.settings import engine -from airflow.utils.db import create_default_connections +from airflow.utils.db import check_migrations, create_default_connections class TestDb(unittest.TestCase): @@ -103,3 +103,7 @@ def test_default_connections_sort(self): source = inspect.getsource(create_default_connections) src = pattern.findall(source) assert sorted(src) == src + + def test_check_migrations(self): + # Should run without error. Can't easily test the behaviour, but we can check it works + check_migrations(1)