From ce2e1abd995c8cf2456c97fbad4b58e0835349ff Mon Sep 17 00:00:00 2001 From: Yan Wong Date: Mon, 9 Jun 2025 18:33:21 +0100 Subject: [PATCH] Make random-binary non-recursive --- python/tskit/combinatorics.py | 50 ++++++++++++++++++++++++----------- 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/python/tskit/combinatorics.py b/python/tskit/combinatorics.py index 880ec73675..bcbb8accb7 100644 --- a/python/tskit/combinatorics.py +++ b/python/tskit/combinatorics.py @@ -150,23 +150,43 @@ def random_binary_tree(leaf_labels, rng): while root.parent is not None: root = root.parent - # Canonicalise the order of the children within a node. This - # is given by (num_leaves, min_label). See also the - # RankTree.canonical_order function for the definition of - # how these are ordered during rank/unrank. + # Canonicalise the order of the children within a node using iterative approach + # This replaces the recursive reorder_children function - def reorder_children(node): - if len(node.children) == 0: - return 1, node.label - keys = [reorder_children(child) for child in node.children] - if keys[0] > keys[1]: - node.children = node.children[::-1] - return ( - sum(leaf_count for leaf_count, _ in keys), - min(min_label for _, min_label in keys), - ) + # Use post-order traversal to process leaf nodes first, then internal nodes + # We need to visit children before parents to compute (num_leaves, min_label) + + # First, collect all nodes in post-order using two stacks + stack1 = [root] + stack2 = [] + + while stack1: + node = stack1.pop() + stack2.append(node) + for child in node.children: + stack1.append(child) + + # Now process nodes in reverse order (post-order) + # Store results for each node as we compute them + node_keys = {} - reorder_children(root) + while stack2: + node = stack2.pop() + + if len(node.children) == 0: + # Leaf node + node_keys[node] = (1, node.label) + else: + # Internal node - get keys from children + keys = [node_keys[child] for child in node.children] + if keys[0] > keys[1]: + node.children = node.children[::-1] + keys = keys[::-1] # Update keys to match new order + + node_keys[node] = ( + sum(leaf_count for leaf_count, _ in keys), + min(min_label for _, min_label in keys), + ) return root @classmethod