diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 056151cefa1c1..b78ea5ef124bf 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -57,7 +57,16 @@ from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.mutable import MutableDict -from sqlalchemy.orm import Mapped, declared_attr, joinedload, mapped_column, relationship, synonym, validates +from sqlalchemy.orm import ( + Mapped, + declared_attr, + joinedload, + mapped_column, + relationship, + selectinload, + synonym, + validates, +) from sqlalchemy.orm.exc import StaleDataError from sqlalchemy.sql.expression import false, select from sqlalchemy.sql.functions import coalesce @@ -1164,6 +1173,27 @@ def _emit_dagrun_span(self, state: DagRunState): span.set_status(status_code) span.end() + def _handle_missed_deadlines(self, *, session: Session) -> None: + deadline_query = ( + select(Deadline) + .where(Deadline.dagrun_id == self.id) + .where(~Deadline.missed) + .options(selectinload(Deadline.callback), selectinload(Deadline.dagrun)) + ) + try: + for deadline in session.scalars( + with_row_locks( + deadline_query, + of=Deadline, + session=session, + skip_locked=True, + key_share=False, + ) + ): + deadline.handle_miss(session) + except Exception: + self.log.warning("Failed to handle missed deadlines for %s", self, exc_info=True) + @provide_session def update_state( self, *, session: Session = NEW_SESSION, execute_callbacks: bool = True @@ -1264,6 +1294,9 @@ def recalculate(self) -> _UnfinishedStates: ) self._check_last_n_dagruns_failed(dag.dag_id, dag.max_consecutive_failed_dag_runs, session) + if dag.deadline: + self._handle_missed_deadlines(session=session) + # if all leaves succeeded and no unfinished tasks, the run succeeded elif not unfinished.tis and all(x.state in State.success_states for x in tis_for_dagrun_state): self.log.info("Marking run %s successful", self) @@ -1321,6 +1354,9 @@ def recalculate(self) -> _UnfinishedStates: execute=execute_callbacks, ) + if dag.deadline: + self._handle_missed_deadlines(session=session) + # finally, if the leaves aren't done, the dag is still running else: self.set_state(DagRunState.RUNNING) diff --git a/airflow-core/src/airflow/models/deadline.py b/airflow-core/src/airflow/models/deadline.py index f596c66301498..7af0cca3fc79d 100644 --- a/airflow-core/src/airflow/models/deadline.py +++ b/airflow-core/src/airflow/models/deadline.py @@ -242,15 +242,12 @@ def get_simple_context(): return { "dag_run": DAGRunResponse.model_validate(dagrun).model_dump(mode="json"), - "deadline": {"id": self.id, "deadline_time": self.deadline_time}, + "deadline": {"id": str(self.id), "deadline_time": self.deadline_time}, } if isinstance(self.callback, TriggererCallback): - # Update the callback with context before queuing - if "kwargs" not in self.callback.data: - self.callback.data["kwargs"] = {} - self.callback.data["kwargs"] = (self.callback.data.get("kwargs") or {}) | { - "context": get_simple_context() + self.callback.data = self.callback.data | { + "kwargs": (self.callback.data.get("kwargs") or {}) | {"context": get_simple_context()}, } self.callback.queue(session=session) @@ -258,14 +255,12 @@ def get_simple_context(): session.flush() elif isinstance(self.callback, ExecutorCallback): - if "kwargs" not in self.callback.data: - self.callback.data["kwargs"] = {} - self.callback.data["kwargs"] = (self.callback.data.get("kwargs") or {}) | { - "context": get_simple_context() + self.callback.data = self.callback.data | { + "kwargs": (self.callback.data.get("kwargs") or {}) | {"context": get_simple_context()}, + "deadline_id": str(self.id), + "dag_run_id": str(self.dagrun.id), + "dag_id": self.dagrun.dag_id, } - self.callback.data["deadline_id"] = str(self.id) - self.callback.data["dag_run_id"] = str(self.dagrun.id) - self.callback.data["dag_id"] = self.dagrun.dag_id self.callback.state = CallbackState.PENDING session.add(self.callback) diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index 4ec048a7f4bcc..4242ed72f54c9 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -1505,6 +1505,76 @@ def test_dagrun_success_handles_empty_deadline_list(self, mock_prune, dag_maker, mock_prune.assert_not_called() assert dag_run.state == DagRunState.SUCCESS + @mock.patch.object(Deadline, "handle_miss") + @mock.patch.object(Deadline, "prune_deadlines") + def test_dagrun_failure_handles_pending_deadline( + self, mock_prune, mock_handle_miss, session, deadline_test_dag + ): + scheduler_dag = deadline_test_dag( + deadline=DeadlineAlert( + reference=DeadlineReference.DAGRUN_QUEUED_AT, + interval=datetime.timedelta(hours=3), + callback=AsyncCallback(empty_callback_for_deadline), + ) + ) + + dag_run = self.create_dag_run( + dag=scheduler_dag, + task_states={"task_1": TaskInstanceState.SUCCESS, "task_2": TaskInstanceState.FAILED}, + session=session, + ) + dag_run.dag = scheduler_dag + session.add( + Deadline( + deadline_time=timezone.utcnow() + datetime.timedelta(hours=3), + callback=AsyncCallback(empty_callback_for_deadline), + dagrun_id=dag_run.id, + dag_id=dag_run.dag_id, + deadline_alert_id=None, + ) + ) + session.flush() + + dag_run.update_state(session=session) + + mock_handle_miss.assert_called_once() + mock_prune.assert_not_called() + assert dag_run.state == DagRunState.FAILED + + @mock.patch.object(Deadline, "handle_miss", side_effect=RuntimeError("deadline failure")) + def test_dagrun_failure_ignores_missed_deadline_handling_error( + self, mock_handle_miss, session, deadline_test_dag + ): + scheduler_dag = deadline_test_dag( + deadline=DeadlineAlert( + reference=DeadlineReference.DAGRUN_QUEUED_AT, + interval=datetime.timedelta(hours=3), + callback=AsyncCallback(empty_callback_for_deadline), + ) + ) + + dag_run = self.create_dag_run( + dag=scheduler_dag, + task_states={"task_1": TaskInstanceState.SUCCESS, "task_2": TaskInstanceState.FAILED}, + session=session, + ) + dag_run.dag = scheduler_dag + session.add( + Deadline( + deadline_time=timezone.utcnow() + datetime.timedelta(hours=3), + callback=AsyncCallback(empty_callback_for_deadline), + dagrun_id=dag_run.id, + dag_id=dag_run.dag_id, + deadline_alert_id=None, + ) + ) + session.flush() + + dag_run.update_state(session=session) + + mock_handle_miss.assert_called_once() + assert dag_run.state == DagRunState.FAILED + @mock.patch.object(Variable, "get") @mock.patch.object(Deadline, "prune_deadlines") def test_dagrun_deadline_variable_interval_stable(self, _, mock_get, session, deadline_test_dag): diff --git a/airflow-core/tests/unit/models/test_deadline.py b/airflow-core/tests/unit/models/test_deadline.py index 4fc1aada1cc02..99f3d9b127b1b 100644 --- a/airflow-core/tests/unit/models/test_deadline.py +++ b/airflow-core/tests/unit/models/test_deadline.py @@ -252,10 +252,73 @@ def test_handle_miss(self, dagrun, session): context = callback_kwargs.pop("context") assert callback_kwargs == TEST_CALLBACK_KWARGS - assert context["deadline"]["id"] == deadline_orm.id + assert context["deadline"]["id"] == str(deadline_orm.id) assert context["deadline"]["deadline_time"].timestamp() == deadline_orm.deadline_time.timestamp() assert context["dag_run"] == DAGRunResponse.model_validate(dagrun).model_dump(mode="json") + @pytest.mark.db_test + def test_handle_miss_persists_triggerer_callback_context(self, dagrun, session): + deadline_orm = Deadline( + deadline_time=DEFAULT_DATE, + callback=AsyncCallback(TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS), + dagrun_id=dagrun.id, + dag_id=dagrun.dag_id, + deadline_alert_id=None, + ) + session.add(deadline_orm) + session.flush() + + callback_id = deadline_orm.callback.id + deadline_id = deadline_orm.id + deadline_time = deadline_orm.deadline_time + expected_dag_run = DAGRunResponse.model_validate(dagrun).model_dump(mode="json") + + deadline_orm.handle_miss(session) + session.commit() + session.expunge_all() + + callback = session.scalar(select(Deadline).where(Deadline.id == deadline_id)).callback + assert callback.id == callback_id + + callback_kwargs = callback.data["kwargs"] + context = callback_kwargs["context"] + assert { + key: value for key, value in callback_kwargs.items() if key != "context" + } == TEST_CALLBACK_KWARGS + assert context["deadline"]["id"] == str(deadline_id) + assert context["deadline"]["deadline_time"].timestamp() == deadline_time.timestamp() + assert context["dag_run"] == expected_dag_run + + callback.trigger = None + session.commit() + + @pytest.mark.db_test + def test_handle_miss_persists_executor_callback_routing_data(self, dagrun, session): + deadline_orm = Deadline( + deadline_time=DEFAULT_DATE, + callback=SyncCallback(TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS), + dagrun_id=dagrun.id, + dag_id=dagrun.dag_id, + deadline_alert_id=None, + ) + session.add(deadline_orm) + session.flush() + + callback_id = deadline_orm.callback.id + deadline_id = deadline_orm.id + dagrun_id = dagrun.id + dag_id = dagrun.dag_id + + deadline_orm.handle_miss(session) + session.commit() + session.expunge_all() + + callback = session.scalar(select(Deadline).where(Deadline.id == deadline_id)).callback + assert callback.id == callback_id + assert callback.data["dag_run_id"] == str(dagrun_id) + assert callback.data["dag_id"] == dag_id + assert callback.data["deadline_id"] == str(deadline_id) + @pytest.mark.db_test class TestCalculatedDeadlineDatabaseCalls: