You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Tracking issue / draft PR for SFTTrainer scaling to frontier-scale MoE.
Train Qwen3-30B-A3B (and 235B-A22B) end-to-end with TRL's SFTTrainer at long context (16k → 1M) on 8×H100 nodes, with MFU competitive with dense training. This page tracks the cross-repo work (transformers / accelerate /TRL / DeepSpeed / Liger) that this depends on, the known issues, and the performance numbers we land along the way.
Community contributions and reproduction reports on different hardware are very welcome. Please drop a comment with your config + numbers.
Why this is non-trivial
Qwen3-30B-A3B has 128 experts (8 active per token) out of the box it currently doesn't train cleanly on any combination of FSDP2 / DeepSpeed /Expert Parallel / Sequence Parallel / Context Parallel with trl. This issue tracks the recipes to get there.
Status snapshot — Qwen3-30B-A3B (2026-05-06)
Best end-to-end-correct configurations on 2× / 4× / 8× p5.48xlarge nodes (H100 SXM5, 989.5 TFLOPS bf16 peak).
Window MFU peak is the per-log-window throughput; adj. is the causal-corrected value (matches the Llama 2/3 / DS-Ulysses convention, half the attention FLOPs disappear under causal masking).
TRL: benchmark-sft-moe branch (this branch will land via the PRs in §TRL below)
transformers: 5.6.0.dev0 + the qwen3-moe-ep-v2 series of fixes (see §transformers below)
accelerate: 1.13.0 + an in-place _prepare_tphas_ep skip (§accelerate)
DeepSpeed: stock — no fork, no in-venv patch. DS's MoE auto-detection is extended in transformers/trainer.py via a transparent monkey-patch on DeepSpeedEngine._configure_distributed_model (see §DeepSpeed).
PyTorch: 2.10.0+cu128
Liger Kernel: fused CrossEntropy + RMSNorm + RoPE (Triton). SwiGLU patch must be disabled under EP (--liger_kernel_config '{"swiglu":false}')
MoE kernel: kernels-community/sonic-moe ("revision": "ep-support") selected via --experts_implementation sonicmoe
Dataset: THUDM/LongAlign-10k packed with --packing --packing_strategy wrapped
Linked PRs
transformers
Core EP support + correctness fixes. The first three are the load-bearing ones. Without them, EP > 1 produces silently wrong expert outputs.
✅ Merged#45436 — Add EP support for Qwen3 MoE, fix GroupedGemmParallel for 2D meshes
✅ Merged#45473 — Fix EP routing: RouterParallel shape, tp_plan property, grouped_mm sentinels (3 bugs that combined to produce silently wrong logits at every EP > 1; first regression tests for EP)
-✅ Merged#45433 — Integrate the kernels-community/sonic-moe CuteDSL fused MoE kernel as a selectable _experts_implementation. Drop-in for grouped_mm; +23 % steady-state MFU vs grouped_mm on 16k EP=8 (single biggest kernel win on this stack).
✅ Merged#45621 — sonicmoe kernel-side sentinel fix. Drops the wrapper-level expert_ids.clamp workaround; +5–8 pp peak MFU at 16k–32k EP=8. TRL-side contribution from this work: the grouped_mm_experts_forward wrapper-level masked_fill_ pre/post-mask pair (integrations/moe.py) that pairs with the kernel fix to keep gradients clean on EP-sentinel rows.
🟡 Open#45994 -- fix reported loss over-counting under TP and EP.
⏳ Planned — tp_plan loader rank-0 gate (Layer 1 of the duplicate-load issue: every TP/EP rank reads the full dense replica from disk; ~480 GB redundant disk I/O / node on 30B EP=8). Draft at benchmark/upstream_issue_tp_plan_duplicate_load.md.
⏳ Planned — ValueError guard when enable_expert_parallel=True and cpu_ram_efficient_loading=True are combined
⏳ Planned — DS-Z2 + EP transformers PR (4 patches: tensor_parallel.py:GroupedGemmParallel.post_shard_wrap backend branch, Trainer.create_accelerator_and_postprocess MoE-group setup, Trainer.create_optimizer MoE param split, Trainer._clip_grad_norm cross-mesh skip) + a post-deepspeed.initialize engine attribute patch that replaces the originally-planned DeepSpeed-side change
⏳ Planned — wire DistributedConfig(enable_expert_parallel=True) intoaccelerate.state.parallelism_config.ep_size once accelerate exposes that field
accelerate
✅ Merged#4022 -- Fix region compilation fsdp2. Enable torch_compile=True.
⏳ Planned — _prepare_tp early-return when model.has_ep (post-#45662 EP params become DTensors and trigger an ImportError on ReplicateParallel; single-line check)
⏳ Planned — first-class ep_size field in ParallelismConfig + submesh_ep_size divisor in prepare_data_loader (mirrors existing TP handling exactly). Today TRL piggy-backs on the TP-replication path by exposing the EP mesh as "tp". Highest-leverage upstream PR remaining.
⏳ Planned — fsdp2_prepare_model capture/restore around the meta-move so EP ignored_params survive the FSDP rank-0 broadcast (paired with the transformers cpu_ram_efficient_loading work above)
TRL
The benchmark-sft-moe branch will be split into a series of PRs againstmain. Order is independence-first — anything that can ship alone goes first.
✅ MergedOpen#5698compute_flops_per_token + compute_mfu in trl/trainer/utils.py. Pure-Python, no SFTTrainer coupling. Includes causal-correct attention and the embed_flops=0 / lm_head_flops=2*V*h accounting fix.
⏳ enable_expert_parallel + expert_parallel_size + experts_implementation config fields and the EP branch in SFTTrainer.__init__ (depends on the transformers PRs landing)
⏳ SP --pad_to_multiple_of auto-default when accelerator.parallelism_config.sp_size > 1
✅ Merged#5575 by @qgallouedec — Chunked cross-entropy loss for SFT (up to −50 % VRAM). Load-bearing for this stack: it's what frees the ~20 GB lm_head logit tensor and lets the EP-replicated expert buffer fit at 32k → 128k context, which unlocks every long-context champion in the table above.
Known issues / open blockers
Status: 🔧 active workaround · 🟡 local, upstream fix might be different· ⛔ not yet fixed.
🔧 EP-aware DataLoader sharding — Local fix in transformers/trainer.py that exposes the EP mesh as "tp" so accelerate's existing TP-replication branch in prepare_data_loader divides correctly. Without it, every world rank gets a unique micro-batch but the 8 ranks of an EP group must see the same batch (EP shards experts only, not data) → silent NCCL hang on the EP all-reduce after a random number of steps. Highest-leverage upstream cleanup remaining.
🟡 _clip_grad_norm cross-mesh failure — clip_grad_norm_ → _foreach_norm stacks per-param norms; stacking DTensors on different meshes (EP mesh + FSDP DP mesh) errors with RuntimeError: All operands in aten.stack.default must have the same mesh. Local skip returns tensor(0.0) for telemetry — gradients are not actually clipped. Fine for benchmarks, unsafe for production. Proper fix is in PyTorch.
⛔ FSDP + EP + compile Adam crash: _group_tensors_by_device_and_dtype strict-asserts grouped tensors share device+dtype; EP DTensors (EP mesh, size 8) and FSDP DP DTensors (size 16) trip the assert. DS-Z2+EP+compile works because DS uses plain nn.Parameter (no DTensor mesh).
⛔ FA3 + CP incompatible : accelerate hard-guards CP to sdpa attention. Long-context MoE has to choose between FA3 throughput (SP path) and ring-attention seq sharding (CP path). Cost: ~5 pp MFU at 32k FSDP+EP+CP=2.
⛔ Loss-zero / NaN at large EP + long-ctx: at "≥32k context AND large total rank count", loss.item() PRE-backward is NaN from step 1, masked downstream so the trainer reports loss=0. NaN appears inside the optimizer step for the FSDP/DS-sharded entry-side params (embed_tokens.weight + layers.0.self_attn.{q,k,v,o}_proj.weight). Forward + backward math is clean; bug is in (grad → optim_state → param). High priority — the 81 % window MFU rows on this stack are throughput-real but training-incorrect today. Healthy convergence at 64k 4n EP=32 (1.87 loss) and at SP-path long-context (1.56 loss).
How to reproduce
SFTTrainer with --enable_expert_parallel requires patched checkouts of transformers and accelerate until the PRs in §transformers and §accelerate land. The fork branches below carry exactly the local fixes
Uses uv — install with curl -LsSf https://astral.sh/uv/install.sh | sh if you don't already have it.
# 1. TRL — this branch
git clone --branch benchmark-sft-moe https://github.com/huggingface/trl.git
cd trl
# 2. Fresh venv in the TRL repo
uv venv .venv --python 3.11
source .venv/bin/activate
uv pip install -e .# 3. Patched transformers — the rebased fork carrying everything in §transformers# that is not yet upstream (EP+FSDP DTensor wrap, DS-Z2+EP trainer hooks,# `_configure_distributed_model` monkey-patch for stock DS, sonicmoe wrapper).
git clone --branch ds-ep-integration https://github.com/AmineDiro/transformers.git ../transformers
uv pip install -e ../transformers --no-deps --reinstall
# 4. Patched accelerate — `_prepare_tp` skip when `model._device_mesh` is set# by transformers' EP path (see §accelerate).
git clone --branch ep-fixes https://github.com/AmineDiro/accelerate.git ../accelerate
uv pip install -e ../accelerate --no-deps --reinstall
# 5. Pin torch + runtime extras the kernel + logging stack expects but TRL's# pyproject does not pin. Versions matched against the validated benchmark venv.
uv pip install --reinstall torch==2.10.0
uv pip install deepspeed==0.18.9 liger-kernel==0.7.0 \
kernels==0.13.0 trackio==0.23.0 \
nvidia-cutlass-dsl==4.4.2 apache-tvm-ffi==0.1.9
# 6. The benchmark templates expect a `venv` at /fsx/amine_dirhoussi/trl/.venv.# For an out-of-tree clone, edit benchmark/templates/launch.sh.j2 to point# at your venv (or pass it via run_benchmark.py if you patch the template# rendering to thread `venv_path` through).# 7. Submit the 32k 2-node champion (DS-Z2 + EP=8 + FA3 + sonicmoe + Liger,# 65 % window MFU). Each row in the YAML corresponds to one cell in the# headline table — pick a different --run-index for the others.
python benchmark/run_benchmark.py \
--config benchmark/configs/qwen3_30b_a3b.yaml \
--submit \
--run-index <row-from-table-above>
Repro verification (2026-05-06) : a fresh install of the three forks above ran the four configs below end-to-end. The infrastructure path works as documented and throughput-correct numbers match the headline table (mfu_window peak 84.89 % at 64k 4n EP=32, 66.65 % at 32k 2n EP=8).
Help wanted
Reproduction reports on non-H100 hardware (H200, MI300X, B200) — especially the EP=8 16k FSDP2 baseline
Cross-checks of the MFU formula in trl/trainer/utils.py:compute_flops_per_token against your own training runs
Note
Low Risk
Documentation-only change adding a new markdown tracking page; no runtime or API behavior changes.
Overview
Adds path_30b_longctx.md, a detailed tracking document for scaling TRL SFTTrainer to long-context MoE (Qwen3 30B/235B), including benchmark MFU/loss snapshots, linked upstream PR dependencies across transformers/accelerate/DeepSpeed/Liger, known blockers, and step-by-step reproduction instructions.
Reviewed by Cursor Bugbot for commit 0988933. Bugbot is set up for automated code reviews on this repo. Configure here.
🛣️ Path to 30B MoE long-context SFT training
Tracking issue / draft PR for SFTTrainer scaling to frontier-scale MoE.
Train Qwen3-30B-A3B (and 235B-A22B) end-to-end with TRL's
SFTTrainerat long context (16k → 1M) on 8×H100 nodes, with MFU competitive with dense training. This page tracks the cross-repo work (transformers / accelerate /TRL / DeepSpeed / Liger) that this depends on, the known issues, and the performance numbers we land along the way.Why this is non-trivial
Qwen3-30B-A3B has 128 experts (8 active per token) out of the box it currently doesn't train cleanly on any combination of FSDP2 / DeepSpeed /Expert Parallel / Sequence Parallel / Context Parallel with
trl. This issue tracks the recipes to get there.Status snapshot — Qwen3-30B-A3B (2026-05-06)
Best end-to-end-correct configurations on 2× / 4× / 8× p5.48xlarge nodes (H100 SXM5, 989.5 TFLOPS bf16 peak).
Window MFU peak is the per-log-window throughput; adj. is the causal-corrected value (matches the Llama 2/3 / DS-Ulysses convention, half the attention FLOPs disappear under causal masking).
For Qwen3-235B-A22B early numbers (32k 8n EP=64 + Liger + compile = 70.1 % window MFU, healthy loss).
Hardware / software baseline
benchmark-sft-moebranch (this branch will land via the PRs in §TRL below)qwen3-moe-ep-v2series of fixes (see §transformers below)_prepare_tphas_epskip (§accelerate)transformers/trainer.pyvia a transparent monkey-patch onDeepSpeedEngine._configure_distributed_model(see §DeepSpeed).--liger_kernel_config '{"swiglu":false}')kernels-community/vllm-flash-attn3kernels-community/sonic-moe("revision": "ep-support") selected via--experts_implementation sonicmoeTHUDM/LongAlign-10kpacked with--packing --packing_strategy wrappedLinked PRs
transformers
Core EP support + correctness fixes. The first three are the load-bearing ones. Without them, EP > 1 produces silently wrong expert outputs.
GroupedGemmParallelfor 2D meshesRouterParallelshape,tp_planproperty,grouped_mmsentinels (3 bugs that combined to produce silently wrong logits at every EP > 1; first regression tests for EP)-✅ Merged #45433 — Integrate the
kernels-community/sonic-moeCuteDSL fused MoE kernel as a selectable_experts_implementation. Drop-in forgrouped_mm; +23 % steady-state MFU vsgrouped_mmon 16k EP=8 (single biggest kernel win on this stack).expert_ids.clampworkaround; +5–8 pp peak MFU at 16k–32k EP=8. TRL-side contribution from this work: thegrouped_mm_experts_forwardwrapper-levelmasked_fill_pre/post-mask pair (integrations/moe.py) that pairs with the kernel fix to keep gradients clean on EP-sentinel rows.ignored_paramsboundary)cpu_ram_efficient_loadingfixestp_planloader rank-0 gate (Layer 1 of the duplicate-load issue: every TP/EP rank reads the full dense replica from disk; ~480 GB redundant disk I/O / node on 30B EP=8). Draft atbenchmark/upstream_issue_tp_plan_duplicate_load.md.ValueErrorguard whenenable_expert_parallel=Trueandcpu_ram_efficient_loading=Trueare combinedtensor_parallel.py:GroupedGemmParallel.post_shard_wrapbackend branch,Trainer.create_accelerator_and_postprocessMoE-group setup,Trainer.create_optimizerMoE param split,Trainer._clip_grad_normcross-mesh skip) + a post-deepspeed.initializeengine attribute patch that replaces the originally-planned DeepSpeed-side changeDistributedConfig(enable_expert_parallel=True)intoaccelerate.state.parallelism_config.ep_sizeonce accelerate exposes that fieldaccelerate
torch_compile=True._prepare_tpearly-return whenmodel.has_ep(post-#45662 EP params become DTensors and trigger anImportErroronReplicateParallel; single-line check)ep_sizefield inParallelismConfig+submesh_ep_sizedivisor inprepare_data_loader(mirrors existing TP handling exactly). Today TRL piggy-backs on the TP-replication path by exposing the EP mesh as"tp". Highest-leverage upstream PR remaining.fsdp2_prepare_modelcapture/restore around the meta-move so EPignored_paramssurvive the FSDP rank-0 broadcast (paired with the transformerscpu_ram_efficient_loadingwork above)TRL
The
benchmark-sft-moebranch will be split into a series of PRs againstmain. Order is independence-first — anything that can ship alone goes first.compute_flops_per_token+compute_mfuintrl/trainer/utils.py. Pure-Python, noSFTTrainercoupling. Includes causal-correct attention and theembed_flops=0/lm_head_flops=2*V*haccounting fix.SFTTrainer.log()MFU integration (window + cumulative; correctsnum_input_tokens_seenovercount bycp_size × sp_size)enable_expert_parallel+expert_parallel_size+experts_implementationconfig fields and the EP branch inSFTTrainer.__init__(depends on the transformers PRs landing)--pad_to_multiple_ofauto-default whenaccelerator.parallelism_config.sp_size > 1Known issues / open blockers
Status: 🔧 active workaround · 🟡 local, upstream fix might be different· ⛔ not yet fixed.
transformers/trainer.pythat exposes the EP mesh as"tp"so accelerate's existing TP-replication branch inprepare_data_loaderdivides correctly. Without it, every world rank gets a unique micro-batch but the 8 ranks of an EP group must see the same batch (EP shards experts only, not data) → silent NCCL hang on the EP all-reduce after a random number of steps. Highest-leverage upstream cleanup remaining._clip_grad_normcross-mesh failure —clip_grad_norm_→_foreach_normstacks per-param norms; stacking DTensors on different meshes (EP mesh + FSDP DP mesh) errors withRuntimeError: All operands in aten.stack.default must have the same mesh. Local skip returnstensor(0.0)for telemetry — gradients are not actually clipped. Fine for benchmarks, unsafe for production. Proper fix is in PyTorch._group_tensors_by_device_and_dtypestrict-asserts grouped tensors share device+dtype; EP DTensors (EP mesh, size 8) and FSDP DP DTensors (size 16) trip the assert. DS-Z2+EP+compile works because DS uses plainnn.Parameter(no DTensor mesh).loss.item()PRE-backward is NaN from step 1, masked downstream so the trainer reportsloss=0. NaN appears inside the optimizer step for the FSDP/DS-sharded entry-side params (embed_tokens.weight+layers.0.self_attn.{q,k,v,o}_proj.weight). Forward + backward math is clean; bug is in(grad → optim_state → param). High priority — the 81 % window MFU rows on this stack are throughput-real but training-incorrect today. Healthy convergence at 64k 4n EP=32 (1.87 loss) and at SP-path long-context (1.56 loss).How to reproduce
SFTTrainerwith--enable_expert_parallelrequires patched checkouts oftransformersandaccelerateuntil the PRs in §transformers and §accelerate land. The fork branches below carry exactly the local fixesUses
uv— install withcurl -LsSf https://astral.sh/uv/install.sh | shif you don't already have it.Help wanted
trl/trainer/utils.py:compute_flops_per_tokenagainst your own training runsNote
Low Risk
Documentation-only change adding a new markdown tracking page; no runtime or API behavior changes.
Overview
Adds
path_30b_longctx.md, a detailed tracking document for scaling TRLSFTTrainerto long-context MoE (Qwen3 30B/235B), including benchmark MFU/loss snapshots, linked upstream PR dependencies acrosstransformers/accelerate/DeepSpeed/Liger, known blockers, and step-by-step reproduction instructions.Reviewed by Cursor Bugbot for commit 0988933. Bugbot is set up for automated code reviews on this repo. Configure here.