Skip to content

fix(distributed): wire to torch-sla DSparseTensor API (post PR #31)#16

Closed
walkerchi wants to merge 5 commits into
dev/corefrom
feat/torch-sla-dsparse-integration
Closed

fix(distributed): wire to torch-sla DSparseTensor API (post PR #31)#16
walkerchi wants to merge 5 commits into
dev/corefrom
feat/torch-sla-dsparse-integration

Conversation

@walkerchi

Copy link
Copy Markdown
Collaborator

Re-opens #14 against the new dev/core integration branch (the original dev branch was retired during the rename, GitHub auto-closed the PR and refused to reopen against a deleted base).

Summary

Status

Multi-proc tests pass under Gloo (CPU, 16/16, ~30s on autodl 2x A100).

Test plan

  • pytest tests/distributed/test_distributed_assemble.py on autodl, Gloo backend, world=2 and world=4

(continuation of closed PR #14)

walkerchi and others added 5 commits June 14, 2026 22:27
torch-sla PR #31 removed the legacy DSparseTensor(values, row, col,
shape=..., num_partitions=..., coords=..., partition_method=...) ctor
in favour of the classmethod constructors (partition / from_local /
from_sparse_local). The old single-process simulator path is gone.

distributed_element_assemble was calling the dead ctor, so every
distributed assembly path errored out with TypeError on import.

Switched to ``DSparseTensor.partition(A_global, mesh, partition_method,
coords)`` via a small ``_build_dsparse`` helper:

  * torch.distributed initialised -> real distributed-shard
    DSparseTensor on the live world (one rank, one row-shard).
  * Not initialised (single-process driver + multi-thread assembly,
    or unit tests) -> mesh=None gives world=1; the DSparseTensor
    holds the entire global matrix locally. matvec / solve still
    compose with the rest of torch-sla.

Added test_multiproc_distributed_assemble_matvec_2procs: spawns 2
gloo ranks, each runs TensorMesh distributed assembly, builds a
DSparseTensor, runs matvec in Shard(0) space (x sliced to owned),
allgathers the result, compares to reference single-mesh assembly.
Rel err 0 vs reference on a triangle rectangle mesh.

All 11 distributed tests pass.
…t tests

The original PR only had a multi-proc *matvec* test. Now adds three
real end-to-end tests that exercise the full distributed-solve path
TensorMesh -> DSparseTensor -> cg_shard -> allgather -> compare:

  * test_multiproc_distributed_solve_2procs: assemble Mass matrix on
    2 gloo ranks, build b = M @ x_ref from a single-process reference,
    run distributed CG, allgather x, verify x_dist == x_ref to 1e-6.
    Mass instead of Laplace because Laplace has a constant null space
    that distributed CG drifts along (separate issue).

  * test_multiproc_distributed_solve_4procs: same with world=4 to
    exercise more partitions + more halo edges.

  * test_multiproc_poisson_dirichlet_2procs: full physics path on
    2 ranks (assemble + Condenser for Dirichlet BCs + single-process
    SparseMatrix solve via distributed_element_assemble_to_sparse,
    compare to single-process reference). The to_sparse path returns
    a tensormesh SparseMatrix, not a DSparseTensor -- this test
    verifies the distributed assembly produces the same global matrix
    as single-process even under multi-process orchestration.

Implementation note: the solve worker calls cg_shard directly (raw
tensor in / out) instead of the higher-level torch_sla.solve(D, b_dt)
wrapper. The wrapper expects DTensor[Shard(0)] for b; building one
from a manually-sliced owned tensor via DTensor.from_local works but
adds DTensor wrapping overhead unrelated to the test goal.

14/14 distributed tests pass.
The previous test_multiproc_poisson_dirichlet_2procs only exercised
distributed_element_assemble_to_sparse which returns a SparseMatrix
(single-process) -- never touched the actual DSparseTensor distributed
solve path. Adds two tests that close that gap:

  test_multiproc_poisson_dirichlet_dsparse_2procs
  test_multiproc_poisson_dirichlet_dsparse_4procs

Pipeline per rank:
  1. TensorMesh distributed assembly -> global Laplace K + load f.
  2. Single-process Condenser strips Dirichlet DOFs -> K_inner, f_inner
     (cheap at this size; the partitioning happens after).
  3. Wrap K_inner as torch_sla.SparseTensor + DSparseTensor.partition
     (partition_method=metis, since Condenser breaks the 1:1 mesh-point
     mapping that RCB/coordinate methods would need).
  4. Slice f_inner to this rank's owned rows.
  5. Run distributed CG via cg_shard with M_apply=identity.
  6. Allgather x_owned -> u_inner_global -> condenser.recover() -> full u.
  7. Compare to single-process reference u: rel-err < 1e-6 required.

This is the actual integration deliverable: real PDE, real distributed
matrix, real distributed solver, end-to-end matches single-process.

16/16 distributed tests pass on CPU (2 new + 14 existing).
The two end-to-end distributed solve tests called the internal
cg_shard primitive directly. Tests represent the user surface, so they
should go through the public torch_sla.solve API with iterative
defaults scoped via SolverConfig. Same dispatch path under the hood
(SolverConfig -> solve -> cg_shard for DSparseTensor+CG) but the
example now reads like the documented usage.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…atter

Three sites in test_distributed_assemble.py each open-coded the same
pad -> 3x dist.all_gather -> Python for-loop scatter pattern (~14 lines
each). The torch-sla side just landed a vectorised public helper
(sparsexlab/torch-sla#41 -- gather_owned_to_global with one index_put_
over an all_gather_into_tensor buffer), so swap each site to a single
call.

-44 lines, +6 lines, same semantics.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
@walkerchi

Copy link
Copy Markdown
Collaborator Author

Superseded by #18 (single rolled-up PR for the whole stack — #16 commits are included).

@walkerchi walkerchi closed this Jun 15, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant