Skip to content

[Draft] a2a model#2021

Draft
rayg1234 wants to merge 27 commits into
mainfrom
rgao_a2a_model
Draft

[Draft] a2a model#2021
rayg1234 wants to merge 27 commits into
mainfrom
rgao_a2a_model

Conversation

@rayg1234

@rayg1234 rayg1234 commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

No description provided.

rgao user added 4 commits May 12, 2026 05:48
All-to-all (A2A) communication module for graph parallelism, replacing
the all-gather approach with point-to-point exchange of only boundary
atoms.

Key components:
- GPContext: dataclass holding per-rank atom assignments and A2A metadata
- build_gp_context: builds communication plan from edge connectivity
- AllToAllCollect: autograd-compatible A2A embedding exchange
- all_to_all_collect_compiled: torch.compile-friendly variant using
  functional collectives (no graph break)
- _safe_all_to_all: Gloo fallback for CPU testing
- _sparse_index_exchange: variable-split index exchange

Tests (22 total, all CPU/Gloo):
- 5 unit tests for build_gp_context (context building, global-to-local
  mapping, edge split indices)
- 9 distributed tests comparing A2A vs all-gather correctness
  (forward, backward, multi-rank, spatial partition)
- 8 distributed correctness tests (dense/sparse graphs, multi-dim
  embeddings, index_split + spatial strategies)
Model integration for all-to-all graph parallelism:

eSCN-MD backbone (escn_md.py):
- Add use_all_to_all_gp and gp_partition_strategy config options
- Replace all-gather with A2A embedding exchange when enabled
- Spatial partitioning via Morton Z-order curve
- AABB halo filtering for reduced graph generation input
- Support both autograd and compiled A2A collect variants
- Block allgather+spatial combination (unsupported)

eSCN-MD block (escn_md_block.py):
- Add A2A collect integration in message passing layers
- Use precomputed edge split indices for local/remote separation

Graph generation (compute.py):
- Extend filter_edges_by_node_partition with send_info computation
- Eliminates NCCL index-exchange collective in build_gp_context
- Pass rank_assignments through generate_graph

Embedding/execution backends:
- Thread GP context through embedding and execution layers

Tests:
- Add send_info optimization correctness test
- Expand ParallelPredictUnit tests with A2A+spatial and
  A2A+index_split GP modes (CPU: workers=2, GPU: workers=1)
- 48 parallelism tests + 25 predict tests all pass
Parametrize test_merge_mole_md_consistency with A2A+spatial GP mode
to verify that all-to-all graph parallel with spatial partitioning
works correctly across multiple MD timesteps (NVT and NPT).

Skips A2A modes when workers < 2 since GP requires multi-rank.
2 new test cases pass (NVT+A2A, NPT+A2A), 2 correctly skipped.
…coverage

- Add is_single_system guard in _generate_graph so AABB halo bails
  out for multi-system batches (prevents cell.view(3,3) crash)
- Add defensive assertion in _compute_halo_graph verifying all
  local partition atoms appear in the halo mask
- Add A2A+index_split mode to MD consistency test parametrization
- Add A2A+spatial mode to batch predict test (exercises halo
  bail-out path for multi-system batches)
- Document single-system limitation in _compute_aabb_halo docstring
@meta-cla meta-cla Bot added the cla signed label Jun 3, 2026
rgao user added 23 commits June 3, 2026 22:58
Extract _resolve_send_metadata and _validate_gp_mappings from the
monolithic build_gp_context (270→120 lines). Validation now uses
fast-path .any().item() checks with full diagnostics only on error.

Add _allgather_index_exchange: packs recv_counts + needed_atoms into
one buffer for a single all-gather call (vs 2x all-to-all). Wired
into eSCN-MD via gp_index_exchange_method constructor parameter.

Benchmarks at GP=64 (8 nodes, 256K atoms): 2xA2A +22.9%, 1xAG +20.3%
vs baseline all-gather GP. Both variants beat baseline at scale.
Benchmarks showed 2xA2A is ~1.5% faster than 1xallgather at GP=64
across 8 nodes (0.593 vs 0.584 ns/day) due to lower communication
volume with variable splits. Remove the inferior allgather variant,
the index_exchange_method parameter, and the dispatch logic. Add
benchmark rationale to _sparse_index_exchange docstring.
Inference always requires autograd for force computation, so the
no-grad compile-friendly path was never hit. Remove the function,
its dispatch branch in escn_md_block, and associated tests.
The complex 13-parameter diagnostic function was hard to follow.
Replaced with two inline asserts for the only error conditions:
negative edge_index_local entries and out-of-bounds send_indices.
v2 radius graph does internal edge filtering so the full edge_index
is never exposed, making send_info derivation in compute.py
impossible. Revert compute.py to main, remove _resolve_send_metadata,
and always use _sparse_index_exchange in build_gp_context.
Remove 6 construction-only fields from GPContext (node_partition,
rank_assignments, needed_atoms, needed_from_ranks, global_to_local,
total_needed_atoms) that are never read after build_gp_context returns.
Update tests to verify through edge_index_local and message passing
instead of accessing removed intermediates.
The 10s timeout causes flaky failures in CI where runners are
resource-constrained and process rendezvous takes longer.
Benchmarks at 64 GPUs showed halo adds ~5% overhead vs plain A2A.
The _sparse_index_exchange already minimizes communication, and
radius_graph_pbc_v2 grid indexing makes graph gen fast enough that
filtering to a halo subset adds cost without measurable savings.

Removes ~150 lines: _compute_aabb_halo, _compute_halo_graph,
use_aabb_halo flag, and the halo call site in _generate_graph.
Also adds timeout_hr to slurm benchmark config.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant