Skip to content

Commit

Permalink
Define datetime and StringID column types centrally in migrations (#1…
Browse files Browse the repository at this point in the history
…9408)

We have various flavours of the code all over the place in many
migration files -- which leads to duplication and things not being in
sync.

This pulls them once in to a central location.
  • Loading branch information
ashb authored Nov 11, 2021
1 parent 2bd4b55 commit 7622f5e
Show file tree
Hide file tree
Showing 24 changed files with 238 additions and 264 deletions.
86 changes: 86 additions & 0 deletions airflow/migrations/db_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import sys

import sqlalchemy as sa
from alembic import context
from lazy_object_proxy import Proxy

######################################
# Note about this module:
#
# It loads the specific type dynamically at runtime. For IDE/typing support
# there is an associated db_types.pyi. If you add a new type in here, add a
# simple version in there too.
######################################


def _mssql_use_date_time2():
conn = context.get_bind()
result = conn.execute(
"""SELECT CASE WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion'))
like '8%' THEN '2000' WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion'))
like '9%' THEN '2005' ELSE '2005Plus' END AS MajorVersion"""
).fetchone()
mssql_version = result[0]
return mssql_version not in ("2000", "2005")


MSSQL_USE_DATE_TIME2 = Proxy(_mssql_use_date_time2)


def _mssql_TIMESTAMP():
from sqlalchemy.dialects import mssql

return mssql.DATETIME2(precision=6) if MSSQL_USE_DATE_TIME2 else mssql.DATETIME


def _mysql_TIMESTAMP():
from sqlalchemy.dialects import mysql

return mysql.TIMESTAMP(fsp=6, timezone=True)


def _sa_TIMESTAMP():
return sa.TIMESTAMP(timezone=True)


def _sa_StringID():
from airflow.models.base import StringID

return StringID


def __getattr__(name):
if name in ["TIMESTAMP", "StringID"]:
dialect = context.get_bind().dialect.name
module = globals()

# Lookup the type based on the dialect specific type, or fallback to the generic type
type_ = module.get(f'_{dialect}_{name}', None) or module.get(f'_sa_{name}')
val = module[name] = type_()
return val

raise AttributeError(f"module {__name__} has no attribute {name}")


if sys.version_info < (3, 7):
from pep562 import Pep562

Pep562(__name__)
28 changes: 28 additions & 0 deletions airflow/migrations/db_types.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

import sqlalchemy as sa

TIMESTAMP = sa.TIMESTAMP
"""Database specific timestamp with timezone"""

StringID = sa.String
"""String column type with correct DB collation applied"""

MSSQL_USE_DATE_TIME2: bool
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from alembic import op
from sqlalchemy.engine.reflection import Inspector

from airflow.models.base import COLLATION_ARGS
from airflow.migrations.db_types import StringID

# revision identifiers, used by Alembic.
revision = '03afc6b6f902'
Expand Down Expand Up @@ -63,7 +63,7 @@ def upgrade():
op.alter_column(
table_name='ab_view_menu',
column_name='name',
type_=sa.String(length=250, **COLLATION_ARGS),
type_=StringID(length=250),
nullable=False,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import mysql

from airflow.models.base import COLLATION_ARGS
from airflow.migrations.db_types import TIMESTAMP, StringID

# revision identifiers, used by Alembic.
revision = '0a2a5b66e19d'
Expand All @@ -38,43 +37,25 @@
INDEX_NAME = 'idx_' + TABLE_NAME + '_dag_task_date'


# For Microsoft SQL Server, TIMESTAMP is a row-id type,
# having nothing to do with date-time. DateTime() will
# be sufficient.
def mssql_timestamp():
return sa.DateTime()


def mysql_timestamp():
return mysql.TIMESTAMP(fsp=6)


def sa_timestamp():
return sa.TIMESTAMP(timezone=True)


def upgrade():
# See 0e2a74e0fc9f_add_time_zone_awareness
conn = op.get_bind()
if conn.dialect.name == 'mysql':
timestamp = mysql_timestamp
elif conn.dialect.name == 'mssql':
timestamp = mssql_timestamp
else:
timestamp = sa_timestamp
timestamp = TIMESTAMP
if op.get_bind().dialect.name == 'mssql':
# We need to keep this as it was for this old migration on mssql
timestamp = sa.DateTime()

op.create_table(
TABLE_NAME,
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('task_id', sa.String(length=250, **COLLATION_ARGS), nullable=False),
sa.Column('dag_id', sa.String(length=250, **COLLATION_ARGS), nullable=False),
sa.Column('task_id', StringID(), nullable=False),
sa.Column('dag_id', StringID(), nullable=False),
# use explicit server_default=None otherwise mysql implies defaults for first timestamp column
sa.Column('execution_date', timestamp(), nullable=False, server_default=None),
sa.Column('execution_date', timestamp, nullable=False, server_default=None),
sa.Column('try_number', sa.Integer(), nullable=False),
sa.Column('start_date', timestamp(), nullable=False),
sa.Column('end_date', timestamp(), nullable=False),
sa.Column('start_date', timestamp, nullable=False),
sa.Column('end_date', timestamp, nullable=False),
sa.Column('duration', sa.Integer(), nullable=False),
sa.Column('reschedule_date', timestamp(), nullable=False),
sa.Column('reschedule_date', timestamp, nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.ForeignKeyConstraint(
['task_id', 'dag_id', 'execution_date'],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
"""

from alembic import op
from sqlalchemy import TIMESTAMP, Column
from sqlalchemy.dialects import mssql, mysql
from sqlalchemy import Column

from airflow.migrations.db_types import TIMESTAMP

# Revision identifiers, used by Alembic.
revision = "142555e44c17"
Expand All @@ -35,36 +36,14 @@
depends_on = None


def _use_date_time2(conn):
result = conn.execute(
"""SELECT CASE WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion'))
like '8%' THEN '2000' WHEN CONVERT(VARCHAR(128), SERVERPROPERTY ('productversion'))
like '9%' THEN '2005' ELSE '2005Plus' END AS MajorVersion"""
).fetchone()
mssql_version = result[0]
return mssql_version not in ("2000", "2005")


def _get_timestamp(conn):
dialect_name = conn.dialect.name
if dialect_name == "mysql":
return mysql.TIMESTAMP(fsp=6, timezone=True)
if dialect_name != "mssql":
return TIMESTAMP(timezone=True)
if _use_date_time2(conn):
return mssql.DATETIME2(precision=6)
return mssql.DATETIME


def upgrade():
"""Apply data_interval fields to DagModel and DagRun."""
column_type = _get_timestamp(op.get_bind())
with op.batch_alter_table("dag_run") as batch_op:
batch_op.add_column(Column("data_interval_start", column_type))
batch_op.add_column(Column("data_interval_end", column_type))
batch_op.add_column(Column("data_interval_start", TIMESTAMP))
batch_op.add_column(Column("data_interval_end", TIMESTAMP))
with op.batch_alter_table("dag") as batch_op:
batch_op.add_column(Column("next_dagrun_data_interval_start", column_type))
batch_op.add_column(Column("next_dagrun_data_interval_end", column_type))
batch_op.add_column(Column("next_dagrun_data_interval_start", TIMESTAMP))
batch_op.add_column(Column("next_dagrun_data_interval_end", TIMESTAMP))


def downgrade():
Expand Down
6 changes: 3 additions & 3 deletions airflow/migrations/versions/1b38cef5b76e_add_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import sqlalchemy as sa
from alembic import op

from airflow.models.base import COLLATION_ARGS
from airflow.migrations.db_types import StringID

# revision identifiers, used by Alembic.
revision = '1b38cef5b76e'
Expand All @@ -40,10 +40,10 @@ def upgrade():
op.create_table(
'dag_run',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('dag_id', sa.String(length=250, **COLLATION_ARGS), nullable=True),
sa.Column('dag_id', StringID(), nullable=True),
sa.Column('execution_date', sa.DateTime(), nullable=True),
sa.Column('state', sa.String(length=50), nullable=True),
sa.Column('run_id', sa.String(length=250, **COLLATION_ARGS), nullable=True),
sa.Column('run_id', StringID(), nullable=True),
sa.Column('external_trigger', sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('dag_id', 'execution_date'),
Expand Down
10 changes: 6 additions & 4 deletions airflow/migrations/versions/3c20cacc0044_add_dagrun_run_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.ext.declarative import declarative_base

from airflow.models.base import ID_LEN
from airflow.migrations.db_types import StringID
from airflow.utils import timezone
from airflow.utils.sqlalchemy import UtcDateTime
from airflow.utils.state import State
Expand All @@ -55,12 +55,12 @@ class DagRun(Base): # type: ignore
__tablename__ = "dag_run"

id = Column(Integer, primary_key=True)
dag_id = Column(String(ID_LEN))
dag_id = Column(StringID())
execution_date = Column(UtcDateTime, default=timezone.utcnow)
start_date = Column(UtcDateTime, default=timezone.utcnow)
end_date = Column(UtcDateTime)
_state = Column('state', String(50), default=State.RUNNING)
run_id = Column(String(ID_LEN))
run_id = Column(StringID())
external_trigger = Column(Boolean, default=True)
run_type = Column(String(50), nullable=False)
conf = Column(PickleType)
Expand Down Expand Up @@ -96,7 +96,9 @@ def upgrade():

# Make run_type not nullable
with op.batch_alter_table("dag_run") as batch_op:
batch_op.alter_column("run_type", type_=run_type_col_type, nullable=False)
batch_op.alter_column(
"run_type", existing_type=run_type_col_type, type_=run_type_col_type, nullable=False
)


def downgrade():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import sqlalchemy as sa
from alembic import op

from airflow.models.base import COLLATION_ARGS
from airflow.migrations.db_types import StringID

# revision identifiers, used by Alembic.
revision = '64de9cddf6c9'
Expand All @@ -39,8 +39,8 @@ def upgrade():
op.create_table(
'task_fail',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('task_id', sa.String(length=250, **COLLATION_ARGS), nullable=False),
sa.Column('dag_id', sa.String(length=250, **COLLATION_ARGS), nullable=False),
sa.Column('task_id', StringID(), nullable=False),
sa.Column('dag_id', StringID(), nullable=False),
sa.Column('execution_date', sa.DateTime(), nullable=False),
sa.Column('start_date', sa.DateTime(), nullable=True),
sa.Column('end_date', sa.DateTime(), nullable=True),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from sqlalchemy import Column, Float, Integer, PickleType, String
from sqlalchemy.ext.declarative import declarative_base

from airflow.models.base import COLLATION_ARGS
from airflow.migrations.db_types import StringID
from airflow.utils.session import create_session
from airflow.utils.sqlalchemy import UtcDateTime

Expand Down Expand Up @@ -60,8 +60,8 @@ class TaskInstance(Base): # type: ignore

__tablename__ = "task_instance"

task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True)
task_id = Column(StringID(), primary_key=True)
dag_id = Column(StringID(), primary_key=True)
execution_date = Column(UtcDateTime, primary_key=True)
start_date = Column(UtcDateTime)
end_date = Column(UtcDateTime)
Expand Down
4 changes: 2 additions & 2 deletions airflow/migrations/versions/7939bcff74ba_add_dagtags_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import sqlalchemy as sa
from alembic import op

from airflow.models.base import COLLATION_ARGS
from airflow.migrations.db_types import StringID

# revision identifiers, used by Alembic.
revision = '7939bcff74ba'
Expand All @@ -41,7 +41,7 @@ def upgrade():
op.create_table(
'dag_tag',
sa.Column('name', sa.String(length=100), nullable=False),
sa.Column('dag_id', sa.String(length=250, **COLLATION_ARGS), nullable=False),
sa.Column('dag_id', StringID(), nullable=False),
sa.ForeignKeyConstraint(
['dag_id'],
['dag.dag_id'],
Expand Down
Loading

0 comments on commit 7622f5e

Please sign in to comment.