diff --git a/airflow-core/src/airflow/models/deadline.py b/airflow-core/src/airflow/models/deadline.py index f596c66301498..7af0cca3fc79d 100644 --- a/airflow-core/src/airflow/models/deadline.py +++ b/airflow-core/src/airflow/models/deadline.py @@ -242,15 +242,12 @@ def get_simple_context(): return { "dag_run": DAGRunResponse.model_validate(dagrun).model_dump(mode="json"), - "deadline": {"id": self.id, "deadline_time": self.deadline_time}, + "deadline": {"id": str(self.id), "deadline_time": self.deadline_time}, } if isinstance(self.callback, TriggererCallback): - # Update the callback with context before queuing - if "kwargs" not in self.callback.data: - self.callback.data["kwargs"] = {} - self.callback.data["kwargs"] = (self.callback.data.get("kwargs") or {}) | { - "context": get_simple_context() + self.callback.data = self.callback.data | { + "kwargs": (self.callback.data.get("kwargs") or {}) | {"context": get_simple_context()}, } self.callback.queue(session=session) @@ -258,14 +255,12 @@ def get_simple_context(): session.flush() elif isinstance(self.callback, ExecutorCallback): - if "kwargs" not in self.callback.data: - self.callback.data["kwargs"] = {} - self.callback.data["kwargs"] = (self.callback.data.get("kwargs") or {}) | { - "context": get_simple_context() + self.callback.data = self.callback.data | { + "kwargs": (self.callback.data.get("kwargs") or {}) | {"context": get_simple_context()}, + "deadline_id": str(self.id), + "dag_run_id": str(self.dagrun.id), + "dag_id": self.dagrun.dag_id, } - self.callback.data["deadline_id"] = str(self.id) - self.callback.data["dag_run_id"] = str(self.dagrun.id) - self.callback.data["dag_id"] = self.dagrun.dag_id self.callback.state = CallbackState.PENDING session.add(self.callback) diff --git a/airflow-core/tests/unit/models/test_deadline.py b/airflow-core/tests/unit/models/test_deadline.py index 4fc1aada1cc02..99f3d9b127b1b 100644 --- a/airflow-core/tests/unit/models/test_deadline.py +++ b/airflow-core/tests/unit/models/test_deadline.py @@ -252,10 +252,73 @@ def test_handle_miss(self, dagrun, session): context = callback_kwargs.pop("context") assert callback_kwargs == TEST_CALLBACK_KWARGS - assert context["deadline"]["id"] == deadline_orm.id + assert context["deadline"]["id"] == str(deadline_orm.id) assert context["deadline"]["deadline_time"].timestamp() == deadline_orm.deadline_time.timestamp() assert context["dag_run"] == DAGRunResponse.model_validate(dagrun).model_dump(mode="json") + @pytest.mark.db_test + def test_handle_miss_persists_triggerer_callback_context(self, dagrun, session): + deadline_orm = Deadline( + deadline_time=DEFAULT_DATE, + callback=AsyncCallback(TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS), + dagrun_id=dagrun.id, + dag_id=dagrun.dag_id, + deadline_alert_id=None, + ) + session.add(deadline_orm) + session.flush() + + callback_id = deadline_orm.callback.id + deadline_id = deadline_orm.id + deadline_time = deadline_orm.deadline_time + expected_dag_run = DAGRunResponse.model_validate(dagrun).model_dump(mode="json") + + deadline_orm.handle_miss(session) + session.commit() + session.expunge_all() + + callback = session.scalar(select(Deadline).where(Deadline.id == deadline_id)).callback + assert callback.id == callback_id + + callback_kwargs = callback.data["kwargs"] + context = callback_kwargs["context"] + assert { + key: value for key, value in callback_kwargs.items() if key != "context" + } == TEST_CALLBACK_KWARGS + assert context["deadline"]["id"] == str(deadline_id) + assert context["deadline"]["deadline_time"].timestamp() == deadline_time.timestamp() + assert context["dag_run"] == expected_dag_run + + callback.trigger = None + session.commit() + + @pytest.mark.db_test + def test_handle_miss_persists_executor_callback_routing_data(self, dagrun, session): + deadline_orm = Deadline( + deadline_time=DEFAULT_DATE, + callback=SyncCallback(TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS), + dagrun_id=dagrun.id, + dag_id=dagrun.dag_id, + deadline_alert_id=None, + ) + session.add(deadline_orm) + session.flush() + + callback_id = deadline_orm.callback.id + deadline_id = deadline_orm.id + dagrun_id = dagrun.id + dag_id = dagrun.dag_id + + deadline_orm.handle_miss(session) + session.commit() + session.expunge_all() + + callback = session.scalar(select(Deadline).where(Deadline.id == deadline_id)).callback + assert callback.id == callback_id + assert callback.data["dag_run_id"] == str(dagrun_id) + assert callback.data["dag_id"] == dag_id + assert callback.data["deadline_id"] == str(deadline_id) + @pytest.mark.db_test class TestCalculatedDeadlineDatabaseCalls: