From 38f655cea6cf5c1f8a7fb3a49333d7b4e4cf62d2 Mon Sep 17 00:00:00 2001 From: rgao user Date: Sat, 2 May 2026 06:55:29 +0000 Subject: [PATCH 01/26] Add all-to-all graph parallel communication primitives All-to-all (A2A) communication module for graph parallelism, replacing the all-gather approach with point-to-point exchange of only boundary atoms. Key components: - GPContext: dataclass holding per-rank atom assignments and A2A metadata - build_gp_context: builds communication plan from edge connectivity - AllToAllCollect: autograd-compatible A2A embedding exchange - all_to_all_collect_compiled: torch.compile-friendly variant using functional collectives (no graph break) - _safe_all_to_all: Gloo fallback for CPU testing - _sparse_index_exchange: variable-split index exchange Tests (22 total, all CPU/Gloo): - 5 unit tests for build_gp_context (context building, global-to-local mapping, edge split indices) - 9 distributed tests comparing A2A vs all-gather correctness (forward, backward, multi-rank, spatial partition) - 8 distributed correctness tests (dense/sparse graphs, multi-dim embeddings, index_split + spatial strategies) --- .../core/common/parallelism/__init__.py | 12 + .../common/parallelism/graph_parallel_a2a.py | 775 ++++++++++++++++++ .../parallelism/test_a2a_correctness.py | 610 ++++++++++++++ .../common/parallelism/test_graph_parallel.py | 651 +++++++++++++++ 4 files changed, 2048 insertions(+) create mode 100644 src/fairchem/core/common/parallelism/graph_parallel_a2a.py create mode 100644 tests/core/common/parallelism/test_a2a_correctness.py create mode 100644 tests/core/common/parallelism/test_graph_parallel.py diff --git a/src/fairchem/core/common/parallelism/__init__.py b/src/fairchem/core/common/parallelism/__init__.py index b795dbc7ba..99b668f46b 100644 --- a/src/fairchem/core/common/parallelism/__init__.py +++ b/src/fairchem/core/common/parallelism/__init__.py @@ -7,6 +7,13 @@ from __future__ import annotations +from fairchem.core.common.parallelism.graph_parallel_a2a import ( + AllToAllCollect, + GPContext, + all_to_all_collect, + all_to_all_collect_compiled, + build_gp_context, +) from fairchem.core.common.parallelism.graph_partition import ( PartitionStrategy, partition_atoms_index_split, @@ -14,7 +21,12 @@ ) __all__ = [ + "AllToAllCollect", + "GPContext", "PartitionStrategy", + "all_to_all_collect", + "all_to_all_collect_compiled", + "build_gp_context", "partition_atoms_index_split", "partition_atoms_spatial", ] diff --git a/src/fairchem/core/common/parallelism/graph_parallel_a2a.py b/src/fairchem/core/common/parallelism/graph_parallel_a2a.py new file mode 100644 index 0000000000..c65ffbcb75 --- /dev/null +++ b/src/fairchem/core/common/parallelism/graph_parallel_a2a.py @@ -0,0 +1,775 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass + +import torch +from torch import distributed as dist +from torch.profiler import record_function + +from fairchem.core.common import gp_utils + + +def _safe_all_to_all( + output_list: list[torch.Tensor], + input_list: list[torch.Tensor], + group: dist.ProcessGroup, +) -> None: + """ + All-to-all with fallback for backends that don't support it (e.g. Gloo). + + When the backend supports all_to_all natively (NCCL), uses it directly. + Otherwise, falls back to pairwise isend/irecv which works on any backend. + + Args: + output_list: List of output tensors, one per rank. + input_list: List of input tensors, one per rank. + group: Process group. + """ + backend = dist.get_backend(group) + if backend == "nccl": + dist.all_to_all(output_list, input_list, group=group) + else: + # Gloo fallback: use pairwise send/recv + rank = dist.get_rank(group) + world_size = dist.get_world_size(group) + ops = [] + for r in range(world_size): + if r == rank: + # Local copy + if input_list[r].numel() > 0: + output_list[r].copy_(input_list[r]) + elif input_list[r].numel() > 0 or output_list[r].numel() > 0: + # Skip zero-length P2P ops to avoid potential hangs + ops.append(dist.P2POp(dist.isend, input_list[r], r, group=group)) + ops.append(dist.P2POp(dist.irecv, output_list[r], r, group=group)) + if ops: + reqs = dist.batch_isend_irecv(ops) + for req in reqs: + req.wait() + + +@dataclass +class GPContext: + """ + Graph parallel context holding per-rank atom assignments + and communication metadata for all-to-all. + + This replaces the all-gather approach by tracking which atoms + each rank needs from other ranks for its local edge computations. + + Attributes: + rank: Current GP rank. + world_size: Number of GP ranks. + node_partition: Global indices of atoms owned by this rank. + rank_assignments: For every atom in the global graph, which rank owns it. + Shape: (total_atoms,), dtype: int. + needed_atoms: Global indices of non-local atoms this rank needs + (sources of edges whose targets are in this rank's partition). + needed_from_ranks: For each atom in needed_atoms, which rank owns it. + send_counts: Number of atoms to send to each rank. Shape: (world_size,). + recv_counts: Number of atoms to receive from each rank. + Shape: (world_size,). + global_to_local: Mapping from global atom index to position in the + local concatenated tensor [local_atoms | received_atoms]. + Shape: (total_atoms,), with -1 for atoms not accessible. + total_local_atoms: Number of atoms in this rank's partition. + total_needed_atoms: Total atoms needed from other ranks. + send_indices: Precomputed local indices of atoms to send, ordered by + destination rank. Computed once at build time to avoid per-forward + all-to-all index exchange. None if not yet computed. + edge_index_local: Precomputed edge index remapped to local indices. + None if not yet computed (set by build_gp_context when edge_index + is provided). + """ + + rank: int + world_size: int + node_partition: torch.Tensor + rank_assignments: torch.Tensor + needed_atoms: torch.Tensor + needed_from_ranks: torch.Tensor + send_counts: torch.Tensor + recv_counts: torch.Tensor + global_to_local: torch.Tensor + total_local_atoms: int + total_needed_atoms: int + send_indices: torch.Tensor | None = None + edge_index_local: torch.Tensor | None = None + # Precomputed Python lists to avoid repeated .tolist() in AllToAllCollect + send_splits: list[int] | None = None + recv_splits: list[int] | None = None + total_recv: int | None = None + # Precomputed integer indices for local/remote edges (for + # comm-compute overlap). Local edges have source atoms owned by + # this rank (edge_index_local[0] < total_local_atoms), remote + # edges have sources from other ranks. Using integer indices + # instead of boolean masks for compile-friendly indexing (avoids + # dynamic-shape boolean masking in compiled graphs). + local_edge_idx: torch.Tensor | None = None + remote_edge_idx: torch.Tensor | None = None + + +def _sparse_index_exchange( + needed_atoms: torch.Tensor, + recv_counts: torch.Tensor, + rank: int, + world_size: int, + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Variable-split index exchange using two small all-to-alls. + + Step 1: Exchange recv_counts to get send_counts (P ints). + Step 2: Exchange actual atom indices with variable split sizes. + + Sends only the exact number of indices needed (no padding), + keeping communication volume minimal. + + Args: + needed_atoms: Global indices of atoms this rank needs, + pre-sorted by source rank (done by the caller). + recv_counts: Number of atoms needed from each rank. + rank: This rank's GP rank. + world_size: GP world size. + device: Tensor device. + + Returns: + Tuple of (send_counts, send_indices_global). + """ + if not gp_utils.initialized(): + return ( + torch.zeros(world_size, dtype=torch.long, device=device), + None, + ) + + gp_group = gp_utils.get_gp_group() + backend = dist.get_backend(gp_group) + + # Step 1: Exchange counts. + # What rank A calls recv_counts[B] is what rank B must send to A. + # So rank B's send_counts[A] = rank A's recv_counts[B]. + # all_to_all on a (world_size,) tensor transposes the count matrix. + send_counts = torch.empty(world_size, dtype=torch.long, device=device) + if backend == "nccl": + dist.all_to_all_single(send_counts, recv_counts.contiguous(), group=gp_group) + else: + # Gloo fallback: use pairwise send/recv + send_list = list(recv_counts.split(1)) + recv_list = list(send_counts.split(1)) + _safe_all_to_all(recv_list, send_list, group=gp_group) + + # Step 2: Exchange actual atom indices with variable splits. + # needed_atoms is already sorted by source rank (done by the + # caller in build_gp_context), so use it directly as send buffer. + if needed_atoms.numel() > 0: + send_buf = needed_atoms.contiguous() + else: + send_buf = torch.empty(0, dtype=torch.long, device=device) + + # Batch send_counts and recv_counts into a single GPU→CPU transfer. + # This eliminates 2 extra GPU→CPU syncs vs separate .tolist() calls. + counts_cpu = torch.stack([send_counts, recv_counts]).cpu() + recv_splits = counts_cpu[0].tolist() # what we recv = what we need + send_splits = counts_cpu[1].tolist() # what we send = what others need + total_recv_indices = sum(recv_splits) + recv_buf = torch.empty(total_recv_indices, dtype=torch.long, device=device) + + if backend == "nccl": + dist.all_to_all_single( + recv_buf, + send_buf, + output_split_sizes=recv_splits, + input_split_sizes=send_splits, + group=gp_group, + ) + else: + # Gloo fallback: use pairwise send/recv + send_list = list(send_buf.split(send_splits)) + recv_list = list(recv_buf.split(recv_splits)) + _safe_all_to_all(recv_list, send_list, group=gp_group) + + # recv_buf now contains the global indices of atoms we must SEND, + # ordered by destination rank. + send_indices_global = recv_buf + + return send_counts, send_indices_global + + +@torch.compiler.disable +def build_gp_context( + edge_index: torch.Tensor, + rank_assignments: torch.Tensor, + rank: int, + world_size: int, + send_info: dict | None = None, + node_partition: torch.Tensor | None = None, +) -> GPContext: + """ + Build the GP context from edge connectivity and atom assignments. + + Determines which non-local atoms this rank needs (edge sources from + other ranks), exchanges atom indices via all-to-all, and computes + all communication metadata. + + When send_info is provided (pre-computed during graph filtering in + filter_edges_by_node_partition), the NCCL index-exchange collective + is skipped entirely — send_counts and send_indices_global are taken + directly from send_info. + + Args: + edge_index: Edge index filtered to edges whose targets are in + this rank's partition, shape (2, num_edges). + Row 0 = source, row 1 = target. + rank_assignments: Rank assignment for each atom, + shape (total_atoms,). + rank: This rank's GP rank. + world_size: GP world size. + send_info: Pre-computed send/recv metadata from graph filtering. + If provided, must contain: + - send_counts: Tensor of shape (world_size,) with count of + atoms to send to each rank. + - send_indices_global: Tensor of global atom indices to send, + sorted by destination rank. + When provided, _sparse_index_exchange is skipped. + node_partition: Pre-computed atom indices in this rank's + partition. If provided, avoids recomputing from + rank_assignments. + + Returns: + GPContext with all metadata needed for all-to-all communication. + """ + total_atoms = rank_assignments.shape[0] + device = rank_assignments.device + + # Atoms owned by this rank (reuse pre-computed if available) + if node_partition is None: + node_partition = (rank_assignments == rank).nonzero(as_tuple=True)[0] + total_local_atoms = node_partition.shape[0] + + # Find which non-local atoms this rank needs as edge sources. + # Since edge_index is already filtered to edges whose targets are + # in this rank's partition, every edge has a local target. We only + # need to find edges where the SOURCE is remote (not in our partition). + local_mask = rank_assignments == rank # (total_atoms,) bool + src_is_remote = ~local_mask[edge_index[0]] + + # Remote sources needed for local targets + # Use boolean mask + nonzero instead of .unique(sorted=True) on raw + # edge sources — O(N) scatter + scan vs O(E log E) sort. + needed_mask = torch.zeros(total_atoms, dtype=torch.bool, device=device) + needed_mask[edge_index[0, src_is_remote]] = True + needed_mask &= ~local_mask # exclude local atoms (safety) + needed_atoms = needed_mask.nonzero(as_tuple=True)[0] + + total_needed_atoms = needed_atoms.shape[0] + needed_from_ranks = rank_assignments[needed_atoms] + + # Compute recv_counts: how many atoms we receive from each rank + if total_needed_atoms > 0: + recv_counts = torch.bincount(needed_from_ranks, minlength=world_size).to( + dtype=torch.long, device=device + ) + else: + recv_counts = torch.zeros(world_size, dtype=torch.long, device=device) + recv_counts[rank] = 0 # Never receive from self + + # CRITICAL: Sort needed_atoms by source rank to match recv_buf ordering. + # all_to_all fills recv_buf by source rank: [atoms from rank 0 | atoms + # from rank 1 | ...]. Within each rank, atoms are in the order we + # requested them (global index order, since argsort is stable). + # global_to_local must assign local indices in this SAME order, + # otherwise local index i maps to recv_buf[i] which has a DIFFERENT + # atom's embedding. + # With index_split, global index order == rank order (no-op sort). + # With spatial, global index order != rank order → sort is essential. + sort_order = needed_from_ranks.argsort(stable=True) + needed_atoms = needed_atoms[sort_order] + needed_from_ranks = needed_from_ranks[sort_order] + + # Use pre-computed send_info when available to skip the + # _sparse_index_exchange NCCL collectives entirely. + if send_info is not None: + # Pre-computed during graph filtering — use it. + send_counts = send_info["send_counts"] + send_indices_global = send_info["send_indices_global"] + else: + with record_function("a2a_sparse_index_exchange"): + send_counts, send_indices_global = _sparse_index_exchange( + needed_atoms=needed_atoms, + recv_counts=recv_counts, + rank=rank, + world_size=world_size, + device=device, + ) + + # Build global_to_local mapping: + # Local atoms: index 0..total_local_atoms-1 (in order of + # node_partition) + # Received atoms: index total_local_atoms..total_local_atoms+total_needed + # IMPORTANT: needed_atoms is sorted by source rank (not global index) + # to match the recv_buf ordering from all_to_all. This ensures that + # local index (total_local + i) maps to recv_buf[i], which contains + # the embedding of needed_atoms[i]. + global_to_local = torch.full((total_atoms,), -1, dtype=torch.long, device=device) + # Map local atoms + global_to_local[node_partition] = torch.arange( + total_local_atoms, dtype=torch.long, device=device + ) + # Map needed remote atoms (in recv_buf order = source rank order) + global_to_local[needed_atoms] = torch.arange( + total_local_atoms, + total_local_atoms + total_needed_atoms, + dtype=torch.long, + device=device, + ) + + # Convert send_indices from global to local + send_indices = None + has_send = send_indices_global is not None + if has_send and send_indices_global.numel() > 0: + send_indices = global_to_local[send_indices_global] + elif has_send: + send_indices = torch.empty(0, dtype=torch.long, device=device) + + # Precompute edge_index_local + edge_index_local = global_to_local[edge_index] + + # Batch ALL GPU→CPU scalar extractions into a single transfer. + # This batches send_counts, recv_counts, AND validation scalars + # into ONE .cpu() call, eliminating extra GPU→CPU syncs. + bad_edge_count = (edge_index_local < 0).sum().unsqueeze(0).to(torch.long) + send_valid = ( + torch.ones(1, dtype=torch.long, device=device) + if send_indices is None or send_indices.numel() == 0 + else ( + ((send_indices >= 0) & (send_indices < total_local_atoms)) + .all() + .unsqueeze(0) + .to(torch.long) + ) + ) + all_cpu = torch.cat([send_counts, recv_counts, bad_edge_count, send_valid]).cpu() + send_splits = all_cpu[:world_size].tolist() + recv_splits = all_cpu[world_size : 2 * world_size].tolist() + total_recv = sum(recv_splits) + n_bad = int(all_cpu[2 * world_size].item()) + send_ok = int(all_cpu[2 * world_size + 1].item()) + + # Validate AFTER the batched CPU transfer (no extra GPU syncs). + if not send_ok: + # Diagnostic: identify which send_indices are out of range. + bad_mask = (send_indices < 0) | (send_indices >= total_local_atoms) + n_bad_send = bad_mask.sum().item() + n_total_send = send_indices.numel() + bad_global = send_indices_global[bad_mask][:10].tolist() + bad_ra = rank_assignments[send_indices_global[bad_mask][:10]].tolist() + raise RuntimeError( + f"Rank {rank}: received requests for atoms not in our " + f"partition ({n_bad_send}/{n_total_send} OOB). " + f"bad_global={bad_global}, bad_ranks={bad_ra}. " + f"This usually means rank_assignments differs across " + f"ranks (e.g. non-deterministic crystal generation)." + ) + if n_bad > 0: + # Only compute diagnostics in the error path (rare). + bad_cols = (edge_index_local < 0).any(dim=0) + bad_globals = edge_index[:, bad_cols].unique() + bad_ranks = rank_assignments[bad_globals] + + # Compute edge-based needed atoms for comparison. + edge_src_remote = ~local_mask[edge_index[0]] + edge_needed_mask = torch.zeros(total_atoms, dtype=torch.bool, device=device) + edge_needed_mask[edge_index[0, edge_src_remote]] = True + edge_needed_mask &= ~local_mask + edge_needed_count = edge_needed_mask.sum().item() + + # Check which edge-needed atoms are NOT in our needed_atoms. + needed_set = torch.zeros(total_atoms, dtype=torch.bool, device=device) + needed_set[needed_atoms] = True + missing = edge_needed_mask & ~needed_set + missing_count = missing.sum().item() + + # Check for local atoms in needed_atoms. + local_in_needed = (local_mask & needed_set).sum().item() + + # Check for bad local atoms (local atoms mapped to -1). + local_is_bad = local_mask[bad_globals] + n_local_bad = local_is_bad.sum().item() + + logging.error( + f"Rank {rank}: GP DIAGNOSTIC — " + f"{n_bad} entries in edge_index_local are -1. " + f"edge_needed={edge_needed_count}, " + f"needed_atoms_count={total_needed_atoms}, " + f"missing_from_needed={missing_count}, " + f"local_in_needed={local_in_needed}, " + f"n_local_bad={n_local_bad}/{len(bad_globals)} bad globals, " + f"total_atoms={total_atoms}, " + f"total_local={total_local_atoms}, " + f"node_partition_range=[{node_partition.min().item()}, " + f"{node_partition.max().item()}], " + f"send_info_provided={send_info is not None}, " + f"bad_globals[:20]={bad_globals.tolist()[:20]}, " + f"bad_ranks[:20]={bad_ranks.tolist()[:20]}" + ) + if missing_count > 0: + missing_indices = missing.nonzero(as_tuple=True)[0][:10] + logging.error( + f"Rank {rank}: Missing atoms (edge-needed but " + f"not in needed_atoms): {missing_indices.tolist()}" + ) + raise RuntimeError( + f"Rank {rank}: edge_index has {n_bad} endpoints not in " + f"global_to_local mapping. This indicates a mismatch " + f"between graph edges and partition assignments." + ) + + # Precompute local/remote edge indices for comm-compute overlap. + # An edge is "local-source" if its source atom is owned by this + # rank (index < total_local_atoms in the remapped edge_index_local). + local_edge_mask = edge_index_local[0] < total_local_atoms + local_edge_idx = local_edge_mask.nonzero(as_tuple=True)[0] + remote_edge_idx = (~local_edge_mask).nonzero(as_tuple=True)[0] + + return GPContext( + rank=rank, + world_size=world_size, + node_partition=node_partition, + rank_assignments=rank_assignments, + needed_atoms=needed_atoms, + needed_from_ranks=needed_from_ranks, + send_counts=send_counts, + recv_counts=recv_counts, + global_to_local=global_to_local, + total_local_atoms=total_local_atoms, + total_needed_atoms=total_needed_atoms, + send_indices=send_indices, + edge_index_local=edge_index_local, + # Precompute Python lists once (avoids .tolist() per layer per fwd) + send_splits=send_splits, + recv_splits=recv_splits, + total_recv=total_recv, + local_edge_idx=local_edge_idx, + remote_edge_idx=remote_edge_idx, + ) + + +class AllToAllCollect(torch.autograd.Function): + """ + Autograd function that uses all-to-all to collect only the needed + remote atom embeddings, replacing the all-gather approach. + + Forward: Sends local atom embeddings to ranks that need them, + receives remote atom embeddings that we need. Returns only the + received remote embeddings (NOT concatenated with local). + + Backward: Reverses the communication — sends gradient of received + embeddings back to their owners, receives gradient of sent + embeddings. + + Optimizations over naive all-to-all: + - Uses ``all_to_all_single`` on NCCL to avoid Python list creation + from ``split()`` — communicates packed tensors directly. + - Returns the pre-allocated receive buffer directly instead of + ``torch.cat(recv_list)`` — avoids a redundant copy. + - Accepts precomputed ``send_splits``/``recv_splits`` to avoid + repeated ``.tolist()`` calls per layer. + """ + + @staticmethod + @torch.compiler.disable + def forward( + ctx, + x_local: torch.Tensor, + send_indices: torch.Tensor, + send_counts: torch.Tensor, + recv_counts: torch.Tensor, + gp_group: dist.ProcessGroup, + rank: int, + world_size: int, + precomputed_send_splits: list[int] | None = None, + precomputed_recv_splits: list[int] | None = None, + precomputed_total_recv: int | None = None, + ) -> torch.Tensor: + """ + Forward all-to-all embedding collection. + + Args: + x_local: Local atom embeddings, + shape (local_atoms, *feature_dims). + send_indices: Local indices of atoms to send, + ordered by dest rank. + send_counts: Number of atoms to send to each rank. + recv_counts: Number of atoms to receive from each rank. + gp_group: GP process group. + rank: GP rank. + world_size: GP world size. + precomputed_send_splits: Optional cached + send_counts.tolist(). + precomputed_recv_splits: Optional cached + recv_counts.tolist(). + precomputed_total_recv: Optional cached + sum(recv_splits). + + Returns: + Received remote embeddings, + shape (sum(recv_counts), *feature_dims). + """ + ctx.send_indices = send_indices + ctx.send_counts = send_counts + ctx.recv_counts = recv_counts + ctx.gp_group = gp_group + ctx.rank = rank + ctx.world_size = world_size + ctx.local_size = x_local.shape[0] + # Cache precomputed splits for backward + ctx.precomputed_send_splits = precomputed_send_splits + ctx.precomputed_recv_splits = precomputed_recv_splits + + feature_shape = x_local.shape[1:] + + # Gather atoms to send (index_select into contiguous buffer) + if send_indices.numel() > 0: + x_send = x_local[send_indices].contiguous() + else: + x_send = torch.empty( + 0, + *feature_shape, + device=x_local.device, + dtype=x_local.dtype, + ) + + # Use precomputed splits if available + send_splits = ( + precomputed_send_splits + if precomputed_send_splits is not None + else send_counts.tolist() + ) + recv_splits = ( + precomputed_recv_splits + if precomputed_recv_splits is not None + else recv_counts.tolist() + ) + total_recv = ( + precomputed_total_recv + if precomputed_total_recv is not None + else sum(recv_splits) + ) + x_recv = torch.empty( + total_recv, + *feature_shape, + device=x_local.device, + dtype=x_local.dtype, + ) + + # Perform all-to-all communication + backend = dist.get_backend(gp_group) + if backend == "nccl": + # Use all_to_all_single for NCCL + dist.all_to_all_single( + x_recv, + x_send, + output_split_sizes=recv_splits, + input_split_sizes=send_splits, + group=gp_group, + ) + else: + # Gloo fallback: use list-based pairwise send/recv + send_list = list(x_send.split(send_splits)) + recv_list = list(x_recv.split(recv_splits)) + _safe_all_to_all(recv_list, send_list, group=gp_group) + + # x_recv already contains all received data in rank order + return x_recv + + @staticmethod + @torch.compiler.disable + def backward(ctx, grad_received: torch.Tensor): + """ + Reverse the all-to-all: send gradients back to the ranks that + originally sent us the embeddings. + """ + send_counts = ctx.send_counts + recv_counts = ctx.recv_counts + send_indices = ctx.send_indices + gp_group = ctx.gp_group + local_size = ctx.local_size + + feature_shape = grad_received.shape[1:] + + # In backward, the roles are reversed + bwd_send_splits = ( + ctx.precomputed_recv_splits + if ctx.precomputed_recv_splits is not None + else recv_counts.tolist() + ) + bwd_recv_splits = ( + ctx.precomputed_send_splits + if ctx.precomputed_send_splits is not None + else send_counts.tolist() + ) + + total_bwd_recv = sum(bwd_recv_splits) + grad_send_back = torch.empty( + total_bwd_recv, + *feature_shape, + device=grad_received.device, + dtype=grad_received.dtype, + ) + + # Reverse all-to-all + backend = dist.get_backend(gp_group) + if backend == "nccl": + dist.all_to_all_single( + grad_send_back, + grad_received.contiguous(), + output_split_sizes=bwd_recv_splits, + input_split_sizes=bwd_send_splits, + group=gp_group, + ) + else: + # Gloo fallback + bwd_send_list = list(grad_received.split(bwd_send_splits)) + bwd_recv_list = list(grad_send_back.split(bwd_recv_splits)) + _safe_all_to_all(bwd_recv_list, bwd_send_list, group=gp_group) + + # Scatter received gradients back to local positions + grad_local = torch.zeros( + local_size, + *feature_shape, + device=grad_received.device, + dtype=grad_received.dtype, + ) + + if total_bwd_recv > 0: + grad_local.index_add_(0, send_indices, grad_send_back) + + # Return gradients for x_local only; None for all other inputs + return ( + grad_local, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +def all_to_all_collect( + x_local: torch.Tensor, + gp_ctx: GPContext, + send_indices: torch.Tensor, +) -> torch.Tensor: + """ + High-level function to collect remote embeddings via all-to-all. + + Returns the received remote embeddings (NOT including local). + The caller should concatenate [x_local, received] and use + gp_ctx.global_to_local to index into this combined tensor. + + Args: + x_local: Local atom embeddings, shape (local_atoms, *features). + gp_ctx: Graph parallel context. + send_indices: Local indices of atoms to send. + + Returns: + x_received: Remote atom embeddings, + shape (total_needed, *features). + """ + if send_indices is None: + raise ValueError( + "send_indices is None — build_gp_context should always " + "compute send_indices. Check GP setup." + ) + return AllToAllCollect.apply( + x_local, + send_indices, + gp_ctx.send_counts, + gp_ctx.recv_counts, + gp_utils.get_gp_group(), + gp_ctx.rank, + gp_ctx.world_size, + gp_ctx.send_splits, + gp_ctx.recv_splits, + gp_ctx.total_recv, + ) + + +def all_to_all_collect_compiled( + x_local: torch.Tensor, + gp_ctx: GPContext, + send_indices: torch.Tensor, +) -> torch.Tensor: + """ + Compile-friendly all-to-all collect using functional collectives. + + Uses ``torch.distributed._functional_collectives.all_to_all_single`` + which is a registered PyTorch op — torch.compile can trace through + it WITHOUT creating a graph break. This eliminates the per-layer + graph break from the ``@torch.compiler.disable`` on + ``AllToAllCollect.forward()``. + + This function does NOT support autograd — gradients will not flow + through the communication. When gradients are needed (e.g., autograd + forces via ``torch.autograd.grad(energy, pos)``), use + ``all_to_all_collect`` instead, which uses an autograd.Function + with proper backward support. + + NOTE: ``all_to_all_single_autograd`` (the funcoll autograd variant) + crashes with torch.compile because it doesn't handle symbolic split + sizes (SymInt). Both BL (all-gather) and A2A have a graph break + when autograd is needed, so this is not a regression vs baseline. + + For MD simulation with spatial partitioning, the split sizes are + effectively constant for hundreds of steps (atoms barely move per + timestep), so torch.compile guards on the split sizes will pass + without recompilation. + + NOTE: Requires NCCL backend. Functional collectives are not + supported on Gloo. CPU/Gloo tests should use ``all_to_all_collect``. + + Args: + x_local: Local atom embeddings, shape (local_atoms, *features). + gp_ctx: Graph parallel context. + send_indices: Local indices of atoms to send. + + Returns: + x_received: Remote atom embeddings, + shape (total_needed, *features). + """ + if send_indices is None: + raise ValueError( + "send_indices is None — build_gp_context should always " + "compute send_indices. Check GP setup." + ) + + # Gather atoms to send (compile-friendly indexing) + x_send = x_local[send_indices].contiguous() + + # Use functional collective — no graph break, no autograd + gp_group = gp_utils.get_gp_group() + + from torch.distributed._functional_collectives import ( + all_to_all_single as functional_a2a, + ) + + x_recv = functional_a2a( + x_send, + output_split_sizes=gp_ctx.recv_splits, + input_split_sizes=gp_ctx.send_splits, + group=gp_group, + ) + + return x_recv diff --git a/tests/core/common/parallelism/test_a2a_correctness.py b/tests/core/common/parallelism/test_a2a_correctness.py new file mode 100644 index 0000000000..d2c458fdae --- /dev/null +++ b/tests/core/common/parallelism/test_a2a_correctness.py @@ -0,0 +1,610 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +Multi-GPU correctness test: A2A (all-to-all) vs BL (all-gather baseline). + +Verifies that the A2A graph parallel implementation produces numerically +identical results to the BL baseline across multiple GPU counts. + +Run directly via torchrun: + torchrun --nproc_per_node=N test_a2a_correctness.py [--natoms 1000] + +Or via pytest (2-process CPU with Gloo): + pytest test_a2a_correctness.py -v + +The test creates an FCC crystal, loads the UMA-S checkpoint, and runs +inference in both BL and A2A modes. The outputs (energy, forces, stress) +are gathered to rank 0 and compared numerically. +""" + +from __future__ import annotations + +import argparse +import json +import logging +import sys + +import pytest +import torch + +from fairchem.core.common import distutils, gp_utils +from fairchem.core.common.parallelism.graph_parallel_a2a import ( + all_to_all_collect, + all_to_all_collect_compiled, + build_gp_context, +) +from fairchem.core.common.parallelism.graph_partition import ( + partition_atoms_index_split, + partition_atoms_spatial, +) +from fairchem.core.common.test_utils import ( + PGConfig, + init_pg_and_rank_and_launch_test, + spawn_multi_process, +) +from fairchem.core.datasets.atomic_data import AtomicData +from fairchem.core.datasets.common_structures import ( + get_fcc_crystal_by_num_atoms, +) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +# ========================================================================= +# Pytest-compatible distributed tests (CPU, Gloo, 2 processes) +# ========================================================================= + + +def _correctness_test_inner( + atomic_numbers, + pos, + edge_index, + num_atoms, + partition_strategy, +): + """ + Inner test function run on each rank. + + Builds GPContext with both BL-style (index_split) and A2A (spatial) + partitioning, runs all-to-all collect, and verifies that the received + embeddings are correct by checking that each received atom's value + matches the expected value from the global embedding tensor. + """ + rank = gp_utils.get_gp_rank() + world_size = gp_utils.get_gp_world_size() + + # Create rank assignments + if partition_strategy == "spatial": + rank_assignments = partition_atoms_spatial(pos, world_size) + else: + rank_assignments = partition_atoms_index_split( + num_atoms, world_size, pos.device + ) + + # Get this rank's partition + node_partition = (rank_assignments == rank).nonzero(as_tuple=True)[0] + + # Filter edges to this rank's partition + target_mask = (rank_assignments == rank)[edge_index[1]] + rank_edge_index = edge_index[:, target_mask] + + # Build GP context + gp_ctx = build_gp_context( + rank_edge_index, + rank_assignments, + rank, + world_size, + node_partition=node_partition, + ) + + # Create a global embedding where each atom's embedding is its + # atomic number (unique per atom). This makes it trivial to verify + # that the right atoms were received. + x_global = atomic_numbers.unsqueeze(1).float() + x_local = x_global[node_partition] + + send_indices = gp_ctx.send_indices + + # Test both collect functions + x_recv_autograd = all_to_all_collect(x_local, gp_ctx, send_indices) + x_recv_compiled = all_to_all_collect_compiled(x_local, gp_ctx, send_indices) + + # Verify shapes + assert x_recv_autograd.shape == x_recv_compiled.shape, ( + f"Rank {rank}: shape mismatch " + f"autograd={x_recv_autograd.shape} " + f"vs compiled={x_recv_compiled.shape}" + ) + + # Verify autograd == compiled + values_match = torch.allclose(x_recv_autograd, x_recv_compiled, atol=1e-6) + + # Verify received values are correct: + # x_recv should contain embeddings of gp_ctx.needed_atoms + # in the correct order (sorted by source rank). + expected_values = x_global[gp_ctx.needed_atoms] + recv_correct = torch.allclose(x_recv_autograd, expected_values, atol=1e-6) + + # Verify edge_index_local is valid + x_full = torch.cat([x_local, x_recv_autograd], dim=0) + edge_valid = (gp_ctx.edge_index_local >= 0).all().item() + edge_in_bounds = (gp_ctx.edge_index_local < x_full.shape[0]).all().item() + + # Verify message passing produces the same result as + # non-distributed. Simple sum aggregation: for each local target, + # sum source embeddings. + x_source = x_full[gp_ctx.edge_index_local[0]] + local_result = torch.zeros( + gp_ctx.total_local_atoms, + x_source.shape[1], + dtype=x_source.dtype, + device=x_source.device, + ) + local_result.index_add_(0, gp_ctx.edge_index_local[1], x_source) + + # Reference: compute the same aggregation on the full graph + x_source_ref = x_global[rank_edge_index[0]] + ref_result = torch.zeros( + num_atoms, + x_source_ref.shape[1], + dtype=x_source_ref.dtype, + device=x_source_ref.device, + ) + ref_result.index_add_(0, rank_edge_index[1], x_source_ref) + ref_local = ref_result[node_partition] + + mp_match = torch.allclose(local_result, ref_local, atol=1e-6) + + return { + "rank": rank, + "partition_strategy": partition_strategy, + "world_size": world_size, + "local_atoms": gp_ctx.total_local_atoms, + "needed_atoms": gp_ctx.total_needed_atoms, + "num_edges": rank_edge_index.shape[1], + "values_match": values_match, + "recv_correct": recv_correct, + "edge_valid": edge_valid, + "edge_in_bounds": edge_in_bounds, + "mp_match": mp_match, + } + + +@pytest.mark.parametrize( + "strategy,num_atoms", + [ + ("index_split", 8), + ("index_split", 20), + ("spatial", 8), + ("spatial", 20), + ], +) +def test_a2a_correctness_gloo(strategy, num_atoms): + """ + Verify A2A correctness at 2 GPUs using Gloo backend. + + Creates a dense graph (all atoms connected) and verifies that: + 1. Autograd and compiled collect produce identical results + 2. Received embeddings contain correct values + 3. Message passing produces correct aggregation + """ + # Create dense graph + src, dst = [], [] + for i in range(num_atoms): + for j in range(num_atoms): + if i != j: + src.append(i) + dst.append(j) + edge_index = torch.tensor([src, dst], dtype=torch.long) + atomic_numbers = torch.arange(2, 2 + num_atoms, dtype=torch.float) + pos = torch.randn(num_atoms, 3) * 10 # spread out for spatial + + config = PGConfig(backend="gloo", world_size=2, gp_group_size=2, use_gp=True) + all_rank_results = spawn_multi_process( + config, + _correctness_test_inner, + init_pg_and_rank_and_launch_test, + atomic_numbers, + pos, + edge_index, + num_atoms, + strategy, + ) + + for result in all_rank_results: + r = result["rank"] + assert result["values_match"], f"Rank {r}: autograd vs compiled mismatch" + assert result["recv_correct"], ( + f"Rank {r}: received embeddings don't match " f"expected values" + ) + assert result["edge_valid"], f"Rank {r}: edge_index_local has negative entries" + assert result[ + "edge_in_bounds" + ], f"Rank {r}: edge_index_local has out-of-bounds entries" + assert result["mp_match"], ( + f"Rank {r}: message passing result differs " f"from reference" + ) + + +@pytest.mark.parametrize("strategy", ["index_split", "spatial"]) +def test_a2a_consistency_across_graph_sizes(strategy): + """ + Verify A2A correctness with sparse graphs (not all-to-all + connected). + + Uses a chain graph (each atom connected to its neighbors within + distance 2) to test the case where not every rank needs atoms + from every other rank. + """ + num_atoms = 16 + + # Chain graph: atom i connected to i-1, i+1, i-2, i+2 + src, dst = [], [] + for i in range(num_atoms): + for d in [-2, -1, 1, 2]: + j = (i + d) % num_atoms # wrap around + src.append(i) + dst.append(j) + edge_index = torch.tensor([src, dst], dtype=torch.long) + atomic_numbers = torch.arange(10, 10 + num_atoms, dtype=torch.float) + # Linear arrangement for clear spatial partitioning + pos = torch.zeros(num_atoms, 3) + pos[:, 0] = torch.arange(num_atoms, dtype=torch.float) + + config = PGConfig(backend="gloo", world_size=2, gp_group_size=2, use_gp=True) + all_rank_results = spawn_multi_process( + config, + _correctness_test_inner, + init_pg_and_rank_and_launch_test, + atomic_numbers, + pos, + edge_index, + num_atoms, + strategy, + ) + + for result in all_rank_results: + r = result["rank"] + assert result["values_match"], f"Rank {r}: autograd vs compiled mismatch" + assert result["recv_correct"], ( + f"Rank {r}: received embeddings don't match " f"expected values" + ) + assert result["mp_match"], ( + f"Rank {r}: message passing result differs " f"from reference" + ) + + +def _multidim_test_inner(x_global, pos, edge_index, num_atoms, strategy): + """ + Test A2A correctness with multi-dimensional embeddings. + Defined at module level for pickle compatibility with + multiprocessing. + """ + rank = gp_utils.get_gp_rank() + world_size = gp_utils.get_gp_world_size() + + if strategy == "spatial": + rank_assignments = partition_atoms_spatial(pos, world_size) + else: + rank_assignments = partition_atoms_index_split( + num_atoms, world_size, pos.device + ) + + node_partition = (rank_assignments == rank).nonzero(as_tuple=True)[0] + target_mask = (rank_assignments == rank)[edge_index[1]] + rank_edge_index = edge_index[:, target_mask] + + gp_ctx = build_gp_context( + rank_edge_index, + rank_assignments, + rank, + world_size, + node_partition=node_partition, + ) + + x_local = x_global[node_partition] + send_indices = gp_ctx.send_indices + + x_recv = all_to_all_collect(x_local, gp_ctx, send_indices) + x_recv_c = all_to_all_collect_compiled(x_local, gp_ctx, send_indices) + + # Verify + expected = x_global[gp_ctx.needed_atoms] + recv_correct = torch.allclose(x_recv, expected, atol=1e-6) + compiled_match = torch.allclose(x_recv, x_recv_c, atol=1e-6) + + return { + "rank": rank, + "recv_correct": recv_correct, + "compiled_match": compiled_match, + "recv_shape": x_recv.shape, + "expected_shape": expected.shape, + } + + +@pytest.mark.parametrize("strategy", ["index_split", "spatial"]) +def test_a2a_multidim_embeddings(strategy): + """ + Verify correctness with multi-dimensional embeddings (not just + scalars). + + Uses 16-dim embeddings to match the typical sphere_channels + in UMA. + """ + num_atoms = 12 + embed_dim = 16 + + src, dst = [], [] + for i in range(num_atoms): + for j in range(num_atoms): + if i != j: + src.append(i) + dst.append(j) + edge_index = torch.tensor([src, dst], dtype=torch.long) + # Use random embeddings instead of scalar atomic numbers + torch.manual_seed(42) + x_global = torch.randn(num_atoms, embed_dim) + pos = torch.randn(num_atoms, 3) * 10 + + config = PGConfig(backend="gloo", world_size=2, gp_group_size=2, use_gp=True) + + all_rank_results = spawn_multi_process( + config, + _multidim_test_inner, + init_pg_and_rank_and_launch_test, + x_global, + pos, + edge_index, + num_atoms, + strategy, + ) + + for result in all_rank_results: + r = result["rank"] + assert result["recv_correct"], ( + f"Rank {r}: multidim recv mismatch, " + f"shape={result['recv_shape']} " + f"vs {result['expected_shape']}" + ) + assert result["compiled_match"], ( + f"Rank {r}: autograd vs compiled mismatch " f"for multidim" + ) + + +# ========================================================================= +# Full model correctness test (GPU, run via torchrun or SLURM) +# ========================================================================= + + +def _resolve_checkpoint(): + """ + Resolve the UMA-S checkpoint path using the fairchem pretrained + model API. + """ + from fairchem.core.calculate.pretrained_mlip import ( + pretrained_checkpoint_path_from_name, + ) + + return pretrained_checkpoint_path_from_name(model_name="uma-s-1p2") + + +def _run_full_model_comparison( + natoms: int = 1000, + results_file: str | None = None, +): + """ + Run the full UMA-S model in both BL and A2A modes and compare + outputs. + + Must be called inside a torchrun process group. + """ + from fairchem.core.units.mlip_unit import MLIPPredictUnit + + rank = distutils.get_rank() + world_size = distutils.get_world_size() + + if rank == 0: + logger.info(f"Running correctness test: {natoms} atoms, " f"{world_size} GPUs") + + checkpoint_path = _resolve_checkpoint() + if rank == 0: + logger.info(f"Using checkpoint: {checkpoint_path}") + + # Create input system + atoms = get_fcc_crystal_by_num_atoms(natoms, atom_type="Al") + actual_natoms = len(atoms) + if rank == 0: + logger.info(f"Created FCC Al crystal: {actual_natoms} atoms") + + data = AtomicData.from_ase( + input_atoms=atoms, + max_neigh=200, + radius=6.0, + task_name="oc20", + r_edges=False, + r_data_keys=["spin", "charge"], + ) + + # -- Run BL (all-gather baseline) -- + if rank == 0: + logger.info("Loading model for BL (all-gather) mode...") + + predictor_bl = MLIPPredictUnit.from_checkpoint( + checkpoint_path, + device=torch.device("cuda"), + inference_settings={ + "tf32": False, + "compile": False, + "activation_checkpointing": False, + "merge_mole": False, + }, + overrides={ + "backbone": { + "use_all_to_all_gp": False, + }, + }, + ) + predictor_bl.model.eval() + + # Warm up + run BL + with torch.no_grad(): + _ = predictor_bl.predict(data) + bl_out = predictor_bl.predict(data) + + bl_energy = bl_out["energy"].clone() + bl_forces = bl_out["forces"].clone() + bl_stress = bl_out.get("stress", torch.tensor([])).clone() + + if rank == 0: + logger.info(f"BL energy: {bl_energy.item():.6f}") + logger.info(f"BL forces shape: {bl_forces.shape}") + logger.info(f"BL forces norm: {bl_forces.norm():.6f}") + + # Clean up BL model + del predictor_bl + torch.cuda.empty_cache() + + # -- Run A2A (all-to-all with spatial partitioning) -- + if rank == 0: + logger.info("Loading model for A2A (all-to-all) mode...") + + predictor_a2a = MLIPPredictUnit.from_checkpoint( + checkpoint_path, + device=torch.device("cuda"), + inference_settings={ + "tf32": False, + "compile": False, + "activation_checkpointing": False, + "merge_mole": False, + }, + overrides={ + "backbone": { + "use_all_to_all_gp": True, + "gp_partition_strategy": "spatial", + }, + }, + ) + predictor_a2a.model.eval() + + # Warm up + run A2A + with torch.no_grad(): + _ = predictor_a2a.predict(data) + a2a_out = predictor_a2a.predict(data) + + a2a_energy = a2a_out["energy"].clone() + a2a_forces = a2a_out["forces"].clone() + a2a_stress = a2a_out.get("stress", torch.tensor([])).clone() + + if rank == 0: + logger.info(f"A2A energy: {a2a_energy.item():.6f}") + logger.info(f"A2A forces shape: {a2a_forces.shape}") + logger.info(f"A2A forces norm: {a2a_forces.norm():.6f}") + + # -- Compare outputs -- + # Energy should match across all ranks (reduced) + energy_diff = abs(bl_energy.item() - a2a_energy.item()) + energy_match = energy_diff < 1e-4 + + # Forces: each rank only has forces for its local atoms. + # Gather all forces to rank 0 for comparison. + # BL forces are already the full set on all ranks. + # A2A forces need gathering. + if bl_forces.shape == a2a_forces.shape: + force_diff = (bl_forces - a2a_forces).abs().max().item() + force_match = force_diff < 1e-4 + force_rmse = (bl_forces - a2a_forces).pow(2).mean().sqrt().item() + else: + # Different shapes — gather and compare + force_diff = float("nan") + force_match = False + force_rmse = float("nan") + + # Stress + if bl_stress.numel() > 0 and a2a_stress.numel() > 0: + stress_diff = (bl_stress - a2a_stress).abs().max().item() + stress_match = stress_diff < 1e-4 + else: + stress_diff = 0.0 + stress_match = True + + results = { + "natoms": actual_natoms, + "world_size": world_size, + "energy_bl": bl_energy.item(), + "energy_a2a": a2a_energy.item(), + "energy_diff": energy_diff, + "energy_match": energy_match, + "force_max_diff": force_diff, + "force_rmse": force_rmse, + "force_match": force_match, + "stress_max_diff": stress_diff, + "stress_match": stress_match, + "all_match": energy_match and force_match and stress_match, + } + + if rank == 0: + logger.info(f"\n{'=' * 60}") + logger.info("CORRECTNESS TEST RESULTS") + logger.info(f"{'=' * 60}") + logger.info(f"Atoms: {actual_natoms}") + logger.info(f"GPUs: {world_size}") + logger.info(f"Energy BL: {bl_energy.item():.6f}") + logger.info(f"Energy A2A: {a2a_energy.item():.6f}") + logger.info(f"Energy diff: {energy_diff:.2e}") + logger.info(f"Force max Δ: {force_diff:.2e}") + logger.info(f"Force RMSE: {force_rmse:.2e}") + logger.info(f"Stress max Δ: {stress_diff:.2e}") + status = "✓ PASS" if results["all_match"] else "✗ FAIL" + logger.info(f"ALL MATCH: {status}") + logger.info(f"{'=' * 60}") + + if results_file: + with open(results_file, "w") as f: + json.dump(results, f, indent=2) + logger.info(f"Results saved to {results_file}") + + return results + + +# ========================================================================= +# CLI entrypoint for SLURM / torchrun +# ========================================================================= + + +def main(): + parser = argparse.ArgumentParser(description="A2A vs BL correctness test") + parser.add_argument( + "--natoms", + type=int, + default=1000, + help="Target number of atoms in FCC crystal", + ) + parser.add_argument( + "--results-file", + type=str, + default=None, + help="Path to save JSON results", + ) + args = parser.parse_args() + + # Initialize distributed + distutils.setup({"submit": False, "cpu": False}) + gp_utils.setup_gp(distutils.get_world_size()) + + try: + results = _run_full_model_comparison( + natoms=args.natoms, + results_file=args.results_file, + ) + if not results["all_match"]: + sys.exit(1) + finally: + distutils.cleanup() + + +if __name__ == "__main__": + main() diff --git a/tests/core/common/parallelism/test_graph_parallel.py b/tests/core/common/parallelism/test_graph_parallel.py new file mode 100644 index 0000000000..a2bb71b1ec --- /dev/null +++ b/tests/core/common/parallelism/test_graph_parallel.py @@ -0,0 +1,651 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import pytest +import torch + +from fairchem.core.common import gp_utils +from fairchem.core.common.gp_utils import ( + gather_from_model_parallel_region_sum_grad, + size_list_fn, +) +from fairchem.core.common.parallelism.graph_parallel_a2a import ( + all_to_all_collect, + all_to_all_collect_compiled, + build_gp_context, +) +from fairchem.core.common.parallelism.graph_partition import ( + partition_atoms_index_split, + partition_atoms_spatial, +) +from fairchem.core.common.test_utils import ( + PGConfig, + init_pg_and_rank_and_launch_test, + spawn_multi_process, +) + +pytestmark = pytest.mark.serial + + +# ========================================================================= +# Unit tests (no distributed, CPU only) +# ========================================================================= + + +class TestBuildGPContext: + """ + Tests for build_gp_context (non-distributed, simulates single rank). + """ + + def test_basic_context_building(self): + """ + Test with a simple graph where all atoms are on rank 0. + """ + # 4 atoms, 2 ranks, rank 0 owns [0,1], rank 1 owns [2,3] + edge_index = torch.tensor([[0, 1, 2, 3, 2], [1, 0, 3, 2, 0]]) + rank_assignments = torch.tensor([0, 0, 1, 1]) + + # Build context for rank 0 (no distributed env) + ctx = build_gp_context(edge_index, rank_assignments, rank=0, world_size=2) + + assert ctx.rank == 0 + assert ctx.world_size == 2 + assert ctx.total_local_atoms == 2 + assert torch.equal(ctx.node_partition, torch.tensor([0, 1])) + + # Rank 0 targets: atoms 0 and 1 + # Edge (2, 0): src=2 is remote, tgt=0 is local -> need atom 2 + # Edge (1, 0): src=1 is local -> don't need + # Edge (0, 1): src=0 is local -> don't need + assert 2 in ctx.needed_atoms + + def test_global_to_local_mapping(self): + """ + Verify that global_to_local correctly maps to local indices. + """ + edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]) + rank_assignments = torch.tensor([0, 0, 1, 1]) + + ctx = build_gp_context(edge_index, rank_assignments, rank=0, world_size=2) + + # Local atoms [0, 1] should map to indices [0, 1] + assert ctx.global_to_local[0] == 0 + assert ctx.global_to_local[1] == 1 + + # Remote atoms that are needed should map to + # indices >= total_local_atoms + for atom in ctx.needed_atoms: + local_idx = ctx.global_to_local[atom].item() + assert local_idx >= ctx.total_local_atoms + + def test_no_cross_partition_edges(self): + """ + When no edges cross partitions, no remote atoms are needed. + """ + edge_index = torch.tensor([[0, 1], [1, 0]]) # Only within rank 0 + rank_assignments = torch.tensor([0, 0, 1, 1]) + + ctx = build_gp_context(edge_index, rank_assignments, rank=0, world_size=2) + assert ctx.total_needed_atoms == 0 + assert ctx.needed_atoms.numel() == 0 + + def test_edge_split_indices(self): + """ + Verify local_edge_idx and remote_edge_idx correctly split edges + by source ownership. + """ + # 4 atoms, 2 ranks: rank 0 owns [0,1], rank 1 owns [2,3] + # Edges: (0,1) local-src, (1,0) local-src, + # (2,0) remote-src, (3,1) remote-src + edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 0, 1]]) + rank_assignments = torch.tensor([0, 0, 1, 1]) + ctx = build_gp_context( + edge_index, + rank_assignments, + rank=0, + world_size=2, + ) + + assert ctx.local_edge_idx is not None + assert ctx.remote_edge_idx is not None + + # Check counts: 2 local-src edges (0,1) and (1,0), + # 2 remote-src edges (2,0) and (3,1) + edge_index_local = ctx.edge_index_local + n_local = ctx.total_local_atoms # 2 + + local_srcs = edge_index_local[0, ctx.local_edge_idx] + remote_srcs = edge_index_local[0, ctx.remote_edge_idx] + + assert (local_srcs < n_local).all() + assert (remote_srcs >= n_local).all() + + # Together they cover all edges + assert ( + ctx.local_edge_idx.numel() + ctx.remote_edge_idx.numel() + == edge_index_local.shape[1] + ) + + def test_edge_split_no_remote_edges(self): + """ + When all edges are local-source, remote_edge_idx should be empty. + """ + edge_index = torch.tensor([[0, 1], [1, 0]]) + rank_assignments = torch.tensor([0, 0, 1, 1]) + ctx = build_gp_context( + edge_index, + rank_assignments, + rank=0, + world_size=2, + ) + + assert ctx.local_edge_idx.numel() == 2 + assert ctx.remote_edge_idx.numel() == 0 + + +# ========================================================================= +# Distributed tests: A2A vs All-Gather correctness +# ========================================================================= + + +def _a2a_simple_layer(x, edge_index, rank_assignments, natoms): + """ + A simple message passing layer using all-to-all communication. + Computes same result as all-gather version but using all-to-all. + """ + rank = gp_utils.get_gp_rank() + world_size = gp_utils.get_gp_world_size() + + # Build GP context (send_indices computed inline) + gp_ctx = build_gp_context(edge_index, rank_assignments, rank, world_size) + + send_indices = gp_ctx.send_indices + + # All-to-all collect + x_received = all_to_all_collect(x, gp_ctx, send_indices) + + # Combine local + received + x_full = torch.cat([x, x_received], dim=0) + + # Remap edges to local space (inline, no helper function needed) + edge_index_local = gp_ctx.global_to_local[edge_index] + + # Simple message passing: source embeddings aggregated to targets + x_source = x_full[edge_index_local[0]] + x_target = x_full[edge_index_local[1]] + + edge_embeddings = (x_source + 1).pow(1.5) * (x_target + 1).pow(1.5) + + # Aggregate to local atoms only + local_atoms = gp_ctx.total_local_atoms + new_node_embedding = torch.zeros( + local_atoms, + *edge_embeddings.shape[1:], + dtype=edge_embeddings.dtype, + device=edge_embeddings.device, + ) + # Target indices in local space are in [0, local_atoms) + # for local targets + local_target_mask = edge_index_local[1] < local_atoms + local_edge_idx = edge_index_local[:, local_target_mask] + local_edge_emb = edge_embeddings[local_target_mask] + new_node_embedding.index_add_(0, local_edge_idx[1], local_edge_emb) + + return new_node_embedding + + +def _allgather_simple_layer(x, edge_index, node_offset, natoms): + """ + A simple message passing layer using all-gather (baseline). + """ + x_full = gather_from_model_parallel_region_sum_grad(x, natoms) + + x_source = x_full[edge_index[0]] + x_target = x_full[edge_index[1]] + + local_atoms = size_list_fn(natoms, gp_utils.get_gp_world_size())[ + gp_utils.get_gp_rank() + ] + + edge_embeddings = (x_source + 1).pow(1.5) * (x_target + 1).pow(1.5) + + new_node_embedding = torch.zeros( + local_atoms, + *edge_embeddings.shape[1:], + dtype=edge_embeddings.dtype, + device=edge_embeddings.device, + ) + new_node_embedding.index_add_(0, edge_index[1] - node_offset, edge_embeddings) + + return new_node_embedding + + +def a2a_vs_allgather_test(atomic_numbers, edge_index): + """ + Compare all-to-all and all-gather results on the same simple layer. + Both should produce identical output. + """ + rank = gp_utils.get_gp_rank() + world_size = gp_utils.get_gp_world_size() + natoms = atomic_numbers.shape[0] + + # Partition atoms (same as gp_utils does) + node_partition = torch.tensor_split(torch.arange(natoms), world_size)[rank] + node_offset = node_partition.min().item() + + # Create rank assignments + rank_assignments = partition_atoms_index_split( + natoms, world_size, torch.device("cpu") + ) + + # Filter edges: keep edges where target is in our partition + target_in_partition = (edge_index[1] >= node_partition.min()) & ( + edge_index[1] <= node_partition.max() + ) + local_edge_index = edge_index[:, target_in_partition] + + # Local embeddings (just use atomic numbers as embedding) + x_local = atomic_numbers[node_partition].clone().unsqueeze(-1) + + # Run all-gather version + result_ag = _allgather_simple_layer(x_local, local_edge_index, node_offset, natoms) + + # Run all-to-all version + result_a2a = _a2a_simple_layer(x_local, local_edge_index, rank_assignments, natoms) + + return { + "rank": rank, + "allgather": result_ag.detach(), + "all_to_all": result_a2a.detach(), + "match": torch.allclose(result_ag, result_a2a, atol=1e-6), + } + + +@pytest.mark.parametrize( + "num_atoms, edges", + [ + # Simple linear chain: 0-1-2-3 + (4, [[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]), + # Star graph: 0 connected to all + (5, [[0, 0, 0, 0, 1, 2, 3, 4], [1, 2, 3, 4, 0, 0, 0, 0]]), + # Dense graph: all-to-all edges (4 atoms) + ( + 4, + [ + [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], + [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2], + ], + ), + ], +) +def test_a2a_vs_allgather(num_atoms, edges): + """ + Verify that all-to-all produces the same results as all-gather + for a simple message passing layer. + """ + atomic_numbers = torch.arange( + 2, 2 + num_atoms, dtype=torch.float, requires_grad=False + ) + edge_index = torch.tensor(edges, dtype=torch.long) + + config = PGConfig(backend="gloo", world_size=2, gp_group_size=2, use_gp=True) + all_rank_results = spawn_multi_process( + config, + a2a_vs_allgather_test, + init_pg_and_rank_and_launch_test, + atomic_numbers, + edge_index, + ) + + for result in all_rank_results: + assert result["match"], ( + f"Rank {result['rank']}: all-gather and all-to-all produced " + f"different results.\n" + f"allgather: {result['allgather']}\n" + f"all_to_all: {result['all_to_all']}" + ) + + +def a2a_backward_test(atomic_numbers, edge_index): + """ + Test that the backward pass of all-to-all produces correct gradients + by comparing energy and forces computed with all-gather vs all-to-all. + """ + rank = gp_utils.get_gp_rank() + world_size = gp_utils.get_gp_world_size() + natoms = atomic_numbers.shape[0] + + # Partition atoms + node_partition = torch.tensor_split(torch.arange(natoms), world_size)[rank] + node_offset = node_partition.min().item() + + rank_assignments = partition_atoms_index_split( + natoms, world_size, torch.device("cpu") + ) + + # Filter edges + target_in_partition = (edge_index[1] >= node_partition.min()) & ( + edge_index[1] <= node_partition.max() + ) + local_edge_index = edge_index[:, target_in_partition] + + results = {} + + for method in ["allgather", "all_to_all"]: + # Need fresh requires_grad for each method + x_local = atomic_numbers[node_partition].clone().detach() + x_local = x_local.unsqueeze(-1).requires_grad_(True) + + if method == "allgather": + embedding = _allgather_simple_layer( + x_local, local_edge_index, node_offset, natoms + ) + else: + embedding = _a2a_simple_layer( + x_local, local_edge_index, rank_assignments, natoms + ) + + # Compute local energy contribution + energy_part = embedding.sum() + energy = gp_utils.reduce_from_model_parallel_region(energy_part) + + # Compute forces (gradient w.r.t. x_local) + forces = torch.autograd.grad( + [energy], + [x_local], + create_graph=False, + )[0] + + results[f"{method}_energy"] = energy.detach() + results[f"{method}_forces"] = forces.detach() + + results["rank"] = rank + results["energy_match"] = torch.allclose( + results["allgather_energy"], + results["all_to_all_energy"], + atol=1e-5, + ) + results["forces_match"] = torch.allclose( + results["allgather_forces"], + results["all_to_all_forces"], + atol=1e-5, + ) + + return results + + +def test_a2a_backward(): + """ + Verify that backward pass of all-to-all matches all-gather. + """ + atomic_numbers = torch.tensor([2.0, 3.0, 5.0, 7.0]) + edge_index = torch.tensor( + [[0, 0, 1, 1, 2, 2, 3, 3], [1, 2, 0, 3, 0, 3, 1, 2]], + dtype=torch.long, + ) + + config = PGConfig(backend="gloo", world_size=2, gp_group_size=2, use_gp=True) + all_rank_results = spawn_multi_process( + config, + a2a_backward_test, + init_pg_and_rank_and_launch_test, + atomic_numbers, + edge_index, + ) + + for result in all_rank_results: + assert result["energy_match"], ( + f"Rank {result['rank']}: energy mismatch. " + f"AG={result['allgather_energy']}, " + f"A2A={result['all_to_all_energy']}" + ) + assert result["forces_match"], ( + f"Rank {result['rank']}: forces mismatch. " + f"AG={result['allgather_forces']}, " + f"A2A={result['all_to_all_forces']}" + ) + + +@pytest.mark.parametrize("world_size", [2, 3]) +def test_a2a_multi_rank(world_size): + """ + Test all-to-all vs all-gather with varying number of GP ranks. + """ + num_atoms = 6 + # Create a ring graph + src = list(range(num_atoms)) + dst = [(i + 1) % num_atoms for i in range(num_atoms)] + # Bidirectional + edge_src = src + dst + edge_dst = dst + src + edge_index = torch.tensor([edge_src, edge_dst], dtype=torch.long) + + atomic_numbers = torch.arange(2, 2 + num_atoms, dtype=torch.float) + + config = PGConfig( + backend="gloo", + world_size=world_size, + gp_group_size=world_size, + use_gp=True, + ) + all_rank_results = spawn_multi_process( + config, + a2a_vs_allgather_test, + init_pg_and_rank_and_launch_test, + atomic_numbers, + edge_index, + ) + + for result in all_rank_results: + assert result["match"], ( + f"world_size={world_size}, " f"rank {result['rank']}: mismatch" + ) + + +def a2a_spatial_partition_test(atomic_numbers, edge_index, pos): + """ + Test all-to-all with spatial partitioning produces correct results + by comparing to all-gather (which always uses index-based + partitioning). + """ + rank = gp_utils.get_gp_rank() + world_size = gp_utils.get_gp_world_size() + natoms = atomic_numbers.shape[0] + + # --- All-gather with index-based partitioning (baseline) --- + node_partition_idx = torch.tensor_split(torch.arange(natoms), world_size)[rank] + node_offset_idx = node_partition_idx.min().item() + + target_in_partition_idx = (edge_index[1] >= node_partition_idx.min()) & ( + edge_index[1] <= node_partition_idx.max() + ) + local_edge_index_idx = edge_index[:, target_in_partition_idx] + + x_local_idx = atomic_numbers[node_partition_idx].clone().unsqueeze(-1) + result_ag = _allgather_simple_layer( + x_local_idx, local_edge_index_idx, node_offset_idx, natoms + ) + + # --- All-to-all with spatial partitioning --- + rank_assignments_spatial = partition_atoms_spatial(pos, world_size) + local_mask = rank_assignments_spatial == rank + node_partition_sp = local_mask.nonzero(as_tuple=True)[0] + + target_in_partition_sp = local_mask[edge_index[1]] + local_edge_index_sp = edge_index[:, target_in_partition_sp] + + x_local_sp = atomic_numbers[node_partition_sp].clone().unsqueeze(-1) + result_a2a = _a2a_simple_layer( + x_local_sp, + local_edge_index_sp, + rank_assignments_spatial, + natoms, + ) + + # Both methods compute message passing over the SAME global graph, + # so local atoms get the same aggregated messages regardless of + # which partition strategy is used. However, different ranks own + # different atoms under spatial vs index partitioning, so we + # gather all results and compare the full output. + # Gather all local results to rank 0 for comparison + full_ag = gather_from_model_parallel_region_sum_grad(result_ag, natoms) + full_a2a = gather_from_model_parallel_region_sum_grad(result_a2a, natoms) + + return { + "rank": rank, + "allgather_full": full_ag.detach(), + "all_to_all_full": full_a2a.detach(), + "match": torch.allclose(full_ag, full_a2a, atol=1e-5), + } + + +def test_a2a_spatial_partition(): + """ + Verify that all-to-all with spatial partitioning produces the same + global results as all-gather with index partitioning. + """ + num_atoms = 8 + # Create atoms in two spatial clusters + pos = torch.cat( + [ + torch.randn(4, 3) + torch.tensor([0.0, 0.0, 0.0]), + torch.randn(4, 3) + torch.tensor([100.0, 0.0, 0.0]), + ] + ) + atomic_numbers = torch.arange( + 2, 2 + num_atoms, dtype=torch.float, requires_grad=False + ) + # Dense graph connecting all atoms + src = [] + dst = [] + for i in range(num_atoms): + for j in range(num_atoms): + if i != j: + src.append(i) + dst.append(j) + edge_index = torch.tensor([src, dst], dtype=torch.long) + + config = PGConfig(backend="gloo", world_size=2, gp_group_size=2, use_gp=True) + all_rank_results = spawn_multi_process( + config, + a2a_spatial_partition_test, + init_pg_and_rank_and_launch_test, + atomic_numbers, + edge_index, + pos, + ) + + for result in all_rank_results: + assert result["match"], ( + f"Rank {result['rank']}: spatial partitioning produced " + f"different global results than index partitioning" + ) + + +# ========================================================================= +# Distributed tests: compiled collect correctness +# ========================================================================= + + +def compiled_collect_test( + atomic_numbers: torch.Tensor, + edge_index: torch.Tensor, + strategy: str, +): + """ + Verify that all_to_all_collect_compiled produces same results + as the original all_to_all_collect (autograd version). + """ + rank = gp_utils.get_gp_rank() + world_size = gp_utils.get_gp_world_size() + + # Partition atoms (deterministic positions — same across all ranks) + if strategy == "spatial": + torch.manual_seed(42) + pos = torch.rand(atomic_numbers.shape[0], 3) + rank_assignments = partition_atoms_spatial(pos, world_size, pos.device) + else: + rank_assignments = partition_atoms_index_split( + atomic_numbers.shape[0], world_size, atomic_numbers.device + ) + + # Filter edges to this rank's partition + node_partition = (rank_assignments == rank).nonzero(as_tuple=True)[0] + target_mask = (rank_assignments == rank)[edge_index[1]] + rank_edge_index = edge_index[:, target_mask] + + gp_ctx = build_gp_context( + rank_edge_index, + rank_assignments, + rank, + world_size, + node_partition=node_partition, + ) + + send_indices = gp_ctx.send_indices + + x = atomic_numbers[node_partition].unsqueeze(1).float() + + # Reference: original autograd version + x_ref = all_to_all_collect(x, gp_ctx, send_indices) + + # Test: compiled functional version + x_compiled = all_to_all_collect_compiled(x, gp_ctx, send_indices) + + match = torch.allclose(x_ref, x_compiled, atol=1e-6) + shape_match = x_ref.shape == x_compiled.shape + + return { + "rank": rank, + "match": match, + "shape_match": shape_match, + "ref_shape": x_ref.shape, + "compiled_shape": x_compiled.shape, + } + + +@pytest.mark.parametrize("strategy", ["index_split", "spatial"]) +def test_compiled_collect_matches_autograd(strategy): + """ + Verify that compile-friendly functional collective produces + identical results to the @torch.compiler.disable autograd version. + """ + num_atoms = 12 + src = [] + dst = [] + for i in range(num_atoms): + for j in range(num_atoms): + if i != j: + src.append(i) + dst.append(j) + edge_index = torch.tensor([src, dst], dtype=torch.long) + atomic_numbers = torch.arange( + 2, 2 + num_atoms, dtype=torch.float, requires_grad=False + ) + + config = PGConfig(backend="gloo", world_size=2, gp_group_size=2, use_gp=True) + all_rank_results = spawn_multi_process( + config, + compiled_collect_test, + init_pg_and_rank_and_launch_test, + atomic_numbers, + edge_index, + strategy, + ) + + for result in all_rank_results: + assert result["shape_match"], ( + f"Rank {result['rank']}: shape mismatch. " + f"ref={result['ref_shape']}, " + f"compiled={result['compiled_shape']}" + ) + assert result["match"], ( + f"Rank {result['rank']}: functional mismatch — " + f"compiled collect produced different embeddings" + ) From 530b685856490dcf90e2d38da6b7a51057b491eb Mon Sep 17 00:00:00 2001 From: rgao user Date: Sat, 2 May 2026 07:09:26 +0000 Subject: [PATCH 02/26] Integrate A2A graph parallel into eSCN-MD backbone Model integration for all-to-all graph parallelism: eSCN-MD backbone (escn_md.py): - Add use_all_to_all_gp and gp_partition_strategy config options - Replace all-gather with A2A embedding exchange when enabled - Spatial partitioning via Morton Z-order curve - AABB halo filtering for reduced graph generation input - Support both autograd and compiled A2A collect variants - Block allgather+spatial combination (unsupported) eSCN-MD block (escn_md_block.py): - Add A2A collect integration in message passing layers - Use precomputed edge split indices for local/remote separation Graph generation (compute.py): - Extend filter_edges_by_node_partition with send_info computation - Eliminates NCCL index-exchange collective in build_gp_context - Pass rank_assignments through generate_graph Embedding/execution backends: - Thread GP context through embedding and execution layers Tests: - Add send_info optimization correctness test - Expand ParallelPredictUnit tests with A2A+spatial and A2A+index_split GP modes (CPU: workers=2, GPU: workers=1) - 48 parallelism tests + 25 predict tests all pass --- src/fairchem/core/graph/compute.py | 129 +++++- src/fairchem/core/models/uma/escn_md.py | 387 ++++++++++++++++-- src/fairchem/core/models/uma/escn_md_block.py | 102 ++++- src/fairchem/core/models/uma/nn/embedding.py | 21 +- .../core/models/uma/nn/execution_backends.py | 28 +- .../common/parallelism/test_graph_parallel.py | 228 ++++++++++- tests/core/units/mlip_unit/test_predict.py | 78 +++- 7 files changed, 856 insertions(+), 117 deletions(-) diff --git a/src/fairchem/core/graph/compute.py b/src/fairchem/core/graph/compute.py index 83e18dc1a3..beb3dc2adc 100644 --- a/src/fairchem/core/graph/compute.py +++ b/src/fairchem/core/graph/compute.py @@ -22,27 +22,89 @@ def filter_edges_by_node_partition( cell_offsets: torch.Tensor, neighbors: torch.Tensor, num_atoms: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Filter edges to keep only those where the target atom belongs to the node partition. - edge_index is shape (2, num_edges) where the first row is the source atom index and the second row is the target atom index for each edge - cell_offsets is shape (num_edges, 3) - neighbors is cardinality of the edge_index per system in the batch + rank_assignments: torch.Tensor | None = None, + rank: int | None = None, + world_size: int | None = None, +) -> ( + tuple[torch.Tensor, torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict] +): + """ + Filter edges to keep only those where the target atom belongs to + the node partition. + + When rank_assignments, rank, and world_size are provided, also + computes send_info: which local atoms need to be sent to which + ranks for all-to-all graph parallel communication. This exploits + access to the full (pre-filter) edge_index to derive send + metadata locally, eliminating the need for an NCCL index-exchange + collective in build_gp_context. Args: - node_partition: Tensor of atom indices belonging to the current rank's partition. - edge_index: Edge index tensor of shape (2, num_edges), where row 0 is the source and 1 is the target atom. - cell_offsets: Cell offsets tensor of shape (num_edges, 3). - neighbors: Tensor with edge count per system in the batch (length = num_systems). - num_atoms: Total number of atoms across all batches. Used to create a boolean mask for filtering. + node_partition: Atom indices in the current rank's partition. + edge_index: Full edge index, shape (2, num_edges). + cell_offsets: Cell offsets, shape (num_edges, 3). + neighbors: Edge count per system in the batch. + num_atoms: Total atoms across all batches. + rank_assignments: Rank for each atom, shape (num_atoms,). + If provided along with rank and world_size, send_info + is computed and returned as a 4th element. + rank: This rank's GP rank. + world_size: GP world size. Returns: - Filtered edge_index, cell_offsets, and neighbors tensors. + Filtered (edge_index, cell_offsets, neighbors). + If rank_assignments is provided, also returns send_info dict + with keys: send_counts, send_indices_global. """ target_atoms = edge_index[1] node_mask = torch.zeros(num_atoms, dtype=torch.bool, device=target_atoms.device) node_mask[node_partition] = True local_edge_mask = node_mask[target_atoms] + # Compute send info BEFORE discarding non-local edges. + # An edge (src, tgt) where src is LOCAL and tgt is REMOTE means + # src must be sent to rank_assignments[tgt]. + send_info = None + if rank_assignments is not None and rank is not None and world_size is not None: + src_is_local = node_mask[edge_index[0]] + tgt_is_remote = ~local_edge_mask + send_edge_mask = src_is_local & tgt_is_remote + + if send_edge_mask.any(): + send_src = edge_index[0, send_edge_mask] + send_dst_rank = rank_assignments[edge_index[1, send_edge_mask]] + + # Unique (dst_rank, src_atom) pairs, sorted by rank then atom. + # Key layout: dst_rank * num_atoms + src_atom ensures rank-major + # ordering, matching what the index exchange produces. + key = send_dst_rank.to(torch.long) * num_atoms + send_src.to(torch.long) + unique_keys = key.unique(sorted=True) + send_ranks = unique_keys // num_atoms + send_atoms = unique_keys % num_atoms + + send_counts = torch.zeros( + world_size, dtype=torch.long, device=edge_index.device + ) + send_counts.scatter_add_( + 0, + send_ranks, + torch.ones_like(send_ranks), + ) + send_info = { + "send_counts": send_counts, + "send_indices_global": send_atoms, + } + else: + send_info = { + "send_counts": torch.zeros( + world_size, dtype=torch.long, device=edge_index.device + ), + "send_indices_global": torch.empty( + 0, dtype=torch.long, device=edge_index.device + ), + } + # Create system index for each edge to track which system each edge belongs to num_systems = neighbors.shape[0] edge_system_idx = torch.repeat_interleave( @@ -55,6 +117,8 @@ def filter_edges_by_node_partition( if neighbors.shape[0] == 1: # If there's only one system, we can skip the scatter_add step and just return the count of remaining edges new_neighbors = local_edge_mask.sum(dtype=neighbors.dtype).unsqueeze(0) + if send_info is not None: + return edge_index, cell_offsets, new_neighbors, send_info return edge_index, cell_offsets, new_neighbors filtered_edge_system_idx = edge_system_idx[local_edge_mask] @@ -69,6 +133,8 @@ def filter_edges_by_node_partition( torch.ones_like(filtered_edge_system_idx, dtype=neighbors.dtype), ) + if send_info is not None: + return edge_index, cell_offsets, new_neighbors, send_info return edge_index, cell_offsets, new_neighbors @@ -123,8 +189,12 @@ def generate_graph( radius_pbc_version: int, pbc: torch.Tensor, node_partition: torch.Tensor | None = None, + rank_assignments: torch.Tensor | None = None, + rank: int | None = None, + world_size: int | None = None, ) -> dict: - """Generate a graph representation from atomic structure data. + """ + Generate a graph representation from atomic structure data. Args: data (dict): A dictionary containing a batch of molecular structures. @@ -138,6 +208,9 @@ def generate_graph( radius_pbc_version: the version of radius_pbc impl (1, 2, or 3 for NVIDIA) pbc (list[bool]): The periodic boundary conditions in 3 dimensions, defaults to [True,True,True] for 3D pbc node_partition (torch.Tensor | None): The partitioning of the nodes (atoms) for distributed inference. If provided, returned graph will be filtered to keep only edges where the target atom (edge_index[1,:]) belongs to the current rank's partition. + rank_assignments: Rank for each atom (for A2A send_info). + rank: This rank's GP rank (for A2A send_info). + world_size: GP world size (for A2A send_info). Returns: dict: A dictionary containing the generated graph with the following keys: @@ -147,13 +220,19 @@ def generate_graph( - 'cell_offsets' (torch.Tensor): Offsets of the cell vectors for each edge. - 'offset_distances' (torch.Tensor): Distances between the atoms connected by the edges, including the cell offsets. - 'neighbors' (torch.Tensor): Number of neighbors for each atom. + - 'send_info' (dict, optional): Send metadata for A2A GP when rank_assignments is provided. """ if radius_pbc_version == 1: radius_graph_pbc_fn = radius_graph_pbc elif radius_pbc_version == 2: radius_graph_pbc_fn = radius_graph_pbc_v2 if node_partition is not None: - data["node_partition"] = node_partition + # Use setattr for compatibility with SimpleNamespace + # (used by halo filtering) and regular data dicts. + try: + data["node_partition"] = node_partition + except TypeError: + data.node_partition = node_partition elif radius_pbc_version == 3: radius_graph_pbc_fn = radius_graph_pbc_nvidia else: @@ -167,15 +246,30 @@ def generate_graph( pbc=pbc, ) - # for v2 it is still faster right now to not do this post filtering, need to investigate further + # V2 does its own internal edge filtering when node_partition is set, + # which is faster than post-filtering. However, this means send_info + # cannot be computed here for v2 (the full edge_index is needed). + # Instead, build_gp_context falls back to _sparse_index_exchange + # (~4ms NCCL collective) when send_info is None. Bypassing v2's + # internal filter to compute send_info was benchmarked and is ~12ms + # SLOWER because v2 generates edges for ALL atoms instead of local + # partition. + send_info = None if node_partition is not None and radius_pbc_version != 2: - edge_index, cell_offsets, neighbors = filter_edges_by_node_partition( + filter_result = filter_edges_by_node_partition( node_partition, edge_index, cell_offsets, neighbors, num_atoms=data.pos.shape[0], + rank_assignments=rank_assignments, + rank=rank, + world_size=world_size, ) + if rank_assignments is not None: + edge_index, cell_offsets, neighbors, send_info = filter_result + else: + edge_index, cell_offsets, neighbors = filter_result out = get_pbc_distances( data.pos, @@ -192,7 +286,7 @@ def generate_graph( cell_offset_distances = out["offsets"] distance_vec = out["distance_vec"] - return { + result = { "edge_index": edge_index, "edge_distance": edge_dist, "edge_distance_vec": distance_vec, @@ -200,3 +294,6 @@ def generate_graph( "offset_distances": cell_offset_distances, "neighbors": neighbors, } + if send_info is not None: + result["send_info"] = send_info + return result diff --git a/src/fairchem/core/models/uma/escn_md.py b/src/fairchem/core/models/uma/escn_md.py index c6640e4c7f..1810ba85ae 100644 --- a/src/fairchem/core/models/uma/escn_md.py +++ b/src/fairchem/core/models/uma/escn_md.py @@ -9,6 +9,7 @@ import logging import os +import types from dataclasses import dataclass from typing import TYPE_CHECKING, Literal @@ -20,6 +21,15 @@ from torch.profiler import record_function from fairchem.core.common import gp_utils +from fairchem.core.common.parallelism.graph_parallel_a2a import ( + GPContext, + build_gp_context, +) +from fairchem.core.common.parallelism.graph_partition import ( + PartitionStrategy, + partition_atoms_index_split, + partition_atoms_spatial, +) from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad from fairchem.core.graph.compute import generate_graph @@ -99,12 +109,10 @@ class GradRegressConfig: hessian_vmap: bool = True -def add_n_empty_edges( - graph_dict: dict, edges_to_add: int, cutoff: float, node_offset: int = 0 -): +def add_n_empty_edges(graph_dict: dict, edges_to_add: int, cutoff: float): graph_dict["edge_index"] = torch.cat( ( - graph_dict["edge_index"].new_ones(2, edges_to_add) * node_offset, + graph_dict["edge_index"].new_zeros(2, edges_to_add), graph_dict["edge_index"], ), dim=1, @@ -315,6 +323,8 @@ def __init__( spin_balanced_channels: list[int] | None = None, edge_chunk_size: int = 1, execution_mode: str = "general", + use_all_to_all_gp: bool = False, + gp_partition_strategy: str = "index_split", ) -> None: super().__init__() self.max_num_elements = max_num_elements @@ -366,6 +376,26 @@ def __init__( ESCNMD_DEFAULT_EDGE_ACTIVATION_CHECKPOINT_CHUNK_SIZE ) self.edge_chunk_size = edge_chunk_size + self.use_all_to_all_gp = use_all_to_all_gp + self.gp_partition_strategy = PartitionStrategy(gp_partition_strategy) + + # Allgather+spatial is not supported because allgather concatenates + # per-rank tensors in rank order, which only matches global atom order + # for contiguous (index_split) partitions. Spatial partitions are + # non-contiguous, so edge_index lookups into the allgathered tensor + # would reference wrong atoms. Spatial partitioning is designed for + # A2A where only boundary atoms are exchanged. + if ( + not self.use_all_to_all_gp + and self.gp_partition_strategy == PartitionStrategy.SPATIAL + ): + raise ValueError( + "Spatial partitioning is not supported with all-gather GP. " + "Allgather concatenates per-rank tensors in rank order, which " + "only matches global atom order for contiguous (index_split) " + "partitions. Use use_all_to_all_gp=True with spatial, or " + "gp_partition_strategy='index_split' with allgather." + ) self.backend = get_execution_backend(execution_mode) @@ -588,18 +618,215 @@ def csd_embedding(self, charge, spin, dataset): ) return torch.nn.SiLU()(self.mix_csd(torch.cat((chg_emb, spin_emb), dim=1))) + @torch.compiler.disable + def _compute_a2a_partition( + self, + pos: torch.Tensor, + total_atoms: int, + device: torch.device, + world_size: int, + rank: int, + strategy: PartitionStrategy, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute A2A rank assignments and node partition. + + Separated from _generate_graph so that only the A2A-specific + partitioning is excluded from torch.compile. The BL (all-gather) + path stays fully compilable. + """ + with record_function("a2a_partition"): + if strategy == PartitionStrategy.SPATIAL: + rank_assignments = partition_atoms_spatial(pos, world_size) + else: + rank_assignments = partition_atoms_index_split( + total_atoms, world_size, device + ) + node_partition = (rank_assignments == rank).nonzero(as_tuple=True)[0] + + return rank_assignments, node_partition + + @torch.compiler.disable + def _compute_aabb_halo( + self, + pos: torch.Tensor, + node_partition: torch.Tensor, + pbc: torch.Tensor, + cell: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute AABB halo mask for graph generation filtering. + + Builds an axis-aligned bounding box (AABB) around the local + partition's atoms, expanded by the interaction cutoff, then + identifies ALL atoms (including PBC images) that fall within + this box. Graph generation then operates on ~N_halo atoms + instead of N_total, significantly reducing cost at scale. + + No NCCL communication — purely local computation. + + Returns: + Tuple of (halo_mask, shift_vecs) where: + - halo_mask: Boolean mask over all atoms in the AABB halo. + - shift_vecs: PBC shift vectors (n_shifts, 3). + """ + device = pos.device + n_total = len(pos) + + # --- Compute our AABB (expanded by cutoff) --- + local_pos = pos[node_partition] + lo = local_pos.min(dim=0)[0] - self.cutoff + hi = local_pos.max(dim=0)[0] + self.cutoff + + cell_sq = cell.view(3, 3) if cell.dim() == 3 else cell + pbc_flat = pbc.view(3) if pbc.dim() == 2 else pbc + + # Build shift vectors for periodic images (up to 27). + shift_components = [] + for d in range(3): + if pbc_flat[d]: + shift_components.append(torch.tensor([-1, 0, 1], device=device)) + else: + shift_components.append(torch.tensor([0], device=device)) + grid = torch.cartesian_prod(*shift_components) + shift_vecs = grid.float() @ cell_sq + + # --- Compute our halo mask (atoms we need) --- + halo_mask = torch.zeros(n_total, dtype=torch.bool, device=device) + for shift in shift_vecs: + shifted = pos + shift + in_box = ((shifted >= lo) & (shifted <= hi)).all(dim=-1) + halo_mask |= in_box + + return halo_mask, shift_vecs + + @torch.compiler.disable + def _compute_halo_graph( + self, + data_dict: dict, + node_partition: torch.Tensor, + rank_assignments: torch.Tensor, + pbc: torch.Tensor, + halo_mask: torch.Tensor, + ) -> dict | None: + """ + Try to generate graph using AABB halo filtering. + + Uses the pre-computed halo_mask from _compute_aabb_halo to + filter graph gen input from N_total to ~N_halo atoms. + Returns the graph_dict if the halo achieves significant + reduction, or None to fall back to full graph generation. + + The edge_index in the returned graph_dict is remapped to + global coordinates for use by the backbone forward pass. + Additionally, halo-local metadata is stored so that + ``build_gp_context`` can work with O(N_halo)-sized tensors + instead of O(N_total). + + Args: + data_dict: Full data dictionary with pos, cell, etc. + node_partition: Local atom indices (global coords). + rank_assignments: Rank assignment per atom (global). + pbc: Periodic boundary conditions. + halo_mask: Pre-computed boolean halo mask. + + Returns: + graph_dict with edge_index in global coordinates and + halo metadata for build_gp_context, + or None if halo didn't help. + """ + with record_function("a2a_halo_filter"): + pos = data_dict["pos"] + cell = data_dict["cell"] + n_total = len(pos) + + n_halo = halo_mask.sum().item() + + # Only use halo filtering if it reduces atoms enough. + if n_halo >= n_total * 0.95: + return None + + halo_indices = halo_mask.nonzero(as_tuple=True)[0] + + # Map global indices to halo-local indices + global_to_halo = torch.full( + (n_total,), -1, dtype=torch.long, device=pos.device + ) + global_to_halo[halo_indices] = torch.arange(n_halo, device=pos.device) + + # Create subset data for graph generation. + data_subset = types.SimpleNamespace() + data_subset.pos = pos[halo_indices] + data_subset.cell = cell + data_subset.natoms = torch.tensor([n_halo], device=pos.device) + data_subset.batch = torch.zeros(n_halo, dtype=torch.long, device=pos.device) + data_subset.pbc = pbc + + # Remap partition and rank_assignments to halo-local + node_partition_local = global_to_halo[node_partition] + rank_assignments_local = rank_assignments[halo_indices] + + graph_dict = generate_graph( + data_subset, + cutoff=self.cutoff, + max_neighbors=self.max_neighbors, + enforce_max_neighbors_strictly=(self.enforce_max_neighbors_strictly), + radius_pbc_version=self.radius_pbc_version, + pbc=pbc, + node_partition=node_partition_local, + rank_assignments=rank_assignments_local, + rank=gp_utils.get_gp_rank(), + world_size=gp_utils.get_gp_world_size(), + ) + + # Remap edge_index from halo-local to global for use + # by the backbone forward pass (source/target embedding + # lookups index into atomic_numbers_full). + graph_dict["edge_index"] = halo_indices[graph_dict["edge_index"]] + + # Remap send_info indices from halo-local to global so + # build_gp_context can use them with global-coord + # rank_assignments and edge_index. + if "send_info" in graph_dict: + si = graph_dict["send_info"] + if si["send_indices_global"].numel() > 0: + si["send_indices_global"] = halo_indices[si["send_indices_global"]] + + return graph_dict + def _generate_graph(self, data_dict): - data_dict["gp_node_offset"] = 0 node_partition = None + rank_assignments = None if gp_utils.initialized(): # create the partitions atomic_numbers_full = data_dict["atomic_numbers_full"] - node_partition = torch.tensor_split( - torch.arange( - len(atomic_numbers_full), device=atomic_numbers_full.device - ), - gp_utils.get_gp_world_size(), - )[gp_utils.get_gp_rank()] + + if self.use_all_to_all_gp: + # All-to-all: compute rank_assignments FIRST, then derive + # node_partition from them. This ensures the + # graph-generation partition and the GPContext partition + # are identical, avoiding index mismatches that cause + # OOB crashes. + natoms = len(atomic_numbers_full) + rank_assignments, node_partition = self._compute_a2a_partition( + pos=data_dict["pos"], + total_atoms=natoms, + device=atomic_numbers_full.device, + world_size=gp_utils.get_gp_world_size(), + rank=gp_utils.get_gp_rank(), + strategy=self.gp_partition_strategy, + ) + else: + # All-gather: only supports contiguous (index_split) + # partitioning. __init__ blocks spatial+allgather. + node_partition = torch.tensor_split( + torch.arange( + len(atomic_numbers_full), + device=atomic_numbers_full.device, + ), + gp_utils.get_gp_world_size(), + )[gp_utils.get_gp_rank()] + assert ( node_partition.numel() > 0 ), "Looks like there is no atoms in this graph paralell partition. Cannot proceed" @@ -616,17 +843,48 @@ def _generate_graph(self, data_dict): assert ( pbc.all() or (~pbc).all() ), "We can only accept pbc that is all true or all false" - # for v2 graph gen we used to pass node_partition as part of the data_dict directly to radius_pbc to allow it generate partial graphs - # to make it more general to accomodate v3, we scrapped and instead have generate_graph handle the partitioning after the graph has been generated - graph_dict = generate_graph( - data_dict, - cutoff=self.cutoff, - max_neighbors=self.max_neighbors, - enforce_max_neighbors_strictly=self.enforce_max_neighbors_strictly, - radius_pbc_version=self.radius_pbc_version, - pbc=pbc, - node_partition=node_partition, - ) + + # AABB halo optimization for A2A: + # Compute AABB bounding box around local partition, + # expanded by cutoff, to filter graph gen input from + # N_total to ~N_halo atoms. Pure local computation, + # no NCCL needed. + graph_dict = None + if self.use_all_to_all_gp and rank_assignments is not None: + halo_mask, shift_vecs = self._compute_aabb_halo( + data_dict["pos"], + node_partition, + pbc, + data_dict["cell"], + ) + graph_dict = self._compute_halo_graph( + data_dict, + node_partition, + rank_assignments, + pbc, + halo_mask, + ) + + if graph_dict is None: + # Full graph gen (no halo filter, or halo didn't help) + graph_dict = generate_graph( + data_dict, + cutoff=self.cutoff, + max_neighbors=self.max_neighbors, + enforce_max_neighbors_strictly=( + self.enforce_max_neighbors_strictly + ), + radius_pbc_version=self.radius_pbc_version, + pbc=pbc, + node_partition=node_partition, + rank_assignments=( + rank_assignments if self.use_all_to_all_gp else None + ), + rank=(gp_utils.get_gp_rank() if self.use_all_to_all_gp else None), + world_size=( + gp_utils.get_gp_world_size() if self.use_all_to_all_gp else None + ), + ) else: # this assume edge_index is provided assert ( @@ -670,10 +928,57 @@ def _generate_graph(self, data_dict): node_partition ] data_dict["batch"] = data_dict["batch_full"][node_partition] - data_dict["gp_node_offset"] = node_partition.min().item() + + # Build GPContext for all-to-all communication + if self.use_all_to_all_gp: + with record_function("a2a_build_gp_context"): + gp_ctx = build_gp_context( + edge_index=graph_dict["edge_index"], + rank_assignments=rank_assignments, + rank=gp_utils.get_gp_rank(), + world_size=gp_utils.get_gp_world_size(), + send_info=graph_dict.get("send_info"), + node_partition=node_partition, + ) + data_dict["gp_ctx"] = gp_ctx + # Store rank_assignments so output heads can reorder + # gathered forces/stress from partition-concatenated order + # back to global index order. Only needed for A2A where + # partitions are non-consecutive (spatial). + data_dict["gp_rank_assignments"] = rank_assignments + # A2A scatter_target comes from gp_ctx.edge_index_local[1] + # which is set dynamically in InteractionBlock.forward(). + data_dict["scatter_target"] = None + else: + # Allgather: pre-compute local target indices for scatter + # operations. Maps global edge targets to 0-based local + # partition indices. Works for both contiguous (index_split) + # and non-contiguous (spatial) partitions. + total_atoms = len(data_dict["atomic_numbers_full"]) + device = graph_dict["edge_index"].device + global_to_local = torch.zeros( + total_atoms, dtype=torch.long, device=device + ) + global_to_local[node_partition] = torch.arange( + len(node_partition), device=device + ) + data_dict["scatter_target"] = global_to_local[ + graph_dict["edge_index"][1] + ] if graph_dict["edge_index"].shape[1] == 0: - add_n_empty_edges(graph_dict, 1, self.cutoff, data_dict["gp_node_offset"]) + add_n_empty_edges(graph_dict, 1, self.cutoff) + # Also update scatter_target for the dummy edge (target=0) + if ( + "scatter_target" in data_dict + and data_dict["scatter_target"] is not None + ): + data_dict["scatter_target"] = torch.cat( + [ + data_dict["scatter_target"].new_zeros(1), + data_dict["scatter_target"], + ] + ) return graph_dict @@ -749,6 +1054,13 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]: ) self.log_MOLE_stats() + # Retrieve precomputed all-to-all context (needed for edge embedding + # and message passing layers) + gp_ctx: GPContext | None = data_dict.get("gp_ctx", None) + send_indices: torch.Tensor | None = None + if gp_ctx is not None: + send_indices = gp_ctx.send_indices + # edge degree embedding with record_function("edge embedding"): dist_scaled = graph_dict["edge_distance"] / self.cutoff @@ -773,9 +1085,16 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]: x_message = self.edge_degree_embedding( x_message, x_edge, - graph_dict["edge_index"], + # scatter_target contains pre-computed local target indices + # for scattering edge messages to nodes. For A2A, use + # gp_ctx.edge_index_local[1] (already local). For allgather, + # use pre-computed global→local mapped targets. + gp_ctx.edge_index_local[1] + if gp_ctx is not None + else data_dict.get( + "scatter_target", default=graph_dict["edge_index"][1] + ), wigner_inv_envelope, - data_dict["gp_node_offset"], ) ############################################################### @@ -800,7 +1119,9 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]: 0 ], sys_node_embedding=sys_node_embedding, - node_offset=data_dict["gp_node_offset"], + scatter_target=data_dict.get("scatter_target", default=None), + gp_ctx=gp_ctx, + send_indices=send_indices, ) # balance any channels requested x_message = self.balance_channels( @@ -1190,6 +1511,18 @@ def forward(self, data_dict: AtomicData, emb: dict[str, torch.Tensor]): forces = gp_utils.gather_from_model_parallel_region( forces, data_dict["atomic_numbers_full"].shape[0] ) + # A2A spatial partitions are non-consecutive, so the + # gathered forces are in partition-concatenated order + # (NOT global index order). Reorder to match positions. + ra = data_dict.get("gp_rank_assignments", None) + if ra is not None: + ws = gp_utils.get_gp_world_size() + perm = torch.cat( + [(ra == r).nonzero(as_tuple=True)[0] for r in range(ws)] + ) + forces_ordered = torch.empty_like(forces) + forces_ordered[perm] = forces + forces = forces_ordered return {"forces": forces} diff --git a/src/fairchem/core/models/uma/escn_md_block.py b/src/fairchem/core/models/uma/escn_md_block.py index 614f5ca8c3..70b72a8c24 100644 --- a/src/fairchem/core/models/uma/escn_md_block.py +++ b/src/fairchem/core/models/uma/escn_md_block.py @@ -16,6 +16,11 @@ from typing_extensions import Literal from fairchem.core.common import gp_utils +from fairchem.core.common.parallelism.graph_parallel_a2a import ( + GPContext, + all_to_all_collect, + all_to_all_collect_compiled, +) from fairchem.core.models.uma.nn.activation import ( GateActivation, SeparableS2Activation_M, @@ -77,7 +82,8 @@ def __init__( ) extra_m0_output_channels = self.lmax * self.hidden_channels elif self.act_type == "s2": - # NOTE: this is the only place where the SO3 grid of the edges (lmax/mmax) is used + # NOTE: this is the only place where the SO3 grid of the + # edges (lmax/mmax) is used self.act = SeparableS2Activation_M( lmax=self.lmax, mmax=self.mmax, @@ -117,37 +123,90 @@ def forward( wigner, wigner_inv_envelope, total_atoms_across_gp_ranks, - node_offset: int = 0, + scatter_target: torch.Tensor | None = None, + gp_ctx: GPContext | None = None, + send_indices: torch.Tensor | None = None, ): - # we perform the all gather upfront once during each forward call so we don't need to repeat this multiple times during activation checkpointing. - if gp_utils.initialized(): - x_full = gp_utils.gather_from_model_parallel_region_sum_grad( - x, total_atoms_across_gp_ranks - ) + """ + Forward pass with support for both all-gather and all-to-all GP. + + When gp_ctx is provided, uses all-to-all to collect only the + needed remote embeddings. Otherwise falls back to all-gather. + + Args: + scatter_target: Pre-computed local target indices [E] for + scattering edge messages to nodes. For allgather, this + is ``edge_index[1]`` mapped to local partition space. + For A2A, derived from ``gp_ctx.edge_index_local[1]``. + If None, defaults to ``edge_index[1]`` (no GP). + """ + if gp_ctx is not None and gp_utils.initialized(): + # All-to-all path: collect only needed remote embeddings. + # When x requires grad (autograd forces/stress), we use the + # AllToAllCollect autograd.Function so gradients flow through + # the communication. This creates a graph break (same as BL's + # GatherFromModelParallelRegionSumGradPadded), but is necessary + # because funcoll autograd crashes with torch.compile (SymInt + # split sizes). When no autograd is needed, use the + # compile-friendly funcoll (no graph break). + needs_grad = torch.is_grad_enabled() and x.requires_grad + if not self.training and not needs_grad: + # Eval path without autograd: compile-friendly funcoll. + with record_function("a2a_collect_compiled"): + x_received = all_to_all_collect_compiled( + x, + gp_ctx, + send_indices, + ) + x_full = torch.cat([x, x_received], dim=0) + edge_index_local = gp_ctx.edge_index_local + else: + # Training or eval+autograd: AllToAllCollect autograd.Function. + # Supports backward; creates graph break (same as BL). + with record_function("a2a_collect"): + x_received = all_to_all_collect(x, gp_ctx, send_indices) + x_full = torch.cat([x, x_received], dim=0) + edge_index_local = gp_ctx.edge_index_local + local_scatter_target = edge_index_local[1] + elif gp_utils.initialized(): + # Legacy all-gather path + with record_function("allgather_collect"): + x_full = gp_utils.gather_from_model_parallel_region_sum_grad( + x, total_atoms_across_gp_ranks + ) + edge_index_local = edge_index + local_scatter_target = scatter_target else: x_full = x + edge_index_local = edge_index + local_scatter_target = ( + scatter_target if scatter_target is not None else edge_index[1] + ) if self.activation_checkpoint_chunk_size is None: return self.forward_chunk( x_full, x.shape[0], x_edge, - edge_index, + edge_index_local, wigner, wigner_inv_envelope, - node_offset, + local_scatter_target, ) - edge_index_partitions = edge_index.split( + edge_index_partitions = edge_index_local.split( self.activation_checkpoint_chunk_size, dim=1 ) + scatter_target_partitions = local_scatter_target.split( + self.activation_checkpoint_chunk_size + ) wigner_partitions = wigner.split(self.activation_checkpoint_chunk_size, dim=0) wigner_inv_partitions = wigner_inv_envelope.split( self.activation_checkpoint_chunk_size, dim=0 ) x_edge_partitions = x_edge.split(self.activation_checkpoint_chunk_size, dim=0) new_embeddings = [] - # when chunking, we need to keep track of the start index of the chunk and give this information - # to the mole layers + # when chunking, we need to keep track of the start index + # of the chunk and give this information to the mole layers ac_mole_start_idx = 0 for idx in range(len(edge_index_partitions)): @@ -160,7 +219,7 @@ def forward( edge_index_partitions[idx], wigner_partitions[idx], wigner_inv_partitions[idx], - node_offset, + scatter_target_partitions[idx], ac_mole_start_idx, use_reentrant=False, ) @@ -179,11 +238,11 @@ def forward_chunk( edge_index, wigner, wigner_inv_envelope, - node_offset: int = 0, + scatter_target: torch.Tensor | None = None, ac_mole_start_idx: int = 0, ): - # here we need to update the ac_start_idx of the mole layers under here for this chunking to - # work properly with MoLE together + # here we need to update the ac_start_idx of the mole layers + # under here for this chunking to work properly with MoLE set_mole_ac_start_index(self, ac_mole_start_idx) with record_function("SO2Conv"): @@ -196,9 +255,8 @@ def forward_chunk( new_embedding = self.backend.permute_wigner_inv_edge_to_node( x_message, wigner_inv_envelope, - edge_index, + scatter_target if scatter_target is not None else edge_index[1], x_original_shape, - node_offset, ) # reset ac start index @@ -354,7 +412,9 @@ def forward( wigner_inv_envelope, total_atoms_across_gp_ranks, sys_node_embedding=None, - node_offset: int = 0, + scatter_target: torch.Tensor | None = None, + gp_ctx: GPContext | None = None, + send_indices: torch.Tensor | None = None, ): x_res = x x = self.norm_1(x) @@ -370,7 +430,9 @@ def forward( wigner, wigner_inv_envelope, total_atoms_across_gp_ranks=total_atoms_across_gp_ranks, - node_offset=node_offset, + scatter_target=scatter_target, + gp_ctx=gp_ctx, + send_indices=send_indices, ) x = x + x_res diff --git a/src/fairchem/core/models/uma/nn/embedding.py b/src/fairchem/core/models/uma/nn/embedding.py index bb4e4c4748..3f25be41fc 100644 --- a/src/fairchem/core/models/uma/nn/embedding.py +++ b/src/fairchem/core/models/uma/nn/embedding.py @@ -78,9 +78,8 @@ def forward_chunk( self, x, x_edge, - edge_index, + scatter_target, wigner_inv_envelope, - node_offset=0, ): radial = self.rad_func(x_edge) @@ -88,46 +87,42 @@ def forward_chunk( x, radial, wigner_inv_envelope, - edge_index, + scatter_target, self.m_0_num_coefficients, self.sphere_channels, self.rescale_factor, - node_offset, ) def forward( self, x, x_edge, - edge_index, + scatter_target, wigner_inv_envelope, - node_offset=0, ): if self.activation_checkpoint_chunk_size is None: return self.forward_chunk( x, x_edge, - edge_index, + scatter_target, wigner_inv_envelope, - node_offset, ) - edge_index_partitions = edge_index.split( - self.activation_checkpoint_chunk_size, dim=1 + scatter_target_partitions = scatter_target.split( + self.activation_checkpoint_chunk_size ) wigner_inv_partitions = wigner_inv_envelope.split( self.activation_checkpoint_chunk_size, dim=0 ) x_edge_partitions = x_edge.split(self.activation_checkpoint_chunk_size, dim=0) - for idx in range(len(edge_index_partitions)): + for idx in range(len(scatter_target_partitions)): x = torch.utils.checkpoint.checkpoint( self.forward_chunk, x, x_edge_partitions[idx], - edge_index_partitions[idx], + scatter_target_partitions[idx], wigner_inv_partitions[idx], - node_offset, use_reentrant=False, ) diff --git a/src/fairchem/core/models/uma/nn/execution_backends.py b/src/fairchem/core/models/uma/nn/execution_backends.py index 3fac942a4b..a2b5387924 100644 --- a/src/fairchem/core/models/uma/nn/execution_backends.py +++ b/src/fairchem/core/models/uma/nn/execution_backends.py @@ -181,9 +181,8 @@ def node_to_edge_wigner_permute( def permute_wigner_inv_edge_to_node( x_message: torch.Tensor, wigner_inv: torch.Tensor, - edge_index: torch.Tensor, + scatter_target: torch.Tensor, num_nodes: int, - node_offset: int = 0, ) -> torch.Tensor: """ Rotate M->L and scatter edge messages to nodes. @@ -193,9 +192,9 @@ def permute_wigner_inv_edge_to_node( Args: x_message: Edge message features [E, M, C] wigner_inv: Inverse Wigner matrices [E, L, M] - edge_index: Edge indices [2, E] + scatter_target: Pre-computed local target indices [E] + for scattering into node output tensor. num_nodes: Total number of nodes (output size) - node_offset: Offset for node indices (for chunking) Returns: Node embeddings [N, L, C] accumulated from edge messages @@ -208,7 +207,7 @@ def permute_wigner_inv_edge_to_node( dtype=x_rotated.dtype, device=x_rotated.device, ) - new_embedding.index_add_(0, edge_index[1] - node_offset, x_rotated) + new_embedding.index_add_(0, scatter_target, x_rotated) return new_embedding @staticmethod @@ -216,11 +215,10 @@ def edge_degree_scatter( x: torch.Tensor, radial_output: torch.Tensor, wigner_inv: torch.Tensor, - edge_index: torch.Tensor, + scatter_target: torch.Tensor, m_0_num_coefficients: int, sphere_channels: int, rescale_factor: float, - node_offset: int = 0, ) -> torch.Tensor: """ Edge degree embedding: rotate radial and scatter to nodes. @@ -232,12 +230,12 @@ def edge_degree_scatter( radial_output: RadialMLP output [E, m0 * C] wigner_inv: Wigner inverse with envelope pre-fused [E, L, m0] or [E, L, L] - edge_index: Edge indices [2, E] + scatter_target: Pre-computed local target indices [E] + for scattering into node output tensor. m_0_num_coefficients: Number of m=0 coefficients (3 for lmax=2) sphere_channels: Number of channels C rescale_factor: Aggregation rescale factor - node_offset: Node offset for graph parallelism Returns: Updated node features [N, L, C] @@ -256,7 +254,7 @@ def edge_degree_scatter( # Scatter to destination nodes with rescaling return x.index_add( 0, - edge_index[1] - node_offset, + scatter_target, x_edge_embedding / rescale_factor, ) @@ -377,9 +375,8 @@ def node_to_edge_wigner_permute( def permute_wigner_inv_edge_to_node( x_message: torch.Tensor, wigner_inv: torch.Tensor, - edge_index: torch.Tensor, + scatter_target: torch.Tensor, num_nodes: int, - node_offset: int = 0, ) -> torch.Tensor: from fairchem.core.models.uma.triton import ( UMASFastGPUPermuteWignerInvEdgeToNode, @@ -393,7 +390,7 @@ def permute_wigner_inv_edge_to_node( dtype=x_rotated.dtype, device=x_rotated.device, ) - new_embedding.index_add_(0, edge_index[1] - node_offset, x_rotated) + new_embedding.index_add_(0, scatter_target, x_rotated) return new_embedding @staticmethod @@ -401,11 +398,10 @@ def edge_degree_scatter( x: torch.Tensor, radial_output: torch.Tensor, wigner_inv: torch.Tensor, - edge_index: torch.Tensor, + scatter_target: torch.Tensor, m_0_num_coefficients: int, sphere_channels: int, rescale_factor: float, - node_offset: int = 0, ) -> torch.Tensor: radial = radial_output.reshape(-1, m_0_num_coefficients, sphere_channels) @@ -417,7 +413,7 @@ def edge_degree_scatter( return x.index_add( 0, - edge_index[1] - node_offset, + scatter_target, x_edge_embedding / rescale_factor, ) diff --git a/tests/core/common/parallelism/test_graph_parallel.py b/tests/core/common/parallelism/test_graph_parallel.py index a2bb71b1ec..322fb5bb9f 100644 --- a/tests/core/common/parallelism/test_graph_parallel.py +++ b/tests/core/common/parallelism/test_graph_parallel.py @@ -38,6 +38,89 @@ # ========================================================================= +class TestPartitionAtomsIndexSplit: + """ + Tests for partition_atoms_index_split. + """ + + def test_single_rank(self): + result = partition_atoms_index_split(5, 1, torch.device("cpu")) + assert result.shape == (5,) + assert (result == 0).all() + + def test_even_split(self): + result = partition_atoms_index_split(6, 3, torch.device("cpu")) + assert result.shape == (6,) + # Atoms 0,1 -> rank 0; atoms 2,3 -> rank 1; atoms 4,5 -> rank 2 + assert result[0] == 0 + assert result[1] == 0 + assert result[2] == 1 + assert result[3] == 1 + assert result[4] == 2 + assert result[5] == 2 + + def test_uneven_split(self): + result = partition_atoms_index_split(5, 2, torch.device("cpu")) + assert result.shape == (5,) + # 5 atoms, 2 ranks: [0,1,2] -> rank 0, [3,4] -> rank 1 + assert (result[:3] == 0).all() + assert (result[3:] == 1).all() + + def test_more_ranks_than_atoms(self): + result = partition_atoms_index_split(2, 5, torch.device("cpu")) + assert result.shape == (2,) + # Each atom gets its own rank + for i in range(2): + assert result[i].item() >= 0 + assert result[i].item() < 5 + + +class TestPartitionAtomsSpatial: + """ + Tests for partition_atoms_spatial. + """ + + def test_single_rank(self): + pos = torch.randn(10, 3) + result = partition_atoms_spatial(pos, 1) + assert result.shape == (10,) + assert (result == 0).all() + + def test_balanced_output(self): + pos = torch.randn(100, 3) + result = partition_atoms_spatial(pos, 4) + assert result.shape == (100,) + # Check all ranks are assigned + for r in range(4): + count = (result == r).sum() + assert count > 0, f"Rank {r} has no atoms" + # Check balance: each should have ~25 atoms (±1) + for r in range(4): + count = (result == r).sum().item() + assert 24 <= count <= 26, f"Rank {r} has {count} atoms, expected ~25" + + def test_spatially_separated_clusters(self): + """ + Atoms in distinct spatial clusters should be assigned to different ranks. + """ + pos = torch.cat( + [ + torch.randn(20, 3) + torch.tensor([0.0, 0.0, 0.0]), + torch.randn(20, 3) + torch.tensor([100.0, 0.0, 0.0]), + ] + ) + result = partition_atoms_spatial(pos, 2) + # The two clusters should be mostly on different ranks + rank_cluster_0 = result[:20].mode()[0].item() + rank_cluster_1 = result[20:].mode()[0].item() + assert rank_cluster_0 != rank_cluster_1 + + def test_more_ranks_than_atoms(self): + pos = torch.randn(3, 3) + result = partition_atoms_spatial(pos, 5) + assert result.shape == (3,) + + class TestBuildGPContext: """ Tests for build_gp_context (non-distributed, simulates single rank). @@ -149,11 +232,6 @@ def test_edge_split_no_remote_edges(self): assert ctx.remote_edge_idx.numel() == 0 -# ========================================================================= -# Distributed tests: A2A vs All-Gather correctness -# ========================================================================= - - def _a2a_simple_layer(x, edge_index, rank_assignments, natoms): """ A simple message passing layer using all-to-all communication. @@ -549,10 +627,148 @@ def test_a2a_spatial_partition(): # ========================================================================= -# Distributed tests: compiled collect correctness +# Distributed tests: send_info optimization correctness # ========================================================================= +def send_info_optimization_test(atomic_numbers, edge_index): + """ + Verify that pre-computed send_info from filter_edges_by_node_partition + produces the same GPContext as the _sparse_index_exchange path. + """ + from fairchem.core.graph.compute import filter_edges_by_node_partition + + rank = gp_utils.get_gp_rank() + world_size = gp_utils.get_gp_world_size() + natoms = atomic_numbers.shape[0] + + rank_assignments = partition_atoms_index_split( + natoms, world_size, torch.device("cpu") + ) + + # Get this rank's partition + node_partition = (rank_assignments == rank).nonzero(as_tuple=True)[0] + + # Filter edges with send_info computation + neighbors = torch.tensor([edge_index.shape[1]]) + edge_index_filtered, _, _, send_info = filter_edges_by_node_partition( + node_partition=node_partition, + edge_index=edge_index, + cell_offsets=torch.zeros(edge_index.shape[1], 3), + neighbors=neighbors, + num_atoms=natoms, + rank_assignments=rank_assignments, + rank=rank, + world_size=world_size, + ) + + # Build GPContext WITH send_info (skip _sparse_index_exchange) + ctx_with_send_info = build_gp_context( + edge_index_filtered, + rank_assignments, + rank, + world_size, + send_info=send_info, + ) + + # Build GPContext WITHOUT send_info (use _sparse_index_exchange) + ctx_without = build_gp_context( + edge_index_filtered, + rank_assignments, + rank, + world_size, + ) + + # Compare the two contexts + send_counts_match = torch.equal( + ctx_with_send_info.send_counts, ctx_without.send_counts + ) + recv_counts_match = torch.equal( + ctx_with_send_info.recv_counts, ctx_without.recv_counts + ) + + # send_indices should select the same atoms (may differ in order + # within a rank's chunk, but counts must match) + si_a = ctx_with_send_info.send_indices + si_b = ctx_without.send_indices + if si_a is not None and si_b is not None: + send_indices_match = si_a.shape == si_b.shape and torch.equal( + si_a.sort()[0], si_b.sort()[0] + ) + elif si_a is None and si_b is None: + send_indices_match = True + else: + send_indices_match = False + + # Functional test: both contexts should produce identical + # all-to-all results + x = atomic_numbers[node_partition].unsqueeze(1).float() + + x_recv_opt = all_to_all_collect( + x, ctx_with_send_info, ctx_with_send_info.send_indices + ) + x_recv_ref = all_to_all_collect(x, ctx_without, ctx_without.send_indices) + + functional_match = torch.allclose(x_recv_opt, x_recv_ref, atol=1e-6) + + return { + "rank": rank, + "send_counts_match": send_counts_match, + "recv_counts_match": recv_counts_match, + "send_indices_match": send_indices_match, + "functional_match": functional_match, + "send_counts_opt": ctx_with_send_info.send_counts, + "send_counts_ref": ctx_without.send_counts, + } + + +def test_send_info_matches_fused_exchange(): + """ + Verify that pre-computed send_info from filter_edges_by_node_partition + produces identical GPContext and all-to-all results as the + _sparse_index_exchange path. + """ + num_atoms = 8 + # Dense graph: all atoms connected + src = [] + dst = [] + for i in range(num_atoms): + for j in range(num_atoms): + if i != j: + src.append(i) + dst.append(j) + edge_index = torch.tensor([src, dst], dtype=torch.long) + atomic_numbers = torch.arange( + 2, 2 + num_atoms, dtype=torch.float, requires_grad=False + ) + + config = PGConfig(backend="gloo", world_size=2, gp_group_size=2, use_gp=True) + all_rank_results = spawn_multi_process( + config, + send_info_optimization_test, + init_pg_and_rank_and_launch_test, + atomic_numbers, + edge_index, + ) + + for result in all_rank_results: + assert result["send_counts_match"], ( + f"Rank {result['rank']}: send_counts mismatch. " + f"opt={result['send_counts_opt']}, " + f"ref={result['send_counts_ref']}" + ) + assert result[ + "recv_counts_match" + ], f"Rank {result['rank']}: recv_counts mismatch" + assert result[ + "send_indices_match" + ], f"Rank {result['rank']}: send_indices mismatch" + assert result["functional_match"], ( + f"Rank {result['rank']}: functional mismatch — " + f"all-to-all produced different embeddings" + ) + + def compiled_collect_test( atomic_numbers: torch.Tensor, edge_index: torch.Tensor, diff --git a/tests/core/units/mlip_unit/test_predict.py b/tests/core/units/mlip_unit/test_predict.py index 2f6b403b75..481d432e45 100644 --- a/tests/core/units/mlip_unit/test_predict.py +++ b/tests/core/units/mlip_unit/test_predict.py @@ -197,8 +197,17 @@ def test_multiple_dataset_predict(internal_graph_gen_version): npt.assert_allclose(pred_forces[batch_batch == 2], pt.get_forces(), atol=ATOL) -def _test_parallel_predict_unit_impl(workers, device, checkpointing, graph_gen_version): - """Implementation of parallel predict unit test.""" +def _test_parallel_predict_unit_impl( + workers, device, checkpointing, graph_gen_version, gp_mode=None +): + """ + Implementation of parallel predict unit test. + + Args: + gp_mode: Optional dict with GP overrides, e.g. + {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}. + If None, uses default GP settings (allgather + index_split). + """ seed = 42 runs = 2 model_path = pretrained_checkpoint_path_from_name("uma-s-1p1") @@ -213,12 +222,17 @@ def _test_parallel_predict_unit_impl(workers, device, checkpointing, graph_gen_v atoms = get_fcc_crystal_by_num_atoms(num_atoms) atomic_data = AtomicData.from_ase(atoms, task_name=["omat"]) + overrides = None + if gp_mode is not None: + overrides = {"backbone": gp_mode} + seed_everywhere(seed) ppunit = ParallelMLIPPredictUnit( inference_model_path=model_path, device=device, inference_settings=ifsets, num_workers=workers, + overrides=overrides, ) for _ in range(runs): pp_results = ppunit.predict(atomic_data) @@ -248,33 +262,59 @@ def _test_parallel_predict_unit_impl(workers, device, checkpointing, graph_gen_v @pytest.mark.serial() @pytest.mark.parametrize( - "workers, checkpointing, graph_gen_version", + "workers, checkpointing, graph_gen_version, gp_mode", [ - (1, False, 2), - (2, False, 2), - (1, False, 3), - (1, True, 3), - (2, False, 3), + # Default GP mode (allgather + index_split) + (1, False, 2, None), + (2, False, 2, None), + (1, False, 3, None), + (1, True, 3, None), + (2, False, 3, None), + # A2A + spatial (requires workers >= 2 for actual GP) + (2, False, 2, {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}), + # A2A + index_split + ( + 2, + False, + 2, + {"use_all_to_all_gp": True, "gp_partition_strategy": "index_split"}, + ), ], ) -def test_parallel_predict_unit_cpu(workers, checkpointing, graph_gen_version): - _test_parallel_predict_unit_impl(workers, "cpu", checkpointing, graph_gen_version) +def test_parallel_predict_unit_cpu(workers, checkpointing, graph_gen_version, gp_mode): + _test_parallel_predict_unit_impl( + workers, "cpu", checkpointing, graph_gen_version, gp_mode + ) @pytest.mark.gpu() @pytest.mark.parametrize( - "workers, checkpointing, graph_gen_version", + "workers, checkpointing, graph_gen_version, gp_mode", [ - (1, False, 2), - (1, True, 2), - (1, True, 3), - (1, False, 3), - # (2, False), - # (2, True), + # Default GP mode (allgather + index_split) + (1, False, 2, None), + (1, True, 2, None), + (1, True, 3, None), + (1, False, 3, None), + # GP modes with 1 worker — verifies code paths don't crash + # (single rank = no actual GP comms, but exercises config flow) + (1, False, 2, {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}), + ( + 1, + False, + 2, + {"use_all_to_all_gp": True, "gp_partition_strategy": "index_split"}, + ), + # For local 8-GPU runs: uncomment to test multi-worker GPU GP + # (2, False, 2, None), + # (2, False, 2, {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}), + # (2, False, 2, {"use_all_to_all_gp": True, "gp_partition_strategy": "index_split"}), ], ) -def test_parallel_predict_unit_gpu(workers, checkpointing, graph_gen_version): - _test_parallel_predict_unit_impl(workers, "cuda", checkpointing, graph_gen_version) +def test_parallel_predict_unit_gpu(workers, checkpointing, graph_gen_version, gp_mode): + _test_parallel_predict_unit_impl( + workers, "cuda", checkpointing, graph_gen_version, gp_mode + ) def _test_parallel_predict_unit_batch_impl(workers, device, checkpointing): From f8c83883c6c712716d3fa5aaf97fbbabd3bd1e7d Mon Sep 17 00:00:00 2001 From: rgao user Date: Sat, 2 May 2026 07:15:26 +0000 Subject: [PATCH 03/26] Add A2A GP mode to MD consistency test Parametrize test_merge_mole_md_consistency with A2A+spatial GP mode to verify that all-to-all graph parallel with spatial partitioning works correctly across multiple MD timesteps (NVT and NPT). Skips A2A modes when workers < 2 since GP requires multi-rank. 2 new test cases pass (NVT+A2A, NPT+A2A), 2 correctly skipped. --- tests/core/units/mlip_unit/test_predict.py | 28 ++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/tests/core/units/mlip_unit/test_predict.py b/tests/core/units/mlip_unit/test_predict.py index 481d432e45..92f0c64695 100644 --- a/tests/core/units/mlip_unit/test_predict.py +++ b/tests/core/units/mlip_unit/test_predict.py @@ -1044,7 +1044,15 @@ def test_batch_server_predict_unit_multiple_systems( @pytest.mark.parametrize("workers", [0, 2]) @pytest.mark.parametrize("ensemble", ["nvt", "npt"]) @pytest.mark.parametrize("device", ["cpu"]) -def test_merge_mole_md_consistency(workers, ensemble, device): +@pytest.mark.parametrize( + "gp_mode", + [ + None, # default allgather + index_split + # A2A + spatial — tests multi-step MD with spatial repartitioning + {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}, + ], +) +def test_merge_mole_md_consistency(workers, ensemble, device, gp_mode): """Test merge_mole vs no-merge consistency over MD trajectory. Runs 3 trials: @@ -1055,9 +1063,16 @@ def test_merge_mole_md_consistency(workers, ensemble, device): Compares the relative drift of A-C against baseline A-B to ensure merge_mole doesn't introduce additional numerical drift beyond the inherent noise between identical runs. + + When gp_mode is not None, passes backbone overrides to enable + A2A graph parallel with the specified partition strategy. """ import torch + # A2A GP modes require workers >= 2 to actually exercise multi-rank + if gp_mode is not None and workers < 2: + pytest.skip("A2A GP mode requires workers >= 2") + torch.use_deterministic_algorithms(True) from ase import units @@ -1133,6 +1148,9 @@ def run_md_trial(atoms, calc, seed, steps): "stresses": np.array(stresses), } + # Build overrides for GP mode + overrides = {"backbone": gp_mode} if gp_mode is not None else None + # Trial A: no merge settings_no_merge = InferenceSettings(merge_mole=False, **base_settings) predict_unit_A = pretrained_mlip.get_predict_unit( @@ -1140,6 +1158,7 @@ def run_md_trial(atoms, calc, seed, steps): device=device, inference_settings=settings_no_merge, workers=workers, + overrides=overrides, ) calc_A = FAIRChemCalculator(predict_unit_A, task_name="omat") results_A = run_md_trial(atoms_template, calc_A, seed=42, steps=md_steps) @@ -1151,6 +1170,7 @@ def run_md_trial(atoms, calc, seed, steps): device=device, inference_settings=settings_no_merge, workers=workers, + overrides=overrides, ) calc_B = FAIRChemCalculator(predict_unit_B, task_name="omat") results_B = run_md_trial(atoms_template, calc_B, seed=42, steps=md_steps) @@ -1159,7 +1179,11 @@ def run_md_trial(atoms, calc, seed, steps): # Trial C: merge settings_merge = InferenceSettings(merge_mole=True, **base_settings) predict_unit_C = pretrained_mlip.get_predict_unit( - "uma-s-1p1", device=device, inference_settings=settings_merge, workers=workers + "uma-s-1p1", + device=device, + inference_settings=settings_merge, + workers=workers, + overrides=overrides, ) calc_C = FAIRChemCalculator(predict_unit_C, task_name="omat") results_C = run_md_trial(atoms_template, calc_C, seed=42, steps=md_steps) From c45cb751a1854d64b1c9abab82d8da5ec6335683 Mon Sep 17 00:00:00 2001 From: rgao user Date: Tue, 5 May 2026 17:14:50 +0000 Subject: [PATCH 04/26] Harden A2A GP: multi-system guard, defensive assertions, expand test coverage - Add is_single_system guard in _generate_graph so AABB halo bails out for multi-system batches (prevents cell.view(3,3) crash) - Add defensive assertion in _compute_halo_graph verifying all local partition atoms appear in the halo mask - Add A2A+index_split mode to MD consistency test parametrization - Add A2A+spatial mode to batch predict test (exercises halo bail-out path for multi-system batches) - Document single-system limitation in _compute_aabb_halo docstring --- src/fairchem/core/models/uma/escn_md.py | 21 ++++++++- tests/core/units/mlip_unit/test_predict.py | 51 ++++++++++++++++------ 2 files changed, 56 insertions(+), 16 deletions(-) diff --git a/src/fairchem/core/models/uma/escn_md.py b/src/fairchem/core/models/uma/escn_md.py index 1810ba85ae..176db9bba2 100644 --- a/src/fairchem/core/models/uma/escn_md.py +++ b/src/fairchem/core/models/uma/escn_md.py @@ -665,6 +665,10 @@ def _compute_aabb_halo( No NCCL communication — purely local computation. + Note: Only supports single-system inputs (not batched). + For multi-system batches, the caller should skip halo + filtering and fall back to full graph generation. + Returns: Tuple of (halo_mask, shift_vecs) where: - halo_mask: Boolean mask over all atoms in the AABB halo. @@ -764,6 +768,10 @@ def _compute_halo_graph( # Remap partition and rank_assignments to halo-local node_partition_local = global_to_halo[node_partition] + assert (node_partition_local >= 0).all(), ( + "Local partition atoms not found in halo — " + "AABB expansion may be too small" + ) rank_assignments_local = rank_assignments[halo_indices] graph_dict = generate_graph( @@ -849,9 +857,18 @@ def _generate_graph(self, data_dict): # expanded by cutoff, to filter graph gen input from # N_total to ~N_halo atoms. Pure local computation, # no NCCL needed. + # Note: AABB halo only supports single-system inputs. + # Multi-system batches skip halo and use full graph gen. graph_dict = None - if self.use_all_to_all_gp and rank_assignments is not None: - halo_mask, shift_vecs = self._compute_aabb_halo( + is_single_system = data_dict["cell"].dim() == 2 or ( + data_dict["cell"].dim() == 3 and data_dict["cell"].shape[0] == 1 + ) + if ( + self.use_all_to_all_gp + and rank_assignments is not None + and is_single_system + ): + halo_mask, _ = self._compute_aabb_halo( data_dict["pos"], node_partition, pbc, diff --git a/tests/core/units/mlip_unit/test_predict.py b/tests/core/units/mlip_unit/test_predict.py index 92f0c64695..bf96de7921 100644 --- a/tests/core/units/mlip_unit/test_predict.py +++ b/tests/core/units/mlip_unit/test_predict.py @@ -317,8 +317,17 @@ def test_parallel_predict_unit_gpu(workers, checkpointing, graph_gen_version, gp ) -def _test_parallel_predict_unit_batch_impl(workers, device, checkpointing): - """Implementation of parallel predict unit batch test.""" +def _test_parallel_predict_unit_batch_impl( + workers, device, checkpointing, gp_mode=None +): + """ + Implementation of parallel predict unit batch test. + + Args: + gp_mode: Optional dict with GP overrides, e.g. + {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}. + If None, uses default GP settings (allgather + index_split). + """ seed = 42 runs = 1 model_path = pretrained_checkpoint_path_from_name("uma-s-1p1") @@ -353,11 +362,17 @@ def _test_parallel_predict_unit_batch_impl(workers, device, checkpointing): ) atomic_data = atomicdata_list_to_batch([h2o_data, o_data]) seed_everywhere(seed) + + overrides = None + if gp_mode is not None: + overrides = {"backbone": gp_mode} + ppunit = ParallelMLIPPredictUnit( inference_model_path=model_path, device=device, inference_settings=ifsets, num_workers=workers, + overrides=overrides, ) for _ in range(runs): pp_results = ppunit.predict(atomic_data) @@ -385,28 +400,34 @@ def _test_parallel_predict_unit_batch_impl(workers, device, checkpointing): @pytest.mark.serial() @pytest.mark.parametrize( - "workers, checkpointing", + "workers, checkpointing, gp_mode", [ - (1, False), - (2, True), + (1, False, None), + (2, True, None), + # A2A + spatial with batch (exercises halo bail-out for multi-system) + ( + 2, + False, + {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}, + ), ], ) -def test_parallel_predict_unit_batch(workers, checkpointing): - _test_parallel_predict_unit_batch_impl(workers, "cpu", checkpointing) +def test_parallel_predict_unit_batch(workers, checkpointing, gp_mode): + _test_parallel_predict_unit_batch_impl(workers, "cpu", checkpointing, gp_mode) @pytest.mark.gpu() @pytest.mark.parametrize( - "workers, checkpointing", + "workers, checkpointing, gp_mode", [ - (1, True), - (1, False), - # (2, True), - # (2, False), + (1, True, None), + (1, False, None), + # (2, True, None), + # (2, False, None), ], ) -def test_parallel_predict_unit_batch_gpu(workers, checkpointing): - _test_parallel_predict_unit_batch_impl(workers, "cuda", checkpointing) +def test_parallel_predict_unit_batch_gpu(workers, checkpointing, gp_mode): + _test_parallel_predict_unit_batch_impl(workers, "cuda", checkpointing, gp_mode) @pytest.mark.gpu() @@ -1050,6 +1071,8 @@ def test_batch_server_predict_unit_multiple_systems( None, # default allgather + index_split # A2A + spatial — tests multi-step MD with spatial repartitioning {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}, + # A2A + index_split — tests A2A with contiguous partitioning + {"use_all_to_all_gp": True, "gp_partition_strategy": "index_split"}, ], ) def test_merge_mole_md_consistency(workers, ensemble, device, gp_mode): From 360f38c9cfad07edc24bc0e6266dc0c1b9da2aee Mon Sep 17 00:00:00 2001 From: rgao user Date: Wed, 3 Jun 2026 22:58:51 +0000 Subject: [PATCH 05/26] Refactor build_gp_context and add allgather index exchange variant MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extract _resolve_send_metadata and _validate_gp_mappings from the monolithic build_gp_context (270→120 lines). Validation now uses fast-path .any().item() checks with full diagnostics only on error. Add _allgather_index_exchange: packs recv_counts + needed_atoms into one buffer for a single all-gather call (vs 2x all-to-all). Wired into eSCN-MD via gp_index_exchange_method constructor parameter. Benchmarks at GP=64 (8 nodes, 256K atoms): 2xA2A +22.9%, 1xAG +20.3% vs baseline all-gather GP. Both variants beat baseline at scale. --- .../common/parallelism/graph_parallel_a2a.py | 430 +++++++++++------- src/fairchem/core/models/uma/escn_md.py | 3 + 2 files changed, 265 insertions(+), 168 deletions(-) diff --git a/src/fairchem/core/common/parallelism/graph_parallel_a2a.py b/src/fairchem/core/common/parallelism/graph_parallel_a2a.py index c65ffbcb75..b442154485 100644 --- a/src/fairchem/core/common/parallelism/graph_parallel_a2a.py +++ b/src/fairchem/core/common/parallelism/graph_parallel_a2a.py @@ -203,206 +203,172 @@ def _sparse_index_exchange( return send_counts, send_indices_global -@torch.compiler.disable -def build_gp_context( - edge_index: torch.Tensor, - rank_assignments: torch.Tensor, +def _allgather_index_exchange( + needed_atoms: torch.Tensor, + recv_counts: torch.Tensor, rank: int, world_size: int, - send_info: dict | None = None, - node_partition: torch.Tensor | None = None, -) -> GPContext: + device: torch.device, + natoms: int, +) -> tuple[torch.Tensor, torch.Tensor | None]: """ - Build the GP context from edge connectivity and atom assignments. - - Determines which non-local atoms this rank needs (edge sources from - other ranks), exchanges atom indices via all-to-all, and computes - all communication metadata. + Single all-gather variant of index exchange. - When send_info is provided (pre-computed during graph filtering in - filter_edges_by_node_partition), the NCCL index-exchange collective - is skipped entirely — send_counts and send_indices_global are taken - directly from send_info. - - Args: - edge_index: Edge index filtered to edges whose targets are in - this rank's partition, shape (2, num_edges). - Row 0 = source, row 1 = target. - rank_assignments: Rank assignment for each atom, - shape (total_atoms,). - rank: This rank's GP rank. - world_size: GP world size. - send_info: Pre-computed send/recv metadata from graph filtering. - If provided, must contain: - - send_counts: Tensor of shape (world_size,) with count of - atoms to send to each rank. - - send_indices_global: Tensor of global atom indices to send, - sorted by destination rank. - When provided, _sparse_index_exchange is skipped. - node_partition: Pre-computed atom indices in this rank's - partition. If provided, avoids recomputing from - rank_assignments. + Packs [recv_counts (P ints), needed_atoms (padded to natoms)] into + one buffer and performs a single all_gather call. Each rank then + extracts its send_counts and send_indices from the gathered data. - Returns: - GPContext with all metadata needed for all-to-all communication. + Communication volume: O(P * (P + natoms)). + Advantage: one collective instead of two, potentially better NCCL + utilization for small messages. """ - total_atoms = rank_assignments.shape[0] - device = rank_assignments.device + if not gp_utils.initialized(): + return ( + torch.zeros(world_size, dtype=torch.long, device=device), + None, + ) - # Atoms owned by this rank (reuse pre-computed if available) - if node_partition is None: - node_partition = (rank_assignments == rank).nonzero(as_tuple=True)[0] - total_local_atoms = node_partition.shape[0] + gp_group = gp_utils.get_gp_group() - # Find which non-local atoms this rank needs as edge sources. - # Since edge_index is already filtered to edges whose targets are - # in this rank's partition, every edge has a local target. We only - # need to find edges where the SOURCE is remote (not in our partition). - local_mask = rank_assignments == rank # (total_atoms,) bool - src_is_remote = ~local_mask[edge_index[0]] + # Pack: [recv_counts (P), needed_atoms padded to natoms] + buf_size = world_size + natoms + send_buf = torch.zeros(buf_size, dtype=torch.long, device=device) + send_buf[:world_size] = recv_counts + n_needed = needed_atoms.numel() + if n_needed > 0: + send_buf[world_size : world_size + n_needed] = needed_atoms + + # Single all-gather + recv_bufs = [ + torch.empty(buf_size, dtype=torch.long, device=device) + for _ in range(world_size) + ] + dist.all_gather(recv_bufs, send_buf.contiguous(), group=gp_group) + + # Unpack count matrix: count_matrix[r][s] = how many atoms rank r + # needs from rank s. + count_matrix = torch.stack([buf[:world_size] for buf in recv_bufs]) + # send_counts[r] = count_matrix[r][rank] = atoms rank r needs from me + send_counts = count_matrix[:, rank].contiguous() + + # Extract send_indices: for each rank r, find atoms they need from me + send_indices_parts = [] + for r in range(world_size): + if r == rank: + continue + length = count_matrix[r, rank].item() + if length > 0: + offset = count_matrix[r, :rank].sum().item() + send_indices_parts.append( + recv_bufs[r][world_size + offset : world_size + offset + length] + ) - # Remote sources needed for local targets - # Use boolean mask + nonzero instead of .unique(sorted=True) on raw - # edge sources — O(N) scatter + scan vs O(E log E) sort. - needed_mask = torch.zeros(total_atoms, dtype=torch.bool, device=device) - needed_mask[edge_index[0, src_is_remote]] = True - needed_mask &= ~local_mask # exclude local atoms (safety) - needed_atoms = needed_mask.nonzero(as_tuple=True)[0] + if send_indices_parts: + send_indices_global = torch.cat(send_indices_parts) + else: + send_indices_global = torch.empty(0, dtype=torch.long, device=device) - total_needed_atoms = needed_atoms.shape[0] - needed_from_ranks = rank_assignments[needed_atoms] + return send_counts, send_indices_global - # Compute recv_counts: how many atoms we receive from each rank - if total_needed_atoms > 0: - recv_counts = torch.bincount(needed_from_ranks, minlength=world_size).to( - dtype=torch.long, device=device - ) - else: - recv_counts = torch.zeros(world_size, dtype=torch.long, device=device) - recv_counts[rank] = 0 # Never receive from self - - # CRITICAL: Sort needed_atoms by source rank to match recv_buf ordering. - # all_to_all fills recv_buf by source rank: [atoms from rank 0 | atoms - # from rank 1 | ...]. Within each rank, atoms are in the order we - # requested them (global index order, since argsort is stable). - # global_to_local must assign local indices in this SAME order, - # otherwise local index i maps to recv_buf[i] which has a DIFFERENT - # atom's embedding. - # With index_split, global index order == rank order (no-op sort). - # With spatial, global index order != rank order → sort is essential. - sort_order = needed_from_ranks.argsort(stable=True) - needed_atoms = needed_atoms[sort_order] - needed_from_ranks = needed_from_ranks[sort_order] - # Use pre-computed send_info when available to skip the - # _sparse_index_exchange NCCL collectives entirely. +def _resolve_send_metadata( + send_info: dict | None, + index_exchange_method: str, + needed_atoms: torch.Tensor, + recv_counts: torch.Tensor, + rank: int, + world_size: int, + device: torch.device, + total_atoms: int, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Resolve send_counts and send_indices_global. + + Uses pre-computed send_info when available; otherwise runs + the configured index exchange collective. + """ if send_info is not None: - # Pre-computed during graph filtering — use it. - send_counts = send_info["send_counts"] - send_indices_global = send_info["send_indices_global"] - else: - with record_function("a2a_sparse_index_exchange"): - send_counts, send_indices_global = _sparse_index_exchange( + return send_info["send_counts"], send_info["send_indices_global"] + + if index_exchange_method == "allgather": + with record_function("a2a_allgather_index_exchange"): + return _allgather_index_exchange( needed_atoms=needed_atoms, recv_counts=recv_counts, rank=rank, world_size=world_size, device=device, + natoms=total_atoms, ) - # Build global_to_local mapping: - # Local atoms: index 0..total_local_atoms-1 (in order of - # node_partition) - # Received atoms: index total_local_atoms..total_local_atoms+total_needed - # IMPORTANT: needed_atoms is sorted by source rank (not global index) - # to match the recv_buf ordering from all_to_all. This ensures that - # local index (total_local + i) maps to recv_buf[i], which contains - # the embedding of needed_atoms[i]. - global_to_local = torch.full((total_atoms,), -1, dtype=torch.long, device=device) - # Map local atoms - global_to_local[node_partition] = torch.arange( - total_local_atoms, dtype=torch.long, device=device - ) - # Map needed remote atoms (in recv_buf order = source rank order) - global_to_local[needed_atoms] = torch.arange( - total_local_atoms, - total_local_atoms + total_needed_atoms, - dtype=torch.long, - device=device, - ) + with record_function("a2a_sparse_index_exchange"): + return _sparse_index_exchange( + needed_atoms=needed_atoms, + recv_counts=recv_counts, + rank=rank, + world_size=world_size, + device=device, + ) - # Convert send_indices from global to local - send_indices = None - has_send = send_indices_global is not None - if has_send and send_indices_global.numel() > 0: - send_indices = global_to_local[send_indices_global] - elif has_send: - send_indices = torch.empty(0, dtype=torch.long, device=device) - # Precompute edge_index_local - edge_index_local = global_to_local[edge_index] +def _validate_gp_mappings( + send_indices: torch.Tensor | None, + send_indices_global: torch.Tensor | None, + edge_index_local: torch.Tensor, + edge_index: torch.Tensor, + rank: int, + rank_assignments: torch.Tensor, + local_mask: torch.Tensor, + needed_atoms: torch.Tensor, + node_partition: torch.Tensor, + total_local_atoms: int, + total_atoms: int, + total_needed_atoms: int, + send_info: dict | None, +) -> None: + """ + Validate index mappings; raise with diagnostics on failure. - # Batch ALL GPU→CPU scalar extractions into a single transfer. - # This batches send_counts, recv_counts, AND validation scalars - # into ONE .cpu() call, eliminating extra GPU→CPU syncs. - bad_edge_count = (edge_index_local < 0).sum().unsqueeze(0).to(torch.long) - send_valid = ( - torch.ones(1, dtype=torch.long, device=device) - if send_indices is None or send_indices.numel() == 0 - else ( - ((send_indices >= 0) & (send_indices < total_local_atoms)) - .all() - .unsqueeze(0) - .to(torch.long) - ) - ) - all_cpu = torch.cat([send_counts, recv_counts, bad_edge_count, send_valid]).cpu() - send_splits = all_cpu[:world_size].tolist() - recv_splits = all_cpu[world_size : 2 * world_size].tolist() - total_recv = sum(recv_splits) - n_bad = int(all_cpu[2 * world_size].item()) - send_ok = int(all_cpu[2 * world_size + 1].item()) + Only called when fast-path checks detect an error, so + GPU→CPU transfers here are acceptable. + """ + device = edge_index.device - # Validate AFTER the batched CPU transfer (no extra GPU syncs). - if not send_ok: - # Diagnostic: identify which send_indices are out of range. + if send_indices is not None and send_indices.numel() > 0: bad_mask = (send_indices < 0) | (send_indices >= total_local_atoms) - n_bad_send = bad_mask.sum().item() - n_total_send = send_indices.numel() - bad_global = send_indices_global[bad_mask][:10].tolist() - bad_ra = rank_assignments[send_indices_global[bad_mask][:10]].tolist() - raise RuntimeError( - f"Rank {rank}: received requests for atoms not in our " - f"partition ({n_bad_send}/{n_total_send} OOB). " - f"bad_global={bad_global}, bad_ranks={bad_ra}. " - f"This usually means rank_assignments differs across " - f"ranks (e.g. non-deterministic crystal generation)." - ) + if bad_mask.any(): + n_bad_send = bad_mask.sum().item() + bad_global = send_indices_global[bad_mask][:10].tolist() + bad_ra = rank_assignments[send_indices_global[bad_mask][:10]].tolist() + raise RuntimeError( + f"Rank {rank}: received requests for atoms not " + f"in our partition " + f"({n_bad_send}/{send_indices.numel()} OOB). " + f"bad_global={bad_global}, " + f"bad_ranks={bad_ra}. " + f"This usually means rank_assignments differs " + f"across ranks (e.g. non-deterministic crystal " + f"generation)." + ) + + n_bad = (edge_index_local < 0).sum().item() if n_bad > 0: - # Only compute diagnostics in the error path (rare). bad_cols = (edge_index_local < 0).any(dim=0) bad_globals = edge_index[:, bad_cols].unique() bad_ranks = rank_assignments[bad_globals] - # Compute edge-based needed atoms for comparison. edge_src_remote = ~local_mask[edge_index[0]] edge_needed_mask = torch.zeros(total_atoms, dtype=torch.bool, device=device) edge_needed_mask[edge_index[0, edge_src_remote]] = True edge_needed_mask &= ~local_mask edge_needed_count = edge_needed_mask.sum().item() - # Check which edge-needed atoms are NOT in our needed_atoms. needed_set = torch.zeros(total_atoms, dtype=torch.bool, device=device) needed_set[needed_atoms] = True missing = edge_needed_mask & ~needed_set missing_count = missing.sum().item() - - # Check for local atoms in needed_atoms. local_in_needed = (local_mask & needed_set).sum().item() - - # Check for bad local atoms (local atoms mapped to -1). - local_is_bad = local_mask[bad_globals] - n_local_bad = local_is_bad.sum().item() + n_local_bad = local_mask[bad_globals].sum().item() logging.error( f"Rank {rank}: GP DIAGNOSTIC — " @@ -411,10 +377,12 @@ def build_gp_context( f"needed_atoms_count={total_needed_atoms}, " f"missing_from_needed={missing_count}, " f"local_in_needed={local_in_needed}, " - f"n_local_bad={n_local_bad}/{len(bad_globals)} bad globals, " + f"n_local_bad={n_local_bad}/{len(bad_globals)} " + f"bad globals, " f"total_atoms={total_atoms}, " f"total_local={total_local_atoms}, " - f"node_partition_range=[{node_partition.min().item()}, " + f"node_partition_range=" + f"[{node_partition.min().item()}, " f"{node_partition.max().item()}], " f"send_info_provided={send_info is not None}, " f"bad_globals[:20]={bad_globals.tolist()[:20]}, " @@ -424,17 +392,144 @@ def build_gp_context( missing_indices = missing.nonzero(as_tuple=True)[0][:10] logging.error( f"Rank {rank}: Missing atoms (edge-needed but " - f"not in needed_atoms): {missing_indices.tolist()}" + f"not in needed_atoms): " + f"{missing_indices.tolist()}" ) raise RuntimeError( - f"Rank {rank}: edge_index has {n_bad} endpoints not in " - f"global_to_local mapping. This indicates a mismatch " - f"between graph edges and partition assignments." + f"Rank {rank}: edge_index has {n_bad} endpoints " + f"not in global_to_local mapping. This indicates " + f"a mismatch between graph edges and partition " + f"assignments." + ) + + +@torch.compiler.disable +def build_gp_context( + edge_index: torch.Tensor, + rank_assignments: torch.Tensor, + rank: int, + world_size: int, + send_info: dict | None = None, + node_partition: torch.Tensor | None = None, + index_exchange_method: str = "a2a", +) -> GPContext: + """ + Build the GP context from edge connectivity and atom assignments. + + Args: + edge_index: Edge index filtered to this rank's partition, + shape (2, num_edges). Row 0 = source, row 1 = target. + rank_assignments: Rank owner for each atom, + shape (total_atoms,). + rank: This rank's GP rank. + world_size: GP world size. + send_info: Pre-computed send metadata (skips index exchange). + node_partition: Pre-computed local atom indices. + index_exchange_method: "a2a" (2x all-to-all) or "allgather" + (1x all-gather). + + Returns: + GPContext with all communication metadata. + """ + total_atoms = rank_assignments.shape[0] + device = rank_assignments.device + + if node_partition is None: + node_partition = (rank_assignments == rank).nonzero(as_tuple=True)[0] + total_local_atoms = node_partition.shape[0] + + # Find remote atoms needed as edge sources. + local_mask = rank_assignments == rank + src_is_remote = ~local_mask[edge_index[0]] + needed_mask = torch.zeros(total_atoms, dtype=torch.bool, device=device) + needed_mask[edge_index[0, src_is_remote]] = True + needed_mask &= ~local_mask + needed_atoms = needed_mask.nonzero(as_tuple=True)[0] + + total_needed_atoms = needed_atoms.shape[0] + needed_from_ranks = rank_assignments[needed_atoms] + + if total_needed_atoms > 0: + recv_counts = torch.bincount(needed_from_ranks, minlength=world_size).to( + dtype=torch.long, device=device + ) + else: + recv_counts = torch.zeros(world_size, dtype=torch.long, device=device) + recv_counts[rank] = 0 + + # Sort needed_atoms by source rank to match A2A recv_buf ordering. + sort_order = needed_from_ranks.argsort(stable=True) + needed_atoms = needed_atoms[sort_order] + needed_from_ranks = needed_from_ranks[sort_order] + + # Resolve send metadata (pre-computed or via collective). + send_counts, send_indices_global = _resolve_send_metadata( + send_info, + index_exchange_method, + needed_atoms, + recv_counts, + rank, + world_size, + device, + total_atoms, + ) + + # Build global-to-local index mapping. + # Local atoms: [0, total_local_atoms) + # Remote atoms: [total_local_atoms, total_local_atoms + needed) + global_to_local = torch.full((total_atoms,), -1, dtype=torch.long, device=device) + global_to_local[node_partition] = torch.arange( + total_local_atoms, dtype=torch.long, device=device + ) + global_to_local[needed_atoms] = torch.arange( + total_local_atoms, + total_local_atoms + total_needed_atoms, + dtype=torch.long, + device=device, + ) + + # Convert send_indices from global to local. + send_indices = None + if send_indices_global is not None: + if send_indices_global.numel() > 0: + send_indices = global_to_local[send_indices_global] + else: + send_indices = torch.empty(0, dtype=torch.long, device=device) + + # Remap edge_index to local indices. + edge_index_local = global_to_local[edge_index] + + # Single GPU-to-CPU transfer for send/recv splits. + splits_cpu = torch.stack([send_counts, recv_counts]).cpu() + send_splits = splits_cpu[0].tolist() + recv_splits = splits_cpu[1].tolist() + total_recv = sum(recv_splits) + + # Validate (only touches GPU on error path). + has_bad_edges = (edge_index_local < 0).any().item() + has_bad_send = ( + send_indices is not None + and send_indices.numel() > 0 + and ((send_indices < 0) | (send_indices >= total_local_atoms)).any().item() + ) + if has_bad_edges or has_bad_send: + _validate_gp_mappings( + send_indices, + send_indices_global, + edge_index_local, + edge_index, + rank, + rank_assignments, + local_mask, + needed_atoms, + node_partition, + total_local_atoms, + total_atoms, + total_needed_atoms, + send_info, ) # Precompute local/remote edge indices for comm-compute overlap. - # An edge is "local-source" if its source atom is owned by this - # rank (index < total_local_atoms in the remapped edge_index_local). local_edge_mask = edge_index_local[0] < total_local_atoms local_edge_idx = local_edge_mask.nonzero(as_tuple=True)[0] remote_edge_idx = (~local_edge_mask).nonzero(as_tuple=True)[0] @@ -453,7 +548,6 @@ def build_gp_context( total_needed_atoms=total_needed_atoms, send_indices=send_indices, edge_index_local=edge_index_local, - # Precompute Python lists once (avoids .tolist() per layer per fwd) send_splits=send_splits, recv_splits=recv_splits, total_recv=total_recv, diff --git a/src/fairchem/core/models/uma/escn_md.py b/src/fairchem/core/models/uma/escn_md.py index 176db9bba2..368805ba75 100644 --- a/src/fairchem/core/models/uma/escn_md.py +++ b/src/fairchem/core/models/uma/escn_md.py @@ -325,6 +325,7 @@ def __init__( execution_mode: str = "general", use_all_to_all_gp: bool = False, gp_partition_strategy: str = "index_split", + gp_index_exchange_method: str = "a2a", ) -> None: super().__init__() self.max_num_elements = max_num_elements @@ -378,6 +379,7 @@ def __init__( self.edge_chunk_size = edge_chunk_size self.use_all_to_all_gp = use_all_to_all_gp self.gp_partition_strategy = PartitionStrategy(gp_partition_strategy) + self.gp_index_exchange_method = gp_index_exchange_method # Allgather+spatial is not supported because allgather concatenates # per-rank tensors in rank order, which only matches global atom order @@ -956,6 +958,7 @@ def _generate_graph(self, data_dict): world_size=gp_utils.get_gp_world_size(), send_info=graph_dict.get("send_info"), node_partition=node_partition, + index_exchange_method=self.gp_index_exchange_method, ) data_dict["gp_ctx"] = gp_ctx # Store rank_assignments so output heads can reorder From 3a06f49242b06c56ed01d526c6963575edcfcb18 Mon Sep 17 00:00:00 2001 From: rgao user Date: Wed, 3 Jun 2026 23:09:13 +0000 Subject: [PATCH 06/26] Remove allgather index exchange variant, keep 2xA2A Benchmarks showed 2xA2A is ~1.5% faster than 1xallgather at GP=64 across 8 nodes (0.593 vs 0.584 ns/day) due to lower communication volume with variable splits. Remove the inferior allgather variant, the index_exchange_method parameter, and the dispatch logic. Add benchmark rationale to _sparse_index_exchange docstring. --- .../common/parallelism/graph_parallel_a2a.py | 100 +++--------------- src/fairchem/core/models/uma/escn_md.py | 3 - 2 files changed, 12 insertions(+), 91 deletions(-) diff --git a/src/fairchem/core/common/parallelism/graph_parallel_a2a.py b/src/fairchem/core/common/parallelism/graph_parallel_a2a.py index b442154485..14b3b486c6 100644 --- a/src/fairchem/core/common/parallelism/graph_parallel_a2a.py +++ b/src/fairchem/core/common/parallelism/graph_parallel_a2a.py @@ -131,7 +131,17 @@ def _sparse_index_exchange( Step 2: Exchange actual atom indices with variable split sizes. Sends only the exact number of indices needed (no padding), - keeping communication volume minimal. + keeping communication volume minimal: O(sum of needed counts). + + We considered a single-allgather alternative that packs counts + + padded indices into one buffer and calls all_gather once, but it + sends O(P * natoms_per_rank) regardless of sparsity. Benchmarks + on H200 at GP=64 across 8 nodes (256K atoms, UMA-S) showed: + 2x A2A (this fn): 0.593 ns/day (+22.8% vs baseline AG-GP) + 1x all-gather: 0.584 ns/day (+20.9% vs baseline AG-GP) + The 2x A2A is ~1.5% faster because the variable-split second + A2A transfers less data when each rank only needs atoms from + nearby neighbors, not all P ranks. Args: needed_atoms: Global indices of atoms this rank needs, @@ -203,104 +213,23 @@ def _sparse_index_exchange( return send_counts, send_indices_global -def _allgather_index_exchange( - needed_atoms: torch.Tensor, - recv_counts: torch.Tensor, - rank: int, - world_size: int, - device: torch.device, - natoms: int, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Single all-gather variant of index exchange. - - Packs [recv_counts (P ints), needed_atoms (padded to natoms)] into - one buffer and performs a single all_gather call. Each rank then - extracts its send_counts and send_indices from the gathered data. - - Communication volume: O(P * (P + natoms)). - Advantage: one collective instead of two, potentially better NCCL - utilization for small messages. - """ - if not gp_utils.initialized(): - return ( - torch.zeros(world_size, dtype=torch.long, device=device), - None, - ) - - gp_group = gp_utils.get_gp_group() - - # Pack: [recv_counts (P), needed_atoms padded to natoms] - buf_size = world_size + natoms - send_buf = torch.zeros(buf_size, dtype=torch.long, device=device) - send_buf[:world_size] = recv_counts - n_needed = needed_atoms.numel() - if n_needed > 0: - send_buf[world_size : world_size + n_needed] = needed_atoms - - # Single all-gather - recv_bufs = [ - torch.empty(buf_size, dtype=torch.long, device=device) - for _ in range(world_size) - ] - dist.all_gather(recv_bufs, send_buf.contiguous(), group=gp_group) - - # Unpack count matrix: count_matrix[r][s] = how many atoms rank r - # needs from rank s. - count_matrix = torch.stack([buf[:world_size] for buf in recv_bufs]) - # send_counts[r] = count_matrix[r][rank] = atoms rank r needs from me - send_counts = count_matrix[:, rank].contiguous() - - # Extract send_indices: for each rank r, find atoms they need from me - send_indices_parts = [] - for r in range(world_size): - if r == rank: - continue - length = count_matrix[r, rank].item() - if length > 0: - offset = count_matrix[r, :rank].sum().item() - send_indices_parts.append( - recv_bufs[r][world_size + offset : world_size + offset + length] - ) - - if send_indices_parts: - send_indices_global = torch.cat(send_indices_parts) - else: - send_indices_global = torch.empty(0, dtype=torch.long, device=device) - - return send_counts, send_indices_global - - def _resolve_send_metadata( send_info: dict | None, - index_exchange_method: str, needed_atoms: torch.Tensor, recv_counts: torch.Tensor, rank: int, world_size: int, device: torch.device, - total_atoms: int, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Resolve send_counts and send_indices_global. Uses pre-computed send_info when available; otherwise runs - the configured index exchange collective. + the sparse index exchange collective. """ if send_info is not None: return send_info["send_counts"], send_info["send_indices_global"] - if index_exchange_method == "allgather": - with record_function("a2a_allgather_index_exchange"): - return _allgather_index_exchange( - needed_atoms=needed_atoms, - recv_counts=recv_counts, - rank=rank, - world_size=world_size, - device=device, - natoms=total_atoms, - ) - with record_function("a2a_sparse_index_exchange"): return _sparse_index_exchange( needed_atoms=needed_atoms, @@ -411,7 +340,6 @@ def build_gp_context( world_size: int, send_info: dict | None = None, node_partition: torch.Tensor | None = None, - index_exchange_method: str = "a2a", ) -> GPContext: """ Build the GP context from edge connectivity and atom assignments. @@ -425,8 +353,6 @@ def build_gp_context( world_size: GP world size. send_info: Pre-computed send metadata (skips index exchange). node_partition: Pre-computed local atom indices. - index_exchange_method: "a2a" (2x all-to-all) or "allgather" - (1x all-gather). Returns: GPContext with all communication metadata. @@ -465,13 +391,11 @@ def build_gp_context( # Resolve send metadata (pre-computed or via collective). send_counts, send_indices_global = _resolve_send_metadata( send_info, - index_exchange_method, needed_atoms, recv_counts, rank, world_size, device, - total_atoms, ) # Build global-to-local index mapping. diff --git a/src/fairchem/core/models/uma/escn_md.py b/src/fairchem/core/models/uma/escn_md.py index 368805ba75..176db9bba2 100644 --- a/src/fairchem/core/models/uma/escn_md.py +++ b/src/fairchem/core/models/uma/escn_md.py @@ -325,7 +325,6 @@ def __init__( execution_mode: str = "general", use_all_to_all_gp: bool = False, gp_partition_strategy: str = "index_split", - gp_index_exchange_method: str = "a2a", ) -> None: super().__init__() self.max_num_elements = max_num_elements @@ -379,7 +378,6 @@ def __init__( self.edge_chunk_size = edge_chunk_size self.use_all_to_all_gp = use_all_to_all_gp self.gp_partition_strategy = PartitionStrategy(gp_partition_strategy) - self.gp_index_exchange_method = gp_index_exchange_method # Allgather+spatial is not supported because allgather concatenates # per-rank tensors in rank order, which only matches global atom order @@ -958,7 +956,6 @@ def _generate_graph(self, data_dict): world_size=gp_utils.get_gp_world_size(), send_info=graph_dict.get("send_info"), node_partition=node_partition, - index_exchange_method=self.gp_index_exchange_method, ) data_dict["gp_ctx"] = gp_ctx # Store rank_assignments so output heads can reorder From 45e218eea7b935d11e53c4992a426cea081787a9 Mon Sep 17 00:00:00 2001 From: rgao user Date: Fri, 5 Jun 2026 00:01:44 +0000 Subject: [PATCH 07/26] Remove all_to_all_collect_compiled (dead code) Inference always requires autograd for force computation, so the no-grad compile-friendly path was never hit. Remove the function, its dispatch branch in escn_md_block, and associated tests. --- .../core/common/parallelism/__init__.py | 2 - .../common/parallelism/graph_parallel_a2a.py | 68 ------------- src/fairchem/core/models/uma/escn_md_block.py | 30 +----- .../parallelism/test_a2a_correctness.py | 23 +---- .../common/parallelism/test_graph_parallel.py | 99 ------------------- 5 files changed, 5 insertions(+), 217 deletions(-) diff --git a/src/fairchem/core/common/parallelism/__init__.py b/src/fairchem/core/common/parallelism/__init__.py index 99b668f46b..f261a2ace6 100644 --- a/src/fairchem/core/common/parallelism/__init__.py +++ b/src/fairchem/core/common/parallelism/__init__.py @@ -11,7 +11,6 @@ AllToAllCollect, GPContext, all_to_all_collect, - all_to_all_collect_compiled, build_gp_context, ) from fairchem.core.common.parallelism.graph_partition import ( @@ -25,7 +24,6 @@ "GPContext", "PartitionStrategy", "all_to_all_collect", - "all_to_all_collect_compiled", "build_gp_context", "partition_atoms_index_split", "partition_atoms_spatial", diff --git a/src/fairchem/core/common/parallelism/graph_parallel_a2a.py b/src/fairchem/core/common/parallelism/graph_parallel_a2a.py index 14b3b486c6..91da6b7ec2 100644 --- a/src/fairchem/core/common/parallelism/graph_parallel_a2a.py +++ b/src/fairchem/core/common/parallelism/graph_parallel_a2a.py @@ -723,71 +723,3 @@ def all_to_all_collect( gp_ctx.recv_splits, gp_ctx.total_recv, ) - - -def all_to_all_collect_compiled( - x_local: torch.Tensor, - gp_ctx: GPContext, - send_indices: torch.Tensor, -) -> torch.Tensor: - """ - Compile-friendly all-to-all collect using functional collectives. - - Uses ``torch.distributed._functional_collectives.all_to_all_single`` - which is a registered PyTorch op — torch.compile can trace through - it WITHOUT creating a graph break. This eliminates the per-layer - graph break from the ``@torch.compiler.disable`` on - ``AllToAllCollect.forward()``. - - This function does NOT support autograd — gradients will not flow - through the communication. When gradients are needed (e.g., autograd - forces via ``torch.autograd.grad(energy, pos)``), use - ``all_to_all_collect`` instead, which uses an autograd.Function - with proper backward support. - - NOTE: ``all_to_all_single_autograd`` (the funcoll autograd variant) - crashes with torch.compile because it doesn't handle symbolic split - sizes (SymInt). Both BL (all-gather) and A2A have a graph break - when autograd is needed, so this is not a regression vs baseline. - - For MD simulation with spatial partitioning, the split sizes are - effectively constant for hundreds of steps (atoms barely move per - timestep), so torch.compile guards on the split sizes will pass - without recompilation. - - NOTE: Requires NCCL backend. Functional collectives are not - supported on Gloo. CPU/Gloo tests should use ``all_to_all_collect``. - - Args: - x_local: Local atom embeddings, shape (local_atoms, *features). - gp_ctx: Graph parallel context. - send_indices: Local indices of atoms to send. - - Returns: - x_received: Remote atom embeddings, - shape (total_needed, *features). - """ - if send_indices is None: - raise ValueError( - "send_indices is None — build_gp_context should always " - "compute send_indices. Check GP setup." - ) - - # Gather atoms to send (compile-friendly indexing) - x_send = x_local[send_indices].contiguous() - - # Use functional collective — no graph break, no autograd - gp_group = gp_utils.get_gp_group() - - from torch.distributed._functional_collectives import ( - all_to_all_single as functional_a2a, - ) - - x_recv = functional_a2a( - x_send, - output_split_sizes=gp_ctx.recv_splits, - input_split_sizes=gp_ctx.send_splits, - group=gp_group, - ) - - return x_recv diff --git a/src/fairchem/core/models/uma/escn_md_block.py b/src/fairchem/core/models/uma/escn_md_block.py index 70b72a8c24..a4bd835472 100644 --- a/src/fairchem/core/models/uma/escn_md_block.py +++ b/src/fairchem/core/models/uma/escn_md_block.py @@ -19,7 +19,6 @@ from fairchem.core.common.parallelism.graph_parallel_a2a import ( GPContext, all_to_all_collect, - all_to_all_collect_compiled, ) from fairchem.core.models.uma.nn.activation import ( GateActivation, @@ -142,31 +141,10 @@ def forward( """ if gp_ctx is not None and gp_utils.initialized(): # All-to-all path: collect only needed remote embeddings. - # When x requires grad (autograd forces/stress), we use the - # AllToAllCollect autograd.Function so gradients flow through - # the communication. This creates a graph break (same as BL's - # GatherFromModelParallelRegionSumGradPadded), but is necessary - # because funcoll autograd crashes with torch.compile (SymInt - # split sizes). When no autograd is needed, use the - # compile-friendly funcoll (no graph break). - needs_grad = torch.is_grad_enabled() and x.requires_grad - if not self.training and not needs_grad: - # Eval path without autograd: compile-friendly funcoll. - with record_function("a2a_collect_compiled"): - x_received = all_to_all_collect_compiled( - x, - gp_ctx, - send_indices, - ) - x_full = torch.cat([x, x_received], dim=0) - edge_index_local = gp_ctx.edge_index_local - else: - # Training or eval+autograd: AllToAllCollect autograd.Function. - # Supports backward; creates graph break (same as BL). - with record_function("a2a_collect"): - x_received = all_to_all_collect(x, gp_ctx, send_indices) - x_full = torch.cat([x, x_received], dim=0) - edge_index_local = gp_ctx.edge_index_local + with record_function("a2a_collect"): + x_received = all_to_all_collect(x, gp_ctx, send_indices) + x_full = torch.cat([x, x_received], dim=0) + edge_index_local = gp_ctx.edge_index_local local_scatter_target = edge_index_local[1] elif gp_utils.initialized(): # Legacy all-gather path diff --git a/tests/core/common/parallelism/test_a2a_correctness.py b/tests/core/common/parallelism/test_a2a_correctness.py index d2c458fdae..64c9a805bc 100644 --- a/tests/core/common/parallelism/test_a2a_correctness.py +++ b/tests/core/common/parallelism/test_a2a_correctness.py @@ -33,7 +33,6 @@ from fairchem.core.common import distutils, gp_utils from fairchem.core.common.parallelism.graph_parallel_a2a import ( all_to_all_collect, - all_to_all_collect_compiled, build_gp_context, ) from fairchem.core.common.parallelism.graph_partition import ( @@ -109,19 +108,8 @@ def _correctness_test_inner( send_indices = gp_ctx.send_indices - # Test both collect functions + # Test collect function x_recv_autograd = all_to_all_collect(x_local, gp_ctx, send_indices) - x_recv_compiled = all_to_all_collect_compiled(x_local, gp_ctx, send_indices) - - # Verify shapes - assert x_recv_autograd.shape == x_recv_compiled.shape, ( - f"Rank {rank}: shape mismatch " - f"autograd={x_recv_autograd.shape} " - f"vs compiled={x_recv_compiled.shape}" - ) - - # Verify autograd == compiled - values_match = torch.allclose(x_recv_autograd, x_recv_compiled, atol=1e-6) # Verify received values are correct: # x_recv should contain embeddings of gp_ctx.needed_atoms @@ -166,7 +154,6 @@ def _correctness_test_inner( "local_atoms": gp_ctx.total_local_atoms, "needed_atoms": gp_ctx.total_needed_atoms, "num_edges": rank_edge_index.shape[1], - "values_match": values_match, "recv_correct": recv_correct, "edge_valid": edge_valid, "edge_in_bounds": edge_in_bounds, @@ -217,7 +204,6 @@ def test_a2a_correctness_gloo(strategy, num_atoms): for result in all_rank_results: r = result["rank"] - assert result["values_match"], f"Rank {r}: autograd vs compiled mismatch" assert result["recv_correct"], ( f"Rank {r}: received embeddings don't match " f"expected values" ) @@ -269,7 +255,6 @@ def test_a2a_consistency_across_graph_sizes(strategy): for result in all_rank_results: r = result["rank"] - assert result["values_match"], f"Rank {r}: autograd vs compiled mismatch" assert result["recv_correct"], ( f"Rank {r}: received embeddings don't match " f"expected values" ) @@ -310,17 +295,14 @@ def _multidim_test_inner(x_global, pos, edge_index, num_atoms, strategy): send_indices = gp_ctx.send_indices x_recv = all_to_all_collect(x_local, gp_ctx, send_indices) - x_recv_c = all_to_all_collect_compiled(x_local, gp_ctx, send_indices) # Verify expected = x_global[gp_ctx.needed_atoms] recv_correct = torch.allclose(x_recv, expected, atol=1e-6) - compiled_match = torch.allclose(x_recv, x_recv_c, atol=1e-6) return { "rank": rank, "recv_correct": recv_correct, - "compiled_match": compiled_match, "recv_shape": x_recv.shape, "expected_shape": expected.shape, } @@ -370,9 +352,6 @@ def test_a2a_multidim_embeddings(strategy): f"shape={result['recv_shape']} " f"vs {result['expected_shape']}" ) - assert result["compiled_match"], ( - f"Rank {r}: autograd vs compiled mismatch " f"for multidim" - ) # ========================================================================= diff --git a/tests/core/common/parallelism/test_graph_parallel.py b/tests/core/common/parallelism/test_graph_parallel.py index 322fb5bb9f..8b0e9ebbcd 100644 --- a/tests/core/common/parallelism/test_graph_parallel.py +++ b/tests/core/common/parallelism/test_graph_parallel.py @@ -17,7 +17,6 @@ ) from fairchem.core.common.parallelism.graph_parallel_a2a import ( all_to_all_collect, - all_to_all_collect_compiled, build_gp_context, ) from fairchem.core.common.parallelism.graph_partition import ( @@ -767,101 +766,3 @@ def test_send_info_matches_fused_exchange(): f"Rank {result['rank']}: functional mismatch — " f"all-to-all produced different embeddings" ) - - -def compiled_collect_test( - atomic_numbers: torch.Tensor, - edge_index: torch.Tensor, - strategy: str, -): - """ - Verify that all_to_all_collect_compiled produces same results - as the original all_to_all_collect (autograd version). - """ - rank = gp_utils.get_gp_rank() - world_size = gp_utils.get_gp_world_size() - - # Partition atoms (deterministic positions — same across all ranks) - if strategy == "spatial": - torch.manual_seed(42) - pos = torch.rand(atomic_numbers.shape[0], 3) - rank_assignments = partition_atoms_spatial(pos, world_size, pos.device) - else: - rank_assignments = partition_atoms_index_split( - atomic_numbers.shape[0], world_size, atomic_numbers.device - ) - - # Filter edges to this rank's partition - node_partition = (rank_assignments == rank).nonzero(as_tuple=True)[0] - target_mask = (rank_assignments == rank)[edge_index[1]] - rank_edge_index = edge_index[:, target_mask] - - gp_ctx = build_gp_context( - rank_edge_index, - rank_assignments, - rank, - world_size, - node_partition=node_partition, - ) - - send_indices = gp_ctx.send_indices - - x = atomic_numbers[node_partition].unsqueeze(1).float() - - # Reference: original autograd version - x_ref = all_to_all_collect(x, gp_ctx, send_indices) - - # Test: compiled functional version - x_compiled = all_to_all_collect_compiled(x, gp_ctx, send_indices) - - match = torch.allclose(x_ref, x_compiled, atol=1e-6) - shape_match = x_ref.shape == x_compiled.shape - - return { - "rank": rank, - "match": match, - "shape_match": shape_match, - "ref_shape": x_ref.shape, - "compiled_shape": x_compiled.shape, - } - - -@pytest.mark.parametrize("strategy", ["index_split", "spatial"]) -def test_compiled_collect_matches_autograd(strategy): - """ - Verify that compile-friendly functional collective produces - identical results to the @torch.compiler.disable autograd version. - """ - num_atoms = 12 - src = [] - dst = [] - for i in range(num_atoms): - for j in range(num_atoms): - if i != j: - src.append(i) - dst.append(j) - edge_index = torch.tensor([src, dst], dtype=torch.long) - atomic_numbers = torch.arange( - 2, 2 + num_atoms, dtype=torch.float, requires_grad=False - ) - - config = PGConfig(backend="gloo", world_size=2, gp_group_size=2, use_gp=True) - all_rank_results = spawn_multi_process( - config, - compiled_collect_test, - init_pg_and_rank_and_launch_test, - atomic_numbers, - edge_index, - strategy, - ) - - for result in all_rank_results: - assert result["shape_match"], ( - f"Rank {result['rank']}: shape mismatch. " - f"ref={result['ref_shape']}, " - f"compiled={result['compiled_shape']}" - ) - assert result["match"], ( - f"Rank {result['rank']}: functional mismatch — " - f"compiled collect produced different embeddings" - ) From 13754ca7b3531e56bf0bbf44aff437a8537c090b Mon Sep 17 00:00:00 2001 From: rgao user Date: Fri, 5 Jun 2026 00:20:09 +0000 Subject: [PATCH 08/26] Replace _validate_gp_mappings with simple assertions The complex 13-parameter diagnostic function was hard to follow. Replaced with two inline asserts for the only error conditions: negative edge_index_local entries and out-of-bounds send_indices. --- .../common/parallelism/graph_parallel_a2a.py | 126 ++---------------- 1 file changed, 12 insertions(+), 114 deletions(-) diff --git a/src/fairchem/core/common/parallelism/graph_parallel_a2a.py b/src/fairchem/core/common/parallelism/graph_parallel_a2a.py index 91da6b7ec2..78b9aa79ca 100644 --- a/src/fairchem/core/common/parallelism/graph_parallel_a2a.py +++ b/src/fairchem/core/common/parallelism/graph_parallel_a2a.py @@ -7,7 +7,6 @@ from __future__ import annotations -import logging from dataclasses import dataclass import torch @@ -240,98 +239,6 @@ def _resolve_send_metadata( ) -def _validate_gp_mappings( - send_indices: torch.Tensor | None, - send_indices_global: torch.Tensor | None, - edge_index_local: torch.Tensor, - edge_index: torch.Tensor, - rank: int, - rank_assignments: torch.Tensor, - local_mask: torch.Tensor, - needed_atoms: torch.Tensor, - node_partition: torch.Tensor, - total_local_atoms: int, - total_atoms: int, - total_needed_atoms: int, - send_info: dict | None, -) -> None: - """ - Validate index mappings; raise with diagnostics on failure. - - Only called when fast-path checks detect an error, so - GPU→CPU transfers here are acceptable. - """ - device = edge_index.device - - if send_indices is not None and send_indices.numel() > 0: - bad_mask = (send_indices < 0) | (send_indices >= total_local_atoms) - if bad_mask.any(): - n_bad_send = bad_mask.sum().item() - bad_global = send_indices_global[bad_mask][:10].tolist() - bad_ra = rank_assignments[send_indices_global[bad_mask][:10]].tolist() - raise RuntimeError( - f"Rank {rank}: received requests for atoms not " - f"in our partition " - f"({n_bad_send}/{send_indices.numel()} OOB). " - f"bad_global={bad_global}, " - f"bad_ranks={bad_ra}. " - f"This usually means rank_assignments differs " - f"across ranks (e.g. non-deterministic crystal " - f"generation)." - ) - - n_bad = (edge_index_local < 0).sum().item() - if n_bad > 0: - bad_cols = (edge_index_local < 0).any(dim=0) - bad_globals = edge_index[:, bad_cols].unique() - bad_ranks = rank_assignments[bad_globals] - - edge_src_remote = ~local_mask[edge_index[0]] - edge_needed_mask = torch.zeros(total_atoms, dtype=torch.bool, device=device) - edge_needed_mask[edge_index[0, edge_src_remote]] = True - edge_needed_mask &= ~local_mask - edge_needed_count = edge_needed_mask.sum().item() - - needed_set = torch.zeros(total_atoms, dtype=torch.bool, device=device) - needed_set[needed_atoms] = True - missing = edge_needed_mask & ~needed_set - missing_count = missing.sum().item() - local_in_needed = (local_mask & needed_set).sum().item() - n_local_bad = local_mask[bad_globals].sum().item() - - logging.error( - f"Rank {rank}: GP DIAGNOSTIC — " - f"{n_bad} entries in edge_index_local are -1. " - f"edge_needed={edge_needed_count}, " - f"needed_atoms_count={total_needed_atoms}, " - f"missing_from_needed={missing_count}, " - f"local_in_needed={local_in_needed}, " - f"n_local_bad={n_local_bad}/{len(bad_globals)} " - f"bad globals, " - f"total_atoms={total_atoms}, " - f"total_local={total_local_atoms}, " - f"node_partition_range=" - f"[{node_partition.min().item()}, " - f"{node_partition.max().item()}], " - f"send_info_provided={send_info is not None}, " - f"bad_globals[:20]={bad_globals.tolist()[:20]}, " - f"bad_ranks[:20]={bad_ranks.tolist()[:20]}" - ) - if missing_count > 0: - missing_indices = missing.nonzero(as_tuple=True)[0][:10] - logging.error( - f"Rank {rank}: Missing atoms (edge-needed but " - f"not in needed_atoms): " - f"{missing_indices.tolist()}" - ) - raise RuntimeError( - f"Rank {rank}: edge_index has {n_bad} endpoints " - f"not in global_to_local mapping. This indicates " - f"a mismatch between graph edges and partition " - f"assignments." - ) - - @torch.compiler.disable def build_gp_context( edge_index: torch.Tensor, @@ -429,28 +336,19 @@ def build_gp_context( recv_splits = splits_cpu[1].tolist() total_recv = sum(recv_splits) - # Validate (only touches GPU on error path). - has_bad_edges = (edge_index_local < 0).any().item() - has_bad_send = ( - send_indices is not None - and send_indices.numel() > 0 - and ((send_indices < 0) | (send_indices >= total_local_atoms)).any().item() + # Validate mappings. + assert not (edge_index_local < 0).any().item(), ( + f"Rank {rank}: edge_index_local has negative entries — " + f"graph edges reference atoms not in any partition's " + f"global_to_local mapping." ) - if has_bad_edges or has_bad_send: - _validate_gp_mappings( - send_indices, - send_indices_global, - edge_index_local, - edge_index, - rank, - rank_assignments, - local_mask, - needed_atoms, - node_partition, - total_local_atoms, - total_atoms, - total_needed_atoms, - send_info, + if send_indices is not None and send_indices.numel() > 0: + assert ( + not ((send_indices < 0) | (send_indices >= total_local_atoms)).any().item() + ), ( + f"Rank {rank}: send_indices out of bounds " + f"[0, {total_local_atoms}) — remote rank requested " + f"atoms not in our partition." ) # Precompute local/remote edge indices for comm-compute overlap. From 37f739352a90711b61db8ae9116d30efe8f2ea18 Mon Sep 17 00:00:00 2001 From: rgao user Date: Fri, 5 Jun 2026 00:26:14 +0000 Subject: [PATCH 09/26] Use torch._assert_async for GPU-only validation (no host sync) --- .../common/parallelism/graph_parallel_a2a.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/src/fairchem/core/common/parallelism/graph_parallel_a2a.py b/src/fairchem/core/common/parallelism/graph_parallel_a2a.py index 78b9aa79ca..a44cd23de6 100644 --- a/src/fairchem/core/common/parallelism/graph_parallel_a2a.py +++ b/src/fairchem/core/common/parallelism/graph_parallel_a2a.py @@ -336,19 +336,17 @@ def build_gp_context( recv_splits = splits_cpu[1].tolist() total_recv = sum(recv_splits) - # Validate mappings. - assert not (edge_index_local < 0).any().item(), ( - f"Rank {rank}: edge_index_local has negative entries — " - f"graph edges reference atoms not in any partition's " - f"global_to_local mapping." + # Validate mappings (async — no device-to-host sync). + torch._assert_async( + ~(edge_index_local < 0).any(), + "edge_index_local has negative entries — graph edges " + "reference atoms not in global_to_local mapping.", ) if send_indices is not None and send_indices.numel() > 0: - assert ( - not ((send_indices < 0) | (send_indices >= total_local_atoms)).any().item() - ), ( - f"Rank {rank}: send_indices out of bounds " - f"[0, {total_local_atoms}) — remote rank requested " - f"atoms not in our partition." + torch._assert_async( + ~((send_indices < 0) | (send_indices >= total_local_atoms)).any(), + "send_indices out of bounds — remote rank requested " + "atoms not in our partition.", ) # Precompute local/remote edge indices for comm-compute overlap. From 1d6fefc316557c4402a6cc9811d9397ca09983b8 Mon Sep 17 00:00:00 2001 From: rgao user Date: Fri, 5 Jun 2026 00:34:45 +0000 Subject: [PATCH 10/26] Add graph parallel verification doc from rgao_a2a_comms --- .../parallelism/graph_parallel_verfication.md | 312 ++++++++++++++++++ 1 file changed, 312 insertions(+) create mode 100644 src/fairchem/core/common/parallelism/graph_parallel_verfication.md diff --git a/src/fairchem/core/common/parallelism/graph_parallel_verfication.md b/src/fairchem/core/common/parallelism/graph_parallel_verfication.md new file mode 100644 index 0000000000..6089d061e7 --- /dev/null +++ b/src/fairchem/core/common/parallelism/graph_parallel_verfication.md @@ -0,0 +1,312 @@ +# A2A Graph Parallel — Test & Benchmark Runbook + +All commands are run-ready. Copy-paste from the fairchem repo root: + +```bash +cd /home/rgao/fairchem +source .venv/bin/activate +``` + +--- + +## 1. Correctness Tests + +### 1a. GPContext unit tests (CPU, <1s) + +```bash +pytest tests/core/common/parallelism/test_graph_parallel.py::TestBuildGPContext -v +``` + +5 tests: `test_basic_context_building`, `test_global_to_local_mapping`, +`test_no_cross_partition_edges`, `test_edge_split_indices`, `test_edge_split_no_remote_edges`. + +### 1b. Distributed A2A primitives (CPU/Gloo, ~10s) + +```bash +pytest tests/core/common/parallelism/test_graph_parallel.py -v -k "not TestBuildGPContext" +``` + +8 tests: forward equivalence (2 graph topologies), backward gradient, multi-rank (2 and 3), +spatial partition, compiled-vs-autograd (index_split + spatial). + +### 1c. A2A embedding correctness (CPU/Gloo, ~30s) + +```bash +pytest tests/core/common/parallelism/test_a2a_correctness.py -v +``` + +8 tests: correctness at 100/500 atoms × 2 strategies, consistency across graph sizes × 2, +1536-dim embeddings × 2. + +### 1d. Full-model GPU correctness (8 GPUs, ~5 min) + +```bash +torchrun --nproc_per_node=8 tests/core/common/parallelism/test_a2a_correctness.py --natoms 1000 +``` + +Loads UMA-S checkpoint, runs BL vs A2A inference, compares energy/forces/stress (tol=1e-4). + +### 1e. Predict pipeline + MD consistency (CPU, PR3 branch, ~2 min) + +```bash +pytest tests/core/units/mlip_unit/test_predict.py::test_parallel_predict_unit_cpu -v -k "spatial or index_split" +pytest tests/core/units/mlip_unit/test_predict.py::test_parallel_predict_unit_gpu -v -k "spatial or index_split" +pytest tests/core/units/mlip_unit/test_predict.py::test_parallel_predict_unit_batch -v -k "spatial" +pytest tests/core/units/mlip_unit/test_predict.py::test_merge_mole_md_consistency -v -k "spatial or index_split" +``` + +### 1f. All correctness at once + +```bash +pytest tests/core/common/parallelism/ -v +``` + +--- + +## 2. Regression Tests + +### 2a. Existing tests must pass unchanged + +```bash +pytest tests/core/common/test_gp_utils.py -v +pytest tests/core/models/uma/test_escn_md.py -v +pytest tests/core/models/uma/test_compile.py -v +pytest tests/core/components/test_uma_speed_benchmark.py -v +pytest tests/core/units/mlip_unit/test_predict.py::test_parallel_predict_unit_cpu -v -k "not spatial and not index_split" +``` + +### 2b. All regression at once + +```bash +pytest tests/core/common/test_gp_utils.py tests/core/models/uma/test_escn_md.py tests/core/models/uma/test_compile.py tests/core/components/test_uma_speed_benchmark.py -v +``` + +--- + +## 3. Performance Benchmarks + +All benchmarks use `InferenceBenchRunner` with FCC crystal systems (deterministic), +`timeiters=100`, `repeats=5`, turbo mode (compile=True, tf32=True). +Results written to `benchmark_results.json` in the run directory under `/checkpoint/ocp/rgao/speed/`. + +### 3a. 8-GPU single-node (local, ~10 min each) + +```bash +# BL baseline +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=local_8gpu \ + runner.natoms_list=[8000,32000] \ + runner.timeiters=100 \ + runner.repeats=5 + +# A2A + spatial +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=local_8gpu \ + runner.natoms_list=[8000,32000] \ + runner.timeiters=100 \ + runner.repeats=5 \ + '+runner.overrides={backbone: {use_all_to_all_gp: true, gp_partition_strategy: spatial}}' +``` + +Expected: A2A within ±5% of BL (NVLink hides comm cost). + +### 3b. 16-GPU 2-node (SLURM, ~15 min each) + +```bash +# BL baseline +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=slurm \ + job.scheduler.num_nodes=2 \ + job.scheduler.ranks_per_node=8 \ + job.graph_parallel_group_size=16 \ + job.scheduler.slurm.qos=h200_dev \ + job.scheduler.slurm.account=ocp \ + runner.natoms_list=[16000,64000] \ + runner.timeiters=100 \ + runner.repeats=5 + +# A2A + spatial +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=slurm \ + job.scheduler.num_nodes=2 \ + job.scheduler.ranks_per_node=8 \ + job.graph_parallel_group_size=16 \ + job.scheduler.slurm.qos=h200_dev \ + job.scheduler.slurm.account=ocp \ + runner.natoms_list=[16000,64000] \ + runner.timeiters=100 \ + runner.repeats=5 \ + '+runner.overrides={backbone: {use_all_to_all_gp: true, gp_partition_strategy: spatial}}' +``` + +Expected: A2A +2-3%. + +### 3c. 32-GPU 4-node (SLURM, ~15 min each) + +```bash +# BL baseline +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=slurm \ + job.scheduler.num_nodes=4 \ + job.scheduler.ranks_per_node=8 \ + job.graph_parallel_group_size=32 \ + job.scheduler.slurm.qos=h200_alignment_shared \ + job.scheduler.slurm.account=ocp \ + runner.natoms_list=[32000,128000] \ + runner.timeiters=100 \ + runner.repeats=5 + +# A2A + spatial +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=slurm \ + job.scheduler.num_nodes=4 \ + job.scheduler.ranks_per_node=8 \ + job.graph_parallel_group_size=32 \ + job.scheduler.slurm.qos=h200_alignment_shared \ + job.scheduler.slurm.account=ocp \ + runner.natoms_list=[32000,128000] \ + runner.timeiters=100 \ + runner.repeats=5 \ + '+runner.overrides={backbone: {use_all_to_all_gp: true, gp_partition_strategy: spatial}}' +``` + +Expected: A2A +5-10%. + +### 3d. 64-GPU 8-node — CANONICAL BENCHMARK (SLURM, ~30 min each) + +```bash +# BL baseline +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=slurm \ + job.scheduler.num_nodes=8 \ + job.scheduler.ranks_per_node=8 \ + job.graph_parallel_group_size=64 \ + job.scheduler.slurm.qos=h200_alignment_shared \ + job.scheduler.slurm.account=ocp \ + runner.natoms_list=[64000,256000] \ + runner.timeiters=100 \ + runner.repeats=5 + +# A2A + spatial +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=slurm \ + job.scheduler.num_nodes=8 \ + job.scheduler.ranks_per_node=8 \ + job.graph_parallel_group_size=64 \ + job.scheduler.slurm.qos=h200_alignment_shared \ + job.scheduler.slurm.account=ocp \ + runner.natoms_list=[64000,256000] \ + runner.timeiters=100 \ + runner.repeats=5 \ + '+runner.overrides={backbone: {use_all_to_all_gp: true, gp_partition_strategy: spatial}}' +``` + +Expected: A2A +10-20%. + +### 3e. 128-GPU 16-node (SLURM, ~30 min each) + +```bash +# BL baseline +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=slurm \ + job.scheduler.num_nodes=16 \ + job.scheduler.ranks_per_node=8 \ + job.graph_parallel_group_size=128 \ + job.scheduler.slurm.qos=h200_alignment_shared \ + job.scheduler.slurm.account=ocp \ + runner.natoms_list=[128000,512000] \ + runner.timeiters=100 \ + runner.repeats=5 + +# A2A + spatial +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=slurm \ + job.scheduler.num_nodes=16 \ + job.scheduler.ranks_per_node=8 \ + job.graph_parallel_group_size=128 \ + job.scheduler.slurm.qos=h200_alignment_shared \ + job.scheduler.slurm.account=ocp \ + runner.natoms_list=[128000,512000] \ + runner.timeiters=100 \ + runner.repeats=5 \ + '+runner.overrides={backbone: {use_all_to_all_gp: true, gp_partition_strategy: spatial}}' +``` + +Expected: A2A +30-40%. + +### 3f. 256-GPU 32-node (SLURM, ~45 min each) + +```bash +# BL baseline +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=slurm \ + job.scheduler.num_nodes=32 \ + job.scheduler.ranks_per_node=8 \ + job.graph_parallel_group_size=256 \ + job.scheduler.slurm.qos=h200_alignment_shared \ + job.scheduler.slurm.account=ocp \ + runner.natoms_list=[256000,1024000] \ + runner.timeiters=100 \ + runner.repeats=5 + +# A2A + spatial +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=slurm \ + job.scheduler.num_nodes=32 \ + job.scheduler.ranks_per_node=8 \ + job.graph_parallel_group_size=256 \ + job.scheduler.slurm.qos=h200_alignment_shared \ + job.scheduler.slurm.account=ocp \ + runner.natoms_list=[256000,1024000] \ + runner.timeiters=100 \ + runner.repeats=5 \ + '+runner.overrides={backbone: {use_all_to_all_gp: true, gp_partition_strategy: spatial}}' +``` + +Expected: A2A +60-80%. + +--- + +## 4. Reading Results + +```bash +# Find latest runs +ls -lt /checkpoint/ocp/rgao/speed/ | head -20 + +# Read a specific result +python3 -c " +import json, sys +with open(sys.argv[1]) as f: + d = json.load(f) +ws = d['config']['world_size'] +for model, data in d['model_to_qps_data'].items(): + print(f'Model: {model}, world_size={ws}') + for natoms, ns_day in data: + atoms_per_gpu = natoms // ws + print(f' {natoms:>7} atoms ({atoms_per_gpu}/gpu): {ns_day:.3f} ns/day') +" /checkpoint/ocp/rgao/speed//benchmark_results.json +``` + +--- + +## 5. Reference Results (H200 turbo) + +### 4k atoms/rank (canonical): + +| GPUs | BL ns/day | A2A ns/day | Speedup | BL WS Eff | A2A WS Eff | +|------|-----------|------------|---------|-----------|------------| +| 8 | 0.651 | 0.670 | +2.9% | 100% | 100% | +| 16 | 0.610 | 0.626 | +2.6% | 93.7% | 93.6% | +| 32 | 0.568 | 0.623 | +9.7% | 87.3% | 93.1% | +| 64 | 0.477 | 0.570 | +19.5% | 73.2% | 85.1% | +| 128 | 0.359 | 0.498 | +38.7% | 55.1% | 74.3% | +| 256 | 0.246 | 0.433 | +76.0% | 37.8% | 64.6% | + +### 1k atoms/rank: + +| GPUs | BL ns/day | A2A ns/day | Speedup | +|------|-----------|------------|---------| +| 8 | 2.088 | 2.010 | -3.7% | +| 16 | 2.043 | 1.990 | -2.6% | +| 32 | 1.769 | 1.848 | +4.5% | +| 64 | 1.488 | 1.626 | +9.3% | From b129bfa0c16b1e2514af40624c1ecdf7198b4bd3 Mon Sep 17 00:00:00 2001 From: rgao user Date: Fri, 5 Jun 2026 00:44:01 +0000 Subject: [PATCH 11/26] Remove send_info precomputation (dead code with radius_pbc v2) v2 radius graph does internal edge filtering so the full edge_index is never exposed, making send_info derivation in compute.py impossible. Revert compute.py to main, remove _resolve_send_metadata, and always use _sparse_index_exchange in build_gp_context. --- .../common/parallelism/graph_parallel_a2a.py | 47 ++---- src/fairchem/core/graph/compute.py | 129 ++-------------- src/fairchem/core/models/uma/escn_md.py | 20 --- .../common/parallelism/test_graph_parallel.py | 143 ------------------ 4 files changed, 25 insertions(+), 314 deletions(-) diff --git a/src/fairchem/core/common/parallelism/graph_parallel_a2a.py b/src/fairchem/core/common/parallelism/graph_parallel_a2a.py index a44cd23de6..373cb574b6 100644 --- a/src/fairchem/core/common/parallelism/graph_parallel_a2a.py +++ b/src/fairchem/core/common/parallelism/graph_parallel_a2a.py @@ -212,40 +212,12 @@ def _sparse_index_exchange( return send_counts, send_indices_global -def _resolve_send_metadata( - send_info: dict | None, - needed_atoms: torch.Tensor, - recv_counts: torch.Tensor, - rank: int, - world_size: int, - device: torch.device, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Resolve send_counts and send_indices_global. - - Uses pre-computed send_info when available; otherwise runs - the sparse index exchange collective. - """ - if send_info is not None: - return send_info["send_counts"], send_info["send_indices_global"] - - with record_function("a2a_sparse_index_exchange"): - return _sparse_index_exchange( - needed_atoms=needed_atoms, - recv_counts=recv_counts, - rank=rank, - world_size=world_size, - device=device, - ) - - @torch.compiler.disable def build_gp_context( edge_index: torch.Tensor, rank_assignments: torch.Tensor, rank: int, world_size: int, - send_info: dict | None = None, node_partition: torch.Tensor | None = None, ) -> GPContext: """ @@ -258,7 +230,6 @@ def build_gp_context( shape (total_atoms,). rank: This rank's GP rank. world_size: GP world size. - send_info: Pre-computed send metadata (skips index exchange). node_partition: Pre-computed local atom indices. Returns: @@ -295,15 +266,15 @@ def build_gp_context( needed_atoms = needed_atoms[sort_order] needed_from_ranks = needed_from_ranks[sort_order] - # Resolve send metadata (pre-computed or via collective). - send_counts, send_indices_global = _resolve_send_metadata( - send_info, - needed_atoms, - recv_counts, - rank, - world_size, - device, - ) + # Exchange send metadata via collective. + with record_function("a2a_sparse_index_exchange"): + send_counts, send_indices_global = _sparse_index_exchange( + needed_atoms=needed_atoms, + recv_counts=recv_counts, + rank=rank, + world_size=world_size, + device=device, + ) # Build global-to-local index mapping. # Local atoms: [0, total_local_atoms) diff --git a/src/fairchem/core/graph/compute.py b/src/fairchem/core/graph/compute.py index beb3dc2adc..83e18dc1a3 100644 --- a/src/fairchem/core/graph/compute.py +++ b/src/fairchem/core/graph/compute.py @@ -22,89 +22,27 @@ def filter_edges_by_node_partition( cell_offsets: torch.Tensor, neighbors: torch.Tensor, num_atoms: int, - rank_assignments: torch.Tensor | None = None, - rank: int | None = None, - world_size: int | None = None, -) -> ( - tuple[torch.Tensor, torch.Tensor, torch.Tensor] - | tuple[torch.Tensor, torch.Tensor, torch.Tensor, dict] -): - """ - Filter edges to keep only those where the target atom belongs to - the node partition. - - When rank_assignments, rank, and world_size are provided, also - computes send_info: which local atoms need to be sent to which - ranks for all-to-all graph parallel communication. This exploits - access to the full (pre-filter) edge_index to derive send - metadata locally, eliminating the need for an NCCL index-exchange - collective in build_gp_context. +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Filter edges to keep only those where the target atom belongs to the node partition. + edge_index is shape (2, num_edges) where the first row is the source atom index and the second row is the target atom index for each edge + cell_offsets is shape (num_edges, 3) + neighbors is cardinality of the edge_index per system in the batch Args: - node_partition: Atom indices in the current rank's partition. - edge_index: Full edge index, shape (2, num_edges). - cell_offsets: Cell offsets, shape (num_edges, 3). - neighbors: Edge count per system in the batch. - num_atoms: Total atoms across all batches. - rank_assignments: Rank for each atom, shape (num_atoms,). - If provided along with rank and world_size, send_info - is computed and returned as a 4th element. - rank: This rank's GP rank. - world_size: GP world size. + node_partition: Tensor of atom indices belonging to the current rank's partition. + edge_index: Edge index tensor of shape (2, num_edges), where row 0 is the source and 1 is the target atom. + cell_offsets: Cell offsets tensor of shape (num_edges, 3). + neighbors: Tensor with edge count per system in the batch (length = num_systems). + num_atoms: Total number of atoms across all batches. Used to create a boolean mask for filtering. Returns: - Filtered (edge_index, cell_offsets, neighbors). - If rank_assignments is provided, also returns send_info dict - with keys: send_counts, send_indices_global. + Filtered edge_index, cell_offsets, and neighbors tensors. """ target_atoms = edge_index[1] node_mask = torch.zeros(num_atoms, dtype=torch.bool, device=target_atoms.device) node_mask[node_partition] = True local_edge_mask = node_mask[target_atoms] - # Compute send info BEFORE discarding non-local edges. - # An edge (src, tgt) where src is LOCAL and tgt is REMOTE means - # src must be sent to rank_assignments[tgt]. - send_info = None - if rank_assignments is not None and rank is not None and world_size is not None: - src_is_local = node_mask[edge_index[0]] - tgt_is_remote = ~local_edge_mask - send_edge_mask = src_is_local & tgt_is_remote - - if send_edge_mask.any(): - send_src = edge_index[0, send_edge_mask] - send_dst_rank = rank_assignments[edge_index[1, send_edge_mask]] - - # Unique (dst_rank, src_atom) pairs, sorted by rank then atom. - # Key layout: dst_rank * num_atoms + src_atom ensures rank-major - # ordering, matching what the index exchange produces. - key = send_dst_rank.to(torch.long) * num_atoms + send_src.to(torch.long) - unique_keys = key.unique(sorted=True) - send_ranks = unique_keys // num_atoms - send_atoms = unique_keys % num_atoms - - send_counts = torch.zeros( - world_size, dtype=torch.long, device=edge_index.device - ) - send_counts.scatter_add_( - 0, - send_ranks, - torch.ones_like(send_ranks), - ) - send_info = { - "send_counts": send_counts, - "send_indices_global": send_atoms, - } - else: - send_info = { - "send_counts": torch.zeros( - world_size, dtype=torch.long, device=edge_index.device - ), - "send_indices_global": torch.empty( - 0, dtype=torch.long, device=edge_index.device - ), - } - # Create system index for each edge to track which system each edge belongs to num_systems = neighbors.shape[0] edge_system_idx = torch.repeat_interleave( @@ -117,8 +55,6 @@ def filter_edges_by_node_partition( if neighbors.shape[0] == 1: # If there's only one system, we can skip the scatter_add step and just return the count of remaining edges new_neighbors = local_edge_mask.sum(dtype=neighbors.dtype).unsqueeze(0) - if send_info is not None: - return edge_index, cell_offsets, new_neighbors, send_info return edge_index, cell_offsets, new_neighbors filtered_edge_system_idx = edge_system_idx[local_edge_mask] @@ -133,8 +69,6 @@ def filter_edges_by_node_partition( torch.ones_like(filtered_edge_system_idx, dtype=neighbors.dtype), ) - if send_info is not None: - return edge_index, cell_offsets, new_neighbors, send_info return edge_index, cell_offsets, new_neighbors @@ -189,12 +123,8 @@ def generate_graph( radius_pbc_version: int, pbc: torch.Tensor, node_partition: torch.Tensor | None = None, - rank_assignments: torch.Tensor | None = None, - rank: int | None = None, - world_size: int | None = None, ) -> dict: - """ - Generate a graph representation from atomic structure data. + """Generate a graph representation from atomic structure data. Args: data (dict): A dictionary containing a batch of molecular structures. @@ -208,9 +138,6 @@ def generate_graph( radius_pbc_version: the version of radius_pbc impl (1, 2, or 3 for NVIDIA) pbc (list[bool]): The periodic boundary conditions in 3 dimensions, defaults to [True,True,True] for 3D pbc node_partition (torch.Tensor | None): The partitioning of the nodes (atoms) for distributed inference. If provided, returned graph will be filtered to keep only edges where the target atom (edge_index[1,:]) belongs to the current rank's partition. - rank_assignments: Rank for each atom (for A2A send_info). - rank: This rank's GP rank (for A2A send_info). - world_size: GP world size (for A2A send_info). Returns: dict: A dictionary containing the generated graph with the following keys: @@ -220,19 +147,13 @@ def generate_graph( - 'cell_offsets' (torch.Tensor): Offsets of the cell vectors for each edge. - 'offset_distances' (torch.Tensor): Distances between the atoms connected by the edges, including the cell offsets. - 'neighbors' (torch.Tensor): Number of neighbors for each atom. - - 'send_info' (dict, optional): Send metadata for A2A GP when rank_assignments is provided. """ if radius_pbc_version == 1: radius_graph_pbc_fn = radius_graph_pbc elif radius_pbc_version == 2: radius_graph_pbc_fn = radius_graph_pbc_v2 if node_partition is not None: - # Use setattr for compatibility with SimpleNamespace - # (used by halo filtering) and regular data dicts. - try: - data["node_partition"] = node_partition - except TypeError: - data.node_partition = node_partition + data["node_partition"] = node_partition elif radius_pbc_version == 3: radius_graph_pbc_fn = radius_graph_pbc_nvidia else: @@ -246,30 +167,15 @@ def generate_graph( pbc=pbc, ) - # V2 does its own internal edge filtering when node_partition is set, - # which is faster than post-filtering. However, this means send_info - # cannot be computed here for v2 (the full edge_index is needed). - # Instead, build_gp_context falls back to _sparse_index_exchange - # (~4ms NCCL collective) when send_info is None. Bypassing v2's - # internal filter to compute send_info was benchmarked and is ~12ms - # SLOWER because v2 generates edges for ALL atoms instead of local - # partition. - send_info = None + # for v2 it is still faster right now to not do this post filtering, need to investigate further if node_partition is not None and radius_pbc_version != 2: - filter_result = filter_edges_by_node_partition( + edge_index, cell_offsets, neighbors = filter_edges_by_node_partition( node_partition, edge_index, cell_offsets, neighbors, num_atoms=data.pos.shape[0], - rank_assignments=rank_assignments, - rank=rank, - world_size=world_size, ) - if rank_assignments is not None: - edge_index, cell_offsets, neighbors, send_info = filter_result - else: - edge_index, cell_offsets, neighbors = filter_result out = get_pbc_distances( data.pos, @@ -286,7 +192,7 @@ def generate_graph( cell_offset_distances = out["offsets"] distance_vec = out["distance_vec"] - result = { + return { "edge_index": edge_index, "edge_distance": edge_dist, "edge_distance_vec": distance_vec, @@ -294,6 +200,3 @@ def generate_graph( "offset_distances": cell_offset_distances, "neighbors": neighbors, } - if send_info is not None: - result["send_info"] = send_info - return result diff --git a/src/fairchem/core/models/uma/escn_md.py b/src/fairchem/core/models/uma/escn_md.py index 176db9bba2..db8daa995f 100644 --- a/src/fairchem/core/models/uma/escn_md.py +++ b/src/fairchem/core/models/uma/escn_md.py @@ -772,7 +772,6 @@ def _compute_halo_graph( "Local partition atoms not found in halo — " "AABB expansion may be too small" ) - rank_assignments_local = rank_assignments[halo_indices] graph_dict = generate_graph( data_subset, @@ -782,9 +781,6 @@ def _compute_halo_graph( radius_pbc_version=self.radius_pbc_version, pbc=pbc, node_partition=node_partition_local, - rank_assignments=rank_assignments_local, - rank=gp_utils.get_gp_rank(), - world_size=gp_utils.get_gp_world_size(), ) # Remap edge_index from halo-local to global for use @@ -792,14 +788,6 @@ def _compute_halo_graph( # lookups index into atomic_numbers_full). graph_dict["edge_index"] = halo_indices[graph_dict["edge_index"]] - # Remap send_info indices from halo-local to global so - # build_gp_context can use them with global-coord - # rank_assignments and edge_index. - if "send_info" in graph_dict: - si = graph_dict["send_info"] - if si["send_indices_global"].numel() > 0: - si["send_indices_global"] = halo_indices[si["send_indices_global"]] - return graph_dict def _generate_graph(self, data_dict): @@ -894,13 +882,6 @@ def _generate_graph(self, data_dict): radius_pbc_version=self.radius_pbc_version, pbc=pbc, node_partition=node_partition, - rank_assignments=( - rank_assignments if self.use_all_to_all_gp else None - ), - rank=(gp_utils.get_gp_rank() if self.use_all_to_all_gp else None), - world_size=( - gp_utils.get_gp_world_size() if self.use_all_to_all_gp else None - ), ) else: # this assume edge_index is provided @@ -954,7 +935,6 @@ def _generate_graph(self, data_dict): rank_assignments=rank_assignments, rank=gp_utils.get_gp_rank(), world_size=gp_utils.get_gp_world_size(), - send_info=graph_dict.get("send_info"), node_partition=node_partition, ) data_dict["gp_ctx"] = gp_ctx diff --git a/tests/core/common/parallelism/test_graph_parallel.py b/tests/core/common/parallelism/test_graph_parallel.py index 8b0e9ebbcd..e5302d1cbe 100644 --- a/tests/core/common/parallelism/test_graph_parallel.py +++ b/tests/core/common/parallelism/test_graph_parallel.py @@ -623,146 +623,3 @@ def test_a2a_spatial_partition(): f"Rank {result['rank']}: spatial partitioning produced " f"different global results than index partitioning" ) - - -# ========================================================================= -# Distributed tests: send_info optimization correctness -# ========================================================================= - - -def send_info_optimization_test(atomic_numbers, edge_index): - """ - Verify that pre-computed send_info from filter_edges_by_node_partition - produces the same GPContext as the _sparse_index_exchange path. - """ - from fairchem.core.graph.compute import filter_edges_by_node_partition - - rank = gp_utils.get_gp_rank() - world_size = gp_utils.get_gp_world_size() - natoms = atomic_numbers.shape[0] - - rank_assignments = partition_atoms_index_split( - natoms, world_size, torch.device("cpu") - ) - - # Get this rank's partition - node_partition = (rank_assignments == rank).nonzero(as_tuple=True)[0] - - # Filter edges with send_info computation - neighbors = torch.tensor([edge_index.shape[1]]) - edge_index_filtered, _, _, send_info = filter_edges_by_node_partition( - node_partition=node_partition, - edge_index=edge_index, - cell_offsets=torch.zeros(edge_index.shape[1], 3), - neighbors=neighbors, - num_atoms=natoms, - rank_assignments=rank_assignments, - rank=rank, - world_size=world_size, - ) - - # Build GPContext WITH send_info (skip _sparse_index_exchange) - ctx_with_send_info = build_gp_context( - edge_index_filtered, - rank_assignments, - rank, - world_size, - send_info=send_info, - ) - - # Build GPContext WITHOUT send_info (use _sparse_index_exchange) - ctx_without = build_gp_context( - edge_index_filtered, - rank_assignments, - rank, - world_size, - ) - - # Compare the two contexts - send_counts_match = torch.equal( - ctx_with_send_info.send_counts, ctx_without.send_counts - ) - recv_counts_match = torch.equal( - ctx_with_send_info.recv_counts, ctx_without.recv_counts - ) - - # send_indices should select the same atoms (may differ in order - # within a rank's chunk, but counts must match) - si_a = ctx_with_send_info.send_indices - si_b = ctx_without.send_indices - if si_a is not None and si_b is not None: - send_indices_match = si_a.shape == si_b.shape and torch.equal( - si_a.sort()[0], si_b.sort()[0] - ) - elif si_a is None and si_b is None: - send_indices_match = True - else: - send_indices_match = False - - # Functional test: both contexts should produce identical - # all-to-all results - x = atomic_numbers[node_partition].unsqueeze(1).float() - - x_recv_opt = all_to_all_collect( - x, ctx_with_send_info, ctx_with_send_info.send_indices - ) - x_recv_ref = all_to_all_collect(x, ctx_without, ctx_without.send_indices) - - functional_match = torch.allclose(x_recv_opt, x_recv_ref, atol=1e-6) - - return { - "rank": rank, - "send_counts_match": send_counts_match, - "recv_counts_match": recv_counts_match, - "send_indices_match": send_indices_match, - "functional_match": functional_match, - "send_counts_opt": ctx_with_send_info.send_counts, - "send_counts_ref": ctx_without.send_counts, - } - - -def test_send_info_matches_fused_exchange(): - """ - Verify that pre-computed send_info from filter_edges_by_node_partition - produces identical GPContext and all-to-all results as the - _sparse_index_exchange path. - """ - num_atoms = 8 - # Dense graph: all atoms connected - src = [] - dst = [] - for i in range(num_atoms): - for j in range(num_atoms): - if i != j: - src.append(i) - dst.append(j) - edge_index = torch.tensor([src, dst], dtype=torch.long) - atomic_numbers = torch.arange( - 2, 2 + num_atoms, dtype=torch.float, requires_grad=False - ) - - config = PGConfig(backend="gloo", world_size=2, gp_group_size=2, use_gp=True) - all_rank_results = spawn_multi_process( - config, - send_info_optimization_test, - init_pg_and_rank_and_launch_test, - atomic_numbers, - edge_index, - ) - - for result in all_rank_results: - assert result["send_counts_match"], ( - f"Rank {result['rank']}: send_counts mismatch. " - f"opt={result['send_counts_opt']}, " - f"ref={result['send_counts_ref']}" - ) - assert result[ - "recv_counts_match" - ], f"Rank {result['rank']}: recv_counts mismatch" - assert result[ - "send_indices_match" - ], f"Rank {result['rank']}: send_indices mismatch" - assert result["functional_match"], ( - f"Rank {result['rank']}: functional mismatch — " - f"all-to-all produced different embeddings" - ) From 7558df46e4ce5792a07a18c6991a9830b2c120de Mon Sep 17 00:00:00 2001 From: rgao user Date: Fri, 5 Jun 2026 00:48:29 +0000 Subject: [PATCH 12/26] Make all GPContext fields required (no optional defaults) --- .../common/parallelism/graph_parallel_a2a.py | 51 +++++++------------ 1 file changed, 17 insertions(+), 34 deletions(-) diff --git a/src/fairchem/core/common/parallelism/graph_parallel_a2a.py b/src/fairchem/core/common/parallelism/graph_parallel_a2a.py index 373cb574b6..9cf266aa10 100644 --- a/src/fairchem/core/common/parallelism/graph_parallel_a2a.py +++ b/src/fairchem/core/common/parallelism/graph_parallel_a2a.py @@ -81,12 +81,9 @@ class GPContext: Shape: (total_atoms,), with -1 for atoms not accessible. total_local_atoms: Number of atoms in this rank's partition. total_needed_atoms: Total atoms needed from other ranks. - send_indices: Precomputed local indices of atoms to send, ordered by - destination rank. Computed once at build time to avoid per-forward - all-to-all index exchange. None if not yet computed. - edge_index_local: Precomputed edge index remapped to local indices. - None if not yet computed (set by build_gp_context when edge_index - is provided). + send_indices: Local indices of atoms to send, ordered by + destination rank. + edge_index_local: Edge index remapped to local indices. """ rank: int @@ -100,20 +97,13 @@ class GPContext: global_to_local: torch.Tensor total_local_atoms: int total_needed_atoms: int - send_indices: torch.Tensor | None = None - edge_index_local: torch.Tensor | None = None - # Precomputed Python lists to avoid repeated .tolist() in AllToAllCollect - send_splits: list[int] | None = None - recv_splits: list[int] | None = None - total_recv: int | None = None - # Precomputed integer indices for local/remote edges (for - # comm-compute overlap). Local edges have source atoms owned by - # this rank (edge_index_local[0] < total_local_atoms), remote - # edges have sources from other ranks. Using integer indices - # instead of boolean masks for compile-friendly indexing (avoids - # dynamic-shape boolean masking in compiled graphs). - local_edge_idx: torch.Tensor | None = None - remote_edge_idx: torch.Tensor | None = None + send_indices: torch.Tensor + edge_index_local: torch.Tensor + send_splits: list[int] + recv_splits: list[int] + total_recv: int + local_edge_idx: torch.Tensor + remote_edge_idx: torch.Tensor def _sparse_index_exchange( @@ -122,7 +112,7 @@ def _sparse_index_exchange( rank: int, world_size: int, device: torch.device, -) -> tuple[torch.Tensor, torch.Tensor | None]: +) -> tuple[torch.Tensor, torch.Tensor]: """ Variable-split index exchange using two small all-to-alls. @@ -156,7 +146,7 @@ def _sparse_index_exchange( if not gp_utils.initialized(): return ( torch.zeros(world_size, dtype=torch.long, device=device), - None, + torch.empty(0, dtype=torch.long, device=device), ) gp_group = gp_utils.get_gp_group() @@ -291,12 +281,10 @@ def build_gp_context( ) # Convert send_indices from global to local. - send_indices = None - if send_indices_global is not None: - if send_indices_global.numel() > 0: - send_indices = global_to_local[send_indices_global] - else: - send_indices = torch.empty(0, dtype=torch.long, device=device) + if send_indices_global.numel() > 0: + send_indices = global_to_local[send_indices_global] + else: + send_indices = torch.empty(0, dtype=torch.long, device=device) # Remap edge_index to local indices. edge_index_local = global_to_local[edge_index] @@ -313,7 +301,7 @@ def build_gp_context( "edge_index_local has negative entries — graph edges " "reference atoms not in global_to_local mapping.", ) - if send_indices is not None and send_indices.numel() > 0: + if send_indices.numel() > 0: torch._assert_async( ~((send_indices < 0) | (send_indices >= total_local_atoms)).any(), "send_indices out of bounds — remote rank requested " @@ -573,11 +561,6 @@ def all_to_all_collect( x_received: Remote atom embeddings, shape (total_needed, *features). """ - if send_indices is None: - raise ValueError( - "send_indices is None — build_gp_context should always " - "compute send_indices. Check GP setup." - ) return AllToAllCollect.apply( x_local, send_indices, From c20b024856fca0996a135291c9a227c7a96314e8 Mon Sep 17 00:00:00 2001 From: rgao user Date: Fri, 5 Jun 2026 00:53:58 +0000 Subject: [PATCH 13/26] Remove torchrun entrypoint; use fairchem CLI for multi-GPU tests --- .../parallelism/graph_parallel_verfication.md | 21 +- .../parallelism/test_a2a_correctness.py | 259 +----------------- 2 files changed, 23 insertions(+), 257 deletions(-) diff --git a/src/fairchem/core/common/parallelism/graph_parallel_verfication.md b/src/fairchem/core/common/parallelism/graph_parallel_verfication.md index 6089d061e7..292a7391b3 100644 --- a/src/fairchem/core/common/parallelism/graph_parallel_verfication.md +++ b/src/fairchem/core/common/parallelism/graph_parallel_verfication.md @@ -38,13 +38,28 @@ pytest tests/core/common/parallelism/test_a2a_correctness.py -v 8 tests: correctness at 100/500 atoms × 2 strategies, consistency across graph sizes × 2, 1536-dim embeddings × 2. -### 1d. Full-model GPU correctness (8 GPUs, ~5 min) +### 1d. Full-model GPU correctness (8 GPUs, ~10 min) + +Run BL and A2A benchmarks at 1000 atoms with a single repeat and compare outputs: ```bash -torchrun --nproc_per_node=8 tests/core/common/parallelism/test_a2a_correctness.py --natoms 1000 +# BL baseline +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=local_8gpu \ + runner.natoms_list=[1000] \ + runner.timeiters=1 \ + runner.repeats=1 + +# A2A + spatial +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=local_8gpu \ + runner.natoms_list=[1000] \ + runner.timeiters=1 \ + runner.repeats=1 \ + '+runner.overrides={backbone: {use_all_to_all_gp: true, gp_partition_strategy: spatial}}' ``` -Loads UMA-S checkpoint, runs BL vs A2A inference, compares energy/forces/stress (tol=1e-4). +Verify energy/forces/stress match between BL and A2A (tol=1e-4). ### 1e. Predict pipeline + MD consistency (CPU, PR3 branch, ~2 min) diff --git a/tests/core/common/parallelism/test_a2a_correctness.py b/tests/core/common/parallelism/test_a2a_correctness.py index 64c9a805bc..e2134dcebb 100644 --- a/tests/core/common/parallelism/test_a2a_correctness.py +++ b/tests/core/common/parallelism/test_a2a_correctness.py @@ -4,33 +4,23 @@ This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree. -Multi-GPU correctness test: A2A (all-to-all) vs BL (all-gather baseline). +A2A (all-to-all) graph parallel correctness tests. -Verifies that the A2A graph parallel implementation produces numerically -identical results to the BL baseline across multiple GPU counts. +Verifies that the A2A graph parallel implementation produces +correct results via multi-process Gloo tests (CPU). -Run directly via torchrun: - torchrun --nproc_per_node=N test_a2a_correctness.py [--natoms 1000] - -Or via pytest (2-process CPU with Gloo): +Run via pytest: pytest test_a2a_correctness.py -v - -The test creates an FCC crystal, loads the UMA-S checkpoint, and runs -inference in both BL and A2A modes. The outputs (energy, forces, stress) -are gathered to rank 0 and compared numerically. """ from __future__ import annotations -import argparse -import json import logging -import sys import pytest import torch -from fairchem.core.common import distutils, gp_utils +from fairchem.core.common import gp_utils from fairchem.core.common.parallelism.graph_parallel_a2a import ( all_to_all_collect, build_gp_context, @@ -44,10 +34,6 @@ init_pg_and_rank_and_launch_test, spawn_multi_process, ) -from fairchem.core.datasets.atomic_data import AtomicData -from fairchem.core.datasets.common_structures import ( - get_fcc_crystal_by_num_atoms, -) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -352,238 +338,3 @@ def test_a2a_multidim_embeddings(strategy): f"shape={result['recv_shape']} " f"vs {result['expected_shape']}" ) - - -# ========================================================================= -# Full model correctness test (GPU, run via torchrun or SLURM) -# ========================================================================= - - -def _resolve_checkpoint(): - """ - Resolve the UMA-S checkpoint path using the fairchem pretrained - model API. - """ - from fairchem.core.calculate.pretrained_mlip import ( - pretrained_checkpoint_path_from_name, - ) - - return pretrained_checkpoint_path_from_name(model_name="uma-s-1p2") - - -def _run_full_model_comparison( - natoms: int = 1000, - results_file: str | None = None, -): - """ - Run the full UMA-S model in both BL and A2A modes and compare - outputs. - - Must be called inside a torchrun process group. - """ - from fairchem.core.units.mlip_unit import MLIPPredictUnit - - rank = distutils.get_rank() - world_size = distutils.get_world_size() - - if rank == 0: - logger.info(f"Running correctness test: {natoms} atoms, " f"{world_size} GPUs") - - checkpoint_path = _resolve_checkpoint() - if rank == 0: - logger.info(f"Using checkpoint: {checkpoint_path}") - - # Create input system - atoms = get_fcc_crystal_by_num_atoms(natoms, atom_type="Al") - actual_natoms = len(atoms) - if rank == 0: - logger.info(f"Created FCC Al crystal: {actual_natoms} atoms") - - data = AtomicData.from_ase( - input_atoms=atoms, - max_neigh=200, - radius=6.0, - task_name="oc20", - r_edges=False, - r_data_keys=["spin", "charge"], - ) - - # -- Run BL (all-gather baseline) -- - if rank == 0: - logger.info("Loading model for BL (all-gather) mode...") - - predictor_bl = MLIPPredictUnit.from_checkpoint( - checkpoint_path, - device=torch.device("cuda"), - inference_settings={ - "tf32": False, - "compile": False, - "activation_checkpointing": False, - "merge_mole": False, - }, - overrides={ - "backbone": { - "use_all_to_all_gp": False, - }, - }, - ) - predictor_bl.model.eval() - - # Warm up + run BL - with torch.no_grad(): - _ = predictor_bl.predict(data) - bl_out = predictor_bl.predict(data) - - bl_energy = bl_out["energy"].clone() - bl_forces = bl_out["forces"].clone() - bl_stress = bl_out.get("stress", torch.tensor([])).clone() - - if rank == 0: - logger.info(f"BL energy: {bl_energy.item():.6f}") - logger.info(f"BL forces shape: {bl_forces.shape}") - logger.info(f"BL forces norm: {bl_forces.norm():.6f}") - - # Clean up BL model - del predictor_bl - torch.cuda.empty_cache() - - # -- Run A2A (all-to-all with spatial partitioning) -- - if rank == 0: - logger.info("Loading model for A2A (all-to-all) mode...") - - predictor_a2a = MLIPPredictUnit.from_checkpoint( - checkpoint_path, - device=torch.device("cuda"), - inference_settings={ - "tf32": False, - "compile": False, - "activation_checkpointing": False, - "merge_mole": False, - }, - overrides={ - "backbone": { - "use_all_to_all_gp": True, - "gp_partition_strategy": "spatial", - }, - }, - ) - predictor_a2a.model.eval() - - # Warm up + run A2A - with torch.no_grad(): - _ = predictor_a2a.predict(data) - a2a_out = predictor_a2a.predict(data) - - a2a_energy = a2a_out["energy"].clone() - a2a_forces = a2a_out["forces"].clone() - a2a_stress = a2a_out.get("stress", torch.tensor([])).clone() - - if rank == 0: - logger.info(f"A2A energy: {a2a_energy.item():.6f}") - logger.info(f"A2A forces shape: {a2a_forces.shape}") - logger.info(f"A2A forces norm: {a2a_forces.norm():.6f}") - - # -- Compare outputs -- - # Energy should match across all ranks (reduced) - energy_diff = abs(bl_energy.item() - a2a_energy.item()) - energy_match = energy_diff < 1e-4 - - # Forces: each rank only has forces for its local atoms. - # Gather all forces to rank 0 for comparison. - # BL forces are already the full set on all ranks. - # A2A forces need gathering. - if bl_forces.shape == a2a_forces.shape: - force_diff = (bl_forces - a2a_forces).abs().max().item() - force_match = force_diff < 1e-4 - force_rmse = (bl_forces - a2a_forces).pow(2).mean().sqrt().item() - else: - # Different shapes — gather and compare - force_diff = float("nan") - force_match = False - force_rmse = float("nan") - - # Stress - if bl_stress.numel() > 0 and a2a_stress.numel() > 0: - stress_diff = (bl_stress - a2a_stress).abs().max().item() - stress_match = stress_diff < 1e-4 - else: - stress_diff = 0.0 - stress_match = True - - results = { - "natoms": actual_natoms, - "world_size": world_size, - "energy_bl": bl_energy.item(), - "energy_a2a": a2a_energy.item(), - "energy_diff": energy_diff, - "energy_match": energy_match, - "force_max_diff": force_diff, - "force_rmse": force_rmse, - "force_match": force_match, - "stress_max_diff": stress_diff, - "stress_match": stress_match, - "all_match": energy_match and force_match and stress_match, - } - - if rank == 0: - logger.info(f"\n{'=' * 60}") - logger.info("CORRECTNESS TEST RESULTS") - logger.info(f"{'=' * 60}") - logger.info(f"Atoms: {actual_natoms}") - logger.info(f"GPUs: {world_size}") - logger.info(f"Energy BL: {bl_energy.item():.6f}") - logger.info(f"Energy A2A: {a2a_energy.item():.6f}") - logger.info(f"Energy diff: {energy_diff:.2e}") - logger.info(f"Force max Δ: {force_diff:.2e}") - logger.info(f"Force RMSE: {force_rmse:.2e}") - logger.info(f"Stress max Δ: {stress_diff:.2e}") - status = "✓ PASS" if results["all_match"] else "✗ FAIL" - logger.info(f"ALL MATCH: {status}") - logger.info(f"{'=' * 60}") - - if results_file: - with open(results_file, "w") as f: - json.dump(results, f, indent=2) - logger.info(f"Results saved to {results_file}") - - return results - - -# ========================================================================= -# CLI entrypoint for SLURM / torchrun -# ========================================================================= - - -def main(): - parser = argparse.ArgumentParser(description="A2A vs BL correctness test") - parser.add_argument( - "--natoms", - type=int, - default=1000, - help="Target number of atoms in FCC crystal", - ) - parser.add_argument( - "--results-file", - type=str, - default=None, - help="Path to save JSON results", - ) - args = parser.parse_args() - - # Initialize distributed - distutils.setup({"submit": False, "cpu": False}) - gp_utils.setup_gp(distutils.get_world_size()) - - try: - results = _run_full_model_comparison( - natoms=args.natoms, - results_file=args.results_file, - ) - if not results["all_match"]: - sys.exit(1) - finally: - distutils.cleanup() - - -if __name__ == "__main__": - main() From 256a1448b81ef8e712555e5d833373a5e31ad1e9 Mon Sep 17 00:00:00 2001 From: rgao user Date: Fri, 5 Jun 2026 01:02:48 +0000 Subject: [PATCH 14/26] Add GPU/NCCL tests for A2A graph parallel primitives --- .../parallelism/test_a2a_correctness.py | 150 ++++++++++++++++ .../common/parallelism/test_graph_parallel.py | 163 ++++++++++++++++++ 2 files changed, 313 insertions(+) diff --git a/tests/core/common/parallelism/test_a2a_correctness.py b/tests/core/common/parallelism/test_a2a_correctness.py index e2134dcebb..b134463d89 100644 --- a/tests/core/common/parallelism/test_a2a_correctness.py +++ b/tests/core/common/parallelism/test_a2a_correctness.py @@ -338,3 +338,153 @@ def test_a2a_multidim_embeddings(strategy): f"shape={result['recv_shape']} " f"vs {result['expected_shape']}" ) + + +# ========================================================================= +# GPU tests (NCCL, 2 processes) +# ========================================================================= + + +def _to_cuda(*tensors): + device = torch.device(f"cuda:{gp_utils.get_gp_rank()}") + return tuple(t.to(device) for t in tensors) + + +def _correctness_test_inner_gpu( + atomic_numbers, pos, edge_index, num_atoms, partition_strategy +): + (atomic_numbers, pos, edge_index) = _to_cuda(atomic_numbers, pos, edge_index) + return _correctness_test_inner( + atomic_numbers, pos, edge_index, num_atoms, partition_strategy + ) + + +def _multidim_test_inner_gpu(x_global, pos, edge_index, num_atoms, strategy): + (x_global, pos, edge_index) = _to_cuda(x_global, pos, edge_index) + return _multidim_test_inner(x_global, pos, edge_index, num_atoms, strategy) + + +@pytest.mark.gpu() +@pytest.mark.parametrize( + "strategy,num_atoms", + [ + ("index_split", 8), + ("index_split", 20), + ("spatial", 8), + ("spatial", 20), + ], +) +def test_a2a_correctness_gpu(strategy, num_atoms): + src, dst = [], [] + for i in range(num_atoms): + for j in range(num_atoms): + if i != j: + src.append(i) + dst.append(j) + edge_index = torch.tensor([src, dst], dtype=torch.long) + atomic_numbers = torch.arange(2, 2 + num_atoms, dtype=torch.float) + pos = torch.randn(num_atoms, 3) * 10 + + config = PGConfig(backend="nccl", world_size=2, gp_group_size=2, use_gp=True) + all_rank_results = spawn_multi_process( + config, + _correctness_test_inner_gpu, + init_pg_and_rank_and_launch_test, + atomic_numbers, + pos, + edge_index, + num_atoms, + strategy, + ) + + for result in all_rank_results: + r = result["rank"] + assert result[ + "recv_correct" + ], f"Rank {r}: received embeddings don't match expected values on GPU" + assert result[ + "edge_valid" + ], f"Rank {r}: edge_index_local has negative entries on GPU" + assert result[ + "edge_in_bounds" + ], f"Rank {r}: edge_index_local has out-of-bounds entries on GPU" + assert result[ + "mp_match" + ], f"Rank {r}: message passing result differs from reference on GPU" + + +@pytest.mark.gpu() +@pytest.mark.parametrize("strategy", ["index_split", "spatial"]) +def test_a2a_consistency_across_graph_sizes_gpu(strategy): + num_atoms = 16 + + src, dst = [], [] + for i in range(num_atoms): + for d in [-2, -1, 1, 2]: + j = (i + d) % num_atoms + src.append(i) + dst.append(j) + edge_index = torch.tensor([src, dst], dtype=torch.long) + atomic_numbers = torch.arange(10, 10 + num_atoms, dtype=torch.float) + pos = torch.zeros(num_atoms, 3) + pos[:, 0] = torch.arange(num_atoms, dtype=torch.float) + + config = PGConfig(backend="nccl", world_size=2, gp_group_size=2, use_gp=True) + all_rank_results = spawn_multi_process( + config, + _correctness_test_inner_gpu, + init_pg_and_rank_and_launch_test, + atomic_numbers, + pos, + edge_index, + num_atoms, + strategy, + ) + + for result in all_rank_results: + r = result["rank"] + assert result[ + "recv_correct" + ], f"Rank {r}: received embeddings don't match expected values on GPU" + assert result[ + "mp_match" + ], f"Rank {r}: message passing result differs from reference on GPU" + + +@pytest.mark.gpu() +@pytest.mark.parametrize("strategy", ["index_split", "spatial"]) +def test_a2a_multidim_embeddings_gpu(strategy): + num_atoms = 12 + embed_dim = 16 + + src, dst = [], [] + for i in range(num_atoms): + for j in range(num_atoms): + if i != j: + src.append(i) + dst.append(j) + edge_index = torch.tensor([src, dst], dtype=torch.long) + torch.manual_seed(42) + x_global = torch.randn(num_atoms, embed_dim) + pos = torch.randn(num_atoms, 3) * 10 + + config = PGConfig(backend="nccl", world_size=2, gp_group_size=2, use_gp=True) + + all_rank_results = spawn_multi_process( + config, + _multidim_test_inner_gpu, + init_pg_and_rank_and_launch_test, + x_global, + pos, + edge_index, + num_atoms, + strategy, + ) + + for result in all_rank_results: + r = result["rank"] + assert result["recv_correct"], ( + f"Rank {r}: multidim recv mismatch on GPU, " + f"shape={result['recv_shape']} " + f"vs {result['expected_shape']}" + ) diff --git a/tests/core/common/parallelism/test_graph_parallel.py b/tests/core/common/parallelism/test_graph_parallel.py index e5302d1cbe..f67ad7d94a 100644 --- a/tests/core/common/parallelism/test_graph_parallel.py +++ b/tests/core/common/parallelism/test_graph_parallel.py @@ -623,3 +623,166 @@ def test_a2a_spatial_partition(): f"Rank {result['rank']}: spatial partitioning produced " f"different global results than index partitioning" ) + + +# ========================================================================= +# GPU tests (NCCL, 2 processes) +# ========================================================================= + + +def _to_cuda(*tensors): + device = torch.device(f"cuda:{gp_utils.get_gp_rank()}") + return tuple(t.to(device) for t in tensors) + + +def a2a_vs_allgather_test_gpu(atomic_numbers, edge_index): + (atomic_numbers, edge_index) = _to_cuda(atomic_numbers, edge_index) + return a2a_vs_allgather_test(atomic_numbers, edge_index) + + +def a2a_backward_test_gpu(atomic_numbers, edge_index): + (atomic_numbers, edge_index) = _to_cuda(atomic_numbers, edge_index) + return a2a_backward_test(atomic_numbers, edge_index) + + +def a2a_spatial_partition_test_gpu(atomic_numbers, edge_index, pos): + (atomic_numbers, edge_index, pos) = _to_cuda(atomic_numbers, edge_index, pos) + return a2a_spatial_partition_test(atomic_numbers, edge_index, pos) + + +@pytest.mark.gpu() +@pytest.mark.parametrize( + "num_atoms, edges", + [ + (4, [[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]), + (5, [[0, 0, 0, 0, 1, 2, 3, 4], [1, 2, 3, 4, 0, 0, 0, 0]]), + ( + 4, + [ + [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], + [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2], + ], + ), + ], +) +def test_a2a_vs_allgather_gpu(num_atoms, edges): + atomic_numbers = torch.arange( + 2, 2 + num_atoms, dtype=torch.float, requires_grad=False + ) + edge_index = torch.tensor(edges, dtype=torch.long) + + config = PGConfig(backend="nccl", world_size=2, gp_group_size=2, use_gp=True) + all_rank_results = spawn_multi_process( + config, + a2a_vs_allgather_test_gpu, + init_pg_and_rank_and_launch_test, + atomic_numbers, + edge_index, + ) + + for result in all_rank_results: + assert result["match"], ( + f"Rank {result['rank']}: all-gather and all-to-all produced " + f"different results on GPU.\n" + f"allgather: {result['allgather']}\n" + f"all_to_all: {result['all_to_all']}" + ) + + +@pytest.mark.gpu() +def test_a2a_backward_gpu(): + atomic_numbers = torch.tensor([2.0, 3.0, 5.0, 7.0]) + edge_index = torch.tensor( + [[0, 0, 1, 1, 2, 2, 3, 3], [1, 2, 0, 3, 0, 3, 1, 2]], + dtype=torch.long, + ) + + config = PGConfig(backend="nccl", world_size=2, gp_group_size=2, use_gp=True) + all_rank_results = spawn_multi_process( + config, + a2a_backward_test_gpu, + init_pg_and_rank_and_launch_test, + atomic_numbers, + edge_index, + ) + + for result in all_rank_results: + assert result["energy_match"], ( + f"Rank {result['rank']}: energy mismatch on GPU. " + f"AG={result['allgather_energy']}, " + f"A2A={result['all_to_all_energy']}" + ) + assert result["forces_match"], ( + f"Rank {result['rank']}: forces mismatch on GPU. " + f"AG={result['allgather_forces']}, " + f"A2A={result['all_to_all_forces']}" + ) + + +@pytest.mark.gpu() +@pytest.mark.parametrize("world_size", [2, 3]) +def test_a2a_multi_rank_gpu(world_size): + num_atoms = 6 + src = list(range(num_atoms)) + dst = [(i + 1) % num_atoms for i in range(num_atoms)] + edge_src = src + dst + edge_dst = dst + src + edge_index = torch.tensor([edge_src, edge_dst], dtype=torch.long) + atomic_numbers = torch.arange(2, 2 + num_atoms, dtype=torch.float) + + config = PGConfig( + backend="nccl", + world_size=world_size, + gp_group_size=world_size, + use_gp=True, + ) + all_rank_results = spawn_multi_process( + config, + a2a_vs_allgather_test_gpu, + init_pg_and_rank_and_launch_test, + atomic_numbers, + edge_index, + ) + + for result in all_rank_results: + assert result["match"], ( + f"world_size={world_size}, " f"rank {result['rank']}: mismatch on GPU" + ) + + +@pytest.mark.gpu() +def test_a2a_spatial_partition_gpu(): + num_atoms = 8 + pos = torch.cat( + [ + torch.randn(4, 3) + torch.tensor([0.0, 0.0, 0.0]), + torch.randn(4, 3) + torch.tensor([100.0, 0.0, 0.0]), + ] + ) + atomic_numbers = torch.arange( + 2, 2 + num_atoms, dtype=torch.float, requires_grad=False + ) + src = [] + dst = [] + for i in range(num_atoms): + for j in range(num_atoms): + if i != j: + src.append(i) + dst.append(j) + edge_index = torch.tensor([src, dst], dtype=torch.long) + + config = PGConfig(backend="nccl", world_size=2, gp_group_size=2, use_gp=True) + all_rank_results = spawn_multi_process( + config, + a2a_spatial_partition_test_gpu, + init_pg_and_rank_and_launch_test, + atomic_numbers, + edge_index, + pos, + ) + + for result in all_rank_results: + assert result["match"], ( + f"Rank {result['rank']}: spatial partitioning produced " + f"different global results than index partitioning on GPU" + ) From 7112ab892308aae55fe4fc948c1a05ff791bfdba Mon Sep 17 00:00:00 2001 From: rgao user Date: Fri, 5 Jun 2026 01:07:18 +0000 Subject: [PATCH 15/26] Skip GPU/NCCL tests when fewer than 2 GPUs available --- .../common/parallelism/test_a2a_correctness.py | 11 ++++++++--- .../common/parallelism/test_graph_parallel.py | 15 +++++++++++---- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/tests/core/common/parallelism/test_a2a_correctness.py b/tests/core/common/parallelism/test_a2a_correctness.py index b134463d89..e26e4c97f6 100644 --- a/tests/core/common/parallelism/test_a2a_correctness.py +++ b/tests/core/common/parallelism/test_a2a_correctness.py @@ -344,6 +344,11 @@ def test_a2a_multidim_embeddings(strategy): # GPU tests (NCCL, 2 processes) # ========================================================================= +_skip_if_lt_2_gpus = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Requires at least 2 GPUs", +) + def _to_cuda(*tensors): device = torch.device(f"cuda:{gp_utils.get_gp_rank()}") @@ -364,7 +369,7 @@ def _multidim_test_inner_gpu(x_global, pos, edge_index, num_atoms, strategy): return _multidim_test_inner(x_global, pos, edge_index, num_atoms, strategy) -@pytest.mark.gpu() +@_skip_if_lt_2_gpus @pytest.mark.parametrize( "strategy,num_atoms", [ @@ -413,7 +418,7 @@ def test_a2a_correctness_gpu(strategy, num_atoms): ], f"Rank {r}: message passing result differs from reference on GPU" -@pytest.mark.gpu() +@_skip_if_lt_2_gpus @pytest.mark.parametrize("strategy", ["index_split", "spatial"]) def test_a2a_consistency_across_graph_sizes_gpu(strategy): num_atoms = 16 @@ -451,7 +456,7 @@ def test_a2a_consistency_across_graph_sizes_gpu(strategy): ], f"Rank {r}: message passing result differs from reference on GPU" -@pytest.mark.gpu() +@_skip_if_lt_2_gpus @pytest.mark.parametrize("strategy", ["index_split", "spatial"]) def test_a2a_multidim_embeddings_gpu(strategy): num_atoms = 12 diff --git a/tests/core/common/parallelism/test_graph_parallel.py b/tests/core/common/parallelism/test_graph_parallel.py index f67ad7d94a..5a8b7ebf97 100644 --- a/tests/core/common/parallelism/test_graph_parallel.py +++ b/tests/core/common/parallelism/test_graph_parallel.py @@ -629,6 +629,11 @@ def test_a2a_spatial_partition(): # GPU tests (NCCL, 2 processes) # ========================================================================= +_skip_if_lt_2_gpus = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 2, + reason="Requires at least 2 GPUs", +) + def _to_cuda(*tensors): device = torch.device(f"cuda:{gp_utils.get_gp_rank()}") @@ -650,7 +655,7 @@ def a2a_spatial_partition_test_gpu(atomic_numbers, edge_index, pos): return a2a_spatial_partition_test(atomic_numbers, edge_index, pos) -@pytest.mark.gpu() +@_skip_if_lt_2_gpus @pytest.mark.parametrize( "num_atoms, edges", [ @@ -689,7 +694,7 @@ def test_a2a_vs_allgather_gpu(num_atoms, edges): ) -@pytest.mark.gpu() +@_skip_if_lt_2_gpus def test_a2a_backward_gpu(): atomic_numbers = torch.tensor([2.0, 3.0, 5.0, 7.0]) edge_index = torch.tensor( @@ -719,9 +724,11 @@ def test_a2a_backward_gpu(): ) -@pytest.mark.gpu() +@_skip_if_lt_2_gpus @pytest.mark.parametrize("world_size", [2, 3]) def test_a2a_multi_rank_gpu(world_size): + if torch.cuda.device_count() < world_size: + pytest.skip(f"Requires at least {world_size} GPUs") num_atoms = 6 src = list(range(num_atoms)) dst = [(i + 1) % num_atoms for i in range(num_atoms)] @@ -750,7 +757,7 @@ def test_a2a_multi_rank_gpu(world_size): ) -@pytest.mark.gpu() +@_skip_if_lt_2_gpus def test_a2a_spatial_partition_gpu(): num_atoms = 8 pos = torch.cat( From 50ab32d2bebb43836cc034dd8bded47be34b05ac Mon Sep 17 00:00:00 2001 From: rgao user Date: Fri, 5 Jun 2026 01:08:47 +0000 Subject: [PATCH 16/26] Use CI env var to skip multi-GPU tests instead of device_count check --- .../common/parallelism/test_a2a_correctness.py | 13 +++++++------ .../common/parallelism/test_graph_parallel.py | 18 +++++++++--------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/tests/core/common/parallelism/test_a2a_correctness.py b/tests/core/common/parallelism/test_a2a_correctness.py index e26e4c97f6..b2fbce17de 100644 --- a/tests/core/common/parallelism/test_a2a_correctness.py +++ b/tests/core/common/parallelism/test_a2a_correctness.py @@ -16,6 +16,7 @@ from __future__ import annotations import logging +import os import pytest import torch @@ -344,9 +345,9 @@ def test_a2a_multidim_embeddings(strategy): # GPU tests (NCCL, 2 processes) # ========================================================================= -_skip_if_lt_2_gpus = pytest.mark.skipif( - not torch.cuda.is_available() or torch.cuda.device_count() < 2, - reason="Requires at least 2 GPUs", +_skip_if_ci = pytest.mark.skipif( + os.environ.get("CI") == "true", + reason="Multi-GPU test, skipped in CI", ) @@ -369,7 +370,7 @@ def _multidim_test_inner_gpu(x_global, pos, edge_index, num_atoms, strategy): return _multidim_test_inner(x_global, pos, edge_index, num_atoms, strategy) -@_skip_if_lt_2_gpus +@_skip_if_ci @pytest.mark.parametrize( "strategy,num_atoms", [ @@ -418,7 +419,7 @@ def test_a2a_correctness_gpu(strategy, num_atoms): ], f"Rank {r}: message passing result differs from reference on GPU" -@_skip_if_lt_2_gpus +@_skip_if_ci @pytest.mark.parametrize("strategy", ["index_split", "spatial"]) def test_a2a_consistency_across_graph_sizes_gpu(strategy): num_atoms = 16 @@ -456,7 +457,7 @@ def test_a2a_consistency_across_graph_sizes_gpu(strategy): ], f"Rank {r}: message passing result differs from reference on GPU" -@_skip_if_lt_2_gpus +@_skip_if_ci @pytest.mark.parametrize("strategy", ["index_split", "spatial"]) def test_a2a_multidim_embeddings_gpu(strategy): num_atoms = 12 diff --git a/tests/core/common/parallelism/test_graph_parallel.py b/tests/core/common/parallelism/test_graph_parallel.py index 5a8b7ebf97..e7aa1e4fa1 100644 --- a/tests/core/common/parallelism/test_graph_parallel.py +++ b/tests/core/common/parallelism/test_graph_parallel.py @@ -7,6 +7,8 @@ from __future__ import annotations +import os + import pytest import torch @@ -629,9 +631,9 @@ def test_a2a_spatial_partition(): # GPU tests (NCCL, 2 processes) # ========================================================================= -_skip_if_lt_2_gpus = pytest.mark.skipif( - not torch.cuda.is_available() or torch.cuda.device_count() < 2, - reason="Requires at least 2 GPUs", +_skip_if_ci = pytest.mark.skipif( + os.environ.get("CI") == "true", + reason="Multi-GPU test, skipped in CI", ) @@ -655,7 +657,7 @@ def a2a_spatial_partition_test_gpu(atomic_numbers, edge_index, pos): return a2a_spatial_partition_test(atomic_numbers, edge_index, pos) -@_skip_if_lt_2_gpus +@_skip_if_ci @pytest.mark.parametrize( "num_atoms, edges", [ @@ -694,7 +696,7 @@ def test_a2a_vs_allgather_gpu(num_atoms, edges): ) -@_skip_if_lt_2_gpus +@_skip_if_ci def test_a2a_backward_gpu(): atomic_numbers = torch.tensor([2.0, 3.0, 5.0, 7.0]) edge_index = torch.tensor( @@ -724,11 +726,9 @@ def test_a2a_backward_gpu(): ) -@_skip_if_lt_2_gpus +@_skip_if_ci @pytest.mark.parametrize("world_size", [2, 3]) def test_a2a_multi_rank_gpu(world_size): - if torch.cuda.device_count() < world_size: - pytest.skip(f"Requires at least {world_size} GPUs") num_atoms = 6 src = list(range(num_atoms)) dst = [(i + 1) % num_atoms for i in range(num_atoms)] @@ -757,7 +757,7 @@ def test_a2a_multi_rank_gpu(world_size): ) -@_skip_if_lt_2_gpus +@_skip_if_ci def test_a2a_spatial_partition_gpu(): num_atoms = 8 pos = torch.cat( From af68f8084113e9ca5d6973d540e867148bf791a5 Mon Sep 17 00:00:00 2001 From: rgao user Date: Fri, 5 Jun 2026 01:35:38 +0000 Subject: [PATCH 17/26] Add full-model GP correctness test (no-GP vs allgather vs A2A) --- .../parallelism/graph_parallel_verfication.md | 23 +---- tests/core/units/mlip_unit/test_predict.py | 95 +++++++++++++++++++ 2 files changed, 100 insertions(+), 18 deletions(-) diff --git a/src/fairchem/core/common/parallelism/graph_parallel_verfication.md b/src/fairchem/core/common/parallelism/graph_parallel_verfication.md index 292a7391b3..78c244db90 100644 --- a/src/fairchem/core/common/parallelism/graph_parallel_verfication.md +++ b/src/fairchem/core/common/parallelism/graph_parallel_verfication.md @@ -38,28 +38,15 @@ pytest tests/core/common/parallelism/test_a2a_correctness.py -v 8 tests: correctness at 100/500 atoms × 2 strategies, consistency across graph sizes × 2, 1536-dim embeddings × 2. -### 1d. Full-model GPU correctness (8 GPUs, ~10 min) - -Run BL and A2A benchmarks at 1000 atoms with a single repeat and compare outputs: +### 1d. Full-model GPU correctness (4+ GPUs, ~5 min) ```bash -# BL baseline -fairchem -c configs/uma/speed/uma-speed.yaml \ - job=local_8gpu \ - runner.natoms_list=[1000] \ - runner.timeiters=1 \ - runner.repeats=1 - -# A2A + spatial -fairchem -c configs/uma/speed/uma-speed.yaml \ - job=local_8gpu \ - runner.natoms_list=[1000] \ - runner.timeiters=1 \ - runner.repeats=1 \ - '+runner.overrides={backbone: {use_all_to_all_gp: true, gp_partition_strategy: spatial}}' +pytest tests/core/units/mlip_unit/test_predict.py::test_full_model_gp_correctness -v ``` -Verify energy/forces/stress match between BL and A2A (tol=1e-4). +27 tests: 3 atom counts (10, 50, 100) × 9 configs (no-GP, allgather, A2A-spatial, +A2A-index_split at 1/2/4 workers). Compares energy/forces/stress against single-GPU +reference (tol: energy/stress 5e-4, forces 1e-4). Skipped on CI (`CI=true`). ### 1e. Predict pipeline + MD consistency (CPU, PR3 branch, ~2 min) diff --git a/tests/core/units/mlip_unit/test_predict.py b/tests/core/units/mlip_unit/test_predict.py index bf96de7921..3bc35e8a41 100644 --- a/tests/core/units/mlip_unit/test_predict.py +++ b/tests/core/units/mlip_unit/test_predict.py @@ -9,6 +9,7 @@ import contextlib import logging +import os from copy import deepcopy import numpy as np @@ -1860,3 +1861,97 @@ def test_execution_mode_not_set_when_conditions_not_met(model_name): f"Expected execution_mode to be None when activation_checkpointing=True, " f"got {predict_unit.inference_settings.execution_mode}" ) + + +# ========================================================================= +# Full-model GP correctness: no-GP vs all-gather vs A2A +# Skipped on CI (no multi-GPU), run locally or via SLURM +# ========================================================================= + +_skip_if_ci = pytest.mark.skipif( + os.environ.get("CI") == "true", + reason="Multi-GPU test, skipped in CI", +) + + +@_skip_if_ci +@pytest.mark.gpu() +@pytest.mark.parametrize("num_atoms", [10, 50, 100]) +@pytest.mark.parametrize( + "workers, gp_mode", + [ + # All-gather (default GP) + (1, None), + (2, None), + (4, None), + # A2A + spatial + (1, {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}), + (2, {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}), + (4, {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}), + # A2A + index_split + (1, {"use_all_to_all_gp": True, "gp_partition_strategy": "index_split"}), + (2, {"use_all_to_all_gp": True, "gp_partition_strategy": "index_split"}), + (4, {"use_all_to_all_gp": True, "gp_partition_strategy": "index_split"}), + ], +) +def test_full_model_gp_correctness(num_atoms, workers, gp_mode): + seed = 42 + model_path = pretrained_checkpoint_path_from_name("uma-s-1p1") + ifsets = InferenceSettings( + tf32=False, + merge_mole=True, + activation_checkpointing=False, + internal_graph_gen_version=2, + external_graph_gen=False, + ) + atoms = get_fcc_crystal_by_num_atoms(num_atoms) + atomic_data = AtomicData.from_ase(atoms, task_name=["omat"]) + + overrides = None + if gp_mode is not None: + overrides = {"backbone": gp_mode} + + seed_everywhere(seed) + ppunit = ParallelMLIPPredictUnit( + inference_model_path=model_path, + device="cuda", + inference_settings=ifsets, + num_workers=workers, + overrides=overrides, + ) + pp_results = ppunit.predict(atomic_data) + distutils.cleanup_gp_ray() + + seed_everywhere(seed) + ref_unit = pretrained_mlip.get_predict_unit( + "uma-s-1p1", device="cuda", inference_settings=ifsets + ) + ref_results = ref_unit.predict(atomic_data) + + assert torch.allclose( + pp_results["energy"].detach().cpu(), + ref_results["energy"].detach().cpu(), + atol=ATOL, + ), ( + f"Energy mismatch: workers={workers}, gp_mode={gp_mode}, " + f"num_atoms={num_atoms}, " + f"pp={pp_results['energy'].item():.6f}, " + f"ref={ref_results['energy'].item():.6f}" + ) + assert torch.allclose( + pp_results["forces"].detach().cpu(), + ref_results["forces"].detach().cpu(), + atol=FORCE_TOL, + ), ( + f"Forces mismatch: workers={workers}, gp_mode={gp_mode}, " + f"num_atoms={num_atoms}, " + f"max_diff={torch.max(torch.abs(pp_results['forces'].detach().cpu() - ref_results['forces'].detach().cpu())).item():.6e}" + ) + assert torch.allclose( + pp_results["stress"].detach().cpu(), + ref_results["stress"].detach().cpu(), + atol=ATOL, + ), ( + f"Stress mismatch: workers={workers}, gp_mode={gp_mode}, " + f"num_atoms={num_atoms}" + ) From 6168d23c2c843670e3ca4e07f0b8170c12d59420 Mon Sep 17 00:00:00 2001 From: rgao user Date: Fri, 5 Jun 2026 16:54:49 +0000 Subject: [PATCH 18/26] Add GP correctness runner for fairchem CLI --- configs/uma/correctness/gp-correctness.yaml | 29 +++ configs/uma/correctness/job/local.yaml | 5 + configs/uma/correctness/job/local_8gpu.yaml | 6 + configs/uma/correctness/job/slurm.yaml | 11 + .../parallelism/graph_parallel_verfication.md | 21 +- .../benchmark/_single/gp_correctness.py | 200 ++++++++++++++++++ 6 files changed, 267 insertions(+), 5 deletions(-) create mode 100644 configs/uma/correctness/gp-correctness.yaml create mode 100644 configs/uma/correctness/job/local.yaml create mode 100644 configs/uma/correctness/job/local_8gpu.yaml create mode 100644 configs/uma/correctness/job/slurm.yaml create mode 100644 src/fairchem/core/components/benchmark/_single/gp_correctness.py diff --git a/configs/uma/correctness/gp-correctness.yaml b/configs/uma/correctness/gp-correctness.yaml new file mode 100644 index 0000000000..98affa0c34 --- /dev/null +++ b/configs/uma/correctness/gp-correctness.yaml @@ -0,0 +1,29 @@ +# GP correctness test: compares single-GPU reference against +# all-gather, A2A-spatial, and A2A-index_split modes. +# +# Usage: +# fairchem -c configs/uma/correctness/gp-correctness.yaml # 1 GPU +# fairchem -c configs/uma/correctness/gp-correctness.yaml job=local_8gpu # 8 GPUs + +defaults: + - job: local + - _self_ + +uma_s_1p2: + _target_: fairchem.core.calculate.pretrained_mlip.pretrained_checkpoint_path_from_name + model_name: "uma-s-1p2" + +runner: + _target_: fairchem.core.components.benchmark._single.gp_correctness.GPCorrectnessRunner + model_checkpoint: ${uma_s_1p2} + device: "cuda" + natoms_list: [10, 50, 100] + seed: 42 + inference_settings: + _target_: fairchem.core.units.mlip_unit.api.inference.InferenceSettings + tf32: False + activation_checkpointing: False + merge_mole: True + compile: False + external_graph_gen: False + internal_graph_gen_version: 2 diff --git a/configs/uma/correctness/job/local.yaml b/configs/uma/correctness/job/local.yaml new file mode 100644 index 0000000000..83ef20ef8f --- /dev/null +++ b/configs/uma/correctness/job/local.yaml @@ -0,0 +1,5 @@ +device_type: CUDA +scheduler: + mode: LOCAL + ranks_per_node: 1 +run_dir: /checkpoint/ocp/rgao/correctness diff --git a/configs/uma/correctness/job/local_8gpu.yaml b/configs/uma/correctness/job/local_8gpu.yaml new file mode 100644 index 0000000000..d3c3cf6fda --- /dev/null +++ b/configs/uma/correctness/job/local_8gpu.yaml @@ -0,0 +1,6 @@ +device_type: CUDA +scheduler: + mode: LOCAL + ranks_per_node: 8 +graph_parallel_group_size: 8 +run_dir: /checkpoint/ocp/rgao/correctness diff --git a/configs/uma/correctness/job/slurm.yaml b/configs/uma/correctness/job/slurm.yaml new file mode 100644 index 0000000000..48fb091775 --- /dev/null +++ b/configs/uma/correctness/job/slurm.yaml @@ -0,0 +1,11 @@ +run_dir: /checkpoint/ocp/rgao/correctness +scheduler: + mode: SLURM + num_nodes: 1 + ranks_per_node: 8 + slurm: + account: ocp + qos: h200_dev + mem_gb: 0 + cpus_per_task: 24 +graph_parallel_group_size: 8 diff --git a/src/fairchem/core/common/parallelism/graph_parallel_verfication.md b/src/fairchem/core/common/parallelism/graph_parallel_verfication.md index 78c244db90..b523330136 100644 --- a/src/fairchem/core/common/parallelism/graph_parallel_verfication.md +++ b/src/fairchem/core/common/parallelism/graph_parallel_verfication.md @@ -38,15 +38,26 @@ pytest tests/core/common/parallelism/test_a2a_correctness.py -v 8 tests: correctness at 100/500 atoms × 2 strategies, consistency across graph sizes × 2, 1536-dim embeddings × 2. -### 1d. Full-model GPU correctness (4+ GPUs, ~5 min) +### 1d. Full-model GPU correctness (8 GPUs, ~5 min) ```bash -pytest tests/core/units/mlip_unit/test_predict.py::test_full_model_gp_correctness -v +# Local 8-GPU +fairchem -c configs/uma/correctness/gp-correctness.yaml job=local_8gpu + +# SLURM +fairchem -c configs/uma/correctness/gp-correctness.yaml job=slurm ``` -27 tests: 3 atom counts (10, 50, 100) × 9 configs (no-GP, allgather, A2A-spatial, -A2A-index_split at 1/2/4 workers). Compares energy/forces/stress against single-GPU -reference (tol: energy/stress 5e-4, forces 1e-4). Skipped on CI (`CI=true`). +9 checks: 3 atom counts (10, 50, 100) × 3 GP modes (allgather, A2A-spatial, +A2A-index_split). Each mode compared against single-GPU no-GP reference. +Compares energy/forces/stress (tol: energy/stress 5e-4, forces 1e-4). +Results written to `correctness_results.json` in the run directory. + +Also available as a pytest (skipped on CI): + +```bash +pytest tests/core/units/mlip_unit/test_predict.py::test_full_model_gp_correctness -v +``` ### 1e. Predict pipeline + MD consistency (CPU, PR3 branch, ~2 min) diff --git a/src/fairchem/core/components/benchmark/_single/gp_correctness.py b/src/fairchem/core/components/benchmark/_single/gp_correctness.py new file mode 100644 index 0000000000..00a78ced78 --- /dev/null +++ b/src/fairchem/core/components/benchmark/_single/gp_correctness.py @@ -0,0 +1,200 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. + +GP correctness runner: compares single-GPU reference against +all-gather and all-to-all graph parallel modes. + +Usage: + fairchem -c configs/uma/correctness/gp-correctness.yaml job=local_8gpu +""" + +from __future__ import annotations + +import json +import logging +import os +import random + +import numpy as np +import torch + +from fairchem.core.common import distutils +from fairchem.core.components.runner import Runner +from fairchem.core.datasets.atomic_data import AtomicData +from fairchem.core.datasets.common_structures import get_fcc_crystal_by_num_atoms +from fairchem.core.units.mlip_unit import MLIPPredictUnit +from fairchem.core.units.mlip_unit.api.inference import ( + InferenceSettings, + inference_settings_default, +) + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +ENERGY_TOL = 5e-4 +FORCE_TOL = 1e-4 +STRESS_TOL = 5e-4 + + +def _seed_everywhere(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +class GPCorrectnessRunner(Runner): + def __init__( + self, + model_checkpoint: str, + natoms_list: list[int] | None = None, + seed: int = 42, + device: str = "cuda", + inference_settings: InferenceSettings = inference_settings_default(), # noqa: B008 + overrides: dict | None = None, + ): + self.model_checkpoint = model_checkpoint + self.natoms_list = natoms_list or [10, 50, 100] + self.seed = seed + self.device = device + self.inference_settings = inference_settings + self.overrides = overrides or {} + + def run(self) -> None: + run_dir = self.job_config.metadata.results_dir + os.makedirs(run_dir, exist_ok=True) + + rank = distutils.get_rank() + world_size = distutils.get_world_size() + + gp_modes = [ + ("allgather", {}), + ( + "a2a_spatial", + {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}, + ), + ( + "a2a_index_split", + {"use_all_to_all_gp": True, "gp_partition_strategy": "index_split"}, + ), + ] + + results = [] + + for num_atoms in self.natoms_list: + _seed_everywhere(self.seed) + atoms = get_fcc_crystal_by_num_atoms(num_atoms) + data = AtomicData.from_ase(atoms, task_name="omat") + + # Reference: single-rank prediction (no GP overrides) + _seed_everywhere(self.seed) + ref_predictor = MLIPPredictUnit( + self.model_checkpoint, + self.device, + overrides=self.overrides, + inference_settings=self.inference_settings, + ) + ref_result = ref_predictor.predict(data) + ref_energy = ref_result["energy"].detach().cpu() + ref_forces = ref_result["forces"].detach().cpu() + ref_stress = ref_result["stress"].detach().cpu() + del ref_predictor + torch.cuda.empty_cache() + + for mode_name, gp_overrides in gp_modes: + merged = {**self.overrides} + if "backbone" not in merged: + merged["backbone"] = {} + merged["backbone"].update(gp_overrides) + + _seed_everywhere(self.seed) + predictor = MLIPPredictUnit( + self.model_checkpoint, + self.device, + overrides=merged, + inference_settings=self.inference_settings, + ) + gp_result = predictor.predict(data) + gp_energy = gp_result["energy"].detach().cpu() + gp_forces = gp_result["forces"].detach().cpu() + gp_stress = gp_result["stress"].detach().cpu() + del predictor + torch.cuda.empty_cache() + + energy_diff = torch.abs(gp_energy - ref_energy).max().item() + forces_diff = torch.abs(gp_forces - ref_forces).max().item() + stress_diff = torch.abs(gp_stress - ref_stress).max().item() + + energy_ok = energy_diff < ENERGY_TOL + forces_ok = forces_diff < FORCE_TOL + stress_ok = stress_diff < STRESS_TOL + all_ok = energy_ok and forces_ok and stress_ok + + result = { + "num_atoms": num_atoms, + "world_size": world_size, + "mode": mode_name, + "energy_diff": energy_diff, + "forces_diff": forces_diff, + "stress_diff": stress_diff, + "energy_ok": energy_ok, + "forces_ok": forces_ok, + "stress_ok": stress_ok, + "pass": all_ok, + } + results.append(result) + + status = "PASS" if all_ok else "FAIL" + logger.info( + f"[{status}] {mode_name} | {num_atoms} atoms | " + f"ws={world_size} | " + f"energy_diff={energy_diff:.2e} | " + f"forces_diff={forces_diff:.2e} | " + f"stress_diff={stress_diff:.2e}" + ) + + # Write results JSON (rank 0 only) + if rank == 0: + output = { + "world_size": world_size, + "model": self.model_checkpoint, + "tolerances": { + "energy": ENERGY_TOL, + "forces": FORCE_TOL, + "stress": STRESS_TOL, + }, + "results": results, + } + results_path = os.path.join(run_dir, "correctness_results.json") + with open(results_path, "w") as f: + json.dump(output, f, indent=2) + logger.info(f"Results written to {results_path}") + + # Summary + n_pass = sum(1 for r in results if r["pass"]) + n_total = len(results) + logger.info(f"Summary: {n_pass}/{n_total} passed") + + if n_pass < n_total: + failures = [r for r in results if not r["pass"]] + for f in failures: + logger.error( + f"FAILED: {f['mode']} | {f['num_atoms']} atoms | " + f"energy={f['energy_diff']:.2e} " + f"forces={f['forces_diff']:.2e} " + f"stress={f['stress_diff']:.2e}" + ) + raise AssertionError( + f"{n_total - n_pass}/{n_total} GP correctness checks failed" + ) + + def save_state(self, checkpoint_location: str, is_preemption: bool = False) -> bool: + return True + + def load_state(self, checkpoint_location: str | None) -> None: + pass From 1a87db076b76b2f3c963b49069117dd852977adb Mon Sep 17 00:00:00 2001 From: rgao user Date: Fri, 5 Jun 2026 16:56:08 +0000 Subject: [PATCH 19/26] Revert "Add GP correctness runner for fairchem CLI" This reverts commit 6168d23c2c843670e3ca4e07f0b8170c12d59420. --- configs/uma/correctness/gp-correctness.yaml | 29 --- configs/uma/correctness/job/local.yaml | 5 - configs/uma/correctness/job/local_8gpu.yaml | 6 - configs/uma/correctness/job/slurm.yaml | 11 - .../parallelism/graph_parallel_verfication.md | 21 +- .../benchmark/_single/gp_correctness.py | 200 ------------------ 6 files changed, 5 insertions(+), 267 deletions(-) delete mode 100644 configs/uma/correctness/gp-correctness.yaml delete mode 100644 configs/uma/correctness/job/local.yaml delete mode 100644 configs/uma/correctness/job/local_8gpu.yaml delete mode 100644 configs/uma/correctness/job/slurm.yaml delete mode 100644 src/fairchem/core/components/benchmark/_single/gp_correctness.py diff --git a/configs/uma/correctness/gp-correctness.yaml b/configs/uma/correctness/gp-correctness.yaml deleted file mode 100644 index 98affa0c34..0000000000 --- a/configs/uma/correctness/gp-correctness.yaml +++ /dev/null @@ -1,29 +0,0 @@ -# GP correctness test: compares single-GPU reference against -# all-gather, A2A-spatial, and A2A-index_split modes. -# -# Usage: -# fairchem -c configs/uma/correctness/gp-correctness.yaml # 1 GPU -# fairchem -c configs/uma/correctness/gp-correctness.yaml job=local_8gpu # 8 GPUs - -defaults: - - job: local - - _self_ - -uma_s_1p2: - _target_: fairchem.core.calculate.pretrained_mlip.pretrained_checkpoint_path_from_name - model_name: "uma-s-1p2" - -runner: - _target_: fairchem.core.components.benchmark._single.gp_correctness.GPCorrectnessRunner - model_checkpoint: ${uma_s_1p2} - device: "cuda" - natoms_list: [10, 50, 100] - seed: 42 - inference_settings: - _target_: fairchem.core.units.mlip_unit.api.inference.InferenceSettings - tf32: False - activation_checkpointing: False - merge_mole: True - compile: False - external_graph_gen: False - internal_graph_gen_version: 2 diff --git a/configs/uma/correctness/job/local.yaml b/configs/uma/correctness/job/local.yaml deleted file mode 100644 index 83ef20ef8f..0000000000 --- a/configs/uma/correctness/job/local.yaml +++ /dev/null @@ -1,5 +0,0 @@ -device_type: CUDA -scheduler: - mode: LOCAL - ranks_per_node: 1 -run_dir: /checkpoint/ocp/rgao/correctness diff --git a/configs/uma/correctness/job/local_8gpu.yaml b/configs/uma/correctness/job/local_8gpu.yaml deleted file mode 100644 index d3c3cf6fda..0000000000 --- a/configs/uma/correctness/job/local_8gpu.yaml +++ /dev/null @@ -1,6 +0,0 @@ -device_type: CUDA -scheduler: - mode: LOCAL - ranks_per_node: 8 -graph_parallel_group_size: 8 -run_dir: /checkpoint/ocp/rgao/correctness diff --git a/configs/uma/correctness/job/slurm.yaml b/configs/uma/correctness/job/slurm.yaml deleted file mode 100644 index 48fb091775..0000000000 --- a/configs/uma/correctness/job/slurm.yaml +++ /dev/null @@ -1,11 +0,0 @@ -run_dir: /checkpoint/ocp/rgao/correctness -scheduler: - mode: SLURM - num_nodes: 1 - ranks_per_node: 8 - slurm: - account: ocp - qos: h200_dev - mem_gb: 0 - cpus_per_task: 24 -graph_parallel_group_size: 8 diff --git a/src/fairchem/core/common/parallelism/graph_parallel_verfication.md b/src/fairchem/core/common/parallelism/graph_parallel_verfication.md index b523330136..78c244db90 100644 --- a/src/fairchem/core/common/parallelism/graph_parallel_verfication.md +++ b/src/fairchem/core/common/parallelism/graph_parallel_verfication.md @@ -38,27 +38,16 @@ pytest tests/core/common/parallelism/test_a2a_correctness.py -v 8 tests: correctness at 100/500 atoms × 2 strategies, consistency across graph sizes × 2, 1536-dim embeddings × 2. -### 1d. Full-model GPU correctness (8 GPUs, ~5 min) - -```bash -# Local 8-GPU -fairchem -c configs/uma/correctness/gp-correctness.yaml job=local_8gpu - -# SLURM -fairchem -c configs/uma/correctness/gp-correctness.yaml job=slurm -``` - -9 checks: 3 atom counts (10, 50, 100) × 3 GP modes (allgather, A2A-spatial, -A2A-index_split). Each mode compared against single-GPU no-GP reference. -Compares energy/forces/stress (tol: energy/stress 5e-4, forces 1e-4). -Results written to `correctness_results.json` in the run directory. - -Also available as a pytest (skipped on CI): +### 1d. Full-model GPU correctness (4+ GPUs, ~5 min) ```bash pytest tests/core/units/mlip_unit/test_predict.py::test_full_model_gp_correctness -v ``` +27 tests: 3 atom counts (10, 50, 100) × 9 configs (no-GP, allgather, A2A-spatial, +A2A-index_split at 1/2/4 workers). Compares energy/forces/stress against single-GPU +reference (tol: energy/stress 5e-4, forces 1e-4). Skipped on CI (`CI=true`). + ### 1e. Predict pipeline + MD consistency (CPU, PR3 branch, ~2 min) ```bash diff --git a/src/fairchem/core/components/benchmark/_single/gp_correctness.py b/src/fairchem/core/components/benchmark/_single/gp_correctness.py deleted file mode 100644 index 00a78ced78..0000000000 --- a/src/fairchem/core/components/benchmark/_single/gp_correctness.py +++ /dev/null @@ -1,200 +0,0 @@ -""" -Copyright (c) Meta Platforms, Inc. and affiliates. - -This source code is licensed under the MIT license found in the -LICENSE file in the root directory of this source tree. - -GP correctness runner: compares single-GPU reference against -all-gather and all-to-all graph parallel modes. - -Usage: - fairchem -c configs/uma/correctness/gp-correctness.yaml job=local_8gpu -""" - -from __future__ import annotations - -import json -import logging -import os -import random - -import numpy as np -import torch - -from fairchem.core.common import distutils -from fairchem.core.components.runner import Runner -from fairchem.core.datasets.atomic_data import AtomicData -from fairchem.core.datasets.common_structures import get_fcc_crystal_by_num_atoms -from fairchem.core.units.mlip_unit import MLIPPredictUnit -from fairchem.core.units.mlip_unit.api.inference import ( - InferenceSettings, - inference_settings_default, -) - -logging.basicConfig( - level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" -) -logger = logging.getLogger(__name__) - -ENERGY_TOL = 5e-4 -FORCE_TOL = 1e-4 -STRESS_TOL = 5e-4 - - -def _seed_everywhere(seed): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -class GPCorrectnessRunner(Runner): - def __init__( - self, - model_checkpoint: str, - natoms_list: list[int] | None = None, - seed: int = 42, - device: str = "cuda", - inference_settings: InferenceSettings = inference_settings_default(), # noqa: B008 - overrides: dict | None = None, - ): - self.model_checkpoint = model_checkpoint - self.natoms_list = natoms_list or [10, 50, 100] - self.seed = seed - self.device = device - self.inference_settings = inference_settings - self.overrides = overrides or {} - - def run(self) -> None: - run_dir = self.job_config.metadata.results_dir - os.makedirs(run_dir, exist_ok=True) - - rank = distutils.get_rank() - world_size = distutils.get_world_size() - - gp_modes = [ - ("allgather", {}), - ( - "a2a_spatial", - {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}, - ), - ( - "a2a_index_split", - {"use_all_to_all_gp": True, "gp_partition_strategy": "index_split"}, - ), - ] - - results = [] - - for num_atoms in self.natoms_list: - _seed_everywhere(self.seed) - atoms = get_fcc_crystal_by_num_atoms(num_atoms) - data = AtomicData.from_ase(atoms, task_name="omat") - - # Reference: single-rank prediction (no GP overrides) - _seed_everywhere(self.seed) - ref_predictor = MLIPPredictUnit( - self.model_checkpoint, - self.device, - overrides=self.overrides, - inference_settings=self.inference_settings, - ) - ref_result = ref_predictor.predict(data) - ref_energy = ref_result["energy"].detach().cpu() - ref_forces = ref_result["forces"].detach().cpu() - ref_stress = ref_result["stress"].detach().cpu() - del ref_predictor - torch.cuda.empty_cache() - - for mode_name, gp_overrides in gp_modes: - merged = {**self.overrides} - if "backbone" not in merged: - merged["backbone"] = {} - merged["backbone"].update(gp_overrides) - - _seed_everywhere(self.seed) - predictor = MLIPPredictUnit( - self.model_checkpoint, - self.device, - overrides=merged, - inference_settings=self.inference_settings, - ) - gp_result = predictor.predict(data) - gp_energy = gp_result["energy"].detach().cpu() - gp_forces = gp_result["forces"].detach().cpu() - gp_stress = gp_result["stress"].detach().cpu() - del predictor - torch.cuda.empty_cache() - - energy_diff = torch.abs(gp_energy - ref_energy).max().item() - forces_diff = torch.abs(gp_forces - ref_forces).max().item() - stress_diff = torch.abs(gp_stress - ref_stress).max().item() - - energy_ok = energy_diff < ENERGY_TOL - forces_ok = forces_diff < FORCE_TOL - stress_ok = stress_diff < STRESS_TOL - all_ok = energy_ok and forces_ok and stress_ok - - result = { - "num_atoms": num_atoms, - "world_size": world_size, - "mode": mode_name, - "energy_diff": energy_diff, - "forces_diff": forces_diff, - "stress_diff": stress_diff, - "energy_ok": energy_ok, - "forces_ok": forces_ok, - "stress_ok": stress_ok, - "pass": all_ok, - } - results.append(result) - - status = "PASS" if all_ok else "FAIL" - logger.info( - f"[{status}] {mode_name} | {num_atoms} atoms | " - f"ws={world_size} | " - f"energy_diff={energy_diff:.2e} | " - f"forces_diff={forces_diff:.2e} | " - f"stress_diff={stress_diff:.2e}" - ) - - # Write results JSON (rank 0 only) - if rank == 0: - output = { - "world_size": world_size, - "model": self.model_checkpoint, - "tolerances": { - "energy": ENERGY_TOL, - "forces": FORCE_TOL, - "stress": STRESS_TOL, - }, - "results": results, - } - results_path = os.path.join(run_dir, "correctness_results.json") - with open(results_path, "w") as f: - json.dump(output, f, indent=2) - logger.info(f"Results written to {results_path}") - - # Summary - n_pass = sum(1 for r in results if r["pass"]) - n_total = len(results) - logger.info(f"Summary: {n_pass}/{n_total} passed") - - if n_pass < n_total: - failures = [r for r in results if not r["pass"]] - for f in failures: - logger.error( - f"FAILED: {f['mode']} | {f['num_atoms']} atoms | " - f"energy={f['energy_diff']:.2e} " - f"forces={f['forces_diff']:.2e} " - f"stress={f['stress_diff']:.2e}" - ) - raise AssertionError( - f"{n_total - n_pass}/{n_total} GP correctness checks failed" - ) - - def save_state(self, checkpoint_location: str, is_preemption: bool = False) -> bool: - return True - - def load_state(self, checkpoint_location: str | None) -> None: - pass From cc71ad70f8c0c4d05ff7eb197ad420d961aff654 Mon Sep 17 00:00:00 2001 From: rgao user Date: Fri, 5 Jun 2026 20:55:55 +0000 Subject: [PATCH 20/26] Simplify build_gp_context: use remote_mask, rename needed_from_ranks_sorted --- .../core/common/parallelism/graph_parallel_a2a.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/fairchem/core/common/parallelism/graph_parallel_a2a.py b/src/fairchem/core/common/parallelism/graph_parallel_a2a.py index 9cf266aa10..da70784e67 100644 --- a/src/fairchem/core/common/parallelism/graph_parallel_a2a.py +++ b/src/fairchem/core/common/parallelism/graph_parallel_a2a.py @@ -233,11 +233,11 @@ def build_gp_context( total_local_atoms = node_partition.shape[0] # Find remote atoms needed as edge sources. - local_mask = rank_assignments == rank - src_is_remote = ~local_mask[edge_index[0]] + remote_mask = rank_assignments != rank + src_is_remote = remote_mask[edge_index[0]] needed_mask = torch.zeros(total_atoms, dtype=torch.bool, device=device) needed_mask[edge_index[0, src_is_remote]] = True - needed_mask &= ~local_mask + needed_mask &= remote_mask needed_atoms = needed_mask.nonzero(as_tuple=True)[0] total_needed_atoms = needed_atoms.shape[0] @@ -254,7 +254,7 @@ def build_gp_context( # Sort needed_atoms by source rank to match A2A recv_buf ordering. sort_order = needed_from_ranks.argsort(stable=True) needed_atoms = needed_atoms[sort_order] - needed_from_ranks = needed_from_ranks[sort_order] + needed_from_ranks_sorted = needed_from_ranks[sort_order] # Exchange send metadata via collective. with record_function("a2a_sparse_index_exchange"): @@ -319,7 +319,7 @@ def build_gp_context( node_partition=node_partition, rank_assignments=rank_assignments, needed_atoms=needed_atoms, - needed_from_ranks=needed_from_ranks, + needed_from_ranks=needed_from_ranks_sorted, send_counts=send_counts, recv_counts=recv_counts, global_to_local=global_to_local, From 0467a4e673c9bc41d82c3acdb1a6bd56a589928e Mon Sep 17 00:00:00 2001 From: rgao user Date: Fri, 5 Jun 2026 20:58:47 +0000 Subject: [PATCH 21/26] Remove redundant remote_mask filter in build_gp_context --- src/fairchem/core/common/parallelism/graph_parallel_a2a.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/fairchem/core/common/parallelism/graph_parallel_a2a.py b/src/fairchem/core/common/parallelism/graph_parallel_a2a.py index da70784e67..ce393dab16 100644 --- a/src/fairchem/core/common/parallelism/graph_parallel_a2a.py +++ b/src/fairchem/core/common/parallelism/graph_parallel_a2a.py @@ -237,7 +237,6 @@ def build_gp_context( src_is_remote = remote_mask[edge_index[0]] needed_mask = torch.zeros(total_atoms, dtype=torch.bool, device=device) needed_mask[edge_index[0, src_is_remote]] = True - needed_mask &= remote_mask needed_atoms = needed_mask.nonzero(as_tuple=True)[0] total_needed_atoms = needed_atoms.shape[0] From 88c776e66187b0f32e036ba489dec32f9cdbc36f Mon Sep 17 00:00:00 2001 From: rgao user Date: Fri, 5 Jun 2026 21:01:38 +0000 Subject: [PATCH 22/26] Update GPContext docstring with missing field descriptions --- .../core/common/parallelism/graph_parallel_a2a.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/fairchem/core/common/parallelism/graph_parallel_a2a.py b/src/fairchem/core/common/parallelism/graph_parallel_a2a.py index ce393dab16..7f540f74a2 100644 --- a/src/fairchem/core/common/parallelism/graph_parallel_a2a.py +++ b/src/fairchem/core/common/parallelism/graph_parallel_a2a.py @@ -71,8 +71,10 @@ class GPContext: rank_assignments: For every atom in the global graph, which rank owns it. Shape: (total_atoms,), dtype: int. needed_atoms: Global indices of non-local atoms this rank needs - (sources of edges whose targets are in this rank's partition). - needed_from_ranks: For each atom in needed_atoms, which rank owns it. + (sources of edges whose targets are in this rank's partition), + sorted by source rank. + needed_from_ranks: For each atom in needed_atoms, which rank owns it + (sorted to match needed_atoms ordering). send_counts: Number of atoms to send to each rank. Shape: (world_size,). recv_counts: Number of atoms to receive from each rank. Shape: (world_size,). @@ -84,6 +86,11 @@ class GPContext: send_indices: Local indices of atoms to send, ordered by destination rank. edge_index_local: Edge index remapped to local indices. + send_splits: Per-rank split sizes for the embedding send buffer. + recv_splits: Per-rank split sizes for the embedding recv buffer. + total_recv: Total number of embeddings to receive (sum of recv_splits). + local_edge_idx: Indices into edge_index_local where source is a local atom. + remote_edge_idx: Indices into edge_index_local where source is a remote atom. """ rank: int From 6dc97085969dee89ee6a1b444082a62f69974f67 Mon Sep 17 00:00:00 2001 From: rgao user Date: Fri, 5 Jun 2026 22:56:57 +0000 Subject: [PATCH 23/26] Slim GPContext to runtime-only fields Remove 6 construction-only fields from GPContext (node_partition, rank_assignments, needed_atoms, needed_from_ranks, global_to_local, total_needed_atoms) that are never read after build_gp_context returns. Update tests to verify through edge_index_local and message passing instead of accessing removed intermediates. --- .../common/parallelism/graph_parallel_a2a.py | 47 +++--------- .../parallelism/test_a2a_correctness.py | 72 ++++++++----------- .../common/parallelism/test_graph_parallel.py | 49 +++++++------ 3 files changed, 70 insertions(+), 98 deletions(-) diff --git a/src/fairchem/core/common/parallelism/graph_parallel_a2a.py b/src/fairchem/core/common/parallelism/graph_parallel_a2a.py index 7f540f74a2..8dfd151f20 100644 --- a/src/fairchem/core/common/parallelism/graph_parallel_a2a.py +++ b/src/fairchem/core/common/parallelism/graph_parallel_a2a.py @@ -58,33 +58,21 @@ def _safe_all_to_all( @dataclass class GPContext: """ - Graph parallel context holding per-rank atom assignments - and communication metadata for all-to-all. + Graph parallel context holding communication metadata for all-to-all. - This replaces the all-gather approach by tracking which atoms - each rank needs from other ranks for its local edge computations. + Runtime-only struct: every field is needed for the forward/backward pass. + Construction intermediates (node_partition, rank_assignments, needed_atoms, + global_to_local, etc.) are computed in build_gp_context but not stored. Attributes: rank: Current GP rank. world_size: Number of GP ranks. - node_partition: Global indices of atoms owned by this rank. - rank_assignments: For every atom in the global graph, which rank owns it. - Shape: (total_atoms,), dtype: int. - needed_atoms: Global indices of non-local atoms this rank needs - (sources of edges whose targets are in this rank's partition), - sorted by source rank. - needed_from_ranks: For each atom in needed_atoms, which rank owns it - (sorted to match needed_atoms ordering). - send_counts: Number of atoms to send to each rank. Shape: (world_size,). - recv_counts: Number of atoms to receive from each rank. - Shape: (world_size,). - global_to_local: Mapping from global atom index to position in the - local concatenated tensor [local_atoms | received_atoms]. - Shape: (total_atoms,), with -1 for atoms not accessible. total_local_atoms: Number of atoms in this rank's partition. - total_needed_atoms: Total atoms needed from other ranks. send_indices: Local indices of atoms to send, ordered by destination rank. + send_counts: Number of atoms to send to each rank. Shape: (world_size,). + recv_counts: Number of atoms to receive from each rank. + Shape: (world_size,). edge_index_local: Edge index remapped to local indices. send_splits: Per-rank split sizes for the embedding send buffer. recv_splits: Per-rank split sizes for the embedding recv buffer. @@ -95,16 +83,10 @@ class GPContext: rank: int world_size: int - node_partition: torch.Tensor - rank_assignments: torch.Tensor - needed_atoms: torch.Tensor - needed_from_ranks: torch.Tensor - send_counts: torch.Tensor - recv_counts: torch.Tensor - global_to_local: torch.Tensor total_local_atoms: int - total_needed_atoms: int send_indices: torch.Tensor + send_counts: torch.Tensor + recv_counts: torch.Tensor edge_index_local: torch.Tensor send_splits: list[int] recv_splits: list[int] @@ -260,7 +242,6 @@ def build_gp_context( # Sort needed_atoms by source rank to match A2A recv_buf ordering. sort_order = needed_from_ranks.argsort(stable=True) needed_atoms = needed_atoms[sort_order] - needed_from_ranks_sorted = needed_from_ranks[sort_order] # Exchange send metadata via collective. with record_function("a2a_sparse_index_exchange"): @@ -322,16 +303,10 @@ def build_gp_context( return GPContext( rank=rank, world_size=world_size, - node_partition=node_partition, - rank_assignments=rank_assignments, - needed_atoms=needed_atoms, - needed_from_ranks=needed_from_ranks_sorted, - send_counts=send_counts, - recv_counts=recv_counts, - global_to_local=global_to_local, total_local_atoms=total_local_atoms, - total_needed_atoms=total_needed_atoms, send_indices=send_indices, + send_counts=send_counts, + recv_counts=recv_counts, edge_index_local=edge_index_local, send_splits=send_splits, recv_splits=recv_splits, diff --git a/tests/core/common/parallelism/test_a2a_correctness.py b/tests/core/common/parallelism/test_a2a_correctness.py index b2fbce17de..3956be90b5 100644 --- a/tests/core/common/parallelism/test_a2a_correctness.py +++ b/tests/core/common/parallelism/test_a2a_correctness.py @@ -15,7 +15,6 @@ from __future__ import annotations -import logging import os import pytest @@ -36,10 +35,6 @@ spawn_multi_process, ) -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - # ========================================================================= # Pytest-compatible distributed tests (CPU, Gloo, 2 processes) # ========================================================================= @@ -98,12 +93,6 @@ def _correctness_test_inner( # Test collect function x_recv_autograd = all_to_all_collect(x_local, gp_ctx, send_indices) - # Verify received values are correct: - # x_recv should contain embeddings of gp_ctx.needed_atoms - # in the correct order (sorted by source rank). - expected_values = x_global[gp_ctx.needed_atoms] - recv_correct = torch.allclose(x_recv_autograd, expected_values, atol=1e-6) - # Verify edge_index_local is valid x_full = torch.cat([x_local, x_recv_autograd], dim=0) edge_valid = (gp_ctx.edge_index_local >= 0).all().item() @@ -139,9 +128,7 @@ def _correctness_test_inner( "partition_strategy": partition_strategy, "world_size": world_size, "local_atoms": gp_ctx.total_local_atoms, - "needed_atoms": gp_ctx.total_needed_atoms, "num_edges": rank_edge_index.shape[1], - "recv_correct": recv_correct, "edge_valid": edge_valid, "edge_in_bounds": edge_in_bounds, "mp_match": mp_match, @@ -162,9 +149,8 @@ def test_a2a_correctness_gloo(strategy, num_atoms): Verify A2A correctness at 2 GPUs using Gloo backend. Creates a dense graph (all atoms connected) and verifies that: - 1. Autograd and compiled collect produce identical results - 2. Received embeddings contain correct values - 3. Message passing produces correct aggregation + 1. Edge indices are valid (non-negative, in-bounds) + 2. Message passing produces correct aggregation vs reference """ # Create dense graph src, dst = [], [] @@ -191,9 +177,6 @@ def test_a2a_correctness_gloo(strategy, num_atoms): for result in all_rank_results: r = result["rank"] - assert result["recv_correct"], ( - f"Rank {r}: received embeddings don't match " f"expected values" - ) assert result["edge_valid"], f"Rank {r}: edge_index_local has negative entries" assert result[ "edge_in_bounds" @@ -242,9 +225,6 @@ def test_a2a_consistency_across_graph_sizes(strategy): for result in all_rank_results: r = result["rank"] - assert result["recv_correct"], ( - f"Rank {r}: received embeddings don't match " f"expected values" - ) assert result["mp_match"], ( f"Rank {r}: message passing result differs " f"from reference" ) @@ -283,15 +263,34 @@ def _multidim_test_inner(x_global, pos, edge_index, num_atoms, strategy): x_recv = all_to_all_collect(x_local, gp_ctx, send_indices) - # Verify - expected = x_global[gp_ctx.needed_atoms] - recv_correct = torch.allclose(x_recv, expected, atol=1e-6) + # Verify message passing produces the same result as non-distributed. + x_full = torch.cat([x_local, x_recv], dim=0) + x_source = x_full[gp_ctx.edge_index_local[0]] + local_result = torch.zeros( + gp_ctx.total_local_atoms, + x_source.shape[1], + dtype=x_source.dtype, + device=x_source.device, + ) + local_result.index_add_(0, gp_ctx.edge_index_local[1], x_source) + + # Reference: compute the same aggregation on the full graph + x_source_ref = x_global[rank_edge_index[0]] + ref_result = torch.zeros( + num_atoms, + x_source_ref.shape[1], + dtype=x_source_ref.dtype, + device=x_source_ref.device, + ) + ref_result.index_add_(0, rank_edge_index[1], x_source_ref) + ref_local = ref_result[node_partition] + + mp_match = torch.allclose(local_result, ref_local, atol=1e-6) return { "rank": rank, - "recv_correct": recv_correct, + "mp_match": mp_match, "recv_shape": x_recv.shape, - "expected_shape": expected.shape, } @@ -334,10 +333,8 @@ def test_a2a_multidim_embeddings(strategy): for result in all_rank_results: r = result["rank"] - assert result["recv_correct"], ( - f"Rank {r}: multidim recv mismatch, " - f"shape={result['recv_shape']} " - f"vs {result['expected_shape']}" + assert result["mp_match"], ( + f"Rank {r}: multidim mp mismatch, " f"recv_shape={result['recv_shape']}" ) @@ -405,9 +402,6 @@ def test_a2a_correctness_gpu(strategy, num_atoms): for result in all_rank_results: r = result["rank"] - assert result[ - "recv_correct" - ], f"Rank {r}: received embeddings don't match expected values on GPU" assert result[ "edge_valid" ], f"Rank {r}: edge_index_local has negative entries on GPU" @@ -449,9 +443,6 @@ def test_a2a_consistency_across_graph_sizes_gpu(strategy): for result in all_rank_results: r = result["rank"] - assert result[ - "recv_correct" - ], f"Rank {r}: received embeddings don't match expected values on GPU" assert result[ "mp_match" ], f"Rank {r}: message passing result differs from reference on GPU" @@ -489,8 +480,7 @@ def test_a2a_multidim_embeddings_gpu(strategy): for result in all_rank_results: r = result["rank"] - assert result["recv_correct"], ( - f"Rank {r}: multidim recv mismatch on GPU, " - f"shape={result['recv_shape']} " - f"vs {result['expected_shape']}" + assert result["mp_match"], ( + f"Rank {r}: multidim mp mismatch on GPU, " + f"recv_shape={result['recv_shape']}" ) diff --git a/tests/core/common/parallelism/test_graph_parallel.py b/tests/core/common/parallelism/test_graph_parallel.py index e7aa1e4fa1..d82d75ca4b 100644 --- a/tests/core/common/parallelism/test_graph_parallel.py +++ b/tests/core/common/parallelism/test_graph_parallel.py @@ -141,32 +141,38 @@ def test_basic_context_building(self): assert ctx.rank == 0 assert ctx.world_size == 2 assert ctx.total_local_atoms == 2 - assert torch.equal(ctx.node_partition, torch.tensor([0, 1])) - # Rank 0 targets: atoms 0 and 1 - # Edge (2, 0): src=2 is remote, tgt=0 is local -> need atom 2 - # Edge (1, 0): src=1 is local -> don't need - # Edge (0, 1): src=0 is local -> don't need - assert 2 in ctx.needed_atoms + # Edge (2, 0): src=2 is remote -> should appear as remote src in edge_index_local + # Verify edge_index_local has some remote sources (index >= total_local_atoms) + has_remote_src = (ctx.edge_index_local[0] >= ctx.total_local_atoms).any() + assert has_remote_src - def test_global_to_local_mapping(self): + def test_edge_index_local_validity(self): """ - Verify that global_to_local correctly maps to local indices. + Verify that edge_index_local indices are valid and correctly + separate local vs remote sources. """ + # 4 atoms: rank 0 owns [0,1], rank 1 owns [2,3] + # Full graph edges, then filter to rank 0's targets edge_index = torch.tensor([[0, 1, 2, 3], [1, 0, 3, 2]]) rank_assignments = torch.tensor([0, 0, 1, 1]) - ctx = build_gp_context(edge_index, rank_assignments, rank=0, world_size=2) + # Filter to edges where target belongs to rank 0 + target_mask = (rank_assignments == 0)[edge_index[1]] + rank_edge_index = edge_index[:, target_mask] + + ctx = build_gp_context(rank_edge_index, rank_assignments, rank=0, world_size=2) + + # All indices should be non-negative + assert (ctx.edge_index_local >= 0).all() - # Local atoms [0, 1] should map to indices [0, 1] - assert ctx.global_to_local[0] == 0 - assert ctx.global_to_local[1] == 1 + # All indices should be in bounds + total = ctx.total_local_atoms + ctx.total_recv + assert (ctx.edge_index_local < total).all() - # Remote atoms that are needed should map to - # indices >= total_local_atoms - for atom in ctx.needed_atoms: - local_idx = ctx.global_to_local[atom].item() - assert local_idx >= ctx.total_local_atoms + # All targets should be local (we filtered to rank 0 targets) + n_local = ctx.total_local_atoms + assert (ctx.edge_index_local[1] < n_local).all() def test_no_cross_partition_edges(self): """ @@ -176,8 +182,9 @@ def test_no_cross_partition_edges(self): rank_assignments = torch.tensor([0, 0, 1, 1]) ctx = build_gp_context(edge_index, rank_assignments, rank=0, world_size=2) - assert ctx.total_needed_atoms == 0 - assert ctx.needed_atoms.numel() == 0 + # All sources should be local (no remote atoms needed) + assert (ctx.edge_index_local[0] < ctx.total_local_atoms).all() + assert ctx.recv_counts.sum() == 0 def test_edge_split_indices(self): """ @@ -252,8 +259,8 @@ def _a2a_simple_layer(x, edge_index, rank_assignments, natoms): # Combine local + received x_full = torch.cat([x, x_received], dim=0) - # Remap edges to local space (inline, no helper function needed) - edge_index_local = gp_ctx.global_to_local[edge_index] + # Use precomputed local edge index + edge_index_local = gp_ctx.edge_index_local # Simple message passing: source embeddings aggregated to targets x_source = x_full[edge_index_local[0]] From fb7ed9c93238e1b1eb6f8c6110a581382e2c957c Mon Sep 17 00:00:00 2001 From: rgao user Date: Fri, 5 Jun 2026 23:19:45 +0000 Subject: [PATCH 24/26] Increase Gloo process group init timeout from 10s to 120s The 10s timeout causes flaky failures in CI where runners are resource-constrained and process rendezvous takes longer. --- src/fairchem/core/common/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fairchem/core/common/test_utils.py b/src/fairchem/core/common/test_utils.py index b8acb93ec4..0182b515cf 100644 --- a/src/fairchem/core/common/test_utils.py +++ b/src/fairchem/core/common/test_utils.py @@ -86,7 +86,7 @@ def init_pg_and_rank_and_launch_test( dist.init_process_group( init_method=init_method, backend=pg_setup_params.backend, - timeout=timedelta(seconds=10), + timeout=timedelta(seconds=120), ) # setup gp if pg_setup_params.use_gp: From 3e8059b92dec17cf9d2909ba1da5459d5ddbbab4 Mon Sep 17 00:00:00 2001 From: rgao user Date: Sun, 7 Jun 2026 03:00:43 +0000 Subject: [PATCH 25/26] Remove AABB halo optimization from A2A graph parallel Benchmarks at 64 GPUs showed halo adds ~5% overhead vs plain A2A. The _sparse_index_exchange already minimizes communication, and radius_graph_pbc_v2 grid indexing makes graph gen fast enough that filtering to a halo subset adds cost without measurable savings. Removes ~150 lines: _compute_aabb_halo, _compute_halo_graph, use_aabb_halo flag, and the halo call site in _generate_graph. Also adds timeout_hr to slurm benchmark config. --- configs/uma/speed/job/slurm.yaml | 1 + src/fairchem/core/models/uma/escn_md.py | 205 ++---------------------- 2 files changed, 15 insertions(+), 191 deletions(-) diff --git a/configs/uma/speed/job/slurm.yaml b/configs/uma/speed/job/slurm.yaml index 0006ba20b3..9088fe575f 100644 --- a/configs/uma/speed/job/slurm.yaml +++ b/configs/uma/speed/job/slurm.yaml @@ -10,4 +10,5 @@ scheduler: qos: h200_ocp_high mem_gb: 0 cpus_per_task: 24 + timeout_hr: 1 graph_parallel_group_size: 64 diff --git a/src/fairchem/core/models/uma/escn_md.py b/src/fairchem/core/models/uma/escn_md.py index db8daa995f..ac3c4ac875 100644 --- a/src/fairchem/core/models/uma/escn_md.py +++ b/src/fairchem/core/models/uma/escn_md.py @@ -9,7 +9,6 @@ import logging import os -import types from dataclasses import dataclass from typing import TYPE_CHECKING, Literal @@ -646,150 +645,6 @@ def _compute_a2a_partition( return rank_assignments, node_partition - @torch.compiler.disable - def _compute_aabb_halo( - self, - pos: torch.Tensor, - node_partition: torch.Tensor, - pbc: torch.Tensor, - cell: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Compute AABB halo mask for graph generation filtering. - - Builds an axis-aligned bounding box (AABB) around the local - partition's atoms, expanded by the interaction cutoff, then - identifies ALL atoms (including PBC images) that fall within - this box. Graph generation then operates on ~N_halo atoms - instead of N_total, significantly reducing cost at scale. - - No NCCL communication — purely local computation. - - Note: Only supports single-system inputs (not batched). - For multi-system batches, the caller should skip halo - filtering and fall back to full graph generation. - - Returns: - Tuple of (halo_mask, shift_vecs) where: - - halo_mask: Boolean mask over all atoms in the AABB halo. - - shift_vecs: PBC shift vectors (n_shifts, 3). - """ - device = pos.device - n_total = len(pos) - - # --- Compute our AABB (expanded by cutoff) --- - local_pos = pos[node_partition] - lo = local_pos.min(dim=0)[0] - self.cutoff - hi = local_pos.max(dim=0)[0] + self.cutoff - - cell_sq = cell.view(3, 3) if cell.dim() == 3 else cell - pbc_flat = pbc.view(3) if pbc.dim() == 2 else pbc - - # Build shift vectors for periodic images (up to 27). - shift_components = [] - for d in range(3): - if pbc_flat[d]: - shift_components.append(torch.tensor([-1, 0, 1], device=device)) - else: - shift_components.append(torch.tensor([0], device=device)) - grid = torch.cartesian_prod(*shift_components) - shift_vecs = grid.float() @ cell_sq - - # --- Compute our halo mask (atoms we need) --- - halo_mask = torch.zeros(n_total, dtype=torch.bool, device=device) - for shift in shift_vecs: - shifted = pos + shift - in_box = ((shifted >= lo) & (shifted <= hi)).all(dim=-1) - halo_mask |= in_box - - return halo_mask, shift_vecs - - @torch.compiler.disable - def _compute_halo_graph( - self, - data_dict: dict, - node_partition: torch.Tensor, - rank_assignments: torch.Tensor, - pbc: torch.Tensor, - halo_mask: torch.Tensor, - ) -> dict | None: - """ - Try to generate graph using AABB halo filtering. - - Uses the pre-computed halo_mask from _compute_aabb_halo to - filter graph gen input from N_total to ~N_halo atoms. - Returns the graph_dict if the halo achieves significant - reduction, or None to fall back to full graph generation. - - The edge_index in the returned graph_dict is remapped to - global coordinates for use by the backbone forward pass. - Additionally, halo-local metadata is stored so that - ``build_gp_context`` can work with O(N_halo)-sized tensors - instead of O(N_total). - - Args: - data_dict: Full data dictionary with pos, cell, etc. - node_partition: Local atom indices (global coords). - rank_assignments: Rank assignment per atom (global). - pbc: Periodic boundary conditions. - halo_mask: Pre-computed boolean halo mask. - - Returns: - graph_dict with edge_index in global coordinates and - halo metadata for build_gp_context, - or None if halo didn't help. - """ - with record_function("a2a_halo_filter"): - pos = data_dict["pos"] - cell = data_dict["cell"] - n_total = len(pos) - - n_halo = halo_mask.sum().item() - - # Only use halo filtering if it reduces atoms enough. - if n_halo >= n_total * 0.95: - return None - - halo_indices = halo_mask.nonzero(as_tuple=True)[0] - - # Map global indices to halo-local indices - global_to_halo = torch.full( - (n_total,), -1, dtype=torch.long, device=pos.device - ) - global_to_halo[halo_indices] = torch.arange(n_halo, device=pos.device) - - # Create subset data for graph generation. - data_subset = types.SimpleNamespace() - data_subset.pos = pos[halo_indices] - data_subset.cell = cell - data_subset.natoms = torch.tensor([n_halo], device=pos.device) - data_subset.batch = torch.zeros(n_halo, dtype=torch.long, device=pos.device) - data_subset.pbc = pbc - - # Remap partition and rank_assignments to halo-local - node_partition_local = global_to_halo[node_partition] - assert (node_partition_local >= 0).all(), ( - "Local partition atoms not found in halo — " - "AABB expansion may be too small" - ) - - graph_dict = generate_graph( - data_subset, - cutoff=self.cutoff, - max_neighbors=self.max_neighbors, - enforce_max_neighbors_strictly=(self.enforce_max_neighbors_strictly), - radius_pbc_version=self.radius_pbc_version, - pbc=pbc, - node_partition=node_partition_local, - ) - - # Remap edge_index from halo-local to global for use - # by the backbone forward pass (source/target embedding - # lookups index into atomic_numbers_full). - graph_dict["edge_index"] = halo_indices[graph_dict["edge_index"]] - - return graph_dict - def _generate_graph(self, data_dict): node_partition = None rank_assignments = None @@ -840,49 +695,15 @@ def _generate_graph(self, data_dict): pbc.all() or (~pbc).all() ), "We can only accept pbc that is all true or all false" - # AABB halo optimization for A2A: - # Compute AABB bounding box around local partition, - # expanded by cutoff, to filter graph gen input from - # N_total to ~N_halo atoms. Pure local computation, - # no NCCL needed. - # Note: AABB halo only supports single-system inputs. - # Multi-system batches skip halo and use full graph gen. - graph_dict = None - is_single_system = data_dict["cell"].dim() == 2 or ( - data_dict["cell"].dim() == 3 and data_dict["cell"].shape[0] == 1 + graph_dict = generate_graph( + data_dict, + cutoff=self.cutoff, + max_neighbors=self.max_neighbors, + enforce_max_neighbors_strictly=(self.enforce_max_neighbors_strictly), + radius_pbc_version=self.radius_pbc_version, + pbc=pbc, + node_partition=node_partition, ) - if ( - self.use_all_to_all_gp - and rank_assignments is not None - and is_single_system - ): - halo_mask, _ = self._compute_aabb_halo( - data_dict["pos"], - node_partition, - pbc, - data_dict["cell"], - ) - graph_dict = self._compute_halo_graph( - data_dict, - node_partition, - rank_assignments, - pbc, - halo_mask, - ) - - if graph_dict is None: - # Full graph gen (no halo filter, or halo didn't help) - graph_dict = generate_graph( - data_dict, - cutoff=self.cutoff, - max_neighbors=self.max_neighbors, - enforce_max_neighbors_strictly=( - self.enforce_max_neighbors_strictly - ), - radius_pbc_version=self.radius_pbc_version, - pbc=pbc, - node_partition=node_partition, - ) else: # this assume edge_index is provided assert ( @@ -1086,10 +907,12 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]: # for scattering edge messages to nodes. For A2A, use # gp_ctx.edge_index_local[1] (already local). For allgather, # use pre-computed global→local mapped targets. - gp_ctx.edge_index_local[1] - if gp_ctx is not None - else data_dict.get( - "scatter_target", default=graph_dict["edge_index"][1] + ( + gp_ctx.edge_index_local[1] + if gp_ctx is not None + else data_dict.get( + "scatter_target", default=graph_dict["edge_index"][1] + ) ), wigner_inv_envelope, ) From 83025bedbc09b886a7033ee50c3e8b9ddc75950d Mon Sep 17 00:00:00 2001 From: rgao user Date: Mon, 8 Jun 2026 23:12:18 +0000 Subject: [PATCH 26/26] Revert timeout_hr addition in slurm benchmark config --- configs/uma/speed/job/slurm.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/configs/uma/speed/job/slurm.yaml b/configs/uma/speed/job/slurm.yaml index 9088fe575f..0006ba20b3 100644 --- a/configs/uma/speed/job/slurm.yaml +++ b/configs/uma/speed/job/slurm.yaml @@ -10,5 +10,4 @@ scheduler: qos: h200_ocp_high mem_gb: 0 cpus_per_task: 24 - timeout_hr: 1 graph_parallel_group_size: 64