Skip to content

fix(distributed): wire to torch-sla DSparseTensor API (post PR #31)#14

Closed
walkerchi wants to merge 5 commits into
devfrom
feat/torch-sla-dsparse-integration
Closed

fix(distributed): wire to torch-sla DSparseTensor API (post PR #31)#14
walkerchi wants to merge 5 commits into
devfrom
feat/torch-sla-dsparse-integration

Conversation

@walkerchi

Copy link
Copy Markdown
Collaborator

Summary

torch-sla PR #31 removed the legacy single-process `DSparseTensor(values, row, col, shape=..., num_partitions=..., coords=..., partition_method=...)` constructor in favour of classmethod entry points (`partition` / `from_local` / `from_sparse_local`).

`distributed_element_assemble` was calling the dead ctor, so every distributed assembly path errored out:

```
TypeError: DSparseTensor.init() got an unexpected keyword argument 'shape'
```

Fix

Small `_build_dsparse` helper that wraps the merged global COO through `DSparseTensor.partition(A_global, mesh, partition_method, coords)`:

  • torch.distributed initialised → real distributed-shard DSparseTensor on the live world (each rank gets its own row-shard).
  • Not initialised (single-process driver + multi-thread assembly, or unit tests) → `mesh=None` gives world=1; the DSparseTensor holds the entire global matrix locally. matvec / solve still compose with the rest of torch-sla.

New test

`test_multiproc_distributed_assemble_matvec_2procs`: spawns 2 gloo ranks, each runs TensorMesh distributed assembly, builds a DSparseTensor, runs matvec in Shard(0) space (x sliced to owned), allgathers the result, compares to the single-mesh reference assembly.

Test plan

  • 11/11 distributed assembly tests pass on CPU
  • Multi-process gloo end-to-end (2 ranks, matvec match to fp64 eps)
  • CUDA / multi-GPU runtime — pending GPU box

walkerchi added 3 commits June 14, 2026 22:27
torch-sla PR #31 removed the legacy DSparseTensor(values, row, col,
shape=..., num_partitions=..., coords=..., partition_method=...) ctor
in favour of the classmethod constructors (partition / from_local /
from_sparse_local). The old single-process simulator path is gone.

distributed_element_assemble was calling the dead ctor, so every
distributed assembly path errored out with TypeError on import.

Switched to ``DSparseTensor.partition(A_global, mesh, partition_method,
coords)`` via a small ``_build_dsparse`` helper:

  * torch.distributed initialised -> real distributed-shard
    DSparseTensor on the live world (one rank, one row-shard).
  * Not initialised (single-process driver + multi-thread assembly,
    or unit tests) -> mesh=None gives world=1; the DSparseTensor
    holds the entire global matrix locally. matvec / solve still
    compose with the rest of torch-sla.

Added test_multiproc_distributed_assemble_matvec_2procs: spawns 2
gloo ranks, each runs TensorMesh distributed assembly, builds a
DSparseTensor, runs matvec in Shard(0) space (x sliced to owned),
allgathers the result, compares to reference single-mesh assembly.
Rel err 0 vs reference on a triangle rectangle mesh.

All 11 distributed tests pass.
…t tests

The original PR only had a multi-proc *matvec* test. Now adds three
real end-to-end tests that exercise the full distributed-solve path
TensorMesh -> DSparseTensor -> cg_shard -> allgather -> compare:

  * test_multiproc_distributed_solve_2procs: assemble Mass matrix on
    2 gloo ranks, build b = M @ x_ref from a single-process reference,
    run distributed CG, allgather x, verify x_dist == x_ref to 1e-6.
    Mass instead of Laplace because Laplace has a constant null space
    that distributed CG drifts along (separate issue).

  * test_multiproc_distributed_solve_4procs: same with world=4 to
    exercise more partitions + more halo edges.

  * test_multiproc_poisson_dirichlet_2procs: full physics path on
    2 ranks (assemble + Condenser for Dirichlet BCs + single-process
    SparseMatrix solve via distributed_element_assemble_to_sparse,
    compare to single-process reference). The to_sparse path returns
    a tensormesh SparseMatrix, not a DSparseTensor -- this test
    verifies the distributed assembly produces the same global matrix
    as single-process even under multi-process orchestration.

Implementation note: the solve worker calls cg_shard directly (raw
tensor in / out) instead of the higher-level torch_sla.solve(D, b_dt)
wrapper. The wrapper expects DTensor[Shard(0)] for b; building one
from a manually-sliced owned tensor via DTensor.from_local works but
adds DTensor wrapping overhead unrelated to the test goal.

14/14 distributed tests pass.
The previous test_multiproc_poisson_dirichlet_2procs only exercised
distributed_element_assemble_to_sparse which returns a SparseMatrix
(single-process) -- never touched the actual DSparseTensor distributed
solve path. Adds two tests that close that gap:

  test_multiproc_poisson_dirichlet_dsparse_2procs
  test_multiproc_poisson_dirichlet_dsparse_4procs

Pipeline per rank:
  1. TensorMesh distributed assembly -> global Laplace K + load f.
  2. Single-process Condenser strips Dirichlet DOFs -> K_inner, f_inner
     (cheap at this size; the partitioning happens after).
  3. Wrap K_inner as torch_sla.SparseTensor + DSparseTensor.partition
     (partition_method=metis, since Condenser breaks the 1:1 mesh-point
     mapping that RCB/coordinate methods would need).
  4. Slice f_inner to this rank's owned rows.
  5. Run distributed CG via cg_shard with M_apply=identity.
  6. Allgather x_owned -> u_inner_global -> condenser.recover() -> full u.
  7. Compare to single-process reference u: rel-err < 1e-6 required.

This is the actual integration deliverable: real PDE, real distributed
matrix, real distributed solver, end-to-end matches single-process.

16/16 distributed tests pass on CPU (2 new + 14 existing).
@walkerchi walkerchi changed the base branch from main to dev June 15, 2026 08:03
walkerchi and others added 2 commits June 15, 2026 16:09
The two end-to-end distributed solve tests called the internal
cg_shard primitive directly. Tests represent the user surface, so they
should go through the public torch_sla.solve API with iterative
defaults scoped via SolverConfig. Same dispatch path under the hood
(SolverConfig -> solve -> cg_shard for DSparseTensor+CG) but the
example now reads like the documented usage.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
…atter

Three sites in test_distributed_assemble.py each open-coded the same
pad -> 3x dist.all_gather -> Python for-loop scatter pattern (~14 lines
each). The torch-sla side just landed a vectorised public helper
(sparsexlab/torch-sla#41 -- gather_owned_to_global with one index_put_
over an all_gather_into_tensor buffer), so swap each site to a single
call.

-44 lines, +6 lines, same semantics.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant