Skip to content

[nemotron] Add NemotronH (Mamba2-Transformer hybrid MoE) text-LLM sup…#1516

Draft
YunchaoYang wants to merge 2 commits into
mainfrom
yy/nemotron-h-phase1-pr
Draft

[nemotron] Add NemotronH (Mamba2-Transformer hybrid MoE) text-LLM sup…#1516
YunchaoYang wants to merge 2 commits into
mainfrom
yy/nemotron-h-phase1-pr

Conversation

@YunchaoYang

Copy link
Copy Markdown
Contributor

…port

Phase 1 of NVIDIA Nemotron-3-Nano-Omni-30B-A3B: text-only LLM with a 52-layer 3-way hybrid decoder.

  • Architecture: single-purpose blocks per layer (23 Mamba2 SSM, 23 MoE FFN, 6 GQA attention). Pattern is configurable via NemotronHConfig.hybrid_override_pattern.
  • Mamba2 mixer: chunked scan via mamba_ssm triton kernels with a PyTorch fallback; incremental decoding via selective_state_update. Group-wise gated RMSNorm (8 groups of 512), gate-first ordering to match HF.
  • MoE: 128 experts, top-6 sigmoid routing with e_score_correction_bias, plus one always-on shared expert. Squared-ReLU activation. Float32 output accumulation. Batched expert_hit transfer to avoid per-iteration CPU-GPU sync.
  • GQA attention: 32 query heads, 2 KV heads, full head dim, no RoPE (Mamba handles position).
  • HuggingFace interop: bidirectional state dict conversion handles language_model.backbone.* key prefix, fused projections, expert stacking, and GQA head layout.
  • Tensor parallelism: shard_specs for attention, embedding, MoE (column- sharded inner FFN + row-sharded output + TP all-reduce), and final_proj.
  • Asset card and composition/models.py registration.
  • 60/60 unit tests pass (config, decoder layer, Mamba2 mixer, MoE).

What does this PR do? Please describe:
A summary of the change or the issue that is fixed.

Fixes #{issue number}

Does your PR introduce any breaking changes? If yes, please list them:
List of all backwards-incompatible changes.

Check list:

  • Was the content of this PR discussed and approved via a GitHub issue? (no need for typos or documentation improvements)
  • Did you read the contributor guideline?
  • Did you make sure that your PR does only one thing instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests?
  • Did you verify new and existing tests pass locally with your changes?
  • Did you update the CHANGELOG? (no need for typos, documentation, or minor internal changes)

…port

Phase 1 of NVIDIA Nemotron-3-Nano-Omni-30B-A3B: text-only LLM with a 52-layer
3-way hybrid decoder.

- Architecture: single-purpose blocks per layer (23 Mamba2 SSM, 23 MoE FFN,
  6 GQA attention). Pattern is configurable via NemotronHConfig.hybrid_override_pattern.
- Mamba2 mixer: chunked scan via mamba_ssm triton kernels with a PyTorch
  fallback; incremental decoding via selective_state_update. Group-wise gated
  RMSNorm (8 groups of 512), gate-first ordering to match HF.
- MoE: 128 experts, top-6 sigmoid routing with e_score_correction_bias, plus
  one always-on shared expert. Squared-ReLU activation. Float32 output
  accumulation. Batched expert_hit transfer to avoid per-iteration CPU-GPU sync.
- GQA attention: 32 query heads, 2 KV heads, full head dim, no RoPE (Mamba
  handles position).
- HuggingFace interop: bidirectional state dict conversion handles
  language_model.backbone.* key prefix, fused projections, expert stacking, and
  GQA head layout.
- Tensor parallelism: shard_specs for attention, embedding, MoE (column-
  sharded inner FFN + row-sharded output + TP all-reduce), and final_proj.
- Asset card and composition/models.py registration.
- 60/60 unit tests pass (config, decoder layer, Mamba2 mixer, MoE).
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 4, 2026
These three asserts (block_type=="mamba", "attention", "moe") fail when an
external wrap policy adds NemotronHMamba2Mixer / NemotronHMoE /
StandardMultiheadAttention to FSDP's auto-wrap. After FSDP child-wrapping,
``self.mixer`` is an ``FSDP(NemotronHMamba2Mixer)`` proxy rather than the
bare class, so ``isinstance(self.mixer, NemotronHMamba2Mixer)`` evaluates
to False and the assert kills the forward pass on the first step.

``block_type`` (set at construction in NemotronHFactory.create_decoder_layer)
is already the authoritative tag for which branch to take; the isinstance
checks were redundant guards. Replace them with ``typing.cast`` so mypy's
union-narrowing stays satisfied without runtime cost.

No behavior change for the default (single-class) wrap policy used by the
existing recipes; unblocks wrapping the inner mixers separately when
sharding the 30B MoE block.

60/60 unit tests pass.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant