Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airflow-core/newsfragments/69075.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix mapped task group return value handed to a downstream task being a bare value instead of a one-element list when the group expanded only once (and ``None`` instead of an empty list when every expansion returned ``None``).
7 changes: 5 additions & 2 deletions task-sdk/src/airflow/sdk/definitions/xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,11 @@ def resolve(self, context: Mapping[str, Any]) -> Any:
ti_count=ti_count,
session=None, # Not used in SDK implementation
)
# None means "no filtering needed" -> use NOTSET to pull all values
map_indexes = NOTSET if computed is None else computed
if computed is None:
# Aggregate the mapped task group as a list, even for a single expansion (#69036) or all-None values (#48005)
# Materialise eagerly (one slice request) so a task returning the value unchanged can still serialize it
return LazyXComSequence(xcom_arg=self, ti=ti)[:]
map_indexes = computed
Comment on lines +358 to +361

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this has the potential to kill performance is the map size is large. I wonder if it would be appropriate to return the lazy sequence directly, and handle it on XCom serialization if it’s actually returned instead.

result = ti.xcom_pull(
task_ids=task_id,
key=self.key,
Expand Down
66 changes: 66 additions & 0 deletions task-sdk/tests/task_sdk/definitions/test_xcom_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,69 @@ def xcom_get(msg):
states = [run_ti(dag, "pull_one", map_index) for map_index in range(5)]
assert states == [TaskInstanceState.SUCCESS] * 5
assert agg_results == {"a", "b", "c", 1, 2}


class TestPlainXComArgResolveMappedGroup:
"""Resolving a task inside a mapped task group from a task outside that group.
Regression tests for #69036 and #48005: the combined return value of a
mapped task group must always be an eager list (one element per expansion),
even when the group expanded only once or every expansion returned ``None``.
Previously this case was routed through ``xcom_pull`` pulling all map
indexes, which collapsed a single value to a bare scalar and an empty set
of values to ``None``. The list must be materialised eagerly so a task that
returns the value unchanged can still push it to XCom.
"""

@staticmethod
def _make_ti(*, computed):
ti = mock.MagicMock()
ti._upstream_map_indexes = None
ti._cached_template_context = {"expanded_ti_count": 1}
ti.run_id = "run-1"
ti.get_relevant_upstream_map_indexes.return_value = computed
return ti

@staticmethod
def _make_arg():
from airflow.sdk.definitions.xcom_arg import PlainXComArg

operator = mock.MagicMock()
operator.is_mapped = False
operator.task_id = "do_something"
operator.dag_id = "test_dag"
operator.get_closest_mapped_task_group.return_value = mock.MagicMock()
return PlainXComArg(operator=operator, key="test")

@pytest.mark.parametrize(
("root", "expected"),
[
pytest.param(["14"], ["14"], id="single-expansion-stays-a-list"),
pytest.param([], [], id="all-none-expansions-give-empty-list"),
pytest.param(["a", "b"], ["a", "b"], id="multiple-expansions"),
],
)
def test_resolve_aggregates_mapped_group_as_eager_list(self, root, expected, mock_supervisor_comms):
from airflow.sdk.execution_time.comms import XComSequenceSliceResult

mock_supervisor_comms.send.return_value = XComSequenceSliceResult(root=root)

arg = self._make_arg()
ti = self._make_ti(computed=None)

resolved = arg.resolve({"ti": ti})

assert resolved == expected
assert isinstance(resolved, list)
ti.xcom_pull.assert_not_called()

def test_resolve_uses_xcom_pull_for_specific_index(self):
arg = self._make_arg()
ti = self._make_ti(computed=0)
ti.xcom_pull.return_value = "value-0"

resolved = arg.resolve({"ti": ti})

assert resolved == "value-0"
ti.xcom_pull.assert_called_once()
assert ti.xcom_pull.call_args.kwargs["map_indexes"] == 0
Loading