diff --git a/airflow-core/newsfragments/69075.bugfix.rst b/airflow-core/newsfragments/69075.bugfix.rst new file mode 100644 index 0000000000000..d992083c37820 --- /dev/null +++ b/airflow-core/newsfragments/69075.bugfix.rst @@ -0,0 +1 @@ +Fix mapped task group return value handed to a downstream task being a bare value instead of a one-element list when the group expanded only once (and ``None`` instead of an empty list when every expansion returned ``None``). diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 595490832597f..880517be4f498 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -354,8 +354,11 @@ def resolve(self, context: Mapping[str, Any]) -> Any: ti_count=ti_count, session=None, # Not used in SDK implementation ) - # None means "no filtering needed" -> use NOTSET to pull all values - map_indexes = NOTSET if computed is None else computed + if computed is None: + # Aggregate the mapped task group as a list, even for a single expansion (#69036) or all-None values (#48005) + # Materialise eagerly (one slice request) so a task returning the value unchanged can still serialize it + return LazyXComSequence(xcom_arg=self, ti=ti)[:] + map_indexes = computed result = ti.xcom_pull( task_ids=task_id, key=self.key, diff --git a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py index 6014d26f2208b..623c00ff8c98f 100644 --- a/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py +++ b/task-sdk/tests/task_sdk/definitions/test_xcom_arg.py @@ -350,3 +350,69 @@ def xcom_get(msg): states = [run_ti(dag, "pull_one", map_index) for map_index in range(5)] assert states == [TaskInstanceState.SUCCESS] * 5 assert agg_results == {"a", "b", "c", 1, 2} + + +class TestPlainXComArgResolveMappedGroup: + """Resolving a task inside a mapped task group from a task outside that group. + + Regression tests for #69036 and #48005: the combined return value of a + mapped task group must always be an eager list (one element per expansion), + even when the group expanded only once or every expansion returned ``None``. + Previously this case was routed through ``xcom_pull`` pulling all map + indexes, which collapsed a single value to a bare scalar and an empty set + of values to ``None``. The list must be materialised eagerly so a task that + returns the value unchanged can still push it to XCom. + """ + + @staticmethod + def _make_ti(*, computed): + ti = mock.MagicMock() + ti._upstream_map_indexes = None + ti._cached_template_context = {"expanded_ti_count": 1} + ti.run_id = "run-1" + ti.get_relevant_upstream_map_indexes.return_value = computed + return ti + + @staticmethod + def _make_arg(): + from airflow.sdk.definitions.xcom_arg import PlainXComArg + + operator = mock.MagicMock() + operator.is_mapped = False + operator.task_id = "do_something" + operator.dag_id = "test_dag" + operator.get_closest_mapped_task_group.return_value = mock.MagicMock() + return PlainXComArg(operator=operator, key="test") + + @pytest.mark.parametrize( + ("root", "expected"), + [ + pytest.param(["14"], ["14"], id="single-expansion-stays-a-list"), + pytest.param([], [], id="all-none-expansions-give-empty-list"), + pytest.param(["a", "b"], ["a", "b"], id="multiple-expansions"), + ], + ) + def test_resolve_aggregates_mapped_group_as_eager_list(self, root, expected, mock_supervisor_comms): + from airflow.sdk.execution_time.comms import XComSequenceSliceResult + + mock_supervisor_comms.send.return_value = XComSequenceSliceResult(root=root) + + arg = self._make_arg() + ti = self._make_ti(computed=None) + + resolved = arg.resolve({"ti": ti}) + + assert resolved == expected + assert isinstance(resolved, list) + ti.xcom_pull.assert_not_called() + + def test_resolve_uses_xcom_pull_for_specific_index(self): + arg = self._make_arg() + ti = self._make_ti(computed=0) + ti.xcom_pull.return_value = "value-0" + + resolved = arg.resolve({"ti": ti}) + + assert resolved == "value-0" + ti.xcom_pull.assert_called_once() + assert ti.xcom_pull.call_args.kwargs["map_indexes"] == 0