diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d4325637..8f57787ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ All notable changes to fairseq2 are documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [0.8.1] - Unreleased +- Qwen 3.5 model family (0.8B, 2B, 9B, 27B dense and 35B-A3B MoE) with base and instruction-tuned variants. Features hybrid GatedDeltaNet linear attention (75%) + full attention (25%) architecture, partial RoPE, QK-norm, output gating, Top-K MoE routing with shared experts, RMSNorm 1+w convention, bidirectional HuggingFace state dict conversion, and SFT/pretraining recipe configs. - Gemma 4 model family (E4B, 31B, 26B-A4B) with base and instruction-tuned variants. Includes decoder with Per-Layer Embeddings (PLE), partial RoPE, KV sharing across sliding/global attention layers, Mixture-of-Experts (26B-A4B), QK/V-norm, logit soft-capping, audio tower (Conformer encoder for multimodal E4B), bidirectional HuggingFace state dict conversion, FSDP/activation checkpointing/tensor parallel support, and SFT recipe configs. - Bump transformers~=v5.5 and loosen huggingface_hub upper bound. (#1508) - Fixed typo in WerMetric: use `hyp_seqs` instead of `ref_seqs` for `hyp_seqs_list`. (#1506) diff --git a/recipes/lm/sft/configs/qwen35_0.8b_gsm8k.yaml b/recipes/lm/sft/configs/qwen35_0.8b_gsm8k.yaml new file mode 100644 index 000000000..23e5e6f85 --- /dev/null +++ b/recipes/lm/sft/configs/qwen35_0.8b_gsm8k.yaml @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Qwen 3.5 0.8B GSM8K SFT Fine-tuning Config +# +# Validates training recipe integration and loss convergence for the Qwen 3.5 +# model on the GSM8K math reasoning dataset. +# +# Usage: +# torchrun --standalone --nproc_per_node=8 -m recipes.lm.sft \ +# --config-file recipes/lm/sft/configs/qwen35_0.8b_gsm8k.yaml \ +# /path/to/output_dir + +model: + name: "qwen35_0.8b" + dtype: bfloat16 + config_overrides: + pad_idx: 248044 + +tokenizer: + name: "qwen35_0.8b" + config_overrides: + use_im_end: true + +dataset: + max_seq_len: 4096 + max_num_tokens: 8192 + valid_split: "sft_test" + chat_mode: false + config_overrides: + sources: + train: + - path: "hg://facebook/fairseq2-lm-gsm8k" + split: "sft_train" + weight: 1.0 + sft_test: + - path: "hg://facebook/fairseq2-lm-gsm8k" + split: "sft_test" + weight: 1.0 + +trainer: + data_parallelism: fsdp + max_grad_norm: 1.0 + mixed_precision: + mode: static + dtype: bfloat16 + +optimizer: + name: adamw + config: + lr: 2.0e-5 + betas: [0.9, 0.95] + weight_decay: 0.1 + impl: fused + +lr_scheduler: + name: cosine_annealing + config: + final_lr_scale: 0.1 + num_warmup_steps: 100 + +regime: + num_steps: 100000 + checkpoint_every_n_steps: 100 + validate_every_n_steps: 100 + keep_last_n_checkpoints: 10 + publish_metrics_every_n_steps: 1 + save_model_only: false + +common: + seed: 0 + metric_recorders: + wandb: + enabled: true + entity: "yunchaoyang1" + project: "fairseq2" + tensorboard: + enabled: false diff --git a/src/fairseq2/assets/cards/models/qwen35.yaml b/src/fairseq2/assets/cards/models/qwen35.yaml new file mode 100644 index 000000000..c72de8b42 --- /dev/null +++ b/src/fairseq2/assets/cards/models/qwen35.yaml @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +name: qwen35_0.8b +model_family: qwen3_5 +model_arch: qwen35_0.8b +checkpoint: "/checkpoint/smallomnillm/shared/models/Qwen3.5-0.8B" +tokenizer: "hg://Qwen/Qwen3.5-0.8B" +tokenizer_family: qwen + +--- + +name: qwen35_2b +model_family: qwen3_5 +model_arch: qwen35_2b +checkpoint: "hg://Qwen/Qwen3.5-2B" +tokenizer: "hg://Qwen/Qwen3.5-2B" +tokenizer_family: qwen + +--- + +name: qwen35_2b_base +model_family: qwen3_5 +model_arch: qwen35_2b +checkpoint: "hg://Qwen/Qwen3.5-2B-Base" +tokenizer: "hg://Qwen/Qwen3.5-2B-Base" +tokenizer_family: qwen + +--- + +name: qwen35_9b +model_family: qwen3_5 +model_arch: qwen35_9b +checkpoint: "hg://Qwen/Qwen3.5-9B" +tokenizer: "hg://Qwen/Qwen3.5-9B" +tokenizer_family: qwen + +--- + +name: qwen35_9b_base +model_family: qwen3_5 +model_arch: qwen35_9b +checkpoint: "hg://Qwen/Qwen3.5-9B-Base" +tokenizer: "hg://Qwen/Qwen3.5-9B-Base" +tokenizer_family: qwen + +--- + +name: qwen35_27b +model_family: qwen3_5 +model_arch: qwen35_27b +checkpoint: "hg://Qwen/Qwen3.5-27B" +tokenizer: "hg://Qwen/Qwen3.5-27B" +tokenizer_family: qwen + +--- + +name: qwen35_moe_35b_a3b +model_family: qwen3_5_moe +model_arch: qwen35_moe_35b_a3b +checkpoint: "hg://Qwen/Qwen3.5-35B-A3B" +tokenizer: "hg://Qwen/Qwen3.5-35B-A3B" +tokenizer_family: qwen + +--- + +name: qwen35_moe_35b_a3b_base +model_family: qwen3_5_moe +model_arch: qwen35_moe_35b_a3b +checkpoint: "hg://Qwen/Qwen3.5-35B-A3B-Base" +tokenizer: "hg://Qwen/Qwen3.5-35B-A3B-Base" +tokenizer_family: qwen diff --git a/src/fairseq2/composition/models.py b/src/fairseq2/composition/models.py index 3be08dbbf..8537ea05d 100644 --- a/src/fairseq2/composition/models.py +++ b/src/fairseq2/composition/models.py @@ -107,11 +107,23 @@ register_olmo_configs, ) from fairseq2.models.qwen import ( + QWEN35_FAMILY, + QWEN35_MOE_FAMILY, QWEN_FAMILY, + Qwen35Config, + Qwen35MoeConfig, QwenConfig, + _Qwen35HuggingFaceConverter, + _Qwen35MoeHuggingFaceConverter, _QwenHuggingFaceConverter, + convert_qwen35_moe_state_dict, + convert_qwen35_state_dict, convert_qwen_state_dict, + create_qwen35_model, + create_qwen35_moe_model, create_qwen_model, + register_qwen35_configs, + register_qwen35_moe_configs, register_qwen_configs, ) from fairseq2.models.s2t_conformer import ( @@ -417,6 +429,44 @@ def _register_model_families(container: DependencyContainer) -> None: HuggingFaceConverter, _QwenHuggingFaceConverter, key=QWEN_FAMILY ) + # Qwen 3.5 + register_model_family( + container, + QWEN35_FAMILY, + kls=TransformerLM, + config_kls=Qwen35Config, + factory=create_qwen35_model, + state_dict_converter=convert_qwen35_state_dict, + compiler=compile_transformer_lm, + fsdp_applier=apply_fsdp_to_transformer_lm, + layerwise_ac_applier=apply_ac_to_transformer_lm, + ) + + register_qwen35_configs(container) + + container.register_type( + HuggingFaceConverter, _Qwen35HuggingFaceConverter, key=QWEN35_FAMILY + ) + + # Qwen 3.5 MoE + register_model_family( + container, + QWEN35_MOE_FAMILY, + kls=TransformerLM, + config_kls=Qwen35MoeConfig, + factory=create_qwen35_moe_model, + state_dict_converter=convert_qwen35_moe_state_dict, + compiler=compile_transformer_lm, + fsdp_applier=apply_fsdp_to_transformer_lm, + layerwise_ac_applier=apply_ac_to_transformer_lm, + ) + + register_qwen35_moe_configs(container) + + container.register_type( + HuggingFaceConverter, _Qwen35MoeHuggingFaceConverter, key=QWEN35_MOE_FAMILY + ) + # S2T Conformer register_model_family( container, diff --git a/src/fairseq2/models/qwen/__init__.py b/src/fairseq2/models/qwen/__init__.py index 0d7d28179..135050769 100644 --- a/src/fairseq2/models/qwen/__init__.py +++ b/src/fairseq2/models/qwen/__init__.py @@ -6,20 +6,57 @@ from __future__ import annotations +from fairseq2.models.qwen.config import QWEN35_FAMILY as QWEN35_FAMILY +from fairseq2.models.qwen.config import QWEN35_MOE_FAMILY as QWEN35_MOE_FAMILY from fairseq2.models.qwen.config import QWEN_FAMILY as QWEN_FAMILY +from fairseq2.models.qwen.config import Qwen35Config as Qwen35Config +from fairseq2.models.qwen.config import Qwen35MoeConfig as Qwen35MoeConfig from fairseq2.models.qwen.config import QwenConfig as QwenConfig +from fairseq2.models.qwen.config import ( + register_qwen35_configs as register_qwen35_configs, +) +from fairseq2.models.qwen.config import ( + register_qwen35_moe_configs as register_qwen35_moe_configs, +) from fairseq2.models.qwen.config import register_qwen_configs as register_qwen_configs +from fairseq2.models.qwen.factory import Qwen35Factory as Qwen35Factory +from fairseq2.models.qwen.factory import Qwen35MoeFactory as Qwen35MoeFactory from fairseq2.models.qwen.factory import QwenFactory as QwenFactory +from fairseq2.models.qwen.factory import create_qwen35_model as create_qwen35_model +from fairseq2.models.qwen.factory import ( + create_qwen35_moe_model as create_qwen35_moe_model, +) from fairseq2.models.qwen.factory import create_qwen_model as create_qwen_model +from fairseq2.models.qwen.hub import get_qwen35_model_hub as get_qwen35_model_hub +from fairseq2.models.qwen.hub import ( + get_qwen35_moe_model_hub as get_qwen35_moe_model_hub, +) +from fairseq2.models.qwen.hub import ( + get_qwen35_moe_tokenizer_hub as get_qwen35_moe_tokenizer_hub, +) +from fairseq2.models.qwen.hub import ( + get_qwen35_tokenizer_hub as get_qwen35_tokenizer_hub, +) from fairseq2.models.qwen.hub import get_qwen_model_hub as get_qwen_model_hub from fairseq2.models.qwen.hub import get_qwen_tokenizer_hub as get_qwen_tokenizer_hub +from fairseq2.models.qwen.interop import ( + _Qwen35HuggingFaceConverter as _Qwen35HuggingFaceConverter, +) +from fairseq2.models.qwen.interop import ( + _Qwen35MoeHuggingFaceConverter as _Qwen35MoeHuggingFaceConverter, +) from fairseq2.models.qwen.interop import ( _QwenHuggingFaceConverter as _QwenHuggingFaceConverter, ) +from fairseq2.models.qwen.interop import ( + convert_qwen35_moe_state_dict as convert_qwen35_moe_state_dict, +) +from fairseq2.models.qwen.interop import ( + convert_qwen35_state_dict as convert_qwen35_state_dict, +) from fairseq2.models.qwen.interop import ( convert_qwen_state_dict as convert_qwen_state_dict, ) -from fairseq2.models.qwen.sharder import get_qwen_shard_specs as get_qwen_shard_specs from fairseq2.models.qwen.tokenizer import QwenTokenizer as QwenTokenizer from fairseq2.models.qwen.tokenizer import QwenTokenizerConfig as QwenTokenizerConfig from fairseq2.models.qwen.tokenizer import load_qwen_tokenizer as load_qwen_tokenizer diff --git a/src/fairseq2/models/qwen/attention.py b/src/fairseq2/models/qwen/attention.py new file mode 100644 index 000000000..afe0de1e9 --- /dev/null +++ b/src/fairseq2/models/qwen/attention.py @@ -0,0 +1,213 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Gated multi-head attention for Qwen 3.5. + +Differs from ``StandardMultiheadAttention`` in three ways: + +1. The Q projection is doubled — half is the query, half is an output gate. +2. Partial RoPE: only the first ``encoding_dim`` dimensions are rotated. +3. Output gating: ``attn_output = attn_output * sigmoid(gate)``. + +Reference: HuggingFace ``modeling_qwen3_5.py`` ``Qwen3_5Attention`` lines 707-779. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Final + +import torch +from torch import Tensor + +from fairseq2.models.transformer import ( + SDPA, + AttentionBiasCache, + AttentionState, + AttentionStateFactory, + FullAttentionState, + MultiheadAttention, +) +from fairseq2.nn import ( + BatchLayout, + IncrementalStateBag, + LayerNorm, + Linear, + PositionEncoder, +) +from fairseq2.ops import repeat_interleave + + +class Qwen35Attention(MultiheadAttention): + """Gated multi-head attention for Qwen 3.5 full-attention layers. + + Key differences from :class:`StandardMultiheadAttention`: + + * **Doubled Q projection** — ``q_proj`` outputs ``num_heads * head_dim * 2``; + the second half is an output gate. + * **Partial RoPE** — only the first ``encoding_dim`` (typically 64) of the + ``head_dim`` (typically 256) are rotated. The rest pass through. + * **Output gating** — ``attn_output * sigmoid(gate)`` before ``output_proj``. + * **QK-Norm** on per-head dimension (after unflatten). + + Reference: ``modeling_qwen3_5.py`` lines 707-779. + """ + + num_heads: Final[int] + num_key_value_heads: Final[int] + num_query_groups: Final[int] + head_dim: Final[int] + + def __init__( + self, + model_dim: int, + num_heads: int, + sdpa: SDPA, + *, + head_dim: int = 256, + num_key_value_heads: int | None = None, + pos_encoder: PositionEncoder | None = None, + q_norm: LayerNorm | None = None, + k_norm: LayerNorm | None = None, + state_factory: AttentionStateFactory | None = None, + qkv_proj_init_fn: Callable[[Linear], None] | None = None, + output_proj_init_fn: Callable[[Linear], None] | None = None, + ) -> None: + super().__init__() + + self.num_heads = num_heads + self.head_dim = head_dim + + if num_key_value_heads is None: + num_key_value_heads = num_heads + self.num_key_value_heads = num_key_value_heads + self.num_query_groups = num_heads // num_key_value_heads + + # Q projection is DOUBLED — half query, half gate. + # HF: nn.Linear(hidden, num_heads * head_dim * 2, bias=False) + self.q_proj = Linear( + model_dim, + num_heads * head_dim * 2, + bias=False, + init_fn=qkv_proj_init_fn, + ) + + self.k_proj = Linear( + model_dim, + num_key_value_heads * head_dim, + bias=False, + init_fn=qkv_proj_init_fn, + ) + + self.v_proj = Linear( + model_dim, + num_key_value_heads * head_dim, + bias=False, + init_fn=qkv_proj_init_fn, + ) + + self.output_proj = Linear( + num_heads * head_dim, + model_dim, + bias=False, + init_fn=output_proj_init_fn, + ) + + self.q_norm = q_norm + self.k_norm = k_norm + self.pos_encoder = pos_encoder + self.sdpa = sdpa + self.state_factory = state_factory + + def forward( + self, + seqs: Tensor, + seqs_layout: BatchLayout, + keys: Tensor, + keys_layout: BatchLayout, + values: Tensor, + bias_cache: AttentionBiasCache, + *, + state_bag: IncrementalStateBag | None = None, + ) -> Tensor: + # -- Q projection: split into query + gate -- + # (B, S, num_heads * head_dim * 2) + q_combined = self.q_proj(seqs) + + # (B, S, num_heads, head_dim * 2) -> split along last dim + q_combined = q_combined.unflatten(-1, (self.num_heads, self.head_dim * 2)) + q, gate = q_combined.chunk(2, dim=-1) + # q: (B, S, num_heads, head_dim) + # gate: (B, S, num_heads, head_dim) + + # Flatten gate to (B, S, num_heads * head_dim) for later element-wise gating. + gate = gate.flatten(-2) # (B, S, num_heads * head_dim) + + # -- K, V projections -- + k = self.k_proj(keys) + v = self.v_proj(values) + k = k.unflatten(-1, (self.num_key_value_heads, self.head_dim)) + v = v.unflatten(-1, (self.num_key_value_heads, self.head_dim)) + + # -- QK-Norm (per head dim, after unflatten) -- + if self.q_norm is not None: + q = self.q_norm(q) + if self.k_norm is not None: + k = self.k_norm(k) + + # -- Partial RoPE -- + # Only the first `encoding_dim` dimensions of each head are rotated. + # The rest pass through unchanged. + if self.pos_encoder is not None: + encoding_dim = self.pos_encoder.encoding_dim + + if encoding_dim < self.head_dim: + # Split into rotary and pass-through parts. + q_rot = q[..., :encoding_dim] + q_pass = q[..., encoding_dim:] + k_rot = k[..., :encoding_dim] + k_pass = k[..., encoding_dim:] + + q_rot = self.pos_encoder(q_rot, seqs_layout, state_bag=state_bag) + k_rot = self.pos_encoder(k_rot, keys_layout, state_bag=state_bag) + + q = torch.cat([q_rot, q_pass], dim=-1) + k = torch.cat([k_rot, k_pass], dim=-1) + else: + # Full rotation (encoding_dim == head_dim). + q = self.pos_encoder(q, seqs_layout, state_bag=state_bag) + k = self.pos_encoder(k, keys_layout, state_bag=state_bag) + + # -- KV cache management -- + if not self.training and state_bag is not None: + state = state_bag.maybe_get_state(self, AttentionState) + if state is None: + state_factory = self.state_factory or FullAttentionState + state = state_factory( + k, v, state_bag.max_num_steps, state_bag.capacity_increment + ) + state_bag.set_state(self, state) + else: + state.append(k, v) + k, v = state.get() + keys_layout = BatchLayout.of(k) + + # -- GQA expansion -- + if self.num_query_groups > 1: + k = repeat_interleave(k, dim=-2, repeat=self.num_query_groups) + v = repeat_interleave(v, dim=-2, repeat=self.num_query_groups) + + # -- Scaled dot-product attention -- + # q, k, v: (B, S, H, D) + attn_output, _ = self.sdpa(q, seqs_layout, k, keys_layout, v, bias_cache) + + # -- Output gating -- + # attn_output: (B, S, H, D) -> (B, S, H * D) + attn_output = attn_output.flatten(-2) + attn_output = attn_output * torch.sigmoid(gate) + + # -- Output projection -- + return self.output_proj(attn_output) diff --git a/src/fairseq2/models/qwen/config.py b/src/fairseq2/models/qwen/config.py index 98c064419..13d23930f 100644 --- a/src/fairseq2/models/qwen/config.py +++ b/src/fairseq2/models/qwen/config.py @@ -13,6 +13,7 @@ from fairseq2.runtime.dependency import DependencyContainer QWEN_FAMILY: Final = "qwen" +QWEN35_FAMILY: Final = "qwen3_5" @dataclass(kw_only=True) @@ -62,6 +63,174 @@ class QwenConfig: dropout_p: float = 0.0 """The dropout probability on outputs of Transformer layers.""" + pad_idx: int | None = None + """The index of the pad symbol in the vocabulary.""" + + +# --------------------------------------------------------------------------- + + +@dataclass(kw_only=True) +class Qwen35Config: + """Holds the configuration of a Qwen 3.5 dense model.""" + + model_dim: int = 4096 + max_seq_len: int = 32_768 + vocab_size: int = 248_320 + tied_embeddings: bool = False + num_layers: int = 32 + num_attn_heads: int = 16 + num_key_value_heads: int = 4 + head_dim: int = 256 + ffn_inner_dim: int = 12_288 + partial_rotary_factor: float = 0.25 + rope_theta: float = 1_000_000.0 + dropout_p: float = 0.0 + layer_types: list[str] | None = None + full_attention_interval: int = 4 + linear_conv_kernel_dim: int = 4 + linear_key_head_dim: int = 128 + linear_value_head_dim: int = 128 + linear_num_key_heads: int = 16 + linear_num_value_heads: int = 32 + + pad_idx: int | None = None + """The index of the pad symbol in the vocabulary.""" + + def __post_init__(self) -> None: + if self.layer_types is None: + interval = self.full_attention_interval + self.layer_types = [ + "linear_attention" if bool((i + 1) % interval) else "full_attention" + for i in range(self.num_layers) + ] + + +def register_qwen35_configs(container: DependencyContainer) -> None: + arch = ConfigRegistrar(container, Qwen35Config) + + @arch("qwen35_0.8b") + def qwen35_0p8b() -> Qwen35Config: + return Qwen35Config( + model_dim=1024, + max_seq_len=262_144, + vocab_size=248_320, + tied_embeddings=True, + num_layers=24, + num_attn_heads=8, + num_key_value_heads=2, + head_dim=256, + ffn_inner_dim=3584, + partial_rotary_factor=0.25, + rope_theta=10_000_000.0, + full_attention_interval=4, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=16, + ) + + @arch("qwen35_2b") + def qwen35_2b() -> Qwen35Config: + return Qwen35Config( + model_dim=2048, + max_seq_len=262_144, + vocab_size=248_320, + tied_embeddings=True, + num_layers=24, + num_attn_heads=8, + num_key_value_heads=2, + head_dim=256, + ffn_inner_dim=6144, + partial_rotary_factor=0.25, + rope_theta=10_000_000.0, + full_attention_interval=4, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=16, + ) + + @arch("qwen35_9b") + def qwen35_9b() -> Qwen35Config: + return Qwen35Config( + model_dim=4096, + max_seq_len=262_144, + vocab_size=248_320, + tied_embeddings=False, + num_layers=32, + num_attn_heads=16, + num_key_value_heads=4, + head_dim=256, + ffn_inner_dim=12_288, + partial_rotary_factor=0.25, + rope_theta=10_000_000.0, + full_attention_interval=4, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=32, + ) + + @arch("qwen35_27b") + def qwen35_27b() -> Qwen35Config: + return Qwen35Config( + model_dim=5120, + max_seq_len=262_144, + vocab_size=248_320, + tied_embeddings=False, + num_layers=64, + num_attn_heads=24, + num_key_value_heads=4, + head_dim=256, + ffn_inner_dim=17_408, + partial_rotary_factor=0.25, + rope_theta=10_000_000.0, + full_attention_interval=4, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=48, + ) + + +# --------------------------------------------------------------------------- +# Qwen 3.5 MoE Config +# --------------------------------------------------------------------------- + +QWEN35_MOE_FAMILY: Final = "qwen3_5_moe" + + +@dataclass(kw_only=True) +class Qwen35MoeConfig(Qwen35Config): + """Holds the configuration of a Qwen 3.5 MoE model.""" + + model_dim: int = 2048 + num_layers: int = 40 + num_key_value_heads: int = 2 + num_experts: int = 256 + num_experts_per_tok: int = 8 + moe_intermediate_size: int = 512 + shared_expert_intermediate_size: int = 512 + router_aux_loss_coef: float = 0.001 + + +def register_qwen35_moe_configs(container: DependencyContainer) -> None: + arch = ConfigRegistrar(container, Qwen35MoeConfig) + + @arch("qwen35_moe_35b_a3b") + def qwen35_moe_35b_a3b() -> Qwen35MoeConfig: + return Qwen35MoeConfig() + + +# --------------------------------------------------------------------------- +# Qwen 2.5 / 3.0 arch configs +# --------------------------------------------------------------------------- + def register_qwen_configs(container: DependencyContainer) -> None: arch = ConfigRegistrar(container, QwenConfig) @@ -76,7 +245,7 @@ def qwen25_3b() -> QwenConfig: config.num_attn_heads = 16 config.num_key_value_heads = 2 config.ffn_inner_dim = 11_008 - config.tied_embeddings = True + config.rope_theta = 1_000_000 return config diff --git a/src/fairseq2/models/qwen/decoder_layer.py b/src/fairseq2/models/qwen/decoder_layer.py new file mode 100644 index 000000000..fa0369cc6 --- /dev/null +++ b/src/fairseq2/models/qwen/decoder_layer.py @@ -0,0 +1,168 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Hybrid decoder layer for Qwen 3.5. + +Each layer holds EITHER a :class:`Qwen35Attention` (full attention with output +gating) OR a :class:`GatedDeltaNet` (linear attention), dispatched by +``layer_type``. The FFN and layer norms are always present. + +Attribute names ``self_attn`` / ``linear_attn`` match HuggingFace for clean +interop key mapping. + +Reference: HuggingFace ``modeling_qwen3_5.py`` ``Qwen3_5DecoderLayer`` +lines 818-870. +""" + +from __future__ import annotations + +from typing import Final, final + +from torch import Tensor +from typing_extensions import override + +from fairseq2.models.qwen.attention import Qwen35Attention +from fairseq2.models.qwen.gated_delta_net import GatedDeltaNet +from fairseq2.models.transformer import ( + AttentionBiasCache, + FeedForwardNetwork, +) +from fairseq2.models.transformer_lm import TransformerLMDecoderLayer +from fairseq2.nn import ( + AdditiveResidualConnect, + BatchLayout, + IncrementalStateBag, + LayerNorm, + ResidualConnect, +) + + +@final +class Qwen35DecoderLayer(TransformerLMDecoderLayer): + """Hybrid decoder layer that dispatches to full or linear attention. + + * ``layer_type == "full_attention"``: uses :attr:`self_attn` + (:class:`Qwen35Attention`). + * ``layer_type == "linear_attention"``: uses :attr:`linear_attn` + (:class:`GatedDeltaNet`). + """ + + layer_type: Final[str] + + def __init__( + self, + layer_type: str, + self_attn: Qwen35Attention | None, + linear_attn: GatedDeltaNet | None, + ffn: FeedForwardNetwork, + self_attn_layer_norm: LayerNorm, + ffn_layer_norm: LayerNorm, + *, + self_attn_residual: ResidualConnect | None = None, + ffn_residual: ResidualConnect | None = None, + ) -> None: + """ + :param layer_type: ``"full_attention"`` or ``"linear_attention"``. + :param self_attn: Gated full attention module (only for full layers). + :param linear_attn: GatedDeltaNet module (only for linear layers). + :param ffn: Feed-forward network (always present). + :param self_attn_layer_norm: Pre-attention layer norm. + :param ffn_layer_norm: Pre-FFN layer norm. + """ + super().__init__() + + self.layer_type = layer_type + + # Register exactly one token mixer — attribute name matters for interop. + self.self_attn: Qwen35Attention | None + self.linear_attn: GatedDeltaNet | None + + if layer_type == "full_attention": + assert self_attn is not None + self.register_module("self_attn", self_attn) + self.register_module("linear_attn", None) + elif layer_type == "linear_attention": + assert linear_attn is not None + self.register_module("self_attn", None) + self.register_module("linear_attn", linear_attn) + else: + raise ValueError( + f"`layer_type` must be 'full_attention' or 'linear_attention', got '{layer_type}'." + ) + + self.self_attn_layer_norm = self_attn_layer_norm + self.ffn = ffn + self.ffn_layer_norm = ffn_layer_norm + + if self_attn_residual is None: + self_attn_residual = AdditiveResidualConnect() + self.self_attn_residual = self_attn_residual + + if ffn_residual is None: + ffn_residual = AdditiveResidualConnect() + self.ffn_residual = ffn_residual + + @override + def forward( + self, + seqs: Tensor, + seqs_layout: BatchLayout, + attn_bias_cache: AttentionBiasCache, + *, + state_bag: IncrementalStateBag | None = None, + ) -> Tensor: + seqs = self._forward_token_mixer(seqs, seqs_layout, attn_bias_cache, state_bag) + seqs = self._forward_ffn(seqs) + return seqs + + def _forward_token_mixer( + self, + seqs: Tensor, + seqs_layout: BatchLayout, + attn_bias_cache: AttentionBiasCache, + state_bag: IncrementalStateBag | None, + ) -> Tensor: + residual = seqs + + seqs = self.self_attn_layer_norm(seqs) + + if self.layer_type == "linear_attention": + assert self.linear_attn is not None + # GatedDeltaNet expects 3D (B, S, D) but packed sequences are 2D + # (T, D). Unsqueeze to (1, T, D) — treats all packed tokens as one + # long causal sequence, which is correct for recurrent computation. + if seqs.dim() == 2: + seqs = self.linear_attn(seqs.unsqueeze(0), state_bag=state_bag) + seqs = seqs.squeeze(0) + else: + seqs = self.linear_attn(seqs, state_bag=state_bag) + else: + assert self.self_attn is not None + seqs = self.self_attn( + seqs, + seqs_layout, + keys=seqs, + keys_layout=seqs_layout, + values=seqs, + bias_cache=attn_bias_cache, + state_bag=state_bag, + ) + + seqs = self.self_attn_residual(seqs, residual) + return seqs + + def _forward_ffn(self, seqs: Tensor) -> Tensor: + residual = seqs + + seqs = self.ffn_layer_norm(seqs) + seqs = self.ffn(seqs) + seqs = self.ffn_residual(seqs, residual) + + return seqs + + @override + def extra_repr(self) -> str: + return f"layer_type={self.layer_type}" diff --git a/src/fairseq2/models/qwen/factory.py b/src/fairseq2/models/qwen/factory.py index 5aeb96b34..6a6021950 100644 --- a/src/fairseq2/models/qwen/factory.py +++ b/src/fairseq2/models/qwen/factory.py @@ -10,7 +10,9 @@ from torch import Tensor from fairseq2.error import NotSupportedError -from fairseq2.models.qwen.config import QwenConfig +from fairseq2.models.qwen.attention import Qwen35Attention +from fairseq2.models.qwen.config import Qwen35Config, Qwen35MoeConfig, QwenConfig +from fairseq2.models.qwen.gated_delta_net import GatedDeltaNet from fairseq2.models.transformer import ( CausalAttentionBias, FeedForwardNetwork, @@ -48,6 +50,10 @@ def create_qwen_model(config: QwenConfig) -> TransformerLM: return QwenFactory(config).create_model() +def create_qwen35_model(config: Qwen35Config) -> TransformerLM: + return Qwen35Factory(config).create_model() + + class QwenFactory: def __init__(self, config: QwenConfig) -> None: self._config = config @@ -85,7 +91,7 @@ def init_embed(embed: StandardEmbedding) -> None: _init_truncated_normal(embed.weight, bias=None, std=std) return VocabShardedEmbedding( - config.vocab_size, config.model_dim, init_fn=init_embed + config.vocab_size, config.model_dim, config.pad_idx, init_fn=init_embed ) def create_decoder_frontend(self, embed: Embedding) -> TransformerFrontend: @@ -265,3 +271,206 @@ def _init_truncated_normal( if bias is not None: nn.init.zeros_(bias) + + +# --------------------------------------------------------------------------- +# Qwen 3.5 Factory +# --------------------------------------------------------------------------- + + +class Qwen35Factory: + """Factory for Qwen 3.5 dense hybrid models.""" + + def __init__(self, config: Qwen35Config) -> None: + self._config = config + config.__post_init__() + + def create_model(self) -> TransformerLM: + config = self._config + + embed = self.create_embedding() + decoder_frontend = self.create_decoder_frontend(embed) + decoder = self.create_decoder() + final_proj = self.create_final_projection(embed) + + return TransformerLM( + config.model_dim, + decoder_frontend, + decoder, + final_proj, + config.pad_idx, + config.max_seq_len, + ) + + def create_embedding(self) -> Embedding: + config = self._config + + def init_embed(embed: StandardEmbedding) -> None: + std = embed.weight.shape[1] ** -0.5 + _init_truncated_normal(embed.weight, bias=None, std=std) + + return VocabShardedEmbedding( + config.vocab_size, config.model_dim, config.pad_idx, init_fn=init_embed + ) + + def create_decoder_frontend(self, embed: Embedding) -> TransformerFrontend: + config = self._config + + return TransformerEmbeddingFrontend( + config.model_dim, + embed, + pos_encoder=None, + no_scale=True, + dropout_p=config.dropout_p, + ) + + def create_decoder(self) -> TransformerLMDecoder: + config = self._config + + pos_encoder = self.create_position_encoder() + + layers = [] + for idx in range(config.num_layers): + layer = self.create_decoder_layer(idx, pos_encoder) + layers.append(layer) + + layer_norm = self.create_layer_norm() + + return StandardTransformerLMDecoder(layers, layer_norm) + + def create_position_encoder(self) -> PositionEncoder: + config = self._config + + encoding_dim = int(config.head_dim * config.partial_rotary_factor) + + return ReferenceRotaryEncoder( + encoding_dim, config.max_seq_len, theta=config.rope_theta + ) + + def create_decoder_layer( + self, layer_idx: int, pos_encoder: PositionEncoder + ) -> TransformerLMDecoderLayer: + from fairseq2.models.qwen.decoder_layer import Qwen35DecoderLayer + + config = self._config + + assert config.layer_types is not None + layer_type = config.layer_types[layer_idx] + + self_attn = None + linear_attn = None + + if layer_type == "full_attention": + self_attn = self.create_gated_attention(layer_idx, pos_encoder) + else: + linear_attn = self.create_gated_delta_net(layer_idx) + + ffn = self.create_ffn(layer_idx) + self_attn_layer_norm = self.create_layer_norm() + ffn_layer_norm = self.create_layer_norm() + + return Qwen35DecoderLayer( + layer_type, + self_attn=self_attn, + linear_attn=linear_attn, + ffn=ffn, + self_attn_layer_norm=self_attn_layer_norm, + ffn_layer_norm=ffn_layer_norm, + ) + + def create_gated_attention( + self, layer_idx: int, pos_encoder: PositionEncoder + ) -> Qwen35Attention: + from fairseq2.models.qwen.attention import Qwen35Attention + + config = self._config + + attn_bias = CausalAttentionBias() + sdpa = create_default_sdpa(attn_bias) + + q_norm = self.create_layer_norm(config.head_dim) + k_norm = self.create_layer_norm(config.head_dim) + + return Qwen35Attention( + config.model_dim, + config.num_attn_heads, + sdpa, + head_dim=config.head_dim, + num_key_value_heads=config.num_key_value_heads, + pos_encoder=pos_encoder, + q_norm=q_norm, + k_norm=k_norm, + ) + + def create_gated_delta_net(self, layer_idx: int) -> GatedDeltaNet: + from fairseq2.models.qwen.gated_delta_net import GatedDeltaNet + + config = self._config + + return GatedDeltaNet( + hidden_size=config.model_dim, + num_k_heads=config.linear_num_key_heads, + num_v_heads=config.linear_num_value_heads, + head_k_dim=config.linear_key_head_dim, + head_v_dim=config.linear_value_head_dim, + conv_kernel_size=config.linear_conv_kernel_dim, + ) + + def create_ffn(self, layer_idx: int) -> FeedForwardNetwork: + config = self._config + + return GLUFeedForwardNetwork( + config.model_dim, + config.ffn_inner_dim, + bias=False, + inner_dim_scale=1.0, + ) + + def create_final_projection(self, embed: Embedding) -> Projection: + config = self._config + + if config.tied_embeddings: + if not isinstance(embed, VocabShardedEmbedding): + raise TypeError( + f"`embed` is expected to be of type `{VocabShardedEmbedding}` when tied_embeddings is True." + ) + if embed.tp_gang.size > 1: + raise NotSupportedError( + "Tied embeddings are not supported when tensor parallelism is enabled." + ) + return TiedProjection(embed.weight, bias=None) + + return ColumnShardedLinear(config.model_dim, config.vocab_size, bias=False) + + def create_layer_norm(self, dim: int | None = None) -> LayerNorm: + config = self._config + if dim is None: + dim = config.model_dim + return RMSNorm(dim, bias=False, eps=1e-06) + + +def create_qwen35_moe_model(config: Qwen35MoeConfig) -> TransformerLM: + return Qwen35MoeFactory(config).create_model() + + +class Qwen35MoeFactory(Qwen35Factory): + """Factory for Qwen 3.5 MoE hybrid models.""" + + _config: Qwen35MoeConfig + + def __init__(self, config: Qwen35MoeConfig) -> None: + super().__init__(config) + self._config = config + + def create_ffn(self, layer_idx: int) -> FeedForwardNetwork: + from fairseq2.models.qwen.moe import Qwen35MoeBlock + + config = self._config + + return Qwen35MoeBlock( + model_dim=config.model_dim, + num_experts=config.num_experts, + num_experts_per_tok=config.num_experts_per_tok, + moe_intermediate_size=config.moe_intermediate_size, + shared_expert_intermediate_size=config.shared_expert_intermediate_size, + ) diff --git a/src/fairseq2/models/qwen/gated_delta_net.py b/src/fairseq2/models/qwen/gated_delta_net.py new file mode 100644 index 000000000..d9e677dc3 --- /dev/null +++ b/src/fairseq2/models/qwen/gated_delta_net.py @@ -0,0 +1,545 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Gated DeltaNet linear attention module for Qwen 3.5. + +Reference: HuggingFace ``modeling_qwen3_5.py`` lines 445-620. +""" + +from __future__ import annotations + +import logging +from typing import Callable, Final, final + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from fairseq2.nn import ( + IncrementalState, + IncrementalStateBag, + Linear, + RMSNorm, +) + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Optional fast-path kernels +# --------------------------------------------------------------------------- + +try: + from causal_conv1d import causal_conv1d_update as _causal_conv1d_update + + _HAS_CAUSAL_CONV1D = True +except ImportError: + _HAS_CAUSAL_CONV1D = False + logger.warning( + "causal_conv1d not found; GatedDeltaNet will use a slower PyTorch fallback " + "for incremental decoding. Install with: pip install causal-conv1d" + ) + +try: + from fla.ops.gated_delta_rule import ( + chunk_gated_delta_rule as _chunk_gated_delta_rule, + ) + from fla.ops.gated_delta_rule import ( + fused_recurrent_gated_delta_rule as _fused_recurrent_gated_delta_rule, + ) + + _HAS_FLA = True +except ImportError: + _HAS_FLA = False + logger.warning( + "flash-linear-attention (fla) not found; GatedDeltaNet will use slower " + "pure-PyTorch chunk/recurrent kernels. Install with: pip install flash-linear-attention" + ) + + +def l2norm(x: Tensor, dim: int = -1, eps: float = 1e-6) -> Tensor: + """L2-normalize along ``dim``. + + Reference: ``modeling_qwen3_5.py`` lines 317-320. + """ + return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + + +# --------------------------------------------------------------------------- +# PyTorch fallback kernels (no external dependencies) +# --------------------------------------------------------------------------- + + +def torch_causal_conv1d_update( + hidden_states: Tensor, + conv_state: Tensor, + weight: Tensor, + bias: Tensor | None = None, + activation: str | None = None, +) -> Tensor: + """Single-step causal conv1d for incremental decoding. + + Reference: ``modeling_qwen3_5.py`` lines 299-314. + + :param hidden_states: ``(B, D, L)`` — typically ``L=1`` during decode. + :param conv_state: ``(B, D, kernel-1)`` — updated in-place. + :param weight: ``(D, kernel)`` — depthwise conv weights. + :param bias: ``(D,)`` or ``None``. + :param activation: ``"silu"`` or ``None``. + :returns: ``(B, D, L)`` convolved output. + """ + _, hidden_size, seq_len = hidden_states.shape + state_len = conv_state.shape[-1] + + hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype) + conv_state.copy_(hidden_states_new[:, :, -state_len:]) + + out = F.conv1d( + hidden_states_new, + weight.unsqueeze(1), + bias, + padding=0, + groups=hidden_size, + ) + if activation == "silu": + out = F.silu(out[:, :, -seq_len:]) + else: + out = out[:, :, -seq_len:] + return out.to(hidden_states.dtype) + + +def torch_chunk_gated_delta_rule( + query: Tensor, + key: Tensor, + value: Tensor, + g: Tensor, + beta: Tensor, + chunk_size: int = 64, + initial_state: Tensor | None = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[Tensor, Tensor | None]: + """Chunked gated delta rule for prefill (pure PyTorch). + + Reference: ``modeling_qwen3_5.py`` lines 323-400. + + :param query: ``(B, S, H, K)`` + :param key: ``(B, S, H, K)`` + :param value: ``(B, S, H, V)`` + :param g: ``(B, S, H)`` — forget gate (log-space). + :param beta: ``(B, S, H)`` — write gate. + :returns: ``(output, final_state)`` + """ + initial_dtype = query.dtype + + if use_qk_l2norm_in_kernel: + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) + for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, seq_len, k_head_dim = key.shape + v_head_dim = value.shape[-1] + + pad_size = (chunk_size - seq_len % chunk_size) % chunk_size + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_seq_len = seq_len + pad_size + + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + + query, key, value, k_beta, v_beta = [ + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) + for x in (query, key, value, k_beta, v_beta) + ] + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + + mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), + diagonal=0, + ) + + g = g.cumsum(dim=-1) + decay_mask = (g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float().tril() + + attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + + last_state = ( + torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + core_out = torch.zeros_like(value) + + mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), + diagonal=1, + ) + + num_chunks = total_seq_len // chunk_size + for i in range(num_chunks): + q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] + attn_i = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_( + mask, 0 + ) + v_prime = k_cumdecay[:, :, i] @ last_state + v_new = v_i - v_prime + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_state + core_out[:, :, i] = attn_inter + attn_i @ v_new + last_state = ( + last_state * g[:, :, i, -1, None, None].exp() + + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose( + -1, -2 + ) + @ v_new + ) + + final_state: Tensor | None = last_state if output_final_state else None + + core_out = core_out.reshape( + core_out.shape[0], core_out.shape[1], -1, core_out.shape[-1] + ) + core_out = core_out[:, :, :seq_len] + core_out = core_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_out, final_state + + +def torch_recurrent_gated_delta_rule( + query: Tensor, + key: Tensor, + value: Tensor, + g: Tensor, + beta: Tensor, + initial_state: Tensor | None = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[Tensor, Tensor | None]: + """Step-by-step recurrent gated delta rule for decode (pure PyTorch). + + Reference: ``modeling_qwen3_5.py`` lines 403-442. + + :param query: ``(B, S, H, K)`` — typically ``S=1`` during decode. + :param key: ``(B, S, H, K)`` + :param value: ``(B, S, H, V)`` + :param g: ``(B, S, H)`` — forget gate (log-space). + :param beta: ``(B, S, H)`` — write gate. + :returns: ``(output, final_state)`` + """ + initial_dtype = query.dtype + + if use_qk_l2norm_in_kernel: + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) + for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, seq_len, k_head_dim = key.shape + v_head_dim = value.shape[-1] + + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale + + core_out = torch.zeros(batch_size, num_heads, seq_len, v_head_dim).to(value) + last_state = ( + torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + + for i in range(seq_len): + q_t = query[:, :, i] + k_t = key[:, :, i] + v_t = value[:, :, i] + g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, i].unsqueeze(-1) + + last_state = last_state * g_t + kv_mem = (last_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + last_state = last_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + core_out[:, :, i] = (last_state * q_t.unsqueeze(-1)).sum(dim=-2) + + final_state: Tensor | None = last_state if output_final_state else None + + core_out = core_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_out, final_state + + +# --------------------------------------------------------------------------- +# Incremental state +# --------------------------------------------------------------------------- + + +@final +class GatedDeltaNetState(IncrementalState): + """Holds conv and recurrent state for :class:`GatedDeltaNet` during + incremental decoding.""" + + conv_state: Tensor + """``(B, conv_dim, kernel_size - 1)``""" + + recurrent_state: Tensor + """``(B, num_v_heads, head_k_dim, head_v_dim)``""" + + def __init__(self, conv_state: Tensor, recurrent_state: Tensor) -> None: + self.conv_state = conv_state + self.recurrent_state = recurrent_state + + def reorder(self, new_order: Tensor) -> None: + self.conv_state = self.conv_state.index_select(0, new_order) + self.recurrent_state = self.recurrent_state.index_select(0, new_order) + + def size_bytes(self) -> int: + return self.capacity_bytes() + + def capacity_bytes(self) -> int: + c = self.conv_state.numel() * self.conv_state.element_size() + r = self.recurrent_state.numel() * self.recurrent_state.element_size() + return c + r + + +# --------------------------------------------------------------------------- +# RMSNormGated — norm-before-gate with silu +# --------------------------------------------------------------------------- + + +class RMSNormGated(nn.Module): + """``RMSNorm(x) * silu(gate)`` + + Internal norm inside GatedDeltaNet. Uses the standard ``weight=ones`` + formula (NOT the ``1+weight`` variant used by the outer layer norms). + + Reference: ``modeling_qwen3_5.py`` lines 264-279. + """ + + def __init__(self, dim: int, eps: float = 1e-6) -> None: + super().__init__() + self.inner_norm = RMSNorm(dim, bias=False, eps=eps) + + def forward(self, hidden_states: Tensor, gate: Tensor) -> Tensor: + hidden_states = self.inner_norm(hidden_states) + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + return hidden_states.to(gate.dtype) + + +# --------------------------------------------------------------------------- +# GatedDeltaNet module +# --------------------------------------------------------------------------- + + +class GatedDeltaNet(nn.Module): + """Gated DeltaNet linear attention module for Qwen 3.5. + + Replaces standard multi-head attention in 75% of Qwen 3.5 layers. + Uses causal convolution followed by a gated delta rule recurrence. + + Reference: ``modeling_qwen3_5.py`` ``Qwen3_5GatedDeltaNet`` lines 445-620. + """ + + hidden_size: Final[int] + num_k_heads: Final[int] + num_v_heads: Final[int] + head_k_dim: Final[int] + head_v_dim: Final[int] + key_dim: Final[int] + value_dim: Final[int] + conv_dim: Final[int] + conv_kernel_size: Final[int] + + def __init__( + self, + hidden_size: int, + num_k_heads: int = 16, + num_v_heads: int = 32, + head_k_dim: int = 128, + head_v_dim: int = 128, + conv_kernel_size: int = 4, + eps: float = 1e-6, + ) -> None: + super().__init__() + + self.hidden_size = hidden_size + self.num_k_heads = num_k_heads + self.num_v_heads = num_v_heads + self.head_k_dim = head_k_dim + self.head_v_dim = head_v_dim + self.key_dim = head_k_dim * num_k_heads + self.value_dim = head_v_dim * num_v_heads + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv_kernel_size = conv_kernel_size + + # Input projections — fairseq2 Linear wrappers. + self.in_proj_qkv = Linear( + hidden_size, self.key_dim * 2 + self.value_dim, bias=False + ) + self.in_proj_z = Linear(hidden_size, self.value_dim, bias=False) + self.in_proj_b = Linear(hidden_size, num_v_heads, bias=False) + self.in_proj_a = Linear(hidden_size, num_v_heads, bias=False) + + # Depthwise causal convolution (no fairseq2 wrapper exists). + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=False, + kernel_size=conv_kernel_size, + groups=self.conv_dim, + padding=conv_kernel_size - 1, + ) + + # Learnable gating parameters (no fairseq2 wrapper for raw params). + self.dt_bias = nn.Parameter(torch.ones(num_v_heads)) + A = torch.empty(num_v_heads).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A)) + + # Output norm (silu-gated, wraps fairseq2 RMSNorm) and projection. + self.norm = RMSNormGated(head_v_dim, eps=eps) + self.out_proj = Linear(self.value_dim, hidden_size, bias=False) + + # Select fast-path kernels when available, else pure-PyTorch fallbacks. + self._conv1d_update_fn: Callable[..., Tensor] = ( + _causal_conv1d_update if _HAS_CAUSAL_CONV1D else torch_causal_conv1d_update + ) + self._chunk_fn: Callable[..., tuple[Tensor, Tensor | None]] = ( + _chunk_gated_delta_rule if _HAS_FLA else torch_chunk_gated_delta_rule + ) + self._recurrent_fn: Callable[..., tuple[Tensor, Tensor | None]] = ( + _fused_recurrent_gated_delta_rule + if _HAS_FLA + else torch_recurrent_gated_delta_rule + ) + + def forward( + self, + seqs: Tensor, + padding_mask: Tensor | None = None, + *, + state_bag: IncrementalStateBag | None = None, + ) -> Tensor: + """ + :param seqs: ``(B, S, D)`` + :param padding_mask: Optional ``(B, S)`` boolean mask (1 = valid). + :param state_bag: Incremental state bag for generation. + :returns: ``(B, S, D)`` + """ + if padding_mask is not None and padding_mask.shape[1] > 1: + seqs = (seqs * padding_mask[:, :, None]).to(seqs.dtype) + + batch_size, seq_len, _ = seqs.shape + + state: GatedDeltaNetState | None = None + if state_bag is not None: + state = state_bag.maybe_get_state(self, GatedDeltaNetState) + + use_cache = state is not None and seq_len == 1 + + # -- Input projections -- + mixed_qkv = self.in_proj_qkv(seqs).transpose(1, 2) # (B, conv_dim, S) + z = self.in_proj_z(seqs).reshape(batch_size, seq_len, -1, self.head_v_dim) + b = self.in_proj_b(seqs) + a = self.in_proj_a(seqs) + + # -- Causal convolution -- + conv_state: Tensor | None = None + + if use_cache: + assert state is not None + mixed_qkv = self._conv1d_update_fn( + mixed_qkv, + state.conv_state, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + "silu", + ) + else: + if state_bag is not None: + conv_state = F.pad( + mixed_qkv, + (self.conv_kernel_size - mixed_qkv.shape[-1], 0), + ) + + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + + mixed_qkv = mixed_qkv.transpose(1, 2) # (B, S, conv_dim) + + # -- Split QKV -- + query, key, value = torch.split( + mixed_qkv, [self.key_dim, self.key_dim, self.value_dim], dim=-1 + ) + query = query.reshape(batch_size, seq_len, -1, self.head_k_dim) + key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) + value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) + + # -- Compute gates -- + beta = b.sigmoid() + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + + # -- GQA expansion -- + groups = self.num_v_heads // self.num_k_heads + if groups > 1: + query = query.repeat_interleave(groups, dim=2) + key = key.repeat_interleave(groups, dim=2) + + # -- Delta rule core -- + if use_cache: + assert state is not None + core_out, last_state = self._recurrent_fn( + query, + key, + value, + g=g, + beta=beta, + initial_state=state.recurrent_state, + output_final_state=True, + use_qk_l2norm_in_kernel=True, + ) + else: + core_out, last_state = self._chunk_fn( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=state_bag is not None, + use_qk_l2norm_in_kernel=True, + ) + + # -- Update incremental state -- + if state_bag is not None: + if state is None: + assert conv_state is not None and last_state is not None + state_bag.set_state(self, GatedDeltaNetState(conv_state, last_state)) + else: + assert last_state is not None + state.recurrent_state = last_state + + # -- Output norm (silu-gated) + projection -- + core_out = core_out.reshape(-1, self.head_v_dim) + z = z.reshape(-1, self.head_v_dim) + core_out = self.norm(core_out, z) + core_out = core_out.reshape(batch_size, seq_len, -1) + + return self.out_proj(core_out) diff --git a/src/fairseq2/models/qwen/hub.py b/src/fairseq2/models/qwen/hub.py index 55df4bcf1..827196df3 100644 --- a/src/fairseq2/models/qwen/hub.py +++ b/src/fairseq2/models/qwen/hub.py @@ -8,7 +8,14 @@ from fairseq2.data.tokenizers import TokenizerHubAccessor from fairseq2.models import ModelHubAccessor -from fairseq2.models.qwen.config import QWEN_FAMILY, QwenConfig +from fairseq2.models.qwen.config import ( + QWEN35_FAMILY, + QWEN35_MOE_FAMILY, + QWEN_FAMILY, + Qwen35Config, + Qwen35MoeConfig, + QwenConfig, +) from fairseq2.models.qwen.tokenizer import QwenTokenizer, QwenTokenizerConfig from fairseq2.models.transformer_lm import TransformerLM @@ -19,3 +26,19 @@ get_qwen_tokenizer_hub = TokenizerHubAccessor( QWEN_FAMILY, kls=QwenTokenizer, config_kls=QwenTokenizerConfig ) + +get_qwen35_model_hub = ModelHubAccessor( + QWEN35_FAMILY, kls=TransformerLM, config_kls=Qwen35Config +) + +get_qwen35_tokenizer_hub = TokenizerHubAccessor( + QWEN35_FAMILY, kls=QwenTokenizer, config_kls=QwenTokenizerConfig +) + +get_qwen35_moe_model_hub = ModelHubAccessor( + QWEN35_MOE_FAMILY, kls=TransformerLM, config_kls=Qwen35MoeConfig +) + +get_qwen35_moe_tokenizer_hub = TokenizerHubAccessor( + QWEN35_MOE_FAMILY, kls=QwenTokenizer, config_kls=QwenTokenizerConfig +) diff --git a/src/fairseq2/models/qwen/interop.py b/src/fairseq2/models/qwen/interop.py index 98da6d6b7..37cac3b51 100644 --- a/src/fairseq2/models/qwen/interop.py +++ b/src/fairseq2/models/qwen/interop.py @@ -8,10 +8,11 @@ from typing import Final, final +import torch from typing_extensions import override from fairseq2.models.hg import HuggingFaceConfig, HuggingFaceConverter -from fairseq2.models.qwen.config import QwenConfig +from fairseq2.models.qwen.config import Qwen35Config, Qwen35MoeConfig, QwenConfig from fairseq2.models.utils.checkpoint import convert_state_dict, create_reverse_key_map from fairseq2.utils.config import cast_config_type @@ -47,9 +48,129 @@ def convert_qwen_state_dict( return state_dict +# HG-side RMSNorm key suffixes for reverse conversion (weight -= 1.0). +_QWEN35_HG_RMSNORM_SUFFIXES = ( + "input_layernorm.weight", + "post_attention_layernorm.weight", + "model.norm.weight", + "self_attn.q_norm.weight", + "self_attn.k_norm.weight", +) + + @final -class _QwenHuggingFaceConverter(HuggingFaceConverter): +class _Qwen35HuggingFaceConverter(HuggingFaceConverter): + @override + def to_hg_config(self, config: object) -> HuggingFaceConfig: + config = cast_config_type(config, Qwen35Config) + + data: dict[str, object] = { + "hidden_size": config.model_dim, + "max_position_embeddings": config.max_seq_len, + "vocab_size": config.vocab_size, + "tie_word_embeddings": config.tied_embeddings, + "num_hidden_layers": config.num_layers, + "num_attention_heads": config.num_attn_heads, + "num_key_value_heads": config.num_key_value_heads, + "head_dim": config.head_dim, + "intermediate_size": config.ffn_inner_dim, + "partial_rotary_factor": config.partial_rotary_factor, + "rope_theta": config.rope_theta, + "full_attention_interval": config.full_attention_interval, + "linear_conv_kernel_dim": config.linear_conv_kernel_dim, + "linear_key_head_dim": config.linear_key_head_dim, + "linear_value_head_dim": config.linear_value_head_dim, + "linear_num_key_heads": config.linear_num_key_heads, + "linear_num_value_heads": config.linear_num_value_heads, + } + + return HuggingFaceConfig( + data, kls_name="Qwen3_5TextConfig", arch="Qwen3_5ForCausalLM" + ) + + @override + def to_hg_state_dict( + self, state_dict: dict[str, object], config: object + ) -> dict[str, object]: + config = cast_config_type(config, Qwen35Config) + + # Use the text-only key map for export (model.layers.*, not + # model.language_model.layers.*). + key_map = create_reverse_key_map(_QWEN35_TEXT_KEY_MAP) + + hg_state_dict = convert_state_dict(state_dict, key_map) + + for key in list(hg_state_dict.keys()): + if any(key.endswith(suffix) for suffix in _QWEN35_HG_RMSNORM_SUFFIXES): + weight = hg_state_dict[key] + if isinstance(weight, torch.Tensor): + hg_state_dict[key] = weight - 1.0 + + if config.tied_embeddings: + hg_state_dict.pop("lm_head.weight", None) + + return hg_state_dict + + +@final +class _Qwen35MoeHuggingFaceConverter(HuggingFaceConverter): + def to_hg_config(self, config: object) -> HuggingFaceConfig: + config = cast_config_type(config, Qwen35MoeConfig) + + data: dict[str, object] = { + "hidden_size": config.model_dim, + "max_position_embeddings": config.max_seq_len, + "vocab_size": config.vocab_size, + "tie_word_embeddings": config.tied_embeddings, + "num_hidden_layers": config.num_layers, + "num_attention_heads": config.num_attn_heads, + "num_key_value_heads": config.num_key_value_heads, + "head_dim": config.head_dim, + "intermediate_size": config.ffn_inner_dim, + "partial_rotary_factor": config.partial_rotary_factor, + "rope_theta": config.rope_theta, + "full_attention_interval": config.full_attention_interval, + "linear_conv_kernel_dim": config.linear_conv_kernel_dim, + "linear_key_head_dim": config.linear_key_head_dim, + "linear_value_head_dim": config.linear_value_head_dim, + "linear_num_key_heads": config.linear_num_key_heads, + "linear_num_value_heads": config.linear_num_value_heads, + "num_experts": config.num_experts, + "num_experts_per_tok": config.num_experts_per_tok, + "moe_intermediate_size": config.moe_intermediate_size, + "shared_expert_intermediate_size": config.shared_expert_intermediate_size, + "router_aux_loss_coef": config.router_aux_loss_coef, + } + + return HuggingFaceConfig( + data, kls_name="Qwen3_5TextConfig", arch="Qwen3_5MoeForCausalLM" + ) + @override + def to_hg_state_dict( + self, state_dict: dict[str, object], config: object + ) -> dict[str, object]: + config = cast_config_type(config, Qwen35MoeConfig) + + # Use the text-only MoE key map for export. + key_map = create_reverse_key_map(_QWEN35_MOE_TEXT_KEY_MAP) + + hg_state_dict = convert_state_dict(state_dict, key_map) + + for key in list(hg_state_dict.keys()): + if any(key.endswith(suffix) for suffix in _QWEN35_HG_RMSNORM_SUFFIXES): + weight = hg_state_dict[key] + if isinstance(weight, torch.Tensor): + hg_state_dict[key] = weight - 1.0 + + if config.tied_embeddings: + hg_state_dict.pop("lm_head.weight", None) + + return hg_state_dict + + +@final +class _QwenHuggingFaceConverter(HuggingFaceConverter): def to_hg_config(self, config: object) -> HuggingFaceConfig: config = cast_config_type(config, QwenConfig) @@ -88,3 +209,191 @@ def to_hg_state_dict( del hg_state_dict["lm_head.weight"] return hg_state_dict + + +# --------------------------------------------------------------------------- +# Qwen 3.5 interop +# --------------------------------------------------------------------------- + +# Text-only key map (matches ``transformers`` Qwen3_5ForCausalLM state dict). +# These are also the canonical keys used for the reverse (fs2 → HF) export. +_QWEN35_TEXT_KEY_MAP: Final = { + # fmt: off + # Full attention layers + r"^model\.layers\.([0-9]+)\.self_attn\.q_proj\.": r"decoder.layers.\1.self_attn.q_proj.", + r"^model\.layers\.([0-9]+)\.self_attn\.k_proj\.": r"decoder.layers.\1.self_attn.k_proj.", + r"^model\.layers\.([0-9]+)\.self_attn\.v_proj\.": r"decoder.layers.\1.self_attn.v_proj.", + r"^model\.layers\.([0-9]+)\.self_attn\.o_proj\.": r"decoder.layers.\1.self_attn.output_proj.", + r"^model\.layers\.([0-9]+)\.self_attn\.q_norm\.": r"decoder.layers.\1.self_attn.q_norm.", + r"^model\.layers\.([0-9]+)\.self_attn\.k_norm\.": r"decoder.layers.\1.self_attn.k_norm.", + # Linear attention layers (GatedDeltaNet) + r"^model\.layers\.([0-9]+)\.linear_attn\.in_proj_qkv\.": r"decoder.layers.\1.linear_attn.in_proj_qkv.", + r"^model\.layers\.([0-9]+)\.linear_attn\.in_proj_z\.": r"decoder.layers.\1.linear_attn.in_proj_z.", + r"^model\.layers\.([0-9]+)\.linear_attn\.in_proj_b\.": r"decoder.layers.\1.linear_attn.in_proj_b.", + r"^model\.layers\.([0-9]+)\.linear_attn\.in_proj_a\.": r"decoder.layers.\1.linear_attn.in_proj_a.", + r"^model\.layers\.([0-9]+)\.linear_attn\.conv1d\.": r"decoder.layers.\1.linear_attn.conv1d.", + r"^model\.layers\.([0-9]+)\.linear_attn\.dt_bias": r"decoder.layers.\1.linear_attn.dt_bias", + r"^model\.layers\.([0-9]+)\.linear_attn\.A_log": r"decoder.layers.\1.linear_attn.A_log", + r"^model\.layers\.([0-9]+)\.linear_attn\.norm\.": r"decoder.layers.\1.linear_attn.norm.inner_norm.", + r"^model\.layers\.([0-9]+)\.linear_attn\.out_proj\.": r"decoder.layers.\1.linear_attn.out_proj.", + # FFN + r"^model\.layers\.([0-9]+)\.mlp\.gate_proj\.": r"decoder.layers.\1.ffn.gate_proj.", + r"^model\.layers\.([0-9]+)\.mlp\.up_proj\.": r"decoder.layers.\1.ffn.inner_proj.", + r"^model\.layers\.([0-9]+)\.mlp\.down_proj\.": r"decoder.layers.\1.ffn.output_proj.", + # Layer norms + r"^model\.layers\.([0-9]+)\.input_layernorm\.": r"decoder.layers.\1.self_attn_layer_norm.", + r"^model\.layers\.([0-9]+)\.post_attention_layernorm\.": r"decoder.layers.\1.ffn_layer_norm.", + # Embeddings & head + r"^model\.norm\.": r"decoder.layer_norm.", + r"^model\.embed_tokens\.": r"decoder_frontend.embed.", + r"^lm_head\.": r"final_proj.", + # fmt: on +} + + +def _expand_with_language_model_prefix( + key_map: dict[str, str], +) -> dict[str, str]: + """Add ``model.language_model.*`` variants for every ``model.*`` pattern. + + Qwen 3.5 checkpoints on HuggingFace Hub are multimodal (VL) models where + the text decoder lives under ``model.language_model.*``. This helper + duplicates the text-only patterns so that the key map handles both formats: + + * Text-only (``model.layers.*``) — from ``transformers`` ``Qwen3_5ForCausalLM`` + * Multimodal (``model.language_model.layers.*``) — from safetensors checkpoint + """ + expanded: dict[str, str] = dict(key_map) + for pattern, replacement in key_map.items(): + if pattern.startswith(r"^model\."): + vl_pattern = pattern.replace(r"^model\.", r"^model\.language_model\.", 1) + expanded[vl_pattern] = replacement + return expanded + + +# Full key map: handles both text-only and multimodal checkpoint formats. +_QWEN35_HG_KEY_MAP: Final = _expand_with_language_model_prefix(_QWEN35_TEXT_KEY_MAP) + +# RMSNorm keys that need weight += 1.0 conversion (Qwen 3.5 uses 1+w formula). +# The GatedDeltaNet internal norm (linear_attn.norm) does NOT need conversion. +_QWEN35_RMSNORM_KEYS = ( + "self_attn_layer_norm.weight", + "ffn_layer_norm.weight", + "decoder.layer_norm.weight", + "self_attn.q_norm.weight", + "self_attn.k_norm.weight", +) + + +# Components not yet integrated in the text-only CausalLM model. +_QWEN35_VL_SKIP_PREFIXES: Final = ( + "model.visual.", # vision encoder + "mtp.", # multi-token prediction head +) + + +def _is_hg_format(state_dict: dict[str, object]) -> bool: + """Return True when the state dict uses HuggingFace key names.""" + return ( + "model.embed_tokens.weight" in state_dict + or "model.language_model.embed_tokens.weight" in state_dict + ) + + +def convert_qwen35_state_dict( + state_dict: dict[str, object], config: Qwen35Config +) -> dict[str, object]: + # Filter out multimodal components not yet integrated (cf. gemma3n pattern). + state_dict = { + k: v + for k, v in state_dict.items() + if not k.startswith(_QWEN35_VL_SKIP_PREFIXES) + } + + if _is_hg_format(state_dict): + state_dict = convert_state_dict(state_dict, _QWEN35_HG_KEY_MAP) + + # Convert (1+w) RMSNorm weights to standard (w) by adding 1.0. + for key in list(state_dict.keys()): + if any(key.endswith(suffix) for suffix in _QWEN35_RMSNORM_KEYS): + weight = state_dict[key] + if isinstance(weight, torch.Tensor): + state_dict[key] = weight + 1.0 + + if config.tied_embeddings: + if "decoder_frontend.embed.weight" in state_dict: + state_dict["final_proj.weight"] = state_dict[ + "decoder_frontend.embed.weight" + ] + elif "final_proj.weight" in state_dict: + state_dict["decoder_frontend.embed.weight"] = state_dict[ + "final_proj.weight" + ] + + return state_dict + + +# --------------------------------------------------------------------------- +# Qwen 3.5 MoE interop +# --------------------------------------------------------------------------- + +# MoE text-only base: start from the dense text-only map, swap FFN patterns. +_QWEN35_MOE_TEXT_KEY_MAP: Final = { + **{ + k: v + for k, v in _QWEN35_TEXT_KEY_MAP.items() + # Drop dense FFN patterns (MoE uses a different FFN layout) + if k + not in ( + r"^model\.layers\.([0-9]+)\.mlp\.gate_proj\.", + r"^model\.layers\.([0-9]+)\.mlp\.up_proj\.", + r"^model\.layers\.([0-9]+)\.mlp\.down_proj\.", + ) + }, + # fmt: off + r"^model\.layers\.([0-9]+)\.mlp\.gate\.": r"decoder.layers.\1.ffn.gate.", + r"^model\.layers\.([0-9]+)\.mlp\.experts\.gate_up_proj": r"decoder.layers.\1.ffn.experts.gate_up_proj", + r"^model\.layers\.([0-9]+)\.mlp\.experts\.down_proj": r"decoder.layers.\1.ffn.experts.down_proj", + r"^model\.layers\.([0-9]+)\.mlp\.shared_expert\.gate_proj\.": r"decoder.layers.\1.ffn.shared_expert.gate_proj.", + r"^model\.layers\.([0-9]+)\.mlp\.shared_expert\.up_proj\.": r"decoder.layers.\1.ffn.shared_expert.inner_proj.", + r"^model\.layers\.([0-9]+)\.mlp\.shared_expert\.down_proj\.": r"decoder.layers.\1.ffn.shared_expert.output_proj.", + r"^model\.layers\.([0-9]+)\.mlp\.shared_expert_gate\.": r"decoder.layers.\1.ffn.shared_expert_gate.", + # fmt: on +} + +_QWEN35_MOE_HG_KEY_MAP: Final = _expand_with_language_model_prefix( + _QWEN35_MOE_TEXT_KEY_MAP +) + + +def convert_qwen35_moe_state_dict( + state_dict: dict[str, object], config: Qwen35MoeConfig +) -> dict[str, object]: + # Filter out multimodal components not yet integrated (cf. gemma3n pattern). + state_dict = { + k: v + for k, v in state_dict.items() + if not k.startswith(_QWEN35_VL_SKIP_PREFIXES) + } + + if _is_hg_format(state_dict): + state_dict = convert_state_dict(state_dict, _QWEN35_MOE_HG_KEY_MAP) + + # Convert (1+w) RMSNorm weights to standard (w) by adding 1.0. + for key in list(state_dict.keys()): + if any(key.endswith(suffix) for suffix in _QWEN35_RMSNORM_KEYS): + weight = state_dict[key] + if isinstance(weight, torch.Tensor): + state_dict[key] = weight + 1.0 + + if config.tied_embeddings: + if "decoder_frontend.embed.weight" in state_dict: + state_dict["final_proj.weight"] = state_dict[ + "decoder_frontend.embed.weight" + ] + elif "final_proj.weight" in state_dict: + state_dict["decoder_frontend.embed.weight"] = state_dict[ + "final_proj.weight" + ] + + return state_dict diff --git a/src/fairseq2/models/qwen/moe.py b/src/fairseq2/models/qwen/moe.py new file mode 100644 index 000000000..03ead545f --- /dev/null +++ b/src/fairseq2/models/qwen/moe.py @@ -0,0 +1,235 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Mixture-of-Experts modules for Qwen 3.5 MoE. + +This module implements the MoE architecture from Qwen 3.5 MoE following the +HuggingFace reference in ``modeling_qwen3_5_moe.py``. + +Classes: + - :class:`Qwen35TopKRouter` — softmax → top-k → renormalize (HF lines 841-857) + - :class:`Qwen35Experts` — fused 3-D parameter experts (HF lines 802-838) + - :class:`Qwen35MoeBlock` — router + experts + shared expert (HF lines 860-879) +""" + +from __future__ import annotations + +from typing import Final + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Module, Parameter +from typing_extensions import override + +from fairseq2.models.transformer import FeedForwardNetwork, GLUFeedForwardNetwork +from fairseq2.nn import Linear + + +class Qwen35TopKRouter(Module): + """Top-k softmax router for Qwen 3.5 MoE. + + Computes softmax over all experts, selects the top-k, and renormalises the + selected weights so they sum to 1. + + Reference: ``Qwen3_5MoeTopKRouter`` (HF lines 841-857). + """ + + num_experts: Final[int] + top_k: Final[int] + model_dim: Final[int] + + def __init__(self, num_experts: int, top_k: int, model_dim: int) -> None: + super().__init__() + + self.num_experts = num_experts + self.top_k = top_k + self.model_dim = model_dim + + self.weight = Parameter(torch.zeros(num_experts, model_dim)) + + def forward(self, hidden_states: Tensor) -> tuple[Tensor, Tensor, Tensor]: + """ + :param hidden_states: + Token representations of shape ``(T, D)`` where *T* is the + (flattened) number of tokens. + + :returns: + A 3-tuple of: + - ``router_logits`` — raw pre-softmax logits ``(T, E)`` + - ``router_weights`` — renormalised top-k weights ``(T, K)`` + - ``router_indices`` — selected expert indices ``(T, K)`` + """ + hidden_states = hidden_states.reshape(-1, self.model_dim) + + router_logits = F.linear(hidden_states, self.weight) + router_probs = F.softmax(router_logits, dtype=torch.float, dim=-1) + + router_weights, router_indices = torch.topk(router_probs, self.top_k, dim=-1) + + router_weights = router_weights / router_weights.sum(dim=-1, keepdim=True) + router_weights = router_weights.to(router_logits.dtype) + + return router_logits, router_weights, router_indices + + +class Qwen35Experts(Module): + """Fused expert layer with 3-D weight parameters for Qwen 3.5 MoE. + + Each expert is a GLU-style MLP (gate+up → SiLU → down) stored as a single + ``(E, 2*I, D)`` gate-up projection and a ``(E, D, I)`` down projection so + that individual experts can be indexed without slicing overhead. + + Reference: ``Qwen3_5MoeExperts`` (HF lines 802-838). + """ + + num_experts: Final[int] + model_dim: Final[int] + expert_inner_dim: Final[int] + + def __init__( + self, + num_experts: int, + model_dim: int, + expert_inner_dim: int, + ) -> None: + super().__init__() + + self.num_experts = num_experts + self.model_dim = model_dim + self.expert_inner_dim = expert_inner_dim + + self.gate_up_proj = Parameter( + torch.empty(num_experts, 2 * expert_inner_dim, model_dim) + ) + self.down_proj = Parameter( + torch.empty(num_experts, model_dim, expert_inner_dim) + ) + + def forward( + self, + hidden_states: Tensor, + top_k_indices: Tensor, + top_k_weights: Tensor, + ) -> Tensor: + """ + :param hidden_states: + Token representations of shape ``(T, D)``. + :param top_k_indices: + Selected expert indices of shape ``(T, K)``. + :param top_k_weights: + Renormalised routing weights of shape ``(T, K)``. + + :returns: + Expert-mixed output of shape ``(T, D)``. + """ + final_hidden_states = torch.zeros_like(hidden_states) + + with torch.no_grad(): + expert_mask = F.one_hot(top_k_indices, num_classes=self.num_experts) + # (T, K, E) → (E, K, T) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + + current_state = hidden_states[token_idx] + + gate, up = F.linear(current_state, self.gate_up_proj[expert_idx]).chunk( + 2, dim=-1 + ) + + current_hidden_states = F.silu(gate) * up + + current_hidden_states = F.linear( + current_hidden_states, self.down_proj[expert_idx] + ) + + current_hidden_states *= top_k_weights[token_idx, top_k_pos, None] + + final_hidden_states.index_add_( + 0, + token_idx, + current_hidden_states.to(final_hidden_states.dtype), + ) + + return final_hidden_states + + +class Qwen35MoeBlock(FeedForwardNetwork): + """Sparse Mixture-of-Experts feed-forward block for Qwen 3.5 MoE. + + Combines a top-k router, a set of sparse experts, a shared expert (standard + GLU MLP), and a learned sigmoid gate that blends the shared expert output + into the final result. + + This class inherits from :class:`FeedForwardNetwork` so it can serve as a + drop-in replacement for :class:`GLUFeedForwardNetwork` inside any + Transformer decoder layer. + + Reference: ``Qwen3_5MoeSparseMoeBlock`` (HF lines 860-879). + """ + + model_dim: Final[int] + + def __init__( + self, + model_dim: int, + num_experts: int, + num_experts_per_tok: int, + moe_intermediate_size: int, + shared_expert_intermediate_size: int, + ) -> None: + """ + :param model_dim: + The dimensionality of the model (``hidden_size``). + :param num_experts: + The total number of routed experts. + :param num_experts_per_tok: + The number of experts activated per token (top-k). + :param moe_intermediate_size: + The intermediate (inner) dimension of each routed expert. + :param shared_expert_intermediate_size: + The intermediate (inner) dimension of the shared expert. + """ + super().__init__() + + self.model_dim = model_dim + + self.gate = Qwen35TopKRouter(num_experts, num_experts_per_tok, model_dim) + + self.experts = Qwen35Experts(num_experts, model_dim, moe_intermediate_size) + + self.shared_expert = GLUFeedForwardNetwork( + model_dim, + shared_expert_intermediate_size, + bias=False, + inner_dim_scale=1.0, + ) + + self.shared_expert_gate = Linear(model_dim, 1, bias=False) + + @override + def forward(self, seqs: Tensor) -> Tensor: + B, S, D = seqs.shape + + hidden_states = seqs.view(-1, D) + + shared_out = self.shared_expert(hidden_states) + + _, routing_weights, selected_experts = self.gate(hidden_states) + + expert_out = self.experts(hidden_states, selected_experts, routing_weights) + + shared_out = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_out + + out: Tensor = (expert_out + shared_out).reshape(B, S, D) + return out diff --git a/tests/integration/models/test_qwen35.py b/tests/integration/models/test_qwen35.py new file mode 100644 index 000000000..390477695 --- /dev/null +++ b/tests/integration/models/test_qwen35.py @@ -0,0 +1,213 @@ +"""Qwen 3.5 0.8B — HuggingFace vs fairseq2 numerical parity test. + +Downloads the HF checkpoint, loads it into both HF and fairseq2, +runs the same input, and asserts logit closeness. +""" + +import os + +import pytest +import torch +import torch.nn.functional as F + +# Use local checkpoint if available (avoids SSL/proxy issues in CI) +_LOCAL_PATH = "/checkpoint/smallomnillm/shared/models/Qwen3.5-0.8B" +MODEL_ID = _LOCAL_PATH if os.path.isdir(_LOCAL_PATH) else "Qwen/Qwen3.5-0.8B" + + +def _hf_model_type_available(model_type: str) -> bool: + """Return True if the installed transformers recognises *model_type*.""" + try: + from transformers.models.auto.configuration_auto import CONFIG_MAPPING + + return model_type in CONFIG_MAPPING + except Exception: + return False + + +@pytest.mark.skipif( + not _hf_model_type_available("qwen3_5"), + reason="transformers does not support model_type 'qwen3_5' (upgrade transformers)", +) +class TestQwen35HFParity: + """Numerical parity between HuggingFace and fairseq2 for Qwen 3.5 0.8B.""" + + def test_logit_parity(self) -> None: + from transformers import AutoModelForCausalLM, AutoTokenizer + + # ---- Step 1: Load HF model ---- + print("=" * 60) + print("Step 1: Loading HuggingFace model...") + print("=" * 60) + + hf_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) + hf_model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, + torch_dtype=torch.float32, + trust_remote_code=True, + ) + hf_model.eval() + print( + f" HF model loaded: {sum(p.numel() for p in hf_model.parameters()):,} params" + ) + + # ---- Step 2: Build fairseq2 model from config ---- + print("\n" + "=" * 60) + print("Step 2: Building fairseq2 model...") + print("=" * 60) + + from fairseq2.models.qwen.config import Qwen35Config + from fairseq2.models.qwen.factory import create_qwen35_model + from fairseq2.models.qwen.interop import convert_qwen35_state_dict + + config = Qwen35Config( + model_dim=1024, + max_seq_len=262_144, + vocab_size=248_320, + tied_embeddings=True, + num_layers=24, + num_attn_heads=8, + num_key_value_heads=2, + head_dim=256, + ffn_inner_dim=3584, + partial_rotary_factor=0.25, + rope_theta=10_000_000.0, + full_attention_interval=4, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=16, + ) + + fs2_model = create_qwen35_model(config) + fs2_model.eval() + print( + f" fs2 model built: {sum(p.numel() for p in fs2_model.parameters()):,} params" + ) + + # ---- Step 3: Convert and load HF state dict into fairseq2 ---- + print("\n" + "=" * 60) + print("Step 3: Converting HF state dict -> fairseq2...") + print("=" * 60) + + hf_state_dict = dict(hf_model.state_dict()) + fs2_state_dict = convert_qwen35_state_dict(hf_state_dict, config) + + fs2_keys = set(fs2_model.state_dict().keys()) + converted_keys = set(fs2_state_dict.keys()) + + missing = fs2_keys - converted_keys + unexpected = converted_keys - fs2_keys + + if missing: + print(f" WARNING: {len(missing)} missing keys:") + for k in sorted(missing)[:20]: + print(f" - {k}") + if unexpected: + print(f" WARNING: {len(unexpected)} unexpected keys:") + for k in sorted(unexpected)[:20]: + print(f" - {k}") + + if missing or unexpected: + print("\n Attempting to load with strict=False...") + result = fs2_model.load_state_dict(fs2_state_dict, strict=False) + print( + f" Missing: {len(result.missing_keys)}, Unexpected: {len(result.unexpected_keys)}" + ) + if result.missing_keys: + pytest.fail( + f"Cannot proceed with {len(result.missing_keys)} missing keys: " + + ", ".join(sorted(result.missing_keys)[:30]) + ) + else: + fs2_model.load_state_dict(fs2_state_dict, strict=True) + print(" State dict loaded successfully (strict=True)") + + # ---- Step 4: Prepare input ---- + print("\n" + "=" * 60) + print("Step 4: Preparing input...") + print("=" * 60) + + test_text = "The capital of France is" + tokens = hf_tokenizer(test_text, return_tensors="pt") + input_ids = tokens["input_ids"] # (1, S) + print(f" Input: '{test_text}'") + print(f" Token IDs: {input_ids.tolist()}") + print(f" Sequence length: {input_ids.shape[1]}") + + # ---- Step 5: HF forward pass ---- + print("\n" + "=" * 60) + print("Step 5: HF forward pass...") + print("=" * 60) + + with torch.no_grad(): + hf_output = hf_model(input_ids) + hf_logits = hf_output.logits # (1, S, V) + + print(f" HF logits shape: {hf_logits.shape}") + print(f" HF logits[0, -1, :5]: {hf_logits[0, -1, :5]}") + + # ---- Step 6: fairseq2 forward pass ---- + print("\n" + "=" * 60) + print("Step 6: fairseq2 forward pass...") + print("=" * 60) + + from fairseq2.nn import BatchLayout + + with torch.no_grad(): + seqs = input_ids # (1, S) + seqs_layout = BatchLayout.of(seqs) + fs2_logits = fs2_model(seqs, seqs_layout) + + print(f" fs2 logits shape: {fs2_logits.shape}") + print(f" fs2 logits[0, -1, :5]: {fs2_logits[0, -1, :5]}") + + # ---- Step 7: Compare ---- + print("\n" + "=" * 60) + print("Step 7: Numerical comparison...") + print("=" * 60) + + hf_last = hf_logits[0, -1].float() + fs2_last = fs2_logits[0, -1].float() + + abs_diff = (hf_last - fs2_last).abs() + max_diff = abs_diff.max().item() + mean_diff = abs_diff.mean().item() + + print(f" Last-token logit max abs diff: {max_diff:.6e}") + print(f" Last-token logit mean abs diff: {mean_diff:.6e}") + + hf_all = hf_logits.float() + fs2_all = fs2_logits.float() + + full_abs_diff = (hf_all - fs2_all).abs() + full_max_diff = full_abs_diff.max().item() + full_mean_diff = full_abs_diff.mean().item() + + print(f" Full-seq logit max abs diff: {full_max_diff:.6e}") + print(f" Full-seq logit mean abs diff: {full_mean_diff:.6e}") + + hf_top1 = int(hf_last.argmax().item()) + fs2_top1 = int(fs2_last.argmax().item()) + print(f"\n HF top-1 token: {hf_top1} -> '{hf_tokenizer.decode([hf_top1])}'") + print(f" fs2 top-1 token: {fs2_top1} -> '{hf_tokenizer.decode([fs2_top1])}'") + + hf_top5: list[int] = [int(t) for t in hf_last.topk(5).indices.tolist()] + fs2_top5: list[int] = [int(t) for t in fs2_last.topk(5).indices.tolist()] + print( + f"\n HF top-5: {hf_top5} -> {[hf_tokenizer.decode([t]) for t in hf_top5]}" + ) + print( + f" fs2 top-5: {fs2_top5} -> {[hf_tokenizer.decode([t]) for t in fs2_top5]}" + ) + + cos_sim = F.cosine_similarity( + hf_last.unsqueeze(0), fs2_last.unsqueeze(0) + ).item() + print(f"\n Cosine similarity (last token): {cos_sim:.8f}") + + ATOL = 1e-4 + assert ( + full_max_diff < ATOL or cos_sim > 0.9999 + ), f"Parity check failed: max diff {full_max_diff:.2e}, cosine sim {cos_sim:.6f}" diff --git a/tests/unit/models/qwen/__init__.py b/tests/unit/models/qwen/__init__.py new file mode 100644 index 000000000..2e41cd717 --- /dev/null +++ b/tests/unit/models/qwen/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/unit/models/qwen/test_qwen35.py b/tests/unit/models/qwen/test_qwen35.py new file mode 100644 index 000000000..9c93e7396 --- /dev/null +++ b/tests/unit/models/qwen/test_qwen35.py @@ -0,0 +1,395 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for the Qwen 3.5 model family (dense + MoE).""" + +from __future__ import annotations + +import pytest +import torch +from torch.testing import assert_close + +from fairseq2.models.qwen.attention import Qwen35Attention +from fairseq2.models.qwen.config import Qwen35Config, Qwen35MoeConfig +from fairseq2.models.qwen.decoder_layer import Qwen35DecoderLayer +from fairseq2.models.qwen.factory import create_qwen35_model, create_qwen35_moe_model +from fairseq2.models.qwen.gated_delta_net import ( + GatedDeltaNet, + GatedDeltaNetState, + torch_chunk_gated_delta_rule, + torch_recurrent_gated_delta_rule, +) +from fairseq2.models.qwen.interop import ( + _QWEN35_HG_KEY_MAP, + _QWEN35_RMSNORM_KEYS, + _Qwen35HuggingFaceConverter, + _Qwen35MoeHuggingFaceConverter, + convert_qwen35_moe_state_dict, + convert_qwen35_state_dict, +) +from fairseq2.models.qwen.moe import Qwen35MoeBlock +from fairseq2.models.transformer import FeedForwardNetwork +from fairseq2.models.transformer.attention_bias import ( + AttentionBiasCache, + CausalAttentionBias, + IdentityBias, +) +from fairseq2.models.transformer.sdpa.naive import NaiveSDPA +from fairseq2.models.utils.checkpoint import convert_state_dict, create_reverse_key_map +from fairseq2.nn import BatchLayout, IncrementalStateBag +from tests.common import assert_close as fs2_assert_close +from tests.common import device + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + + +def _small_dense_config() -> Qwen35Config: + config = Qwen35Config() + config.model_dim = 64 + config.vocab_size = 128 + config.num_layers = 4 + config.num_attn_heads = 4 + config.num_key_value_heads = 2 + config.head_dim = 16 + config.ffn_inner_dim = 128 + config.partial_rotary_factor = 0.25 + config.linear_num_key_heads = 2 + config.linear_num_value_heads = 4 + config.linear_key_head_dim = 8 + config.linear_value_head_dim = 8 + config.layer_types = None + config.__post_init__() + return config + + +def _small_moe_config() -> Qwen35MoeConfig: + config = Qwen35MoeConfig() + config.model_dim = 64 + config.vocab_size = 128 + config.num_layers = 4 + config.num_attn_heads = 4 + config.num_key_value_heads = 2 + config.head_dim = 16 + config.ffn_inner_dim = 128 + config.partial_rotary_factor = 0.25 + config.linear_num_key_heads = 2 + config.linear_num_value_heads = 4 + config.linear_key_head_dim = 8 + config.linear_value_head_dim = 8 + config.num_experts = 4 + config.num_experts_per_tok = 2 + config.moe_intermediate_size = 32 + config.shared_expert_intermediate_size = 32 + config.layer_types = None + config.__post_init__() + return config + + +# --------------------------------------------------------------------------- +# GatedDeltaNet +# --------------------------------------------------------------------------- + + +class TestGatedDeltaNet: + def test_forward_shape(self) -> None: + gdn = GatedDeltaNet( + hidden_size=64, + num_k_heads=2, + num_v_heads=4, + head_k_dim=16, + head_v_dim=16, + conv_kernel_size=4, + ).to(device) + out = gdn(torch.randn(2, 8, 64, device=device)) + assert out.shape == (2, 8, 64) + + def test_chunked_vs_recurrent(self) -> None: + B, S, H, K, V = 1, 16, 4, 16, 16 + q = torch.randn(B, S, H, K, device=device) + k = torch.randn(B, S, H, K, device=device) + v = torch.randn(B, S, H, V, device=device) + g = -torch.rand(B, S, H, device=device).abs() + beta = torch.rand(B, S, H, device=device) + + c_out, c_st = torch_chunk_gated_delta_rule( + q, k, v, g, beta, output_final_state=True, use_qk_l2norm_in_kernel=True + ) + r_out, r_st = torch_recurrent_gated_delta_rule( + q, k, v, g, beta, output_final_state=True, use_qk_l2norm_in_kernel=True + ) + fs2_assert_close(c_out, r_out, atol=1e-4) + assert c_st is not None and r_st is not None + fs2_assert_close(c_st, r_st, atol=1e-4) + + def test_state_reorder(self) -> None: + conv = torch.randn(3, 8, 3, device=device) + rec = torch.randn(3, 4, 16, 16, device=device) + state = GatedDeltaNetState(conv, rec) + state.reorder(torch.tensor([2, 0, 1], device=device)) + fs2_assert_close(state.conv_state[0], conv[2]) + fs2_assert_close(state.recurrent_state[0], rec[2]) + + @pytest.mark.skipif( + not torch.cuda.is_available(), + reason="causal_conv1d incremental decode requires CUDA", + ) + def test_incremental_decode(self) -> None: + gdn = ( + GatedDeltaNet( + hidden_size=64, + num_k_heads=2, + num_v_heads=4, + head_k_dim=16, + head_v_dim=16, + ) + .to(device) + .eval() + ) + + full_seq = torch.randn(1, 9, 64, device=device) + with torch.no_grad(): + full_out = gdn(full_seq) + + state_bag = IncrementalStateBag(max_num_steps=9) + with torch.no_grad(): + gdn(full_seq[:, :8, :], state_bag=state_bag) + state_bag.increment_step_nr(8) + with torch.no_grad(): + incr_out = gdn(full_seq[:, 8:, :], state_bag=state_bag) + fs2_assert_close(incr_out, full_out[:, -1:, :], atol=1e-4) + + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- + + +class TestQwen35Attention: + def test_forward_shape(self) -> None: + sdpa = NaiveSDPA(IdentityBias()) + attn = Qwen35Attention(model_dim=64, num_heads=4, sdpa=sdpa, head_dim=16).to( + device + ) + seqs = torch.randn(2, 8, 64, device=device) + layout = BatchLayout.of(seqs) + bias_cache = AttentionBiasCache() + with torch.no_grad(): + out = attn(seqs, layout, seqs, layout, seqs, bias_cache) + assert out.shape == (2, 8, 64) + + def test_gqa(self) -> None: + sdpa = NaiveSDPA(IdentityBias()) + attn = Qwen35Attention( + model_dim=64, + num_heads=4, + sdpa=sdpa, + head_dim=16, + num_key_value_heads=2, + ).to(device) + seqs = torch.randn(2, 6, 64, device=device) + layout = BatchLayout.of(seqs) + with torch.no_grad(): + out = attn(seqs, layout, seqs, layout, seqs, AttentionBiasCache()) + assert out.shape == (2, 6, 64) + + def test_incremental_kv_cache(self) -> None: + sdpa = NaiveSDPA(CausalAttentionBias()) + attn = Qwen35Attention(model_dim=64, num_heads=4, sdpa=sdpa, head_dim=16).to( + device + ) + attn.eval() + + seqs = torch.randn(1, 6, 64, device=device) + layout = BatchLayout.of(seqs) + bias_cache = AttentionBiasCache() + with torch.no_grad(): + full_out = attn(seqs, layout, seqs, layout, seqs, bias_cache) + + state_bag = IncrementalStateBag(max_num_steps=32) + with torch.no_grad(): + for idx in range(6): + step = seqs[:, idx : idx + 1, :] + sl = BatchLayout.of(step) + out = attn(step, sl, step, sl, step, bias_cache, state_bag=state_bag) + fs2_assert_close(out, full_out[:, idx : idx + 1, :], atol=1e-5) + state_bag.increment_step_nr() + + +# --------------------------------------------------------------------------- +# Model factory +# --------------------------------------------------------------------------- + + +class TestQwen35Factory: + def test_small_model_forward(self) -> None: + config = _small_dense_config() + model = create_qwen35_model(config).to(device).eval() + ids = torch.randint(0, 128, (1, 16), device=device) + with torch.no_grad(): + logits = model(ids, BatchLayout.of(ids)) + assert logits.shape == (1, 16, 128) + + def test_hybrid_layer_pattern(self) -> None: + config = _small_dense_config() + with torch.device("meta"): + model = create_qwen35_model(config) + types = [ + l.layer_type + for l in model.decoder.layers + if isinstance(l, Qwen35DecoderLayer) + ] + assert types == [ + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + ] + + +# --------------------------------------------------------------------------- +# MoE +# --------------------------------------------------------------------------- + + +class TestQwen35Moe: + def test_moe_block_shape(self) -> None: + moe = Qwen35MoeBlock( + model_dim=32, + num_experts=4, + num_experts_per_tok=2, + moe_intermediate_size=16, + shared_expert_intermediate_size=16, + ).to(device) + with torch.no_grad(): + out = moe(torch.randn(2, 8, 32, device=device)) + assert out.shape == (2, 8, 32) + + def test_moe_is_ffn(self) -> None: + moe = Qwen35MoeBlock( + model_dim=32, + num_experts=4, + num_experts_per_tok=2, + moe_intermediate_size=16, + shared_expert_intermediate_size=16, + ) + assert isinstance(moe, FeedForwardNetwork) + + +# --------------------------------------------------------------------------- +# Interop (state dict conversion) +# --------------------------------------------------------------------------- + + +class TestQwen35Interop: + def test_key_round_trip(self) -> None: + config = _small_dense_config() + with torch.device("meta"): + model = create_qwen35_model(config) + fs2_keys = set(model.state_dict().keys()) + + sd: dict[str, object] = {k: torch.empty(0) for k in fs2_keys} + rev = create_reverse_key_map(_QWEN35_HG_KEY_MAP) + hg_sd = convert_state_dict(sd, rev) + rt_keys = set(convert_state_dict(dict(hg_sd), _QWEN35_HG_KEY_MAP).keys()) + assert fs2_keys == rt_keys + + def test_rmsnorm_plus_one(self) -> None: + config = _small_dense_config() + hf_sd: dict[str, object] = {} + for i in range(config.num_layers): + hf_sd[f"model.layers.{i}.input_layernorm.weight"] = torch.zeros( + config.model_dim + ) + hf_sd[f"model.layers.{i}.post_attention_layernorm.weight"] = torch.zeros( + config.model_dim + ) + hf_sd["model.norm.weight"] = torch.zeros(config.model_dim) + hf_sd["model.embed_tokens.weight"] = torch.zeros( + config.vocab_size, config.model_dim + ) + hf_sd["lm_head.weight"] = torch.zeros(config.vocab_size, config.model_dim) + + converted = convert_qwen35_state_dict(dict(hf_sd), config) + for key in converted: + if any(key.endswith(s) for s in _QWEN35_RMSNORM_KEYS): + weight = converted[key] + assert isinstance(weight, torch.Tensor) + assert_close(weight, torch.ones_like(weight)) + + def test_tied_embeddings(self) -> None: + config = _small_dense_config() + config.tied_embeddings = True + weight = torch.randn(config.vocab_size, config.model_dim) + hf_sd: dict[str, object] = { + "model.embed_tokens.weight": weight, + "model.norm.weight": torch.zeros(config.model_dim), + } + result = convert_qwen35_state_dict(dict(hf_sd), config) + assert "decoder_frontend.embed.weight" in result + assert "final_proj.weight" in result + assert result["final_proj.weight"] is result["decoder_frontend.embed.weight"] + + def test_vl_keys_filtered(self) -> None: + config = _small_dense_config() + config.tied_embeddings = True + sd: dict[str, object] = { + "model.language_model.embed_tokens.weight": torch.randn( + config.vocab_size, config.model_dim + ), + "model.language_model.norm.weight": torch.zeros(config.model_dim), + "model.visual.blocks.0.attn.proj.weight": torch.empty(0), + "mtp.fc.weight": torch.empty(0), + } + result = convert_qwen35_state_dict(dict(sd), config) + for key in result: + assert not key.startswith(("model.visual.", "mtp.")) + + +# --------------------------------------------------------------------------- +# HuggingFace converter (bidirectional) +# --------------------------------------------------------------------------- + + +class TestQwen35HuggingFaceConverter: + def test_dense_round_trip(self) -> None: + config = _small_dense_config() + with torch.device("meta"): + model = create_qwen35_model(config) + fs2_keys = set(model.state_dict().keys()) + sd: dict[str, object] = {k: torch.empty(0) for k in fs2_keys} + + converter = _Qwen35HuggingFaceConverter() + hg_sd = converter.to_hg_state_dict(sd, config) + rt_keys = set(convert_qwen35_state_dict(dict(hg_sd), config).keys()) + assert fs2_keys == rt_keys + + def test_to_hg_config(self) -> None: + config = _small_dense_config() + hg_config = _Qwen35HuggingFaceConverter().to_hg_config(config) + assert hg_config.kls_name == "Qwen3_5TextConfig" + assert hg_config.arch == "Qwen3_5ForCausalLM" + assert hg_config.data["hidden_size"] == config.model_dim + + def test_moe_round_trip(self) -> None: + config = _small_moe_config() + with torch.device("meta"): + model = create_qwen35_moe_model(config) + fs2_keys = set(model.state_dict().keys()) + sd: dict[str, object] = {k: torch.empty(0) for k in fs2_keys} + + converter = _Qwen35MoeHuggingFaceConverter() + hg_sd = converter.to_hg_state_dict(sd, config) + rt_keys = set(convert_qwen35_moe_state_dict(dict(hg_sd), config).keys()) + assert fs2_keys == rt_keys + + def test_moe_to_hg_config(self) -> None: + config = _small_moe_config() + hg_config = _Qwen35MoeHuggingFaceConverter().to_hg_config(config) + assert hg_config.kls_name == "Qwen3_5TextConfig" + assert hg_config.arch == "Qwen3_5MoeForCausalLM" + assert hg_config.data["num_experts"] == config.num_experts