Skip to content

🛣️ Path to 30B MoE long-context SFT training #5713

@AmineDiro

Description

@AmineDiro

🛣️ Path to 30B MoE long-context SFT training

Tracking issue / draft PR for SFTTrainer scaling to frontier-scale MoE.

Train Qwen3-30B-A3B (and 235B-A22B) end-to-end with TRL's SFTTrainer at long context (16k → 1M) on 8×H100 nodes, with MFU competitive with dense training. This page tracks the cross-repo work (transformers / accelerate /TRL / DeepSpeed / Liger) that this depends on, the known issues, and the performance numbers we land along the way.

Community contributions and reproduction reports on different hardware are very welcome. Please drop a comment with your config + numbers.

Why this is non-trivial

Qwen3-30B-A3B has 128 experts (8 active per token) out of the box it currently doesn't train cleanly on any combination of FSDP2 / DeepSpeed /Expert Parallel / Sequence Parallel / Context Parallel with trl. This issue tracks the recipes to get there.

Status snapshot — Qwen3-30B-A3B (2026-05-06)

Best end-to-end-correct configurations on 2× / 4× / 8× p5.48xlarge nodes (H100 SXM5, 989.5 TFLOPS bf16 peak).

Window MFU peak is the per-log-window throughput; adj. is the causal-corrected value (matches the Llama 2/3 / DS-Ulysses convention, half the attention FLOPs disappear under causal masking).

Context Nodes Recipe Window MFU Adj. MFU Loss
16k 2 FSDP2 + EP=8 + FA3 + sonicmoe (post-PR #45621 kernel fix) 48.2 % 32.3 % 13.4 ✅
32k 2 DS-Z2 + EP=8 + FA3 + sonicmoe + Liger 65.0 % 39.2 % 8.0 ✅
64k 4 DS-Z2 + EP=32 + FA3 + sonicmoe + Liger (R3) 72.0 % 40.1 % 1.87 ✅
256k 8 DS-Z3 + SP=2 + FA3 + sonicmoe + Liger + compile 63.6 % 32.8 % 1.56 ✅
512k 8 DS-Z3 + SP=4 + FA3 + sonicmoe + Liger + compile 63.3 % 32.1 %
1M 8 DS-Z3 + SP=8 + FA3 + sonicmoe + Liger + compile 62.3 % 31.4 % 1.56 ✅
128k 4 DS-Z2 + EP=32 + FA3 + sonicmoe + Liger 81.7 % 43.4 % NaN ⛔

For Qwen3-235B-A22B early numbers (32k 8n EP=64 + Liger + compile = 70.1 % window MFU, healthy loss).

Hardware / software baseline

  • Cluster: AWS p5.48xlarge — 8× H100 SXM5 80 GB per node, 32× EFA, 3200 Gbps inter-node aggregate
  • TRL: benchmark-sft-moe branch (this branch will land via the PRs in §TRL below)
  • transformers: 5.6.0.dev0 + the qwen3-moe-ep-v2 series of fixes (see §transformers below)
  • accelerate: 1.13.0 + an in-place _prepare_tp has_ep skip (§accelerate)
  • DeepSpeed: stock — no fork, no in-venv patch. DS's MoE auto-detection is extended in transformers/trainer.py via a transparent monkey-patch on DeepSpeedEngine._configure_distributed_model (see §DeepSpeed).
  • PyTorch: 2.10.0+cu128
  • Liger Kernel: fused CrossEntropy + RMSNorm + RoPE (Triton). SwiGLU patch must be disabled under EP (--liger_kernel_config '{"swiglu":false}')
  • Flash Attention 3: kernels-community/vllm-flash-attn3
  • MoE kernel: kernels-community/sonic-moe ("revision": "ep-support") selected via --experts_implementation sonicmoe
  • Dataset: THUDM/LongAlign-10k packed with --packing --packing_strategy wrapped

Linked PRs

transformers

Core EP support + correctness fixes. The first three are the load-bearing ones. Without them, EP > 1 produces silently wrong expert outputs.

  • Merged #45436 — Add EP support for Qwen3 MoE, fix GroupedGemmParallel for 2D meshes
  • Merged #45473 — Fix EP routing: RouterParallel shape, tp_plan property, grouped_mm sentinels (3 bugs that combined to produce silently wrong logits at every EP > 1; first regression tests for EP)
    -✅ Merged #45433 — Integrate the kernels-community/sonic-moe CuteDSL fused MoE kernel as a selectable _experts_implementation. Drop-in for grouped_mm; +23 % steady-state MFU vs grouped_mm on 16k EP=8 (single biggest kernel win on this stack).
  • Merged #45621 — sonicmoe kernel-side sentinel fix. Drops the wrapper-level expert_ids.clamp workaround; +5–8 pp peak MFU at 16k–32k EP=8. TRL-side contribution from this work: the grouped_mm_experts_forward wrapper-level masked_fill_ pre/post-mask pair (integrations/moe.py) that pairs with the kernel fix to keep gradients clean on EP-sentinel rows.
  • 🟡 Open #45994 -- fix reported loss over-counting under TP and EP.
  • 🟡 Open #45662 — EP + FSDP DTensor wrap (lets EP-sharded params survive FSDP2's ignored_params boundary)
  • 🟡 Open #45548 — DeepSpeed-Z3 + EP weight loading
  • 🟡 Open #45649 — FSDP cpu_ram_efficient_loading fixes
  • Plannedtp_plan loader rank-0 gate (Layer 1 of the duplicate-load issue: every TP/EP rank reads the full dense replica from disk; ~480 GB redundant disk I/O / node on 30B EP=8). Draft at benchmark/upstream_issue_tp_plan_duplicate_load.md.
  • PlannedValueError guard when enable_expert_parallel=True and cpu_ram_efficient_loading=True are combined
  • Planned — DS-Z2 + EP transformers PR (4 patches: tensor_parallel.py:GroupedGemmParallel.post_shard_wrap backend branch, Trainer.create_accelerator_and_postprocess MoE-group setup, Trainer.create_optimizer MoE param split, Trainer._clip_grad_norm cross-mesh skip) + a post-deepspeed.initialize engine attribute patch that replaces the originally-planned DeepSpeed-side change
  • Planned — wire DistributedConfig(enable_expert_parallel=True) intoaccelerate.state.parallelism_config.ep_size once accelerate exposes that field

accelerate

  • Merged #4022 -- Fix region compilation fsdp2. Enable torch_compile=True.
  • Planned_prepare_tp early-return when model.has_ep (post-#45662 EP params become DTensors and trigger an ImportError on ReplicateParallel; single-line check)
  • Planned — first-class ep_size field in ParallelismConfig + submesh_ep_size divisor in prepare_data_loader (mirrors existing TP handling exactly). Today TRL piggy-backs on the TP-replication path by exposing the EP mesh as "tp". Highest-leverage upstream PR remaining.
  • Plannedfsdp2_prepare_model capture/restore around the meta-move so EP ignored_params survive the FSDP rank-0 broadcast (paired with the transformers cpu_ram_efficient_loading work above)

TRL

The benchmark-sft-moe branch will be split into a series of PRs againstmain. Order is independence-first — anything that can ship alone goes first.

  1. Merged Open #5698 compute_flops_per_token + compute_mfu in trl/trainer/utils.py. Pure-Python, no SFTTrainer coupling. Includes causal-correct attention and the embed_flops=0 / lm_head_flops=2*V*h accounting fix.
  2. SFTTrainer.log() MFU integration (window + cumulative; corrects num_input_tokens_seen overcount by cp_size × sp_size)
  3. enable_expert_parallel + expert_parallel_size + experts_implementation config fields and the EP branch in SFTTrainer.__init__ (depends on the transformers PRs landing)
  4. ⏳ SP --pad_to_multiple_of auto-default when accelerator.parallelism_config.sp_size > 1
  5. Merged #5575 by @qgallouedec — Chunked cross-entropy loss for SFT (up to −50 % VRAM). Load-bearing for this stack: it's what frees the ~20 GB lm_head logit tensor and lets the EP-replicated expert buffer fit at 32k → 128k context, which unlocks every long-context champion in the table above.

Known issues / open blockers

Status: 🔧 active workaround · 🟡 local, upstream fix might be different· ⛔ not yet fixed.

  • 🔧 EP-aware DataLoader sharding — Local fix in transformers/trainer.py that exposes the EP mesh as "tp" so accelerate's existing TP-replication branch in prepare_data_loader divides correctly. Without it, every world rank gets a unique micro-batch but the 8 ranks of an EP group must see the same batch (EP shards experts only, not data) → silent NCCL hang on the EP all-reduce after a random number of steps. Highest-leverage upstream cleanup remaining.
  • 🟡 _clip_grad_norm cross-mesh failureclip_grad_norm__foreach_norm stacks per-param norms; stacking DTensors on different meshes (EP mesh + FSDP DP mesh) errors with RuntimeError: All operands in aten.stack.default must have the same mesh. Local skip returns tensor(0.0) for telemetry — gradients are not actually clipped. Fine for benchmarks, unsafe for production. Proper fix is in PyTorch.
  • ⛔ FSDP + EP + compile Adam crash: _group_tensors_by_device_and_dtype strict-asserts grouped tensors share device+dtype; EP DTensors (EP mesh, size 8) and FSDP DP DTensors (size 16) trip the assert. DS-Z2+EP+compile works because DS uses plain nn.Parameter (no DTensor mesh).
  • ⛔ FA3 + CP incompatible : accelerate hard-guards CP to sdpa attention. Long-context MoE has to choose between FA3 throughput (SP path) and ring-attention seq sharding (CP path). Cost: ~5 pp MFU at 32k FSDP+EP+CP=2.
  • ⛔ Loss-zero / NaN at large EP + long-ctx: at "≥32k context AND large total rank count", loss.item() PRE-backward is NaN from step 1, masked downstream so the trainer reports loss=0. NaN appears inside the optimizer step for the FSDP/DS-sharded entry-side params (embed_tokens.weight + layers.0.self_attn.{q,k,v,o}_proj.weight). Forward + backward math is clean; bug is in (grad → optim_state → param). High priority — the 81 % window MFU rows on this stack are throughput-real but training-incorrect today. Healthy convergence at 64k 4n EP=32 (1.87 loss) and at SP-path long-context (1.56 loss).

How to reproduce

SFTTrainer with --enable_expert_parallel requires patched checkouts of transformers and accelerate until the PRs in §transformers and §accelerate land. The fork branches below carry exactly the local fixes

Uses uv — install with curl -LsSf https://astral.sh/uv/install.sh | sh if you don't already have it.

# 1. TRL — this branch
git clone --branch benchmark-sft-moe https://github.com/huggingface/trl.git
cd trl

# 2. Fresh venv in the TRL repo
uv venv .venv --python 3.11
source .venv/bin/activate
uv pip install -e .

# 3. Patched transformers — the rebased fork carrying everything in §transformers
#    that is not yet upstream (EP+FSDP DTensor wrap, DS-Z2+EP trainer hooks,
#    `_configure_distributed_model` monkey-patch for stock DS, sonicmoe wrapper).
git clone --branch ds-ep-integration https://github.com/AmineDiro/transformers.git ../transformers
uv pip install -e ../transformers --no-deps --reinstall

# 4. Patched accelerate — `_prepare_tp` skip when `model._device_mesh` is set
#    by transformers' EP path (see §accelerate).
git clone --branch ep-fixes https://github.com/AmineDiro/accelerate.git ../accelerate
uv pip install -e ../accelerate --no-deps --reinstall

# 5. Pin torch + runtime extras the kernel + logging stack expects but TRL's
#    pyproject does not pin. Versions matched against the validated benchmark venv.
uv pip install --reinstall torch==2.10.0
uv pip install deepspeed==0.18.9 liger-kernel==0.7.0 \
               kernels==0.13.0 trackio==0.23.0 \
               nvidia-cutlass-dsl==4.4.2 apache-tvm-ffi==0.1.9

# 6. The benchmark templates expect a `venv` at /fsx/amine_dirhoussi/trl/.venv.
#    For an out-of-tree clone, edit benchmark/templates/launch.sh.j2 to point
#    at your venv (or pass it via run_benchmark.py if you patch the template
#    rendering to thread `venv_path` through).

# 7. Submit the 32k 2-node champion (DS-Z2 + EP=8 + FA3 + sonicmoe + Liger,
#    65 % window MFU). Each row in the YAML corresponds to one cell in the
#    headline table — pick a different --run-index for the others.
python benchmark/run_benchmark.py \
    --config benchmark/configs/qwen3_30b_a3b.yaml \
    --submit \
    --run-index <row-from-table-above>

Repro verification (2026-05-06) : a fresh install of the three forks above ran the four configs below end-to-end. The infrastructure path works as documented and throughput-correct numbers match the headline table (mfu_window peak 84.89 % at 64k 4n EP=32, 66.65 % at 32k 2n EP=8).

Help wanted

  • Reproduction reports on non-H100 hardware (H200, MI300X, B200) — especially the EP=8 16k FSDP2 baseline
  • Cross-checks of the MFU formula in trl/trainer/utils.py:compute_flops_per_token against your own training runs

Note

Low Risk
Documentation-only change adding a new markdown tracking page; no runtime or API behavior changes.

Overview
Adds path_30b_longctx.md, a detailed tracking document for scaling TRL SFTTrainer to long-context MoE (Qwen3 30B/235B), including benchmark MFU/loss snapshots, linked upstream PR dependencies across transformers/accelerate/DeepSpeed/Liger, known blockers, and step-by-step reproduction instructions.

Reviewed by Cursor Bugbot for commit 0988933. Bugbot is set up for automated code reviews on this repo. Configure here.

Metadata

Metadata

Assignees

Labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions