Production-grade sparse MoE training runtime. Designed to keep large-scale pre-training jobs alive end-to-end: sparse Top-K routing in custom Triton, DP+EP distributed training with TP support in core layers and PP work in progress, on PyTorch 2.12+, asynchronous sharded checkpointing through a two-tier (NVMe → S3 / MinIO) durable store, and a TorchElastic state-machine that evicts dead ranks, reshards experts, and hot-resumes training without operator intervention.
- Why this exists
- System architecture
- Hardware & software requirements
- Installation
- Local CPU / Gloo regression workflow
- Cluster-scale multi-GPU training
- Configuration reference
- Mathematical invariants & CI gates
- Telemetry envelope
- Fault-injection / chaos workflow
- Repository layout
- License
At frontier-lab scale three engineering disciplines that are normally separate teams must instead be co-designed in a single repository:
| Layer | Concern | This repo's contribution |
|---|---|---|
| Hardware-aware kernels | Memory coalescing, SRAM tiling, Tensor-Core feeding for sparse Top-K routing | pkg/kernels/moe_router.py — Triton forward + (planned) backward, dynamic-bound masking, 128-byte aligned loads |
| Distributed runtime | DP+EP training with TP support in core layers, FSDP2 DTensor sharding, EP all_to_all_single overlapped with compute |
pkg/distributed/parallel_mesh.py — init_device_mesh((dp, ep)) with TP axis reserved, dedicated comm stream overlap, dedicated CUDA streams |
| Fault-tolerant infra | Async pinned-memory checkpointing, S3/MinIO mirror, evict→reshard→reload state-machine | pkg/elastic/fault_monitor.py — TorchElastic harness, SHARDED_STATE_DICT, signal-driven flush |
moe-engine keeps these three layers in one binary so an MFU regression or a
checkpoint-stall bug can be isolated to a single line, not a six-team incident.
┌─────────────────────────────────────────────┐
│ train.py (entrypoint) │
│ argparse → load_config → build_topology │
└──────────────┬──────────────────────────────┘
│
┌───────────────────────────────────┼────────────────────────────────────┐
│ │ │
▼ ▼ ▼
┌────────────────────────────┐ ┌──────────────────────────────┐ ┌──────────────────────────────┐
│ pkg/distributed/ │ │ pkg/elastic/ │ │ pkg/kernels/ │
│ parallel_mesh.py │ │ fault_monitor.py │ │ moe_router.py │
│ │ │ │ │ │
│ • ParallelTopology │ │ • ElasticTrainerHarness │ │ • MoERouter (nn.Module) │
│ • init_device_mesh((dp,ep))│ │ • AsyncCheckpointer │ │ ├─ forward: Triton @jit │
│ with TP axis reserved │ │ • _PinnedHostStager │ │ │ - dynamic-bound mask │
│ • DistributedMoELayer │ │ • ClusterStateMachine │ │ │ - 128B aligned loads │
│ • apply_fsdp2(...) │ │ • LocalNVMeAdapter │ │ └─ backward: Triton JIT │
│ • all_to_all_dispatch │ │ • S3Adapter (boto3) │ │ • CPU autograd fallback │
│ on dedicated comm stream │ │ │ │ │
└───────────┬────────────────┘ └─────────────┬────────────────┘ └──────────────┬───────────────┘
│ │ │
│ DeviceMesh sub-meshes │ pinned-host snapshot queue │ routing tokens
│ ("pp","dp","ep","tp") │ │ + gating weights
│ │ │
▼ ▼ ▼
┌────────────────────┐ ┌───────────────────────────┐ ┌──────────────────────┐
│ NCCL / Gloo │ │ tier-1 NVMe (staging) │ │ Triton runtime │
│ process groups │ │ tier-2 S3 / MinIO mirror │ │ (CUDA / ROCm) │
│ (one per axis) │ │ background I/O thread×N │ │ │
└────────────────────┘ └───────────┬───────────────┘ └──────────────────────┘
│
▼
┌───────────────────────────┐
│ TorchElastic agent │
│ (rdzv: c10d / etcd) │
│ rendezvous → restart loop │
└───────────────────────────┘
Data-flow per training step
───────────────────────────
ids → embed → (TP shard) → block_0 ── … ── block_N → norm → lm_head → loss
│
▼
DistributedMoELayer.forward
┌────────────────────────────┐
│ 1. router (Triton fwd) │
│ 2. sort by target EP rank │
│ 3. all_to_all_single │
│ on a dedicated comm stream ──► launch ─────────┐
│ 4. independent compute ─── overlap ───►│ GPU compute
│ 5. work.wait() on dispatch │ in flight
│ 6. local SwiGLU experts │
│ 7. all_to_all_combine on a dedicated comm stream ──┘
│ 8. weight ⊗ combine → reduce-K │
└──────────────────────────────────┘
Per-rank a coordinate identifies its mesh slice. Sub-meshes are obtained by name:
mesh["dp"] (for FSDP2 sharding), mesh["ep"] (for all_to_all_single),
mesh["tp"] (for column-parallel linears), and mesh["pp"] is reserved
for future pipeline stage send/recv support.
| Component | Minimum | Recommended | Notes |
|---|---|---|---|
| Python | 3.10 | 3.11 | |
| PyTorch | 2.5 | 2.12+ | init_device_mesh, FSDP2 (fully_shard), DCP |
| Triton | 3.0 | 3.x latest | required for GPU forward kernel |
| CUDA | 12.1 | 12.4+ | for H100/H200/B200 BF16 paths |
| NCCL | 2.20 | 2.21+ | needed for TORCH_NCCL_ASYNC_ERROR_HANDLING |
boto3 |
1.34 | latest | only if streaming to S3/MinIO |
moto |
5.x | latest | local S3 mock for the chaos suite |
| Profile | GPUs | Interconnect | Notes |
|---|---|---|---|
| Smoke / CI | none (CPU + Gloo) | localhost loopback | full unit + integration suite |
| Single-node dev | 1× H100 80GB | PCIe Gen5 | world=1 degenerate path |
| Pod (one node) | 8× H100 SXM5 | NVLink 4 | TP across the NVLink island, EP within node |
| Cluster | 256–10 240 H100 | NVLink + InfiniBand 400G | TP intra-node (size 8), PP inter-node, DP via FSDP2, EP across all GPUs |
The default config (configs/default.yaml) targets H100 SXM5 with a peak of
989 TFLOP/s BF16. Override telemetry.hardware_peak_tflops for B200/MI300X.
git clone <this-repo> moe-engine && cd moe-engine
# Recommended: a fresh venv / conda env with python 3.11
python -m venv .venv && source .venv/bin/activate
pip install -U pip wheel
pip install -r requirements.txt
# Optional: GPU-only Triton kernels
pip install triton==3.* # already pinned in requirements.txt for cu12
# Optional: S3/MinIO mirror
pip install boto3 botocoreVerify the install:
python -c "import torch, triton; print(torch.__version__, triton.__version__)"Every code path in this repo no-ops cleanly on a 1-rank world. You can run the entire non-chaos test suite on a laptop:
# Unit + integration tests, ~20 s on a modern laptop
pytest -m "not chaos" -v
# Single-rank end-to-end smoke (5 training steps, toy 64-d model)
python train.py --config configs/smoke.yaml --smokeTo exercise the elastic / chaos suite with simulated multi-rank Gloo (spawned as subprocesses on localhost):
# Baseline + Scenario B (storage stall). Scenario A (sudden node failure)
# is lower priority and not included in default runs.
GLOO_SOCKET_IFNAME=lo pytest -m chaos -v -k "baseline or scenario_b"torchrun \
--standalone \
--nnodes=1 \
--nproc_per_node=8 \
train.py --config configs/default.yamlSet the same RDZV_ENDPOINT on every node (typically the head node's IP+port
or an etcd cluster). The launcher script scripts/launch.sh wraps this:
# On every node:
NUM_NODES=32 \
GPUS_PER_NODE=8 \
RDZV_ENDPOINT=head-node:29500 \
RUN_ID=moe-run-001 \
bash scripts/launch.shThe launcher injects the NCCL fail-fast environment variables
(TORCH_NCCL_ASYNC_ERROR_HANDLING=1, TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=30)
and points the elastic agent at train.py. Workers can be
added or removed mid-run; surviving ranks reshard experts and hot-resume from
the most recent async checkpoint.
configs/default.yaml::parallelism must satisfy
tensor_parallel · pipeline_parallel · data_parallel · expert_parallel = WORLD_SIZE.
Example for 256 GPUs (32 nodes × 8 H100):
tensor_parallel: 8(intra-node, NVLink)pipeline_parallel: 4(inter-node, IB)data_parallel: 4(FSDP2 across remaining axis)expert_parallel: 2
The mesh constructor enforces this product equality; missized configs fail fast at boot rather than mid-step.
configs/default.yaml is the source of truth; configs/smoke.yaml shrinks
every dimension for laptop runs. Every block:
model: # transformer hyperparameters
parallelism: # topology axes, must product to WORLD_SIZE
training: # batch sizes, optimizer, schedule
checkpoint: # local NVMe dir, remote URI, async workers, retention
elastic: # rendezvous, heartbeat interval, drop grace, min_nodes
telemetry: # log dirs, MFU target, hardware peak TFLOPs| Invariant | Statement | Tested in |
|---|---|---|
| Mesh shape | dp_size · ep_size · tp_size = world_size for active axes |
tests/test_distributed.py, tests/test_distributed_invariants.py |
| Token conservation | Σ_r dispatched_r = B·S·K |
tests/test_distributed.py, tests/test_chaos.py |
| Numerical tolerance | atol < 1e-5, rtol < 1e-5 (fp64 reference) |
tests/test_kernels.py |
| Checksum identity | hash(state_dict_post_load) == hash(state_dict_pre_save) |
tests/test_elastic.py |
| Monotonic progression | step_{n+1} > step_n across every restart generation |
tests/test_chaos.py |
| Comm-compute overlap | EP dispatch/combine use a dedicated CUDA stream and event synchronization | pkg/distributed/parallel_mesh.py::DistributedMoELayer.forward |
| Async-ckpt zero-leak | After harness.checkpoint(), no device-resident references survive into the writer thread queue |
tests/test_elastic.py::test_async_ckpt_no_device_refs |
| MFU target | >= 0.55 of theoretical peak BF16 |
pkg/utils/mfu.py |
| Dynamic MoE FLOP | FLOPs_step = 2·T_active·P_dense + 2·T_routed·P_experts |
pkg/utils/mfu.py |
CI fails on violation of any of the above.
Every training step emits one structured JSON record (also fanned out to
TensorBoard via pkg/telemetry/logger.py):
{
"step": 1024,
"loss": 2.187,
"mfu": 0.612,
"tokens_per_sec": 184320,
"wall_clock_ms": 412.3,
"kernel": {
"sram_bytes_per_block": 49152,
"achieved_bw_gbps": 1843.0,
"tokens_per_expert_mean": 256.4,
"tokens_per_expert_std": 18.7,
"used_triton": true
},
"collective": {
"all_to_all_dispatch_ms": 0.87,
"all_to_all_combine_ms": 0.91,
"all_gather_ms": 1.21,
"reduce_scatter_ms": 1.05,
"overlap_ratio": 0.78
},
"memory": {
"peak_allocated_gb": 62.4,
"reserved_gb": 70.0,
"leak_delta_gb": 0.0,
"pinned_host_mb": 1842.0
},
"infra": {
"async_ckpt_commit_ms": 412.0,
"active_nodes": 1250,
"ep_world_size": 64,
"tp_world_size": 8,
"pp_world_size": 4,
"dp_world_size": 4,
"restart_generation": 2
}
}The chaos suite spawns 4 Gloo workers as subprocesses on localhost and exercises the full TorchElastic restart loop:
GLOO_SOCKET_IFNAME=lo pytest -m chaos -vScenarios:
| Scenario | What it injects | What it verifies |
|---|---|---|
| baseline | nothing | end-to-end correctness, monotonic step progression |
| A: sudden node failure | SIGKILL to one worker mid-step |
(lower priority, not included in default test runs) |
| B: storage stall | injected 5-second time.sleep inside the storage adapter |
async writer queue drains in background, training step never blocks, ckpt eventually commits |
Status: baseline ✅, scenario_b ✅, scenario_a ⏸️ (lower priority)
| Component | Status | Detail |
|---|---|---|
| Triton router kernel — forward | ✅ CI-verified | Fused matmul+softmax+topK+renorm; single HBM pass; SRAM tile 64×64 |
| Triton router kernel — backward | ✅ CI-verified | Analytic ∂/∂logits through softmax→topK→renorm; atol=rtol=1e-5 |
| Token conservation invariant | ✅ CI-verified | sum(dispatch_cnt) == N×K asserted every forward; 100-seed sweep |
| Expert load imbalance metric | ✅ v0.2 | max_load / mean_load tracked per step; logged to telemetry |
| Router z-loss | ✅ v0.2 | Auxiliary regularisation signal; emitted per step |
| DP+EP device mesh | ✅ CI-verified | init_device_mesh (PyTorch 2.5+); degenerate 1-rank fallback |
| EP all-to-all (dispatch + combine) | ✅ CI-verified | Non-blocking all_to_all_single; dedicated CUDA stream; event sync |
| Compute-comm overlap | ✅ | Expert FFN runs on default stream while a2a is in flight |
| Comm/compute overlap ratio | ✅ v0.3 | dispatch_ms / expert_compute_ms; emitted in collective telemetry block |
| FSDP2 sharding | ✅ | fully_shard along DP axis; per-param DTensor; MixedPrecision |
| Tensor Parallelism | ✅ v0.2 | ColumnParallelLinear + RowParallelLinear; both w_gate and w_up ColumnParallel; all_reduce in RowParallel; 2-rank mp.spawn verified |
| Sequence Parallelism | ✅ v0.2 | scatter/gather_sequence_parallel; active when tp_size > 1 |
| SP all-gather fusion | ✅ v0.3 | next_weight param fuses backward all-gather with next projection matmul; halves SP collectives; 2-rank mp.spawn verified |
| Pipeline Parallelism (single-process) | ✅ v0.2 | PipelineStage + 1F1B schedule; warmup/steady/drain phases; 13 unit tests |
| Pipeline Parallelism (multi-process) | ✅ v0.3 | run_1f1b_distributed; real dist.send/dist.recv on PP group; activation tagging; 2-rank mp.spawn verified |
| MFU accounting | ✅ v0.2 | MoE-sparse formula: (K/E)×P_expert; MFUAccountant streaming tracker |
| Real CUDA telemetry | ✅ v0.2 | CUDA events on dispatch + combine; memory_stats() peak GB |
| WandB integration | ✅ v0.3 | WandBSink; active when WANDB_API_KEY set; --wandb-project flag; log_config() records hyperparameters |
| Async two-tier checkpointing | ✅ CI-verified | Pinned host → NVMe (O_DIRECT, 256 MB chunks, atomic rename) → S3 |
| TorchElastic state machine | ✅ CI-verified | Evict → reshard (round-robin) → reload → resume |
| Etcd rendezvous | ✅ v0.2 | ElasticTrainerHarness backend selector; c10d (<100 nodes) / etcd (>100) |
| Prometheus metrics | ✅ v0.3 | Optional in-process /metrics endpoint; 10 gauges (incl. expert_compute_ms, comm_compute_overlap_ratio) |
| Docker + docker-compose | ✅ v0.2 | Multi-stage image; 1/4/8-GPU compose targets; monitoring stack |
| Kubernetes manifests | ✅ v0.2 | Single-node Job + multi-node Indexed Job; PVC; etcd rendezvous |
| Benchmark suite | ✅ v0.2 | benchmarks/run_benchmark.py; CPU+GPU sweeps; JSON/CSV output |
| Chaos: storage stall (Scenario B) | ✅ CI-verified | 10s injected stall; queue drains; no deadlock |
| Chaos: node kill + recovery (Scenario A) | ~85% pass rate; Gloo connectFullMesh timeout on 4-rank restart |
|
| Nsight/CUPTI integration | ❌ Planned v0.4 | Requires GPU hardware |
| Real multi-node benchmark data | ❌ Planned v0.4 | Requires sustained cluster access |
moe-engine/
├── pkg/
│ ├── kernels/moe_router.py Triton fwd+bwd kernel, MoERouter module
│ ├── distributed/parallel_mesh.py 4D mesh, DistributedMoELayer, TP/SP/PP layers
│ ├── elastic/fault_monitor.py AsyncCheckpointer, ClusterStateMachine, harness
│ ├── telemetry/logger.py Structured JSONL + TensorBoard + Prometheus + WandB
│ └── utils/
│ ├── mfu.py MoE-aware MFU accounting + streaming tracker
│ └── config.py YAML config loader
├── benchmarks/
│ ├── run_benchmark.py Reproducible benchmark suite (CPU+GPU)
│ └── BENCHMARKS.md Methodology + results + engineering notes
├── deploy/
│ ├── docker/
│ │ ├── Dockerfile Multi-stage image (builder + runtime)
│ │ ├── docker-compose.yml smoke / 4-GPU / 8-GPU / monitoring targets
│ │ └── prometheus.yml Prometheus scrape config
│ └── k8s/
│ ├── namespace.yaml
│ ├── configmap.yaml
│ ├── training-job.yaml Single-node 8-GPU Job
│ ├── training-job-multinode.yaml 16-node Indexed Job + etcd rendezvous
│ └── pvc.yaml ReadWriteMany checkpoint PVC
├── configs/
│ ├── default.yaml H100-scale production config
│ └── smoke.yaml CPU-only 2-step test config
├── tests/ Full test suite (15 files, 145 tests)
├── docs/ Architecture, design, operations docs
├── train.py TorchElastic entrypoint
├── roadmap.md Honest status + next actions
└── pyproject.toml
-
Replace Gloo with NCCL everywhere — Gloo's
connectFullMeshis O(N²) in the number of ranks. At 1000+ ranks the re-formation time after a node drop dominates recovery. NCCL uses a ring-based topology that scales logarithmically. -
Gradient checkpointing at the expert level — At extreme scale, the expert activation tensors for the combine step can't all stay live simultaneously. Selectively recomputing the
w_up × silu(w_gate)activation per expert halves peak memory for the MoE layers. -
Overlapped NVMe checkpoint streaming — The current design copies the full shard to pinned memory before enqueuing. At very large shard sizes (80GB+), a better design streams tensor-by-tensor directly from CUDA to NVMe without staging the full shard in host RAM.
-
Expert-level capacity overflow handling — The current
capacity_factor=1.25simply drops overflow tokens. Production systems (Switch Transformer, GShard) re-route overflow to the second-choice expert. This requires a second router pass and adds ~5% router overhead but is essential for training stability. -
Sequence parallelism by default at TP>2 — At long context (128K tokens), the hidden state tensor for a single sequence doesn't fit on one GPU at fp32. Sequence parallelism is not optional at those scales; it should be the default codepath, not an auxiliary feature.
Apache-2.0. See LICENSE.