diff --git a/src/fairchem/core/common/parallelism/graph_parallel_verfication.md b/src/fairchem/core/common/parallelism/graph_parallel_verfication.md new file mode 100644 index 0000000000..78c244db90 --- /dev/null +++ b/src/fairchem/core/common/parallelism/graph_parallel_verfication.md @@ -0,0 +1,314 @@ +# A2A Graph Parallel — Test & Benchmark Runbook + +All commands are run-ready. Copy-paste from the fairchem repo root: + +```bash +cd /home/rgao/fairchem +source .venv/bin/activate +``` + +--- + +## 1. Correctness Tests + +### 1a. GPContext unit tests (CPU, <1s) + +```bash +pytest tests/core/common/parallelism/test_graph_parallel.py::TestBuildGPContext -v +``` + +5 tests: `test_basic_context_building`, `test_global_to_local_mapping`, +`test_no_cross_partition_edges`, `test_edge_split_indices`, `test_edge_split_no_remote_edges`. + +### 1b. Distributed A2A primitives (CPU/Gloo, ~10s) + +```bash +pytest tests/core/common/parallelism/test_graph_parallel.py -v -k "not TestBuildGPContext" +``` + +8 tests: forward equivalence (2 graph topologies), backward gradient, multi-rank (2 and 3), +spatial partition, compiled-vs-autograd (index_split + spatial). + +### 1c. A2A embedding correctness (CPU/Gloo, ~30s) + +```bash +pytest tests/core/common/parallelism/test_a2a_correctness.py -v +``` + +8 tests: correctness at 100/500 atoms × 2 strategies, consistency across graph sizes × 2, +1536-dim embeddings × 2. + +### 1d. Full-model GPU correctness (4+ GPUs, ~5 min) + +```bash +pytest tests/core/units/mlip_unit/test_predict.py::test_full_model_gp_correctness -v +``` + +27 tests: 3 atom counts (10, 50, 100) × 9 configs (no-GP, allgather, A2A-spatial, +A2A-index_split at 1/2/4 workers). Compares energy/forces/stress against single-GPU +reference (tol: energy/stress 5e-4, forces 1e-4). Skipped on CI (`CI=true`). + +### 1e. Predict pipeline + MD consistency (CPU, PR3 branch, ~2 min) + +```bash +pytest tests/core/units/mlip_unit/test_predict.py::test_parallel_predict_unit_cpu -v -k "spatial or index_split" +pytest tests/core/units/mlip_unit/test_predict.py::test_parallel_predict_unit_gpu -v -k "spatial or index_split" +pytest tests/core/units/mlip_unit/test_predict.py::test_parallel_predict_unit_batch -v -k "spatial" +pytest tests/core/units/mlip_unit/test_predict.py::test_merge_mole_md_consistency -v -k "spatial or index_split" +``` + +### 1f. All correctness at once + +```bash +pytest tests/core/common/parallelism/ -v +``` + +--- + +## 2. Regression Tests + +### 2a. Existing tests must pass unchanged + +```bash +pytest tests/core/common/test_gp_utils.py -v +pytest tests/core/models/uma/test_escn_md.py -v +pytest tests/core/models/uma/test_compile.py -v +pytest tests/core/components/test_uma_speed_benchmark.py -v +pytest tests/core/units/mlip_unit/test_predict.py::test_parallel_predict_unit_cpu -v -k "not spatial and not index_split" +``` + +### 2b. All regression at once + +```bash +pytest tests/core/common/test_gp_utils.py tests/core/models/uma/test_escn_md.py tests/core/models/uma/test_compile.py tests/core/components/test_uma_speed_benchmark.py -v +``` + +--- + +## 3. Performance Benchmarks + +All benchmarks use `InferenceBenchRunner` with FCC crystal systems (deterministic), +`timeiters=100`, `repeats=5`, turbo mode (compile=True, tf32=True). +Results written to `benchmark_results.json` in the run directory under `/checkpoint/ocp/rgao/speed/`. + +### 3a. 8-GPU single-node (local, ~10 min each) + +```bash +# BL baseline +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=local_8gpu \ + runner.natoms_list=[8000,32000] \ + runner.timeiters=100 \ + runner.repeats=5 + +# A2A + spatial +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=local_8gpu \ + runner.natoms_list=[8000,32000] \ + runner.timeiters=100 \ + runner.repeats=5 \ + '+runner.overrides={backbone: {use_all_to_all_gp: true, gp_partition_strategy: spatial}}' +``` + +Expected: A2A within ±5% of BL (NVLink hides comm cost). + +### 3b. 16-GPU 2-node (SLURM, ~15 min each) + +```bash +# BL baseline +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=slurm \ + job.scheduler.num_nodes=2 \ + job.scheduler.ranks_per_node=8 \ + job.graph_parallel_group_size=16 \ + job.scheduler.slurm.qos=h200_dev \ + job.scheduler.slurm.account=ocp \ + runner.natoms_list=[16000,64000] \ + runner.timeiters=100 \ + runner.repeats=5 + +# A2A + spatial +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=slurm \ + job.scheduler.num_nodes=2 \ + job.scheduler.ranks_per_node=8 \ + job.graph_parallel_group_size=16 \ + job.scheduler.slurm.qos=h200_dev \ + job.scheduler.slurm.account=ocp \ + runner.natoms_list=[16000,64000] \ + runner.timeiters=100 \ + runner.repeats=5 \ + '+runner.overrides={backbone: {use_all_to_all_gp: true, gp_partition_strategy: spatial}}' +``` + +Expected: A2A +2-3%. + +### 3c. 32-GPU 4-node (SLURM, ~15 min each) + +```bash +# BL baseline +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=slurm \ + job.scheduler.num_nodes=4 \ + job.scheduler.ranks_per_node=8 \ + job.graph_parallel_group_size=32 \ + job.scheduler.slurm.qos=h200_alignment_shared \ + job.scheduler.slurm.account=ocp \ + runner.natoms_list=[32000,128000] \ + runner.timeiters=100 \ + runner.repeats=5 + +# A2A + spatial +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=slurm \ + job.scheduler.num_nodes=4 \ + job.scheduler.ranks_per_node=8 \ + job.graph_parallel_group_size=32 \ + job.scheduler.slurm.qos=h200_alignment_shared \ + job.scheduler.slurm.account=ocp \ + runner.natoms_list=[32000,128000] \ + runner.timeiters=100 \ + runner.repeats=5 \ + '+runner.overrides={backbone: {use_all_to_all_gp: true, gp_partition_strategy: spatial}}' +``` + +Expected: A2A +5-10%. + +### 3d. 64-GPU 8-node — CANONICAL BENCHMARK (SLURM, ~30 min each) + +```bash +# BL baseline +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=slurm \ + job.scheduler.num_nodes=8 \ + job.scheduler.ranks_per_node=8 \ + job.graph_parallel_group_size=64 \ + job.scheduler.slurm.qos=h200_alignment_shared \ + job.scheduler.slurm.account=ocp \ + runner.natoms_list=[64000,256000] \ + runner.timeiters=100 \ + runner.repeats=5 + +# A2A + spatial +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=slurm \ + job.scheduler.num_nodes=8 \ + job.scheduler.ranks_per_node=8 \ + job.graph_parallel_group_size=64 \ + job.scheduler.slurm.qos=h200_alignment_shared \ + job.scheduler.slurm.account=ocp \ + runner.natoms_list=[64000,256000] \ + runner.timeiters=100 \ + runner.repeats=5 \ + '+runner.overrides={backbone: {use_all_to_all_gp: true, gp_partition_strategy: spatial}}' +``` + +Expected: A2A +10-20%. + +### 3e. 128-GPU 16-node (SLURM, ~30 min each) + +```bash +# BL baseline +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=slurm \ + job.scheduler.num_nodes=16 \ + job.scheduler.ranks_per_node=8 \ + job.graph_parallel_group_size=128 \ + job.scheduler.slurm.qos=h200_alignment_shared \ + job.scheduler.slurm.account=ocp \ + runner.natoms_list=[128000,512000] \ + runner.timeiters=100 \ + runner.repeats=5 + +# A2A + spatial +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=slurm \ + job.scheduler.num_nodes=16 \ + job.scheduler.ranks_per_node=8 \ + job.graph_parallel_group_size=128 \ + job.scheduler.slurm.qos=h200_alignment_shared \ + job.scheduler.slurm.account=ocp \ + runner.natoms_list=[128000,512000] \ + runner.timeiters=100 \ + runner.repeats=5 \ + '+runner.overrides={backbone: {use_all_to_all_gp: true, gp_partition_strategy: spatial}}' +``` + +Expected: A2A +30-40%. + +### 3f. 256-GPU 32-node (SLURM, ~45 min each) + +```bash +# BL baseline +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=slurm \ + job.scheduler.num_nodes=32 \ + job.scheduler.ranks_per_node=8 \ + job.graph_parallel_group_size=256 \ + job.scheduler.slurm.qos=h200_alignment_shared \ + job.scheduler.slurm.account=ocp \ + runner.natoms_list=[256000,1024000] \ + runner.timeiters=100 \ + runner.repeats=5 + +# A2A + spatial +fairchem -c configs/uma/speed/uma-speed.yaml \ + job=slurm \ + job.scheduler.num_nodes=32 \ + job.scheduler.ranks_per_node=8 \ + job.graph_parallel_group_size=256 \ + job.scheduler.slurm.qos=h200_alignment_shared \ + job.scheduler.slurm.account=ocp \ + runner.natoms_list=[256000,1024000] \ + runner.timeiters=100 \ + runner.repeats=5 \ + '+runner.overrides={backbone: {use_all_to_all_gp: true, gp_partition_strategy: spatial}}' +``` + +Expected: A2A +60-80%. + +--- + +## 4. Reading Results + +```bash +# Find latest runs +ls -lt /checkpoint/ocp/rgao/speed/ | head -20 + +# Read a specific result +python3 -c " +import json, sys +with open(sys.argv[1]) as f: + d = json.load(f) +ws = d['config']['world_size'] +for model, data in d['model_to_qps_data'].items(): + print(f'Model: {model}, world_size={ws}') + for natoms, ns_day in data: + atoms_per_gpu = natoms // ws + print(f' {natoms:>7} atoms ({atoms_per_gpu}/gpu): {ns_day:.3f} ns/day') +" /checkpoint/ocp/rgao/speed//benchmark_results.json +``` + +--- + +## 5. Reference Results (H200 turbo) + +### 4k atoms/rank (canonical): + +| GPUs | BL ns/day | A2A ns/day | Speedup | BL WS Eff | A2A WS Eff | +|------|-----------|------------|---------|-----------|------------| +| 8 | 0.651 | 0.670 | +2.9% | 100% | 100% | +| 16 | 0.610 | 0.626 | +2.6% | 93.7% | 93.6% | +| 32 | 0.568 | 0.623 | +9.7% | 87.3% | 93.1% | +| 64 | 0.477 | 0.570 | +19.5% | 73.2% | 85.1% | +| 128 | 0.359 | 0.498 | +38.7% | 55.1% | 74.3% | +| 256 | 0.246 | 0.433 | +76.0% | 37.8% | 64.6% | + +### 1k atoms/rank: + +| GPUs | BL ns/day | A2A ns/day | Speedup | +|------|-----------|------------|---------| +| 8 | 2.088 | 2.010 | -3.7% | +| 16 | 2.043 | 1.990 | -2.6% | +| 32 | 1.769 | 1.848 | +4.5% | +| 64 | 1.488 | 1.626 | +9.3% | diff --git a/src/fairchem/core/models/uma/escn_md.py b/src/fairchem/core/models/uma/escn_md.py index 0c8e82fd37..dfe8315111 100644 --- a/src/fairchem/core/models/uma/escn_md.py +++ b/src/fairchem/core/models/uma/escn_md.py @@ -19,6 +19,15 @@ from torch.profiler import record_function from fairchem.core.common import gp_utils +from fairchem.core.common.parallelism.graph_parallel_a2a import ( + GPContext, + build_gp_context, +) +from fairchem.core.common.parallelism.graph_partition import ( + PartitionStrategy, + partition_atoms_index_split, + partition_atoms_spatial, +) from fairchem.core.common.registry import registry from fairchem.core.common.utils import conditional_grad from fairchem.core.graph.compute import generate_graph @@ -93,12 +102,10 @@ class GradRegressConfig: hessian_vmap: bool = True -def add_n_empty_edges( - graph_dict: dict, edges_to_add: int, cutoff: float, node_offset: int = 0 -): +def add_n_empty_edges(graph_dict: dict, edges_to_add: int, cutoff: float): graph_dict["edge_index"] = torch.cat( ( - graph_dict["edge_index"].new_ones(2, edges_to_add) * node_offset, + graph_dict["edge_index"].new_zeros(2, edges_to_add), graph_dict["edge_index"], ), dim=1, @@ -309,6 +316,8 @@ def __init__( spin_balanced_channels: list[int] | None = None, edge_chunk_size: int = 1, execution_mode: str = "general", + use_all_to_all_gp: bool = False, + gp_partition_strategy: str = "index_split", ) -> None: super().__init__() self.max_num_elements = max_num_elements @@ -360,6 +369,26 @@ def __init__( ESCNMD_DEFAULT_EDGE_ACTIVATION_CHECKPOINT_CHUNK_SIZE ) self.edge_chunk_size = edge_chunk_size + self.use_all_to_all_gp = use_all_to_all_gp + self.gp_partition_strategy = PartitionStrategy(gp_partition_strategy) + + # Allgather+spatial is not supported because allgather concatenates + # per-rank tensors in rank order, which only matches global atom order + # for contiguous (index_split) partitions. Spatial partitions are + # non-contiguous, so edge_index lookups into the allgathered tensor + # would reference wrong atoms. Spatial partitioning is designed for + # A2A where only boundary atoms are exchanged. + if ( + not self.use_all_to_all_gp + and self.gp_partition_strategy == PartitionStrategy.SPATIAL + ): + raise ValueError( + "Spatial partitioning is not supported with all-gather GP. " + "Allgather concatenates per-rank tensors in rank order, which " + "only matches global atom order for contiguous (index_split) " + "partitions. Use use_all_to_all_gp=True with spatial, or " + "gp_partition_strategy='index_split' with allgather." + ) self.backend = get_execution_backend(execution_mode) @@ -582,18 +611,67 @@ def csd_embedding(self, charge, spin, dataset): ) return torch.nn.SiLU()(self.mix_csd(torch.cat((chg_emb, spin_emb), dim=1))) + @torch.compiler.disable + def _compute_a2a_partition( + self, + pos: torch.Tensor, + total_atoms: int, + device: torch.device, + world_size: int, + rank: int, + strategy: PartitionStrategy, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Compute A2A rank assignments and node partition. + + Separated from _generate_graph so that only the A2A-specific + partitioning is excluded from torch.compile. The BL (all-gather) + path stays fully compilable. + """ + with record_function("a2a_partition"): + if strategy == PartitionStrategy.SPATIAL: + rank_assignments = partition_atoms_spatial(pos, world_size) + else: + rank_assignments = partition_atoms_index_split( + total_atoms, world_size, device + ) + node_partition = (rank_assignments == rank).nonzero(as_tuple=True)[0] + + return rank_assignments, node_partition + def _generate_graph(self, data_dict): - data_dict["gp_node_offset"] = 0 node_partition = None + rank_assignments = None if gp_utils.initialized(): # create the partitions atomic_numbers_full = data_dict["atomic_numbers_full"] - node_partition = torch.tensor_split( - torch.arange( - len(atomic_numbers_full), device=atomic_numbers_full.device - ), - gp_utils.get_gp_world_size(), - )[gp_utils.get_gp_rank()] + + if self.use_all_to_all_gp: + # All-to-all: compute rank_assignments FIRST, then derive + # node_partition from them. This ensures the + # graph-generation partition and the GPContext partition + # are identical, avoiding index mismatches that cause + # OOB crashes. + natoms = len(atomic_numbers_full) + rank_assignments, node_partition = self._compute_a2a_partition( + pos=data_dict["pos"], + total_atoms=natoms, + device=atomic_numbers_full.device, + world_size=gp_utils.get_gp_world_size(), + rank=gp_utils.get_gp_rank(), + strategy=self.gp_partition_strategy, + ) + else: + # All-gather: only supports contiguous (index_split) + # partitioning. __init__ blocks spatial+allgather. + node_partition = torch.tensor_split( + torch.arange( + len(atomic_numbers_full), + device=atomic_numbers_full.device, + ), + gp_utils.get_gp_world_size(), + )[gp_utils.get_gp_rank()] + assert ( node_partition.numel() > 0 ), "Looks like there is no atoms in this graph paralell partition. Cannot proceed" @@ -607,13 +685,15 @@ def _generate_graph(self, data_dict): "pbc" in data_dict ), "Since always_use_pbc is False, pbc conditions must be supplied by the input data" pbc = data_dict["pbc"] - # for v2 graph gen we used to pass node_partition as part of the data_dict directly to radius_pbc to allow it generate partial graphs - # to make it more general to accomodate v3, we scrapped and instead have generate_graph handle the partitioning after the graph has been generated + assert ( + pbc.all() or (~pbc).all() + ), "We can only accept pbc that is all true or all false" + graph_dict = generate_graph( data_dict, cutoff=self.cutoff, max_neighbors=self.max_neighbors, - enforce_max_neighbors_strictly=self.enforce_max_neighbors_strictly, + enforce_max_neighbors_strictly=(self.enforce_max_neighbors_strictly), radius_pbc_version=self.radius_pbc_version, pbc=pbc, node_partition=node_partition, @@ -661,10 +741,56 @@ def _generate_graph(self, data_dict): node_partition ] data_dict["batch"] = data_dict["batch_full"][node_partition] - data_dict["gp_node_offset"] = node_partition.min().item() + + # Build GPContext for all-to-all communication + if self.use_all_to_all_gp: + with record_function("a2a_build_gp_context"): + gp_ctx = build_gp_context( + edge_index=graph_dict["edge_index"], + rank_assignments=rank_assignments, + rank=gp_utils.get_gp_rank(), + world_size=gp_utils.get_gp_world_size(), + node_partition=node_partition, + ) + data_dict["gp_ctx"] = gp_ctx + # Store rank_assignments so output heads can reorder + # gathered forces/stress from partition-concatenated order + # back to global index order. Only needed for A2A where + # partitions are non-consecutive (spatial). + data_dict["gp_rank_assignments"] = rank_assignments + # A2A scatter_target comes from gp_ctx.edge_index_local[1] + # which is set dynamically in InteractionBlock.forward(). + data_dict["scatter_target"] = None + else: + # Allgather: pre-compute local target indices for scatter + # operations. Maps global edge targets to 0-based local + # partition indices. Works for both contiguous (index_split) + # and non-contiguous (spatial) partitions. + total_atoms = len(data_dict["atomic_numbers_full"]) + device = graph_dict["edge_index"].device + global_to_local = torch.zeros( + total_atoms, dtype=torch.long, device=device + ) + global_to_local[node_partition] = torch.arange( + len(node_partition), device=device + ) + data_dict["scatter_target"] = global_to_local[ + graph_dict["edge_index"][1] + ] if graph_dict["edge_index"].shape[1] == 0: - add_n_empty_edges(graph_dict, 1, self.cutoff, data_dict["gp_node_offset"]) + add_n_empty_edges(graph_dict, 1, self.cutoff) + # Also update scatter_target for the dummy edge (target=0) + if ( + "scatter_target" in data_dict + and data_dict["scatter_target"] is not None + ): + data_dict["scatter_target"] = torch.cat( + [ + data_dict["scatter_target"].new_zeros(1), + data_dict["scatter_target"], + ] + ) return graph_dict @@ -740,6 +866,13 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]: ) self.log_MOLE_stats() + # Retrieve precomputed all-to-all context (needed for edge embedding + # and message passing layers) + gp_ctx: GPContext | None = data_dict.get("gp_ctx", None) + send_indices: torch.Tensor | None = None + if gp_ctx is not None: + send_indices = gp_ctx.send_indices + # edge degree embedding with record_function("edge embedding"): dist_scaled = graph_dict["edge_distance"] / self.cutoff @@ -764,9 +897,18 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]: x_message = self.edge_degree_embedding( x_message, x_edge, - graph_dict["edge_index"], + # scatter_target contains pre-computed local target indices + # for scattering edge messages to nodes. For A2A, use + # gp_ctx.edge_index_local[1] (already local). For allgather, + # use pre-computed global→local mapped targets. + ( + gp_ctx.edge_index_local[1] + if gp_ctx is not None + else data_dict.get( + "scatter_target", default=graph_dict["edge_index"][1] + ) + ), wigner_inv_envelope, - data_dict["gp_node_offset"], ) ############################################################### @@ -791,7 +933,9 @@ def forward(self, data_dict: AtomicData) -> dict[str, torch.Tensor]: 0 ], sys_node_embedding=sys_node_embedding, - node_offset=data_dict["gp_node_offset"], + scatter_target=data_dict.get("scatter_target", default=None), + gp_ctx=gp_ctx, + send_indices=send_indices, ) # balance any channels requested x_message = self.balance_channels( @@ -1134,6 +1278,18 @@ def forward(self, data_dict: AtomicData, emb: dict[str, torch.Tensor]): forces = gp_utils.gather_from_model_parallel_region( forces, data_dict["atomic_numbers_full"].shape[0] ) + # A2A spatial partitions are non-consecutive, so the + # gathered forces are in partition-concatenated order + # (NOT global index order). Reorder to match positions. + ra = data_dict.get("gp_rank_assignments", None) + if ra is not None: + ws = gp_utils.get_gp_world_size() + perm = torch.cat( + [(ra == r).nonzero(as_tuple=True)[0] for r in range(ws)] + ) + forces_ordered = torch.empty_like(forces) + forces_ordered[perm] = forces + forces = forces_ordered return {"forces": forces} diff --git a/src/fairchem/core/models/uma/escn_md_block.py b/src/fairchem/core/models/uma/escn_md_block.py index 614f5ca8c3..a4bd835472 100644 --- a/src/fairchem/core/models/uma/escn_md_block.py +++ b/src/fairchem/core/models/uma/escn_md_block.py @@ -16,6 +16,10 @@ from typing_extensions import Literal from fairchem.core.common import gp_utils +from fairchem.core.common.parallelism.graph_parallel_a2a import ( + GPContext, + all_to_all_collect, +) from fairchem.core.models.uma.nn.activation import ( GateActivation, SeparableS2Activation_M, @@ -77,7 +81,8 @@ def __init__( ) extra_m0_output_channels = self.lmax * self.hidden_channels elif self.act_type == "s2": - # NOTE: this is the only place where the SO3 grid of the edges (lmax/mmax) is used + # NOTE: this is the only place where the SO3 grid of the + # edges (lmax/mmax) is used self.act = SeparableS2Activation_M( lmax=self.lmax, mmax=self.mmax, @@ -117,37 +122,69 @@ def forward( wigner, wigner_inv_envelope, total_atoms_across_gp_ranks, - node_offset: int = 0, + scatter_target: torch.Tensor | None = None, + gp_ctx: GPContext | None = None, + send_indices: torch.Tensor | None = None, ): - # we perform the all gather upfront once during each forward call so we don't need to repeat this multiple times during activation checkpointing. - if gp_utils.initialized(): - x_full = gp_utils.gather_from_model_parallel_region_sum_grad( - x, total_atoms_across_gp_ranks - ) + """ + Forward pass with support for both all-gather and all-to-all GP. + + When gp_ctx is provided, uses all-to-all to collect only the + needed remote embeddings. Otherwise falls back to all-gather. + + Args: + scatter_target: Pre-computed local target indices [E] for + scattering edge messages to nodes. For allgather, this + is ``edge_index[1]`` mapped to local partition space. + For A2A, derived from ``gp_ctx.edge_index_local[1]``. + If None, defaults to ``edge_index[1]`` (no GP). + """ + if gp_ctx is not None and gp_utils.initialized(): + # All-to-all path: collect only needed remote embeddings. + with record_function("a2a_collect"): + x_received = all_to_all_collect(x, gp_ctx, send_indices) + x_full = torch.cat([x, x_received], dim=0) + edge_index_local = gp_ctx.edge_index_local + local_scatter_target = edge_index_local[1] + elif gp_utils.initialized(): + # Legacy all-gather path + with record_function("allgather_collect"): + x_full = gp_utils.gather_from_model_parallel_region_sum_grad( + x, total_atoms_across_gp_ranks + ) + edge_index_local = edge_index + local_scatter_target = scatter_target else: x_full = x + edge_index_local = edge_index + local_scatter_target = ( + scatter_target if scatter_target is not None else edge_index[1] + ) if self.activation_checkpoint_chunk_size is None: return self.forward_chunk( x_full, x.shape[0], x_edge, - edge_index, + edge_index_local, wigner, wigner_inv_envelope, - node_offset, + local_scatter_target, ) - edge_index_partitions = edge_index.split( + edge_index_partitions = edge_index_local.split( self.activation_checkpoint_chunk_size, dim=1 ) + scatter_target_partitions = local_scatter_target.split( + self.activation_checkpoint_chunk_size + ) wigner_partitions = wigner.split(self.activation_checkpoint_chunk_size, dim=0) wigner_inv_partitions = wigner_inv_envelope.split( self.activation_checkpoint_chunk_size, dim=0 ) x_edge_partitions = x_edge.split(self.activation_checkpoint_chunk_size, dim=0) new_embeddings = [] - # when chunking, we need to keep track of the start index of the chunk and give this information - # to the mole layers + # when chunking, we need to keep track of the start index + # of the chunk and give this information to the mole layers ac_mole_start_idx = 0 for idx in range(len(edge_index_partitions)): @@ -160,7 +197,7 @@ def forward( edge_index_partitions[idx], wigner_partitions[idx], wigner_inv_partitions[idx], - node_offset, + scatter_target_partitions[idx], ac_mole_start_idx, use_reentrant=False, ) @@ -179,11 +216,11 @@ def forward_chunk( edge_index, wigner, wigner_inv_envelope, - node_offset: int = 0, + scatter_target: torch.Tensor | None = None, ac_mole_start_idx: int = 0, ): - # here we need to update the ac_start_idx of the mole layers under here for this chunking to - # work properly with MoLE together + # here we need to update the ac_start_idx of the mole layers + # under here for this chunking to work properly with MoLE set_mole_ac_start_index(self, ac_mole_start_idx) with record_function("SO2Conv"): @@ -196,9 +233,8 @@ def forward_chunk( new_embedding = self.backend.permute_wigner_inv_edge_to_node( x_message, wigner_inv_envelope, - edge_index, + scatter_target if scatter_target is not None else edge_index[1], x_original_shape, - node_offset, ) # reset ac start index @@ -354,7 +390,9 @@ def forward( wigner_inv_envelope, total_atoms_across_gp_ranks, sys_node_embedding=None, - node_offset: int = 0, + scatter_target: torch.Tensor | None = None, + gp_ctx: GPContext | None = None, + send_indices: torch.Tensor | None = None, ): x_res = x x = self.norm_1(x) @@ -370,7 +408,9 @@ def forward( wigner, wigner_inv_envelope, total_atoms_across_gp_ranks=total_atoms_across_gp_ranks, - node_offset=node_offset, + scatter_target=scatter_target, + gp_ctx=gp_ctx, + send_indices=send_indices, ) x = x + x_res diff --git a/src/fairchem/core/models/uma/nn/embedding.py b/src/fairchem/core/models/uma/nn/embedding.py index bb4e4c4748..3f25be41fc 100644 --- a/src/fairchem/core/models/uma/nn/embedding.py +++ b/src/fairchem/core/models/uma/nn/embedding.py @@ -78,9 +78,8 @@ def forward_chunk( self, x, x_edge, - edge_index, + scatter_target, wigner_inv_envelope, - node_offset=0, ): radial = self.rad_func(x_edge) @@ -88,46 +87,42 @@ def forward_chunk( x, radial, wigner_inv_envelope, - edge_index, + scatter_target, self.m_0_num_coefficients, self.sphere_channels, self.rescale_factor, - node_offset, ) def forward( self, x, x_edge, - edge_index, + scatter_target, wigner_inv_envelope, - node_offset=0, ): if self.activation_checkpoint_chunk_size is None: return self.forward_chunk( x, x_edge, - edge_index, + scatter_target, wigner_inv_envelope, - node_offset, ) - edge_index_partitions = edge_index.split( - self.activation_checkpoint_chunk_size, dim=1 + scatter_target_partitions = scatter_target.split( + self.activation_checkpoint_chunk_size ) wigner_inv_partitions = wigner_inv_envelope.split( self.activation_checkpoint_chunk_size, dim=0 ) x_edge_partitions = x_edge.split(self.activation_checkpoint_chunk_size, dim=0) - for idx in range(len(edge_index_partitions)): + for idx in range(len(scatter_target_partitions)): x = torch.utils.checkpoint.checkpoint( self.forward_chunk, x, x_edge_partitions[idx], - edge_index_partitions[idx], + scatter_target_partitions[idx], wigner_inv_partitions[idx], - node_offset, use_reentrant=False, ) diff --git a/src/fairchem/core/models/uma/nn/execution_backends.py b/src/fairchem/core/models/uma/nn/execution_backends.py index 3fac942a4b..a2b5387924 100644 --- a/src/fairchem/core/models/uma/nn/execution_backends.py +++ b/src/fairchem/core/models/uma/nn/execution_backends.py @@ -181,9 +181,8 @@ def node_to_edge_wigner_permute( def permute_wigner_inv_edge_to_node( x_message: torch.Tensor, wigner_inv: torch.Tensor, - edge_index: torch.Tensor, + scatter_target: torch.Tensor, num_nodes: int, - node_offset: int = 0, ) -> torch.Tensor: """ Rotate M->L and scatter edge messages to nodes. @@ -193,9 +192,9 @@ def permute_wigner_inv_edge_to_node( Args: x_message: Edge message features [E, M, C] wigner_inv: Inverse Wigner matrices [E, L, M] - edge_index: Edge indices [2, E] + scatter_target: Pre-computed local target indices [E] + for scattering into node output tensor. num_nodes: Total number of nodes (output size) - node_offset: Offset for node indices (for chunking) Returns: Node embeddings [N, L, C] accumulated from edge messages @@ -208,7 +207,7 @@ def permute_wigner_inv_edge_to_node( dtype=x_rotated.dtype, device=x_rotated.device, ) - new_embedding.index_add_(0, edge_index[1] - node_offset, x_rotated) + new_embedding.index_add_(0, scatter_target, x_rotated) return new_embedding @staticmethod @@ -216,11 +215,10 @@ def edge_degree_scatter( x: torch.Tensor, radial_output: torch.Tensor, wigner_inv: torch.Tensor, - edge_index: torch.Tensor, + scatter_target: torch.Tensor, m_0_num_coefficients: int, sphere_channels: int, rescale_factor: float, - node_offset: int = 0, ) -> torch.Tensor: """ Edge degree embedding: rotate radial and scatter to nodes. @@ -232,12 +230,12 @@ def edge_degree_scatter( radial_output: RadialMLP output [E, m0 * C] wigner_inv: Wigner inverse with envelope pre-fused [E, L, m0] or [E, L, L] - edge_index: Edge indices [2, E] + scatter_target: Pre-computed local target indices [E] + for scattering into node output tensor. m_0_num_coefficients: Number of m=0 coefficients (3 for lmax=2) sphere_channels: Number of channels C rescale_factor: Aggregation rescale factor - node_offset: Node offset for graph parallelism Returns: Updated node features [N, L, C] @@ -256,7 +254,7 @@ def edge_degree_scatter( # Scatter to destination nodes with rescaling return x.index_add( 0, - edge_index[1] - node_offset, + scatter_target, x_edge_embedding / rescale_factor, ) @@ -377,9 +375,8 @@ def node_to_edge_wigner_permute( def permute_wigner_inv_edge_to_node( x_message: torch.Tensor, wigner_inv: torch.Tensor, - edge_index: torch.Tensor, + scatter_target: torch.Tensor, num_nodes: int, - node_offset: int = 0, ) -> torch.Tensor: from fairchem.core.models.uma.triton import ( UMASFastGPUPermuteWignerInvEdgeToNode, @@ -393,7 +390,7 @@ def permute_wigner_inv_edge_to_node( dtype=x_rotated.dtype, device=x_rotated.device, ) - new_embedding.index_add_(0, edge_index[1] - node_offset, x_rotated) + new_embedding.index_add_(0, scatter_target, x_rotated) return new_embedding @staticmethod @@ -401,11 +398,10 @@ def edge_degree_scatter( x: torch.Tensor, radial_output: torch.Tensor, wigner_inv: torch.Tensor, - edge_index: torch.Tensor, + scatter_target: torch.Tensor, m_0_num_coefficients: int, sphere_channels: int, rescale_factor: float, - node_offset: int = 0, ) -> torch.Tensor: radial = radial_output.reshape(-1, m_0_num_coefficients, sphere_channels) @@ -417,7 +413,7 @@ def edge_degree_scatter( return x.index_add( 0, - edge_index[1] - node_offset, + scatter_target, x_edge_embedding / rescale_factor, ) diff --git a/tests/core/units/mlip_unit/test_predict.py b/tests/core/units/mlip_unit/test_predict.py index e384dafde0..b8ec984fae 100644 --- a/tests/core/units/mlip_unit/test_predict.py +++ b/tests/core/units/mlip_unit/test_predict.py @@ -9,6 +9,7 @@ import contextlib import logging +import os from copy import deepcopy import numpy as np @@ -197,8 +198,17 @@ def test_multiple_dataset_predict(internal_graph_gen_version): npt.assert_allclose(pred_forces[batch_batch == 2], pt.get_forces(), atol=ATOL) -def _test_parallel_predict_unit_impl(workers, device, checkpointing, graph_gen_version): - """Implementation of parallel predict unit test.""" +def _test_parallel_predict_unit_impl( + workers, device, checkpointing, graph_gen_version, gp_mode=None +): + """ + Implementation of parallel predict unit test. + + Args: + gp_mode: Optional dict with GP overrides, e.g. + {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}. + If None, uses default GP settings (allgather + index_split). + """ seed = 42 runs = 2 model_path = pretrained_checkpoint_path_from_name("uma-s-1p1") @@ -213,12 +223,17 @@ def _test_parallel_predict_unit_impl(workers, device, checkpointing, graph_gen_v atoms = get_fcc_crystal_by_num_atoms(num_atoms) atomic_data = AtomicData.from_ase(atoms, task_name=["omat"]) + overrides = None + if gp_mode is not None: + overrides = {"backbone": gp_mode} + seed_everywhere(seed) ppunit = ParallelMLIPPredictUnit( inference_model_path=model_path, device=device, inference_settings=ifsets, num_workers=workers, + overrides=overrides, ) for _ in range(runs): pp_results = ppunit.predict(atomic_data) @@ -248,37 +263,72 @@ def _test_parallel_predict_unit_impl(workers, device, checkpointing, graph_gen_v @pytest.mark.serial() @pytest.mark.parametrize( - "workers, checkpointing, graph_gen_version", + "workers, checkpointing, graph_gen_version, gp_mode", [ - (1, False, 2), - (2, False, 2), - (1, False, 3), - (1, True, 3), - (2, False, 3), + # Default GP mode (allgather + index_split) + (1, False, 2, None), + (2, False, 2, None), + (1, False, 3, None), + (1, True, 3, None), + (2, False, 3, None), + # A2A + spatial (requires workers >= 2 for actual GP) + (2, False, 2, {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}), + # A2A + index_split + ( + 2, + False, + 2, + {"use_all_to_all_gp": True, "gp_partition_strategy": "index_split"}, + ), ], ) -def test_parallel_predict_unit_cpu(workers, checkpointing, graph_gen_version): - _test_parallel_predict_unit_impl(workers, "cpu", checkpointing, graph_gen_version) +def test_parallel_predict_unit_cpu(workers, checkpointing, graph_gen_version, gp_mode): + _test_parallel_predict_unit_impl( + workers, "cpu", checkpointing, graph_gen_version, gp_mode + ) @pytest.mark.gpu() @pytest.mark.parametrize( - "workers, checkpointing, graph_gen_version", + "workers, checkpointing, graph_gen_version, gp_mode", [ - (1, False, 2), - (1, True, 2), - (1, True, 3), - (1, False, 3), - # (2, False), - # (2, True), + # Default GP mode (allgather + index_split) + (1, False, 2, None), + (1, True, 2, None), + (1, True, 3, None), + (1, False, 3, None), + # GP modes with 1 worker — verifies code paths don't crash + # (single rank = no actual GP comms, but exercises config flow) + (1, False, 2, {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}), + ( + 1, + False, + 2, + {"use_all_to_all_gp": True, "gp_partition_strategy": "index_split"}, + ), + # For local 8-GPU runs: uncomment to test multi-worker GPU GP + # (2, False, 2, None), + # (2, False, 2, {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}), + # (2, False, 2, {"use_all_to_all_gp": True, "gp_partition_strategy": "index_split"}), ], ) -def test_parallel_predict_unit_gpu(workers, checkpointing, graph_gen_version): - _test_parallel_predict_unit_impl(workers, "cuda", checkpointing, graph_gen_version) +def test_parallel_predict_unit_gpu(workers, checkpointing, graph_gen_version, gp_mode): + _test_parallel_predict_unit_impl( + workers, "cuda", checkpointing, graph_gen_version, gp_mode + ) -def _test_parallel_predict_unit_batch_impl(workers, device, checkpointing): - """Implementation of parallel predict unit batch test.""" +def _test_parallel_predict_unit_batch_impl( + workers, device, checkpointing, gp_mode=None +): + """ + Implementation of parallel predict unit batch test. + + Args: + gp_mode: Optional dict with GP overrides, e.g. + {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}. + If None, uses default GP settings (allgather + index_split). + """ seed = 42 runs = 1 model_path = pretrained_checkpoint_path_from_name("uma-s-1p1") @@ -313,11 +363,17 @@ def _test_parallel_predict_unit_batch_impl(workers, device, checkpointing): ) atomic_data = atomicdata_list_to_batch([h2o_data, o_data]) seed_everywhere(seed) + + overrides = None + if gp_mode is not None: + overrides = {"backbone": gp_mode} + ppunit = ParallelMLIPPredictUnit( inference_model_path=model_path, device=device, inference_settings=ifsets, num_workers=workers, + overrides=overrides, ) for _ in range(runs): pp_results = ppunit.predict(atomic_data) @@ -345,28 +401,34 @@ def _test_parallel_predict_unit_batch_impl(workers, device, checkpointing): @pytest.mark.serial() @pytest.mark.parametrize( - "workers, checkpointing", + "workers, checkpointing, gp_mode", [ - (1, False), - (2, True), + (1, False, None), + (2, True, None), + # A2A + spatial with batch (exercises halo bail-out for multi-system) + ( + 2, + False, + {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}, + ), ], ) -def test_parallel_predict_unit_batch(workers, checkpointing): - _test_parallel_predict_unit_batch_impl(workers, "cpu", checkpointing) +def test_parallel_predict_unit_batch(workers, checkpointing, gp_mode): + _test_parallel_predict_unit_batch_impl(workers, "cpu", checkpointing, gp_mode) @pytest.mark.gpu() @pytest.mark.parametrize( - "workers, checkpointing", + "workers, checkpointing, gp_mode", [ - (1, True), - (1, False), - # (2, True), - # (2, False), + (1, True, None), + (1, False, None), + # (2, True, None), + # (2, False, None), ], ) -def test_parallel_predict_unit_batch_gpu(workers, checkpointing): - _test_parallel_predict_unit_batch_impl(workers, "cuda", checkpointing) +def test_parallel_predict_unit_batch_gpu(workers, checkpointing, gp_mode): + _test_parallel_predict_unit_batch_impl(workers, "cuda", checkpointing, gp_mode) @pytest.mark.gpu() @@ -1004,7 +1066,17 @@ def test_batch_server_predict_unit_multiple_systems(batch_server_handle): @pytest.mark.parametrize("workers", [0, 2]) @pytest.mark.parametrize("ensemble", ["nvt", "npt"]) @pytest.mark.parametrize("device", ["cpu"]) -def test_merge_mole_md_consistency(workers, ensemble, device): +@pytest.mark.parametrize( + "gp_mode", + [ + None, # default allgather + index_split + # A2A + spatial — tests multi-step MD with spatial repartitioning + {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}, + # A2A + index_split — tests A2A with contiguous partitioning + {"use_all_to_all_gp": True, "gp_partition_strategy": "index_split"}, + ], +) +def test_merge_mole_md_consistency(workers, ensemble, device, gp_mode): """Test merge_mole vs no-merge consistency over MD trajectory. Runs 3 trials: @@ -1015,9 +1087,16 @@ def test_merge_mole_md_consistency(workers, ensemble, device): Compares the relative drift of A-C against baseline A-B to ensure merge_mole doesn't introduce additional numerical drift beyond the inherent noise between identical runs. + + When gp_mode is not None, passes backbone overrides to enable + A2A graph parallel with the specified partition strategy. """ import torch + # A2A GP modes require workers >= 2 to actually exercise multi-rank + if gp_mode is not None and workers < 2: + pytest.skip("A2A GP mode requires workers >= 2") + torch.use_deterministic_algorithms(True) from ase import units @@ -1093,6 +1172,9 @@ def run_md_trial(atoms, calc, seed, steps): "stresses": np.array(stresses), } + # Build overrides for GP mode + overrides = {"backbone": gp_mode} if gp_mode is not None else None + # Trial A: no merge settings_no_merge = InferenceSettings(merge_mole=False, **base_settings) predict_unit_A = pretrained_mlip.get_predict_unit( @@ -1100,6 +1182,7 @@ def run_md_trial(atoms, calc, seed, steps): device=device, inference_settings=settings_no_merge, workers=workers, + overrides=overrides, ) calc_A = FAIRChemCalculator(predict_unit_A, task_name="omat") results_A = run_md_trial(atoms_template, calc_A, seed=42, steps=md_steps) @@ -1111,6 +1194,7 @@ def run_md_trial(atoms, calc, seed, steps): device=device, inference_settings=settings_no_merge, workers=workers, + overrides=overrides, ) calc_B = FAIRChemCalculator(predict_unit_B, task_name="omat") results_B = run_md_trial(atoms_template, calc_B, seed=42, steps=md_steps) @@ -1119,7 +1203,11 @@ def run_md_trial(atoms, calc, seed, steps): # Trial C: merge settings_merge = InferenceSettings(merge_mole=True, **base_settings) predict_unit_C = pretrained_mlip.get_predict_unit( - "uma-s-1p1", device=device, inference_settings=settings_merge, workers=workers + "uma-s-1p1", + device=device, + inference_settings=settings_merge, + workers=workers, + overrides=overrides, ) calc_C = FAIRChemCalculator(predict_unit_C, task_name="omat") results_C = run_md_trial(atoms_template, calc_C, seed=42, steps=md_steps) @@ -1773,3 +1861,97 @@ def test_execution_mode_not_set_when_conditions_not_met(model_name): f"Expected execution_mode to be None when activation_checkpointing=True, " f"got {predict_unit.inference_settings.execution_mode}" ) + + +# ========================================================================= +# Full-model GP correctness: no-GP vs all-gather vs A2A +# Skipped on CI (no multi-GPU), run locally or via SLURM +# ========================================================================= + +_skip_if_ci = pytest.mark.skipif( + os.environ.get("CI") == "true", + reason="Multi-GPU test, skipped in CI", +) + + +@_skip_if_ci +@pytest.mark.gpu() +@pytest.mark.parametrize("num_atoms", [10, 50, 100]) +@pytest.mark.parametrize( + "workers, gp_mode", + [ + # All-gather (default GP) + (1, None), + (2, None), + (4, None), + # A2A + spatial + (1, {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}), + (2, {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}), + (4, {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}), + # A2A + index_split + (1, {"use_all_to_all_gp": True, "gp_partition_strategy": "index_split"}), + (2, {"use_all_to_all_gp": True, "gp_partition_strategy": "index_split"}), + (4, {"use_all_to_all_gp": True, "gp_partition_strategy": "index_split"}), + ], +) +def test_full_model_gp_correctness(num_atoms, workers, gp_mode): + seed = 42 + model_path = pretrained_checkpoint_path_from_name("uma-s-1p1") + ifsets = InferenceSettings( + tf32=False, + merge_mole=True, + activation_checkpointing=False, + internal_graph_gen_version=2, + external_graph_gen=False, + ) + atoms = get_fcc_crystal_by_num_atoms(num_atoms) + atomic_data = AtomicData.from_ase(atoms, task_name=["omat"]) + + overrides = None + if gp_mode is not None: + overrides = {"backbone": gp_mode} + + seed_everywhere(seed) + ppunit = ParallelMLIPPredictUnit( + inference_model_path=model_path, + device="cuda", + inference_settings=ifsets, + num_workers=workers, + overrides=overrides, + ) + pp_results = ppunit.predict(atomic_data) + distutils.cleanup_gp_ray() + + seed_everywhere(seed) + ref_unit = pretrained_mlip.get_predict_unit( + "uma-s-1p1", device="cuda", inference_settings=ifsets + ) + ref_results = ref_unit.predict(atomic_data) + + assert torch.allclose( + pp_results["energy"].detach().cpu(), + ref_results["energy"].detach().cpu(), + atol=ATOL, + ), ( + f"Energy mismatch: workers={workers}, gp_mode={gp_mode}, " + f"num_atoms={num_atoms}, " + f"pp={pp_results['energy'].item():.6f}, " + f"ref={ref_results['energy'].item():.6f}" + ) + assert torch.allclose( + pp_results["forces"].detach().cpu(), + ref_results["forces"].detach().cpu(), + atol=FORCE_TOL, + ), ( + f"Forces mismatch: workers={workers}, gp_mode={gp_mode}, " + f"num_atoms={num_atoms}, " + f"max_diff={torch.max(torch.abs(pp_results['forces'].detach().cpu() - ref_results['forces'].detach().cpu())).item():.6e}" + ) + assert torch.allclose( + pp_results["stress"].detach().cpu(), + ref_results["stress"].detach().cpu(), + atol=ATOL, + ), ( + f"Stress mismatch: workers={workers}, gp_mode={gp_mode}, " + f"num_atoms={num_atoms}" + )