Skip to content

Mattral/Composed-Mixture-of-Experts-Engine

moe-engine  ·  A Composed Mixture-of-Experts Engine

Production-grade sparse MoE training runtime. Designed to keep large-scale pre-training jobs alive end-to-end: sparse Top-K routing in custom Triton, DP+EP distributed training with TP support in core layers and PP work in progress, on PyTorch 2.12+, asynchronous sharded checkpointing through a two-tier (NVMe → S3 / MinIO) durable store, and a TorchElastic state-machine that evicts dead ranks, reshards experts, and hot-resumes training without operator intervention.

Apache-2.0 PyTorch Triton


Table of Contents

  1. Why this exists
  2. System architecture
  3. Hardware & software requirements
  4. Installation
  5. Local CPU / Gloo regression workflow
  6. Cluster-scale multi-GPU training
  7. Configuration reference
  8. Mathematical invariants & CI gates
  9. Telemetry envelope
  10. Fault-injection / chaos workflow
  11. Repository layout
  12. License

1. Why this exists

At frontier-lab scale three engineering disciplines that are normally separate teams must instead be co-designed in a single repository:

Layer Concern This repo's contribution
Hardware-aware kernels Memory coalescing, SRAM tiling, Tensor-Core feeding for sparse Top-K routing pkg/kernels/moe_router.py — Triton forward + (planned) backward, dynamic-bound masking, 128-byte aligned loads
Distributed runtime DP+EP training with TP support in core layers, FSDP2 DTensor sharding, EP all_to_all_single overlapped with compute pkg/distributed/parallel_mesh.pyinit_device_mesh((dp, ep)) with TP axis reserved, dedicated comm stream overlap, dedicated CUDA streams
Fault-tolerant infra Async pinned-memory checkpointing, S3/MinIO mirror, evict→reshard→reload state-machine pkg/elastic/fault_monitor.py — TorchElastic harness, SHARDED_STATE_DICT, signal-driven flush

moe-engine keeps these three layers in one binary so an MFU regression or a checkpoint-stall bug can be isolated to a single line, not a six-team incident.


2. System architecture

                                 ┌─────────────────────────────────────────────┐
                                 │              train.py  (entrypoint)         │
                                 │  argparse → load_config → build_topology    │
                                 └──────────────┬──────────────────────────────┘
                                                │
            ┌───────────────────────────────────┼────────────────────────────────────┐
            │                                   │                                    │
            ▼                                   ▼                                    ▼
┌────────────────────────────┐   ┌──────────────────────────────┐   ┌──────────────────────────────┐
│ pkg/distributed/           │   │ pkg/elastic/                 │   │ pkg/kernels/                 │
│   parallel_mesh.py         │   │   fault_monitor.py           │   │   moe_router.py              │
│                            │   │                              │   │                              │
│ • ParallelTopology         │   │ • ElasticTrainerHarness      │   │ • MoERouter (nn.Module)      │
│ • init_device_mesh((dp,ep))│   │ • AsyncCheckpointer          │   │   ├─ forward: Triton @jit    │
│   with TP axis reserved    │   │ • _PinnedHostStager          │   │   │   - dynamic-bound mask   │
│ • DistributedMoELayer      │   │ • ClusterStateMachine        │   │   │   - 128B aligned loads   │
│ • apply_fsdp2(...)         │   │ • LocalNVMeAdapter           │   │   └─ backward: Triton JIT    │
│ • all_to_all_dispatch      │   │ • S3Adapter (boto3)          │   │ • CPU autograd fallback      │
│   on dedicated comm stream │   │                              │   │                              │
└───────────┬────────────────┘   └─────────────┬────────────────┘   └──────────────┬───────────────┘
            │                                  │                                   │
            │   DeviceMesh sub-meshes          │   pinned-host snapshot queue      │  routing tokens
            │   ("pp","dp","ep","tp")          │                                   │  + gating weights
            │                                  │                                   │
            ▼                                  ▼                                   ▼
   ┌────────────────────┐         ┌───────────────────────────┐         ┌──────────────────────┐
   │ NCCL / Gloo        │         │ tier-1 NVMe (staging)     │         │ Triton runtime       │
   │ process groups     │         │ tier-2 S3 / MinIO mirror  │         │ (CUDA / ROCm)        │
   │ (one per axis)     │         │ background I/O thread×N   │         │                      │
   └────────────────────┘         └───────────┬───────────────┘         └──────────────────────┘
                                              │
                                              ▼
                                  ┌───────────────────────────┐
                                  │ TorchElastic agent        │
                                  │ (rdzv: c10d / etcd)       │
                                  │ rendezvous → restart loop │
                                  └───────────────────────────┘

                       Data-flow per training step
                       ───────────────────────────
   ids → embed → (TP shard) → block_0 ── … ── block_N → norm → lm_head → loss
                                  │
                                  ▼
                    DistributedMoELayer.forward
                    ┌────────────────────────────┐
                    │ 1. router (Triton fwd)     │
                    │ 2. sort by target EP rank  │
                    │ 3. all_to_all_single       │
                    │    on a dedicated comm stream ──► launch  ─────────┐
                    │ 4. independent compute  ─── overlap ───►│  GPU compute
                    │ 5. work.wait() on dispatch              │  in flight
                    │ 6. local SwiGLU experts                            │
                    │ 7. all_to_all_combine on a dedicated comm stream ──┘
                    │ 8. weight ⊗ combine → reduce-K  │
                    └──────────────────────────────────┘

Per-rank a coordinate identifies its mesh slice. Sub-meshes are obtained by name: mesh["dp"] (for FSDP2 sharding), mesh["ep"] (for all_to_all_single), mesh["tp"] (for column-parallel linears), and mesh["pp"] is reserved for future pipeline stage send/recv support.


3. Hardware & software requirements

Software

Component Minimum Recommended Notes
Python 3.10 3.11
PyTorch 2.5 2.12+ init_device_mesh, FSDP2 (fully_shard), DCP
Triton 3.0 3.x latest required for GPU forward kernel
CUDA 12.1 12.4+ for H100/H200/B200 BF16 paths
NCCL 2.20 2.21+ needed for TORCH_NCCL_ASYNC_ERROR_HANDLING
boto3 1.34 latest only if streaming to S3/MinIO
moto 5.x latest local S3 mock for the chaos suite

Hardware

Profile GPUs Interconnect Notes
Smoke / CI none (CPU + Gloo) localhost loopback full unit + integration suite
Single-node dev 1× H100 80GB PCIe Gen5 world=1 degenerate path
Pod (one node) 8× H100 SXM5 NVLink 4 TP across the NVLink island, EP within node
Cluster 256–10 240 H100 NVLink + InfiniBand 400G TP intra-node (size 8), PP inter-node, DP via FSDP2, EP across all GPUs

The default config (configs/default.yaml) targets H100 SXM5 with a peak of 989 TFLOP/s BF16. Override telemetry.hardware_peak_tflops for B200/MI300X.


4. Installation

git clone <this-repo> moe-engine && cd moe-engine

# Recommended: a fresh venv / conda env with python 3.11
python -m venv .venv && source .venv/bin/activate

pip install -U pip wheel
pip install -r requirements.txt

# Optional: GPU-only Triton kernels
pip install triton==3.*           # already pinned in requirements.txt for cu12

# Optional: S3/MinIO mirror
pip install boto3 botocore

Verify the install:

python -c "import torch, triton; print(torch.__version__, triton.__version__)"

5. Local CPU / Gloo regression workflow

Every code path in this repo no-ops cleanly on a 1-rank world. You can run the entire non-chaos test suite on a laptop:

# Unit + integration tests, ~20 s on a modern laptop
pytest -m "not chaos" -v

# Single-rank end-to-end smoke (5 training steps, toy 64-d model)
python train.py --config configs/smoke.yaml --smoke

To exercise the elastic / chaos suite with simulated multi-rank Gloo (spawned as subprocesses on localhost):

# Baseline + Scenario B (storage stall). Scenario A (sudden node failure)
# is lower priority and not included in default runs.
GLOO_SOCKET_IFNAME=lo pytest -m chaos -v -k "baseline or scenario_b"

6. Cluster-scale multi-GPU training

6.1 Single-node, 8× GPU

torchrun \
  --standalone \
  --nnodes=1 \
  --nproc_per_node=8 \
  train.py --config configs/default.yaml

6.2 Multi-node TorchElastic (rendezvous via c10d)

Set the same RDZV_ENDPOINT on every node (typically the head node's IP+port or an etcd cluster). The launcher script scripts/launch.sh wraps this:

# On every node:
NUM_NODES=32 \
GPUS_PER_NODE=8 \
RDZV_ENDPOINT=head-node:29500 \
RUN_ID=moe-run-001 \
bash scripts/launch.sh

The launcher injects the NCCL fail-fast environment variables (TORCH_NCCL_ASYNC_ERROR_HANDLING=1, TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=30) and points the elastic agent at train.py. Workers can be added or removed mid-run; surviving ranks reshard experts and hot-resume from the most recent async checkpoint.

6.3 Topology selection

configs/default.yaml::parallelism must satisfy tensor_parallel · pipeline_parallel · data_parallel · expert_parallel = WORLD_SIZE.

Example for 256 GPUs (32 nodes × 8 H100):

  • tensor_parallel: 8  (intra-node, NVLink)
  • pipeline_parallel: 4  (inter-node, IB)
  • data_parallel: 4  (FSDP2 across remaining axis)
  • expert_parallel: 2

The mesh constructor enforces this product equality; missized configs fail fast at boot rather than mid-step.


7. Configuration reference

configs/default.yaml is the source of truth; configs/smoke.yaml shrinks every dimension for laptop runs. Every block:

model:           # transformer hyperparameters
parallelism:     # topology axes, must product to WORLD_SIZE
training:        # batch sizes, optimizer, schedule
checkpoint:      # local NVMe dir, remote URI, async workers, retention
elastic:         # rendezvous, heartbeat interval, drop grace, min_nodes
telemetry:       # log dirs, MFU target, hardware peak TFLOPs

8. Mathematical invariants & CI gates

Invariant Statement Tested in
Mesh shape dp_size · ep_size · tp_size = world_size for active axes tests/test_distributed.py, tests/test_distributed_invariants.py
Token conservation Σ_r dispatched_r = B·S·K tests/test_distributed.py, tests/test_chaos.py
Numerical tolerance atol < 1e-5, rtol < 1e-5 (fp64 reference) tests/test_kernels.py
Checksum identity hash(state_dict_post_load) == hash(state_dict_pre_save) tests/test_elastic.py
Monotonic progression step_{n+1} > step_n across every restart generation tests/test_chaos.py
Comm-compute overlap EP dispatch/combine use a dedicated CUDA stream and event synchronization pkg/distributed/parallel_mesh.py::DistributedMoELayer.forward
Async-ckpt zero-leak After harness.checkpoint(), no device-resident references survive into the writer thread queue tests/test_elastic.py::test_async_ckpt_no_device_refs
MFU target >= 0.55 of theoretical peak BF16 pkg/utils/mfu.py
Dynamic MoE FLOP FLOPs_step = 2·T_active·P_dense + 2·T_routed·P_experts pkg/utils/mfu.py

CI fails on violation of any of the above.


9. Telemetry envelope

Every training step emits one structured JSON record (also fanned out to TensorBoard via pkg/telemetry/logger.py):

{
  "step": 1024,
  "loss": 2.187,
  "mfu": 0.612,
  "tokens_per_sec": 184320,
  "wall_clock_ms": 412.3,
  "kernel": {
    "sram_bytes_per_block": 49152,
    "achieved_bw_gbps": 1843.0,
    "tokens_per_expert_mean": 256.4,
    "tokens_per_expert_std":  18.7,
    "used_triton": true
  },
  "collective": {
    "all_to_all_dispatch_ms": 0.87,
    "all_to_all_combine_ms":  0.91,
    "all_gather_ms": 1.21,
    "reduce_scatter_ms": 1.05,
    "overlap_ratio": 0.78
  },
  "memory": {
    "peak_allocated_gb": 62.4,
    "reserved_gb": 70.0,
    "leak_delta_gb": 0.0,
    "pinned_host_mb": 1842.0
  },
  "infra": {
    "async_ckpt_commit_ms": 412.0,
    "active_nodes": 1250,
    "ep_world_size": 64,
    "tp_world_size": 8,
    "pp_world_size": 4,
    "dp_world_size": 4,
    "restart_generation": 2
  }
}

10. Fault-injection / chaos workflow

The chaos suite spawns 4 Gloo workers as subprocesses on localhost and exercises the full TorchElastic restart loop:

GLOO_SOCKET_IFNAME=lo pytest -m chaos -v

Scenarios:

Scenario What it injects What it verifies
baseline nothing end-to-end correctness, monotonic step progression
A: sudden node failure SIGKILL to one worker mid-step (lower priority, not included in default test runs)
B: storage stall injected 5-second time.sleep inside the storage adapter async writer queue drains in background, training step never blocks, ckpt eventually commits

Status: baseline ✅, scenario_b ✅, scenario_a ⏸️ (lower priority)


What is actually built

Component Status Detail
Triton router kernel — forward ✅ CI-verified Fused matmul+softmax+topK+renorm; single HBM pass; SRAM tile 64×64
Triton router kernel — backward ✅ CI-verified Analytic ∂/∂logits through softmax→topK→renorm; atol=rtol=1e-5
Token conservation invariant ✅ CI-verified sum(dispatch_cnt) == N×K asserted every forward; 100-seed sweep
Expert load imbalance metric ✅ v0.2 max_load / mean_load tracked per step; logged to telemetry
Router z-loss ✅ v0.2 Auxiliary regularisation signal; emitted per step
DP+EP device mesh ✅ CI-verified init_device_mesh (PyTorch 2.5+); degenerate 1-rank fallback
EP all-to-all (dispatch + combine) ✅ CI-verified Non-blocking all_to_all_single; dedicated CUDA stream; event sync
Compute-comm overlap Expert FFN runs on default stream while a2a is in flight
Comm/compute overlap ratio ✅ v0.3 dispatch_ms / expert_compute_ms; emitted in collective telemetry block
FSDP2 sharding fully_shard along DP axis; per-param DTensor; MixedPrecision
Tensor Parallelism ✅ v0.2 ColumnParallelLinear + RowParallelLinear; both w_gate and w_up ColumnParallel; all_reduce in RowParallel; 2-rank mp.spawn verified
Sequence Parallelism ✅ v0.2 scatter/gather_sequence_parallel; active when tp_size > 1
SP all-gather fusion ✅ v0.3 next_weight param fuses backward all-gather with next projection matmul; halves SP collectives; 2-rank mp.spawn verified
Pipeline Parallelism (single-process) ✅ v0.2 PipelineStage + 1F1B schedule; warmup/steady/drain phases; 13 unit tests
Pipeline Parallelism (multi-process) ✅ v0.3 run_1f1b_distributed; real dist.send/dist.recv on PP group; activation tagging; 2-rank mp.spawn verified
MFU accounting ✅ v0.2 MoE-sparse formula: (K/E)×P_expert; MFUAccountant streaming tracker
Real CUDA telemetry ✅ v0.2 CUDA events on dispatch + combine; memory_stats() peak GB
WandB integration ✅ v0.3 WandBSink; active when WANDB_API_KEY set; --wandb-project flag; log_config() records hyperparameters
Async two-tier checkpointing ✅ CI-verified Pinned host → NVMe (O_DIRECT, 256 MB chunks, atomic rename) → S3
TorchElastic state machine ✅ CI-verified Evict → reshard (round-robin) → reload → resume
Etcd rendezvous ✅ v0.2 ElasticTrainerHarness backend selector; c10d (<100 nodes) / etcd (>100)
Prometheus metrics ✅ v0.3 Optional in-process /metrics endpoint; 10 gauges (incl. expert_compute_ms, comm_compute_overlap_ratio)
Docker + docker-compose ✅ v0.2 Multi-stage image; 1/4/8-GPU compose targets; monitoring stack
Kubernetes manifests ✅ v0.2 Single-node Job + multi-node Indexed Job; PVC; etcd rendezvous
Benchmark suite ✅ v0.2 benchmarks/run_benchmark.py; CPU+GPU sweeps; JSON/CSV output
Chaos: storage stall (Scenario B) ✅ CI-verified 10s injected stall; queue drains; no deadlock
Chaos: node kill + recovery (Scenario A) ⚠️ Flaky ~85% pass rate; Gloo connectFullMesh timeout on 4-rank restart
Nsight/CUPTI integration ❌ Planned v0.4 Requires GPU hardware
Real multi-node benchmark data ❌ Planned v0.4 Requires sustained cluster access

11. Repository layout

moe-engine/
├── pkg/
│   ├── kernels/moe_router.py        Triton fwd+bwd kernel, MoERouter module
│   ├── distributed/parallel_mesh.py 4D mesh, DistributedMoELayer, TP/SP/PP layers
│   ├── elastic/fault_monitor.py     AsyncCheckpointer, ClusterStateMachine, harness
│   ├── telemetry/logger.py          Structured JSONL + TensorBoard + Prometheus + WandB
│   └── utils/
│       ├── mfu.py                   MoE-aware MFU accounting + streaming tracker
│       └── config.py                YAML config loader
├── benchmarks/
│   ├── run_benchmark.py             Reproducible benchmark suite (CPU+GPU)
│   └── BENCHMARKS.md               Methodology + results + engineering notes
├── deploy/
│   ├── docker/
│   │   ├── Dockerfile               Multi-stage image (builder + runtime)
│   │   ├── docker-compose.yml       smoke / 4-GPU / 8-GPU / monitoring targets
│   │   └── prometheus.yml           Prometheus scrape config
│   └── k8s/
│       ├── namespace.yaml
│       ├── configmap.yaml
│       ├── training-job.yaml        Single-node 8-GPU Job
│       ├── training-job-multinode.yaml  16-node Indexed Job + etcd rendezvous
│       └── pvc.yaml                 ReadWriteMany checkpoint PVC
├── configs/
│   ├── default.yaml                 H100-scale production config
│   └── smoke.yaml                   CPU-only 2-step test config
├── tests/                           Full test suite (15 files, 145 tests)
├── docs/                            Architecture, design, operations docs
├── train.py                         TorchElastic entrypoint
├── roadmap.md                       Honest status + next actions
└── pyproject.toml

What would I do differently at 1000+ GPUs

  1. Replace Gloo with NCCL everywhere — Gloo's connectFullMesh is O(N²) in the number of ranks. At 1000+ ranks the re-formation time after a node drop dominates recovery. NCCL uses a ring-based topology that scales logarithmically.

  2. Gradient checkpointing at the expert level — At extreme scale, the expert activation tensors for the combine step can't all stay live simultaneously. Selectively recomputing the w_up × silu(w_gate) activation per expert halves peak memory for the MoE layers.

  3. Overlapped NVMe checkpoint streaming — The current design copies the full shard to pinned memory before enqueuing. At very large shard sizes (80GB+), a better design streams tensor-by-tensor directly from CUDA to NVMe without staging the full shard in host RAM.

  4. Expert-level capacity overflow handling — The current capacity_factor=1.25 simply drops overflow tokens. Production systems (Switch Transformer, GShard) re-route overflow to the second-choice expert. This requires a second router pass and adds ~5% router overhead but is essential for training stability.

  5. Sequence parallelism by default at TP>2 — At long context (128K tokens), the hidden state tensor for a single sequence doesn't fit on one GPU at fp32. Sequence parallelism is not optional at those scales; it should be the default codepath, not an auxiliary feature.


12. License

Apache-2.0. See LICENSE.

About

moe-engine is a research-grade infrastructure layer for training large Mixture-of-Experts language models at hyperscale. It is designed around one core constraint: at 10K+ GPUs, nodes die continuously. The system must keep training alive end-to-end — routing correctly, checkpointing durably, and resuming without operator intervention.

Topics

Resources

License

Contributing

Security policy

Stars

Watchers

Forks

Packages

 
 
 

Contributors