Skip to content

Commit

Permalink
Remove Pydantic classes from models/dag (#44509)
Browse files Browse the repository at this point in the history
* Fix: Remove Pydantic classes from models/dag

* Fix mypy error
  • Loading branch information
jason810496 authored Nov 30, 2024
1 parent 0c354e7 commit f58bd73
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 8 deletions.
10 changes: 4 additions & 6 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,6 @@
from airflow.models.abstractoperator import TaskStateChangeCallback
from airflow.models.dagbag import DagBag
from airflow.models.operator import Operator
from airflow.serialization.pydantic.dag import DagModelPydantic
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.typing_compat import Literal

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -513,7 +511,7 @@ def get_next_data_interval(self, dag_model: DagModel) -> DataInterval | None:
# infer from the logical date.
return self.infer_automated_data_interval(dag_model.next_dagrun)

def get_run_data_interval(self, run: DagRun | DagRunPydantic) -> DataInterval:
def get_run_data_interval(self, run: DagRun) -> DataInterval:
"""
Get the data interval of this run.
Expand Down Expand Up @@ -873,7 +871,7 @@ def get_active_runs(self):

@staticmethod
@provide_session
def fetch_dagrun(dag_id: str, run_id: str, session: Session = NEW_SESSION) -> DagRun | DagRunPydantic:
def fetch_dagrun(dag_id: str, run_id: str, session: Session = NEW_SESSION) -> DagRun:
"""
Return the dag run for a given run_id if it exists, otherwise none.
Expand All @@ -885,7 +883,7 @@ def fetch_dagrun(dag_id: str, run_id: str, session: Session = NEW_SESSION) -> Da
return session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, DagRun.run_id == run_id))

@provide_session
def get_dagrun(self, run_id: str, session: Session = NEW_SESSION) -> DagRun | DagRunPydantic:
def get_dagrun(self, run_id: str, session: Session = NEW_SESSION) -> DagRun:
return DAG.fetch_dagrun(dag_id=self.dag_id, run_id=run_id, session=session)

@provide_session
Expand Down Expand Up @@ -2139,7 +2137,7 @@ def get_dagmodel(dag_id: str, session: Session = NEW_SESSION) -> DagModel | None

@classmethod
@provide_session
def get_current(cls, dag_id: str, session=NEW_SESSION) -> DagModel | DagModelPydantic:
def get_current(cls, dag_id: str, session=NEW_SESSION) -> DagModel:
return session.scalar(select(cls).where(cls.dag_id == dag_id))

@provide_session
Expand Down
3 changes: 1 addition & 2 deletions airflow/utils/log/file_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
from airflow.models import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -271,7 +270,7 @@ def close(self):
@provide_session
def _render_filename_db_access(
*, ti: TaskInstance | TaskInstancePydantic, try_number: int, session=None
) -> tuple[DagRun | DagRunPydantic, TaskInstance | TaskInstancePydantic, str | None, str | None]:
) -> tuple[DagRun, TaskInstance | TaskInstancePydantic, str | None, str | None]:
ti = _ensure_ti(ti, session)
dag_run = ti.get_dagrun(session=session)
template = dag_run.get_log_template(session=session).filename
Expand Down

0 comments on commit f58bd73

Please sign in to comment.