Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_3_PLUS
from airflow.providers.common.compat.sdk import AirflowException, Stats, timezone
from airflow.utils.helpers import prune_dict

if TYPE_CHECKING:
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -94,7 +95,9 @@ def __init__(self, *args, **kwargs):
from airflow.providers.common.compat.sdk import conf

self.conf = conf

# Backwards compatibility for Airflow versions that do not define team_name.
if not hasattr(self, "team_name"):
self.team_name = None
self.lambda_function_name = self.conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.FUNCTION_NAME)
self.sqs_queue_url = self.conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.QUEUE_URL)
self.dlq_url = self.conf.get(CONFIG_GROUP_NAME, AllLambdaConfigKeys.DLQ_URL)
Expand Down Expand Up @@ -558,7 +561,9 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task

:param tis: The task instances to adopt.
"""
with Stats.timer("lambda_executor.adopt_task_instances.duration"):
with Stats.timer(
"lambda_executor.adopt_task_instances.duration", tags=prune_dict({"team_name": self.team_name})
):
adopted_tis: list[TaskInstance] = []

if serialized_workload_keys := [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from airflow.providers.amazon.aws.hooks.batch_client import BatchClientHook
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_3_PLUS
from airflow.providers.common.compat.sdk import AirflowException, Stats, timezone
from airflow.utils.helpers import merge_dicts
from airflow.utils.helpers import merge_dicts, prune_dict

if TYPE_CHECKING:
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -116,7 +116,9 @@ def __init__(self, *args, **kwargs):
from airflow.providers.common.compat.sdk import conf

self.conf = conf

# Backwards compatibility for Airflow versions that do not define team_name.
if not hasattr(self, "team_name"):
self.team_name = None
self.attempts_since_last_successful_connection = 0
self.load_batch_connection(check_connection=False)
self.IS_BOTO_CONNECTION_HEALTHY = False
Expand Down Expand Up @@ -540,7 +542,9 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task

Anything that is not adopted will be cleared by the scheduler and becomes eligible for re-scheduling.
"""
with Stats.timer("batch_executor.adopt_task_instances.duration"):
with Stats.timer(
"batch_executor.adopt_task_instances.duration", tags=prune_dict({"team_name": self.team_name})
):
adopted_tis: list[TaskInstance] = []

if job_ids := [ti.external_executor_id for ti in tis if ti.external_executor_id]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
from airflow.providers.amazon.aws.hooks.ecs import EcsHook
from airflow.providers.amazon.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_3_PLUS
from airflow.providers.common.compat.sdk import AirflowException, Stats, timezone
from airflow.utils.helpers import merge_dicts
from airflow.utils.helpers import merge_dicts, prune_dict
from airflow.utils.state import State

if TYPE_CHECKING:
Expand Down Expand Up @@ -126,6 +126,9 @@ def __init__(self, *args, **kwargs):
from airflow.providers.common.compat.sdk import conf

self.conf = conf
# Backwards compatibility for Airflow versions that do not define team_name.
if not hasattr(self, "team_name"):
self.team_name = None

self.cluster = self.conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.CLUSTER)
self.container_name = self.conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.CONTAINER_NAME)
Expand Down Expand Up @@ -641,7 +644,9 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task

Anything that is not adopted will be cleared by the scheduler and becomes eligible for re-scheduling.
"""
with Stats.timer("ecs_executor.adopt_task_instances.duration"):
with Stats.timer(
"ecs_executor.adopt_task_instances.duration", tags=prune_dict({"team_name": self.team_name})
):
adopted_tis: list[TaskInstance] = []

if task_arns := [ti.external_executor_id for ti in tis if ti.external_executor_id]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,33 @@ def test_try_adopt_task_instances_callback(self, mock_executor):

assert len(not_adopted) == 0

@pytest.mark.parametrize(
("team_name", "expected_tags"),
[
pytest.param(None, {}, id="without_team"),
pytest.param(
"team_a",
{"team_name": "team_a"},
id="with_team",
marks=pytest.mark.skipif(
not AIRFLOW_V_3_1_PLUS, reason="Multi-team support requires Airflow 3.1+"
),
),
],
)
@mock.patch.object(lambda_executor.Stats, "timer")
def test_try_adopt_task_instances_emits_team_name_tag(
self, mock_timer, mock_executor, team_name, expected_tags
):
"""Test that the adopt task instances duration metric is tagged with the team name."""
mock_executor.team_name = team_name

mock_executor.try_adopt_task_instances([])

mock_timer.assert_called_once_with(
"lambda_executor.adopt_task_instances.duration", tags=expected_tags
)

@mock.patch("airflow.providers.amazon.aws.executors.aws_lambda.lambda_executor.timezone")
def test_end_timeout(self, mock_timezone, mock_executor, mock_airflow_key):
"""Test that executor can end successfully; waiting for all workloads to naturally exit."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,31 @@ def test_team_config(self):
assert submit_kwargs["jobDefinition"] == "some-job-def"
assert submit_kwargs["jobName"] == "some-job-name"

@pytest.mark.parametrize(
("team_name", "expected_tags"),
[
pytest.param(None, {}, id="without_team"),
pytest.param(
"team_a",
{"team_name": "team_a"},
id="with_team",
marks=pytest.mark.skipif(
not AIRFLOW_V_3_1_PLUS, reason="Multi-team support requires Airflow 3.1+"
),
),
],
)
@mock.patch.object(batch_executor.Stats, "timer")
def test_try_adopt_task_instances_emits_team_name_tag(
self, mock_timer, mock_executor, team_name, expected_tags
):
"""Test that the adopt task instances duration metric is tagged with the team name."""
mock_executor.team_name = team_name

mock_executor.try_adopt_task_instances([])

mock_timer.assert_called_once_with("batch_executor.adopt_task_instances.duration", tags=expected_tags)


class TestBatchExecutorConfig:
@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@

from tests_common import RUNNING_TESTS_AGAINST_AIRFLOW_PACKAGES
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_3_PLUS
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_V_3_1_PLUS, AIRFLOW_V_3_3_PLUS

airflow_version = VersionInfo(*map(int, airflow_version_str.split(".")[:3]))

Expand Down Expand Up @@ -1334,6 +1334,31 @@ def test_try_adopt_task_instances(self, mock_executor):
# The remaining one task is unable to be adopted.
assert len(not_adopted_tasks) == 1

@pytest.mark.parametrize(
("team_name", "expected_tags"),
[
pytest.param(None, {}, id="without_team"),
pytest.param(
"team_a",
{"team_name": "team_a"},
id="with_team",
marks=pytest.mark.skipif(
not AIRFLOW_V_3_1_PLUS, reason="Multi-team support requires Airflow 3.1+"
),
),
],
)
@mock.patch.object(ecs_executor.Stats, "timer")
def test_try_adopt_task_instances_emits_team_name_tag(
self, mock_timer, mock_executor, team_name, expected_tags
):
"""Test that the adopt task instances duration metric is tagged with the team name."""
mock_executor.team_name = team_name

mock_executor.try_adopt_task_instances([])

mock_timer.assert_called_once_with("ecs_executor.adopt_task_instances.duration", tags=expected_tags)


class TestEcsExecutorConfig:
@pytest.fixture
Expand Down
Loading