Skip to content

Commit

Permalink
Remove AIP-44 from Job (#44493)
Browse files Browse the repository at this point in the history
Part of #44436
  • Loading branch information
potiuk authored Nov 30, 2024
1 parent 84907f1 commit 55e419e
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 106 deletions.
128 changes: 26 additions & 102 deletions airflow/jobs/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,19 @@
from sqlalchemy.orm import backref, foreign, relationship
from sqlalchemy.orm.session import make_transient

from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.executors.executor_loader import ExecutorLoader
from airflow.listeners.listener import get_listener_manager
from airflow.models.base import ID_LEN, Base
from airflow.serialization.pydantic.job import JobPydantic
from airflow.stats import Stats
from airflow.traces.tracer import Trace, add_span
from airflow.utils import timezone
from airflow.utils.helpers import convert_camel_to_snake
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.net import get_hostname
from airflow.utils.platform import getuser
from airflow.utils.retries import retry_db_transaction
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.sqlalchemy import UtcDateTime
from airflow.utils.state import JobState

Expand Down Expand Up @@ -168,7 +165,10 @@ def kill(self, session: Session = NEW_SESSION) -> NoReturn:
except Exception as e:
self.log.error("on_kill() method failed: %s", e)

Job._kill(job_id=self.id, session=session)
job = session.scalar(select(Job).where(Job.id == self.id, session=session).limit(1))
job.end_date = timezone.utcnow()
session.merge(job)
session.commit()
raise AirflowException("Job shut down externally.")

def on_kill(self):
Expand Down Expand Up @@ -201,7 +201,7 @@ def heartbeat(
try:
span.set_attribute("heartbeat", str(self.latest_heartbeat))
# This will cause it to load from the db
self._merge_from(Job._fetch_from_db(self, session))
session.merge(self)
previous_heartbeat = self.latest_heartbeat

if self.state == JobState.RESTARTING:
Expand All @@ -217,17 +217,19 @@ def heartbeat(
if span.is_recording():
span.add_event(name="sleep", attributes={"sleep_for": sleep_for})
sleep(sleep_for)

job = Job._update_heartbeat(job=self, session=session)
self._merge_from(job)
time_since_last_heartbeat = (timezone.utcnow() - previous_heartbeat).total_seconds()
health_check_threshold_value = health_check_threshold(self.job_type, self.heartrate)
if time_since_last_heartbeat > health_check_threshold_value:
self.log.info("Heartbeat recovered after %.2f seconds", time_since_last_heartbeat)
# At this point, the DB has updated.
previous_heartbeat = self.latest_heartbeat

heartbeat_callback(session)
# Update last heartbeat time
with create_session() as session:
# Make the session aware of this object
session.merge(self)
self.latest_heartbeat = timezone.utcnow()
session.commit()
time_since_last_heartbeat = (timezone.utcnow() - previous_heartbeat).total_seconds()
health_check_threshold_value = health_check_threshold(self.job_type, self.heartrate)
if time_since_last_heartbeat > health_check_threshold_value:
self.log.info("Heartbeat recovered after %.2f seconds", time_since_last_heartbeat)
# At this point, the DB has updated.
previous_heartbeat = self.latest_heartbeat
heartbeat_callback(session)
self.log.debug("[heartbeat]")
self.heartbeat_failed = False
except OperationalError:
Expand Down Expand Up @@ -260,36 +262,23 @@ def prepare_for_execution(self, session: Session = NEW_SESSION):
Stats.incr(self.__class__.__name__.lower() + "_start", 1, 1)
self.state = JobState.RUNNING
self.start_date = timezone.utcnow()
self._merge_from(Job._add_to_db(job=self, session=session))
session.add(self)
session.commit()
make_transient(self)

@provide_session
def complete_execution(self, session: Session = NEW_SESSION):
get_listener_manager().hook.before_stopping(component=self)
self.end_date = timezone.utcnow()
Job._update_in_db(job=self, session=session)
session.merge(self)
session.commit()
Stats.incr(self.__class__.__name__.lower() + "_end", 1, 1)

@provide_session
def most_recent_job(self, session: Session = NEW_SESSION) -> Job | JobPydantic | None:
def most_recent_job(self, session: Session = NEW_SESSION) -> Job | None:
"""Return the most recent job of this type, if any, based on last heartbeat received."""
return most_recent_job(self.job_type, session=session)

def _merge_from(self, job: Job | JobPydantic | None):
if job is None:
self.log.error("Job is empty: %s", self.id)
return
self.id = job.id
self.dag_id = job.dag_id
self.state = job.state
self.job_type = job.job_type
self.start_date = job.start_date
self.end_date = job.end_date
self.latest_heartbeat = job.latest_heartbeat
self.executor_class = job.executor_class
self.hostname = job.hostname
self.unixname = job.unixname

@staticmethod
def _heartrate(job_type: str) -> float:
if job_type == "TriggererJob":
Expand All @@ -312,74 +301,9 @@ def _is_alive(
and (timezone.utcnow() - latest_heartbeat).total_seconds() < health_check_threshold_value
)

@staticmethod
@internal_api_call
@provide_session
def _kill(job_id: str, session: Session = NEW_SESSION) -> Job | JobPydantic:
job = session.scalar(select(Job).where(Job.id == job_id).limit(1))
job.end_date = timezone.utcnow()
session.merge(job)
session.commit()
return job

@staticmethod
@internal_api_call
@provide_session
@retry_db_transaction
def _fetch_from_db(job: Job | JobPydantic, session: Session = NEW_SESSION) -> Job | JobPydantic | None:
if isinstance(job, Job):
# not Internal API
session.merge(job)
return job
# Internal API,
return session.scalar(select(Job).where(Job.id == job.id).limit(1))

@staticmethod
@internal_api_call
@provide_session
def _add_to_db(job: Job | JobPydantic, session: Session = NEW_SESSION) -> Job | JobPydantic:
if isinstance(job, JobPydantic):
orm_job = Job()
orm_job._merge_from(job)
else:
orm_job = job
session.add(orm_job)
session.commit()
return orm_job

@staticmethod
@internal_api_call
@provide_session
def _update_in_db(job: Job | JobPydantic, session: Session = NEW_SESSION):
if isinstance(job, Job):
# not Internal API
session.merge(job)
session.commit()
# Internal API.
orm_job: Job | None = session.scalar(select(Job).where(Job.id == job.id).limit(1))
if orm_job is None:
return
orm_job._merge_from(job)
session.merge(orm_job)
session.commit()

@staticmethod
@internal_api_call
@provide_session
@retry_db_transaction
def _update_heartbeat(job: Job | JobPydantic, session: Session = NEW_SESSION) -> Job | JobPydantic:
orm_job: Job | None = session.scalar(select(Job).where(Job.id == job.id).limit(1))
if orm_job is None:
return job
orm_job.latest_heartbeat = timezone.utcnow()
session.merge(orm_job)
session.commit()
return orm_job


@internal_api_call
@provide_session
def most_recent_job(job_type: str, session: Session = NEW_SESSION) -> Job | JobPydantic | None:
def most_recent_job(job_type: str, session: Session = NEW_SESSION) -> Job | None:
"""
Return the most recent job of this type, if any, based on last heartbeat received.
Expand Down Expand Up @@ -434,7 +358,7 @@ def execute_job(job: Job, execute_callable: Callable[[], int | None]) -> int | N
which happens in the "complete_execution" step (which again can be executed locally in case of
database operations or over the Internal API call.
:param job: Job to execute - it can be either DB job or it's Pydantic serialized version. It does
:param job: Job to execute - DB job. It does
not really matter, because except of running the heartbeat and state setting,
the runner should not modify the job state.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def _initialize_method_map() -> dict[str, Callable]:
expand_alias_to_assets,
FileTaskHandler._render_filename_db_access,
Job._add_to_db,
Job._fetch_from_db,
Job._kill,
Job._update_heartbeat,
Job._update_in_db,
Expand Down
9 changes: 6 additions & 3 deletions tests/jobs/test_base_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,13 @@ def test_is_alive_scheduler(self, job_type):
job.latest_heartbeat = timezone.utcnow() - datetime.timedelta(seconds=10)
assert job.is_alive() is False, "Completed jobs even with recent heartbeat should not be alive"

def test_heartbeat_failed(self, caplog):
@patch("airflow.jobs.job.create_session")
def test_heartbeat_failed(self, mock_create_session, caplog):
when = timezone.utcnow() - datetime.timedelta(seconds=60)
mock_session = Mock(name="MockSession")
mock_session.commit.side_effect = OperationalError("Force fail", {}, None)
with create_session() as session:
mock_session = Mock(spec_set=session, name="MockSession")
mock_create_session.return_value.__enter__.return_value = mock_session
mock_session.commit.side_effect = OperationalError("Force fail", {}, None)
job = Job(heartrate=10, state=State.RUNNING)
job.latest_heartbeat = when
with caplog.at_level(logging.ERROR):
Expand Down

0 comments on commit 55e419e

Please sign in to comment.