diff --git a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py index 7fc388c9608a8..ba3af24415a0d 100644 --- a/providers/celery/src/airflow/providers/celery/executors/celery_executor.py +++ b/providers/celery/src/airflow/providers/celery/executors/celery_executor.py @@ -44,6 +44,7 @@ ) from airflow.providers.celery.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_2_PLUS from airflow.providers.common.compat.sdk import AirflowTaskTimeout, Stats +from airflow.utils.helpers import prune_dict from airflow.utils.state import TaskInstanceState log = logging.getLogger(__name__) @@ -210,7 +211,7 @@ def _send_workloads(self, workload_tuples_to_send: Sequence[WorkloadInCelery]): ): retries = self.workload_publish_retries[key] if retries < self.workload_publish_max_retries: - Stats.incr("celery.task_timeout_error") + Stats.incr("celery.task_timeout_error", tags=prune_dict({"team_name": self.team_name})) self.log.info( "[Try %s of %s] Celery Task Timeout Error for Workload: (%s).", self.workload_publish_retries[key] + 1, diff --git a/providers/celery/tests/unit/celery/executors/test_celery_executor.py b/providers/celery/tests/unit/celery/executors/test_celery_executor.py index c349efa32cb4e..f66c8e81ab3bf 100644 --- a/providers/celery/tests/unit/celery/executors/test_celery_executor.py +++ b/providers/celery/tests/unit/celery/executors/test_celery_executor.py @@ -39,7 +39,7 @@ from airflow.models.taskinstance import TaskInstance, TaskInstanceKey from airflow.providers.celery.executors import celery_executor, celery_executor_utils, default_celery from airflow.providers.celery.executors.celery_executor import CeleryExecutor -from airflow.providers.common.compat.sdk import conf +from airflow.providers.common.compat.sdk import AirflowTaskTimeout, conf from airflow.utils.state import State from tests_common.test_utils import db @@ -216,6 +216,64 @@ def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock ] mock_stats_gauge.assert_has_calls(calls) + @pytest.mark.backend("mysql", "postgres") + @pytest.mark.parametrize( + ("team_name", "tags"), + [ + pytest.param( + None, + {}, + id="without_team", + ), + pytest.param( + "team_a", + {"team_name": "team_a"}, + id="with_team", + marks=pytest.mark.skipif( + not AIRFLOW_V_3_1_PLUS, + reason="team_name metrics require Airflow 3.1+", + ), + ), + ], + ) + @mock.patch("airflow.providers.celery.executors.celery_executor.Stats") + def test_send_workloads_emits_task_timeout_metric( + self, + mock_stats, + team_name, + tags, + ): + with _prepare_app(): + executor = celery_executor.CeleryExecutor() + executor.team_name = team_name + + key = TaskInstanceKey( + dag_id="dag", + task_id="task", + run_id="run", + try_number=1, + ) + timeout = AirflowTaskTimeout() + exception = celery_executor_utils.ExceptionWithTraceback(timeout, "traceback") + + executor.workload_publish_max_retries = 3 + executor.workload_publish_retries[key] = 0 + executor.queued_tasks[key] = mock.Mock() + + with mock.patch.object( + executor, + "_send_workloads_to_celery", + return_value=[(key, None, exception)], + ): + executor._send_workloads([mock.Mock()]) + + mock_stats.incr.assert_called_once_with( + "celery.task_timeout_error", + tags=tags, + ) + + assert executor.workload_publish_retries[key] == 1 + @pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Airflow 3 doesn't have execute_command anymore") @pytest.mark.parametrize( ("command", "raise_exception"),