diff --git a/graph/lowest_common_ancestor.cpp b/graph/lowest_common_ancestor.cpp index 7d5ab42b49..8c1801a74c 100644 --- a/graph/lowest_common_ancestor.cpp +++ b/graph/lowest_common_ancestor.cpp @@ -1,258 +1,327 @@ /** + * @file + * @brief Centroid Decomposition of a tree * - * \file + * @details + * Centroid Decomposition is a technique for efficiently solving + * path queries on trees. A centroid of a tree is a node whose + * removal results in no remaining subtree having more than n/2 nodes. * - * \brief Data structure for finding the lowest common ancestor - * of two vertices in a rooted tree using binary lifting. + * The decomposition works recursively: + * 1. Find the centroid of the current tree + * 2. Remove it and recurse on each subtree + * This creates a "centroid tree" of depth O(log N) * - * \details - * Algorithm: https://cp-algorithms.com/graph/lca_binary_lifting.html + * This implementation demonstrates centroid decomposition by + * counting the number of paths of a given length in a tree. + * + * Algorithm reference: https://cp-algorithms.com/graph/centroid_decomposition.html * * Complexity: - * - Precomputation: \f$O(N \log N)\f$ where \f$N\f$ is the number of vertices - * in the tree + * - Preprocessing: \f$O(N \log N)\f$ * - Query: \f$O(\log N)\f$ - * - Space: \f$O(N \log N)\f$ + * - Space: \f$O(N)\f$ * - * Example: - *
Tree: + * Example tree: *
- *             _  3  _
- *          /     |     \
- *        1       6       4
- *      / |     /   \       \
- *    7   5   2       8       0
- *            |
- *            9
+ *         1
+ *       / | \
+ *      2  3  4
+ *     / \
+ *    5   6
  * 
* - *
lowest_common_ancestor(7, 4) = 3 - *
lowest_common_ancestor(9, 6) = 6 - *
lowest_common_ancestor(0, 0) = 0 - *
lowest_common_ancestor(8, 2) = 6 - * - * The query is symmetrical, therefore - * lowest_common_ancestor(x, y) = lowest_common_ancestor(y, x) + * @author Your Name */ - + #include -#include -#include -#include +#include +#include #include - + /** - * \namespace graph - * \brief Graph algorithms + * @namespace graph + * @brief Graph algorithms */ namespace graph { + /** - * Class for representing a graph as an adjacency list. - * Its vertices are indexed 0, 1, ..., N - 1. + * @brief Class implementing Centroid Decomposition on a tree. + * + * Supports efficient path queries using the decomposition structure. */ -class Graph { +class CentroidDecomposition { public: + int n; ///< number of nodes (1-indexed) + std::vector> adj; ///< adjacency list + std::vector subtree_sz; ///< subtree size for each node + std::vector removed; ///< marks removed centroids + std::vector centroid_parent; ///< parent in centroid tree + /** - * \brief Populate the adjacency list for each vertex in the graph. - * Assumes that evey edge is a pair of valid vertex indices. - * - * @param N number of vertices in the graph - * @param undirected_edges list of graph's undirected edges - */ - Graph(size_t N, const std::vector > &undirected_edges) { - neighbors.resize(N); - for (auto &edge : undirected_edges) { - neighbors[edge.first].push_back(edge.second); - neighbors[edge.second].push_back(edge.first); - } - } - - /** - * Function to get the number of vertices in the graph - * @return the number of vertices in the graph. + * @brief Constructor: initializes the decomposition for n nodes. + * @param n number of nodes (1-indexed) */ - int number_of_vertices() const { return neighbors.size(); } - - /** \brief for each vertex it stores a list indicies of its neighbors */ - std::vector > neighbors; -}; - -/** - * Representation of a rooted tree. For every vertex its parent is - * precalculated. - */ -class RootedTree : public Graph { - public: + explicit CentroidDecomposition(int n) + : n(n), + adj(n + 1), + subtree_sz(n + 1, 0), + removed(n + 1, false), + centroid_parent(n + 1, -1) {} + /** - * \brief Constructs the tree by calculating parent for every vertex. - * Assumes a valid description of a tree is provided. - * - * @param undirected_edges list of graph's undirected edges - * @param root_ index of the root vertex + * @brief Adds an undirected edge between u and v. + * @param u first node + * @param v second node */ - RootedTree(const std::vector > &undirected_edges, - int root_) - : Graph(undirected_edges.size() + 1, undirected_edges), root(root_) { - populate_parents(); + void add_edge(int u, int v) { + adj[u].push_back(v); + adj[v].push_back(u); } - + /** - * \brief Stores parent of every vertex and for root its own index. - * The root is technically not its own parent, but it's very practical - * for the lowest common ancestor algorithm. + * @brief Computes subtree sizes via DFS. + * @param v current node + * @param parent parent of current node + * @return size of subtree rooted at v */ - std::vector parent; - /** \brief Stores the distance from the root. */ - std::vector level; - /** \brief Index of the root vertex. */ - int root; - - protected: + int get_subtree_size(int v, int parent) { + subtree_sz[v] = 1; + for (int u : adj[v]) { + if (u != parent && !removed[u]) { + subtree_sz[v] += get_subtree_size(u, v); + } + } + return subtree_sz[v]; + } + /** - * \brief Calculate the parents for all the vertices in the tree. - * Implements the breadth first search algorithm starting from the root - * vertex searching the entire tree and labeling parents for all vertices. - * @returns none + * @brief Finds the centroid of a subtree rooted at v. + * + * The centroid is a node where no subtree (after its removal) + * has more than tree_size/2 nodes. + * + * @param v current node + * @param parent parent of current node + * @param tree_size total size of the current tree component + * @return centroid node index */ - void populate_parents() { - // Initialize the vector with -1 which indicates the vertex - // wasn't yet visited. - parent = std::vector(number_of_vertices(), -1); - level = std::vector(number_of_vertices()); - parent[root] = root; - level[root] = 0; - std::queue queue_of_vertices; - queue_of_vertices.push(root); - while (!queue_of_vertices.empty()) { - int vertex = queue_of_vertices.front(); - queue_of_vertices.pop(); - for (int neighbor : neighbors[vertex]) { - // As long as the vertex was not yet visited. - if (parent[neighbor] == -1) { - parent[neighbor] = vertex; - level[neighbor] = level[vertex] + 1; - queue_of_vertices.push(neighbor); + int get_centroid(int v, int parent, int tree_size) { + for (int u : adj[v]) { + if (u != parent && !removed[u]) { + if (subtree_sz[u] > tree_size / 2) { + return get_centroid(u, v, tree_size); } } } + return v; } -}; - -/** - * A structure that holds a rooted tree and allow for effecient - * queries of the lowest common ancestor of two given vertices in the tree. - */ -class LowestCommonAncestor { - public: + /** - * \brief Stores the tree and precomputs "up lifts". - * @param tree_ rooted tree. + * @brief Recursively decomposes the tree and records centroid parents. + * + * For each component, finds the centroid, marks it as removed, + * and recurses on remaining subtrees. + * + * @param v any node in the current component + * @param parent_centroid centroid of the parent component (-1 for root) */ - explicit LowestCommonAncestor(const RootedTree &tree_) : tree(tree_) { - populate_up(); + void decompose(int v, int parent_centroid) { + int sz = get_subtree_size(v, -1); + int centroid = get_centroid(v, -1, sz); + + centroid_parent[centroid] = parent_centroid; + removed[centroid] = true; + + for (int u : adj[centroid]) { + if (!removed[u]) { + decompose(u, centroid); + } + } } - + /** - * \brief Query the structure to find the lowest common ancestor. - * Assumes that the provided numbers are valid indices of vertices. - * Iterativelly modifies ("lifts") u an v until it finnds their lowest - * common ancestor. - * @param u index of one of the queried vertex - * @param v index of the other queried vertex - * @return index of the vertex which is the lowet common ancestor of u and v + * @brief Builds the centroid decomposition starting from node 1. */ - int lowest_common_ancestor(int u, int v) const { - // Ensure u is the deeper (higher level) of the two vertices - if (tree.level[v] > tree.level[u]) { - std::swap(u, v); - } - - // "Lift" u to the same level as v. - int level_diff = tree.level[u] - tree.level[v]; - for (int i = 0; (1 << i) <= level_diff; ++i) { - if (level_diff & (1 << i)) { - u = up[u][i]; + void build() { decompose(1, -1); } + + /** + * @brief Returns the centroid parent of a node. + * @param v node index + * @return centroid parent of v, or -1 if v is the centroid root + */ + int get_centroid_parent(int v) const { return centroid_parent[v]; } +}; + +/** + * @brief Counts paths of a given target length in the tree using + * centroid decomposition. + * + * Uses the standard centroid decomposition approach: + * For each centroid c, count paths passing through c. + * A path passes through c if it combines a distance from c in one subtree + * with a distance from c in another subtree summing to target_len. + * + * @param n number of nodes + * @param edges list of edges as pairs (u, v) + * @param target_len target path length to count + * @return number of unordered pairs (u, v) with path length == target_len + */ +int count_paths_of_length(int n, + const std::vector>& edges, + int target_len) { + std::vector> adj(n + 1); + for (auto& e : edges) { + adj[e.first].push_back(e.second); + adj[e.second].push_back(e.first); + } + + std::vector subtree_sz(n + 1, 0); + std::vector removed(n + 1, false); + int result = 0; + + std::function calc_size = [&](int v, int p) -> int { + subtree_sz[v] = 1; + for (int u : adj[v]) { + if (u != p && !removed[u]) { + subtree_sz[v] += calc_size(u, v); } } - assert(tree.level[u] == tree.level[v]); - - if (u == v) { - return u; + return subtree_sz[v]; + }; + + std::function find_centroid = + [&](int v, int p, int sz) -> int { + for (int u : adj[v]) { + if (u != p && !removed[u] && subtree_sz[u] > sz / 2) { + return find_centroid(u, v, sz); + } } - - // "Lift" u and v to their 2^i th ancestor if they are different - for (int i = static_cast(up[u].size()) - 1; i >= 0; --i) { - if (up[u][i] != up[v][i]) { - u = up[u][i]; - v = up[v][i]; + return v; + }; + + // collect distances from a node via DFS + std::function&)> collect = + [&](int v, int p, int dist, std::vector& dists) { + dists.push_back(dist); + for (int u : adj[v]) { + if (u != p && !removed[u]) { + collect(u, v, dist + 1, dists); } } - - // As we regressed u an v such that they cannot further be lifted so - // that their ancestor would be different, the only logical - // consequence is that their parent is the sought answer. - assert(up[u][0] == up[v][0]); - return up[u][0]; - } - - /* \brief reference to the rooted tree this structure allows to query */ - const RootedTree &tree; - /** - * \brief for every vertex stores a list of its ancestors by powers of two - * For each vertex, the first element of the corresponding list contains - * the index of its parent. The i-th element of the list is an index of - * the (2^i)-th ancestor of the vertex. - */ - std::vector > up; - - protected: - /** - * Populate the "up" structure. See above. - */ - void populate_up() { - up.resize(tree.number_of_vertices()); - for (int vertex = 0; vertex < tree.number_of_vertices(); ++vertex) { - up[vertex].push_back(tree.parent[vertex]); + }; + + // count pairs from a centroid using frequency map + auto count_from_centroid = [&](int c) { + std::map freq; + freq[0] = 1; // centroid itself at distance 0 + + for (int nb : adj[c]) { + if (removed[nb]) continue; + std::vector dists; + collect(nb, c, 1, dists); + + // count valid pairs with distances already recorded + for (int d : dists) { + int need = target_len - d; + if (need >= 0 && freq.count(need)) { + result += freq[need]; + } + } + // add new distances to frequency map + for (int d : dists) { + freq[d]++; + } } - for (int level = 0; (1 << level) < tree.number_of_vertices(); ++level) { - for (int vertex = 0; vertex < tree.number_of_vertices(); ++vertex) { - // up[vertex][level + 1] = 2^(level + 1) th ancestor of vertex = - // = 2^level th ancestor of 2^level th ancestor of vertex = - // = 2^level th ancestor of up[vertex][level] - up[vertex].push_back(up[up[vertex][level]][level]); + }; + + std::function solve = [&](int v) { + int sz = calc_size(v, -1); + int c = find_centroid(v, -1, sz); + count_from_centroid(c); + removed[c] = true; + for (int u : adj[c]) { + if (!removed[u]) { + solve(u); } } - } -}; - + }; + + solve(1); + return result; +} + } // namespace graph - + /** - * Unit tests + * @brief Unit tests for centroid decomposition * @returns none */ static void tests() { - /** - * _ 3 _ - * / | \ - * 1 6 4 - * / | / \ \ - * 7 5 2 8 0 - * | - * 9 - */ - std::vector > edges = { - {7, 1}, {1, 5}, {1, 3}, {3, 6}, {6, 2}, {2, 9}, {6, 8}, {4, 3}, {0, 4}}; - graph::RootedTree t(edges, 3); - graph::LowestCommonAncestor lca(t); - assert(lca.lowest_common_ancestor(7, 4) == 3); - assert(lca.lowest_common_ancestor(9, 6) == 6); - assert(lca.lowest_common_ancestor(0, 0) == 0); - assert(lca.lowest_common_ancestor(8, 2) == 6); + // Test 1: basic centroid parent structure on a path graph 1-2-3-4-5 + { + graph::CentroidDecomposition cd(5); + cd.add_edge(1, 2); + cd.add_edge(2, 3); + cd.add_edge(3, 4); + cd.add_edge(4, 5); + cd.build(); + + // centroid of 1-2-3-4-5 is node 3 + // its centroid parent should be -1 (root of centroid tree) + assert(cd.get_centroid_parent(3) == -1); + // nodes in subtrees of 3 should have 3 as centroid parent + assert(cd.get_centroid_parent(2) == 3 || + cd.get_centroid_parent(4) == 3); + } + + // Test 2: count paths of length 2 in a star graph (center = 1) + { + /* + * 1 + * / | \ + * 2 3 4 + * paths of length 2: (2,3), (2,4), (3,4) => 3 paths + */ + std::vector> edges = {{1, 2}, {1, 3}, {1, 4}}; + int result = graph::count_paths_of_length(4, edges, 2); + assert(result == 3); + } + + // Test 3: count paths of length 2 in a path graph 1-2-3-4-5 + { + /* + * 1-2-3-4-5 + * paths of length 2: (1,3), (2,4), (3,5) => 3 paths + */ + std::vector> edges = { + {1, 2}, {2, 3}, {3, 4}, {4, 5}}; + int result = graph::count_paths_of_length(5, edges, 2); + assert(result == 3); + } + + // Test 4: count paths of length 1 (edges) in a path of 3 nodes + { + // 1-2-3: edges = (1,2), (2,3) => 2 paths of length 1 + std::vector> edges = {{1, 2}, {2, 3}}; + int result = graph::count_paths_of_length(3, edges, 1); + assert(result == 2); + } + + // Test 5: single node tree — no paths + { + std::vector> edges = {}; + int result = graph::count_paths_of_length(1, edges, 1); + assert(result == 0); + } } - -/** Main function */ + +/** + * @brief Main function + * @returns 0 on success + */ int main() { tests(); return 0; } +