From 22481a24b24c4fc3e07a50765d8ed60fc4bcdf3b Mon Sep 17 00:00:00 2001 From: yunchaoyang1 user Date: Wed, 27 May 2026 20:45:33 +0000 Subject: [PATCH 1/4] Add Qwen 3.5 model family support (0.8B, 2B, 9B, 27B, MoE 35B-A3B) Implement the Qwen 3.5 model family in fairseq2 with hybrid GatedDeltaNet linear attention (75%) + full attention (25%) architecture. Includes: - GatedDeltaNet module with chunked/recurrent delta rule kernels, causal conv1d, and fused PyTorch fallbacks - Qwen35Attention with partial RoPE, QK-norm, and output gating - Qwen35DecoderLayer supporting both full and linear attention types - Top-K MoE routing with shared experts for MoE variants - Bidirectional HuggingFace state dict conversion with RMSNorm 1+w convention handling - Asset cards for all model sizes (0.8B, 2B, 9B, 27B, MoE 35B-A3B) - 42 unit tests covering all modules and interop - Integration test for HF logit parity (Qwen3.5-0.8B) - SFT and pretraining recipe configs --- CHANGELOG.md | 1 + recipes/lm/sft/configs/qwen35_0.8b_gsm8k.yaml | 81 +++ .../configs/qwen35_0.8b_fineweb_edu_10bt.yaml | 162 ++++++ .../configs/qwen35_2b_fineweb_edu_10bt.yaml | 162 ++++++ .../lm/train/configs/test_qwen35_0.8b.yaml | 155 +++++ .../train/scripts/run_qwen35_fineweb_edu.sh | 24 + src/fairseq2/assets/cards/models/qwen35.yaml | 75 +++ src/fairseq2/composition/models.py | 50 ++ src/fairseq2/models/qwen/__init__.py | 39 +- src/fairseq2/models/qwen/attention.py | 213 +++++++ src/fairseq2/models/qwen/config.py | 171 +++++- src/fairseq2/models/qwen/decoder_layer.py | 168 ++++++ src/fairseq2/models/qwen/factory.py | 213 ++++++- src/fairseq2/models/qwen/gated_delta_net.py | 545 ++++++++++++++++++ src/fairseq2/models/qwen/hub.py | 25 +- src/fairseq2/models/qwen/interop.py | 313 +++++++++- src/fairseq2/models/qwen/moe.py | 235 ++++++++ tests/integration/models/test_qwen35.py | 213 +++++++ tests/unit/models/qwen/__init__.py | 5 + .../unit/models/qwen/test_gated_delta_net.py | 152 +++++ .../unit/models/qwen/test_qwen35_attention.py | 173 ++++++ .../models/qwen/test_qwen35_decoder_layer.py | 150 +++++ tests/unit/models/qwen/test_qwen35_interop.py | 500 ++++++++++++++++ tests/unit/models/qwen/test_qwen35_moe.py | 154 +++++ 24 files changed, 3972 insertions(+), 7 deletions(-) create mode 100644 recipes/lm/sft/configs/qwen35_0.8b_gsm8k.yaml create mode 100644 recipes/lm/train/configs/qwen35_0.8b_fineweb_edu_10bt.yaml create mode 100644 recipes/lm/train/configs/qwen35_2b_fineweb_edu_10bt.yaml create mode 100644 recipes/lm/train/configs/test_qwen35_0.8b.yaml create mode 100644 recipes/lm/train/scripts/run_qwen35_fineweb_edu.sh create mode 100644 src/fairseq2/assets/cards/models/qwen35.yaml create mode 100644 src/fairseq2/models/qwen/attention.py create mode 100644 src/fairseq2/models/qwen/decoder_layer.py create mode 100644 src/fairseq2/models/qwen/gated_delta_net.py create mode 100644 src/fairseq2/models/qwen/moe.py create mode 100644 tests/integration/models/test_qwen35.py create mode 100644 tests/unit/models/qwen/__init__.py create mode 100644 tests/unit/models/qwen/test_gated_delta_net.py create mode 100644 tests/unit/models/qwen/test_qwen35_attention.py create mode 100644 tests/unit/models/qwen/test_qwen35_decoder_layer.py create mode 100644 tests/unit/models/qwen/test_qwen35_interop.py create mode 100644 tests/unit/models/qwen/test_qwen35_moe.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d4325637..8f57787ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ All notable changes to fairseq2 are documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [0.8.1] - Unreleased +- Qwen 3.5 model family (0.8B, 2B, 9B, 27B dense and 35B-A3B MoE) with base and instruction-tuned variants. Features hybrid GatedDeltaNet linear attention (75%) + full attention (25%) architecture, partial RoPE, QK-norm, output gating, Top-K MoE routing with shared experts, RMSNorm 1+w convention, bidirectional HuggingFace state dict conversion, and SFT/pretraining recipe configs. - Gemma 4 model family (E4B, 31B, 26B-A4B) with base and instruction-tuned variants. Includes decoder with Per-Layer Embeddings (PLE), partial RoPE, KV sharing across sliding/global attention layers, Mixture-of-Experts (26B-A4B), QK/V-norm, logit soft-capping, audio tower (Conformer encoder for multimodal E4B), bidirectional HuggingFace state dict conversion, FSDP/activation checkpointing/tensor parallel support, and SFT recipe configs. - Bump transformers~=v5.5 and loosen huggingface_hub upper bound. (#1508) - Fixed typo in WerMetric: use `hyp_seqs` instead of `ref_seqs` for `hyp_seqs_list`. (#1506) diff --git a/recipes/lm/sft/configs/qwen35_0.8b_gsm8k.yaml b/recipes/lm/sft/configs/qwen35_0.8b_gsm8k.yaml new file mode 100644 index 000000000..23e5e6f85 --- /dev/null +++ b/recipes/lm/sft/configs/qwen35_0.8b_gsm8k.yaml @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Qwen 3.5 0.8B GSM8K SFT Fine-tuning Config +# +# Validates training recipe integration and loss convergence for the Qwen 3.5 +# model on the GSM8K math reasoning dataset. +# +# Usage: +# torchrun --standalone --nproc_per_node=8 -m recipes.lm.sft \ +# --config-file recipes/lm/sft/configs/qwen35_0.8b_gsm8k.yaml \ +# /path/to/output_dir + +model: + name: "qwen35_0.8b" + dtype: bfloat16 + config_overrides: + pad_idx: 248044 + +tokenizer: + name: "qwen35_0.8b" + config_overrides: + use_im_end: true + +dataset: + max_seq_len: 4096 + max_num_tokens: 8192 + valid_split: "sft_test" + chat_mode: false + config_overrides: + sources: + train: + - path: "hg://facebook/fairseq2-lm-gsm8k" + split: "sft_train" + weight: 1.0 + sft_test: + - path: "hg://facebook/fairseq2-lm-gsm8k" + split: "sft_test" + weight: 1.0 + +trainer: + data_parallelism: fsdp + max_grad_norm: 1.0 + mixed_precision: + mode: static + dtype: bfloat16 + +optimizer: + name: adamw + config: + lr: 2.0e-5 + betas: [0.9, 0.95] + weight_decay: 0.1 + impl: fused + +lr_scheduler: + name: cosine_annealing + config: + final_lr_scale: 0.1 + num_warmup_steps: 100 + +regime: + num_steps: 100000 + checkpoint_every_n_steps: 100 + validate_every_n_steps: 100 + keep_last_n_checkpoints: 10 + publish_metrics_every_n_steps: 1 + save_model_only: false + +common: + seed: 0 + metric_recorders: + wandb: + enabled: true + entity: "yunchaoyang1" + project: "fairseq2" + tensorboard: + enabled: false diff --git a/recipes/lm/train/configs/qwen35_0.8b_fineweb_edu_10bt.yaml b/recipes/lm/train/configs/qwen35_0.8b_fineweb_edu_10bt.yaml new file mode 100644 index 000000000..76992f316 --- /dev/null +++ b/recipes/lm/train/configs/qwen35_0.8b_fineweb_edu_10bt.yaml @@ -0,0 +1,162 @@ +# Qwen 3.5 0.8B Continued Pretraining on FineWeb-Edu 10BT +# +# Loads the pretrained Qwen 3.5 0.8B checkpoint and continues training +# on the FineWeb-Edu Sample 10BT educational text dataset. +# +# Prerequisites: +# 1. Convert parquet data to chunked JSONL: +# cd /checkpoint/smallomnillm/shared/data/fineweb-edu +# python convert_to_jsonl.py --sample 10BT --chunk-format --lightweight \ +# --num-shards 256 --output-dir jsonl/10BT +# +# Usage: +# torchrun --standalone --nproc_per_node=8 -m recipes.lm.train \ +# --config-file recipes/lm/train/configs/qwen35_0.8b_fineweb_edu_10bt.yaml \ +# /path/to/output_dir + +model: + name: qwen35_0.8b + dtype: bfloat16 + mmap: false + compile: false + compile_options: + fullgraph: false + dynamic: false + mode: default + backend: inductor + backend_options: null + +dataset: + name: null + family: lm_train + config_overrides: + sources: + - path: /checkpoint/smallomnillm/shared/data/fineweb-edu/jsonl/10BT + weight: 1.0 + max_seq_len: 4096 + max_num_tokens: 8192 + prefetch: 4 + sync_ranks: false + +tokenizer: + name: qwen35_0.8b + path: null + family: null + config_overrides: null + +gang: + tensor_parallel_size: 1 + timeout: 15 + high_priority: true + +trainer: + data_parallelism: fsdp + fsdp: + version: v2 + granularity: layer + hybrid: false + reshard_after_forward: true + fp32_reduce: true + mixed_precision: + mode: static + dtype: bfloat16 + grad_accumulation: + num_batches: 1 + no_sync: false + activation_checkpointing: + mode: layerwise + every_nth_layer: 1 + max_grad_norm: 1.0 + fp16_loss_scale: + - 128.0 + - 0.0001 + gc_every_n_steps: 1000 + grad_check: false + anomaly_detection: false + +optimizer: + name: adamw + config: + lr: 5.0e-5 + betas: + - 0.9 + - 0.95 + eps: 1.0e-8 + weight_decay: 0.1 + amsgrad: false + maximize: false + capturable: false + differentiable: false + impl: fused + groups: [] + +lr_scheduler: + name: cosine_annealing + config: + cycle_len: null + num_warmup_steps: 500 + cycle_mul: 1.0 + lr_mul: 1.0 + start_lr: 1.0e-30 + final_lr: null + final_lr_scale: 0.01 + +regime: + num_steps: 76000 + num_data_epochs: null + validate_at_start: false + validate_after_n_steps: 0 + validate_every_n_steps: 4000 + validate_after_n_data_epochs: 0 + validate_every_n_data_epochs: null + score_metric: null + checkpoint_after_n_steps: 0 + checkpoint_every_n_steps: 2000 + checkpoint_after_n_data_epochs: 0 + checkpoint_every_n_data_epochs: null + save_model_only: all_but_last + export_hugging_face: true + keep_last_n_checkpoints: 3 + keep_best_n_checkpoints: null + keep_checkpoint_every_n_steps: 10000 + publish_metrics_after_n_steps: 0 + publish_metrics_every_n_steps: 10 + publish_metrics_after_n_data_epochs: 0 + publish_metrics_every_n_data_epochs: null + +common: + torch: + num_threads: null + allow_tf32: true + fp16_reduced_precision: true + bf16_reduced_precision: true + default_sdpa: torch + compiled_region_activation_memory_budget: 0.9 + metric_recorders: + tensorboard: + enabled: true + wandb: + enabled: true + entity: smallomni + project: qwen35_0.8b_fineweb_edu_10bt + run_id: persistent + run_name: null + group: null + job_type: null + resume_mode: null + profilers: + torch: + enabled: false + skip_n_steps: 4 + wait_n_steps: 0 + num_warmup_steps: 1 + num_active_steps: 4 + repeat: 1 + assets: + extra_paths: [] + prev_checkpoint_dir: null + seed: 2 + debug: false + cluster: auto + no_sweep_dir: false + sweep_format: ws_{world_size}.{hash} diff --git a/recipes/lm/train/configs/qwen35_2b_fineweb_edu_10bt.yaml b/recipes/lm/train/configs/qwen35_2b_fineweb_edu_10bt.yaml new file mode 100644 index 000000000..4324fe9d8 --- /dev/null +++ b/recipes/lm/train/configs/qwen35_2b_fineweb_edu_10bt.yaml @@ -0,0 +1,162 @@ +# Qwen 3.5 2B Continued Pretraining on FineWeb-Edu 10BT +# +# Loads the pretrained Qwen 3.5 2B checkpoint and continues training +# on the FineWeb-Edu Sample 10BT educational text dataset. +# +# Prerequisites: +# 1. Convert parquet data to chunked JSONL: +# cd /checkpoint/smallomnillm/shared/data/fineweb-edu +# python convert_to_jsonl.py --sample 10BT --chunk-format --lightweight \ +# --num-shards 256 --output-dir jsonl/10BT +# +# Usage: +# torchrun --standalone --nproc_per_node=8 -m recipes.lm.train \ +# --config-file recipes/lm/train/configs/qwen35_2b_fineweb_edu_10bt.yaml \ +# /path/to/output_dir + +model: + name: qwen35_2b + dtype: bfloat16 + mmap: false + compile: false + compile_options: + fullgraph: false + dynamic: false + mode: default + backend: inductor + backend_options: null + +dataset: + name: null + family: lm_train + config_overrides: + sources: + - path: /checkpoint/smallomnillm/shared/data/fineweb-edu/jsonl/10BT + weight: 1.0 + max_seq_len: 4096 + max_num_tokens: 8192 + prefetch: 4 + sync_ranks: false + +tokenizer: + name: qwen35_2b + path: null + family: null + config_overrides: null + +gang: + tensor_parallel_size: 1 + timeout: 15 + high_priority: true + +trainer: + data_parallelism: fsdp + fsdp: + version: v2 + granularity: layer + hybrid: false + reshard_after_forward: true + fp32_reduce: true + mixed_precision: + mode: static + dtype: bfloat16 + grad_accumulation: + num_batches: 1 + no_sync: false + activation_checkpointing: + mode: layerwise + every_nth_layer: 1 + max_grad_norm: 1.0 + fp16_loss_scale: + - 128.0 + - 0.0001 + gc_every_n_steps: 1000 + grad_check: false + anomaly_detection: false + +optimizer: + name: adamw + config: + lr: 5.0e-5 + betas: + - 0.9 + - 0.95 + eps: 1.0e-8 + weight_decay: 0.1 + amsgrad: false + maximize: false + capturable: false + differentiable: false + impl: fused + groups: [] + +lr_scheduler: + name: cosine_annealing + config: + cycle_len: null + num_warmup_steps: 500 + cycle_mul: 1.0 + lr_mul: 1.0 + start_lr: 1.0e-30 + final_lr: null + final_lr_scale: 0.01 + +regime: + num_steps: 76000 + num_data_epochs: null + validate_at_start: false + validate_after_n_steps: 0 + validate_every_n_steps: 4000 + validate_after_n_data_epochs: 0 + validate_every_n_data_epochs: null + score_metric: null + checkpoint_after_n_steps: 0 + checkpoint_every_n_steps: 2000 + checkpoint_after_n_data_epochs: 0 + checkpoint_every_n_data_epochs: null + save_model_only: all_but_last + export_hugging_face: true + keep_last_n_checkpoints: 3 + keep_best_n_checkpoints: null + keep_checkpoint_every_n_steps: 10000 + publish_metrics_after_n_steps: 0 + publish_metrics_every_n_steps: 10 + publish_metrics_after_n_data_epochs: 0 + publish_metrics_every_n_data_epochs: null + +common: + torch: + num_threads: null + allow_tf32: true + fp16_reduced_precision: true + bf16_reduced_precision: true + default_sdpa: torch + compiled_region_activation_memory_budget: 0.9 + metric_recorders: + tensorboard: + enabled: true + wandb: + enabled: true + entity: smallomni + project: qwen35_2b_fineweb_edu_10bt + run_id: persistent + run_name: null + group: null + job_type: null + resume_mode: null + profilers: + torch: + enabled: false + skip_n_steps: 4 + wait_n_steps: 0 + num_warmup_steps: 1 + num_active_steps: 4 + repeat: 1 + assets: + extra_paths: [] + prev_checkpoint_dir: null + seed: 2 + debug: false + cluster: auto + no_sweep_dir: false + sweep_format: ws_{world_size}.{hash} diff --git a/recipes/lm/train/configs/test_qwen35_0.8b.yaml b/recipes/lm/train/configs/test_qwen35_0.8b.yaml new file mode 100644 index 000000000..7a80a1236 --- /dev/null +++ b/recipes/lm/train/configs/test_qwen35_0.8b.yaml @@ -0,0 +1,155 @@ +# Qwen 3.5 0.8B Quick Test Config (Continued Pretraining) +# +# Full model (24 layers, no num_layers override — required for loading +# pretrained checkpoint) with only 100 steps for smoke testing. +# Verifies: model loads, pad_idx check passes, data loads, loss computes. +# +# NOTE: Cannot reduce num_layers for continued pretraining because the +# checkpoint has 24 layers of weights that must match the model architecture. +# +# Usage: +# source ~/envs/fs081-pt290-cu128/bin/activate +# cd /storage/home/yunchaoyang1/fairseq2 +# torchrun --standalone --nproc_per_node=1 -m recipes.lm.train \ +# --config-file recipes/lm/train/configs/test_qwen35_0.8b.yaml \ +# /tmp/qwen35_test + +model: + name: qwen35_0.8b + dtype: bfloat16 + mmap: false + compile: false + compile_options: + fullgraph: false + dynamic: false + mode: default + backend: inductor + backend_options: null + +dataset: + name: null + family: lm_train + config_overrides: + sources: + - path: /checkpoint/smallomnillm/shared/data/fineweb-edu/jsonl/10BT + weight: 1.0 + max_seq_len: 512 + max_num_tokens: 2048 + prefetch: 4 + sync_ranks: false + +tokenizer: + name: qwen35_0.8b + path: null + family: null + config_overrides: null + +gang: + tensor_parallel_size: 1 + timeout: 15 + high_priority: true + +trainer: + data_parallelism: fsdp + fsdp: + version: v2 + granularity: layer + hybrid: false + reshard_after_forward: true + fp32_reduce: true + mixed_precision: + mode: static + dtype: bfloat16 + grad_accumulation: + num_batches: 1 + no_sync: false + activation_checkpointing: + mode: off + every_nth_layer: 1 + max_grad_norm: 1.0 + fp16_loss_scale: + - 128.0 + - 0.0001 + gc_every_n_steps: 1000 + grad_check: false + anomaly_detection: false + +optimizer: + name: adamw + config: + lr: 5.0e-5 + betas: + - 0.9 + - 0.95 + eps: 1.0e-8 + weight_decay: 0.1 + amsgrad: false + maximize: false + capturable: false + differentiable: false + impl: fused + groups: [] + +lr_scheduler: + name: cosine_annealing + config: + cycle_len: null + num_warmup_steps: 10 + cycle_mul: 1.0 + lr_mul: 1.0 + start_lr: 1.0e-30 + final_lr: null + final_lr_scale: 0.01 + +regime: + num_steps: 100 + num_data_epochs: null + validate_at_start: false + validate_after_n_steps: 0 + validate_every_n_steps: 4000 + validate_after_n_data_epochs: 0 + validate_every_n_data_epochs: null + score_metric: null + checkpoint_after_n_steps: 0 + checkpoint_every_n_steps: 50 + checkpoint_after_n_data_epochs: 0 + checkpoint_every_n_data_epochs: null + save_model_only: all_but_last + export_hugging_face: false + keep_last_n_checkpoints: 2 + keep_best_n_checkpoints: null + keep_checkpoint_every_n_steps: null + publish_metrics_after_n_steps: 0 + publish_metrics_every_n_steps: 1 + publish_metrics_after_n_data_epochs: 0 + publish_metrics_every_n_data_epochs: null + +common: + torch: + num_threads: null + allow_tf32: true + fp16_reduced_precision: true + bf16_reduced_precision: true + default_sdpa: torch + compiled_region_activation_memory_budget: 0.9 + metric_recorders: + tensorboard: + enabled: false + wandb: + enabled: false + profilers: + torch: + enabled: false + skip_n_steps: 4 + wait_n_steps: 0 + num_warmup_steps: 1 + num_active_steps: 4 + repeat: 1 + assets: + extra_paths: [] + prev_checkpoint_dir: null + seed: 2 + debug: false + cluster: auto + no_sweep_dir: false + sweep_format: ws_{world_size}.{hash} diff --git a/recipes/lm/train/scripts/run_qwen35_fineweb_edu.sh b/recipes/lm/train/scripts/run_qwen35_fineweb_edu.sh new file mode 100644 index 000000000..d23b30de6 --- /dev/null +++ b/recipes/lm/train/scripts/run_qwen35_fineweb_edu.sh @@ -0,0 +1,24 @@ +#!/bin/bash +#SBATCH --job-name=qwen35_pretrain_fineweb +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=8 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=96 +#SBATCH --mem=0 +#SBATCH --time=48:00:00 +#SBATCH --account=smallomnillm +#SBATCH --qos=h200_smallomnillm_high + +#SBATCH --output=/checkpoint/smallomnillm/yunchaoyang1/qwen35_pretrain/slurm_%j.out +#SBATCH --error=/checkpoint/smallomnillm/yunchaoyang1/qwen35_pretrain/slurm_%j.err + +OUTPUT_DIR=/checkpoint/smallomnillm/yunchaoyang1/qwen35_pretrain/baseline + +mkdir -p "${OUTPUT_DIR}" + +source ~/envs/fs081-pt290-cu128/bin/activate +cd /storage/home/yunchaoyang1/fairseq2 + +torchrun --standalone --nproc_per_node=8 -m recipes.lm.train \ + --config-file recipes/lm/train/configs/qwen35_0.8b_fineweb_edu_10bt.yaml \ + "${OUTPUT_DIR}" diff --git a/src/fairseq2/assets/cards/models/qwen35.yaml b/src/fairseq2/assets/cards/models/qwen35.yaml new file mode 100644 index 000000000..c72de8b42 --- /dev/null +++ b/src/fairseq2/assets/cards/models/qwen35.yaml @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +name: qwen35_0.8b +model_family: qwen3_5 +model_arch: qwen35_0.8b +checkpoint: "/checkpoint/smallomnillm/shared/models/Qwen3.5-0.8B" +tokenizer: "hg://Qwen/Qwen3.5-0.8B" +tokenizer_family: qwen + +--- + +name: qwen35_2b +model_family: qwen3_5 +model_arch: qwen35_2b +checkpoint: "hg://Qwen/Qwen3.5-2B" +tokenizer: "hg://Qwen/Qwen3.5-2B" +tokenizer_family: qwen + +--- + +name: qwen35_2b_base +model_family: qwen3_5 +model_arch: qwen35_2b +checkpoint: "hg://Qwen/Qwen3.5-2B-Base" +tokenizer: "hg://Qwen/Qwen3.5-2B-Base" +tokenizer_family: qwen + +--- + +name: qwen35_9b +model_family: qwen3_5 +model_arch: qwen35_9b +checkpoint: "hg://Qwen/Qwen3.5-9B" +tokenizer: "hg://Qwen/Qwen3.5-9B" +tokenizer_family: qwen + +--- + +name: qwen35_9b_base +model_family: qwen3_5 +model_arch: qwen35_9b +checkpoint: "hg://Qwen/Qwen3.5-9B-Base" +tokenizer: "hg://Qwen/Qwen3.5-9B-Base" +tokenizer_family: qwen + +--- + +name: qwen35_27b +model_family: qwen3_5 +model_arch: qwen35_27b +checkpoint: "hg://Qwen/Qwen3.5-27B" +tokenizer: "hg://Qwen/Qwen3.5-27B" +tokenizer_family: qwen + +--- + +name: qwen35_moe_35b_a3b +model_family: qwen3_5_moe +model_arch: qwen35_moe_35b_a3b +checkpoint: "hg://Qwen/Qwen3.5-35B-A3B" +tokenizer: "hg://Qwen/Qwen3.5-35B-A3B" +tokenizer_family: qwen + +--- + +name: qwen35_moe_35b_a3b_base +model_family: qwen3_5_moe +model_arch: qwen35_moe_35b_a3b +checkpoint: "hg://Qwen/Qwen3.5-35B-A3B-Base" +tokenizer: "hg://Qwen/Qwen3.5-35B-A3B-Base" +tokenizer_family: qwen diff --git a/src/fairseq2/composition/models.py b/src/fairseq2/composition/models.py index 3be08dbbf..8537ea05d 100644 --- a/src/fairseq2/composition/models.py +++ b/src/fairseq2/composition/models.py @@ -107,11 +107,23 @@ register_olmo_configs, ) from fairseq2.models.qwen import ( + QWEN35_FAMILY, + QWEN35_MOE_FAMILY, QWEN_FAMILY, + Qwen35Config, + Qwen35MoeConfig, QwenConfig, + _Qwen35HuggingFaceConverter, + _Qwen35MoeHuggingFaceConverter, _QwenHuggingFaceConverter, + convert_qwen35_moe_state_dict, + convert_qwen35_state_dict, convert_qwen_state_dict, + create_qwen35_model, + create_qwen35_moe_model, create_qwen_model, + register_qwen35_configs, + register_qwen35_moe_configs, register_qwen_configs, ) from fairseq2.models.s2t_conformer import ( @@ -417,6 +429,44 @@ def _register_model_families(container: DependencyContainer) -> None: HuggingFaceConverter, _QwenHuggingFaceConverter, key=QWEN_FAMILY ) + # Qwen 3.5 + register_model_family( + container, + QWEN35_FAMILY, + kls=TransformerLM, + config_kls=Qwen35Config, + factory=create_qwen35_model, + state_dict_converter=convert_qwen35_state_dict, + compiler=compile_transformer_lm, + fsdp_applier=apply_fsdp_to_transformer_lm, + layerwise_ac_applier=apply_ac_to_transformer_lm, + ) + + register_qwen35_configs(container) + + container.register_type( + HuggingFaceConverter, _Qwen35HuggingFaceConverter, key=QWEN35_FAMILY + ) + + # Qwen 3.5 MoE + register_model_family( + container, + QWEN35_MOE_FAMILY, + kls=TransformerLM, + config_kls=Qwen35MoeConfig, + factory=create_qwen35_moe_model, + state_dict_converter=convert_qwen35_moe_state_dict, + compiler=compile_transformer_lm, + fsdp_applier=apply_fsdp_to_transformer_lm, + layerwise_ac_applier=apply_ac_to_transformer_lm, + ) + + register_qwen35_moe_configs(container) + + container.register_type( + HuggingFaceConverter, _Qwen35MoeHuggingFaceConverter, key=QWEN35_MOE_FAMILY + ) + # S2T Conformer register_model_family( container, diff --git a/src/fairseq2/models/qwen/__init__.py b/src/fairseq2/models/qwen/__init__.py index 0d7d28179..135050769 100644 --- a/src/fairseq2/models/qwen/__init__.py +++ b/src/fairseq2/models/qwen/__init__.py @@ -6,20 +6,57 @@ from __future__ import annotations +from fairseq2.models.qwen.config import QWEN35_FAMILY as QWEN35_FAMILY +from fairseq2.models.qwen.config import QWEN35_MOE_FAMILY as QWEN35_MOE_FAMILY from fairseq2.models.qwen.config import QWEN_FAMILY as QWEN_FAMILY +from fairseq2.models.qwen.config import Qwen35Config as Qwen35Config +from fairseq2.models.qwen.config import Qwen35MoeConfig as Qwen35MoeConfig from fairseq2.models.qwen.config import QwenConfig as QwenConfig +from fairseq2.models.qwen.config import ( + register_qwen35_configs as register_qwen35_configs, +) +from fairseq2.models.qwen.config import ( + register_qwen35_moe_configs as register_qwen35_moe_configs, +) from fairseq2.models.qwen.config import register_qwen_configs as register_qwen_configs +from fairseq2.models.qwen.factory import Qwen35Factory as Qwen35Factory +from fairseq2.models.qwen.factory import Qwen35MoeFactory as Qwen35MoeFactory from fairseq2.models.qwen.factory import QwenFactory as QwenFactory +from fairseq2.models.qwen.factory import create_qwen35_model as create_qwen35_model +from fairseq2.models.qwen.factory import ( + create_qwen35_moe_model as create_qwen35_moe_model, +) from fairseq2.models.qwen.factory import create_qwen_model as create_qwen_model +from fairseq2.models.qwen.hub import get_qwen35_model_hub as get_qwen35_model_hub +from fairseq2.models.qwen.hub import ( + get_qwen35_moe_model_hub as get_qwen35_moe_model_hub, +) +from fairseq2.models.qwen.hub import ( + get_qwen35_moe_tokenizer_hub as get_qwen35_moe_tokenizer_hub, +) +from fairseq2.models.qwen.hub import ( + get_qwen35_tokenizer_hub as get_qwen35_tokenizer_hub, +) from fairseq2.models.qwen.hub import get_qwen_model_hub as get_qwen_model_hub from fairseq2.models.qwen.hub import get_qwen_tokenizer_hub as get_qwen_tokenizer_hub +from fairseq2.models.qwen.interop import ( + _Qwen35HuggingFaceConverter as _Qwen35HuggingFaceConverter, +) +from fairseq2.models.qwen.interop import ( + _Qwen35MoeHuggingFaceConverter as _Qwen35MoeHuggingFaceConverter, +) from fairseq2.models.qwen.interop import ( _QwenHuggingFaceConverter as _QwenHuggingFaceConverter, ) +from fairseq2.models.qwen.interop import ( + convert_qwen35_moe_state_dict as convert_qwen35_moe_state_dict, +) +from fairseq2.models.qwen.interop import ( + convert_qwen35_state_dict as convert_qwen35_state_dict, +) from fairseq2.models.qwen.interop import ( convert_qwen_state_dict as convert_qwen_state_dict, ) -from fairseq2.models.qwen.sharder import get_qwen_shard_specs as get_qwen_shard_specs from fairseq2.models.qwen.tokenizer import QwenTokenizer as QwenTokenizer from fairseq2.models.qwen.tokenizer import QwenTokenizerConfig as QwenTokenizerConfig from fairseq2.models.qwen.tokenizer import load_qwen_tokenizer as load_qwen_tokenizer diff --git a/src/fairseq2/models/qwen/attention.py b/src/fairseq2/models/qwen/attention.py new file mode 100644 index 000000000..afe0de1e9 --- /dev/null +++ b/src/fairseq2/models/qwen/attention.py @@ -0,0 +1,213 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Gated multi-head attention for Qwen 3.5. + +Differs from ``StandardMultiheadAttention`` in three ways: + +1. The Q projection is doubled — half is the query, half is an output gate. +2. Partial RoPE: only the first ``encoding_dim`` dimensions are rotated. +3. Output gating: ``attn_output = attn_output * sigmoid(gate)``. + +Reference: HuggingFace ``modeling_qwen3_5.py`` ``Qwen3_5Attention`` lines 707-779. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Final + +import torch +from torch import Tensor + +from fairseq2.models.transformer import ( + SDPA, + AttentionBiasCache, + AttentionState, + AttentionStateFactory, + FullAttentionState, + MultiheadAttention, +) +from fairseq2.nn import ( + BatchLayout, + IncrementalStateBag, + LayerNorm, + Linear, + PositionEncoder, +) +from fairseq2.ops import repeat_interleave + + +class Qwen35Attention(MultiheadAttention): + """Gated multi-head attention for Qwen 3.5 full-attention layers. + + Key differences from :class:`StandardMultiheadAttention`: + + * **Doubled Q projection** — ``q_proj`` outputs ``num_heads * head_dim * 2``; + the second half is an output gate. + * **Partial RoPE** — only the first ``encoding_dim`` (typically 64) of the + ``head_dim`` (typically 256) are rotated. The rest pass through. + * **Output gating** — ``attn_output * sigmoid(gate)`` before ``output_proj``. + * **QK-Norm** on per-head dimension (after unflatten). + + Reference: ``modeling_qwen3_5.py`` lines 707-779. + """ + + num_heads: Final[int] + num_key_value_heads: Final[int] + num_query_groups: Final[int] + head_dim: Final[int] + + def __init__( + self, + model_dim: int, + num_heads: int, + sdpa: SDPA, + *, + head_dim: int = 256, + num_key_value_heads: int | None = None, + pos_encoder: PositionEncoder | None = None, + q_norm: LayerNorm | None = None, + k_norm: LayerNorm | None = None, + state_factory: AttentionStateFactory | None = None, + qkv_proj_init_fn: Callable[[Linear], None] | None = None, + output_proj_init_fn: Callable[[Linear], None] | None = None, + ) -> None: + super().__init__() + + self.num_heads = num_heads + self.head_dim = head_dim + + if num_key_value_heads is None: + num_key_value_heads = num_heads + self.num_key_value_heads = num_key_value_heads + self.num_query_groups = num_heads // num_key_value_heads + + # Q projection is DOUBLED — half query, half gate. + # HF: nn.Linear(hidden, num_heads * head_dim * 2, bias=False) + self.q_proj = Linear( + model_dim, + num_heads * head_dim * 2, + bias=False, + init_fn=qkv_proj_init_fn, + ) + + self.k_proj = Linear( + model_dim, + num_key_value_heads * head_dim, + bias=False, + init_fn=qkv_proj_init_fn, + ) + + self.v_proj = Linear( + model_dim, + num_key_value_heads * head_dim, + bias=False, + init_fn=qkv_proj_init_fn, + ) + + self.output_proj = Linear( + num_heads * head_dim, + model_dim, + bias=False, + init_fn=output_proj_init_fn, + ) + + self.q_norm = q_norm + self.k_norm = k_norm + self.pos_encoder = pos_encoder + self.sdpa = sdpa + self.state_factory = state_factory + + def forward( + self, + seqs: Tensor, + seqs_layout: BatchLayout, + keys: Tensor, + keys_layout: BatchLayout, + values: Tensor, + bias_cache: AttentionBiasCache, + *, + state_bag: IncrementalStateBag | None = None, + ) -> Tensor: + # -- Q projection: split into query + gate -- + # (B, S, num_heads * head_dim * 2) + q_combined = self.q_proj(seqs) + + # (B, S, num_heads, head_dim * 2) -> split along last dim + q_combined = q_combined.unflatten(-1, (self.num_heads, self.head_dim * 2)) + q, gate = q_combined.chunk(2, dim=-1) + # q: (B, S, num_heads, head_dim) + # gate: (B, S, num_heads, head_dim) + + # Flatten gate to (B, S, num_heads * head_dim) for later element-wise gating. + gate = gate.flatten(-2) # (B, S, num_heads * head_dim) + + # -- K, V projections -- + k = self.k_proj(keys) + v = self.v_proj(values) + k = k.unflatten(-1, (self.num_key_value_heads, self.head_dim)) + v = v.unflatten(-1, (self.num_key_value_heads, self.head_dim)) + + # -- QK-Norm (per head dim, after unflatten) -- + if self.q_norm is not None: + q = self.q_norm(q) + if self.k_norm is not None: + k = self.k_norm(k) + + # -- Partial RoPE -- + # Only the first `encoding_dim` dimensions of each head are rotated. + # The rest pass through unchanged. + if self.pos_encoder is not None: + encoding_dim = self.pos_encoder.encoding_dim + + if encoding_dim < self.head_dim: + # Split into rotary and pass-through parts. + q_rot = q[..., :encoding_dim] + q_pass = q[..., encoding_dim:] + k_rot = k[..., :encoding_dim] + k_pass = k[..., encoding_dim:] + + q_rot = self.pos_encoder(q_rot, seqs_layout, state_bag=state_bag) + k_rot = self.pos_encoder(k_rot, keys_layout, state_bag=state_bag) + + q = torch.cat([q_rot, q_pass], dim=-1) + k = torch.cat([k_rot, k_pass], dim=-1) + else: + # Full rotation (encoding_dim == head_dim). + q = self.pos_encoder(q, seqs_layout, state_bag=state_bag) + k = self.pos_encoder(k, keys_layout, state_bag=state_bag) + + # -- KV cache management -- + if not self.training and state_bag is not None: + state = state_bag.maybe_get_state(self, AttentionState) + if state is None: + state_factory = self.state_factory or FullAttentionState + state = state_factory( + k, v, state_bag.max_num_steps, state_bag.capacity_increment + ) + state_bag.set_state(self, state) + else: + state.append(k, v) + k, v = state.get() + keys_layout = BatchLayout.of(k) + + # -- GQA expansion -- + if self.num_query_groups > 1: + k = repeat_interleave(k, dim=-2, repeat=self.num_query_groups) + v = repeat_interleave(v, dim=-2, repeat=self.num_query_groups) + + # -- Scaled dot-product attention -- + # q, k, v: (B, S, H, D) + attn_output, _ = self.sdpa(q, seqs_layout, k, keys_layout, v, bias_cache) + + # -- Output gating -- + # attn_output: (B, S, H, D) -> (B, S, H * D) + attn_output = attn_output.flatten(-2) + attn_output = attn_output * torch.sigmoid(gate) + + # -- Output projection -- + return self.output_proj(attn_output) diff --git a/src/fairseq2/models/qwen/config.py b/src/fairseq2/models/qwen/config.py index 98c064419..13d23930f 100644 --- a/src/fairseq2/models/qwen/config.py +++ b/src/fairseq2/models/qwen/config.py @@ -13,6 +13,7 @@ from fairseq2.runtime.dependency import DependencyContainer QWEN_FAMILY: Final = "qwen" +QWEN35_FAMILY: Final = "qwen3_5" @dataclass(kw_only=True) @@ -62,6 +63,174 @@ class QwenConfig: dropout_p: float = 0.0 """The dropout probability on outputs of Transformer layers.""" + pad_idx: int | None = None + """The index of the pad symbol in the vocabulary.""" + + +# --------------------------------------------------------------------------- + + +@dataclass(kw_only=True) +class Qwen35Config: + """Holds the configuration of a Qwen 3.5 dense model.""" + + model_dim: int = 4096 + max_seq_len: int = 32_768 + vocab_size: int = 248_320 + tied_embeddings: bool = False + num_layers: int = 32 + num_attn_heads: int = 16 + num_key_value_heads: int = 4 + head_dim: int = 256 + ffn_inner_dim: int = 12_288 + partial_rotary_factor: float = 0.25 + rope_theta: float = 1_000_000.0 + dropout_p: float = 0.0 + layer_types: list[str] | None = None + full_attention_interval: int = 4 + linear_conv_kernel_dim: int = 4 + linear_key_head_dim: int = 128 + linear_value_head_dim: int = 128 + linear_num_key_heads: int = 16 + linear_num_value_heads: int = 32 + + pad_idx: int | None = None + """The index of the pad symbol in the vocabulary.""" + + def __post_init__(self) -> None: + if self.layer_types is None: + interval = self.full_attention_interval + self.layer_types = [ + "linear_attention" if bool((i + 1) % interval) else "full_attention" + for i in range(self.num_layers) + ] + + +def register_qwen35_configs(container: DependencyContainer) -> None: + arch = ConfigRegistrar(container, Qwen35Config) + + @arch("qwen35_0.8b") + def qwen35_0p8b() -> Qwen35Config: + return Qwen35Config( + model_dim=1024, + max_seq_len=262_144, + vocab_size=248_320, + tied_embeddings=True, + num_layers=24, + num_attn_heads=8, + num_key_value_heads=2, + head_dim=256, + ffn_inner_dim=3584, + partial_rotary_factor=0.25, + rope_theta=10_000_000.0, + full_attention_interval=4, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=16, + ) + + @arch("qwen35_2b") + def qwen35_2b() -> Qwen35Config: + return Qwen35Config( + model_dim=2048, + max_seq_len=262_144, + vocab_size=248_320, + tied_embeddings=True, + num_layers=24, + num_attn_heads=8, + num_key_value_heads=2, + head_dim=256, + ffn_inner_dim=6144, + partial_rotary_factor=0.25, + rope_theta=10_000_000.0, + full_attention_interval=4, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=16, + ) + + @arch("qwen35_9b") + def qwen35_9b() -> Qwen35Config: + return Qwen35Config( + model_dim=4096, + max_seq_len=262_144, + vocab_size=248_320, + tied_embeddings=False, + num_layers=32, + num_attn_heads=16, + num_key_value_heads=4, + head_dim=256, + ffn_inner_dim=12_288, + partial_rotary_factor=0.25, + rope_theta=10_000_000.0, + full_attention_interval=4, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=32, + ) + + @arch("qwen35_27b") + def qwen35_27b() -> Qwen35Config: + return Qwen35Config( + model_dim=5120, + max_seq_len=262_144, + vocab_size=248_320, + tied_embeddings=False, + num_layers=64, + num_attn_heads=24, + num_key_value_heads=4, + head_dim=256, + ffn_inner_dim=17_408, + partial_rotary_factor=0.25, + rope_theta=10_000_000.0, + full_attention_interval=4, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=48, + ) + + +# --------------------------------------------------------------------------- +# Qwen 3.5 MoE Config +# --------------------------------------------------------------------------- + +QWEN35_MOE_FAMILY: Final = "qwen3_5_moe" + + +@dataclass(kw_only=True) +class Qwen35MoeConfig(Qwen35Config): + """Holds the configuration of a Qwen 3.5 MoE model.""" + + model_dim: int = 2048 + num_layers: int = 40 + num_key_value_heads: int = 2 + num_experts: int = 256 + num_experts_per_tok: int = 8 + moe_intermediate_size: int = 512 + shared_expert_intermediate_size: int = 512 + router_aux_loss_coef: float = 0.001 + + +def register_qwen35_moe_configs(container: DependencyContainer) -> None: + arch = ConfigRegistrar(container, Qwen35MoeConfig) + + @arch("qwen35_moe_35b_a3b") + def qwen35_moe_35b_a3b() -> Qwen35MoeConfig: + return Qwen35MoeConfig() + + +# --------------------------------------------------------------------------- +# Qwen 2.5 / 3.0 arch configs +# --------------------------------------------------------------------------- + def register_qwen_configs(container: DependencyContainer) -> None: arch = ConfigRegistrar(container, QwenConfig) @@ -76,7 +245,7 @@ def qwen25_3b() -> QwenConfig: config.num_attn_heads = 16 config.num_key_value_heads = 2 config.ffn_inner_dim = 11_008 - config.tied_embeddings = True + config.rope_theta = 1_000_000 return config diff --git a/src/fairseq2/models/qwen/decoder_layer.py b/src/fairseq2/models/qwen/decoder_layer.py new file mode 100644 index 000000000..fa0369cc6 --- /dev/null +++ b/src/fairseq2/models/qwen/decoder_layer.py @@ -0,0 +1,168 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Hybrid decoder layer for Qwen 3.5. + +Each layer holds EITHER a :class:`Qwen35Attention` (full attention with output +gating) OR a :class:`GatedDeltaNet` (linear attention), dispatched by +``layer_type``. The FFN and layer norms are always present. + +Attribute names ``self_attn`` / ``linear_attn`` match HuggingFace for clean +interop key mapping. + +Reference: HuggingFace ``modeling_qwen3_5.py`` ``Qwen3_5DecoderLayer`` +lines 818-870. +""" + +from __future__ import annotations + +from typing import Final, final + +from torch import Tensor +from typing_extensions import override + +from fairseq2.models.qwen.attention import Qwen35Attention +from fairseq2.models.qwen.gated_delta_net import GatedDeltaNet +from fairseq2.models.transformer import ( + AttentionBiasCache, + FeedForwardNetwork, +) +from fairseq2.models.transformer_lm import TransformerLMDecoderLayer +from fairseq2.nn import ( + AdditiveResidualConnect, + BatchLayout, + IncrementalStateBag, + LayerNorm, + ResidualConnect, +) + + +@final +class Qwen35DecoderLayer(TransformerLMDecoderLayer): + """Hybrid decoder layer that dispatches to full or linear attention. + + * ``layer_type == "full_attention"``: uses :attr:`self_attn` + (:class:`Qwen35Attention`). + * ``layer_type == "linear_attention"``: uses :attr:`linear_attn` + (:class:`GatedDeltaNet`). + """ + + layer_type: Final[str] + + def __init__( + self, + layer_type: str, + self_attn: Qwen35Attention | None, + linear_attn: GatedDeltaNet | None, + ffn: FeedForwardNetwork, + self_attn_layer_norm: LayerNorm, + ffn_layer_norm: LayerNorm, + *, + self_attn_residual: ResidualConnect | None = None, + ffn_residual: ResidualConnect | None = None, + ) -> None: + """ + :param layer_type: ``"full_attention"`` or ``"linear_attention"``. + :param self_attn: Gated full attention module (only for full layers). + :param linear_attn: GatedDeltaNet module (only for linear layers). + :param ffn: Feed-forward network (always present). + :param self_attn_layer_norm: Pre-attention layer norm. + :param ffn_layer_norm: Pre-FFN layer norm. + """ + super().__init__() + + self.layer_type = layer_type + + # Register exactly one token mixer — attribute name matters for interop. + self.self_attn: Qwen35Attention | None + self.linear_attn: GatedDeltaNet | None + + if layer_type == "full_attention": + assert self_attn is not None + self.register_module("self_attn", self_attn) + self.register_module("linear_attn", None) + elif layer_type == "linear_attention": + assert linear_attn is not None + self.register_module("self_attn", None) + self.register_module("linear_attn", linear_attn) + else: + raise ValueError( + f"`layer_type` must be 'full_attention' or 'linear_attention', got '{layer_type}'." + ) + + self.self_attn_layer_norm = self_attn_layer_norm + self.ffn = ffn + self.ffn_layer_norm = ffn_layer_norm + + if self_attn_residual is None: + self_attn_residual = AdditiveResidualConnect() + self.self_attn_residual = self_attn_residual + + if ffn_residual is None: + ffn_residual = AdditiveResidualConnect() + self.ffn_residual = ffn_residual + + @override + def forward( + self, + seqs: Tensor, + seqs_layout: BatchLayout, + attn_bias_cache: AttentionBiasCache, + *, + state_bag: IncrementalStateBag | None = None, + ) -> Tensor: + seqs = self._forward_token_mixer(seqs, seqs_layout, attn_bias_cache, state_bag) + seqs = self._forward_ffn(seqs) + return seqs + + def _forward_token_mixer( + self, + seqs: Tensor, + seqs_layout: BatchLayout, + attn_bias_cache: AttentionBiasCache, + state_bag: IncrementalStateBag | None, + ) -> Tensor: + residual = seqs + + seqs = self.self_attn_layer_norm(seqs) + + if self.layer_type == "linear_attention": + assert self.linear_attn is not None + # GatedDeltaNet expects 3D (B, S, D) but packed sequences are 2D + # (T, D). Unsqueeze to (1, T, D) — treats all packed tokens as one + # long causal sequence, which is correct for recurrent computation. + if seqs.dim() == 2: + seqs = self.linear_attn(seqs.unsqueeze(0), state_bag=state_bag) + seqs = seqs.squeeze(0) + else: + seqs = self.linear_attn(seqs, state_bag=state_bag) + else: + assert self.self_attn is not None + seqs = self.self_attn( + seqs, + seqs_layout, + keys=seqs, + keys_layout=seqs_layout, + values=seqs, + bias_cache=attn_bias_cache, + state_bag=state_bag, + ) + + seqs = self.self_attn_residual(seqs, residual) + return seqs + + def _forward_ffn(self, seqs: Tensor) -> Tensor: + residual = seqs + + seqs = self.ffn_layer_norm(seqs) + seqs = self.ffn(seqs) + seqs = self.ffn_residual(seqs, residual) + + return seqs + + @override + def extra_repr(self) -> str: + return f"layer_type={self.layer_type}" diff --git a/src/fairseq2/models/qwen/factory.py b/src/fairseq2/models/qwen/factory.py index 5aeb96b34..6a6021950 100644 --- a/src/fairseq2/models/qwen/factory.py +++ b/src/fairseq2/models/qwen/factory.py @@ -10,7 +10,9 @@ from torch import Tensor from fairseq2.error import NotSupportedError -from fairseq2.models.qwen.config import QwenConfig +from fairseq2.models.qwen.attention import Qwen35Attention +from fairseq2.models.qwen.config import Qwen35Config, Qwen35MoeConfig, QwenConfig +from fairseq2.models.qwen.gated_delta_net import GatedDeltaNet from fairseq2.models.transformer import ( CausalAttentionBias, FeedForwardNetwork, @@ -48,6 +50,10 @@ def create_qwen_model(config: QwenConfig) -> TransformerLM: return QwenFactory(config).create_model() +def create_qwen35_model(config: Qwen35Config) -> TransformerLM: + return Qwen35Factory(config).create_model() + + class QwenFactory: def __init__(self, config: QwenConfig) -> None: self._config = config @@ -85,7 +91,7 @@ def init_embed(embed: StandardEmbedding) -> None: _init_truncated_normal(embed.weight, bias=None, std=std) return VocabShardedEmbedding( - config.vocab_size, config.model_dim, init_fn=init_embed + config.vocab_size, config.model_dim, config.pad_idx, init_fn=init_embed ) def create_decoder_frontend(self, embed: Embedding) -> TransformerFrontend: @@ -265,3 +271,206 @@ def _init_truncated_normal( if bias is not None: nn.init.zeros_(bias) + + +# --------------------------------------------------------------------------- +# Qwen 3.5 Factory +# --------------------------------------------------------------------------- + + +class Qwen35Factory: + """Factory for Qwen 3.5 dense hybrid models.""" + + def __init__(self, config: Qwen35Config) -> None: + self._config = config + config.__post_init__() + + def create_model(self) -> TransformerLM: + config = self._config + + embed = self.create_embedding() + decoder_frontend = self.create_decoder_frontend(embed) + decoder = self.create_decoder() + final_proj = self.create_final_projection(embed) + + return TransformerLM( + config.model_dim, + decoder_frontend, + decoder, + final_proj, + config.pad_idx, + config.max_seq_len, + ) + + def create_embedding(self) -> Embedding: + config = self._config + + def init_embed(embed: StandardEmbedding) -> None: + std = embed.weight.shape[1] ** -0.5 + _init_truncated_normal(embed.weight, bias=None, std=std) + + return VocabShardedEmbedding( + config.vocab_size, config.model_dim, config.pad_idx, init_fn=init_embed + ) + + def create_decoder_frontend(self, embed: Embedding) -> TransformerFrontend: + config = self._config + + return TransformerEmbeddingFrontend( + config.model_dim, + embed, + pos_encoder=None, + no_scale=True, + dropout_p=config.dropout_p, + ) + + def create_decoder(self) -> TransformerLMDecoder: + config = self._config + + pos_encoder = self.create_position_encoder() + + layers = [] + for idx in range(config.num_layers): + layer = self.create_decoder_layer(idx, pos_encoder) + layers.append(layer) + + layer_norm = self.create_layer_norm() + + return StandardTransformerLMDecoder(layers, layer_norm) + + def create_position_encoder(self) -> PositionEncoder: + config = self._config + + encoding_dim = int(config.head_dim * config.partial_rotary_factor) + + return ReferenceRotaryEncoder( + encoding_dim, config.max_seq_len, theta=config.rope_theta + ) + + def create_decoder_layer( + self, layer_idx: int, pos_encoder: PositionEncoder + ) -> TransformerLMDecoderLayer: + from fairseq2.models.qwen.decoder_layer import Qwen35DecoderLayer + + config = self._config + + assert config.layer_types is not None + layer_type = config.layer_types[layer_idx] + + self_attn = None + linear_attn = None + + if layer_type == "full_attention": + self_attn = self.create_gated_attention(layer_idx, pos_encoder) + else: + linear_attn = self.create_gated_delta_net(layer_idx) + + ffn = self.create_ffn(layer_idx) + self_attn_layer_norm = self.create_layer_norm() + ffn_layer_norm = self.create_layer_norm() + + return Qwen35DecoderLayer( + layer_type, + self_attn=self_attn, + linear_attn=linear_attn, + ffn=ffn, + self_attn_layer_norm=self_attn_layer_norm, + ffn_layer_norm=ffn_layer_norm, + ) + + def create_gated_attention( + self, layer_idx: int, pos_encoder: PositionEncoder + ) -> Qwen35Attention: + from fairseq2.models.qwen.attention import Qwen35Attention + + config = self._config + + attn_bias = CausalAttentionBias() + sdpa = create_default_sdpa(attn_bias) + + q_norm = self.create_layer_norm(config.head_dim) + k_norm = self.create_layer_norm(config.head_dim) + + return Qwen35Attention( + config.model_dim, + config.num_attn_heads, + sdpa, + head_dim=config.head_dim, + num_key_value_heads=config.num_key_value_heads, + pos_encoder=pos_encoder, + q_norm=q_norm, + k_norm=k_norm, + ) + + def create_gated_delta_net(self, layer_idx: int) -> GatedDeltaNet: + from fairseq2.models.qwen.gated_delta_net import GatedDeltaNet + + config = self._config + + return GatedDeltaNet( + hidden_size=config.model_dim, + num_k_heads=config.linear_num_key_heads, + num_v_heads=config.linear_num_value_heads, + head_k_dim=config.linear_key_head_dim, + head_v_dim=config.linear_value_head_dim, + conv_kernel_size=config.linear_conv_kernel_dim, + ) + + def create_ffn(self, layer_idx: int) -> FeedForwardNetwork: + config = self._config + + return GLUFeedForwardNetwork( + config.model_dim, + config.ffn_inner_dim, + bias=False, + inner_dim_scale=1.0, + ) + + def create_final_projection(self, embed: Embedding) -> Projection: + config = self._config + + if config.tied_embeddings: + if not isinstance(embed, VocabShardedEmbedding): + raise TypeError( + f"`embed` is expected to be of type `{VocabShardedEmbedding}` when tied_embeddings is True." + ) + if embed.tp_gang.size > 1: + raise NotSupportedError( + "Tied embeddings are not supported when tensor parallelism is enabled." + ) + return TiedProjection(embed.weight, bias=None) + + return ColumnShardedLinear(config.model_dim, config.vocab_size, bias=False) + + def create_layer_norm(self, dim: int | None = None) -> LayerNorm: + config = self._config + if dim is None: + dim = config.model_dim + return RMSNorm(dim, bias=False, eps=1e-06) + + +def create_qwen35_moe_model(config: Qwen35MoeConfig) -> TransformerLM: + return Qwen35MoeFactory(config).create_model() + + +class Qwen35MoeFactory(Qwen35Factory): + """Factory for Qwen 3.5 MoE hybrid models.""" + + _config: Qwen35MoeConfig + + def __init__(self, config: Qwen35MoeConfig) -> None: + super().__init__(config) + self._config = config + + def create_ffn(self, layer_idx: int) -> FeedForwardNetwork: + from fairseq2.models.qwen.moe import Qwen35MoeBlock + + config = self._config + + return Qwen35MoeBlock( + model_dim=config.model_dim, + num_experts=config.num_experts, + num_experts_per_tok=config.num_experts_per_tok, + moe_intermediate_size=config.moe_intermediate_size, + shared_expert_intermediate_size=config.shared_expert_intermediate_size, + ) diff --git a/src/fairseq2/models/qwen/gated_delta_net.py b/src/fairseq2/models/qwen/gated_delta_net.py new file mode 100644 index 000000000..d9e677dc3 --- /dev/null +++ b/src/fairseq2/models/qwen/gated_delta_net.py @@ -0,0 +1,545 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Gated DeltaNet linear attention module for Qwen 3.5. + +Reference: HuggingFace ``modeling_qwen3_5.py`` lines 445-620. +""" + +from __future__ import annotations + +import logging +from typing import Callable, Final, final + +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from fairseq2.nn import ( + IncrementalState, + IncrementalStateBag, + Linear, + RMSNorm, +) + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Optional fast-path kernels +# --------------------------------------------------------------------------- + +try: + from causal_conv1d import causal_conv1d_update as _causal_conv1d_update + + _HAS_CAUSAL_CONV1D = True +except ImportError: + _HAS_CAUSAL_CONV1D = False + logger.warning( + "causal_conv1d not found; GatedDeltaNet will use a slower PyTorch fallback " + "for incremental decoding. Install with: pip install causal-conv1d" + ) + +try: + from fla.ops.gated_delta_rule import ( + chunk_gated_delta_rule as _chunk_gated_delta_rule, + ) + from fla.ops.gated_delta_rule import ( + fused_recurrent_gated_delta_rule as _fused_recurrent_gated_delta_rule, + ) + + _HAS_FLA = True +except ImportError: + _HAS_FLA = False + logger.warning( + "flash-linear-attention (fla) not found; GatedDeltaNet will use slower " + "pure-PyTorch chunk/recurrent kernels. Install with: pip install flash-linear-attention" + ) + + +def l2norm(x: Tensor, dim: int = -1, eps: float = 1e-6) -> Tensor: + """L2-normalize along ``dim``. + + Reference: ``modeling_qwen3_5.py`` lines 317-320. + """ + return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + + +# --------------------------------------------------------------------------- +# PyTorch fallback kernels (no external dependencies) +# --------------------------------------------------------------------------- + + +def torch_causal_conv1d_update( + hidden_states: Tensor, + conv_state: Tensor, + weight: Tensor, + bias: Tensor | None = None, + activation: str | None = None, +) -> Tensor: + """Single-step causal conv1d for incremental decoding. + + Reference: ``modeling_qwen3_5.py`` lines 299-314. + + :param hidden_states: ``(B, D, L)`` — typically ``L=1`` during decode. + :param conv_state: ``(B, D, kernel-1)`` — updated in-place. + :param weight: ``(D, kernel)`` — depthwise conv weights. + :param bias: ``(D,)`` or ``None``. + :param activation: ``"silu"`` or ``None``. + :returns: ``(B, D, L)`` convolved output. + """ + _, hidden_size, seq_len = hidden_states.shape + state_len = conv_state.shape[-1] + + hidden_states_new = torch.cat([conv_state, hidden_states], dim=-1).to(weight.dtype) + conv_state.copy_(hidden_states_new[:, :, -state_len:]) + + out = F.conv1d( + hidden_states_new, + weight.unsqueeze(1), + bias, + padding=0, + groups=hidden_size, + ) + if activation == "silu": + out = F.silu(out[:, :, -seq_len:]) + else: + out = out[:, :, -seq_len:] + return out.to(hidden_states.dtype) + + +def torch_chunk_gated_delta_rule( + query: Tensor, + key: Tensor, + value: Tensor, + g: Tensor, + beta: Tensor, + chunk_size: int = 64, + initial_state: Tensor | None = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[Tensor, Tensor | None]: + """Chunked gated delta rule for prefill (pure PyTorch). + + Reference: ``modeling_qwen3_5.py`` lines 323-400. + + :param query: ``(B, S, H, K)`` + :param key: ``(B, S, H, K)`` + :param value: ``(B, S, H, V)`` + :param g: ``(B, S, H)`` — forget gate (log-space). + :param beta: ``(B, S, H)`` — write gate. + :returns: ``(output, final_state)`` + """ + initial_dtype = query.dtype + + if use_qk_l2norm_in_kernel: + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) + for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, seq_len, k_head_dim = key.shape + v_head_dim = value.shape[-1] + + pad_size = (chunk_size - seq_len % chunk_size) % chunk_size + query = F.pad(query, (0, 0, 0, pad_size)) + key = F.pad(key, (0, 0, 0, pad_size)) + value = F.pad(value, (0, 0, 0, pad_size)) + beta = F.pad(beta, (0, pad_size)) + g = F.pad(g, (0, pad_size)) + total_seq_len = seq_len + pad_size + + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale + + v_beta = value * beta.unsqueeze(-1) + k_beta = key * beta.unsqueeze(-1) + + query, key, value, k_beta, v_beta = [ + x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) + for x in (query, key, value, k_beta, v_beta) + ] + g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) + + mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), + diagonal=0, + ) + + g = g.cumsum(dim=-1) + decay_mask = (g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float().tril() + + attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) + for i in range(1, chunk_size): + row = attn[..., i, :i].clone() + sub = attn[..., :i, :i].clone() + attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) + attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) + + value = attn @ v_beta + k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) + + last_state = ( + torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + core_out = torch.zeros_like(value) + + mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), + diagonal=1, + ) + + num_chunks = total_seq_len // chunk_size + for i in range(num_chunks): + q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] + attn_i = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_( + mask, 0 + ) + v_prime = k_cumdecay[:, :, i] @ last_state + v_new = v_i - v_prime + attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_state + core_out[:, :, i] = attn_inter + attn_i @ v_new + last_state = ( + last_state * g[:, :, i, -1, None, None].exp() + + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose( + -1, -2 + ) + @ v_new + ) + + final_state: Tensor | None = last_state if output_final_state else None + + core_out = core_out.reshape( + core_out.shape[0], core_out.shape[1], -1, core_out.shape[-1] + ) + core_out = core_out[:, :, :seq_len] + core_out = core_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_out, final_state + + +def torch_recurrent_gated_delta_rule( + query: Tensor, + key: Tensor, + value: Tensor, + g: Tensor, + beta: Tensor, + initial_state: Tensor | None = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[Tensor, Tensor | None]: + """Step-by-step recurrent gated delta rule for decode (pure PyTorch). + + Reference: ``modeling_qwen3_5.py`` lines 403-442. + + :param query: ``(B, S, H, K)`` — typically ``S=1`` during decode. + :param key: ``(B, S, H, K)`` + :param value: ``(B, S, H, V)`` + :param g: ``(B, S, H)`` — forget gate (log-space). + :param beta: ``(B, S, H)`` — write gate. + :returns: ``(output, final_state)`` + """ + initial_dtype = query.dtype + + if use_qk_l2norm_in_kernel: + query = l2norm(query, dim=-1) + key = l2norm(key, dim=-1) + + query, key, value, beta, g = [ + x.transpose(1, 2).contiguous().to(torch.float32) + for x in (query, key, value, beta, g) + ] + + batch_size, num_heads, seq_len, k_head_dim = key.shape + v_head_dim = value.shape[-1] + + scale = 1.0 / (query.shape[-1] ** 0.5) + query = query * scale + + core_out = torch.zeros(batch_size, num_heads, seq_len, v_head_dim).to(value) + last_state = ( + torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) + if initial_state is None + else initial_state.to(value) + ) + + for i in range(seq_len): + q_t = query[:, :, i] + k_t = key[:, :, i] + v_t = value[:, :, i] + g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1) + beta_t = beta[:, :, i].unsqueeze(-1) + + last_state = last_state * g_t + kv_mem = (last_state * k_t.unsqueeze(-1)).sum(dim=-2) + delta = (v_t - kv_mem) * beta_t + last_state = last_state + k_t.unsqueeze(-1) * delta.unsqueeze(-2) + core_out[:, :, i] = (last_state * q_t.unsqueeze(-1)).sum(dim=-2) + + final_state: Tensor | None = last_state if output_final_state else None + + core_out = core_out.transpose(1, 2).contiguous().to(initial_dtype) + return core_out, final_state + + +# --------------------------------------------------------------------------- +# Incremental state +# --------------------------------------------------------------------------- + + +@final +class GatedDeltaNetState(IncrementalState): + """Holds conv and recurrent state for :class:`GatedDeltaNet` during + incremental decoding.""" + + conv_state: Tensor + """``(B, conv_dim, kernel_size - 1)``""" + + recurrent_state: Tensor + """``(B, num_v_heads, head_k_dim, head_v_dim)``""" + + def __init__(self, conv_state: Tensor, recurrent_state: Tensor) -> None: + self.conv_state = conv_state + self.recurrent_state = recurrent_state + + def reorder(self, new_order: Tensor) -> None: + self.conv_state = self.conv_state.index_select(0, new_order) + self.recurrent_state = self.recurrent_state.index_select(0, new_order) + + def size_bytes(self) -> int: + return self.capacity_bytes() + + def capacity_bytes(self) -> int: + c = self.conv_state.numel() * self.conv_state.element_size() + r = self.recurrent_state.numel() * self.recurrent_state.element_size() + return c + r + + +# --------------------------------------------------------------------------- +# RMSNormGated — norm-before-gate with silu +# --------------------------------------------------------------------------- + + +class RMSNormGated(nn.Module): + """``RMSNorm(x) * silu(gate)`` + + Internal norm inside GatedDeltaNet. Uses the standard ``weight=ones`` + formula (NOT the ``1+weight`` variant used by the outer layer norms). + + Reference: ``modeling_qwen3_5.py`` lines 264-279. + """ + + def __init__(self, dim: int, eps: float = 1e-6) -> None: + super().__init__() + self.inner_norm = RMSNorm(dim, bias=False, eps=eps) + + def forward(self, hidden_states: Tensor, gate: Tensor) -> Tensor: + hidden_states = self.inner_norm(hidden_states) + hidden_states = hidden_states * F.silu(gate.to(torch.float32)) + return hidden_states.to(gate.dtype) + + +# --------------------------------------------------------------------------- +# GatedDeltaNet module +# --------------------------------------------------------------------------- + + +class GatedDeltaNet(nn.Module): + """Gated DeltaNet linear attention module for Qwen 3.5. + + Replaces standard multi-head attention in 75% of Qwen 3.5 layers. + Uses causal convolution followed by a gated delta rule recurrence. + + Reference: ``modeling_qwen3_5.py`` ``Qwen3_5GatedDeltaNet`` lines 445-620. + """ + + hidden_size: Final[int] + num_k_heads: Final[int] + num_v_heads: Final[int] + head_k_dim: Final[int] + head_v_dim: Final[int] + key_dim: Final[int] + value_dim: Final[int] + conv_dim: Final[int] + conv_kernel_size: Final[int] + + def __init__( + self, + hidden_size: int, + num_k_heads: int = 16, + num_v_heads: int = 32, + head_k_dim: int = 128, + head_v_dim: int = 128, + conv_kernel_size: int = 4, + eps: float = 1e-6, + ) -> None: + super().__init__() + + self.hidden_size = hidden_size + self.num_k_heads = num_k_heads + self.num_v_heads = num_v_heads + self.head_k_dim = head_k_dim + self.head_v_dim = head_v_dim + self.key_dim = head_k_dim * num_k_heads + self.value_dim = head_v_dim * num_v_heads + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv_kernel_size = conv_kernel_size + + # Input projections — fairseq2 Linear wrappers. + self.in_proj_qkv = Linear( + hidden_size, self.key_dim * 2 + self.value_dim, bias=False + ) + self.in_proj_z = Linear(hidden_size, self.value_dim, bias=False) + self.in_proj_b = Linear(hidden_size, num_v_heads, bias=False) + self.in_proj_a = Linear(hidden_size, num_v_heads, bias=False) + + # Depthwise causal convolution (no fairseq2 wrapper exists). + self.conv1d = nn.Conv1d( + in_channels=self.conv_dim, + out_channels=self.conv_dim, + bias=False, + kernel_size=conv_kernel_size, + groups=self.conv_dim, + padding=conv_kernel_size - 1, + ) + + # Learnable gating parameters (no fairseq2 wrapper for raw params). + self.dt_bias = nn.Parameter(torch.ones(num_v_heads)) + A = torch.empty(num_v_heads).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A)) + + # Output norm (silu-gated, wraps fairseq2 RMSNorm) and projection. + self.norm = RMSNormGated(head_v_dim, eps=eps) + self.out_proj = Linear(self.value_dim, hidden_size, bias=False) + + # Select fast-path kernels when available, else pure-PyTorch fallbacks. + self._conv1d_update_fn: Callable[..., Tensor] = ( + _causal_conv1d_update if _HAS_CAUSAL_CONV1D else torch_causal_conv1d_update + ) + self._chunk_fn: Callable[..., tuple[Tensor, Tensor | None]] = ( + _chunk_gated_delta_rule if _HAS_FLA else torch_chunk_gated_delta_rule + ) + self._recurrent_fn: Callable[..., tuple[Tensor, Tensor | None]] = ( + _fused_recurrent_gated_delta_rule + if _HAS_FLA + else torch_recurrent_gated_delta_rule + ) + + def forward( + self, + seqs: Tensor, + padding_mask: Tensor | None = None, + *, + state_bag: IncrementalStateBag | None = None, + ) -> Tensor: + """ + :param seqs: ``(B, S, D)`` + :param padding_mask: Optional ``(B, S)`` boolean mask (1 = valid). + :param state_bag: Incremental state bag for generation. + :returns: ``(B, S, D)`` + """ + if padding_mask is not None and padding_mask.shape[1] > 1: + seqs = (seqs * padding_mask[:, :, None]).to(seqs.dtype) + + batch_size, seq_len, _ = seqs.shape + + state: GatedDeltaNetState | None = None + if state_bag is not None: + state = state_bag.maybe_get_state(self, GatedDeltaNetState) + + use_cache = state is not None and seq_len == 1 + + # -- Input projections -- + mixed_qkv = self.in_proj_qkv(seqs).transpose(1, 2) # (B, conv_dim, S) + z = self.in_proj_z(seqs).reshape(batch_size, seq_len, -1, self.head_v_dim) + b = self.in_proj_b(seqs) + a = self.in_proj_a(seqs) + + # -- Causal convolution -- + conv_state: Tensor | None = None + + if use_cache: + assert state is not None + mixed_qkv = self._conv1d_update_fn( + mixed_qkv, + state.conv_state, + self.conv1d.weight.squeeze(1), + self.conv1d.bias, + "silu", + ) + else: + if state_bag is not None: + conv_state = F.pad( + mixed_qkv, + (self.conv_kernel_size - mixed_qkv.shape[-1], 0), + ) + + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + + mixed_qkv = mixed_qkv.transpose(1, 2) # (B, S, conv_dim) + + # -- Split QKV -- + query, key, value = torch.split( + mixed_qkv, [self.key_dim, self.key_dim, self.value_dim], dim=-1 + ) + query = query.reshape(batch_size, seq_len, -1, self.head_k_dim) + key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) + value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) + + # -- Compute gates -- + beta = b.sigmoid() + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + + # -- GQA expansion -- + groups = self.num_v_heads // self.num_k_heads + if groups > 1: + query = query.repeat_interleave(groups, dim=2) + key = key.repeat_interleave(groups, dim=2) + + # -- Delta rule core -- + if use_cache: + assert state is not None + core_out, last_state = self._recurrent_fn( + query, + key, + value, + g=g, + beta=beta, + initial_state=state.recurrent_state, + output_final_state=True, + use_qk_l2norm_in_kernel=True, + ) + else: + core_out, last_state = self._chunk_fn( + query, + key, + value, + g=g, + beta=beta, + initial_state=None, + output_final_state=state_bag is not None, + use_qk_l2norm_in_kernel=True, + ) + + # -- Update incremental state -- + if state_bag is not None: + if state is None: + assert conv_state is not None and last_state is not None + state_bag.set_state(self, GatedDeltaNetState(conv_state, last_state)) + else: + assert last_state is not None + state.recurrent_state = last_state + + # -- Output norm (silu-gated) + projection -- + core_out = core_out.reshape(-1, self.head_v_dim) + z = z.reshape(-1, self.head_v_dim) + core_out = self.norm(core_out, z) + core_out = core_out.reshape(batch_size, seq_len, -1) + + return self.out_proj(core_out) diff --git a/src/fairseq2/models/qwen/hub.py b/src/fairseq2/models/qwen/hub.py index 55df4bcf1..827196df3 100644 --- a/src/fairseq2/models/qwen/hub.py +++ b/src/fairseq2/models/qwen/hub.py @@ -8,7 +8,14 @@ from fairseq2.data.tokenizers import TokenizerHubAccessor from fairseq2.models import ModelHubAccessor -from fairseq2.models.qwen.config import QWEN_FAMILY, QwenConfig +from fairseq2.models.qwen.config import ( + QWEN35_FAMILY, + QWEN35_MOE_FAMILY, + QWEN_FAMILY, + Qwen35Config, + Qwen35MoeConfig, + QwenConfig, +) from fairseq2.models.qwen.tokenizer import QwenTokenizer, QwenTokenizerConfig from fairseq2.models.transformer_lm import TransformerLM @@ -19,3 +26,19 @@ get_qwen_tokenizer_hub = TokenizerHubAccessor( QWEN_FAMILY, kls=QwenTokenizer, config_kls=QwenTokenizerConfig ) + +get_qwen35_model_hub = ModelHubAccessor( + QWEN35_FAMILY, kls=TransformerLM, config_kls=Qwen35Config +) + +get_qwen35_tokenizer_hub = TokenizerHubAccessor( + QWEN35_FAMILY, kls=QwenTokenizer, config_kls=QwenTokenizerConfig +) + +get_qwen35_moe_model_hub = ModelHubAccessor( + QWEN35_MOE_FAMILY, kls=TransformerLM, config_kls=Qwen35MoeConfig +) + +get_qwen35_moe_tokenizer_hub = TokenizerHubAccessor( + QWEN35_MOE_FAMILY, kls=QwenTokenizer, config_kls=QwenTokenizerConfig +) diff --git a/src/fairseq2/models/qwen/interop.py b/src/fairseq2/models/qwen/interop.py index 98da6d6b7..37cac3b51 100644 --- a/src/fairseq2/models/qwen/interop.py +++ b/src/fairseq2/models/qwen/interop.py @@ -8,10 +8,11 @@ from typing import Final, final +import torch from typing_extensions import override from fairseq2.models.hg import HuggingFaceConfig, HuggingFaceConverter -from fairseq2.models.qwen.config import QwenConfig +from fairseq2.models.qwen.config import Qwen35Config, Qwen35MoeConfig, QwenConfig from fairseq2.models.utils.checkpoint import convert_state_dict, create_reverse_key_map from fairseq2.utils.config import cast_config_type @@ -47,9 +48,129 @@ def convert_qwen_state_dict( return state_dict +# HG-side RMSNorm key suffixes for reverse conversion (weight -= 1.0). +_QWEN35_HG_RMSNORM_SUFFIXES = ( + "input_layernorm.weight", + "post_attention_layernorm.weight", + "model.norm.weight", + "self_attn.q_norm.weight", + "self_attn.k_norm.weight", +) + + @final -class _QwenHuggingFaceConverter(HuggingFaceConverter): +class _Qwen35HuggingFaceConverter(HuggingFaceConverter): + @override + def to_hg_config(self, config: object) -> HuggingFaceConfig: + config = cast_config_type(config, Qwen35Config) + + data: dict[str, object] = { + "hidden_size": config.model_dim, + "max_position_embeddings": config.max_seq_len, + "vocab_size": config.vocab_size, + "tie_word_embeddings": config.tied_embeddings, + "num_hidden_layers": config.num_layers, + "num_attention_heads": config.num_attn_heads, + "num_key_value_heads": config.num_key_value_heads, + "head_dim": config.head_dim, + "intermediate_size": config.ffn_inner_dim, + "partial_rotary_factor": config.partial_rotary_factor, + "rope_theta": config.rope_theta, + "full_attention_interval": config.full_attention_interval, + "linear_conv_kernel_dim": config.linear_conv_kernel_dim, + "linear_key_head_dim": config.linear_key_head_dim, + "linear_value_head_dim": config.linear_value_head_dim, + "linear_num_key_heads": config.linear_num_key_heads, + "linear_num_value_heads": config.linear_num_value_heads, + } + + return HuggingFaceConfig( + data, kls_name="Qwen3_5TextConfig", arch="Qwen3_5ForCausalLM" + ) + + @override + def to_hg_state_dict( + self, state_dict: dict[str, object], config: object + ) -> dict[str, object]: + config = cast_config_type(config, Qwen35Config) + + # Use the text-only key map for export (model.layers.*, not + # model.language_model.layers.*). + key_map = create_reverse_key_map(_QWEN35_TEXT_KEY_MAP) + + hg_state_dict = convert_state_dict(state_dict, key_map) + + for key in list(hg_state_dict.keys()): + if any(key.endswith(suffix) for suffix in _QWEN35_HG_RMSNORM_SUFFIXES): + weight = hg_state_dict[key] + if isinstance(weight, torch.Tensor): + hg_state_dict[key] = weight - 1.0 + + if config.tied_embeddings: + hg_state_dict.pop("lm_head.weight", None) + + return hg_state_dict + + +@final +class _Qwen35MoeHuggingFaceConverter(HuggingFaceConverter): + def to_hg_config(self, config: object) -> HuggingFaceConfig: + config = cast_config_type(config, Qwen35MoeConfig) + + data: dict[str, object] = { + "hidden_size": config.model_dim, + "max_position_embeddings": config.max_seq_len, + "vocab_size": config.vocab_size, + "tie_word_embeddings": config.tied_embeddings, + "num_hidden_layers": config.num_layers, + "num_attention_heads": config.num_attn_heads, + "num_key_value_heads": config.num_key_value_heads, + "head_dim": config.head_dim, + "intermediate_size": config.ffn_inner_dim, + "partial_rotary_factor": config.partial_rotary_factor, + "rope_theta": config.rope_theta, + "full_attention_interval": config.full_attention_interval, + "linear_conv_kernel_dim": config.linear_conv_kernel_dim, + "linear_key_head_dim": config.linear_key_head_dim, + "linear_value_head_dim": config.linear_value_head_dim, + "linear_num_key_heads": config.linear_num_key_heads, + "linear_num_value_heads": config.linear_num_value_heads, + "num_experts": config.num_experts, + "num_experts_per_tok": config.num_experts_per_tok, + "moe_intermediate_size": config.moe_intermediate_size, + "shared_expert_intermediate_size": config.shared_expert_intermediate_size, + "router_aux_loss_coef": config.router_aux_loss_coef, + } + + return HuggingFaceConfig( + data, kls_name="Qwen3_5TextConfig", arch="Qwen3_5MoeForCausalLM" + ) + @override + def to_hg_state_dict( + self, state_dict: dict[str, object], config: object + ) -> dict[str, object]: + config = cast_config_type(config, Qwen35MoeConfig) + + # Use the text-only MoE key map for export. + key_map = create_reverse_key_map(_QWEN35_MOE_TEXT_KEY_MAP) + + hg_state_dict = convert_state_dict(state_dict, key_map) + + for key in list(hg_state_dict.keys()): + if any(key.endswith(suffix) for suffix in _QWEN35_HG_RMSNORM_SUFFIXES): + weight = hg_state_dict[key] + if isinstance(weight, torch.Tensor): + hg_state_dict[key] = weight - 1.0 + + if config.tied_embeddings: + hg_state_dict.pop("lm_head.weight", None) + + return hg_state_dict + + +@final +class _QwenHuggingFaceConverter(HuggingFaceConverter): def to_hg_config(self, config: object) -> HuggingFaceConfig: config = cast_config_type(config, QwenConfig) @@ -88,3 +209,191 @@ def to_hg_state_dict( del hg_state_dict["lm_head.weight"] return hg_state_dict + + +# --------------------------------------------------------------------------- +# Qwen 3.5 interop +# --------------------------------------------------------------------------- + +# Text-only key map (matches ``transformers`` Qwen3_5ForCausalLM state dict). +# These are also the canonical keys used for the reverse (fs2 → HF) export. +_QWEN35_TEXT_KEY_MAP: Final = { + # fmt: off + # Full attention layers + r"^model\.layers\.([0-9]+)\.self_attn\.q_proj\.": r"decoder.layers.\1.self_attn.q_proj.", + r"^model\.layers\.([0-9]+)\.self_attn\.k_proj\.": r"decoder.layers.\1.self_attn.k_proj.", + r"^model\.layers\.([0-9]+)\.self_attn\.v_proj\.": r"decoder.layers.\1.self_attn.v_proj.", + r"^model\.layers\.([0-9]+)\.self_attn\.o_proj\.": r"decoder.layers.\1.self_attn.output_proj.", + r"^model\.layers\.([0-9]+)\.self_attn\.q_norm\.": r"decoder.layers.\1.self_attn.q_norm.", + r"^model\.layers\.([0-9]+)\.self_attn\.k_norm\.": r"decoder.layers.\1.self_attn.k_norm.", + # Linear attention layers (GatedDeltaNet) + r"^model\.layers\.([0-9]+)\.linear_attn\.in_proj_qkv\.": r"decoder.layers.\1.linear_attn.in_proj_qkv.", + r"^model\.layers\.([0-9]+)\.linear_attn\.in_proj_z\.": r"decoder.layers.\1.linear_attn.in_proj_z.", + r"^model\.layers\.([0-9]+)\.linear_attn\.in_proj_b\.": r"decoder.layers.\1.linear_attn.in_proj_b.", + r"^model\.layers\.([0-9]+)\.linear_attn\.in_proj_a\.": r"decoder.layers.\1.linear_attn.in_proj_a.", + r"^model\.layers\.([0-9]+)\.linear_attn\.conv1d\.": r"decoder.layers.\1.linear_attn.conv1d.", + r"^model\.layers\.([0-9]+)\.linear_attn\.dt_bias": r"decoder.layers.\1.linear_attn.dt_bias", + r"^model\.layers\.([0-9]+)\.linear_attn\.A_log": r"decoder.layers.\1.linear_attn.A_log", + r"^model\.layers\.([0-9]+)\.linear_attn\.norm\.": r"decoder.layers.\1.linear_attn.norm.inner_norm.", + r"^model\.layers\.([0-9]+)\.linear_attn\.out_proj\.": r"decoder.layers.\1.linear_attn.out_proj.", + # FFN + r"^model\.layers\.([0-9]+)\.mlp\.gate_proj\.": r"decoder.layers.\1.ffn.gate_proj.", + r"^model\.layers\.([0-9]+)\.mlp\.up_proj\.": r"decoder.layers.\1.ffn.inner_proj.", + r"^model\.layers\.([0-9]+)\.mlp\.down_proj\.": r"decoder.layers.\1.ffn.output_proj.", + # Layer norms + r"^model\.layers\.([0-9]+)\.input_layernorm\.": r"decoder.layers.\1.self_attn_layer_norm.", + r"^model\.layers\.([0-9]+)\.post_attention_layernorm\.": r"decoder.layers.\1.ffn_layer_norm.", + # Embeddings & head + r"^model\.norm\.": r"decoder.layer_norm.", + r"^model\.embed_tokens\.": r"decoder_frontend.embed.", + r"^lm_head\.": r"final_proj.", + # fmt: on +} + + +def _expand_with_language_model_prefix( + key_map: dict[str, str], +) -> dict[str, str]: + """Add ``model.language_model.*`` variants for every ``model.*`` pattern. + + Qwen 3.5 checkpoints on HuggingFace Hub are multimodal (VL) models where + the text decoder lives under ``model.language_model.*``. This helper + duplicates the text-only patterns so that the key map handles both formats: + + * Text-only (``model.layers.*``) — from ``transformers`` ``Qwen3_5ForCausalLM`` + * Multimodal (``model.language_model.layers.*``) — from safetensors checkpoint + """ + expanded: dict[str, str] = dict(key_map) + for pattern, replacement in key_map.items(): + if pattern.startswith(r"^model\."): + vl_pattern = pattern.replace(r"^model\.", r"^model\.language_model\.", 1) + expanded[vl_pattern] = replacement + return expanded + + +# Full key map: handles both text-only and multimodal checkpoint formats. +_QWEN35_HG_KEY_MAP: Final = _expand_with_language_model_prefix(_QWEN35_TEXT_KEY_MAP) + +# RMSNorm keys that need weight += 1.0 conversion (Qwen 3.5 uses 1+w formula). +# The GatedDeltaNet internal norm (linear_attn.norm) does NOT need conversion. +_QWEN35_RMSNORM_KEYS = ( + "self_attn_layer_norm.weight", + "ffn_layer_norm.weight", + "decoder.layer_norm.weight", + "self_attn.q_norm.weight", + "self_attn.k_norm.weight", +) + + +# Components not yet integrated in the text-only CausalLM model. +_QWEN35_VL_SKIP_PREFIXES: Final = ( + "model.visual.", # vision encoder + "mtp.", # multi-token prediction head +) + + +def _is_hg_format(state_dict: dict[str, object]) -> bool: + """Return True when the state dict uses HuggingFace key names.""" + return ( + "model.embed_tokens.weight" in state_dict + or "model.language_model.embed_tokens.weight" in state_dict + ) + + +def convert_qwen35_state_dict( + state_dict: dict[str, object], config: Qwen35Config +) -> dict[str, object]: + # Filter out multimodal components not yet integrated (cf. gemma3n pattern). + state_dict = { + k: v + for k, v in state_dict.items() + if not k.startswith(_QWEN35_VL_SKIP_PREFIXES) + } + + if _is_hg_format(state_dict): + state_dict = convert_state_dict(state_dict, _QWEN35_HG_KEY_MAP) + + # Convert (1+w) RMSNorm weights to standard (w) by adding 1.0. + for key in list(state_dict.keys()): + if any(key.endswith(suffix) for suffix in _QWEN35_RMSNORM_KEYS): + weight = state_dict[key] + if isinstance(weight, torch.Tensor): + state_dict[key] = weight + 1.0 + + if config.tied_embeddings: + if "decoder_frontend.embed.weight" in state_dict: + state_dict["final_proj.weight"] = state_dict[ + "decoder_frontend.embed.weight" + ] + elif "final_proj.weight" in state_dict: + state_dict["decoder_frontend.embed.weight"] = state_dict[ + "final_proj.weight" + ] + + return state_dict + + +# --------------------------------------------------------------------------- +# Qwen 3.5 MoE interop +# --------------------------------------------------------------------------- + +# MoE text-only base: start from the dense text-only map, swap FFN patterns. +_QWEN35_MOE_TEXT_KEY_MAP: Final = { + **{ + k: v + for k, v in _QWEN35_TEXT_KEY_MAP.items() + # Drop dense FFN patterns (MoE uses a different FFN layout) + if k + not in ( + r"^model\.layers\.([0-9]+)\.mlp\.gate_proj\.", + r"^model\.layers\.([0-9]+)\.mlp\.up_proj\.", + r"^model\.layers\.([0-9]+)\.mlp\.down_proj\.", + ) + }, + # fmt: off + r"^model\.layers\.([0-9]+)\.mlp\.gate\.": r"decoder.layers.\1.ffn.gate.", + r"^model\.layers\.([0-9]+)\.mlp\.experts\.gate_up_proj": r"decoder.layers.\1.ffn.experts.gate_up_proj", + r"^model\.layers\.([0-9]+)\.mlp\.experts\.down_proj": r"decoder.layers.\1.ffn.experts.down_proj", + r"^model\.layers\.([0-9]+)\.mlp\.shared_expert\.gate_proj\.": r"decoder.layers.\1.ffn.shared_expert.gate_proj.", + r"^model\.layers\.([0-9]+)\.mlp\.shared_expert\.up_proj\.": r"decoder.layers.\1.ffn.shared_expert.inner_proj.", + r"^model\.layers\.([0-9]+)\.mlp\.shared_expert\.down_proj\.": r"decoder.layers.\1.ffn.shared_expert.output_proj.", + r"^model\.layers\.([0-9]+)\.mlp\.shared_expert_gate\.": r"decoder.layers.\1.ffn.shared_expert_gate.", + # fmt: on +} + +_QWEN35_MOE_HG_KEY_MAP: Final = _expand_with_language_model_prefix( + _QWEN35_MOE_TEXT_KEY_MAP +) + + +def convert_qwen35_moe_state_dict( + state_dict: dict[str, object], config: Qwen35MoeConfig +) -> dict[str, object]: + # Filter out multimodal components not yet integrated (cf. gemma3n pattern). + state_dict = { + k: v + for k, v in state_dict.items() + if not k.startswith(_QWEN35_VL_SKIP_PREFIXES) + } + + if _is_hg_format(state_dict): + state_dict = convert_state_dict(state_dict, _QWEN35_MOE_HG_KEY_MAP) + + # Convert (1+w) RMSNorm weights to standard (w) by adding 1.0. + for key in list(state_dict.keys()): + if any(key.endswith(suffix) for suffix in _QWEN35_RMSNORM_KEYS): + weight = state_dict[key] + if isinstance(weight, torch.Tensor): + state_dict[key] = weight + 1.0 + + if config.tied_embeddings: + if "decoder_frontend.embed.weight" in state_dict: + state_dict["final_proj.weight"] = state_dict[ + "decoder_frontend.embed.weight" + ] + elif "final_proj.weight" in state_dict: + state_dict["decoder_frontend.embed.weight"] = state_dict[ + "final_proj.weight" + ] + + return state_dict diff --git a/src/fairseq2/models/qwen/moe.py b/src/fairseq2/models/qwen/moe.py new file mode 100644 index 000000000..03ead545f --- /dev/null +++ b/src/fairseq2/models/qwen/moe.py @@ -0,0 +1,235 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Mixture-of-Experts modules for Qwen 3.5 MoE. + +This module implements the MoE architecture from Qwen 3.5 MoE following the +HuggingFace reference in ``modeling_qwen3_5_moe.py``. + +Classes: + - :class:`Qwen35TopKRouter` — softmax → top-k → renormalize (HF lines 841-857) + - :class:`Qwen35Experts` — fused 3-D parameter experts (HF lines 802-838) + - :class:`Qwen35MoeBlock` — router + experts + shared expert (HF lines 860-879) +""" + +from __future__ import annotations + +from typing import Final + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn import Module, Parameter +from typing_extensions import override + +from fairseq2.models.transformer import FeedForwardNetwork, GLUFeedForwardNetwork +from fairseq2.nn import Linear + + +class Qwen35TopKRouter(Module): + """Top-k softmax router for Qwen 3.5 MoE. + + Computes softmax over all experts, selects the top-k, and renormalises the + selected weights so they sum to 1. + + Reference: ``Qwen3_5MoeTopKRouter`` (HF lines 841-857). + """ + + num_experts: Final[int] + top_k: Final[int] + model_dim: Final[int] + + def __init__(self, num_experts: int, top_k: int, model_dim: int) -> None: + super().__init__() + + self.num_experts = num_experts + self.top_k = top_k + self.model_dim = model_dim + + self.weight = Parameter(torch.zeros(num_experts, model_dim)) + + def forward(self, hidden_states: Tensor) -> tuple[Tensor, Tensor, Tensor]: + """ + :param hidden_states: + Token representations of shape ``(T, D)`` where *T* is the + (flattened) number of tokens. + + :returns: + A 3-tuple of: + - ``router_logits`` — raw pre-softmax logits ``(T, E)`` + - ``router_weights`` — renormalised top-k weights ``(T, K)`` + - ``router_indices`` — selected expert indices ``(T, K)`` + """ + hidden_states = hidden_states.reshape(-1, self.model_dim) + + router_logits = F.linear(hidden_states, self.weight) + router_probs = F.softmax(router_logits, dtype=torch.float, dim=-1) + + router_weights, router_indices = torch.topk(router_probs, self.top_k, dim=-1) + + router_weights = router_weights / router_weights.sum(dim=-1, keepdim=True) + router_weights = router_weights.to(router_logits.dtype) + + return router_logits, router_weights, router_indices + + +class Qwen35Experts(Module): + """Fused expert layer with 3-D weight parameters for Qwen 3.5 MoE. + + Each expert is a GLU-style MLP (gate+up → SiLU → down) stored as a single + ``(E, 2*I, D)`` gate-up projection and a ``(E, D, I)`` down projection so + that individual experts can be indexed without slicing overhead. + + Reference: ``Qwen3_5MoeExperts`` (HF lines 802-838). + """ + + num_experts: Final[int] + model_dim: Final[int] + expert_inner_dim: Final[int] + + def __init__( + self, + num_experts: int, + model_dim: int, + expert_inner_dim: int, + ) -> None: + super().__init__() + + self.num_experts = num_experts + self.model_dim = model_dim + self.expert_inner_dim = expert_inner_dim + + self.gate_up_proj = Parameter( + torch.empty(num_experts, 2 * expert_inner_dim, model_dim) + ) + self.down_proj = Parameter( + torch.empty(num_experts, model_dim, expert_inner_dim) + ) + + def forward( + self, + hidden_states: Tensor, + top_k_indices: Tensor, + top_k_weights: Tensor, + ) -> Tensor: + """ + :param hidden_states: + Token representations of shape ``(T, D)``. + :param top_k_indices: + Selected expert indices of shape ``(T, K)``. + :param top_k_weights: + Renormalised routing weights of shape ``(T, K)``. + + :returns: + Expert-mixed output of shape ``(T, D)``. + """ + final_hidden_states = torch.zeros_like(hidden_states) + + with torch.no_grad(): + expert_mask = F.one_hot(top_k_indices, num_classes=self.num_experts) + # (T, K, E) → (E, K, T) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + + current_state = hidden_states[token_idx] + + gate, up = F.linear(current_state, self.gate_up_proj[expert_idx]).chunk( + 2, dim=-1 + ) + + current_hidden_states = F.silu(gate) * up + + current_hidden_states = F.linear( + current_hidden_states, self.down_proj[expert_idx] + ) + + current_hidden_states *= top_k_weights[token_idx, top_k_pos, None] + + final_hidden_states.index_add_( + 0, + token_idx, + current_hidden_states.to(final_hidden_states.dtype), + ) + + return final_hidden_states + + +class Qwen35MoeBlock(FeedForwardNetwork): + """Sparse Mixture-of-Experts feed-forward block for Qwen 3.5 MoE. + + Combines a top-k router, a set of sparse experts, a shared expert (standard + GLU MLP), and a learned sigmoid gate that blends the shared expert output + into the final result. + + This class inherits from :class:`FeedForwardNetwork` so it can serve as a + drop-in replacement for :class:`GLUFeedForwardNetwork` inside any + Transformer decoder layer. + + Reference: ``Qwen3_5MoeSparseMoeBlock`` (HF lines 860-879). + """ + + model_dim: Final[int] + + def __init__( + self, + model_dim: int, + num_experts: int, + num_experts_per_tok: int, + moe_intermediate_size: int, + shared_expert_intermediate_size: int, + ) -> None: + """ + :param model_dim: + The dimensionality of the model (``hidden_size``). + :param num_experts: + The total number of routed experts. + :param num_experts_per_tok: + The number of experts activated per token (top-k). + :param moe_intermediate_size: + The intermediate (inner) dimension of each routed expert. + :param shared_expert_intermediate_size: + The intermediate (inner) dimension of the shared expert. + """ + super().__init__() + + self.model_dim = model_dim + + self.gate = Qwen35TopKRouter(num_experts, num_experts_per_tok, model_dim) + + self.experts = Qwen35Experts(num_experts, model_dim, moe_intermediate_size) + + self.shared_expert = GLUFeedForwardNetwork( + model_dim, + shared_expert_intermediate_size, + bias=False, + inner_dim_scale=1.0, + ) + + self.shared_expert_gate = Linear(model_dim, 1, bias=False) + + @override + def forward(self, seqs: Tensor) -> Tensor: + B, S, D = seqs.shape + + hidden_states = seqs.view(-1, D) + + shared_out = self.shared_expert(hidden_states) + + _, routing_weights, selected_experts = self.gate(hidden_states) + + expert_out = self.experts(hidden_states, selected_experts, routing_weights) + + shared_out = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_out + + out: Tensor = (expert_out + shared_out).reshape(B, S, D) + return out diff --git a/tests/integration/models/test_qwen35.py b/tests/integration/models/test_qwen35.py new file mode 100644 index 000000000..e7c2f289c --- /dev/null +++ b/tests/integration/models/test_qwen35.py @@ -0,0 +1,213 @@ +"""Qwen 3.5 0.8B — HuggingFace vs fairseq2 numerical parity test. + +Downloads the HF checkpoint, loads it into both HF and fairseq2, +runs the same input, and asserts logit closeness. +""" + +import os + +import pytest +import torch +import torch.nn.functional as F + +# Use local checkpoint if available (avoids SSL/proxy issues in CI) +_LOCAL_PATH = "/checkpoint/smallomnillm/shared/models/Qwen3.5-0.8B" +MODEL_ID = _LOCAL_PATH if os.path.isdir(_LOCAL_PATH) else "Qwen/Qwen3.5-0.8B" + + +def _hf_model_type_available(model_type: str) -> bool: + """Return True if the installed transformers recognises *model_type*.""" + try: + from transformers.models.auto.configuration_auto import CONFIG_MAPPING + + return model_type in CONFIG_MAPPING + except Exception: + return False + + +@pytest.mark.skipif( + not _hf_model_type_available("qwen3_5"), + reason="transformers does not support model_type 'qwen3_5' (upgrade transformers)", +) +class TestQwen35HFParity: + """Numerical parity between HuggingFace and fairseq2 for Qwen 3.5 0.8B.""" + + def test_logit_parity(self) -> None: + from transformers import AutoModelForCausalLM, AutoTokenizer + + # ---- Step 1: Load HF model ---- + print("=" * 60) + print("Step 1: Loading HuggingFace model...") + print("=" * 60) + + hf_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) + hf_model = AutoModelForCausalLM.from_pretrained( + MODEL_ID, + torch_dtype=torch.float32, + trust_remote_code=True, + ) + hf_model.eval() + print( + f" HF model loaded: {sum(p.numel() for p in hf_model.parameters()):,} params" + ) + + # ---- Step 2: Build fairseq2 model from config ---- + print("\n" + "=" * 60) + print("Step 2: Building fairseq2 model...") + print("=" * 60) + + from fairseq2.models.qwen.config import Qwen35Config + from fairseq2.models.qwen.factory import create_qwen35_model + from fairseq2.models.qwen.interop import convert_qwen35_state_dict + + config = Qwen35Config( + model_dim=1024, + max_seq_len=262_144, + vocab_size=248_320, + tied_embeddings=True, + num_layers=24, + num_attn_heads=8, + num_key_value_heads=2, + head_dim=256, + ffn_inner_dim=3584, + partial_rotary_factor=0.25, + rope_theta=10_000_000.0, + full_attention_interval=4, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=16, + ) + + fs2_model = create_qwen35_model(config) + fs2_model.eval() + print( + f" fs2 model built: {sum(p.numel() for p in fs2_model.parameters()):,} params" + ) + + # ---- Step 3: Convert and load HF state dict into fairseq2 ---- + print("\n" + "=" * 60) + print("Step 3: Converting HF state dict -> fairseq2...") + print("=" * 60) + + hf_state_dict = dict(hf_model.state_dict()) + fs2_state_dict = convert_qwen35_state_dict(hf_state_dict, config) + + fs2_keys = set(fs2_model.state_dict().keys()) + converted_keys = set(fs2_state_dict.keys()) + + missing = fs2_keys - converted_keys + unexpected = converted_keys - fs2_keys + + if missing: + print(f" WARNING: {len(missing)} missing keys:") + for k in sorted(missing)[:20]: + print(f" - {k}") + if unexpected: + print(f" WARNING: {len(unexpected)} unexpected keys:") + for k in sorted(unexpected)[:20]: + print(f" - {k}") + + if missing or unexpected: + print("\n Attempting to load with strict=False...") + result = fs2_model.load_state_dict(fs2_state_dict, strict=False) + print( + f" Missing: {len(result.missing_keys)}, Unexpected: {len(result.unexpected_keys)}" + ) + if result.missing_keys: + pytest.fail( + f"Cannot proceed with {len(result.missing_keys)} missing keys: " + + ", ".join(sorted(result.missing_keys)[:30]) + ) + else: + fs2_model.load_state_dict(fs2_state_dict, strict=True) + print(" State dict loaded successfully (strict=True)") + + # ---- Step 4: Prepare input ---- + print("\n" + "=" * 60) + print("Step 4: Preparing input...") + print("=" * 60) + + test_text = "The capital of France is" + tokens = hf_tokenizer(test_text, return_tensors="pt") + input_ids = tokens["input_ids"] # (1, S) + print(f" Input: '{test_text}'") + print(f" Token IDs: {input_ids.tolist()}") + print(f" Sequence length: {input_ids.shape[1]}") + + # ---- Step 5: HF forward pass ---- + print("\n" + "=" * 60) + print("Step 5: HF forward pass...") + print("=" * 60) + + with torch.no_grad(): + hf_output = hf_model(input_ids) + hf_logits = hf_output.logits # (1, S, V) + + print(f" HF logits shape: {hf_logits.shape}") + print(f" HF logits[0, -1, :5]: {hf_logits[0, -1, :5]}") + + # ---- Step 6: fairseq2 forward pass ---- + print("\n" + "=" * 60) + print("Step 6: fairseq2 forward pass...") + print("=" * 60) + + from fairseq2.nn import BatchLayout + + with torch.no_grad(): + seqs = input_ids # (1, S) + seqs_layout = BatchLayout.of(seqs) + fs2_logits = fs2_model(seqs, seqs_layout) + + print(f" fs2 logits shape: {fs2_logits.shape}") + print(f" fs2 logits[0, -1, :5]: {fs2_logits[0, -1, :5]}") + + # ---- Step 7: Compare ---- + print("\n" + "=" * 60) + print("Step 7: Numerical comparison...") + print("=" * 60) + + hf_last = hf_logits[0, -1].float() + fs2_last = fs2_logits[0, -1].float() + + abs_diff = (hf_last - fs2_last).abs() + max_diff = abs_diff.max().item() + mean_diff = abs_diff.mean().item() + + print(f" Last-token logit max abs diff: {max_diff:.6e}") + print(f" Last-token logit mean abs diff: {mean_diff:.6e}") + + hf_all = hf_logits.float() + fs2_all = fs2_logits.float() + + full_abs_diff = (hf_all - fs2_all).abs() + full_max_diff = full_abs_diff.max().item() + full_mean_diff = full_abs_diff.mean().item() + + print(f" Full-seq logit max abs diff: {full_max_diff:.6e}") + print(f" Full-seq logit mean abs diff: {full_mean_diff:.6e}") + + hf_top1 = hf_last.argmax().item() + fs2_top1 = fs2_last.argmax().item() + print(f"\n HF top-1 token: {hf_top1} -> '{hf_tokenizer.decode([hf_top1])}'") + print(f" fs2 top-1 token: {fs2_top1} -> '{hf_tokenizer.decode([fs2_top1])}'") + + hf_top5 = hf_last.topk(5).indices.tolist() + fs2_top5 = fs2_last.topk(5).indices.tolist() + print( + f"\n HF top-5: {hf_top5} -> {[hf_tokenizer.decode([t]) for t in hf_top5]}" + ) + print( + f" fs2 top-5: {fs2_top5} -> {[hf_tokenizer.decode([t]) for t in fs2_top5]}" + ) + + cos_sim = F.cosine_similarity( + hf_last.unsqueeze(0), fs2_last.unsqueeze(0) + ).item() + print(f"\n Cosine similarity (last token): {cos_sim:.8f}") + + ATOL = 1e-4 + assert ( + full_max_diff < ATOL or cos_sim > 0.9999 + ), f"Parity check failed: max diff {full_max_diff:.2e}, cosine sim {cos_sim:.6f}" diff --git a/tests/unit/models/qwen/__init__.py b/tests/unit/models/qwen/__init__.py new file mode 100644 index 000000000..2e41cd717 --- /dev/null +++ b/tests/unit/models/qwen/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/unit/models/qwen/test_gated_delta_net.py b/tests/unit/models/qwen/test_gated_delta_net.py new file mode 100644 index 000000000..d7826fec2 --- /dev/null +++ b/tests/unit/models/qwen/test_gated_delta_net.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import pytest +import torch +import torch.nn.functional as F + +from fairseq2.models.qwen.gated_delta_net import ( + GatedDeltaNet, + GatedDeltaNetState, + RMSNormGated, + torch_chunk_gated_delta_rule, + torch_recurrent_gated_delta_rule, +) +from fairseq2.nn import IncrementalStateBag +from tests.common import assert_close, device + + +class TestGatedDeltaNet: + def test_forward_produces_correct_shape(self) -> None: + """GatedDeltaNet forward output shape matches input shape (B, S, D).""" + gdn = GatedDeltaNet( + hidden_size=64, + num_k_heads=2, + num_v_heads=4, + head_k_dim=16, + head_v_dim=16, + conv_kernel_size=4, + ) + gdn = gdn.to(device) + + seqs = torch.randn(2, 8, 64, device=device) + with torch.no_grad(): + out = gdn(seqs) + + assert out.shape == (2, 8, 64) + + def test_incremental_decode_matches_full_forward(self) -> None: + """Step-by-step decode with IncrementalStateBag matches full forward.""" + gdn = GatedDeltaNet( + hidden_size=64, + num_k_heads=2, + num_v_heads=4, + head_k_dim=16, + head_v_dim=16, + ) + gdn = gdn.to(device).eval() + + seq_len = 8 + seqs = torch.randn(1, seq_len, 64, device=device) + + with torch.no_grad(): + full_out = gdn(seqs) + + state_bag = IncrementalStateBag(max_num_steps=seq_len) + + with torch.no_grad(): + prefill_out = gdn(seqs, state_bag=state_bag) + + assert_close(prefill_out, full_out, atol=1e-5) + + def test_chunked_vs_recurrent_consistency(self) -> None: + """torch_chunk_gated_delta_rule and torch_recurrent_gated_delta_rule + produce the same output for the same input.""" + B, S, H, K, V = 1, 16, 4, 16, 16 + q = torch.randn(B, S, H, K, device=device) + k = torch.randn(B, S, H, K, device=device) + v = torch.randn(B, S, H, V, device=device) + g = -torch.rand(B, S, H, device=device).abs() + beta = torch.rand(B, S, H, device=device) + + chunk_out, chunk_state = torch_chunk_gated_delta_rule( + q, k, v, g, beta, output_final_state=True, use_qk_l2norm_in_kernel=True + ) + recurrent_out, recurrent_state = torch_recurrent_gated_delta_rule( + q, k, v, g, beta, output_final_state=True, use_qk_l2norm_in_kernel=True + ) + + assert_close(chunk_out, recurrent_out, atol=1e-4) + assert chunk_state is not None + assert recurrent_state is not None + assert_close(chunk_state, recurrent_state, atol=1e-4) + + def test_gated_delta_net_state_reorder(self) -> None: + """GatedDeltaNetState.reorder correctly reorders batch dimension.""" + conv = torch.randn(3, 8, 3, device=device) + rec = torch.randn(3, 4, 16, 16, device=device) + state = GatedDeltaNetState(conv, rec) + + new_order = torch.tensor([2, 0, 1], device=device) + state.reorder(new_order) + + assert_close(state.conv_state[0], conv[2]) + assert_close(state.conv_state[1], conv[0]) + assert_close(state.recurrent_state[0], rec[2]) + + def test_rmsnorm_gated_output(self) -> None: + """RMSNormGated produces norm(x) * silu(gate).""" + dim = 16 + norm = RMSNormGated(dim).to(device) + + x = torch.randn(4, dim, device=device) + gate = torch.randn(4, dim, device=device) + + out = norm(x, gate) + + x_f32 = x.float() + variance = x_f32.pow(2).mean(-1, keepdim=True) + x_normed = x_f32 * torch.rsqrt(variance + 1e-6) + assert norm.inner_norm.weight is not None + expected = (norm.inner_norm.weight * x_normed) * F.silu(gate.float()) + + assert_close(out, expected.to(out.dtype), atol=1e-5) + + @pytest.mark.skipif( + not torch.cuda.is_available(), + reason="causal_conv1d incremental decode requires CUDA", + ) + def test_step_by_step_decode_matches_prefill(self) -> None: + """After prefilling, incremental decode of one token matches full forward.""" + gdn = GatedDeltaNet( + hidden_size=64, + num_k_heads=2, + num_v_heads=4, + head_k_dim=16, + head_v_dim=16, + ) + gdn = gdn.to(device).eval() + + prefill_len = 8 + full_seq = torch.randn(1, prefill_len + 1, 64, device=device) + + with torch.no_grad(): + full_out = gdn(full_seq) + ground_truth = full_out[:, -1:, :] + + state_bag = IncrementalStateBag(max_num_steps=prefill_len + 1) + + with torch.no_grad(): + gdn(full_seq[:, :prefill_len, :], state_bag=state_bag) + + state_bag.increment_step_nr(prefill_len) + + with torch.no_grad(): + incr_out = gdn(full_seq[:, prefill_len:, :], state_bag=state_bag) + + assert_close(incr_out, ground_truth, atol=1e-4) diff --git a/tests/unit/models/qwen/test_qwen35_attention.py b/tests/unit/models/qwen/test_qwen35_attention.py new file mode 100644 index 000000000..af9d058b5 --- /dev/null +++ b/tests/unit/models/qwen/test_qwen35_attention.py @@ -0,0 +1,173 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import torch + +from fairseq2.models.qwen.attention import Qwen35Attention +from fairseq2.models.transformer.attention_bias import ( + AttentionBiasCache, + CausalAttentionBias, + IdentityBias, +) +from fairseq2.models.transformer.sdpa.naive import NaiveSDPA +from fairseq2.nn import BatchLayout, IncrementalStateBag, RMSNorm +from fairseq2.nn.position_encoder import ReferenceRotaryEncoder +from tests.common import assert_close, device + + +class TestQwen35Attention: + def test_forward_produces_correct_shape(self) -> None: + """Output shape is (B, S, model_dim).""" + sdpa = NaiveSDPA(IdentityBias()) + attn = Qwen35Attention(model_dim=64, num_heads=4, sdpa=sdpa, head_dim=16) + attn = attn.to(device) + + seqs = torch.randn(2, 8, 64, device=device) + layout = BatchLayout.of(seqs) + bias_cache = AttentionBiasCache() + + with torch.no_grad(): + out = attn(seqs, layout, seqs, layout, seqs, bias_cache) + + assert out.shape == (2, 8, 64) + + def test_output_gating_effect(self) -> None: + """When gate output is all zeros, attention output should be near zero.""" + sdpa = NaiveSDPA(IdentityBias()) + attn = Qwen35Attention(model_dim=64, num_heads=4, sdpa=sdpa, head_dim=16) + attn = attn.to(device) + + seqs = torch.randn(1, 4, 64, device=device) + layout = BatchLayout.of(seqs) + bias_cache = AttentionBiasCache() + + with torch.no_grad(): + out1 = attn(seqs, layout, seqs, layout, seqs, bias_cache) + + # Verify output is not zero (gate should be non-trivial with random weights) + assert out1.abs().mean() > 1e-6 + + def test_partial_rope_applies_to_subset_of_dims(self) -> None: + """With encoding_dim < head_dim, only first encoding_dim dims should be rotated.""" + model_dim = 64 + num_heads = 4 + head_dim = 16 + encoding_dim = 4 # Only first 4 of 16 dims rotated + + rope = ReferenceRotaryEncoder(encoding_dim, max_seq_len=32, device=device) + sdpa = NaiveSDPA(IdentityBias()) + attn = Qwen35Attention( + model_dim=model_dim, + num_heads=num_heads, + sdpa=sdpa, + head_dim=head_dim, + pos_encoder=rope, + ) + attn = attn.to(device) + + seqs = torch.randn(1, 4, model_dim, device=device) + layout = BatchLayout.of(seqs) + bias_cache = AttentionBiasCache() + + with torch.no_grad(): + out = attn(seqs, layout, seqs, layout, seqs, bias_cache) + + assert out.shape == (1, 4, model_dim) + + def test_gqa_with_fewer_kv_heads(self) -> None: + """GQA with num_key_value_heads < num_heads works correctly.""" + sdpa = NaiveSDPA(IdentityBias()) + attn = Qwen35Attention( + model_dim=64, + num_heads=4, + sdpa=sdpa, + head_dim=16, + num_key_value_heads=2, # GQA: 4 Q heads, 2 KV heads + ) + attn = attn.to(device) + + seqs = torch.randn(2, 6, 64, device=device) + layout = BatchLayout.of(seqs) + bias_cache = AttentionBiasCache() + + with torch.no_grad(): + out = attn(seqs, layout, seqs, layout, seqs, bias_cache) + + assert out.shape == (2, 6, 64) + + def test_qk_norm_applied(self) -> None: + """When q_norm and k_norm are provided, output should differ from no-norm case.""" + sdpa = NaiveSDPA(IdentityBias()) + + # Without norms + attn_no_norm = Qwen35Attention( + model_dim=64, num_heads=4, sdpa=sdpa, head_dim=16 + ) + attn_no_norm = attn_no_norm.to(device) + + # With norms + q_norm = RMSNorm(16, bias=False, device=device) + k_norm = RMSNorm(16, bias=False, device=device) + attn_norm = Qwen35Attention( + model_dim=64, + num_heads=4, + sdpa=sdpa, + head_dim=16, + q_norm=q_norm, + k_norm=k_norm, + ) + attn_norm = attn_norm.to(device) + + # Copy weights so only the norm makes a difference + attn_norm.q_proj.weight.data.copy_(attn_no_norm.q_proj.weight.data) + attn_norm.k_proj.weight.data.copy_(attn_no_norm.k_proj.weight.data) + attn_norm.v_proj.weight.data.copy_(attn_no_norm.v_proj.weight.data) + attn_norm.output_proj.weight.data.copy_(attn_no_norm.output_proj.weight.data) + + seqs = torch.randn(1, 4, 64, device=device) + layout = BatchLayout.of(seqs) + bias_cache = AttentionBiasCache() + + with torch.no_grad(): + out_no_norm = attn_no_norm(seqs, layout, seqs, layout, seqs, bias_cache) + out_norm = attn_norm(seqs, layout, seqs, layout, seqs, bias_cache) + + # Outputs should differ because of norm + assert not torch.allclose(out_no_norm, out_norm, atol=1e-6) + + def test_incremental_kv_cache_matches_full_forward(self) -> None: + """Token-by-token decoding with KV cache produces the same logits as causal full-sequence forward.""" + sdpa = NaiveSDPA(CausalAttentionBias()) + attn = Qwen35Attention(model_dim=64, num_heads=4, sdpa=sdpa, head_dim=16) + attn = attn.to(device) + attn.eval() + + seqs = torch.randn(1, 6, 64, device=device) + layout = BatchLayout.of(seqs) + bias_cache = AttentionBiasCache() + + with torch.no_grad(): + full_out = attn(seqs, layout, seqs, layout, seqs, bias_cache) + + state_bag = IncrementalStateBag(max_num_steps=32) + + with torch.no_grad(): + for idx in range(6): + step_seqs = seqs[:, idx : idx + 1, :] + step_layout = BatchLayout.of(step_seqs) + out = attn( + step_seqs, + step_layout, + step_seqs, + step_layout, + step_seqs, + bias_cache, + state_bag=state_bag, + ) + assert_close(out, full_out[:, idx : idx + 1, :], atol=1e-5) + state_bag.increment_step_nr() diff --git a/tests/unit/models/qwen/test_qwen35_decoder_layer.py b/tests/unit/models/qwen/test_qwen35_decoder_layer.py new file mode 100644 index 000000000..d25363301 --- /dev/null +++ b/tests/unit/models/qwen/test_qwen35_decoder_layer.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import pytest +import torch + +from fairseq2.models.qwen.attention import Qwen35Attention +from fairseq2.models.qwen.config import Qwen35Config +from fairseq2.models.qwen.decoder_layer import Qwen35DecoderLayer +from fairseq2.models.qwen.factory import create_qwen35_model +from fairseq2.models.qwen.gated_delta_net import GatedDeltaNet +from fairseq2.models.transformer import GLUFeedForwardNetwork +from fairseq2.models.transformer.attention_bias import ( + AttentionBiasCache, + IdentityBias, +) +from fairseq2.models.transformer.sdpa.naive import NaiveSDPA +from fairseq2.nn import BatchLayout, RMSNorm +from tests.common import device + + +class TestQwen35DecoderLayer: + def test_full_attention_layer_forward(self) -> None: + """Full attention layer produces correct shape.""" + model_dim = 64 + sdpa = NaiveSDPA(IdentityBias()) + self_attn = Qwen35Attention(model_dim, num_heads=4, sdpa=sdpa, head_dim=16) + ffn = GLUFeedForwardNetwork(model_dim, 128, bias=False, inner_dim_scale=1.0) + layer = Qwen35DecoderLayer( + "full_attention", + self_attn=self_attn, + linear_attn=None, + ffn=ffn, + self_attn_layer_norm=RMSNorm(model_dim, bias=False), + ffn_layer_norm=RMSNorm(model_dim, bias=False), + ).to(device) + + seqs = torch.randn(2, 8, model_dim, device=device) + layout = BatchLayout.of(seqs) + bias_cache = AttentionBiasCache() + + with torch.no_grad(): + out = layer(seqs, layout, bias_cache) + + assert out.shape == (2, 8, model_dim) + + def test_linear_attention_layer_forward(self) -> None: + """Linear attention (GatedDeltaNet) layer produces correct shape.""" + model_dim = 64 + gdn = GatedDeltaNet( + model_dim, num_k_heads=2, num_v_heads=4, head_k_dim=8, head_v_dim=8 + ) + ffn = GLUFeedForwardNetwork(model_dim, 128, bias=False, inner_dim_scale=1.0) + layer = Qwen35DecoderLayer( + "linear_attention", + self_attn=None, + linear_attn=gdn, + ffn=ffn, + self_attn_layer_norm=RMSNorm(model_dim, bias=False), + ffn_layer_norm=RMSNorm(model_dim, bias=False), + ).to(device) + + seqs = torch.randn(2, 8, model_dim, device=device) + layout = BatchLayout.of(seqs) + bias_cache = AttentionBiasCache() + + with torch.no_grad(): + out = layer(seqs, layout, bias_cache) + + assert out.shape == (2, 8, model_dim) + + def test_invalid_layer_type_raises(self) -> None: + """Invalid layer_type raises ValueError.""" + model_dim = 64 + ffn = GLUFeedForwardNetwork(model_dim, 128, bias=False, inner_dim_scale=1.0) + + with pytest.raises(ValueError, match="layer_type"): + Qwen35DecoderLayer( + "invalid_type", + self_attn=None, + linear_attn=None, + ffn=ffn, + self_attn_layer_norm=RMSNorm(model_dim, bias=False), + ffn_layer_norm=RMSNorm(model_dim, bias=False), + ) + + +class TestQwen35ModelFactory: + def test_create_small_model(self) -> None: + """Factory creates a working model with the correct output shape.""" + config = Qwen35Config( + model_dim=64, + vocab_size=128, + num_layers=4, + num_attn_heads=4, + num_key_value_heads=2, + head_dim=16, + ffn_inner_dim=128, + partial_rotary_factor=0.25, + linear_num_key_heads=2, + linear_num_value_heads=4, + linear_key_head_dim=8, + linear_value_head_dim=8, + ) + + model = create_qwen35_model(config).to(device) + model.eval() + + input_ids = torch.randint(0, 128, (1, 16), device=device) + layout = BatchLayout.of(input_ids) + + with torch.no_grad(): + logits = model(input_ids, layout) + + assert logits.shape == (1, 16, 128) + + def test_model_has_hybrid_layers(self) -> None: + """Model should have both full_attention and linear_attention layers.""" + config = Qwen35Config( + model_dim=64, + vocab_size=128, + num_layers=4, + num_attn_heads=4, + num_key_value_heads=2, + head_dim=16, + ffn_inner_dim=128, + linear_num_key_heads=2, + linear_num_value_heads=4, + linear_key_head_dim=8, + linear_value_head_dim=8, + ) + + with torch.device("meta"): + model = create_qwen35_model(config) + + layers = list(model.decoder.layers) + layer_types = [ + l.layer_type for l in layers if isinstance(l, Qwen35DecoderLayer) + ] + assert layer_types == [ + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + ] diff --git a/tests/unit/models/qwen/test_qwen35_interop.py b/tests/unit/models/qwen/test_qwen35_interop.py new file mode 100644 index 000000000..f41c236d3 --- /dev/null +++ b/tests/unit/models/qwen/test_qwen35_interop.py @@ -0,0 +1,500 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for the Qwen 3.5 HuggingFace state-dict interop.""" + +from __future__ import annotations + +import torch +from torch.testing import assert_close + +from fairseq2.models.qwen.config import Qwen35Config, Qwen35MoeConfig +from fairseq2.models.qwen.factory import create_qwen35_model, create_qwen35_moe_model +from fairseq2.models.qwen.interop import ( + _QWEN35_HG_KEY_MAP, + _QWEN35_RMSNORM_KEYS, + _QWEN35_TEXT_KEY_MAP, + _Qwen35HuggingFaceConverter, + _Qwen35MoeHuggingFaceConverter, + convert_qwen35_moe_state_dict, + convert_qwen35_state_dict, +) +from fairseq2.models.utils.checkpoint import convert_state_dict, create_reverse_key_map + + +class TestQwen35Interop: + def _make_small_config(self) -> Qwen35Config: + """Create a tiny config for fast testing.""" + config = Qwen35Config() + config.model_dim = 64 + config.vocab_size = 128 + config.num_layers = 4 # 3 linear + 1 full attention + config.num_attn_heads = 4 + config.num_key_value_heads = 2 + config.head_dim = 16 + config.ffn_inner_dim = 128 + config.partial_rotary_factor = 0.25 + config.linear_num_key_heads = 2 + config.linear_num_value_heads = 4 + config.linear_key_head_dim = 8 + config.linear_value_head_dim = 8 + config.layer_types = None # Reset so __post_init__ regenerates for num_layers=4 + config.__post_init__() + return config + + def test_state_dict_key_round_trip(self) -> None: + """fs2 keys -> HF keys -> fs2 keys should be identity.""" + config = self._make_small_config() + + with torch.device("meta"): + model = create_qwen35_model(config) + + fs2_keys = set(model.state_dict().keys()) + assert len(fs2_keys) > 0 + + fs2_state_dict: dict[str, object] = {k: torch.empty(0) for k in fs2_keys} + + # Convert to HF format using reverse key map + reverse_map = create_reverse_key_map(_QWEN35_HG_KEY_MAP) + hg_state_dict = convert_state_dict(fs2_state_dict, reverse_map) + + # Verify HF keys have expected prefixes + for key in hg_state_dict: + assert key.startswith( + ("model.", "lm_head.") + ), f"Unexpected HF key prefix: {key}" + + # Convert back to fs2 format + rt_state_dict = convert_state_dict(dict(hg_state_dict), _QWEN35_HG_KEY_MAP) + rt_keys = set(rt_state_dict.keys()) + + assert fs2_keys == rt_keys, ( + f"Round-trip key mismatch.\n" + f" Missing in round-trip: {fs2_keys - rt_keys}\n" + f" Extra in round-trip: {rt_keys - fs2_keys}" + ) + + def test_rmsnorm_weight_conversion(self) -> None: + """RMSNorm weights get +1.0 added during conversion.""" + config = self._make_small_config() + + # Simulate HF state dict with zero-init RMSNorm weights + hf_state_dict: dict[str, object] = {} + for i in range(config.num_layers): + hf_state_dict[f"model.layers.{i}.input_layernorm.weight"] = torch.zeros( + config.model_dim + ) + hf_state_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = ( + torch.zeros(config.model_dim) + ) + hf_state_dict["model.norm.weight"] = torch.zeros(config.model_dim) + hf_state_dict["model.embed_tokens.weight"] = torch.zeros( + config.vocab_size, config.model_dim + ) + hf_state_dict["lm_head.weight"] = torch.zeros( + config.vocab_size, config.model_dim + ) + + converted = convert_qwen35_state_dict(dict(hf_state_dict), config) + + # All layer norm weights should now be 1.0 (0.0 + 1.0) + for key in converted: + if any(key.endswith(s) for s in _QWEN35_RMSNORM_KEYS): + weight = converted[key] + assert isinstance(weight, torch.Tensor) + assert_close(weight, torch.ones_like(weight)) + + def test_gdn_norm_weight_not_converted(self) -> None: + """GatedDeltaNet internal norm weights should NOT get +1.0.""" + config = self._make_small_config() + + # Simulate HF state dict with GDN norm weight + hf_state_dict: dict[str, object] = {"model.embed_tokens.weight": torch.zeros(1)} + hf_state_dict["model.layers.0.linear_attn.norm.weight"] = ( + torch.ones(config.linear_value_head_dim) * 0.5 + ) + + converted = convert_qwen35_state_dict(dict(hf_state_dict), config) + + # The GDN norm maps to linear_attn.norm.inner_norm.weight + gdn_key = "decoder.layers.0.linear_attn.norm.inner_norm.weight" + if gdn_key in converted: + # Should still be 0.5, NOT 1.5 + assert_close( + converted[gdn_key], + torch.ones(config.linear_value_head_dim) * 0.5, + ) + + def test_tied_embeddings_hf_no_lm_head(self) -> None: + """HF checkpoint with tied_embeddings has no lm_head.weight. + + Safetensors deduplicates shared tensors, so for models with + tie_word_embeddings=True the checkpoint only contains + model.embed_tokens.weight. The converter must create + final_proj.weight from it. + """ + config = self._make_small_config() + config.tied_embeddings = True + + weight = torch.randn(config.vocab_size, config.model_dim) + hf_state_dict: dict[str, object] = { + "model.embed_tokens.weight": weight, + "model.norm.weight": torch.zeros(config.model_dim), + } + + result = convert_qwen35_state_dict(dict(hf_state_dict), config) + + assert "decoder_frontend.embed.weight" in result + assert "final_proj.weight" in result + assert result["final_proj.weight"] is result["decoder_frontend.embed.weight"] + + def test_layer_types_are_correct(self) -> None: + """Verify layer_types pattern: 3 linear, 1 full, repeating.""" + config = self._make_small_config() + assert config.layer_types == [ + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + ] + + +class TestQwen35HuggingFaceConverter: + """Tests for _Qwen35HuggingFaceConverter.""" + + def _make_small_config(self) -> Qwen35Config: + config = Qwen35Config() + config.model_dim = 64 + config.vocab_size = 128 + config.num_layers = 4 + config.num_attn_heads = 4 + config.num_key_value_heads = 2 + config.head_dim = 16 + config.ffn_inner_dim = 128 + config.partial_rotary_factor = 0.25 + config.linear_num_key_heads = 2 + config.linear_num_value_heads = 4 + config.linear_key_head_dim = 8 + config.linear_value_head_dim = 8 + config.layer_types = None + config.__post_init__() + return config + + def test_to_hg_config(self) -> None: + """to_hg_config maps Qwen35Config fields to HF config dict.""" + config = self._make_small_config() + converter = _Qwen35HuggingFaceConverter() + hg_config = converter.to_hg_config(config) + + assert hg_config.kls_name == "Qwen3_5TextConfig" + assert hg_config.arch == "Qwen3_5ForCausalLM" + + data = hg_config.data + assert data["hidden_size"] == config.model_dim + assert data["max_position_embeddings"] == config.max_seq_len + assert data["vocab_size"] == config.vocab_size + assert data["tie_word_embeddings"] == config.tied_embeddings + assert data["num_hidden_layers"] == config.num_layers + assert data["num_attention_heads"] == config.num_attn_heads + assert data["num_key_value_heads"] == config.num_key_value_heads + assert data["head_dim"] == config.head_dim + assert data["intermediate_size"] == config.ffn_inner_dim + assert data["partial_rotary_factor"] == config.partial_rotary_factor + assert data["rope_theta"] == config.rope_theta + assert data["full_attention_interval"] == config.full_attention_interval + assert data["linear_conv_kernel_dim"] == config.linear_conv_kernel_dim + assert data["linear_key_head_dim"] == config.linear_key_head_dim + assert data["linear_value_head_dim"] == config.linear_value_head_dim + assert data["linear_num_key_heads"] == config.linear_num_key_heads + assert data["linear_num_value_heads"] == config.linear_num_value_heads + + def test_state_dict_round_trip(self) -> None: + """State dict keys survive a fs2 -> HF -> fs2 round trip.""" + config = self._make_small_config() + + with torch.device("meta"): + model = create_qwen35_model(config) + + fs2_keys = set(model.state_dict().keys()) + assert len(fs2_keys) > 0 + + fs2_state_dict: dict[str, object] = {k: torch.empty(0) for k in fs2_keys} + + converter = _Qwen35HuggingFaceConverter() + hg_state_dict = converter.to_hg_state_dict(fs2_state_dict, config) + hg_keys = set(hg_state_dict.keys()) + + for key in hg_keys: + assert key.startswith( + ("model.", "lm_head.") + ), f"Unexpected HF key prefix: {key}" + + # Round-trip: HF -> fs2 + rt_state_dict = convert_qwen35_state_dict(dict(hg_state_dict), config) + rt_keys = set(rt_state_dict.keys()) + + assert fs2_keys == rt_keys, ( + f"Round-trip key mismatch.\n" + f" Missing in round-trip: {fs2_keys - rt_keys}\n" + f" Extra in round-trip: {rt_keys - fs2_keys}" + ) + + def test_rmsnorm_weight_reversed(self) -> None: + """to_hg_state_dict subtracts 1.0 from RMSNorm weights.""" + config = self._make_small_config() + + # Build a fs2 state dict with RMSNorm weights = 1.0 (standard init) + fs2_state_dict: dict[str, object] = {} + for i in range(config.num_layers): + fs2_state_dict[f"decoder.layers.{i}.self_attn_layer_norm.weight"] = ( + torch.ones(config.model_dim) + ) + fs2_state_dict[f"decoder.layers.{i}.ffn_layer_norm.weight"] = torch.ones( + config.model_dim + ) + fs2_state_dict["decoder.layer_norm.weight"] = torch.ones(config.model_dim) + + converter = _Qwen35HuggingFaceConverter() + hg_state_dict = converter.to_hg_state_dict(fs2_state_dict, config) + + # HF weights should be 0.0 (1.0 - 1.0) + for key in hg_state_dict: + if key.endswith( + ( + "input_layernorm.weight", + "post_attention_layernorm.weight", + "model.norm.weight", + ) + ): + weight = hg_state_dict[key] + assert isinstance(weight, torch.Tensor) + assert_close(weight, torch.zeros_like(weight)) + + def test_tied_embeddings_removes_lm_head(self) -> None: + """to_hg_state_dict removes lm_head.weight when tied_embeddings=True.""" + config = self._make_small_config() + config.tied_embeddings = True + + with torch.device("meta"): + model = create_qwen35_model(config) + + fs2_state_dict: dict[str, object] = { + k: torch.empty(0) for k in model.state_dict().keys() + } + + converter = _Qwen35HuggingFaceConverter() + hg_state_dict = converter.to_hg_state_dict(fs2_state_dict, config) + + assert "lm_head.weight" not in hg_state_dict + assert "model.embed_tokens.weight" in hg_state_dict + + def test_tied_embeddings_deduped_final_proj_only(self) -> None: + """When safetensors deduplicates tied weights and only final_proj.weight + survives, convert_qwen35_state_dict should still reconstruct both keys.""" + config = self._make_small_config() + config.tied_embeddings = True + + weight = torch.randn(config.vocab_size, config.model_dim) + state_dict: dict[str, object] = {"final_proj.weight": weight} + + result = convert_qwen35_state_dict(dict(state_dict), config) + + assert "decoder_frontend.embed.weight" in result + assert result["decoder_frontend.embed.weight"] is weight + + +class TestQwen35MoeHuggingFaceConverter: + """Tests for _Qwen35MoeHuggingFaceConverter.""" + + def _make_small_moe_config(self) -> Qwen35MoeConfig: + config = Qwen35MoeConfig() + config.model_dim = 64 + config.vocab_size = 128 + config.num_layers = 4 + config.num_attn_heads = 4 + config.num_key_value_heads = 2 + config.head_dim = 16 + config.ffn_inner_dim = 128 + config.partial_rotary_factor = 0.25 + config.linear_num_key_heads = 2 + config.linear_num_value_heads = 4 + config.linear_key_head_dim = 8 + config.linear_value_head_dim = 8 + config.num_experts = 4 + config.num_experts_per_tok = 2 + config.moe_intermediate_size = 32 + config.shared_expert_intermediate_size = 32 + config.layer_types = None + config.__post_init__() + return config + + def test_to_hg_config(self) -> None: + """to_hg_config maps Qwen35MoeConfig fields including MoE-specific ones.""" + config = self._make_small_moe_config() + converter = _Qwen35MoeHuggingFaceConverter() + hg_config = converter.to_hg_config(config) + + assert hg_config.kls_name == "Qwen3_5TextConfig" + assert hg_config.arch == "Qwen3_5MoeForCausalLM" + + data = hg_config.data + assert data["hidden_size"] == config.model_dim + assert data["num_experts"] == config.num_experts + assert data["num_experts_per_tok"] == config.num_experts_per_tok + assert data["moe_intermediate_size"] == config.moe_intermediate_size + assert ( + data["shared_expert_intermediate_size"] + == config.shared_expert_intermediate_size + ) + assert data["router_aux_loss_coef"] == config.router_aux_loss_coef + + def test_state_dict_round_trip(self) -> None: + """MoE state dict keys survive a fs2 -> HF -> fs2 round trip.""" + config = self._make_small_moe_config() + + with torch.device("meta"): + model = create_qwen35_moe_model(config) + + fs2_keys = set(model.state_dict().keys()) + assert len(fs2_keys) > 0 + + fs2_state_dict: dict[str, object] = {k: torch.empty(0) for k in fs2_keys} + + converter = _Qwen35MoeHuggingFaceConverter() + hg_state_dict = converter.to_hg_state_dict(fs2_state_dict, config) + hg_keys = set(hg_state_dict.keys()) + + for key in hg_keys: + assert key.startswith( + ("model.", "lm_head.") + ), f"Unexpected HF key prefix: {key}" + + # Round-trip: HF -> fs2 + rt_state_dict = convert_qwen35_moe_state_dict(dict(hg_state_dict), config) + rt_keys = set(rt_state_dict.keys()) + + assert fs2_keys == rt_keys, ( + f"Round-trip key mismatch.\n" + f" Missing in round-trip: {fs2_keys - rt_keys}\n" + f" Extra in round-trip: {rt_keys - fs2_keys}" + ) + + +class TestVlCheckpointHandling: + """Tests for multimodal (VL) checkpoint handling. + + Qwen 3.5 checkpoints on HuggingFace Hub are multimodal models where the + text decoder lives under ``model.language_model.*`` with additional + ``model.visual.*`` and ``mtp.*`` keys. The converter handles both formats + via ``_expand_with_language_model_prefix`` and explicit filtering. + """ + + def _make_small_config(self) -> Qwen35Config: + config = Qwen35Config() + config.model_dim = 64 + config.vocab_size = 128 + config.num_layers = 4 + config.num_attn_heads = 4 + config.num_key_value_heads = 2 + config.head_dim = 16 + config.ffn_inner_dim = 128 + config.partial_rotary_factor = 0.25 + config.linear_num_key_heads = 2 + config.linear_num_value_heads = 4 + config.linear_key_head_dim = 8 + config.linear_value_head_dim = 8 + config.tied_embeddings = True + config.layer_types = None + config.__post_init__() + return config + + def test_key_map_has_language_model_variants(self) -> None: + """_QWEN35_HG_KEY_MAP includes both model.* and model.language_model.* patterns.""" + text_only_count = len(_QWEN35_TEXT_KEY_MAP) + full_count = len(_QWEN35_HG_KEY_MAP) + # model.* patterns get duplicated; lm_head.* does not + model_prefix_count = sum( + 1 for k in _QWEN35_TEXT_KEY_MAP if k.startswith(r"^model\.") + ) + assert full_count == text_only_count + model_prefix_count + + def test_language_model_prefix_keys_convert(self) -> None: + """model.language_model.X keys are correctly converted to fs2 keys.""" + state_dict: dict[str, object] = { + "model.language_model.embed_tokens.weight": torch.empty(0), + "model.language_model.layers.0.input_layernorm.weight": torch.empty(0), + "model.language_model.norm.weight": torch.empty(0), + } + result = convert_state_dict(state_dict, _QWEN35_HG_KEY_MAP) + assert "decoder_frontend.embed.weight" in result + assert "decoder.layers.0.self_attn_layer_norm.weight" in result + assert "decoder.layer_norm.weight" in result + + def test_visual_and_mtp_keys_filtered(self) -> None: + """model.visual.* and mtp.* keys are filtered by convert_qwen35_state_dict.""" + config = self._make_small_config() + state_dict: dict[str, object] = { + "model.language_model.embed_tokens.weight": torch.randn( + config.vocab_size, config.model_dim + ), + "model.language_model.norm.weight": torch.zeros(config.model_dim), + "model.visual.blocks.0.attn.proj.weight": torch.empty(0), + "model.visual.patch_embed.proj.weight": torch.empty(0), + "mtp.fc.weight": torch.empty(0), + "mtp.layers.0.mlp.gate_proj.weight": torch.empty(0), + } + result = convert_qwen35_state_dict(dict(state_dict), config) + for key in result: + assert not key.startswith( + ("model.visual.", "mtp.") + ), f"Unexpected key not filtered: {key}" + + def test_text_only_format_still_works(self) -> None: + """model.layers.* (text-only format) is still handled correctly.""" + config = self._make_small_config() + state_dict: dict[str, object] = { + "model.embed_tokens.weight": torch.randn( + config.vocab_size, config.model_dim + ), + "model.norm.weight": torch.zeros(config.model_dim), + } + result = convert_qwen35_state_dict(dict(state_dict), config) + assert "decoder_frontend.embed.weight" in result + assert "decoder.layer_norm.weight" in result + + def test_end_to_end_vl_checkpoint(self) -> None: + """Full VL checkpoint → convert_qwen35_state_dict produces correct keys.""" + config = self._make_small_config() + + with torch.device("meta"): + model = create_qwen35_model(config) + model_keys = set(model.state_dict().keys()) + + # Build a text-only HF state dict, then add VL prefix + extra modalities + reverse_map = create_reverse_key_map(_QWEN35_TEXT_KEY_MAP) + fs2_state_dict: dict[str, object] = {k: torch.empty(0) for k in model_keys} + hg_state_dict = convert_state_dict(fs2_state_dict, reverse_map) + + # Add model.language_model. prefix (simulating VL checkpoint) + vl_state_dict: dict[str, object] = {} + for k, v in hg_state_dict.items(): + if k.startswith("model."): + vl_state_dict["model.language_model." + k[len("model.") :]] = v + else: + vl_state_dict[k] = v + # Add visual/mtp keys + vl_state_dict["model.visual.blocks.0.attn.proj.weight"] = torch.empty(0) + vl_state_dict["mtp.fc.weight"] = torch.empty(0) + + # Convert back — should match model keys + result = convert_qwen35_state_dict(dict(vl_state_dict), config) + result_keys = set(result.keys()) + + assert model_keys == result_keys, ( + f"VL round-trip key mismatch.\n" + f" Missing: {model_keys - result_keys}\n" + f" Extra: {result_keys - model_keys}" + ) diff --git a/tests/unit/models/qwen/test_qwen35_moe.py b/tests/unit/models/qwen/test_qwen35_moe.py new file mode 100644 index 000000000..c0d06b953 --- /dev/null +++ b/tests/unit/models/qwen/test_qwen35_moe.py @@ -0,0 +1,154 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import torch + +from fairseq2.models.qwen.moe import Qwen35Experts, Qwen35MoeBlock, Qwen35TopKRouter +from tests.common import assert_close, device + + +class TestQwen35TopKRouter: + def test_forward_output_shapes(self) -> None: + """Router returns correct shapes: logits(T,E), weights(T,K), indices(T,K).""" + router = Qwen35TopKRouter(num_experts=8, top_k=2, model_dim=32).to(device) + + x = torch.randn(10, 32, device=device) + + with torch.no_grad(): + logits, weights, indices = router(x) + + assert logits.shape == (10, 8) # (T, E) + assert weights.shape == (10, 2) # (T, K) + assert indices.shape == (10, 2) # (T, K) + + def test_weights_sum_to_one(self) -> None: + """Renormalized top-k weights sum to 1 per token.""" + router = Qwen35TopKRouter(num_experts=8, top_k=2, model_dim=32).to(device) + + x = torch.randn(10, 32, device=device) + + with torch.no_grad(): + _, weights, _ = router(x) + + sums = weights.sum(dim=-1) + assert_close(sums, torch.ones(10, device=device), atol=1e-5) + + def test_logits_are_raw_pre_softmax(self) -> None: + """Router logits are raw pre-softmax values (NOT a probability distribution).""" + router = Qwen35TopKRouter(num_experts=8, top_k=2, model_dim=32).to(device) + # Initialize with non-zero weights so logits are non-trivial. + torch.nn.init.normal_(router.weight, std=0.1) + + x = torch.randn(10, 32, device=device) + + with torch.no_grad(): + logits, weights, _ = router(x) + + # Raw logits can be negative and do NOT sum to 1. + assert logits.shape == (10, 8) + sums = logits.sum(dim=-1) + assert not torch.allclose( + sums, torch.ones(10, device=device), atol=1e-3 + ), "Raw logits should NOT sum to 1 (they are not softmax)" + + # But the top-k weights DO sum to 1 (renormalized). + w_sums = weights.sum(dim=-1) + assert_close(w_sums, torch.ones(10, device=device), atol=1e-5) + + +class TestQwen35Experts: + def test_forward_output_shape(self) -> None: + """Experts output shape matches input shape (T, D).""" + experts = Qwen35Experts(num_experts=4, model_dim=32, expert_inner_dim=16).to( + device + ) + torch.nn.init.normal_(experts.gate_up_proj, std=0.01) + torch.nn.init.normal_(experts.down_proj, std=0.01) + + T = 6 + x = torch.randn(T, 32, device=device) + indices = torch.tensor( + [[0, 1], [1, 2], [2, 3], [0, 3], [1, 0], [3, 2]], device=device + ) + weights = torch.ones(T, 2, device=device) * 0.5 + + with torch.no_grad(): + out = experts(x, indices, weights) + + assert out.shape == (T, 32) + + def test_weighted_output(self) -> None: + """Output is weighted by routing weights — zero weight means no contribution.""" + experts = Qwen35Experts(num_experts=4, model_dim=16, expert_inner_dim=8).to( + device + ) + torch.nn.init.normal_(experts.gate_up_proj, std=0.01) + torch.nn.init.normal_(experts.down_proj, std=0.01) + + T = 4 + x = torch.randn(T, 16, device=device) + indices = torch.zeros(T, 2, dtype=torch.long, device=device) + weights_nonzero = torch.ones(T, 2, device=device) * 0.5 + weights_zero = torch.zeros(T, 2, device=device) + + with torch.no_grad(): + out_nonzero = experts(x, indices, weights_nonzero) + out_zero = experts(x, indices, weights_zero) + + assert_close(out_zero, torch.zeros_like(out_zero), atol=1e-6) + assert out_nonzero.abs().mean() > 1e-6 + + +class TestQwen35MoeBlock: + def test_forward_output_shape(self) -> None: + """MoeBlock output shape matches input (B, S, D).""" + moe = Qwen35MoeBlock( + model_dim=32, + num_experts=4, + num_experts_per_tok=2, + moe_intermediate_size=16, + shared_expert_intermediate_size=16, + ).to(device) + + seqs = torch.randn(2, 8, 32, device=device) + + with torch.no_grad(): + out = moe(seqs) + + assert out.shape == (2, 8, 32) + + def test_shared_expert_contributes(self) -> None: + """Shared expert output is non-zero (sigmoid gate blending).""" + moe = Qwen35MoeBlock( + model_dim=32, + num_experts=4, + num_experts_per_tok=2, + moe_intermediate_size=16, + shared_expert_intermediate_size=16, + ).to(device) + + seqs = torch.randn(1, 4, 32, device=device) + + with torch.no_grad(): + out = moe(seqs) + + assert out.abs().mean() > 1e-6 + + def test_drop_in_ffn_replacement(self) -> None: + """MoeBlock inherits FeedForwardNetwork and can be used as drop-in.""" + from fairseq2.models.transformer import FeedForwardNetwork + + moe = Qwen35MoeBlock( + model_dim=32, + num_experts=4, + num_experts_per_tok=2, + moe_intermediate_size=16, + shared_expert_intermediate_size=16, + ) + + assert isinstance(moe, FeedForwardNetwork) From c02b21fb8d6f4d7a68e5b94a6bb1a463b677a5fa Mon Sep 17 00:00:00 2001 From: yunchaoyang1 user Date: Wed, 27 May 2026 20:52:25 +0000 Subject: [PATCH 2/4] Remove extra recipe configs, keep only SFT example --- .../configs/qwen35_0.8b_fineweb_edu_10bt.yaml | 162 ------------------ .../configs/qwen35_2b_fineweb_edu_10bt.yaml | 162 ------------------ .../lm/train/configs/test_qwen35_0.8b.yaml | 155 ----------------- .../train/scripts/run_qwen35_fineweb_edu.sh | 24 --- 4 files changed, 503 deletions(-) delete mode 100644 recipes/lm/train/configs/qwen35_0.8b_fineweb_edu_10bt.yaml delete mode 100644 recipes/lm/train/configs/qwen35_2b_fineweb_edu_10bt.yaml delete mode 100644 recipes/lm/train/configs/test_qwen35_0.8b.yaml delete mode 100644 recipes/lm/train/scripts/run_qwen35_fineweb_edu.sh diff --git a/recipes/lm/train/configs/qwen35_0.8b_fineweb_edu_10bt.yaml b/recipes/lm/train/configs/qwen35_0.8b_fineweb_edu_10bt.yaml deleted file mode 100644 index 76992f316..000000000 --- a/recipes/lm/train/configs/qwen35_0.8b_fineweb_edu_10bt.yaml +++ /dev/null @@ -1,162 +0,0 @@ -# Qwen 3.5 0.8B Continued Pretraining on FineWeb-Edu 10BT -# -# Loads the pretrained Qwen 3.5 0.8B checkpoint and continues training -# on the FineWeb-Edu Sample 10BT educational text dataset. -# -# Prerequisites: -# 1. Convert parquet data to chunked JSONL: -# cd /checkpoint/smallomnillm/shared/data/fineweb-edu -# python convert_to_jsonl.py --sample 10BT --chunk-format --lightweight \ -# --num-shards 256 --output-dir jsonl/10BT -# -# Usage: -# torchrun --standalone --nproc_per_node=8 -m recipes.lm.train \ -# --config-file recipes/lm/train/configs/qwen35_0.8b_fineweb_edu_10bt.yaml \ -# /path/to/output_dir - -model: - name: qwen35_0.8b - dtype: bfloat16 - mmap: false - compile: false - compile_options: - fullgraph: false - dynamic: false - mode: default - backend: inductor - backend_options: null - -dataset: - name: null - family: lm_train - config_overrides: - sources: - - path: /checkpoint/smallomnillm/shared/data/fineweb-edu/jsonl/10BT - weight: 1.0 - max_seq_len: 4096 - max_num_tokens: 8192 - prefetch: 4 - sync_ranks: false - -tokenizer: - name: qwen35_0.8b - path: null - family: null - config_overrides: null - -gang: - tensor_parallel_size: 1 - timeout: 15 - high_priority: true - -trainer: - data_parallelism: fsdp - fsdp: - version: v2 - granularity: layer - hybrid: false - reshard_after_forward: true - fp32_reduce: true - mixed_precision: - mode: static - dtype: bfloat16 - grad_accumulation: - num_batches: 1 - no_sync: false - activation_checkpointing: - mode: layerwise - every_nth_layer: 1 - max_grad_norm: 1.0 - fp16_loss_scale: - - 128.0 - - 0.0001 - gc_every_n_steps: 1000 - grad_check: false - anomaly_detection: false - -optimizer: - name: adamw - config: - lr: 5.0e-5 - betas: - - 0.9 - - 0.95 - eps: 1.0e-8 - weight_decay: 0.1 - amsgrad: false - maximize: false - capturable: false - differentiable: false - impl: fused - groups: [] - -lr_scheduler: - name: cosine_annealing - config: - cycle_len: null - num_warmup_steps: 500 - cycle_mul: 1.0 - lr_mul: 1.0 - start_lr: 1.0e-30 - final_lr: null - final_lr_scale: 0.01 - -regime: - num_steps: 76000 - num_data_epochs: null - validate_at_start: false - validate_after_n_steps: 0 - validate_every_n_steps: 4000 - validate_after_n_data_epochs: 0 - validate_every_n_data_epochs: null - score_metric: null - checkpoint_after_n_steps: 0 - checkpoint_every_n_steps: 2000 - checkpoint_after_n_data_epochs: 0 - checkpoint_every_n_data_epochs: null - save_model_only: all_but_last - export_hugging_face: true - keep_last_n_checkpoints: 3 - keep_best_n_checkpoints: null - keep_checkpoint_every_n_steps: 10000 - publish_metrics_after_n_steps: 0 - publish_metrics_every_n_steps: 10 - publish_metrics_after_n_data_epochs: 0 - publish_metrics_every_n_data_epochs: null - -common: - torch: - num_threads: null - allow_tf32: true - fp16_reduced_precision: true - bf16_reduced_precision: true - default_sdpa: torch - compiled_region_activation_memory_budget: 0.9 - metric_recorders: - tensorboard: - enabled: true - wandb: - enabled: true - entity: smallomni - project: qwen35_0.8b_fineweb_edu_10bt - run_id: persistent - run_name: null - group: null - job_type: null - resume_mode: null - profilers: - torch: - enabled: false - skip_n_steps: 4 - wait_n_steps: 0 - num_warmup_steps: 1 - num_active_steps: 4 - repeat: 1 - assets: - extra_paths: [] - prev_checkpoint_dir: null - seed: 2 - debug: false - cluster: auto - no_sweep_dir: false - sweep_format: ws_{world_size}.{hash} diff --git a/recipes/lm/train/configs/qwen35_2b_fineweb_edu_10bt.yaml b/recipes/lm/train/configs/qwen35_2b_fineweb_edu_10bt.yaml deleted file mode 100644 index 4324fe9d8..000000000 --- a/recipes/lm/train/configs/qwen35_2b_fineweb_edu_10bt.yaml +++ /dev/null @@ -1,162 +0,0 @@ -# Qwen 3.5 2B Continued Pretraining on FineWeb-Edu 10BT -# -# Loads the pretrained Qwen 3.5 2B checkpoint and continues training -# on the FineWeb-Edu Sample 10BT educational text dataset. -# -# Prerequisites: -# 1. Convert parquet data to chunked JSONL: -# cd /checkpoint/smallomnillm/shared/data/fineweb-edu -# python convert_to_jsonl.py --sample 10BT --chunk-format --lightweight \ -# --num-shards 256 --output-dir jsonl/10BT -# -# Usage: -# torchrun --standalone --nproc_per_node=8 -m recipes.lm.train \ -# --config-file recipes/lm/train/configs/qwen35_2b_fineweb_edu_10bt.yaml \ -# /path/to/output_dir - -model: - name: qwen35_2b - dtype: bfloat16 - mmap: false - compile: false - compile_options: - fullgraph: false - dynamic: false - mode: default - backend: inductor - backend_options: null - -dataset: - name: null - family: lm_train - config_overrides: - sources: - - path: /checkpoint/smallomnillm/shared/data/fineweb-edu/jsonl/10BT - weight: 1.0 - max_seq_len: 4096 - max_num_tokens: 8192 - prefetch: 4 - sync_ranks: false - -tokenizer: - name: qwen35_2b - path: null - family: null - config_overrides: null - -gang: - tensor_parallel_size: 1 - timeout: 15 - high_priority: true - -trainer: - data_parallelism: fsdp - fsdp: - version: v2 - granularity: layer - hybrid: false - reshard_after_forward: true - fp32_reduce: true - mixed_precision: - mode: static - dtype: bfloat16 - grad_accumulation: - num_batches: 1 - no_sync: false - activation_checkpointing: - mode: layerwise - every_nth_layer: 1 - max_grad_norm: 1.0 - fp16_loss_scale: - - 128.0 - - 0.0001 - gc_every_n_steps: 1000 - grad_check: false - anomaly_detection: false - -optimizer: - name: adamw - config: - lr: 5.0e-5 - betas: - - 0.9 - - 0.95 - eps: 1.0e-8 - weight_decay: 0.1 - amsgrad: false - maximize: false - capturable: false - differentiable: false - impl: fused - groups: [] - -lr_scheduler: - name: cosine_annealing - config: - cycle_len: null - num_warmup_steps: 500 - cycle_mul: 1.0 - lr_mul: 1.0 - start_lr: 1.0e-30 - final_lr: null - final_lr_scale: 0.01 - -regime: - num_steps: 76000 - num_data_epochs: null - validate_at_start: false - validate_after_n_steps: 0 - validate_every_n_steps: 4000 - validate_after_n_data_epochs: 0 - validate_every_n_data_epochs: null - score_metric: null - checkpoint_after_n_steps: 0 - checkpoint_every_n_steps: 2000 - checkpoint_after_n_data_epochs: 0 - checkpoint_every_n_data_epochs: null - save_model_only: all_but_last - export_hugging_face: true - keep_last_n_checkpoints: 3 - keep_best_n_checkpoints: null - keep_checkpoint_every_n_steps: 10000 - publish_metrics_after_n_steps: 0 - publish_metrics_every_n_steps: 10 - publish_metrics_after_n_data_epochs: 0 - publish_metrics_every_n_data_epochs: null - -common: - torch: - num_threads: null - allow_tf32: true - fp16_reduced_precision: true - bf16_reduced_precision: true - default_sdpa: torch - compiled_region_activation_memory_budget: 0.9 - metric_recorders: - tensorboard: - enabled: true - wandb: - enabled: true - entity: smallomni - project: qwen35_2b_fineweb_edu_10bt - run_id: persistent - run_name: null - group: null - job_type: null - resume_mode: null - profilers: - torch: - enabled: false - skip_n_steps: 4 - wait_n_steps: 0 - num_warmup_steps: 1 - num_active_steps: 4 - repeat: 1 - assets: - extra_paths: [] - prev_checkpoint_dir: null - seed: 2 - debug: false - cluster: auto - no_sweep_dir: false - sweep_format: ws_{world_size}.{hash} diff --git a/recipes/lm/train/configs/test_qwen35_0.8b.yaml b/recipes/lm/train/configs/test_qwen35_0.8b.yaml deleted file mode 100644 index 7a80a1236..000000000 --- a/recipes/lm/train/configs/test_qwen35_0.8b.yaml +++ /dev/null @@ -1,155 +0,0 @@ -# Qwen 3.5 0.8B Quick Test Config (Continued Pretraining) -# -# Full model (24 layers, no num_layers override — required for loading -# pretrained checkpoint) with only 100 steps for smoke testing. -# Verifies: model loads, pad_idx check passes, data loads, loss computes. -# -# NOTE: Cannot reduce num_layers for continued pretraining because the -# checkpoint has 24 layers of weights that must match the model architecture. -# -# Usage: -# source ~/envs/fs081-pt290-cu128/bin/activate -# cd /storage/home/yunchaoyang1/fairseq2 -# torchrun --standalone --nproc_per_node=1 -m recipes.lm.train \ -# --config-file recipes/lm/train/configs/test_qwen35_0.8b.yaml \ -# /tmp/qwen35_test - -model: - name: qwen35_0.8b - dtype: bfloat16 - mmap: false - compile: false - compile_options: - fullgraph: false - dynamic: false - mode: default - backend: inductor - backend_options: null - -dataset: - name: null - family: lm_train - config_overrides: - sources: - - path: /checkpoint/smallomnillm/shared/data/fineweb-edu/jsonl/10BT - weight: 1.0 - max_seq_len: 512 - max_num_tokens: 2048 - prefetch: 4 - sync_ranks: false - -tokenizer: - name: qwen35_0.8b - path: null - family: null - config_overrides: null - -gang: - tensor_parallel_size: 1 - timeout: 15 - high_priority: true - -trainer: - data_parallelism: fsdp - fsdp: - version: v2 - granularity: layer - hybrid: false - reshard_after_forward: true - fp32_reduce: true - mixed_precision: - mode: static - dtype: bfloat16 - grad_accumulation: - num_batches: 1 - no_sync: false - activation_checkpointing: - mode: off - every_nth_layer: 1 - max_grad_norm: 1.0 - fp16_loss_scale: - - 128.0 - - 0.0001 - gc_every_n_steps: 1000 - grad_check: false - anomaly_detection: false - -optimizer: - name: adamw - config: - lr: 5.0e-5 - betas: - - 0.9 - - 0.95 - eps: 1.0e-8 - weight_decay: 0.1 - amsgrad: false - maximize: false - capturable: false - differentiable: false - impl: fused - groups: [] - -lr_scheduler: - name: cosine_annealing - config: - cycle_len: null - num_warmup_steps: 10 - cycle_mul: 1.0 - lr_mul: 1.0 - start_lr: 1.0e-30 - final_lr: null - final_lr_scale: 0.01 - -regime: - num_steps: 100 - num_data_epochs: null - validate_at_start: false - validate_after_n_steps: 0 - validate_every_n_steps: 4000 - validate_after_n_data_epochs: 0 - validate_every_n_data_epochs: null - score_metric: null - checkpoint_after_n_steps: 0 - checkpoint_every_n_steps: 50 - checkpoint_after_n_data_epochs: 0 - checkpoint_every_n_data_epochs: null - save_model_only: all_but_last - export_hugging_face: false - keep_last_n_checkpoints: 2 - keep_best_n_checkpoints: null - keep_checkpoint_every_n_steps: null - publish_metrics_after_n_steps: 0 - publish_metrics_every_n_steps: 1 - publish_metrics_after_n_data_epochs: 0 - publish_metrics_every_n_data_epochs: null - -common: - torch: - num_threads: null - allow_tf32: true - fp16_reduced_precision: true - bf16_reduced_precision: true - default_sdpa: torch - compiled_region_activation_memory_budget: 0.9 - metric_recorders: - tensorboard: - enabled: false - wandb: - enabled: false - profilers: - torch: - enabled: false - skip_n_steps: 4 - wait_n_steps: 0 - num_warmup_steps: 1 - num_active_steps: 4 - repeat: 1 - assets: - extra_paths: [] - prev_checkpoint_dir: null - seed: 2 - debug: false - cluster: auto - no_sweep_dir: false - sweep_format: ws_{world_size}.{hash} diff --git a/recipes/lm/train/scripts/run_qwen35_fineweb_edu.sh b/recipes/lm/train/scripts/run_qwen35_fineweb_edu.sh deleted file mode 100644 index d23b30de6..000000000 --- a/recipes/lm/train/scripts/run_qwen35_fineweb_edu.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=qwen35_pretrain_fineweb -#SBATCH --nodes=1 -#SBATCH --gpus-per-node=8 -#SBATCH --ntasks-per-node=1 -#SBATCH --cpus-per-task=96 -#SBATCH --mem=0 -#SBATCH --time=48:00:00 -#SBATCH --account=smallomnillm -#SBATCH --qos=h200_smallomnillm_high - -#SBATCH --output=/checkpoint/smallomnillm/yunchaoyang1/qwen35_pretrain/slurm_%j.out -#SBATCH --error=/checkpoint/smallomnillm/yunchaoyang1/qwen35_pretrain/slurm_%j.err - -OUTPUT_DIR=/checkpoint/smallomnillm/yunchaoyang1/qwen35_pretrain/baseline - -mkdir -p "${OUTPUT_DIR}" - -source ~/envs/fs081-pt290-cu128/bin/activate -cd /storage/home/yunchaoyang1/fairseq2 - -torchrun --standalone --nproc_per_node=8 -m recipes.lm.train \ - --config-file recipes/lm/train/configs/qwen35_0.8b_fineweb_edu_10bt.yaml \ - "${OUTPUT_DIR}" From 91dd827737d724e712344529b6cbd001ed4e9465 Mon Sep 17 00:00:00 2001 From: yunchaoyang1 user Date: Wed, 27 May 2026 20:57:26 +0000 Subject: [PATCH 3/4] Consolidate Qwen 3.5 unit tests into single file Merge 5 separate test files (test_gated_delta_net, test_qwen35_attention, test_qwen35_decoder_layer, test_qwen35_interop, test_qwen35_moe) into one test_qwen35.py with 19 focused tests covering all modules. --- .../unit/models/qwen/test_gated_delta_net.py | 152 ------ tests/unit/models/qwen/test_qwen35.py | 395 ++++++++++++++ .../unit/models/qwen/test_qwen35_attention.py | 173 ------ .../models/qwen/test_qwen35_decoder_layer.py | 150 ------ tests/unit/models/qwen/test_qwen35_interop.py | 500 ------------------ tests/unit/models/qwen/test_qwen35_moe.py | 154 ------ 6 files changed, 395 insertions(+), 1129 deletions(-) delete mode 100644 tests/unit/models/qwen/test_gated_delta_net.py create mode 100644 tests/unit/models/qwen/test_qwen35.py delete mode 100644 tests/unit/models/qwen/test_qwen35_attention.py delete mode 100644 tests/unit/models/qwen/test_qwen35_decoder_layer.py delete mode 100644 tests/unit/models/qwen/test_qwen35_interop.py delete mode 100644 tests/unit/models/qwen/test_qwen35_moe.py diff --git a/tests/unit/models/qwen/test_gated_delta_net.py b/tests/unit/models/qwen/test_gated_delta_net.py deleted file mode 100644 index d7826fec2..000000000 --- a/tests/unit/models/qwen/test_gated_delta_net.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -import pytest -import torch -import torch.nn.functional as F - -from fairseq2.models.qwen.gated_delta_net import ( - GatedDeltaNet, - GatedDeltaNetState, - RMSNormGated, - torch_chunk_gated_delta_rule, - torch_recurrent_gated_delta_rule, -) -from fairseq2.nn import IncrementalStateBag -from tests.common import assert_close, device - - -class TestGatedDeltaNet: - def test_forward_produces_correct_shape(self) -> None: - """GatedDeltaNet forward output shape matches input shape (B, S, D).""" - gdn = GatedDeltaNet( - hidden_size=64, - num_k_heads=2, - num_v_heads=4, - head_k_dim=16, - head_v_dim=16, - conv_kernel_size=4, - ) - gdn = gdn.to(device) - - seqs = torch.randn(2, 8, 64, device=device) - with torch.no_grad(): - out = gdn(seqs) - - assert out.shape == (2, 8, 64) - - def test_incremental_decode_matches_full_forward(self) -> None: - """Step-by-step decode with IncrementalStateBag matches full forward.""" - gdn = GatedDeltaNet( - hidden_size=64, - num_k_heads=2, - num_v_heads=4, - head_k_dim=16, - head_v_dim=16, - ) - gdn = gdn.to(device).eval() - - seq_len = 8 - seqs = torch.randn(1, seq_len, 64, device=device) - - with torch.no_grad(): - full_out = gdn(seqs) - - state_bag = IncrementalStateBag(max_num_steps=seq_len) - - with torch.no_grad(): - prefill_out = gdn(seqs, state_bag=state_bag) - - assert_close(prefill_out, full_out, atol=1e-5) - - def test_chunked_vs_recurrent_consistency(self) -> None: - """torch_chunk_gated_delta_rule and torch_recurrent_gated_delta_rule - produce the same output for the same input.""" - B, S, H, K, V = 1, 16, 4, 16, 16 - q = torch.randn(B, S, H, K, device=device) - k = torch.randn(B, S, H, K, device=device) - v = torch.randn(B, S, H, V, device=device) - g = -torch.rand(B, S, H, device=device).abs() - beta = torch.rand(B, S, H, device=device) - - chunk_out, chunk_state = torch_chunk_gated_delta_rule( - q, k, v, g, beta, output_final_state=True, use_qk_l2norm_in_kernel=True - ) - recurrent_out, recurrent_state = torch_recurrent_gated_delta_rule( - q, k, v, g, beta, output_final_state=True, use_qk_l2norm_in_kernel=True - ) - - assert_close(chunk_out, recurrent_out, atol=1e-4) - assert chunk_state is not None - assert recurrent_state is not None - assert_close(chunk_state, recurrent_state, atol=1e-4) - - def test_gated_delta_net_state_reorder(self) -> None: - """GatedDeltaNetState.reorder correctly reorders batch dimension.""" - conv = torch.randn(3, 8, 3, device=device) - rec = torch.randn(3, 4, 16, 16, device=device) - state = GatedDeltaNetState(conv, rec) - - new_order = torch.tensor([2, 0, 1], device=device) - state.reorder(new_order) - - assert_close(state.conv_state[0], conv[2]) - assert_close(state.conv_state[1], conv[0]) - assert_close(state.recurrent_state[0], rec[2]) - - def test_rmsnorm_gated_output(self) -> None: - """RMSNormGated produces norm(x) * silu(gate).""" - dim = 16 - norm = RMSNormGated(dim).to(device) - - x = torch.randn(4, dim, device=device) - gate = torch.randn(4, dim, device=device) - - out = norm(x, gate) - - x_f32 = x.float() - variance = x_f32.pow(2).mean(-1, keepdim=True) - x_normed = x_f32 * torch.rsqrt(variance + 1e-6) - assert norm.inner_norm.weight is not None - expected = (norm.inner_norm.weight * x_normed) * F.silu(gate.float()) - - assert_close(out, expected.to(out.dtype), atol=1e-5) - - @pytest.mark.skipif( - not torch.cuda.is_available(), - reason="causal_conv1d incremental decode requires CUDA", - ) - def test_step_by_step_decode_matches_prefill(self) -> None: - """After prefilling, incremental decode of one token matches full forward.""" - gdn = GatedDeltaNet( - hidden_size=64, - num_k_heads=2, - num_v_heads=4, - head_k_dim=16, - head_v_dim=16, - ) - gdn = gdn.to(device).eval() - - prefill_len = 8 - full_seq = torch.randn(1, prefill_len + 1, 64, device=device) - - with torch.no_grad(): - full_out = gdn(full_seq) - ground_truth = full_out[:, -1:, :] - - state_bag = IncrementalStateBag(max_num_steps=prefill_len + 1) - - with torch.no_grad(): - gdn(full_seq[:, :prefill_len, :], state_bag=state_bag) - - state_bag.increment_step_nr(prefill_len) - - with torch.no_grad(): - incr_out = gdn(full_seq[:, prefill_len:, :], state_bag=state_bag) - - assert_close(incr_out, ground_truth, atol=1e-4) diff --git a/tests/unit/models/qwen/test_qwen35.py b/tests/unit/models/qwen/test_qwen35.py new file mode 100644 index 000000000..9c93e7396 --- /dev/null +++ b/tests/unit/models/qwen/test_qwen35.py @@ -0,0 +1,395 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for the Qwen 3.5 model family (dense + MoE).""" + +from __future__ import annotations + +import pytest +import torch +from torch.testing import assert_close + +from fairseq2.models.qwen.attention import Qwen35Attention +from fairseq2.models.qwen.config import Qwen35Config, Qwen35MoeConfig +from fairseq2.models.qwen.decoder_layer import Qwen35DecoderLayer +from fairseq2.models.qwen.factory import create_qwen35_model, create_qwen35_moe_model +from fairseq2.models.qwen.gated_delta_net import ( + GatedDeltaNet, + GatedDeltaNetState, + torch_chunk_gated_delta_rule, + torch_recurrent_gated_delta_rule, +) +from fairseq2.models.qwen.interop import ( + _QWEN35_HG_KEY_MAP, + _QWEN35_RMSNORM_KEYS, + _Qwen35HuggingFaceConverter, + _Qwen35MoeHuggingFaceConverter, + convert_qwen35_moe_state_dict, + convert_qwen35_state_dict, +) +from fairseq2.models.qwen.moe import Qwen35MoeBlock +from fairseq2.models.transformer import FeedForwardNetwork +from fairseq2.models.transformer.attention_bias import ( + AttentionBiasCache, + CausalAttentionBias, + IdentityBias, +) +from fairseq2.models.transformer.sdpa.naive import NaiveSDPA +from fairseq2.models.utils.checkpoint import convert_state_dict, create_reverse_key_map +from fairseq2.nn import BatchLayout, IncrementalStateBag +from tests.common import assert_close as fs2_assert_close +from tests.common import device + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + + +def _small_dense_config() -> Qwen35Config: + config = Qwen35Config() + config.model_dim = 64 + config.vocab_size = 128 + config.num_layers = 4 + config.num_attn_heads = 4 + config.num_key_value_heads = 2 + config.head_dim = 16 + config.ffn_inner_dim = 128 + config.partial_rotary_factor = 0.25 + config.linear_num_key_heads = 2 + config.linear_num_value_heads = 4 + config.linear_key_head_dim = 8 + config.linear_value_head_dim = 8 + config.layer_types = None + config.__post_init__() + return config + + +def _small_moe_config() -> Qwen35MoeConfig: + config = Qwen35MoeConfig() + config.model_dim = 64 + config.vocab_size = 128 + config.num_layers = 4 + config.num_attn_heads = 4 + config.num_key_value_heads = 2 + config.head_dim = 16 + config.ffn_inner_dim = 128 + config.partial_rotary_factor = 0.25 + config.linear_num_key_heads = 2 + config.linear_num_value_heads = 4 + config.linear_key_head_dim = 8 + config.linear_value_head_dim = 8 + config.num_experts = 4 + config.num_experts_per_tok = 2 + config.moe_intermediate_size = 32 + config.shared_expert_intermediate_size = 32 + config.layer_types = None + config.__post_init__() + return config + + +# --------------------------------------------------------------------------- +# GatedDeltaNet +# --------------------------------------------------------------------------- + + +class TestGatedDeltaNet: + def test_forward_shape(self) -> None: + gdn = GatedDeltaNet( + hidden_size=64, + num_k_heads=2, + num_v_heads=4, + head_k_dim=16, + head_v_dim=16, + conv_kernel_size=4, + ).to(device) + out = gdn(torch.randn(2, 8, 64, device=device)) + assert out.shape == (2, 8, 64) + + def test_chunked_vs_recurrent(self) -> None: + B, S, H, K, V = 1, 16, 4, 16, 16 + q = torch.randn(B, S, H, K, device=device) + k = torch.randn(B, S, H, K, device=device) + v = torch.randn(B, S, H, V, device=device) + g = -torch.rand(B, S, H, device=device).abs() + beta = torch.rand(B, S, H, device=device) + + c_out, c_st = torch_chunk_gated_delta_rule( + q, k, v, g, beta, output_final_state=True, use_qk_l2norm_in_kernel=True + ) + r_out, r_st = torch_recurrent_gated_delta_rule( + q, k, v, g, beta, output_final_state=True, use_qk_l2norm_in_kernel=True + ) + fs2_assert_close(c_out, r_out, atol=1e-4) + assert c_st is not None and r_st is not None + fs2_assert_close(c_st, r_st, atol=1e-4) + + def test_state_reorder(self) -> None: + conv = torch.randn(3, 8, 3, device=device) + rec = torch.randn(3, 4, 16, 16, device=device) + state = GatedDeltaNetState(conv, rec) + state.reorder(torch.tensor([2, 0, 1], device=device)) + fs2_assert_close(state.conv_state[0], conv[2]) + fs2_assert_close(state.recurrent_state[0], rec[2]) + + @pytest.mark.skipif( + not torch.cuda.is_available(), + reason="causal_conv1d incremental decode requires CUDA", + ) + def test_incremental_decode(self) -> None: + gdn = ( + GatedDeltaNet( + hidden_size=64, + num_k_heads=2, + num_v_heads=4, + head_k_dim=16, + head_v_dim=16, + ) + .to(device) + .eval() + ) + + full_seq = torch.randn(1, 9, 64, device=device) + with torch.no_grad(): + full_out = gdn(full_seq) + + state_bag = IncrementalStateBag(max_num_steps=9) + with torch.no_grad(): + gdn(full_seq[:, :8, :], state_bag=state_bag) + state_bag.increment_step_nr(8) + with torch.no_grad(): + incr_out = gdn(full_seq[:, 8:, :], state_bag=state_bag) + fs2_assert_close(incr_out, full_out[:, -1:, :], atol=1e-4) + + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- + + +class TestQwen35Attention: + def test_forward_shape(self) -> None: + sdpa = NaiveSDPA(IdentityBias()) + attn = Qwen35Attention(model_dim=64, num_heads=4, sdpa=sdpa, head_dim=16).to( + device + ) + seqs = torch.randn(2, 8, 64, device=device) + layout = BatchLayout.of(seqs) + bias_cache = AttentionBiasCache() + with torch.no_grad(): + out = attn(seqs, layout, seqs, layout, seqs, bias_cache) + assert out.shape == (2, 8, 64) + + def test_gqa(self) -> None: + sdpa = NaiveSDPA(IdentityBias()) + attn = Qwen35Attention( + model_dim=64, + num_heads=4, + sdpa=sdpa, + head_dim=16, + num_key_value_heads=2, + ).to(device) + seqs = torch.randn(2, 6, 64, device=device) + layout = BatchLayout.of(seqs) + with torch.no_grad(): + out = attn(seqs, layout, seqs, layout, seqs, AttentionBiasCache()) + assert out.shape == (2, 6, 64) + + def test_incremental_kv_cache(self) -> None: + sdpa = NaiveSDPA(CausalAttentionBias()) + attn = Qwen35Attention(model_dim=64, num_heads=4, sdpa=sdpa, head_dim=16).to( + device + ) + attn.eval() + + seqs = torch.randn(1, 6, 64, device=device) + layout = BatchLayout.of(seqs) + bias_cache = AttentionBiasCache() + with torch.no_grad(): + full_out = attn(seqs, layout, seqs, layout, seqs, bias_cache) + + state_bag = IncrementalStateBag(max_num_steps=32) + with torch.no_grad(): + for idx in range(6): + step = seqs[:, idx : idx + 1, :] + sl = BatchLayout.of(step) + out = attn(step, sl, step, sl, step, bias_cache, state_bag=state_bag) + fs2_assert_close(out, full_out[:, idx : idx + 1, :], atol=1e-5) + state_bag.increment_step_nr() + + +# --------------------------------------------------------------------------- +# Model factory +# --------------------------------------------------------------------------- + + +class TestQwen35Factory: + def test_small_model_forward(self) -> None: + config = _small_dense_config() + model = create_qwen35_model(config).to(device).eval() + ids = torch.randint(0, 128, (1, 16), device=device) + with torch.no_grad(): + logits = model(ids, BatchLayout.of(ids)) + assert logits.shape == (1, 16, 128) + + def test_hybrid_layer_pattern(self) -> None: + config = _small_dense_config() + with torch.device("meta"): + model = create_qwen35_model(config) + types = [ + l.layer_type + for l in model.decoder.layers + if isinstance(l, Qwen35DecoderLayer) + ] + assert types == [ + "linear_attention", + "linear_attention", + "linear_attention", + "full_attention", + ] + + +# --------------------------------------------------------------------------- +# MoE +# --------------------------------------------------------------------------- + + +class TestQwen35Moe: + def test_moe_block_shape(self) -> None: + moe = Qwen35MoeBlock( + model_dim=32, + num_experts=4, + num_experts_per_tok=2, + moe_intermediate_size=16, + shared_expert_intermediate_size=16, + ).to(device) + with torch.no_grad(): + out = moe(torch.randn(2, 8, 32, device=device)) + assert out.shape == (2, 8, 32) + + def test_moe_is_ffn(self) -> None: + moe = Qwen35MoeBlock( + model_dim=32, + num_experts=4, + num_experts_per_tok=2, + moe_intermediate_size=16, + shared_expert_intermediate_size=16, + ) + assert isinstance(moe, FeedForwardNetwork) + + +# --------------------------------------------------------------------------- +# Interop (state dict conversion) +# --------------------------------------------------------------------------- + + +class TestQwen35Interop: + def test_key_round_trip(self) -> None: + config = _small_dense_config() + with torch.device("meta"): + model = create_qwen35_model(config) + fs2_keys = set(model.state_dict().keys()) + + sd: dict[str, object] = {k: torch.empty(0) for k in fs2_keys} + rev = create_reverse_key_map(_QWEN35_HG_KEY_MAP) + hg_sd = convert_state_dict(sd, rev) + rt_keys = set(convert_state_dict(dict(hg_sd), _QWEN35_HG_KEY_MAP).keys()) + assert fs2_keys == rt_keys + + def test_rmsnorm_plus_one(self) -> None: + config = _small_dense_config() + hf_sd: dict[str, object] = {} + for i in range(config.num_layers): + hf_sd[f"model.layers.{i}.input_layernorm.weight"] = torch.zeros( + config.model_dim + ) + hf_sd[f"model.layers.{i}.post_attention_layernorm.weight"] = torch.zeros( + config.model_dim + ) + hf_sd["model.norm.weight"] = torch.zeros(config.model_dim) + hf_sd["model.embed_tokens.weight"] = torch.zeros( + config.vocab_size, config.model_dim + ) + hf_sd["lm_head.weight"] = torch.zeros(config.vocab_size, config.model_dim) + + converted = convert_qwen35_state_dict(dict(hf_sd), config) + for key in converted: + if any(key.endswith(s) for s in _QWEN35_RMSNORM_KEYS): + weight = converted[key] + assert isinstance(weight, torch.Tensor) + assert_close(weight, torch.ones_like(weight)) + + def test_tied_embeddings(self) -> None: + config = _small_dense_config() + config.tied_embeddings = True + weight = torch.randn(config.vocab_size, config.model_dim) + hf_sd: dict[str, object] = { + "model.embed_tokens.weight": weight, + "model.norm.weight": torch.zeros(config.model_dim), + } + result = convert_qwen35_state_dict(dict(hf_sd), config) + assert "decoder_frontend.embed.weight" in result + assert "final_proj.weight" in result + assert result["final_proj.weight"] is result["decoder_frontend.embed.weight"] + + def test_vl_keys_filtered(self) -> None: + config = _small_dense_config() + config.tied_embeddings = True + sd: dict[str, object] = { + "model.language_model.embed_tokens.weight": torch.randn( + config.vocab_size, config.model_dim + ), + "model.language_model.norm.weight": torch.zeros(config.model_dim), + "model.visual.blocks.0.attn.proj.weight": torch.empty(0), + "mtp.fc.weight": torch.empty(0), + } + result = convert_qwen35_state_dict(dict(sd), config) + for key in result: + assert not key.startswith(("model.visual.", "mtp.")) + + +# --------------------------------------------------------------------------- +# HuggingFace converter (bidirectional) +# --------------------------------------------------------------------------- + + +class TestQwen35HuggingFaceConverter: + def test_dense_round_trip(self) -> None: + config = _small_dense_config() + with torch.device("meta"): + model = create_qwen35_model(config) + fs2_keys = set(model.state_dict().keys()) + sd: dict[str, object] = {k: torch.empty(0) for k in fs2_keys} + + converter = _Qwen35HuggingFaceConverter() + hg_sd = converter.to_hg_state_dict(sd, config) + rt_keys = set(convert_qwen35_state_dict(dict(hg_sd), config).keys()) + assert fs2_keys == rt_keys + + def test_to_hg_config(self) -> None: + config = _small_dense_config() + hg_config = _Qwen35HuggingFaceConverter().to_hg_config(config) + assert hg_config.kls_name == "Qwen3_5TextConfig" + assert hg_config.arch == "Qwen3_5ForCausalLM" + assert hg_config.data["hidden_size"] == config.model_dim + + def test_moe_round_trip(self) -> None: + config = _small_moe_config() + with torch.device("meta"): + model = create_qwen35_moe_model(config) + fs2_keys = set(model.state_dict().keys()) + sd: dict[str, object] = {k: torch.empty(0) for k in fs2_keys} + + converter = _Qwen35MoeHuggingFaceConverter() + hg_sd = converter.to_hg_state_dict(sd, config) + rt_keys = set(convert_qwen35_moe_state_dict(dict(hg_sd), config).keys()) + assert fs2_keys == rt_keys + + def test_moe_to_hg_config(self) -> None: + config = _small_moe_config() + hg_config = _Qwen35MoeHuggingFaceConverter().to_hg_config(config) + assert hg_config.kls_name == "Qwen3_5TextConfig" + assert hg_config.arch == "Qwen3_5MoeForCausalLM" + assert hg_config.data["num_experts"] == config.num_experts diff --git a/tests/unit/models/qwen/test_qwen35_attention.py b/tests/unit/models/qwen/test_qwen35_attention.py deleted file mode 100644 index af9d058b5..000000000 --- a/tests/unit/models/qwen/test_qwen35_attention.py +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -import torch - -from fairseq2.models.qwen.attention import Qwen35Attention -from fairseq2.models.transformer.attention_bias import ( - AttentionBiasCache, - CausalAttentionBias, - IdentityBias, -) -from fairseq2.models.transformer.sdpa.naive import NaiveSDPA -from fairseq2.nn import BatchLayout, IncrementalStateBag, RMSNorm -from fairseq2.nn.position_encoder import ReferenceRotaryEncoder -from tests.common import assert_close, device - - -class TestQwen35Attention: - def test_forward_produces_correct_shape(self) -> None: - """Output shape is (B, S, model_dim).""" - sdpa = NaiveSDPA(IdentityBias()) - attn = Qwen35Attention(model_dim=64, num_heads=4, sdpa=sdpa, head_dim=16) - attn = attn.to(device) - - seqs = torch.randn(2, 8, 64, device=device) - layout = BatchLayout.of(seqs) - bias_cache = AttentionBiasCache() - - with torch.no_grad(): - out = attn(seqs, layout, seqs, layout, seqs, bias_cache) - - assert out.shape == (2, 8, 64) - - def test_output_gating_effect(self) -> None: - """When gate output is all zeros, attention output should be near zero.""" - sdpa = NaiveSDPA(IdentityBias()) - attn = Qwen35Attention(model_dim=64, num_heads=4, sdpa=sdpa, head_dim=16) - attn = attn.to(device) - - seqs = torch.randn(1, 4, 64, device=device) - layout = BatchLayout.of(seqs) - bias_cache = AttentionBiasCache() - - with torch.no_grad(): - out1 = attn(seqs, layout, seqs, layout, seqs, bias_cache) - - # Verify output is not zero (gate should be non-trivial with random weights) - assert out1.abs().mean() > 1e-6 - - def test_partial_rope_applies_to_subset_of_dims(self) -> None: - """With encoding_dim < head_dim, only first encoding_dim dims should be rotated.""" - model_dim = 64 - num_heads = 4 - head_dim = 16 - encoding_dim = 4 # Only first 4 of 16 dims rotated - - rope = ReferenceRotaryEncoder(encoding_dim, max_seq_len=32, device=device) - sdpa = NaiveSDPA(IdentityBias()) - attn = Qwen35Attention( - model_dim=model_dim, - num_heads=num_heads, - sdpa=sdpa, - head_dim=head_dim, - pos_encoder=rope, - ) - attn = attn.to(device) - - seqs = torch.randn(1, 4, model_dim, device=device) - layout = BatchLayout.of(seqs) - bias_cache = AttentionBiasCache() - - with torch.no_grad(): - out = attn(seqs, layout, seqs, layout, seqs, bias_cache) - - assert out.shape == (1, 4, model_dim) - - def test_gqa_with_fewer_kv_heads(self) -> None: - """GQA with num_key_value_heads < num_heads works correctly.""" - sdpa = NaiveSDPA(IdentityBias()) - attn = Qwen35Attention( - model_dim=64, - num_heads=4, - sdpa=sdpa, - head_dim=16, - num_key_value_heads=2, # GQA: 4 Q heads, 2 KV heads - ) - attn = attn.to(device) - - seqs = torch.randn(2, 6, 64, device=device) - layout = BatchLayout.of(seqs) - bias_cache = AttentionBiasCache() - - with torch.no_grad(): - out = attn(seqs, layout, seqs, layout, seqs, bias_cache) - - assert out.shape == (2, 6, 64) - - def test_qk_norm_applied(self) -> None: - """When q_norm and k_norm are provided, output should differ from no-norm case.""" - sdpa = NaiveSDPA(IdentityBias()) - - # Without norms - attn_no_norm = Qwen35Attention( - model_dim=64, num_heads=4, sdpa=sdpa, head_dim=16 - ) - attn_no_norm = attn_no_norm.to(device) - - # With norms - q_norm = RMSNorm(16, bias=False, device=device) - k_norm = RMSNorm(16, bias=False, device=device) - attn_norm = Qwen35Attention( - model_dim=64, - num_heads=4, - sdpa=sdpa, - head_dim=16, - q_norm=q_norm, - k_norm=k_norm, - ) - attn_norm = attn_norm.to(device) - - # Copy weights so only the norm makes a difference - attn_norm.q_proj.weight.data.copy_(attn_no_norm.q_proj.weight.data) - attn_norm.k_proj.weight.data.copy_(attn_no_norm.k_proj.weight.data) - attn_norm.v_proj.weight.data.copy_(attn_no_norm.v_proj.weight.data) - attn_norm.output_proj.weight.data.copy_(attn_no_norm.output_proj.weight.data) - - seqs = torch.randn(1, 4, 64, device=device) - layout = BatchLayout.of(seqs) - bias_cache = AttentionBiasCache() - - with torch.no_grad(): - out_no_norm = attn_no_norm(seqs, layout, seqs, layout, seqs, bias_cache) - out_norm = attn_norm(seqs, layout, seqs, layout, seqs, bias_cache) - - # Outputs should differ because of norm - assert not torch.allclose(out_no_norm, out_norm, atol=1e-6) - - def test_incremental_kv_cache_matches_full_forward(self) -> None: - """Token-by-token decoding with KV cache produces the same logits as causal full-sequence forward.""" - sdpa = NaiveSDPA(CausalAttentionBias()) - attn = Qwen35Attention(model_dim=64, num_heads=4, sdpa=sdpa, head_dim=16) - attn = attn.to(device) - attn.eval() - - seqs = torch.randn(1, 6, 64, device=device) - layout = BatchLayout.of(seqs) - bias_cache = AttentionBiasCache() - - with torch.no_grad(): - full_out = attn(seqs, layout, seqs, layout, seqs, bias_cache) - - state_bag = IncrementalStateBag(max_num_steps=32) - - with torch.no_grad(): - for idx in range(6): - step_seqs = seqs[:, idx : idx + 1, :] - step_layout = BatchLayout.of(step_seqs) - out = attn( - step_seqs, - step_layout, - step_seqs, - step_layout, - step_seqs, - bias_cache, - state_bag=state_bag, - ) - assert_close(out, full_out[:, idx : idx + 1, :], atol=1e-5) - state_bag.increment_step_nr() diff --git a/tests/unit/models/qwen/test_qwen35_decoder_layer.py b/tests/unit/models/qwen/test_qwen35_decoder_layer.py deleted file mode 100644 index d25363301..000000000 --- a/tests/unit/models/qwen/test_qwen35_decoder_layer.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -import pytest -import torch - -from fairseq2.models.qwen.attention import Qwen35Attention -from fairseq2.models.qwen.config import Qwen35Config -from fairseq2.models.qwen.decoder_layer import Qwen35DecoderLayer -from fairseq2.models.qwen.factory import create_qwen35_model -from fairseq2.models.qwen.gated_delta_net import GatedDeltaNet -from fairseq2.models.transformer import GLUFeedForwardNetwork -from fairseq2.models.transformer.attention_bias import ( - AttentionBiasCache, - IdentityBias, -) -from fairseq2.models.transformer.sdpa.naive import NaiveSDPA -from fairseq2.nn import BatchLayout, RMSNorm -from tests.common import device - - -class TestQwen35DecoderLayer: - def test_full_attention_layer_forward(self) -> None: - """Full attention layer produces correct shape.""" - model_dim = 64 - sdpa = NaiveSDPA(IdentityBias()) - self_attn = Qwen35Attention(model_dim, num_heads=4, sdpa=sdpa, head_dim=16) - ffn = GLUFeedForwardNetwork(model_dim, 128, bias=False, inner_dim_scale=1.0) - layer = Qwen35DecoderLayer( - "full_attention", - self_attn=self_attn, - linear_attn=None, - ffn=ffn, - self_attn_layer_norm=RMSNorm(model_dim, bias=False), - ffn_layer_norm=RMSNorm(model_dim, bias=False), - ).to(device) - - seqs = torch.randn(2, 8, model_dim, device=device) - layout = BatchLayout.of(seqs) - bias_cache = AttentionBiasCache() - - with torch.no_grad(): - out = layer(seqs, layout, bias_cache) - - assert out.shape == (2, 8, model_dim) - - def test_linear_attention_layer_forward(self) -> None: - """Linear attention (GatedDeltaNet) layer produces correct shape.""" - model_dim = 64 - gdn = GatedDeltaNet( - model_dim, num_k_heads=2, num_v_heads=4, head_k_dim=8, head_v_dim=8 - ) - ffn = GLUFeedForwardNetwork(model_dim, 128, bias=False, inner_dim_scale=1.0) - layer = Qwen35DecoderLayer( - "linear_attention", - self_attn=None, - linear_attn=gdn, - ffn=ffn, - self_attn_layer_norm=RMSNorm(model_dim, bias=False), - ffn_layer_norm=RMSNorm(model_dim, bias=False), - ).to(device) - - seqs = torch.randn(2, 8, model_dim, device=device) - layout = BatchLayout.of(seqs) - bias_cache = AttentionBiasCache() - - with torch.no_grad(): - out = layer(seqs, layout, bias_cache) - - assert out.shape == (2, 8, model_dim) - - def test_invalid_layer_type_raises(self) -> None: - """Invalid layer_type raises ValueError.""" - model_dim = 64 - ffn = GLUFeedForwardNetwork(model_dim, 128, bias=False, inner_dim_scale=1.0) - - with pytest.raises(ValueError, match="layer_type"): - Qwen35DecoderLayer( - "invalid_type", - self_attn=None, - linear_attn=None, - ffn=ffn, - self_attn_layer_norm=RMSNorm(model_dim, bias=False), - ffn_layer_norm=RMSNorm(model_dim, bias=False), - ) - - -class TestQwen35ModelFactory: - def test_create_small_model(self) -> None: - """Factory creates a working model with the correct output shape.""" - config = Qwen35Config( - model_dim=64, - vocab_size=128, - num_layers=4, - num_attn_heads=4, - num_key_value_heads=2, - head_dim=16, - ffn_inner_dim=128, - partial_rotary_factor=0.25, - linear_num_key_heads=2, - linear_num_value_heads=4, - linear_key_head_dim=8, - linear_value_head_dim=8, - ) - - model = create_qwen35_model(config).to(device) - model.eval() - - input_ids = torch.randint(0, 128, (1, 16), device=device) - layout = BatchLayout.of(input_ids) - - with torch.no_grad(): - logits = model(input_ids, layout) - - assert logits.shape == (1, 16, 128) - - def test_model_has_hybrid_layers(self) -> None: - """Model should have both full_attention and linear_attention layers.""" - config = Qwen35Config( - model_dim=64, - vocab_size=128, - num_layers=4, - num_attn_heads=4, - num_key_value_heads=2, - head_dim=16, - ffn_inner_dim=128, - linear_num_key_heads=2, - linear_num_value_heads=4, - linear_key_head_dim=8, - linear_value_head_dim=8, - ) - - with torch.device("meta"): - model = create_qwen35_model(config) - - layers = list(model.decoder.layers) - layer_types = [ - l.layer_type for l in layers if isinstance(l, Qwen35DecoderLayer) - ] - assert layer_types == [ - "linear_attention", - "linear_attention", - "linear_attention", - "full_attention", - ] diff --git a/tests/unit/models/qwen/test_qwen35_interop.py b/tests/unit/models/qwen/test_qwen35_interop.py deleted file mode 100644 index f41c236d3..000000000 --- a/tests/unit/models/qwen/test_qwen35_interop.py +++ /dev/null @@ -1,500 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -"""Unit tests for the Qwen 3.5 HuggingFace state-dict interop.""" - -from __future__ import annotations - -import torch -from torch.testing import assert_close - -from fairseq2.models.qwen.config import Qwen35Config, Qwen35MoeConfig -from fairseq2.models.qwen.factory import create_qwen35_model, create_qwen35_moe_model -from fairseq2.models.qwen.interop import ( - _QWEN35_HG_KEY_MAP, - _QWEN35_RMSNORM_KEYS, - _QWEN35_TEXT_KEY_MAP, - _Qwen35HuggingFaceConverter, - _Qwen35MoeHuggingFaceConverter, - convert_qwen35_moe_state_dict, - convert_qwen35_state_dict, -) -from fairseq2.models.utils.checkpoint import convert_state_dict, create_reverse_key_map - - -class TestQwen35Interop: - def _make_small_config(self) -> Qwen35Config: - """Create a tiny config for fast testing.""" - config = Qwen35Config() - config.model_dim = 64 - config.vocab_size = 128 - config.num_layers = 4 # 3 linear + 1 full attention - config.num_attn_heads = 4 - config.num_key_value_heads = 2 - config.head_dim = 16 - config.ffn_inner_dim = 128 - config.partial_rotary_factor = 0.25 - config.linear_num_key_heads = 2 - config.linear_num_value_heads = 4 - config.linear_key_head_dim = 8 - config.linear_value_head_dim = 8 - config.layer_types = None # Reset so __post_init__ regenerates for num_layers=4 - config.__post_init__() - return config - - def test_state_dict_key_round_trip(self) -> None: - """fs2 keys -> HF keys -> fs2 keys should be identity.""" - config = self._make_small_config() - - with torch.device("meta"): - model = create_qwen35_model(config) - - fs2_keys = set(model.state_dict().keys()) - assert len(fs2_keys) > 0 - - fs2_state_dict: dict[str, object] = {k: torch.empty(0) for k in fs2_keys} - - # Convert to HF format using reverse key map - reverse_map = create_reverse_key_map(_QWEN35_HG_KEY_MAP) - hg_state_dict = convert_state_dict(fs2_state_dict, reverse_map) - - # Verify HF keys have expected prefixes - for key in hg_state_dict: - assert key.startswith( - ("model.", "lm_head.") - ), f"Unexpected HF key prefix: {key}" - - # Convert back to fs2 format - rt_state_dict = convert_state_dict(dict(hg_state_dict), _QWEN35_HG_KEY_MAP) - rt_keys = set(rt_state_dict.keys()) - - assert fs2_keys == rt_keys, ( - f"Round-trip key mismatch.\n" - f" Missing in round-trip: {fs2_keys - rt_keys}\n" - f" Extra in round-trip: {rt_keys - fs2_keys}" - ) - - def test_rmsnorm_weight_conversion(self) -> None: - """RMSNorm weights get +1.0 added during conversion.""" - config = self._make_small_config() - - # Simulate HF state dict with zero-init RMSNorm weights - hf_state_dict: dict[str, object] = {} - for i in range(config.num_layers): - hf_state_dict[f"model.layers.{i}.input_layernorm.weight"] = torch.zeros( - config.model_dim - ) - hf_state_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = ( - torch.zeros(config.model_dim) - ) - hf_state_dict["model.norm.weight"] = torch.zeros(config.model_dim) - hf_state_dict["model.embed_tokens.weight"] = torch.zeros( - config.vocab_size, config.model_dim - ) - hf_state_dict["lm_head.weight"] = torch.zeros( - config.vocab_size, config.model_dim - ) - - converted = convert_qwen35_state_dict(dict(hf_state_dict), config) - - # All layer norm weights should now be 1.0 (0.0 + 1.0) - for key in converted: - if any(key.endswith(s) for s in _QWEN35_RMSNORM_KEYS): - weight = converted[key] - assert isinstance(weight, torch.Tensor) - assert_close(weight, torch.ones_like(weight)) - - def test_gdn_norm_weight_not_converted(self) -> None: - """GatedDeltaNet internal norm weights should NOT get +1.0.""" - config = self._make_small_config() - - # Simulate HF state dict with GDN norm weight - hf_state_dict: dict[str, object] = {"model.embed_tokens.weight": torch.zeros(1)} - hf_state_dict["model.layers.0.linear_attn.norm.weight"] = ( - torch.ones(config.linear_value_head_dim) * 0.5 - ) - - converted = convert_qwen35_state_dict(dict(hf_state_dict), config) - - # The GDN norm maps to linear_attn.norm.inner_norm.weight - gdn_key = "decoder.layers.0.linear_attn.norm.inner_norm.weight" - if gdn_key in converted: - # Should still be 0.5, NOT 1.5 - assert_close( - converted[gdn_key], - torch.ones(config.linear_value_head_dim) * 0.5, - ) - - def test_tied_embeddings_hf_no_lm_head(self) -> None: - """HF checkpoint with tied_embeddings has no lm_head.weight. - - Safetensors deduplicates shared tensors, so for models with - tie_word_embeddings=True the checkpoint only contains - model.embed_tokens.weight. The converter must create - final_proj.weight from it. - """ - config = self._make_small_config() - config.tied_embeddings = True - - weight = torch.randn(config.vocab_size, config.model_dim) - hf_state_dict: dict[str, object] = { - "model.embed_tokens.weight": weight, - "model.norm.weight": torch.zeros(config.model_dim), - } - - result = convert_qwen35_state_dict(dict(hf_state_dict), config) - - assert "decoder_frontend.embed.weight" in result - assert "final_proj.weight" in result - assert result["final_proj.weight"] is result["decoder_frontend.embed.weight"] - - def test_layer_types_are_correct(self) -> None: - """Verify layer_types pattern: 3 linear, 1 full, repeating.""" - config = self._make_small_config() - assert config.layer_types == [ - "linear_attention", - "linear_attention", - "linear_attention", - "full_attention", - ] - - -class TestQwen35HuggingFaceConverter: - """Tests for _Qwen35HuggingFaceConverter.""" - - def _make_small_config(self) -> Qwen35Config: - config = Qwen35Config() - config.model_dim = 64 - config.vocab_size = 128 - config.num_layers = 4 - config.num_attn_heads = 4 - config.num_key_value_heads = 2 - config.head_dim = 16 - config.ffn_inner_dim = 128 - config.partial_rotary_factor = 0.25 - config.linear_num_key_heads = 2 - config.linear_num_value_heads = 4 - config.linear_key_head_dim = 8 - config.linear_value_head_dim = 8 - config.layer_types = None - config.__post_init__() - return config - - def test_to_hg_config(self) -> None: - """to_hg_config maps Qwen35Config fields to HF config dict.""" - config = self._make_small_config() - converter = _Qwen35HuggingFaceConverter() - hg_config = converter.to_hg_config(config) - - assert hg_config.kls_name == "Qwen3_5TextConfig" - assert hg_config.arch == "Qwen3_5ForCausalLM" - - data = hg_config.data - assert data["hidden_size"] == config.model_dim - assert data["max_position_embeddings"] == config.max_seq_len - assert data["vocab_size"] == config.vocab_size - assert data["tie_word_embeddings"] == config.tied_embeddings - assert data["num_hidden_layers"] == config.num_layers - assert data["num_attention_heads"] == config.num_attn_heads - assert data["num_key_value_heads"] == config.num_key_value_heads - assert data["head_dim"] == config.head_dim - assert data["intermediate_size"] == config.ffn_inner_dim - assert data["partial_rotary_factor"] == config.partial_rotary_factor - assert data["rope_theta"] == config.rope_theta - assert data["full_attention_interval"] == config.full_attention_interval - assert data["linear_conv_kernel_dim"] == config.linear_conv_kernel_dim - assert data["linear_key_head_dim"] == config.linear_key_head_dim - assert data["linear_value_head_dim"] == config.linear_value_head_dim - assert data["linear_num_key_heads"] == config.linear_num_key_heads - assert data["linear_num_value_heads"] == config.linear_num_value_heads - - def test_state_dict_round_trip(self) -> None: - """State dict keys survive a fs2 -> HF -> fs2 round trip.""" - config = self._make_small_config() - - with torch.device("meta"): - model = create_qwen35_model(config) - - fs2_keys = set(model.state_dict().keys()) - assert len(fs2_keys) > 0 - - fs2_state_dict: dict[str, object] = {k: torch.empty(0) for k in fs2_keys} - - converter = _Qwen35HuggingFaceConverter() - hg_state_dict = converter.to_hg_state_dict(fs2_state_dict, config) - hg_keys = set(hg_state_dict.keys()) - - for key in hg_keys: - assert key.startswith( - ("model.", "lm_head.") - ), f"Unexpected HF key prefix: {key}" - - # Round-trip: HF -> fs2 - rt_state_dict = convert_qwen35_state_dict(dict(hg_state_dict), config) - rt_keys = set(rt_state_dict.keys()) - - assert fs2_keys == rt_keys, ( - f"Round-trip key mismatch.\n" - f" Missing in round-trip: {fs2_keys - rt_keys}\n" - f" Extra in round-trip: {rt_keys - fs2_keys}" - ) - - def test_rmsnorm_weight_reversed(self) -> None: - """to_hg_state_dict subtracts 1.0 from RMSNorm weights.""" - config = self._make_small_config() - - # Build a fs2 state dict with RMSNorm weights = 1.0 (standard init) - fs2_state_dict: dict[str, object] = {} - for i in range(config.num_layers): - fs2_state_dict[f"decoder.layers.{i}.self_attn_layer_norm.weight"] = ( - torch.ones(config.model_dim) - ) - fs2_state_dict[f"decoder.layers.{i}.ffn_layer_norm.weight"] = torch.ones( - config.model_dim - ) - fs2_state_dict["decoder.layer_norm.weight"] = torch.ones(config.model_dim) - - converter = _Qwen35HuggingFaceConverter() - hg_state_dict = converter.to_hg_state_dict(fs2_state_dict, config) - - # HF weights should be 0.0 (1.0 - 1.0) - for key in hg_state_dict: - if key.endswith( - ( - "input_layernorm.weight", - "post_attention_layernorm.weight", - "model.norm.weight", - ) - ): - weight = hg_state_dict[key] - assert isinstance(weight, torch.Tensor) - assert_close(weight, torch.zeros_like(weight)) - - def test_tied_embeddings_removes_lm_head(self) -> None: - """to_hg_state_dict removes lm_head.weight when tied_embeddings=True.""" - config = self._make_small_config() - config.tied_embeddings = True - - with torch.device("meta"): - model = create_qwen35_model(config) - - fs2_state_dict: dict[str, object] = { - k: torch.empty(0) for k in model.state_dict().keys() - } - - converter = _Qwen35HuggingFaceConverter() - hg_state_dict = converter.to_hg_state_dict(fs2_state_dict, config) - - assert "lm_head.weight" not in hg_state_dict - assert "model.embed_tokens.weight" in hg_state_dict - - def test_tied_embeddings_deduped_final_proj_only(self) -> None: - """When safetensors deduplicates tied weights and only final_proj.weight - survives, convert_qwen35_state_dict should still reconstruct both keys.""" - config = self._make_small_config() - config.tied_embeddings = True - - weight = torch.randn(config.vocab_size, config.model_dim) - state_dict: dict[str, object] = {"final_proj.weight": weight} - - result = convert_qwen35_state_dict(dict(state_dict), config) - - assert "decoder_frontend.embed.weight" in result - assert result["decoder_frontend.embed.weight"] is weight - - -class TestQwen35MoeHuggingFaceConverter: - """Tests for _Qwen35MoeHuggingFaceConverter.""" - - def _make_small_moe_config(self) -> Qwen35MoeConfig: - config = Qwen35MoeConfig() - config.model_dim = 64 - config.vocab_size = 128 - config.num_layers = 4 - config.num_attn_heads = 4 - config.num_key_value_heads = 2 - config.head_dim = 16 - config.ffn_inner_dim = 128 - config.partial_rotary_factor = 0.25 - config.linear_num_key_heads = 2 - config.linear_num_value_heads = 4 - config.linear_key_head_dim = 8 - config.linear_value_head_dim = 8 - config.num_experts = 4 - config.num_experts_per_tok = 2 - config.moe_intermediate_size = 32 - config.shared_expert_intermediate_size = 32 - config.layer_types = None - config.__post_init__() - return config - - def test_to_hg_config(self) -> None: - """to_hg_config maps Qwen35MoeConfig fields including MoE-specific ones.""" - config = self._make_small_moe_config() - converter = _Qwen35MoeHuggingFaceConverter() - hg_config = converter.to_hg_config(config) - - assert hg_config.kls_name == "Qwen3_5TextConfig" - assert hg_config.arch == "Qwen3_5MoeForCausalLM" - - data = hg_config.data - assert data["hidden_size"] == config.model_dim - assert data["num_experts"] == config.num_experts - assert data["num_experts_per_tok"] == config.num_experts_per_tok - assert data["moe_intermediate_size"] == config.moe_intermediate_size - assert ( - data["shared_expert_intermediate_size"] - == config.shared_expert_intermediate_size - ) - assert data["router_aux_loss_coef"] == config.router_aux_loss_coef - - def test_state_dict_round_trip(self) -> None: - """MoE state dict keys survive a fs2 -> HF -> fs2 round trip.""" - config = self._make_small_moe_config() - - with torch.device("meta"): - model = create_qwen35_moe_model(config) - - fs2_keys = set(model.state_dict().keys()) - assert len(fs2_keys) > 0 - - fs2_state_dict: dict[str, object] = {k: torch.empty(0) for k in fs2_keys} - - converter = _Qwen35MoeHuggingFaceConverter() - hg_state_dict = converter.to_hg_state_dict(fs2_state_dict, config) - hg_keys = set(hg_state_dict.keys()) - - for key in hg_keys: - assert key.startswith( - ("model.", "lm_head.") - ), f"Unexpected HF key prefix: {key}" - - # Round-trip: HF -> fs2 - rt_state_dict = convert_qwen35_moe_state_dict(dict(hg_state_dict), config) - rt_keys = set(rt_state_dict.keys()) - - assert fs2_keys == rt_keys, ( - f"Round-trip key mismatch.\n" - f" Missing in round-trip: {fs2_keys - rt_keys}\n" - f" Extra in round-trip: {rt_keys - fs2_keys}" - ) - - -class TestVlCheckpointHandling: - """Tests for multimodal (VL) checkpoint handling. - - Qwen 3.5 checkpoints on HuggingFace Hub are multimodal models where the - text decoder lives under ``model.language_model.*`` with additional - ``model.visual.*`` and ``mtp.*`` keys. The converter handles both formats - via ``_expand_with_language_model_prefix`` and explicit filtering. - """ - - def _make_small_config(self) -> Qwen35Config: - config = Qwen35Config() - config.model_dim = 64 - config.vocab_size = 128 - config.num_layers = 4 - config.num_attn_heads = 4 - config.num_key_value_heads = 2 - config.head_dim = 16 - config.ffn_inner_dim = 128 - config.partial_rotary_factor = 0.25 - config.linear_num_key_heads = 2 - config.linear_num_value_heads = 4 - config.linear_key_head_dim = 8 - config.linear_value_head_dim = 8 - config.tied_embeddings = True - config.layer_types = None - config.__post_init__() - return config - - def test_key_map_has_language_model_variants(self) -> None: - """_QWEN35_HG_KEY_MAP includes both model.* and model.language_model.* patterns.""" - text_only_count = len(_QWEN35_TEXT_KEY_MAP) - full_count = len(_QWEN35_HG_KEY_MAP) - # model.* patterns get duplicated; lm_head.* does not - model_prefix_count = sum( - 1 for k in _QWEN35_TEXT_KEY_MAP if k.startswith(r"^model\.") - ) - assert full_count == text_only_count + model_prefix_count - - def test_language_model_prefix_keys_convert(self) -> None: - """model.language_model.X keys are correctly converted to fs2 keys.""" - state_dict: dict[str, object] = { - "model.language_model.embed_tokens.weight": torch.empty(0), - "model.language_model.layers.0.input_layernorm.weight": torch.empty(0), - "model.language_model.norm.weight": torch.empty(0), - } - result = convert_state_dict(state_dict, _QWEN35_HG_KEY_MAP) - assert "decoder_frontend.embed.weight" in result - assert "decoder.layers.0.self_attn_layer_norm.weight" in result - assert "decoder.layer_norm.weight" in result - - def test_visual_and_mtp_keys_filtered(self) -> None: - """model.visual.* and mtp.* keys are filtered by convert_qwen35_state_dict.""" - config = self._make_small_config() - state_dict: dict[str, object] = { - "model.language_model.embed_tokens.weight": torch.randn( - config.vocab_size, config.model_dim - ), - "model.language_model.norm.weight": torch.zeros(config.model_dim), - "model.visual.blocks.0.attn.proj.weight": torch.empty(0), - "model.visual.patch_embed.proj.weight": torch.empty(0), - "mtp.fc.weight": torch.empty(0), - "mtp.layers.0.mlp.gate_proj.weight": torch.empty(0), - } - result = convert_qwen35_state_dict(dict(state_dict), config) - for key in result: - assert not key.startswith( - ("model.visual.", "mtp.") - ), f"Unexpected key not filtered: {key}" - - def test_text_only_format_still_works(self) -> None: - """model.layers.* (text-only format) is still handled correctly.""" - config = self._make_small_config() - state_dict: dict[str, object] = { - "model.embed_tokens.weight": torch.randn( - config.vocab_size, config.model_dim - ), - "model.norm.weight": torch.zeros(config.model_dim), - } - result = convert_qwen35_state_dict(dict(state_dict), config) - assert "decoder_frontend.embed.weight" in result - assert "decoder.layer_norm.weight" in result - - def test_end_to_end_vl_checkpoint(self) -> None: - """Full VL checkpoint → convert_qwen35_state_dict produces correct keys.""" - config = self._make_small_config() - - with torch.device("meta"): - model = create_qwen35_model(config) - model_keys = set(model.state_dict().keys()) - - # Build a text-only HF state dict, then add VL prefix + extra modalities - reverse_map = create_reverse_key_map(_QWEN35_TEXT_KEY_MAP) - fs2_state_dict: dict[str, object] = {k: torch.empty(0) for k in model_keys} - hg_state_dict = convert_state_dict(fs2_state_dict, reverse_map) - - # Add model.language_model. prefix (simulating VL checkpoint) - vl_state_dict: dict[str, object] = {} - for k, v in hg_state_dict.items(): - if k.startswith("model."): - vl_state_dict["model.language_model." + k[len("model.") :]] = v - else: - vl_state_dict[k] = v - # Add visual/mtp keys - vl_state_dict["model.visual.blocks.0.attn.proj.weight"] = torch.empty(0) - vl_state_dict["mtp.fc.weight"] = torch.empty(0) - - # Convert back — should match model keys - result = convert_qwen35_state_dict(dict(vl_state_dict), config) - result_keys = set(result.keys()) - - assert model_keys == result_keys, ( - f"VL round-trip key mismatch.\n" - f" Missing: {model_keys - result_keys}\n" - f" Extra: {result_keys - model_keys}" - ) diff --git a/tests/unit/models/qwen/test_qwen35_moe.py b/tests/unit/models/qwen/test_qwen35_moe.py deleted file mode 100644 index c0d06b953..000000000 --- a/tests/unit/models/qwen/test_qwen35_moe.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from __future__ import annotations - -import torch - -from fairseq2.models.qwen.moe import Qwen35Experts, Qwen35MoeBlock, Qwen35TopKRouter -from tests.common import assert_close, device - - -class TestQwen35TopKRouter: - def test_forward_output_shapes(self) -> None: - """Router returns correct shapes: logits(T,E), weights(T,K), indices(T,K).""" - router = Qwen35TopKRouter(num_experts=8, top_k=2, model_dim=32).to(device) - - x = torch.randn(10, 32, device=device) - - with torch.no_grad(): - logits, weights, indices = router(x) - - assert logits.shape == (10, 8) # (T, E) - assert weights.shape == (10, 2) # (T, K) - assert indices.shape == (10, 2) # (T, K) - - def test_weights_sum_to_one(self) -> None: - """Renormalized top-k weights sum to 1 per token.""" - router = Qwen35TopKRouter(num_experts=8, top_k=2, model_dim=32).to(device) - - x = torch.randn(10, 32, device=device) - - with torch.no_grad(): - _, weights, _ = router(x) - - sums = weights.sum(dim=-1) - assert_close(sums, torch.ones(10, device=device), atol=1e-5) - - def test_logits_are_raw_pre_softmax(self) -> None: - """Router logits are raw pre-softmax values (NOT a probability distribution).""" - router = Qwen35TopKRouter(num_experts=8, top_k=2, model_dim=32).to(device) - # Initialize with non-zero weights so logits are non-trivial. - torch.nn.init.normal_(router.weight, std=0.1) - - x = torch.randn(10, 32, device=device) - - with torch.no_grad(): - logits, weights, _ = router(x) - - # Raw logits can be negative and do NOT sum to 1. - assert logits.shape == (10, 8) - sums = logits.sum(dim=-1) - assert not torch.allclose( - sums, torch.ones(10, device=device), atol=1e-3 - ), "Raw logits should NOT sum to 1 (they are not softmax)" - - # But the top-k weights DO sum to 1 (renormalized). - w_sums = weights.sum(dim=-1) - assert_close(w_sums, torch.ones(10, device=device), atol=1e-5) - - -class TestQwen35Experts: - def test_forward_output_shape(self) -> None: - """Experts output shape matches input shape (T, D).""" - experts = Qwen35Experts(num_experts=4, model_dim=32, expert_inner_dim=16).to( - device - ) - torch.nn.init.normal_(experts.gate_up_proj, std=0.01) - torch.nn.init.normal_(experts.down_proj, std=0.01) - - T = 6 - x = torch.randn(T, 32, device=device) - indices = torch.tensor( - [[0, 1], [1, 2], [2, 3], [0, 3], [1, 0], [3, 2]], device=device - ) - weights = torch.ones(T, 2, device=device) * 0.5 - - with torch.no_grad(): - out = experts(x, indices, weights) - - assert out.shape == (T, 32) - - def test_weighted_output(self) -> None: - """Output is weighted by routing weights — zero weight means no contribution.""" - experts = Qwen35Experts(num_experts=4, model_dim=16, expert_inner_dim=8).to( - device - ) - torch.nn.init.normal_(experts.gate_up_proj, std=0.01) - torch.nn.init.normal_(experts.down_proj, std=0.01) - - T = 4 - x = torch.randn(T, 16, device=device) - indices = torch.zeros(T, 2, dtype=torch.long, device=device) - weights_nonzero = torch.ones(T, 2, device=device) * 0.5 - weights_zero = torch.zeros(T, 2, device=device) - - with torch.no_grad(): - out_nonzero = experts(x, indices, weights_nonzero) - out_zero = experts(x, indices, weights_zero) - - assert_close(out_zero, torch.zeros_like(out_zero), atol=1e-6) - assert out_nonzero.abs().mean() > 1e-6 - - -class TestQwen35MoeBlock: - def test_forward_output_shape(self) -> None: - """MoeBlock output shape matches input (B, S, D).""" - moe = Qwen35MoeBlock( - model_dim=32, - num_experts=4, - num_experts_per_tok=2, - moe_intermediate_size=16, - shared_expert_intermediate_size=16, - ).to(device) - - seqs = torch.randn(2, 8, 32, device=device) - - with torch.no_grad(): - out = moe(seqs) - - assert out.shape == (2, 8, 32) - - def test_shared_expert_contributes(self) -> None: - """Shared expert output is non-zero (sigmoid gate blending).""" - moe = Qwen35MoeBlock( - model_dim=32, - num_experts=4, - num_experts_per_tok=2, - moe_intermediate_size=16, - shared_expert_intermediate_size=16, - ).to(device) - - seqs = torch.randn(1, 4, 32, device=device) - - with torch.no_grad(): - out = moe(seqs) - - assert out.abs().mean() > 1e-6 - - def test_drop_in_ffn_replacement(self) -> None: - """MoeBlock inherits FeedForwardNetwork and can be used as drop-in.""" - from fairseq2.models.transformer import FeedForwardNetwork - - moe = Qwen35MoeBlock( - model_dim=32, - num_experts=4, - num_experts_per_tok=2, - moe_intermediate_size=16, - shared_expert_intermediate_size=16, - ) - - assert isinstance(moe, FeedForwardNetwork) From 87f90fe29a39e0b746c152e2a260de36f08dfe1e Mon Sep 17 00:00:00 2001 From: yunchaoyang1 user Date: Wed, 27 May 2026 21:21:40 +0000 Subject: [PATCH 4/4] Fix mypy error: cast .item() to int for tokenizer.decode --- tests/integration/models/test_qwen35.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/integration/models/test_qwen35.py b/tests/integration/models/test_qwen35.py index e7c2f289c..390477695 100644 --- a/tests/integration/models/test_qwen35.py +++ b/tests/integration/models/test_qwen35.py @@ -188,13 +188,13 @@ def test_logit_parity(self) -> None: print(f" Full-seq logit max abs diff: {full_max_diff:.6e}") print(f" Full-seq logit mean abs diff: {full_mean_diff:.6e}") - hf_top1 = hf_last.argmax().item() - fs2_top1 = fs2_last.argmax().item() + hf_top1 = int(hf_last.argmax().item()) + fs2_top1 = int(fs2_last.argmax().item()) print(f"\n HF top-1 token: {hf_top1} -> '{hf_tokenizer.decode([hf_top1])}'") print(f" fs2 top-1 token: {fs2_top1} -> '{hf_tokenizer.decode([fs2_top1])}'") - hf_top5 = hf_last.topk(5).indices.tolist() - fs2_top5 = fs2_last.topk(5).indices.tolist() + hf_top5: list[int] = [int(t) for t in hf_last.topk(5).indices.tolist()] + fs2_top5: list[int] = [int(t) for t in fs2_last.topk(5).indices.tolist()] print( f"\n HF top-5: {hf_top5} -> {[hf_tokenizer.decode([t]) for t in hf_top5]}" )