From 3797ee29a8e45f7c1e442c3f13f92b3b8543b1fd Mon Sep 17 00:00:00 2001 From: Hsin-Fang Chiang Date: Thu, 25 Jun 2026 13:23:57 -0700 Subject: [PATCH] Add prune_unanchored_quanta to QuantumGraphBuilder This allows a pruning after the graph and cross-task reachability is fully settled, after all per-task _adjust_task_quanta. --- doc/changes/DM-55320.feature.md | 1 + .../lsst/pipe/base/quantum_graph_builder.py | 36 +++++++++ .../lsst/pipe/base/quantum_graph_skeleton.py | 59 ++++++++++++++ tests/test_graphBuilder.py | 46 +++++++++++ tests/test_graphSkeleton.py | 81 +++++++++++++++++++ 5 files changed, 223 insertions(+) create mode 100644 doc/changes/DM-55320.feature.md create mode 100644 tests/test_graphSkeleton.py diff --git a/doc/changes/DM-55320.feature.md b/doc/changes/DM-55320.feature.md new file mode 100644 index 000000000..65605d54a --- /dev/null +++ b/doc/changes/DM-55320.feature.md @@ -0,0 +1 @@ +Add prune_unanchored_quanta parameter to QuantumGraphBuilder diff --git a/python/lsst/pipe/base/quantum_graph_builder.py b/python/lsst/pipe/base/quantum_graph_builder.py index 5e7c4c8dc..eb2cf8692 100644 --- a/python/lsst/pipe/base/quantum_graph_builder.py +++ b/python/lsst/pipe/base/quantum_graph_builder.py @@ -138,6 +138,12 @@ class QuantumGraphBuilder(ABC): the upstream quanta that need to regenerate those intermediates to also run. Has no effect without ``skip_existing_in``. ``["*"]`` means retaining all datasets, equivalent to not providing this option. + prune_unanchored_quanta : `tuple` [ `str`, `str` ], optional + A ``(source_label, anchor_label)`` pair of task labels triggering + unanchored-quantum pruning after the skeleton is assembled. A + ``source_label`` quantum is removed along with its entire downstream + chain if no ``anchor_label`` quantum is reachable from it along + directed graph edges. clobber : `bool`, optional Whether to raise if predicted outputs already exist in ``output_run`` (not including those quanta that would be skipped because they've @@ -182,6 +188,7 @@ def __init__( output_run: str | None = None, skip_existing_in: Sequence[str] = (), retained_dataset_types: Sequence[str] | None = None, + prune_unanchored_quanta: tuple[str, str] | None = None, clobber: bool = False, ): self.log = getLogger(__name__) @@ -206,6 +213,7 @@ def __init__( raise ValueError("retained_dataset_types has no effect without skip_existing_in.") self.empty_data_id = DataCoordinate.make_empty(butler.dimensions) self.clobber = clobber + self._prune_unanchored_quanta = prune_unanchored_quanta # See whether the output run already exists. self.output_run_exists = False try: @@ -249,6 +257,21 @@ def __init__( task_node.label: PrerequisiteInfo(task_node, self._pipeline_graph) for task_node in pipeline_graph.tasks.values() } + if self._prune_unanchored_quanta is not None: + source_label, anchor_label = self._prune_unanchored_quanta + if source_label not in self._pipeline_graph.tasks: + self.log.warning( + "prune_unanchored_quanta source label %r is not present in the pipeline; " + "pruning will have no effect.", + source_label, + ) + elif anchor_label not in self._pipeline_graph.tasks: + self.log.warning( + "prune_unanchored_quanta anchor label %r is not present in the pipeline; " + "all %r quanta will be treated as unanchored and removed.", + anchor_label, + source_label, + ) log: LsstLogAdapter """Logger to use for all quantum-graph generation messages. @@ -470,6 +493,19 @@ def _build_skeleton(self, attach_datastore_records: bool = True) -> QuantumGraph # with the quanta because no quantum knows if its the only # consumer). full_skeleton.remove_orphan_datasets() + if self._prune_unanchored_quanta is not None: + source_label, anchor_label = self._prune_unanchored_quanta + n_source, n_downstream = full_skeleton.remove_unanchored_quanta(source_label, anchor_label) + if n_source: + self.log.info( + "Pruned %d unanchored %r quanta and %d downstream quanta (%d total) based on %r.", + n_source, + source_label, + n_downstream, + n_source + n_downstream, + anchor_label, + ) + full_skeleton.remove_orphan_datasets() if attach_datastore_records: self._attach_datastore_records(full_skeleton) return full_skeleton diff --git a/python/lsst/pipe/base/quantum_graph_skeleton.py b/python/lsst/pipe/base/quantum_graph_skeleton.py index 386b3f7c3..b859b8bed 100644 --- a/python/lsst/pipe/base/quantum_graph_skeleton.py +++ b/python/lsst/pipe/base/quantum_graph_skeleton.py @@ -562,6 +562,65 @@ def remove_orphan_datasets(self) -> None: if not orphan.is_task and orphan not in self._global_init_outputs: self._xgraph.remove_node(orphan) + def remove_unanchored_quanta(self, source_label: str, anchor_label: str) -> tuple[int, int]: + """Remove unanchored source quanta and their entire downstream chain. + + A source quantum is considered unanchored if no quantum with + ``anchor_label`` is reachable along directed edges from it. Unanchored + source quanta and every descendant reachable from them are removed. + + Parameters + ---------- + source_label : `str` + Task label of the source task whose unanchored quanta to remove. + anchor_label : `str` + Task label that must appear downstream of a source quantum for that + quantum to be considered anchored. + + Returns + ------- + n_source : `int` + Number of unanchored source quanta removed. + n_downstream : `int` + Number of additional downstream quanta removed (not counting + ``source_label`` quanta or any dataset nodes). + """ + if not self.has_task(source_label): + return 0, 0 + source_quanta = set(self.get_quanta(source_label)) + if not source_quanta: + return 0, 0 + + anchor_quanta = set(self.get_quanta(anchor_label)) if self.has_task(anchor_label) else set() + reachable: set[QuantumKey] = set() + for quantum in anchor_quanta: + reachable.update(networkx.ancestors(self._xgraph, quantum)) + + unanchored = source_quanta - reachable + if not unanchored: + return 0, 0 + + to_remove: set = set(unanchored) + for quantum in unanchored: + to_remove.update(networkx.descendants(self._xgraph, quantum)) + + n_downstream = sum(1 for n in to_remove if isinstance(n, QuantumKey) and n not in unanchored) + # to_remove has both QuantumKey and DatasetKey nodes + affected_labels = {n.task_label for n in to_remove if isinstance(n, QuantumKey)} + for node in to_remove: + if isinstance(node, QuantumKey): + _, quanta = self._tasks[node.task_label] + quanta.remove(node) + self._xgraph.remove_nodes_from(to_remove) + # For any task with no quanta remaining, remove its TaskInitKey and + # any init-output dataset nodes attached to it, then drop the task. + for label in affected_labels: + task_init_key, remaining = self._tasks[label] + if not remaining: + self._xgraph.remove_nodes_from(list(self._xgraph.successors(task_init_key))) + self.remove_task(label) + return len(unanchored), n_downstream + def extract_overall_inputs(self) -> dict[DatasetKey | PrerequisiteDatasetKey, DatasetRef]: """Find overall input datasets. diff --git a/tests/test_graphBuilder.py b/tests/test_graphBuilder.py index 6bc764af4..c65cfb5d2 100644 --- a/tests/test_graphBuilder.py +++ b/tests/test_graphBuilder.py @@ -465,6 +465,52 @@ def test_full_chain_unskipped_when_none_retained(self): self.assertEqual(len(qgraph), 3) +class PruneUnanchoredQuantaTestCase(unittest.TestCase): + """Tests for the prune_unanchored_quanta behavior of QuantumGraphBuilder. + + Pipeline: auto0 -> source -> auto1 -> anchor -> auto2 + + All tasks are dimensionless so each is one quantum. + """ + + def setUp(self): + self.helper = InMemoryRepo() + self.enterContext(self.helper) + self.helper.add_task("source") + self.helper.add_task("anchor") + self.helper.make_quantum_graph_builder(output_run="output_run") + + def _build(self, **kwargs): + return AllDimensionsQuantumGraphBuilder( + self.helper.pipeline_graph, + self.helper.butler, + input_collections=[self.helper.input_chain], + output_run="output_run", + **kwargs, + ).build(attach_datastore_records=False) + + def test_no_effect_without_parameter(self): + """Without prune_unanchored_quanta, all quanta are kept.""" + qg = self._build() + self.assertEqual(len(qg), 2) + + def test_no_pruning_when_anchor_reachable(self): + """Anchor reachable from source quantum: nothing is pruned.""" + qg = self._build(prune_unanchored_quanta=("source", "anchor")) + self.assertEqual(len(qg), 2) + + def test_all_pruned_when_anchor_label_absent(self): + """Anchor is absent: all source quanta and task removed.""" + qg = self._build(prune_unanchored_quanta=("source", "no_such_task")) + self.assertEqual(len(qg), 0) + self.assertNotIn("source", {td.label for td in qg.iterTaskGraph()}) + + def test_noop_when_source_label_absent(self): + """source_label not in pipeline: nothing happens.""" + qg = self._build(prune_unanchored_quanta=("no_such_task", "anchor")) + self.assertEqual(len(qg), 2) + + if __name__ == "__main__": lsst.utils.tests.init() unittest.main() diff --git a/tests/test_graphSkeleton.py b/tests/test_graphSkeleton.py new file mode 100644 index 000000000..3692a9526 --- /dev/null +++ b/tests/test_graphSkeleton.py @@ -0,0 +1,81 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Unit tests for QuantumGraphSkeleton.""" + +import unittest + +from lsst.daf.butler import DataCoordinate, DimensionUniverse +from lsst.pipe.base.quantum_graph_skeleton import DatasetKey, QuantumGraphSkeleton + + +class RemoveUnanchoredQuantaTestCase(unittest.TestCase): + """Tests for ``QuantumGraphSkeleton.remove_unanchored_quanta``. + + Graph: + source1 -> d1 -> anchor1 (band 1: anchored) + source2 -> d2 -> anchor2 (band 2: anchored) + source3 -> d3 (band 3: unanchored) + source4 -> d4 (band 4: unanchored) + """ + + def setUp(self): + universe = DimensionUniverse() + self.skeleton = QuantumGraphSkeleton(["source", "anchor"]) + for i in range(1, 5): + data_id = DataCoordinate.standardize({"band": i}, universe=universe) + source_quantum = self.skeleton.add_quantum_node("source", data_id) + dataset = self.skeleton.add_dataset_node(f"d{i}", data_id) + self.skeleton.add_output_edge(source_quantum, dataset) + if i <= 2: + anchor_quantum = self.skeleton.add_quantum_node("anchor", data_id) + self.skeleton.add_input_edge(anchor_quantum, dataset) + + def test_remove_unanchored(self): + """Unanchored source quanta and their descendants are removed.""" + n_source, n_downstream = self.skeleton.remove_unanchored_quanta("source", "anchor") + self.assertEqual(n_source, 2) + self.assertEqual(n_downstream, 0) + self.assertEqual(len(self.skeleton.get_quanta("source")), 2) + self.assertEqual(len(self.skeleton.get_quanta("anchor")), 2) + self.assertNotIn(DatasetKey("d3", (3,)), self.skeleton) + self.assertNotIn(DatasetKey("d4", (4,)), self.skeleton) + + # Second call on an already-pruned skeleton is a no-op. + n_source, n_downstream = self.skeleton.remove_unanchored_quanta("source", "anchor") + self.assertEqual(n_source, 0) + self.assertEqual(n_downstream, 0) + + def test_task_dropped_when_all_unanchored(self): + """Task is dropped when all its quanta are removed.""" + n_source, _ = self.skeleton.remove_unanchored_quanta("source", "nonexistent") + self.assertEqual(n_source, 4) + self.assertFalse(self.skeleton.has_task("source")) + + +if __name__ == "__main__": + unittest.main()