From 4e2df75448d01d28ae6a58906b38371ddad98a9d Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 29 Apr 2026 17:42:35 +0800 Subject: [PATCH 01/52] update --- src/mcore_bridge/config/model_config.py | 5 +++-- src/mcore_bridge/model/constant.py | 1 + src/mcore_bridge/model/mm_gpts/__init__.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/mcore_bridge/config/model_config.py b/src/mcore_bridge/config/model_config.py index 943c100..5eec352 100644 --- a/src/mcore_bridge/config/model_config.py +++ b/src/mcore_bridge/config/model_config.py @@ -3,13 +3,13 @@ import os import re import torch.nn.functional as F -from dataclasses import dataclass +from dataclasses import dataclass, field from megatron.core import mpu from megatron.core.transformer import TransformerConfig from transformers import PretrainedConfig from transformers.utils import is_torch_npu_available from transformers.utils.versions import require_version -from typing import List, Literal, Optional, Union +from typing import Any, Dict, List, Literal, Optional, Union from mcore_bridge.utils import get_logger, json_parse_to_dict @@ -229,6 +229,7 @@ class ModelConfig(TransformerConfig): task_type: Literal['causal_lm', 'seq_cls', 'embedding', 'generative_reranker'] = 'causal_lm' num_labels: Optional[int] = None mlp_padding_free: bool = False + model_kwargs: Dict[str, Any] = field(default_factory=dict) _mindspeed_defaults_cache = None diff --git a/src/mcore_bridge/model/constant.py b/src/mcore_bridge/model/constant.py index 6c61f09..36f9d81 100644 --- a/src/mcore_bridge/model/constant.py +++ b/src/mcore_bridge/model/constant.py @@ -27,6 +27,7 @@ class MLLMModelType: glm4v_moe = 'glm4v_moe' kimi_vl = 'kimi_vl' llama4 = 'llama4' + gemma4 = 'gemma4' kimi_k25 = 'kimi_k25' diff --git a/src/mcore_bridge/model/mm_gpts/__init__.py b/src/mcore_bridge/model/mm_gpts/__init__.py index d13e4e7..b8ea385 100644 --- a/src/mcore_bridge/model/mm_gpts/__init__.py +++ b/src/mcore_bridge/model/mm_gpts/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from . import glm, internvl, kimi_vl, llama4, llava, qwen, qwen3_5, qwen3_5_gdn, qwen3_omni, qwen3_vl +from . import gemma4, glm, internvl, kimi_vl, llama4, llava, qwen, qwen3_5, qwen3_5_gdn, qwen3_omni, qwen3_vl From c388954d3a86fa9764226fed3b886df0af9497af Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 29 Apr 2026 17:49:36 +0800 Subject: [PATCH 02/52] update --- src/mcore_bridge/model/mm_gpts/gemma4.py | 63 ++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 src/mcore_bridge/model/mm_gpts/gemma4.py diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py new file mode 100644 index 0000000..9b99428 --- /dev/null +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -0,0 +1,63 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +from transformers import AutoModel, PretrainedConfig + +from mcore_bridge.bridge import GPTBridge + +from ..constant import ModelType +from ..register import ModelLoader, ModelMeta, register_model +from .utils import HuggingFaceVit + + +class Gemma4Vit(HuggingFaceVit): + module_mapping = { + 'model.vision_tower': 'vision_tower', + 'model.embed_vision': 'embed_vision', + 'model.audio_tower': 'audio_tower', + 'model.embed_audio': 'embed_audio', + } + _vision_tower = ['vision_tower', 'audio_tower'] + _aligner = ['embed_vision', 'embed_audio'] + support_multimodal = False + + def prepare_model(self, hf_config: PretrainedConfig): + from transformers.models.gemma4.modeling_gemma4 import Gemma4MultimodalEmbedder + self.vision_tower = AutoModel.from_config(hf_config.vision_config) + self.vocab_size = hf_config.text_config.vocab_size + + language_model = AutoModel.from_config(config=hf_config.text_config) + self.language_model = language_model + self.vocab_size_per_layer_input = hf_config.text_config.vocab_size_per_layer_input + self.audio_tower = AutoModel.from_config(hf_config.audio_config) if hf_config.audio_config is not None else None + self.embed_vision = ( + Gemma4MultimodalEmbedder(hf_config.vision_config, hf_config.text_config) + if hf_config.vision_config is not None else None) + self.embed_audio = ( + Gemma4MultimodalEmbedder(hf_config.audio_config, hf_config.text_config) + if hf_config.audio_config is not None else None) + + def get_inputs_embeds(self, inputs_embeds, **kwargs): + return inputs_embeds + + +class Gemma4Bridge(GPTBridge): + pass + + +class Gemma4Loader(ModelLoader): + pass + # def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): + # layer_specs = get_gpt_decoder_block_spec( + # self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage) + # for layer_spec in layer_specs.layer_specs: + # pass + # return layer_specs + + +register_model( + ModelMeta( + ModelType.gemma4, + ['gemma4'], + bridge_cls=Gemma4Bridge, + visual_cls=Gemma4Vit, + loader=Gemma4Loader, + )) From 76af2bcebdf2c1011951a219a4bcda87103185ae Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 29 Apr 2026 19:19:39 +0800 Subject: [PATCH 03/52] update --- src/mcore_bridge/model/gpt_model.py | 65 ++++++++++++++++------------- 1 file changed, 37 insertions(+), 28 deletions(-) diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 5fb714c..e34c60a 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -110,9 +110,7 @@ def __init__( for i in range(len(self.decoder.layers)): if hasattr(self.decoder.layers[i].self_attention, 'rotary_pos_emb'): del self.decoder.layers[i].self_attention.rotary_pos_emb - self.attention_scaling = 1. - new_inv_freq, self.attention_scaling = get_rope_inv_freq(config) - self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) + self._set_inv_freq() if self.config.task_type == 'seq_cls' and self.post_process: self.output_layer = OutputLayerLinear( config.hidden_size, @@ -222,7 +220,36 @@ def _preprocess( if decoder_input is not None and self.training and torch.is_grad_enabled() and not decoder_input.requires_grad: # fix LoRA incompatibility with gradient checkpointing decoder_input = decoder_input.requires_grad_(True) + rotary_pos_emb, rotary_pos_cos, decoder_rotary_pos_emb, rotary_pos_sin = self._get_rotary_pos_emb( + decoder_input, position_ids, packed_seq_params=packed_seq_params) + + if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration') + or self.config.flash_decode) and rotary_pos_cos is not None + and inference_context.is_static_batching()): + current_batch_size = input_ids.shape[0] + sequence_len_offset = torch.tensor( + [inference_context.sequence_len_offset] * current_batch_size, + dtype=torch.int32, + device=rotary_pos_cos.device, # Co-locate this with the rotary tensors + ) + else: + sequence_len_offset = None + + # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the + # reference held by this caller function, enabling early garbage collection for + # inference. Skip wrapping if decoder_input is logged after decoder completion. + if in_inference_mode and not has_config_logger_enabled(self.config): + decoder_input = WrappedTensor(decoder_input) + return (decoder_input, rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, + sequence_len_offset) + + def _set_inv_freq(self): + self.attention_scaling = 1. + new_inv_freq, self.attention_scaling = get_rope_inv_freq(self.config) + self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) + + def _get_rotary_pos_emb(self, decoder_input, position_ids, packed_seq_params, inference_context=None): # Rotary positional embeddings (embedding is None for PP intermediate devices) rotary_pos_emb = None rotary_pos_cos = None @@ -257,26 +284,13 @@ def _preprocess( rotary_seq_len, packed_seq=packed_seq, ) + decoder_rotary_pos_emb = rotary_pos_emb + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' + if self.position_embedding_type == 'rope' and packed_seq and not self.config.apply_rope_fusion: + assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' + decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]] - if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration') - or self.config.flash_decode) and rotary_pos_cos is not None - and inference_context.is_static_batching()): - current_batch_size = input_ids.shape[0] - sequence_len_offset = torch.tensor( - [inference_context.sequence_len_offset] * current_batch_size, - dtype=torch.int32, - device=rotary_pos_cos.device, # Co-locate this with the rotary tensors - ) - else: - sequence_len_offset = None - - # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the - # reference held by this caller function, enabling early garbage collection for - # inference. Skip wrapping if decoder_input is logged after decoder completion. - if in_inference_mode and not has_config_logger_enabled(self.config): - decoder_input = WrappedTensor(decoder_input) - - return decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset + return rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin # Code borrowed from NVIDIA/Megatron-LM def forward( @@ -308,7 +322,7 @@ def forward( inference_context = deprecate_inference_params(inference_context, inference_params) - decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = ( + decoder_input, rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = ( self._preprocess( input_ids=input_ids, position_ids=position_ids, @@ -316,11 +330,6 @@ def forward( inference_context=inference_context, packed_seq_params=packed_seq_params, )) - decoder_rotary_pos_emb = rotary_pos_emb - packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' - if self.position_embedding_type == 'rope' and packed_seq and not self.config.apply_rope_fusion: - assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' - decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]] mtp_decoder_input = decoder_input if self.config.is_multimodal and self.config.mtp_num_layers and decoder_input is None: From 54e33435585c2bdd1f5045c162b238e98f0565ba Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 29 Apr 2026 21:47:11 +0800 Subject: [PATCH 04/52] update --- src/mcore_bridge/tuners/patcher.py | 39 +++--------------------------- src/mcore_bridge/utils/__init__.py | 2 +- src/mcore_bridge/utils/utils.py | 36 +++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 37 deletions(-) diff --git a/src/mcore_bridge/tuners/patcher.py b/src/mcore_bridge/tuners/patcher.py index e715c35..f9cae8c 100644 --- a/src/mcore_bridge/tuners/patcher.py +++ b/src/mcore_bridge/tuners/patcher.py @@ -1,6 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -import copy -from contextlib import contextmanager from megatron.core.extensions.transformer_engine import TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.moe.router import TopKRouter @@ -11,6 +9,8 @@ from torch import nn from typing import Optional +from mcore_bridge.utils import patch_deepcopy + from .lora import LoraParallelLinear @@ -37,39 +37,6 @@ def dispatch_megatron( model.dispatch_megatron = dispatch_megatron -@contextmanager -def _patch_deepcopy(): - _origin_deepcopy = copy.deepcopy - copy_keys = ('tp_group', '_tp_group', 'config') - - def new_deepcopy(x, *args, **kwargs): - if not isinstance(x, nn.Module): - return _origin_deepcopy(x, *args, **kwargs) - - saved_values = {} - for key in copy_keys: - val = getattr(x, key, None) - if val is not None: - saved_values[key] = val - setattr(x, key, None) - - try: - res = _origin_deepcopy(x, *args, **kwargs) - finally: - for key, value in saved_values.items(): - setattr(x, key, value) - - for key, value in saved_values.items(): - setattr(res, key, value) - return res - - copy.deepcopy = new_deepcopy - try: - yield - finally: - copy.deepcopy = _origin_deepcopy - - def _patch_lora_model(): if hasattr(LoraModel, '_mcore_patched'): return @@ -77,7 +44,7 @@ def _patch_lora_model(): __origin_init__ = LoraModel.__init__ def __new_init__(self, *args, **kwargs): - with _patch_deepcopy(): + with patch_deepcopy(): __origin_init__(self, *args, **kwargs) if not isinstance(self.model, MegatronModule): return diff --git a/src/mcore_bridge/utils/__init__.py b/src/mcore_bridge/utils/__init__.py index 34cdbd1..d4285be 100644 --- a/src/mcore_bridge/utils/__init__.py +++ b/src/mcore_bridge/utils/__init__.py @@ -6,4 +6,4 @@ from .megatron_utils import get_local_layer_specs, set_random_seed, split_cp_inputs, unwrap_model from .safetensors import SafetensorLazyLoader, StreamingSafetensorSaver from .torch_utils import gc_collect, get_current_device, safe_ddp_context, to_device -from .utils import deep_getattr, get_env_args, json_parse_to_dict +from .utils import deep_getattr, get_env_args, json_parse_to_dict, patch_deepcopy diff --git a/src/mcore_bridge/utils/utils.py b/src/mcore_bridge/utils/utils.py index 6905c41..a7e525b 100644 --- a/src/mcore_bridge/utils/utils.py +++ b/src/mcore_bridge/utils/utils.py @@ -1,5 +1,8 @@ +import copy import json import os +from contextlib import contextmanager +from torch import nn from transformers.utils import strtobool from typing import Callable, Dict, Optional, TypeVar, Union @@ -58,3 +61,36 @@ def deep_getattr(obj, attr: str, default=None): else: obj = getattr(obj, a, default) return obj + + +@contextmanager +def patch_deepcopy(): + _origin_deepcopy = copy.deepcopy + copy_keys = ('tp_group', '_tp_group', 'config') + + def new_deepcopy(x, *args, **kwargs): + if not isinstance(x, nn.Module): + return _origin_deepcopy(x, *args, **kwargs) + + saved_values = {} + for key in copy_keys: + val = getattr(x, key, None) + if val is not None: + saved_values[key] = val + setattr(x, key, None) + + try: + res = _origin_deepcopy(x, *args, **kwargs) + finally: + for key, value in saved_values.items(): + setattr(x, key, value) + + for key, value in saved_values.items(): + setattr(res, key, value) + return res + + copy.deepcopy = new_deepcopy + try: + yield + finally: + copy.deepcopy = _origin_deepcopy From 25a45bda3aa7dde0b9125c33614f94bdc0c4c18e Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 30 Apr 2026 10:13:18 +0800 Subject: [PATCH 05/52] update --- src/mcore_bridge/model/gpt_model.py | 2 +- src/mcore_bridge/model/mm_gpt_model.py | 4 ++- src/mcore_bridge/model/mm_gpts/gemma4.py | 41 ++++++++++++++++++------ src/mcore_bridge/model/rope.py | 4 +-- 4 files changed, 38 insertions(+), 13 deletions(-) diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index e34c60a..ace31e4 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -220,7 +220,7 @@ def _preprocess( if decoder_input is not None and self.training and torch.is_grad_enabled() and not decoder_input.requires_grad: # fix LoRA incompatibility with gradient checkpointing decoder_input = decoder_input.requires_grad_(True) - rotary_pos_emb, rotary_pos_cos, decoder_rotary_pos_emb, rotary_pos_sin = self._get_rotary_pos_emb( + rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb( decoder_input, position_ids, packed_seq_params=packed_seq_params) if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration') diff --git a/src/mcore_bridge/model/mm_gpt_model.py b/src/mcore_bridge/model/mm_gpt_model.py index b68e82b..b3fc0d6 100644 --- a/src/mcore_bridge/model/mm_gpt_model.py +++ b/src/mcore_bridge/model/mm_gpt_model.py @@ -18,6 +18,7 @@ class MultimodalGPTModel(MegatronModule): + language_model_cls = GPTModel def __init__(self, config: ModelConfig, @@ -29,7 +30,8 @@ def __init__(self, super().__init__(config) self.pre_process = pre_process self.post_process = post_process - self.language_model = GPTModel(config, transformer_layer_spec, pre_process, post_process, *_args, **kwargs) + self.language_model = self.language_model_cls(config, transformer_layer_spec, pre_process, post_process, *_args, + **kwargs) self.vp_stage = self.language_model.vp_stage self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights self.model_meta = config.model_meta diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 9b99428..2d656dd 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -1,10 +1,14 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import copy from transformers import AutoModel, PretrainedConfig from mcore_bridge.bridge import GPTBridge from ..constant import ModelType +from ..gpt_model import GPTModel +from ..mm_gpt_model import MultimodalGPTModel from ..register import ModelLoader, ModelMeta, register_model +from ..rope import get_rope_inv_freq from .utils import HuggingFaceVit @@ -22,15 +26,8 @@ class Gemma4Vit(HuggingFaceVit): def prepare_model(self, hf_config: PretrainedConfig): from transformers.models.gemma4.modeling_gemma4 import Gemma4MultimodalEmbedder self.vision_tower = AutoModel.from_config(hf_config.vision_config) - self.vocab_size = hf_config.text_config.vocab_size - - language_model = AutoModel.from_config(config=hf_config.text_config) - self.language_model = language_model - self.vocab_size_per_layer_input = hf_config.text_config.vocab_size_per_layer_input self.audio_tower = AutoModel.from_config(hf_config.audio_config) if hf_config.audio_config is not None else None - self.embed_vision = ( - Gemma4MultimodalEmbedder(hf_config.vision_config, hf_config.text_config) - if hf_config.vision_config is not None else None) + self.embed_vision = Gemma4MultimodalEmbedder(hf_config.vision_config, hf_config.text_config) self.embed_audio = ( Gemma4MultimodalEmbedder(hf_config.audio_config, hf_config.text_config) if hf_config.audio_config is not None else None) @@ -43,8 +40,34 @@ class Gemma4Bridge(GPTBridge): pass +class Gemma4TextGPTModel(GPTModel): + + def _set_inv_freq(self): + rope_scaling = self.config.rope_scaling + self.config.rope_scaling = rope_scaling['sliding_attention'] + new_inv_freq, attention_scaling = get_rope_inv_freq(self.config) + assert attention_scaling == 1, 'not support' + self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) + # full + self.full_rotary_pos_emb = copy.copy(self.rotary_pos_emb) + self.config.rope_scaling = rope_scaling['full_attention'] + kwargs = {} + if self.config.rope_scaling['rope_type'] == 'proportional': + kwargs['head_dim_key'] = 'global_head_dim' + new_inv_freq, attention_scaling = get_rope_inv_freq(self.config, **kwargs) + assert attention_scaling == 1, 'not support' + self.full_rotary_pos_emb.inv_freq = new_inv_freq + self.attention_scaling = attention_scaling + + self.config.rope_scaling = rope_scaling + + +class Gemma4GPTModel(MultimodalGPTModel): + language_model_cls = Gemma4TextGPTModel + + class Gemma4Loader(ModelLoader): - pass + model_cls = Gemma4GPTModel # def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): # layer_specs = get_gpt_decoder_block_spec( # self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage) diff --git a/src/mcore_bridge/model/rope.py b/src/mcore_bridge/model/rope.py index 5cabe42..e7db3c3 100644 --- a/src/mcore_bridge/model/rope.py +++ b/src/mcore_bridge/model/rope.py @@ -106,12 +106,12 @@ def _get_rope_type(rope_scaling: Optional[Dict[str, Any]]): return rope_type -def get_rope_inv_freq(config, seq_len=None): +def get_rope_inv_freq(config, seq_len=None, **kwargs): from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS ROPE_INIT_FUNCTIONS.update(EXTENDED_ROPE_INIT_FUNCTIONS) dummy_config = _get_dummy_config(config) rope_init_fn = ROPE_INIT_FUNCTIONS[_get_rope_type(config.rope_scaling)] - inv_freq, attention_scaling = rope_init_fn(dummy_config, 'cpu', seq_len=seq_len) + inv_freq, attention_scaling = rope_init_fn(dummy_config, 'cpu', seq_len=seq_len, **kwargs) if attention_scaling is None: attention_scaling = 1. return inv_freq, attention_scaling From 32106178dc2476f09e2589262b8c5b3a28b70dce Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 30 Apr 2026 11:30:56 +0800 Subject: [PATCH 06/52] update --- src/mcore_bridge/model/mm_gpts/gemma4.py | 39 +++++++++++++++++++----- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 2d656dd..b2a7fde 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -1,8 +1,12 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import copy +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from transformers import AutoModel, PretrainedConfig +from typing import Optional -from mcore_bridge.bridge import GPTBridge +from mcore_bridge.bridge import MultimodalGPTBridge +from mcore_bridge.config import ModelConfig from ..constant import ModelType from ..gpt_model import GPTModel @@ -36,12 +40,30 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): return inputs_embeds -class Gemma4Bridge(GPTBridge): +class Gemma4SelfAttention(SelfAttention): + + def __init__( + self, + config: ModelConfig, + submodules: SelfAttentionSubmodules, + layer_number: int, + *args, + **kwargs, + ): + text_config = config.hf_config.text_config + super().__init__(config, submodules, layer_number, *args, **kwargs) + + +class Gemma4Bridge(MultimodalGPTBridge): pass class Gemma4TextGPTModel(GPTModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + print() + def _set_inv_freq(self): rope_scaling = self.config.rope_scaling self.config.rope_scaling = rope_scaling['sliding_attention'] @@ -68,12 +90,13 @@ class Gemma4GPTModel(MultimodalGPTModel): class Gemma4Loader(ModelLoader): model_cls = Gemma4GPTModel - # def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): - # layer_specs = get_gpt_decoder_block_spec( - # self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage) - # for layer_spec in layer_specs.layer_specs: - # pass - # return layer_specs + + def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): + layer_specs = get_gpt_decoder_block_spec( + self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage) + for layer_spec in layer_specs.layer_specs: + layer_spec.submodules.self_attention.module = Gemma4SelfAttention + return layer_specs register_model( From 5b4e118bc804475f58a057a7805f7ff6f312ca4d Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 4 May 2026 18:15:41 +0800 Subject: [PATCH 07/52] update --- src/mcore_bridge/model/modules/__init__.py | 1 + .../model/modules/transformer_layer.py | 30 +++++++++++++++++++ src/mcore_bridge/patcher.py | 30 ------------------- 3 files changed, 31 insertions(+), 30 deletions(-) create mode 100644 src/mcore_bridge/model/modules/transformer_layer.py diff --git a/src/mcore_bridge/model/modules/__init__.py b/src/mcore_bridge/model/modules/__init__.py index 6fd1ac7..eff1bd6 100644 --- a/src/mcore_bridge/model/modules/__init__.py +++ b/src/mcore_bridge/model/modules/__init__.py @@ -2,3 +2,4 @@ from .gated_delta_net import GatedDeltaNet from .gated_self_attention import GatedSelfAttention from .mtp_layer import MultiTokenPredictionLayer +from .transformer_layer import CustomTransformerLayer diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py new file mode 100644 index 0000000..55aa952 --- /dev/null +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -0,0 +1,30 @@ +import megatron.core +from megatron.core.transformer import TransformerLayer +from packaging import version + +mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + + +class CustomTransformerLayer(TransformerLayer): + + def forward(self, *args, **kwargs): + """ + Perform a forward pass through the transformer layer. + + This method calls the core computation of a transformer layer, including + self-attention, cross-attention (if applicable), and feed-forward operations. + """ + if not mcore_013: + return super().forward(self, *args, **kwargs) + hidden_states, context = self._forward_attention(*args, **kwargs) + mlp_padding_free = self.config.mlp_padding_free and 'attention_mask' in kwargs + mask = None + if mlp_padding_free and hidden_states.shape[1] > 1: + mask = ((~kwargs['attention_mask']).sum(dim=(1, 2)) > 0).t() + hidden_states = hidden_states[mask][:, None] + output = self._forward_mlp(hidden_states, kwargs.get('inference_context', None)) + if mask is not None: + new_output = hidden_states.new_zeros((*mask.shape, output.shape[-1])) + new_output[mask] = output.squeeze(1) + output = new_output + return output, context diff --git a/src/mcore_bridge/patcher.py b/src/mcore_bridge/patcher.py index 15280b5..527b0c5 100644 --- a/src/mcore_bridge/patcher.py +++ b/src/mcore_bridge/patcher.py @@ -13,7 +13,6 @@ from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region, gather_from_tensor_model_parallel_region, scatter_to_sequence_parallel_region) -from megatron.core.transformer import TransformerLayer from megatron.core.transformer.multi_latent_attention import MLASelfAttention, MultiLatentAttention from megatron.core.transformer.multi_token_prediction import MultiTokenPredictionBlock, get_mtp_layer_offset from megatron.core.utils import deprecate_inference_params @@ -413,34 +412,6 @@ def sharded_state_dict( peft_module.OriginModulesToSaveWrapper = OriginModulesToSaveWrapper -def _patch_TransformerLayer(): - _origin_forward = TransformerLayer.forward - - def forward(self, *_args, **kwargs): - """ - Perform a forward pass through the transformer layer. - - This method calls the core computation of a transformer layer, including - self-attention, cross-attention (if applicable), and feed-forward operations. - """ - if not mcore_013: - return _origin_forward(self, *_args, **kwargs) - hidden_states, context = self._forward_attention(*_args, **kwargs) - mlp_padding_free = self.config.mlp_padding_free and 'attention_mask' in kwargs - mask = None - if mlp_padding_free and hidden_states.shape[1] > 1: - mask = ((~kwargs['attention_mask']).sum(dim=(1, 2)) > 0).t() - hidden_states = hidden_states[mask][:, None] - output = self._forward_mlp(hidden_states, kwargs.get('inference_context', None)) - if mask is not None: - new_output = hidden_states.new_zeros((*mask.shape, output.shape[-1])) - new_output[mask] = output.squeeze(1) - output = new_output - return output, context - - TransformerLayer.forward = forward - - def _patch_TELinear(): def __repr__(self): @@ -769,7 +740,6 @@ def apply_patch(): # patch module _patch_mla_attention() _patch_TEGroupedLinear() - _patch_TransformerLayer() _patch_TELinear() _patch_mrope() _patch_mtp() From d1d22462c3e56e89ac34f06aafeb484243cfd78a Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 5 May 2026 00:07:45 +0800 Subject: [PATCH 08/52] update --- src/mcore_bridge/model/mm_gpts/gemma4.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index b2a7fde..5b09745 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -4,7 +4,7 @@ from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from transformers import AutoModel, PretrainedConfig from typing import Optional - +from megatron.core.transformer.mlp import MLP from mcore_bridge.bridge import MultimodalGPTBridge from mcore_bridge.config import ModelConfig @@ -51,7 +51,27 @@ def __init__( **kwargs, ): text_config = config.hf_config.text_config + self.is_sliding = text_config.layer_types[layer_number - 1] == 'sliding_attention' + self.sliding_window = text_config.sliding_window if self.is_sliding else None + kv_channels = config.kv_channels + config.kv_channels = text_config.global_head_dim if not self.is_sliding and text_config.global_head_dim else text_config.head_dim super().__init__(config, submodules, layer_number, *args, **kwargs) + config.kv_channels = kv_channels + +class Gemma4MLP(MLP): + def __init__( + self, + config: ModelConfig, + submodules: SelfAttentionSubmodules, + *args, + **kwargs, + ): + text_config = config.hf_config.text_config + self.enable_moe_block = text_config.enable_moe_block + first_kv_shared_layer_idx = text_config.num_hidden_layers - text_config.num_kv_shared_layers + is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 + use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer + super().__init__(config, submodules, *args, **kwargs) class Gemma4Bridge(MultimodalGPTBridge): @@ -96,6 +116,7 @@ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage) for layer_spec in layer_specs.layer_specs: layer_spec.submodules.self_attention.module = Gemma4SelfAttention + layer_spec.submodules.mlp.module = Gemma4MLP return layer_specs From 14b164489b9228b73276b3f6b3560984e483baac Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 5 May 2026 12:19:34 +0800 Subject: [PATCH 09/52] fix --- src/mcore_bridge/model/mm_gpts/gemma4.py | 11 +- .../model/modules/transformer_layer.py | 208 +++++++++++++++++- src/mcore_bridge/model/register.py | 6 +- 3 files changed, 221 insertions(+), 4 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 5b09745..8cba526 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -2,9 +2,10 @@ import copy from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.mlp import MLP from transformers import AutoModel, PretrainedConfig from typing import Optional -from megatron.core.transformer.mlp import MLP + from mcore_bridge.bridge import MultimodalGPTBridge from mcore_bridge.config import ModelConfig @@ -54,18 +55,24 @@ def __init__( self.is_sliding = text_config.layer_types[layer_number - 1] == 'sliding_attention' self.sliding_window = text_config.sliding_window if self.is_sliding else None kv_channels = config.kv_channels - config.kv_channels = text_config.global_head_dim if not self.is_sliding and text_config.global_head_dim else text_config.head_dim + config.kv_channels = ( + text_config.global_head_dim if not self.is_sliding and text_config.global_head_dim else text_config.head_dim + ) super().__init__(config, submodules, layer_number, *args, **kwargs) config.kv_channels = kv_channels + class Gemma4MLP(MLP): + def __init__( self, config: ModelConfig, submodules: SelfAttentionSubmodules, + layer_number: int, *args, **kwargs, ): + self.layer_number = layer_number text_config = config.hf_config.text_config self.enable_moe_block = text_config.enable_moe_block first_kv_shared_layer_idx = text_config.num_hidden_layers - text_config.num_kv_shared_layers diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index 55aa952..83c42e3 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -1,12 +1,218 @@ import megatron.core -from megatron.core.transformer import TransformerLayer +import torch +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.enums import CudaGraphScope, LayerType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import (TransformerLayer, TransformerLayerSubmodules, + get_transformer_layer_offset) +from megatron.core.utils import get_pg_rank from packaging import version +from typing import Optional + +from mcore_bridge.utils import get_logger mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') +logger = get_logger() + class CustomTransformerLayer(TransformerLayer): + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + layer_number: int = 1, + hidden_dropout: Optional[float] = None, + pg_collection: Optional[ProcessGroupCollection] = None, + vp_stage: Optional[int] = None, + is_mtp_layer: bool = False, + add_layer_offset: bool = True, + pp_layer_offset: Optional[int] = None, + ): + self.submodules_config = submodules + super().__init__(config=config, vp_stage=vp_stage) + + if pg_collection is None: + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + self.pg_collection = pg_collection + self.tp_group = pg_collection.tp + + # MTP inner layers use their own layer numbering (starting from 1 within each MTP depth), + # so they should NOT add the decoder layer offset. The router.py handles MTP layer + # numbering separately by adding config.num_layers to distinguish MTP layers from decoder + # layers in the aux loss tracker. + # + # When add_layer_offset is False, the caller has already included the correct offset + # in layer_number (e.g. when using --hybrid-layer-pattern with fVPP). + if is_mtp_layer or not add_layer_offset: + self.layer_number = layer_number + else: + self.layer_number = layer_number + get_transformer_layer_offset(self.config, vp_stage, + get_pg_rank(pg_collection.pp)) + self.hidden_dropout = config.hidden_dropout if hidden_dropout is None else hidden_dropout + self.is_mtp_layer = is_mtp_layer + + # [Module 1: Input Layernorm] Optional Layernorm on the input data + # TODO: add pytorch only layernorm + self.input_layernorm = submodules.input_layernorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + attention_optional_kwargs = {} + if config.context_parallel_size > 1 and config.cp_comm_type is not None: + if isinstance(config.cp_comm_type, list): + # layer_number is 1-indexed, so we need to subtract 1 to get the correct index + attention_optional_kwargs['cp_comm_type'] = config.cp_comm_type[self.layer_number - 1] + else: + attention_optional_kwargs['cp_comm_type'] = config.cp_comm_type + + attention_optional_kwargs['pg_collection'] = pg_collection + if pp_layer_offset is not None: + attention_optional_kwargs['pp_layer_offset'] = pp_layer_offset + + # [Module 2: SelfAttention] + self.self_attention = build_module( + submodules.self_attention, + config=self.config, + layer_number=self.layer_number, + **attention_optional_kwargs, + ) + + # [Module 3: BiasDropoutFusion] + self.self_attn_bda = build_module(submodules.self_attn_bda) + + # [Module 4: Post SelfAttention] Optional Layernorm after self-attn + self.pre_cross_attn_layernorm = submodules.pre_cross_attn_layernorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + # [Module 5: CrossAttention] + self.cross_attention = build_module( + submodules.cross_attention, + config=self.config, + layer_number=self.layer_number, + **attention_optional_kwargs, + ) + + # [Module 6: BiasDropoutFusion] + self.cross_attn_bda = build_module(submodules.cross_attn_bda, config=self.config) + + # [Module 7: Pre MLP] Optional Layernorm before MLP + self.pre_mlp_layernorm = submodules.pre_mlp_layernorm( + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + # [Module 8: MLP block] + additional_mlp_kwargs = {} + # import here to avoid circular import + from megatron.core.extensions.transformer_engine import TEFusedMLP + from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP + from megatron.core.transformer.moe.moe_layer import MoELayer + + # MLP expects tp_group but MoELayer expects pg_collection to be passed in. + # We can change MLP to accept pg_collection but it makes the logic implicit + # The conditional below is to make the logic explicit + # if submodules.mlp is not a ModuleSpec,we dont have to handle passing additional kwargs + if isinstance(submodules.mlp, ModuleSpec): + if submodules.mlp.module in (MoELayer, TEGroupedMLP, SequentialMLP): + additional_mlp_kwargs['pg_collection'] = pg_collection + # Pass is_mtp_layer flag to MoELayer to distinguish MTP MoE layers. + if submodules.mlp.module == MoELayer: + additional_mlp_kwargs['is_mtp_layer'] = self.is_mtp_layer + elif submodules.mlp.module == MLP: + assert hasattr(pg_collection, 'tp'), 'TP process group is required for MLP in TransformerLayer' + additional_mlp_kwargs['tp_group'] = pg_collection.tp + elif TEFusedMLP is not None and submodules.mlp.module == TEFusedMLP: + assert hasattr(pg_collection, 'tp'), 'TP process group is required for TEFusedMLP in TransformerLayer' + additional_mlp_kwargs['tp_group'] = pg_collection.tp + else: + logger.warning_once(f"Unknown MLP type: {type(submodules.mlp)}. Using default kwargs.") + self.mlp = build_module(submodules.mlp, config=self.config, **additional_mlp_kwargs) + if hasattr(self.mlp, 'set_layer_number'): + self.mlp.set_layer_number(self.layer_number) + + # [Module 9: BiasDropoutFusion] + self.mlp_bda = build_module(submodules.mlp_bda) + + self.is_moe_layer = isinstance(self.mlp, MoELayer) + + self.recompute_input_layernorm = False + self.recompute_pre_mlp_layernorm = False + self.recompute_mlp = False + if self.config.recompute_granularity == 'selective': + assert self.config.recompute_modules is not None + if 'layernorm' in self.config.recompute_modules: + if not isinstance(self.input_layernorm, IdentityOp): + self.recompute_input_layernorm = True + if self.config.fp8 or self.config.fp4: + self.self_attention.set_for_recompute_input_layernorm() + + def can_recompute_pre_mlp_layernorm_for_cudagraph(): + if (not self.is_moe_layer or CudaGraphScope.moe_router not in self.config.cuda_graph_scope + or self.config.cuda_graph_impl == 'local'): + # Not a MoE layer, or not capturing the router part. + return True + if (self.config.moe_shared_expert_intermediate_size is not None + and self.config.moe_shared_expert_overlap): + # If shared expert overlap is used, we cannot make the pre-mlp layernorm + # recomputation, because the shared expert takes the layernorm output as + # input, and it is outside of the CUDA graph scope. + logger.warning( + 'pre_mlp_layernorm recompute is not supported with moe router ' + 'cudagraph + shared expert overlap. Disabling pre_mlp_layernorm ' + 'recompute.', ) + return False + if CudaGraphScope.moe_preprocess in self.config.cuda_graph_scope and ( + self.config.moe_token_dispatcher_type == 'alltoall' or self.config.moe_latent_size): + # Only when capturing the preprocess part and using alltoall token + # dispatcher or latent MoE can we make the pre-mlp layernorm recomputation. + # Because in other cases the layernorm output returns directly as one of the + # outputs of the cudagraph, which will be allocated a static buffer, thus + # not able to be released. + return True + logger.warning( + 'pre_mlp_layernorm recompute is only supported with moe router + ' + 'preprocess cudagraph will alltoall token dispatcher or latent MoE. ' + 'Disabling pre_mlp_layernorm recompute.', ) + return False + + if (not isinstance(self.pre_mlp_layernorm, IdentityOp) + and can_recompute_pre_mlp_layernorm_for_cudagraph()): + self.recompute_pre_mlp_layernorm = True + if self.config.fp8 or self.config.fp4: + if isinstance(self.mlp, MoELayer): + self.mlp.set_for_recompute_pre_mlp_layernorm() + else: + from megatron.core.extensions.transformer_engine import set_save_original_input + + set_save_original_input(self.mlp.linear_fc1) + if 'mlp' in self.config.recompute_modules: + if not self.is_moe_layer: + self.recompute_mlp = True + self.offload_attn_norm = ( + self.config.fine_grained_activation_offloading and 'attn_norm' in self.config.offload_modules + and not isinstance(self.input_layernorm, IdentityOp)) + self.offload_mlp_norm = ( + self.config.fine_grained_activation_offloading and 'mlp_norm' in self.config.offload_modules + and not isinstance(self.pre_mlp_layernorm, IdentityOp)) + + # @jcasper how should we handle nvfuser? + # Set bias+dropout+add fusion grad_enable execution handler. + # TORCH_MAJOR = int(torch.__version__.split('.')[0]) + # TORCH_MINOR = int(torch.__version__.split('.')[1]) + # use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) + # self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad + self.bias_dropout_add_exec_handler = torch.enable_grad + def forward(self, *args, **kwargs): """ Perform a forward pass through the transformer layer. diff --git a/src/mcore_bridge/model/register.py b/src/mcore_bridge/model/register.py index 0e67e90..9be5b6f 100644 --- a/src/mcore_bridge/model/register.py +++ b/src/mcore_bridge/model/register.py @@ -15,7 +15,7 @@ from mcore_bridge.config import ModelConfig from mcore_bridge.utils import get_logger -from .modules import MultiTokenPredictionLayer +from .modules import CustomTransformerLayer, MultiTokenPredictionLayer if TYPE_CHECKING: from .gpt_model import GPTModel @@ -138,6 +138,10 @@ def _set_shared_expert_gate(self, transformer_layer_spec): if hasattr(layer_spec.submodules.mlp.submodules, 'shared_experts'): layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} + def _set_custom_layer(self, transformer_layer_spec): + pass + # CustomTransformerLayer + def build_model( self, pre_process=True, From 196a58fda42452eebe2eabee2f2545f0099e2122 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 5 May 2026 13:19:36 +0800 Subject: [PATCH 10/52] update --- src/mcore_bridge/model/gpts/glm4.py | 10 ++-- src/mcore_bridge/model/gpts/minimax_m2.py | 7 +-- src/mcore_bridge/model/mm_gpts/gemma4.py | 4 +- .../model/modules/transformer_layer.py | 45 +++++++++++++----- src/mcore_bridge/model/register.py | 46 ++++++------------- 5 files changed, 58 insertions(+), 54 deletions(-) diff --git a/src/mcore_bridge/model/gpts/glm4.py b/src/mcore_bridge/model/gpts/glm4.py index 5c7a8ca..861f94d 100644 --- a/src/mcore_bridge/model/gpts/glm4.py +++ b/src/mcore_bridge/model/gpts/glm4.py @@ -91,11 +91,11 @@ def _set_layer_attn(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: boo class Glm4Loader(ModelLoader): def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): - layer_spec = self._get_transformer_layer_spec() - layer_spec.submodules.self_attention.module = Glm4SelfAttention - layer_spec.submodules.mlp.module = Glm4MLP - transformer_layer.MLP = Glm4MLP # patch - return layer_spec + transformer_layer_spec = super().get_transformer_layer_spec(vp_stage) + for layer_spec in transformer_layer_spec.layer_specs: + layer_spec.submodules.self_attention.module = Glm4SelfAttention + layer_spec.submodules.mlp.module = Glm4MLP + return transformer_layer_spec register_model(ModelMeta( diff --git a/src/mcore_bridge/model/gpts/minimax_m2.py b/src/mcore_bridge/model/gpts/minimax_m2.py index 81b11b6..c03f803 100644 --- a/src/mcore_bridge/model/gpts/minimax_m2.py +++ b/src/mcore_bridge/model/gpts/minimax_m2.py @@ -95,9 +95,10 @@ def _set_moe_state( class MinimaxM2Loader(ModelLoader): def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): - layer_spec = self._get_transformer_layer_spec() - layer_spec.submodules.self_attention.module = MinimaxM2SelfAttention - return layer_spec + transformer_layer_spec = super().get_transformer_layer_spec(vp_stage) + for layer_spec in transformer_layer_spec.layer_specs: + layer_spec.submodules.self_attention.module = MinimaxM2SelfAttention + return transformer_layer_spec register_model(ModelMeta( diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 8cba526..8bfc0c2 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -56,8 +56,8 @@ def __init__( self.sliding_window = text_config.sliding_window if self.is_sliding else None kv_channels = config.kv_channels config.kv_channels = ( - text_config.global_head_dim if not self.is_sliding and text_config.global_head_dim else text_config.head_dim - ) + text_config.global_head_dim + if not self.is_sliding and text_config.global_head_dim else text_config.head_dim) super().__init__(config, submodules, layer_number, *args, **kwargs) config.kv_channels = kv_channels diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index 83c42e3..4fe0a82 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -1,7 +1,8 @@ +import enum +import inspect import megatron.core import torch from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.transformer.enums import CudaGraphScope, LayerType from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP from megatron.core.transformer.spec_utils import ModuleSpec, build_module @@ -14,6 +15,22 @@ from mcore_bridge.utils import get_logger +try: + from megatron.core.transformer.enums import CudaGraphScope +except ImportError: + + class CudaGraphScope(enum.Enum): + """Cuda Graph Scope - defines which parts of the model to capture.""" + + full_iteration = 1 # Captures the entire training/inference iteration + attn = 2 # Captures attention layers + mlp = 3 # Captures MLP layers (dense layers only) + moe = 4 # Captures MoE layers (drop-and-pad MoE layers only) + moe_router = 5 # Captures MoE router part + moe_preprocess = 6 # Captures MoE preprocessing part (requires moe_router) + mamba = 7 # Captures Mamba layers + + mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') logger = get_logger() @@ -34,7 +51,7 @@ def __init__( pp_layer_offset: Optional[int] = None, ): self.submodules_config = submodules - super().__init__(config=config, vp_stage=vp_stage) + super(TransformerLayer, self).__init__(config=config, vp_stage=vp_stage) if pg_collection is None: pg_collection = ProcessGroupCollection.use_mpu_process_groups() @@ -118,6 +135,9 @@ def __init__( from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP from megatron.core.transformer.moe.moe_layer import MoELayer + from mcore_bridge.model.gpts.glm4 import Glm4MLP + from mcore_bridge.model.mm_gpts.gemma4 import Gemma4MLP + # MLP expects tp_group but MoELayer expects pg_collection to be passed in. # We can change MLP to accept pg_collection but it makes the logic implicit # The conditional below is to make the logic explicit @@ -126,16 +146,18 @@ def __init__( if submodules.mlp.module in (MoELayer, TEGroupedMLP, SequentialMLP): additional_mlp_kwargs['pg_collection'] = pg_collection # Pass is_mtp_layer flag to MoELayer to distinguish MTP MoE layers. - if submodules.mlp.module == MoELayer: + if submodules.mlp.module == MoELayer and 'is_mtp_layer' in inspect.signature(MoELayer).parameters: additional_mlp_kwargs['is_mtp_layer'] = self.is_mtp_layer - elif submodules.mlp.module == MLP: + elif submodules.mlp.module in (MLP, Glm4MLP): assert hasattr(pg_collection, 'tp'), 'TP process group is required for MLP in TransformerLayer' additional_mlp_kwargs['tp_group'] = pg_collection.tp + elif submodules.mlp.module == Gemma4MLP: + additional_mlp_kwargs['layer_number'] = layer_number elif TEFusedMLP is not None and submodules.mlp.module == TEFusedMLP: assert hasattr(pg_collection, 'tp'), 'TP process group is required for TEFusedMLP in TransformerLayer' additional_mlp_kwargs['tp_group'] = pg_collection.tp else: - logger.warning_once(f"Unknown MLP type: {type(submodules.mlp)}. Using default kwargs.") + logger.warning_once(f'Unknown MLP type: {type(submodules.mlp)}. Using default kwargs.') self.mlp = build_module(submodules.mlp, config=self.config, **additional_mlp_kwargs) if hasattr(self.mlp, 'set_layer_number'): self.mlp.set_layer_number(self.layer_number) @@ -198,12 +220,13 @@ def can_recompute_pre_mlp_layernorm_for_cudagraph(): if 'mlp' in self.config.recompute_modules: if not self.is_moe_layer: self.recompute_mlp = True - self.offload_attn_norm = ( - self.config.fine_grained_activation_offloading and 'attn_norm' in self.config.offload_modules - and not isinstance(self.input_layernorm, IdentityOp)) - self.offload_mlp_norm = ( - self.config.fine_grained_activation_offloading and 'mlp_norm' in self.config.offload_modules - and not isinstance(self.pre_mlp_layernorm, IdentityOp)) + if hasattr(self.config, 'fine_grained_activation_offloading'): + self.offload_attn_norm = ( + self.config.fine_grained_activation_offloading and 'attn_norm' in self.config.offload_modules + and not isinstance(self.input_layernorm, IdentityOp)) + self.offload_mlp_norm = ( + self.config.fine_grained_activation_offloading and 'mlp_norm' in self.config.offload_modules + and not isinstance(self.pre_mlp_layernorm, IdentityOp)) # @jcasper how should we handle nvfuser? # Set bias+dropout+add fusion grad_enable execution handler. diff --git a/src/mcore_bridge/model/register.py b/src/mcore_bridge/model/register.py index 9be5b6f..e8eef7d 100644 --- a/src/mcore_bridge/model/register.py +++ b/src/mcore_bridge/model/register.py @@ -90,41 +90,20 @@ def _replace_spec_dsa(self, layer_spec): layer_spec.submodules.self_attention = dsa_spec def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): - if self.config.num_moe_experts: - transformer_layer_spec = get_gpt_decoder_block_spec( - self.config, - use_transformer_engine=True, - normalization=self.config.normalization, - qk_l2_norm=self.config.qk_l2_norm, - vp_stage=vp_stage) - if self.config.experimental_attention_variant == 'dsa': - for layer_spec in transformer_layer_spec.layer_specs: - self._replace_spec_dsa(layer_spec) - else: - transformer_layer_spec = self._get_transformer_layer_spec() - return transformer_layer_spec - - def _get_transformer_layer_spec(self): - config = self.config - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - config.num_moe_experts, - config.moe_grouped_gemm, - config.qk_layernorm, - config.multi_latent_attention, - qk_l2_norm=config.qk_l2_norm, - ) + transformer_layer_spec = get_gpt_decoder_block_spec( + self.config, + use_transformer_engine=True, + normalization=self.config.normalization, + qk_l2_norm=self.config.qk_l2_norm, + vp_stage=vp_stage) + if self.config.experimental_attention_variant == 'dsa': + for layer_spec in transformer_layer_spec.layer_specs: + self._replace_spec_dsa(layer_spec) return transformer_layer_spec def get_mtp_block_spec(self, transformer_layer_spec, vp_stage: Optional[int] = None): - if hasattr(transformer_layer_spec, 'layer_specs') and len(transformer_layer_spec.layer_specs) == 0: - # Get the decoder layer spec explicitly if no decoder layer in the last stage, - # Only happens with block spec (TransformerBlockSubmodules) when using MoE. - # TODO: remove - transformer_layer_spec_for_mtp = self._get_transformer_layer_spec() - else: - transformer_layer_spec_for_mtp = transformer_layer_spec mtp_block_spec = get_gpt_mtp_block_spec( - self.config, transformer_layer_spec_for_mtp, use_transformer_engine=True, vp_stage=vp_stage) + self.config, transformer_layer_spec, use_transformer_engine=True, vp_stage=vp_stage) if mtp_block_spec is not None: for layer_spec in mtp_block_spec.layer_specs: layer_spec.module = MultiTokenPredictionLayer @@ -139,8 +118,8 @@ def _set_shared_expert_gate(self, transformer_layer_spec): layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} def _set_custom_layer(self, transformer_layer_spec): - pass - # CustomTransformerLayer + for layer_spec in transformer_layer_spec.layer_specs: + layer_spec.module = CustomTransformerLayer def build_model( self, @@ -150,6 +129,7 @@ def build_model( ) -> Union['GPTModel', 'MultimodalGPTModel']: transformer_layer_spec = self.get_transformer_layer_spec(vp_stage=vp_stage) self._set_shared_expert_gate(transformer_layer_spec) + self._set_custom_layer(transformer_layer_spec) mtp_block_spec = None if self.config.mtp_num_layers is not None: mtp_block_spec = self.get_mtp_block_spec(transformer_layer_spec, vp_stage=vp_stage) From 68e33a7dc04b3b8a746993802f98b5437ba705ae Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 5 May 2026 13:31:46 +0800 Subject: [PATCH 11/52] update --- src/mcore_bridge/model/modules/transformer_layer.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index 4fe0a82..83dcf5e 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -1,6 +1,5 @@ import enum import inspect -import megatron.core import torch from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.identity_op import IdentityOp @@ -10,7 +9,6 @@ from megatron.core.transformer.transformer_layer import (TransformerLayer, TransformerLayerSubmodules, get_transformer_layer_offset) from megatron.core.utils import get_pg_rank -from packaging import version from typing import Optional from mcore_bridge.utils import get_logger @@ -31,8 +29,6 @@ class CudaGraphScope(enum.Enum): mamba = 7 # Captures Mamba layers -mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') - logger = get_logger() @@ -243,8 +239,6 @@ def forward(self, *args, **kwargs): This method calls the core computation of a transformer layer, including self-attention, cross-attention (if applicable), and feed-forward operations. """ - if not mcore_013: - return super().forward(self, *args, **kwargs) hidden_states, context = self._forward_attention(*args, **kwargs) mlp_padding_free = self.config.mlp_padding_free and 'attention_mask' in kwargs mask = None From 44ddaec8fec777dbda3b627c5718694d7d1bb9a8 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 5 May 2026 20:31:15 +0800 Subject: [PATCH 12/52] update --- .../model/modules/transformer_layer.py | 22 ++++++++++++++++++- src/mcore_bridge/model/register.py | 4 +--- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index 83dcf5e..11c940b 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -2,6 +2,8 @@ import inspect import torch from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region, + scatter_to_sequence_parallel_region) from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP from megatron.core.transformer.spec_utils import ModuleSpec, build_module @@ -242,12 +244,30 @@ def forward(self, *args, **kwargs): hidden_states, context = self._forward_attention(*args, **kwargs) mlp_padding_free = self.config.mlp_padding_free and 'attention_mask' in kwargs mask = None + enable_sp = self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1 + pad_size = 0 if mlp_padding_free and hidden_states.shape[1] > 1: + if enable_sp: + hidden_states = gather_from_sequence_parallel_region(hidden_states, tensor_parallel_output_grad=False) mask = ((~kwargs['attention_mask']).sum(dim=(1, 2)) > 0).t() hidden_states = hidden_states[mask][:, None] + if enable_sp: + tp_size = self.config.tensor_model_parallel_size + num_tokens = hidden_states.shape[0] + remainder = num_tokens % tp_size + if remainder != 0: + pad_size = tp_size - remainder + hidden_states = torch.nn.functional.pad(hidden_states, (0, 0, 0, 0, 0, pad_size)) + hidden_states = scatter_to_sequence_parallel_region(hidden_states) output = self._forward_mlp(hidden_states, kwargs.get('inference_context', None)) if mask is not None: - new_output = hidden_states.new_zeros((*mask.shape, output.shape[-1])) + if enable_sp: + output = gather_from_sequence_parallel_region(output, tensor_parallel_output_grad=False) + if pad_size > 0: + output = output[:-pad_size] + new_output = output.new_zeros((*mask.shape, output.shape[-1])) new_output[mask] = output.squeeze(1) output = new_output + if enable_sp: + output = scatter_to_sequence_parallel_region(output) return output, context diff --git a/src/mcore_bridge/model/register.py b/src/mcore_bridge/model/register.py index e8eef7d..15b37fe 100644 --- a/src/mcore_bridge/model/register.py +++ b/src/mcore_bridge/model/register.py @@ -4,9 +4,7 @@ from megatron.core import mpu from megatron.core.enums import ModelType from megatron.core.extensions.transformer_engine import TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear -from megatron.core.models.gpt.gpt_layer_specs import (get_gpt_decoder_block_spec, - get_gpt_layer_with_transformer_engine_spec, - get_gpt_mtp_block_spec) +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec, get_gpt_mtp_block_spec from packaging import version from torch import nn from typing import TYPE_CHECKING, List, Optional, Type, Union From 0de0ebba80cdb40092a6436403a0e103620b3826 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 10:37:52 +0800 Subject: [PATCH 13/52] update --- src/mcore_bridge/config/model_config.py | 5 ++--- src/mcore_bridge/config/parser.py | 2 ++ src/mcore_bridge/model/gpt_model.py | 2 +- src/mcore_bridge/model/mm_gpts/gemma4.py | 15 +++++++++++++-- tests/test_mllm.py | 7 ++++++- 5 files changed, 24 insertions(+), 7 deletions(-) diff --git a/src/mcore_bridge/config/model_config.py b/src/mcore_bridge/config/model_config.py index 5eec352..943c100 100644 --- a/src/mcore_bridge/config/model_config.py +++ b/src/mcore_bridge/config/model_config.py @@ -3,13 +3,13 @@ import os import re import torch.nn.functional as F -from dataclasses import dataclass, field +from dataclasses import dataclass from megatron.core import mpu from megatron.core.transformer import TransformerConfig from transformers import PretrainedConfig from transformers.utils import is_torch_npu_available from transformers.utils.versions import require_version -from typing import Any, Dict, List, Literal, Optional, Union +from typing import List, Literal, Optional, Union from mcore_bridge.utils import get_logger, json_parse_to_dict @@ -229,7 +229,6 @@ class ModelConfig(TransformerConfig): task_type: Literal['causal_lm', 'seq_cls', 'embedding', 'generative_reranker'] = 'causal_lm' num_labels: Optional[int] = None mlp_padding_free: bool = False - model_kwargs: Dict[str, Any] = field(default_factory=dict) _mindspeed_defaults_cache = None diff --git a/src/mcore_bridge/config/parser.py b/src/mcore_bridge/config/parser.py index 877b63c..68ef8e3 100644 --- a/src/mcore_bridge/config/parser.py +++ b/src/mcore_bridge/config/parser.py @@ -149,6 +149,8 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]: n_shared_experts = res.pop('n_shared_experts') elif llm_model_type in {'ernie4_5', 'ernie4_5_moe', 'glm4'}: res['rotary_interleaved'] = True + elif hf_model_type == 'gemma4': + config.qk_layernorm = True elif llm_model_type == 'gpt_oss': res['add_bias_linear'] = True res['bias_dropout_fusion'] = False diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 86c9587..5679122 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -311,7 +311,7 @@ def forward( """ inference_context = deprecate_inference_params(inference_context, inference_params) - + # There is a difference in whether rotary_pos_emb can be fused between the decoder and MTP. decoder_input, rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = ( self._preprocess( input_ids=input_ids, diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 8bfc0c2..a45fb5a 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -15,6 +15,7 @@ from ..register import ModelLoader, ModelMeta, register_model from ..rope import get_rope_inv_freq from .utils import HuggingFaceVit +from ..module import CustomTransformerLayer class Gemma4Vit(HuggingFaceVit): @@ -76,9 +77,12 @@ def __init__( text_config = config.hf_config.text_config self.enable_moe_block = text_config.enable_moe_block first_kv_shared_layer_idx = text_config.num_hidden_layers - text_config.num_kv_shared_layers - is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 - use_double_wide_mlp = config.use_double_wide_mlp and is_kv_shared_layer + is_kv_shared_layer = layer_number > first_kv_shared_layer_idx > 0 + use_double_wide_mlp = text_config.use_double_wide_mlp and is_kv_shared_layer + ffn_hidden_size = config.ffn_hidden_size + config.ffn_hidden_size = config.ffn_hidden_size * (2 if use_double_wide_mlp else 1) super().__init__(config, submodules, *args, **kwargs) + config.ffn_hidden_size = ffn_hidden_size class Gemma4Bridge(MultimodalGPTBridge): @@ -110,6 +114,9 @@ def _set_inv_freq(self): self.config.rope_scaling = rope_scaling +class Gemma4TransformerLayer(CustomTransformerLayer): + pass + class Gemma4GPTModel(MultimodalGPTModel): language_model_cls = Gemma4TextGPTModel @@ -127,6 +134,10 @@ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): return layer_specs + def _set_custom_layer(self, transformer_layer_spec): + for layer_spec in transformer_layer_spec.layer_specs: + layer_spec.module = Gemma4TransformerLayer + register_model( ModelMeta( ModelType.gemma4, diff --git a/tests/test_mllm.py b/tests/test_mllm.py index f4832f7..bacf76c 100644 --- a/tests/test_mllm.py +++ b/tests/test_mllm.py @@ -112,6 +112,10 @@ def test_llava_onevision1_5(): _test_model('lmms-lab/LLaVA-OneVision-1.5-4B-Instruct') +def test_gemma4(): + _test_model('google/gemma-4-E2B-it') + + if __name__ == '__main__': # test_qwen2_5_vl() # test_qwen2_vl() @@ -131,4 +135,5 @@ def test_llava_onevision1_5(): # test_qwen3_omni() # test_llama4() # test_qwen3_5() - test_llava_onevision1_5() + # test_llava_onevision1_5() + test_gemma4() From 2a81bf056b08fb3784c731f56d63d05c90bbee86 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 10:47:07 +0800 Subject: [PATCH 14/52] update --- src/mcore_bridge/model/mm_gpts/gemma4.py | 51 +++++++++++++++++++++--- 1 file changed, 46 insertions(+), 5 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index a45fb5a..e3ef62d 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -53,14 +53,55 @@ def __init__( **kwargs, ): text_config = config.hf_config.text_config - self.is_sliding = text_config.layer_types[layer_number - 1] == 'sliding_attention' + layer_idx = layer_number - 1 + + # Layer type / sliding attention + self.layer_type = text_config.layer_types[layer_idx] + self.is_sliding = self.layer_type == 'sliding_attention' self.sliding_window = text_config.sliding_window if self.is_sliding else None - kv_channels = config.kv_channels - config.kv_channels = ( + + # Head dim: global layers may use a different head dim than sliding ones + self.head_dim = ( text_config.global_head_dim if not self.is_sliding and text_config.global_head_dim else text_config.head_dim) - super().__init__(config, submodules, layer_number, *args, **kwargs) - config.kv_channels = kv_channels + + # Alternative attention (k == v) for global layers when `attention_k_eq_v` is set + self.use_alternative_attention = ( + getattr(text_config, 'attention_k_eq_v', False) and not self.is_sliding) + num_key_value_heads = ( + text_config.num_global_key_value_heads + if self.use_alternative_attention else text_config.num_key_value_heads) + self.num_key_value_groups = text_config.num_attention_heads // num_key_value_heads + + self.is_causal = getattr(text_config, 'use_bidirectional_attention', None) != 'all' + + # Shared KV across the trailing layers + num_kv_shared_layers = getattr(text_config, 'num_kv_shared_layers', 0) + first_kv_shared_layer_idx = text_config.num_hidden_layers - num_kv_shared_layers + self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 + prev_layers = text_config.layer_types[:first_kv_shared_layer_idx] + if self.is_kv_shared_layer: + # For shared layers, reuse KV from the last non-shared layer of the same type + self.kv_shared_layer_index = ( + len(prev_layers) - 1 - prev_layers[::-1].index(self.layer_type)) + self.store_full_length_kv = False + else: + self.kv_shared_layer_index = None + # Non-shared layers that are the last of their type in `prev_layers` must keep full KV + self.store_full_length_kv = ( + self.layer_type in prev_layers + and layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index(self.layer_type)) + + # Patch config so the underlying linear_qkv is built with the correct shapes + orig_kv_channels = config.kv_channels + orig_num_query_groups = config.num_query_groups + config.kv_channels = self.head_dim + config.num_query_groups = num_key_value_heads + try: + super().__init__(config, submodules, layer_number, *args, **kwargs) + finally: + config.kv_channels = orig_kv_channels + config.num_query_groups = orig_num_query_groups class Gemma4MLP(MLP): From 7e05d3d70ebdcdcea606b26120a6e0af4cdca8d4 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 10:51:32 +0800 Subject: [PATCH 15/52] update --- src/mcore_bridge/model/gpts/minimax_m2.py | 8 +++++--- src/mcore_bridge/model/mm_gpts/gemma4.py | 6 ++++-- src/mcore_bridge/model/modules/gated_delta_net.py | 7 +++++-- src/mcore_bridge/model/modules/mtp_layer.py | 7 +++++-- 4 files changed, 19 insertions(+), 9 deletions(-) diff --git a/src/mcore_bridge/model/gpts/minimax_m2.py b/src/mcore_bridge/model/gpts/minimax_m2.py index c03f803..7830215 100644 --- a/src/mcore_bridge/model/gpts/minimax_m2.py +++ b/src/mcore_bridge/model/gpts/minimax_m2.py @@ -27,9 +27,11 @@ def __init__( k_layernorm = submodules.k_layernorm submodules.q_layernorm = IdentityOp submodules.k_layernorm = IdentityOp - super().__init__(config, submodules, *args, **kwargs) - submodules.q_layernorm = q_layernorm - submodules.k_layernorm = k_layernorm + try: + super().__init__(config, submodules, *args, **kwargs) + finally: + submodules.q_layernorm = q_layernorm + submodules.k_layernorm = k_layernorm self.q_norm = build_module( submodules.q_layernorm, hidden_size=self.hidden_size_per_attention_head * config.num_attention_heads, diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index e3ef62d..6c267ce 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -122,8 +122,10 @@ def __init__( use_double_wide_mlp = text_config.use_double_wide_mlp and is_kv_shared_layer ffn_hidden_size = config.ffn_hidden_size config.ffn_hidden_size = config.ffn_hidden_size * (2 if use_double_wide_mlp else 1) - super().__init__(config, submodules, *args, **kwargs) - config.ffn_hidden_size = ffn_hidden_size + try: + super().__init__(config, submodules, *args, **kwargs) + finally: + config.ffn_hidden_size = ffn_hidden_size class Gemma4Bridge(MultimodalGPTBridge): diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index ef150b7..5fcb167 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -96,10 +96,13 @@ def __init__(self, config: ModelConfig, submodules: 'GatedDeltaNetSubmodules', * submodules.in_proj = IdentityOp if 'cp_comm_type' not in inspect.signature(_GatedDeltaNet).parameters: kwargs.pop('cp_comm_type', None) - super().__init__(config, submodules, *args, **kwargs) + try: + super().__init__(config, submodules, *args, **kwargs) + finally: + if config.linear_decoupled_in_proj: + submodules.in_proj = in_proj if not config.linear_decoupled_in_proj: return - submodules.in_proj = in_proj self.in_proj_qkvz_dim = self.qk_dim * 2 + self.v_dim * 2 self.in_proj_ba_dim = self.num_value_heads * 2 del self.in_proj diff --git a/src/mcore_bridge/model/modules/mtp_layer.py b/src/mcore_bridge/model/modules/mtp_layer.py index 537abf3..8be6aeb 100644 --- a/src/mcore_bridge/model/modules/mtp_layer.py +++ b/src/mcore_bridge/model/modules/mtp_layer.py @@ -29,11 +29,14 @@ def __init__(self, config: ModelConfig, submodules, *args, **kwargs): if config.fp8_param: eh_proj = submodules.eh_proj submodules.eh_proj = IdentityOp - super().__init__(config, submodules, *args, **kwargs) + try: + super().__init__(config, submodules, *args, **kwargs) + finally: + if config.fp8_param: + submodules.eh_proj = eh_proj self.tp_group = getattr(self, 'tp_group', None) if not config.fp8_param: return - submodules.eh_proj = eh_proj fp8_context = transformer_engine.pytorch.fp8_model_init(enabled=False) with fp8_context: self.eh_proj = build_module( From 8da05dfc646c04cb857471b53b7e42a17470020e Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 10:58:44 +0800 Subject: [PATCH 16/52] fix --- src/mcore_bridge/model/mm_gpts/gemma4.py | 30 +++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 6c267ce..e14716d 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -1,7 +1,9 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import copy +import torch from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP from transformers import AutoModel, PretrainedConfig from typing import Optional @@ -18,6 +20,20 @@ from ..module import CustomTransformerLayer +class Gemma4VNorm(torch.nn.Module): + """RMSNorm without learnable scale, mirroring HF `Gemma4RMSNorm(with_scale=False)`.""" + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + orig_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(-1, keepdim=True) + return (x * torch.rsqrt(variance + self.eps)).to(orig_dtype) + + class Gemma4Vit(HuggingFaceVit): module_mapping = { 'model.vision_tower': 'vision_tower', @@ -92,16 +108,28 @@ def __init__( self.layer_type in prev_layers and layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index(self.layer_type)) - # Patch config so the underlying linear_qkv is built with the correct shapes + # Patch config so the underlying linear_qkv/q_layernorm/k_layernorm are built correctly. + # HF keeps `q_norm` on every layer, but only builds `k_norm`/`v_norm` on non-kv-shared + # layers, so replace `k_layernorm` with `IdentityOp` when this layer shares KV. orig_kv_channels = config.kv_channels orig_num_query_groups = config.num_query_groups + orig_k_layernorm = submodules.k_layernorm config.kv_channels = self.head_dim config.num_query_groups = num_key_value_heads + if self.is_kv_shared_layer: + submodules.k_layernorm = IdentityOp try: super().__init__(config, submodules, layer_number, *args, **kwargs) finally: config.kv_channels = orig_kv_channels config.num_query_groups = orig_num_query_groups + submodules.k_layernorm = orig_k_layernorm + + # HF builds a `v_norm` (RMSNorm without learnable scale) for non-kv-shared layers. + # mcore's SelfAttention has no v_layernorm by default, so attach one explicitly here. + self.v_norm = ( + Gemma4VNorm(self.head_dim, eps=self.config.layernorm_epsilon) + if not self.is_kv_shared_layer else None) class Gemma4MLP(MLP): From d1eff8a9f4e25bec42e770813fe47eb3a4b787a5 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 11:00:09 +0800 Subject: [PATCH 17/52] fix --- src/mcore_bridge/config/parser.py | 2 -- src/mcore_bridge/model/mm_gpts/gemma4.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/mcore_bridge/config/parser.py b/src/mcore_bridge/config/parser.py index 68ef8e3..877b63c 100644 --- a/src/mcore_bridge/config/parser.py +++ b/src/mcore_bridge/config/parser.py @@ -149,8 +149,6 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]: n_shared_experts = res.pop('n_shared_experts') elif llm_model_type in {'ernie4_5', 'ernie4_5_moe', 'glm4'}: res['rotary_interleaved'] = True - elif hf_model_type == 'gemma4': - config.qk_layernorm = True elif llm_model_type == 'gpt_oss': res['add_bias_linear'] = True res['bias_dropout_fusion'] = False diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index e3ef62d..e7dcace 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -15,7 +15,7 @@ from ..register import ModelLoader, ModelMeta, register_model from ..rope import get_rope_inv_freq from .utils import HuggingFaceVit -from ..module import CustomTransformerLayer +from ..modules import CustomTransformerLayer class Gemma4Vit(HuggingFaceVit): From 0c22e68aec252d8d4a1cad9ef79fbc3d27e3de79 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 11:02:07 +0800 Subject: [PATCH 18/52] fix --- src/mcore_bridge/model/mm_gpts/gemma4.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index e14716d..be0a17f 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -89,8 +89,6 @@ def __init__( if self.use_alternative_attention else text_config.num_key_value_heads) self.num_key_value_groups = text_config.num_attention_heads // num_key_value_heads - self.is_causal = getattr(text_config, 'use_bidirectional_attention', None) != 'all' - # Shared KV across the trailing layers num_kv_shared_layers = getattr(text_config, 'num_kv_shared_layers', 0) first_kv_shared_layer_idx = text_config.num_hidden_layers - num_kv_shared_layers From fa5360be7b1a8624b66d0c6d232acab958577408 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 11:13:45 +0800 Subject: [PATCH 19/52] fix --- src/mcore_bridge/config/parser.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/mcore_bridge/config/parser.py b/src/mcore_bridge/config/parser.py index 877b63c..21d34c5 100644 --- a/src/mcore_bridge/config/parser.py +++ b/src/mcore_bridge/config/parser.py @@ -149,6 +149,8 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]: n_shared_experts = res.pop('n_shared_experts') elif llm_model_type in {'ernie4_5', 'ernie4_5_moe', 'glm4'}: res['rotary_interleaved'] = True + elif hf_model_type in {'gemma4'}: + res['qk_layernorm'] = True elif llm_model_type == 'gpt_oss': res['add_bias_linear'] = True res['bias_dropout_fusion'] = False From d25db28302db3cdcb2ad8aa6a75dd9cbfa03cece Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 11:14:39 +0800 Subject: [PATCH 20/52] update --- src/mcore_bridge/bridge/gpt_bridge.py | 9 +++++++-- src/mcore_bridge/model/mm_gpts/gemma4.py | 21 ++++++++++++--------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index c380328..62f6cb7 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -439,7 +439,8 @@ def _set_state_dict(self, to_mcore: bool, *, offset: float = 0, - is_expert: bool = False): + is_expert: bool = False, + _check_mg_param: bool = True): if '.' in mg_key: module_key, param_key = mg_key.rsplit('.', 1) else: @@ -487,7 +488,11 @@ def _set_state_dict(self, else: mg_param = deep_getattr(sub_module, param_key) if to_mcore: - assert mg_param is not None, f'mg_module: {mg_module}, mg_key: {mg_key}' + if mg_param is None: + if _check_mg_param: + raise ValueError(f'mg_module: {mg_module}, mg_key: {mg_key}') + else: + return hf_weight = hf_state_dict[hf_key].load() if module_key in { 'embedding.word_embeddings', 'output_layer' diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index a41fdf0..aa4b7fb 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -14,10 +14,10 @@ from ..constant import ModelType from ..gpt_model import GPTModel from ..mm_gpt_model import MultimodalGPTModel +from ..modules import CustomTransformerLayer from ..register import ModelLoader, ModelMeta, register_model from ..rope import get_rope_inv_freq from .utils import HuggingFaceVit -from ..modules import CustomTransformerLayer class Gemma4VNorm(torch.nn.Module): @@ -82,8 +82,7 @@ def __init__( if not self.is_sliding and text_config.global_head_dim else text_config.head_dim) # Alternative attention (k == v) for global layers when `attention_k_eq_v` is set - self.use_alternative_attention = ( - getattr(text_config, 'attention_k_eq_v', False) and not self.is_sliding) + self.use_alternative_attention = (getattr(text_config, 'attention_k_eq_v', False) and not self.is_sliding) num_key_value_heads = ( text_config.num_global_key_value_heads if self.use_alternative_attention else text_config.num_key_value_heads) @@ -96,8 +95,7 @@ def __init__( prev_layers = text_config.layer_types[:first_kv_shared_layer_idx] if self.is_kv_shared_layer: # For shared layers, reuse KV from the last non-shared layer of the same type - self.kv_shared_layer_index = ( - len(prev_layers) - 1 - prev_layers[::-1].index(self.layer_type)) + self.kv_shared_layer_index = (len(prev_layers) - 1 - prev_layers[::-1].index(self.layer_type)) self.store_full_length_kv = False else: self.kv_shared_layer_index = None @@ -126,8 +124,7 @@ def __init__( # HF builds a `v_norm` (RMSNorm without learnable scale) for non-kv-shared layers. # mcore's SelfAttention has no v_layernorm by default, so attach one explicitly here. self.v_norm = ( - Gemma4VNorm(self.head_dim, eps=self.config.layernorm_epsilon) - if not self.is_kv_shared_layer else None) + Gemma4VNorm(self.head_dim, eps=self.config.layernorm_epsilon) if not self.is_kv_shared_layer else None) class Gemma4MLP(MLP): @@ -155,7 +152,12 @@ def __init__( class Gemma4Bridge(MultimodalGPTBridge): - pass + + def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore): + self._set_state_dict( + mg_attn, 'q_layernorm.weight', hf_state_dict, self.hf_q_norm_key, to_mcore, _check_mg_param=False) + self._set_state_dict( + mg_attn, 'k_layernorm.weight', hf_state_dict, self.hf_k_norm_key, to_mcore, _check_mg_param=False) class Gemma4TextGPTModel(GPTModel): @@ -183,6 +185,7 @@ def _set_inv_freq(self): self.config.rope_scaling = rope_scaling + class Gemma4TransformerLayer(CustomTransformerLayer): pass @@ -202,11 +205,11 @@ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): layer_spec.submodules.mlp.module = Gemma4MLP return layer_specs - def _set_custom_layer(self, transformer_layer_spec): for layer_spec in transformer_layer_spec.layer_specs: layer_spec.module = Gemma4TransformerLayer + register_model( ModelMeta( ModelType.gemma4, From bfbcbc4b2747e1fd14eba117460f887ac14d4fe2 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 11:44:16 +0800 Subject: [PATCH 21/52] update --- .../model/modules/transformer_layer.py | 66 ++++++++++--------- 1 file changed, 34 insertions(+), 32 deletions(-) diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index be07631..74fbf59 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -1,11 +1,14 @@ import enum import inspect import torch +from megatron.core.extensions.transformer_engine import TEFusedMLP from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.tensor_parallel.mappings import (gather_from_sequence_parallel_region, scatter_to_sequence_parallel_region) from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP +from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP +from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import (TransformerLayer, TransformerLayerSubmodules, @@ -126,43 +129,13 @@ def __init__( hidden_size=self.config.hidden_size, eps=self.config.layernorm_epsilon, ) - # [Module 8: MLP block] - additional_mlp_kwargs = {} - # import here to avoid circular import - from megatron.core.extensions.transformer_engine import TEFusedMLP - from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP - from megatron.core.transformer.moe.moe_layer import MoELayer - - from mcore_bridge.model.gpts.glm4 import Glm4MLP - from mcore_bridge.model.mm_gpts.gemma4 import Gemma4MLP - # MLP expects tp_group but MoELayer expects pg_collection to be passed in. - # We can change MLP to accept pg_collection but it makes the logic implicit - # The conditional below is to make the logic explicit - # if submodules.mlp is not a ModuleSpec,we dont have to handle passing additional kwargs - if isinstance(submodules.mlp, ModuleSpec): - if submodules.mlp.module in (MoELayer, TEGroupedMLP, SequentialMLP): - additional_mlp_kwargs['pg_collection'] = pg_collection - # Pass is_mtp_layer flag to MoELayer to distinguish MTP MoE layers. - if submodules.mlp.module == MoELayer and 'is_mtp_layer' in inspect.signature(MoELayer).parameters: - additional_mlp_kwargs['is_mtp_layer'] = self.is_mtp_layer - elif submodules.mlp.module in (MLP, Glm4MLP): - assert hasattr(pg_collection, 'tp'), 'TP process group is required for MLP in TransformerLayer' - additional_mlp_kwargs['tp_group'] = pg_collection.tp - elif submodules.mlp.module == Gemma4MLP: - additional_mlp_kwargs['layer_number'] = layer_number - elif TEFusedMLP is not None and submodules.mlp.module == TEFusedMLP: - assert hasattr(pg_collection, 'tp'), 'TP process group is required for TEFusedMLP in TransformerLayer' - additional_mlp_kwargs['tp_group'] = pg_collection.tp - else: - logger.warning_once(f'Unknown MLP type: {submodules.mlp.module}. Using default kwargs.') - self.mlp = build_module(submodules.mlp, config=self.config, **additional_mlp_kwargs) + # [Module 8: MLP block] + self.mlp = self._build_mlp(submodules.mlp) if hasattr(self.mlp, 'set_layer_number'): self.mlp.set_layer_number(self.layer_number) - # [Module 9: BiasDropoutFusion] self.mlp_bda = build_module(submodules.mlp_bda) - self.is_moe_layer = isinstance(self.mlp, MoELayer) self.recompute_input_layernorm = False @@ -234,6 +207,35 @@ def can_recompute_pre_mlp_layernorm_for_cudagraph(): # self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad self.bias_dropout_add_exec_handler = torch.enable_grad + def _build_mlp(self, mlp_spec): + pg_collection = self.pg_collection + additional_mlp_kwargs = {} + # import here to avoid circular import + from mcore_bridge.model.gpts.glm4 import Glm4MLP + from mcore_bridge.model.mm_gpts.gemma4 import Gemma4MLP + + # MLP expects tp_group but MoELayer expects pg_collection to be passed in. + # We can change MLP to accept pg_collection but it makes the logic implicit + # The conditional below is to make the logic explicit + # if smlp_spec is not a ModuleSpec,we dont have to handle passing additional kwargs + if isinstance(mlp_spec, ModuleSpec): + if mlp_spec.module in (MoELayer, TEGroupedMLP, SequentialMLP): + additional_mlp_kwargs['pg_collection'] = pg_collection + # Pass is_mtp_layer flag to MoELayer to distinguish MTP MoE layers. + if mlp_spec.module == MoELayer and 'is_mtp_layer' in inspect.signature(MoELayer).parameters: + additional_mlp_kwargs['is_mtp_layer'] = self.is_mtp_layer + elif mlp_spec.module in (MLP, Glm4MLP): + assert hasattr(pg_collection, 'tp'), 'TP process group is required for MLP in TransformerLayer' + additional_mlp_kwargs['tp_group'] = pg_collection.tp + elif mlp_spec.module == Gemma4MLP: + additional_mlp_kwargs['layer_number'] = self.layer_number + elif TEFusedMLP is not None and mlp_spec.module == TEFusedMLP: + assert hasattr(pg_collection, 'tp'), 'TP process group is required for TEFusedMLP in TransformerLayer' + additional_mlp_kwargs['tp_group'] = pg_collection.tp + else: + logger.warning_once(f'Unknown MLP type: {mlp_spec.module}. Using default kwargs.') + self.mlp = build_module(mlp_spec, config=self.config, **additional_mlp_kwargs) + def forward(self, *args, **kwargs): """ Perform a forward pass through the transformer layer. From 2300825168648b45820217bfa5941afb49ed4bac Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 13:13:52 +0800 Subject: [PATCH 22/52] fix --- src/mcore_bridge/bridge/gpt_bridge.py | 11 +++++++---- src/mcore_bridge/model/modules/transformer_layer.py | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 62f6cb7..837f5f9 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -40,9 +40,12 @@ class GPTBridge: hf_o_proj_key = 'o_proj' hf_attn_prefix = 'self_attn' hf_mlp_prefix = 'mlp' + hf_post_attention_layernorm = 'post_attention_layernorm' hf_gate_key = 'gate.weight' hf_shared_expert_key = None hf_expert_bias_key = 'gate.e_score_correction_bias' + additional_dim0_keys = {} + additional_dim1_keys = {} def __init__(self, config: ModelConfig): self.config = config @@ -124,11 +127,11 @@ def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]: 'linear_kv_up_proj', # mtp 'eh_proj', - } + } & self.additional_dim0_keys if self.config.task_type in {'causal_lm', 'generative_reranker'}: dim0_keys.add('output_layer') # RowLinear - dim1_keys = {'out_proj', 'linear_proj', 'linear_fc2'} + dim1_keys = {'out_proj', 'linear_proj', 'linear_fc2'} & self.additional_dim1_keys if 'lora_A' not in mg_key and 'lora_B' not in mg_key: key, suffix = mg_key.rsplit('.', 2)[-2:] if suffix == 'layer_norm_weight': @@ -1592,13 +1595,13 @@ def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool hf_state_dict.update( self._set_moe_state( mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore, is_mtp=is_mtp)) - self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, 'post_attention_layernorm.weight', + self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, f'{self.hf_post_attention_layernorm}.weight', to_mcore) else: hf_state_dict.update( self._set_mlp_state(mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore)) self._set_state_dict(mg_layer, 'mlp.linear_fc1.layer_norm_weight', hf_state_dict, - 'post_attention_layernorm.weight', to_mcore) + f'{self.hf_post_attention_layernorm}.weight', to_mcore) return hf_state_dict def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index 74fbf59..b3bd54b 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -234,7 +234,7 @@ def _build_mlp(self, mlp_spec): additional_mlp_kwargs['tp_group'] = pg_collection.tp else: logger.warning_once(f'Unknown MLP type: {mlp_spec.module}. Using default kwargs.') - self.mlp = build_module(mlp_spec, config=self.config, **additional_mlp_kwargs) + return build_module(mlp_spec, config=self.config, **additional_mlp_kwargs) def forward(self, *args, **kwargs): """ From cda31a5117211c8361f904a719767c879b71c100 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 14:30:39 +0800 Subject: [PATCH 23/52] update --- src/mcore_bridge/bridge/gpt_bridge.py | 11 +- src/mcore_bridge/model/mm_gpts/gemma4.py | 189 ++++++++++++++++++++++- 2 files changed, 194 insertions(+), 6 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 837f5f9..2c410a2 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -1595,8 +1595,8 @@ def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool hf_state_dict.update( self._set_moe_state( mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore, is_mtp=is_mtp)) - self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, f'{self.hf_post_attention_layernorm}.weight', - to_mcore) + self._set_state_dict(mg_layer, 'pre_mlp_layernorm.weight', hf_state_dict, + f'{self.hf_post_attention_layernorm}.weight', to_mcore) else: hf_state_dict.update( self._set_mlp_state(mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore)) @@ -1618,13 +1618,16 @@ def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: i hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict + def _set_word_embeddings(self, mg_model, hf_state_dict, to_mcore): + lm_model = getattr(mg_model, 'language_model') if self.is_multimodal else mg_model + self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, self.hf_embed_key, to_mcore) + def _convert_pre_process(self, mg_model, hf_state_dict, hf_prefix: str, to_mcore): if to_mcore: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) else: hf_state_dict = {} - lm_model = getattr(mg_model, 'language_model') if self.is_multimodal else mg_model - self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, self.hf_embed_key, to_mcore) + self._set_word_embeddings(mg_model, hf_state_dict, to_mcore) if self.is_multimodal: for prefix, mg_prefix in self.module_mapping.items(): mg_module = deep_getattr(mg_model, f'visual.{mg_prefix}') diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index aa4b7fb..7ba0faf 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -1,10 +1,14 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import copy +import math import torch +from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TENorm, TERowParallelLinear from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec +from megatron.core.tensor_parallel import VocabParallelEmbedding from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP +from megatron.core.transformer.spec_utils import build_module from transformers import AutoModel, PretrainedConfig from typing import Optional @@ -121,6 +125,25 @@ def __init__( config.num_query_groups = orig_num_query_groups submodules.k_layernorm = orig_k_layernorm + # HF kv-shared layers only keep `q_proj` (K/V are reused from an earlier layer), so the + # default mcore `linear_qkv` shape `[Q + 2*KV, hidden]` over-allocates. Rebuild it with + # out_dim = query_projection_size so shapes match HF `q_proj` 1:1 for weight bridging. + # Mirrors attention.py L1275-L1289, minus the `+ 2 * kv_projection_size` term. + if self.is_kv_shared_layer: + self.linear_qkv_out_dim = self.query_projection_size + self.linear_qkv = submodules.linear_qkv( + self.config.hidden_size, + self.linear_qkv_out_dim, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.config.add_bias_linear or self.config.add_qkv_bias, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='qkv', + tp_group=self.pg_collection.tp, + ) + # HF builds a `v_norm` (RMSNorm without learnable scale) for non-kv-shared layers. # mcore's SelfAttention has no v_layernorm by default, so attach one explicitly here. self.v_norm = ( @@ -152,6 +175,9 @@ def __init__( class Gemma4Bridge(MultimodalGPTBridge): + hf_post_attention_layernorm = 'pre_feedforward_layernorm' + additional_dim0_keys = {'per_layer_input_gate', 'per_layer_model_projection'} + additional_dim1_keys = {'per_layer_projection'} def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore): self._set_state_dict( @@ -159,12 +185,110 @@ def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore): self._set_state_dict( mg_attn, 'k_layernorm.weight', hf_state_dict, self.hf_k_norm_key, to_mcore, _check_mg_param=False) + def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool): + is_kv_shared_layer = False if mg_attn is None else mg_attn.is_kv_shared_layer + is_kv_shared_layer = torch.tensor([is_kv_shared_layer], dtype=torch.bool, device='cuda') + if self.pp_size > 1: + dist.all_reduce(is_lora, group=self.pp_group, op=dist.ReduceOp.MAX) + is_kv_shared_layer = is_kv_shared_layer.item() + if is_kv_shared_layer: + self._set_state_dict(mg_attn, 'linear_qkv.weight', hf_state_dict, 'q_proj.weight', to_mcore) + return hf_state_dict + else: + return super()._set_qkv(mg_attn, hf_state_dict, to_mcore) + + def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): + hf_prefix = f'{hf_prefix}{layer_idx}.' + if to_mcore: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + else: + hf_state_dict = {} + hf_state_dict.update(self._set_layer_attn(mg_layer, hf_state_dict, layer_idx, to_mcore)) + hf_state_dict.update(self._set_layer_mlp(mg_layer, hf_state_dict, layer_idx, to_mcore)) + for key in [ + 'post_attention_layernorm', 'post_feedforward_layernorm', 'per_layer_input_gate', + 'per_layer_projection', 'post_per_layer_input_norm' + ]: + self._set_state_dict( + mg_layer, + f'{key}.weight', + hf_state_dict if to_mcore else new_hf_state_dict, + f'{key}.weight', + to_mcore, + _check_mg_param=False) + if to_mcore: + hf_state_dict = {} + else: + hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) + return hf_state_dict + + def _set_word_embeddings(self, mg_model, hf_state_dict, to_mcore): + lm_model = getattr(mg_model, 'language_model') if self.is_multimodal else mg_model + self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, self.hf_embed_key, to_mcore) + for key in ['embed_tokens_per_layer', 'per_layer_model_projection', 'per_layer_projection_norm']: + self._set_state_dict(lm_model, f'{key}.weight', hf_state_dict, f'model.language_model.{key}.weight', + to_mcore) + class Gemma4TextGPTModel(GPTModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - print() + text_config = self.config.hf_config.text_config + # HF: `self.unique_layer_types = set(self.config.layer_types)` — needed by the rotary + # embedding selection logic (sliding vs global) when that path is wired up. + self.unique_layer_types = set(text_config.layer_types) + + # HF: Per-Layer Embeddings (PLE). Only populated on the pre-process (PP stage 0) side, + # since the auxiliary signal is derived from `input_ids` / the token embedding output. + # See `modeling_gemma4.py` L1574-L1592 for the reference construction. Built with + # megatron-native parallel modules (mirroring `LanguageModelEmbedding` at + # `gpt_model.py` L150-L157) so the aux signal follows the TP/SP layout of the + # primary embedding. + self.hidden_size_per_layer_input = getattr(text_config, 'hidden_size_per_layer_input', None) + if self.hidden_size_per_layer_input and self.pre_process: + num_layers = text_config.num_hidden_layers + hidden_size = text_config.hidden_size + total_dim = num_layers * self.hidden_size_per_layer_input + tp_size = self.config.tensor_model_parallel_size + # Pad aux vocab size to be TP-divisible, matching how `GPTModel` pads the main + # `padded_vocab_size` before feeding it into `VocabParallelEmbedding`. + padded_vocab_size_per_layer = math.ceil(text_config.vocab_size_per_layer_input / tp_size) * tp_size + # Vocab-parallel embedding (shard on vocab dim). HF's `Gemma4TextScaledWordEmbedding` + # applies an `embed_scale = hidden_size_per_layer_input**0.5` factor on forward; + # we capture the scale as a sibling attribute so the weight shape stays 1:1 with HF. + self.embed_tokens_per_layer = VocabParallelEmbedding( + num_embeddings=padded_vocab_size_per_layer, + embedding_dim=total_dim, + init_method=self.config.init_method, + config=self.config, + tp_group=self.pg_collection.tp, + ) + self.embed_tokens_per_layer_scale = self.hidden_size_per_layer_input**0.5 + self.per_layer_input_scale = 2.0**-0.5 + # Column-parallel projection: output dim `num_layers * hidden_size_per_layer_input` + # is split across TP ranks so each rank produces its own shard of the packed + # per-layer input tensor. + self.per_layer_model_projection = build_module( + TEColumnParallelLinear, + hidden_size, + total_dim, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=False, + skip_bias_add=True, + is_expert=False, + tp_comm_buffer_name='per_layer_model_projection', + tp_group=self.pg_collection.tp, + ) + self.per_layer_model_projection_scale = hidden_size**-0.5 + self.per_layer_projection_norm = build_module( + TENorm, + hidden_size=self.hidden_size_per_layer_input, + config=self.config, + eps=self.config.layernorm_epsilon, + ) def _set_inv_freq(self): rope_scaling = self.config.rope_scaling @@ -187,7 +311,68 @@ def _set_inv_freq(self): class Gemma4TransformerLayer(CustomTransformerLayer): - pass + + def __init__(self, config, submodules, *args, **kwargs): + super().__init__(config, submodules, *args, **kwargs) + text_config = config.hf_config.text_config + hidden_size = self.config.hidden_size + eps = self.config.layernorm_epsilon + + # HF keeps an extra layernorm after self-attn / feedforward (before the residual add). + # mcore's TransformerLayer does not include these, so attach them here. + self.post_attention_layernorm = build_module(TENorm, hidden_size=hidden_size, config=self.config, eps=eps) + self.post_feedforward_layernorm = build_module(TENorm, hidden_size=hidden_size, config=self.config, eps=eps) + + # HF: `self.register_buffer("layer_scalar", torch.ones(1))` + self.register_buffer('layer_scalar', torch.ones(1)) + + # HF: per-layer input projection branch, only when `hidden_size_per_layer_input` is set. + self.hidden_size_per_layer_input = getattr(text_config, 'hidden_size_per_layer_input', None) + if self.hidden_size_per_layer_input: + from transformers.activations import ACT2FN + self.act_fn = ACT2FN[text_config.hidden_activation] + # Megatron-style parallel linears (see attention.py L348-361 for `linear_proj`): + # `per_layer_input_gate` is column-parallel (output dim split across TP), then its + # output is consumed by the row-parallel `per_layer_projection` which gathers along TP. + self.per_layer_input_gate = build_module( + TEColumnParallelLinear, + hidden_size, + self.hidden_size_per_layer_input, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=False, + skip_bias_add=True, + is_expert=False, + tp_comm_buffer_name='per_layer_input_gate', + tp_group=self.pg_collection.tp, + ) + self.per_layer_projection = build_module( + TERowParallelLinear, + self.hidden_size_per_layer_input, + hidden_size, + config=self.config, + init_method=self.config.output_layer_init_method, + bias=False, + input_is_parallel=True, + skip_bias_add=True, + is_expert=False, + tp_comm_buffer_name='per_layer_projection', + tp_group=self.pg_collection.tp, + ) + self.post_per_layer_input_norm = build_module(TENorm, hidden_size=hidden_size, config=self.config, eps=eps) + + # HF: extra layernorms when the layer runs a MoE block in parallel with the dense MLP. + # Router / experts modules are gemma4-specific and intentionally skipped here; they can + # be wired by the bridge/forward override once their mcore counterparts are implemented. + self.enable_moe_block = getattr(text_config, 'enable_moe_block', False) + if self.enable_moe_block: + self.post_feedforward_layernorm_1 = build_module( + TENorm, hidden_size=hidden_size, config=self.config, eps=eps) + self.post_feedforward_layernorm_2 = build_module( + TENorm, hidden_size=hidden_size, config=self.config, eps=eps) + self.pre_feedforward_layernorm_2 = build_module( + TENorm, hidden_size=hidden_size, config=self.config, eps=eps) class Gemma4GPTModel(MultimodalGPTModel): From 7e6fb75b421009316e07009590b2788c70e36d03 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 16:52:11 +0800 Subject: [PATCH 24/52] update --- src/mcore_bridge/model/mm_gpts/gemma4.py | 54 +++++++++++++++++++++++- 1 file changed, 52 insertions(+), 2 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 7ba0faf..54dc50d 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -47,18 +47,65 @@ class Gemma4Vit(HuggingFaceVit): } _vision_tower = ['vision_tower', 'audio_tower'] _aligner = ['embed_vision', 'embed_audio'] - support_multimodal = False def prepare_model(self, hf_config: PretrainedConfig): - from transformers.models.gemma4.modeling_gemma4 import Gemma4MultimodalEmbedder + from transformers.models.gemma4.modeling_gemma4 import Gemma4MultimodalEmbedder, Gemma4Model self.vision_tower = AutoModel.from_config(hf_config.vision_config) self.audio_tower = AutoModel.from_config(hf_config.audio_config) if hf_config.audio_config is not None else None self.embed_vision = Gemma4MultimodalEmbedder(hf_config.vision_config, hf_config.text_config) self.embed_audio = ( Gemma4MultimodalEmbedder(hf_config.audio_config, hf_config.text_config) if hf_config.audio_config is not None else None) + self.register_buffer("embed_scale", torch.tensor(hf_config.hidden_size**0.5), persistent=False) + self.model_cls = Gemma4Model def get_inputs_embeds(self, inputs_embeds, **kwargs): + input_ids = kwargs.get('input_ids') + inputs_embeds *= self.embed_scale.to(inputs_embeds.dtype) + + hf_config = self.hf_config + input_ids = kwargs.get('input_ids') + pixel_values = kwargs.get('pixel_values') + pixel_values_videos = kwargs.get('pixel_values_videos') + input_features = kwargs.get('input_features') + input_features_mask = kwargs.get('input_features_mask') + image_position_ids = kwargs.get('image_position_ids') + video_position_ids = kwargs.get('video_position_ids') + + image_mask = input_ids == hf_config.image_token_id + video_mask = input_ids == hf_config.video_token_id + audio_mask = input_ids == hf_config.audio_token_id + + if pixel_values is not None: + vision_outputs = self.vision_tower( + pixel_values=pixel_values.to(self.vision_tower.dtype), + pixel_position_ids=image_position_ids, + ) + image_features = self.embed_vision(inputs_embeds=vision_outputs.last_hidden_state) + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + image_mask_e = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter(image_mask_e, image_features) + + if pixel_values_videos is not None: + pixel_values_videos_flat = pixel_values_videos.flatten(0, 1) + video_position_ids_flat = video_position_ids.flatten(0, 1) if video_position_ids is not None else None + vision_outputs = self.vision_tower( + pixel_values=pixel_values_videos_flat.to(self.vision_tower.dtype), + pixel_position_ids=video_position_ids_flat, + ) + video_features = self.embed_vision(inputs_embeds=vision_outputs.last_hidden_state) + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + video_mask_e = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter(video_mask_e, video_features) + + if (input_features is not None and input_features_mask is not None and self.audio_tower is not None): + audio_outputs = self.audio_tower(input_features, input_features_mask, return_dict=True) + audio_features = self.embed_audio(inputs_embeds=audio_outputs.last_hidden_state) + audio_features = audio_features[audio_outputs.attention_mask] + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + audio_mask_e = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter(audio_mask_e, audio_features) + return inputs_embeds @@ -309,6 +356,9 @@ def _set_inv_freq(self): self.config.rope_scaling = rope_scaling + def forward(self): + pass + class Gemma4TransformerLayer(CustomTransformerLayer): From e1d085192d83801dd6e1348d1ee2bdc19821a545 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 16:52:22 +0800 Subject: [PATCH 25/52] update --- src/mcore_bridge/bridge/gpt_bridge.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 2c410a2..a1899d7 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -44,8 +44,8 @@ class GPTBridge: hf_gate_key = 'gate.weight' hf_shared_expert_key = None hf_expert_bias_key = 'gate.e_score_correction_bias' - additional_dim0_keys = {} - additional_dim1_keys = {} + additional_dim0_keys = set() + additional_dim1_keys = set() def __init__(self, config: ModelConfig): self.config = config From e3cbe5db1bfdca03c1c35fd833a9b5924bb88237 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 9 May 2026 18:13:19 +0800 Subject: [PATCH 26/52] update --- src/mcore_bridge/model/mm_gpts/gemma4.py | 101 ++++++++--------------- 1 file changed, 36 insertions(+), 65 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 54dc50d..3b92bbf 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -2,6 +2,7 @@ import copy import math import torch +import torch.distributed as dist from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TENorm, TERowParallelLinear from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec from megatron.core.tensor_parallel import VocabParallelEmbedding @@ -49,14 +50,15 @@ class Gemma4Vit(HuggingFaceVit): _aligner = ['embed_vision', 'embed_audio'] def prepare_model(self, hf_config: PretrainedConfig): - from transformers.models.gemma4.modeling_gemma4 import Gemma4MultimodalEmbedder, Gemma4Model + from transformers.models.gemma4.modeling_gemma4 import Gemma4Model, Gemma4MultimodalEmbedder self.vision_tower = AutoModel.from_config(hf_config.vision_config) + dtype = self.vision_tower.dtype self.audio_tower = AutoModel.from_config(hf_config.audio_config) if hf_config.audio_config is not None else None - self.embed_vision = Gemma4MultimodalEmbedder(hf_config.vision_config, hf_config.text_config) + self.embed_vision = Gemma4MultimodalEmbedder(hf_config.vision_config, hf_config.text_config).to(dtype) self.embed_audio = ( - Gemma4MultimodalEmbedder(hf_config.audio_config, hf_config.text_config) + Gemma4MultimodalEmbedder(hf_config.audio_config, hf_config.text_config).to(dtype) if hf_config.audio_config is not None else None) - self.register_buffer("embed_scale", torch.tensor(hf_config.hidden_size**0.5), persistent=False) + self.register_buffer('embed_scale', torch.tensor(hf_config.hidden_size**0.5).to(dtype), persistent=False) self.model_cls = Gemma4Model def get_inputs_embeds(self, inputs_embeds, **kwargs): @@ -75,38 +77,35 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): image_mask = input_ids == hf_config.image_token_id video_mask = input_ids == hf_config.video_token_id audio_mask = input_ids == hf_config.audio_token_id + multimodal_mask = image_mask | video_mask | audio_mask + llm_input_ids = input_ids.clone() + llm_input_ids[multimodal_mask] = hf_config.text_config.pad_token_id if pixel_values is not None: - vision_outputs = self.vision_tower( - pixel_values=pixel_values.to(self.vision_tower.dtype), - pixel_position_ids=image_position_ids, - ) - image_features = self.embed_vision(inputs_embeds=vision_outputs.last_hidden_state) + with self.patch_hf_config(): + image_features = self.model_cls.get_image_features( + self, pixel_values, image_position_ids, return_dict=True).pooler_output image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) image_mask_e = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) inputs_embeds = inputs_embeds.masked_scatter(image_mask_e, image_features) if pixel_values_videos is not None: - pixel_values_videos_flat = pixel_values_videos.flatten(0, 1) - video_position_ids_flat = video_position_ids.flatten(0, 1) if video_position_ids is not None else None - vision_outputs = self.vision_tower( - pixel_values=pixel_values_videos_flat.to(self.vision_tower.dtype), - pixel_position_ids=video_position_ids_flat, - ) - video_features = self.embed_vision(inputs_embeds=vision_outputs.last_hidden_state) + with self.patch_hf_config(): + video_features = self.get_video_features( + pixel_values_videos, video_position_ids, return_dict=True).pooler_output video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) video_mask_e = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) inputs_embeds = inputs_embeds.masked_scatter(video_mask_e, video_features) if (input_features is not None and input_features_mask is not None and self.audio_tower is not None): - audio_outputs = self.audio_tower(input_features, input_features_mask, return_dict=True) - audio_features = self.embed_audio(inputs_embeds=audio_outputs.last_hidden_state) - audio_features = audio_features[audio_outputs.attention_mask] + with self.patch_hf_config(): + audio_output = self.get_audio_features(input_features, input_features_mask, return_dict=True) + audio_features = audio_output.pooler_output + audio_features = audio_features[audio_output.attention_mask] audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) audio_mask_e = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) inputs_embeds = inputs_embeds.masked_scatter(audio_mask_e, audio_features) - - return inputs_embeds + return {'inputs_embeds': inputs_embeds, 'llm_input_ids': llm_input_ids} class Gemma4SelfAttention(SelfAttention): @@ -155,9 +154,6 @@ def __init__( self.layer_type in prev_layers and layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index(self.layer_type)) - # Patch config so the underlying linear_qkv/q_layernorm/k_layernorm are built correctly. - # HF keeps `q_norm` on every layer, but only builds `k_norm`/`v_norm` on non-kv-shared - # layers, so replace `k_layernorm` with `IdentityOp` when this layer shares KV. orig_kv_channels = config.kv_channels orig_num_query_groups = config.num_query_groups orig_k_layernorm = submodules.k_layernorm @@ -172,10 +168,6 @@ def __init__( config.num_query_groups = orig_num_query_groups submodules.k_layernorm = orig_k_layernorm - # HF kv-shared layers only keep `q_proj` (K/V are reused from an earlier layer), so the - # default mcore `linear_qkv` shape `[Q + 2*KV, hidden]` over-allocates. Rebuild it with - # out_dim = query_projection_size so shapes match HF `q_proj` 1:1 for weight bridging. - # Mirrors attention.py L1275-L1289, minus the `+ 2 * kv_projection_size` term. if self.is_kv_shared_layer: self.linear_qkv_out_dim = self.query_projection_size self.linear_qkv = submodules.linear_qkv( @@ -191,8 +183,6 @@ def __init__( tp_group=self.pg_collection.tp, ) - # HF builds a `v_norm` (RMSNorm without learnable scale) for non-kv-shared layers. - # mcore's SelfAttention has no v_layernorm by default, so attach one explicitly here. self.v_norm = ( Gemma4VNorm(self.head_dim, eps=self.config.layernorm_epsilon) if not self.is_kv_shared_layer else None) @@ -236,7 +226,7 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool): is_kv_shared_layer = False if mg_attn is None else mg_attn.is_kv_shared_layer is_kv_shared_layer = torch.tensor([is_kv_shared_layer], dtype=torch.bool, device='cuda') if self.pp_size > 1: - dist.all_reduce(is_lora, group=self.pp_group, op=dist.ReduceOp.MAX) + dist.all_reduce(is_kv_shared_layer, group=self.pp_group, op=dist.ReduceOp.MAX) is_kv_shared_layer = is_kv_shared_layer.item() if is_kv_shared_layer: self._set_state_dict(mg_attn, 'linear_qkv.weight', hf_state_dict, 'q_proj.weight', to_mcore) @@ -257,12 +247,7 @@ def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: i 'per_layer_projection', 'post_per_layer_input_norm' ]: self._set_state_dict( - mg_layer, - f'{key}.weight', - hf_state_dict if to_mcore else new_hf_state_dict, - f'{key}.weight', - to_mcore, - _check_mg_param=False) + mg_layer, f'{key}.weight', hf_state_dict, f'{key}.weight', to_mcore, _check_mg_param=False) if to_mcore: hf_state_dict = {} else: @@ -281,29 +266,18 @@ class Gemma4TextGPTModel(GPTModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.pad_embedding = self.embedding.word_embeddings.weight text_config = self.config.hf_config.text_config - # HF: `self.unique_layer_types = set(self.config.layer_types)` — needed by the rotary - # embedding selection logic (sliding vs global) when that path is wired up. + self.text_config = text_config self.unique_layer_types = set(text_config.layer_types) - # HF: Per-Layer Embeddings (PLE). Only populated on the pre-process (PP stage 0) side, - # since the auxiliary signal is derived from `input_ids` / the token embedding output. - # See `modeling_gemma4.py` L1574-L1592 for the reference construction. Built with - # megatron-native parallel modules (mirroring `LanguageModelEmbedding` at - # `gpt_model.py` L150-L157) so the aux signal follows the TP/SP layout of the - # primary embedding. self.hidden_size_per_layer_input = getattr(text_config, 'hidden_size_per_layer_input', None) if self.hidden_size_per_layer_input and self.pre_process: num_layers = text_config.num_hidden_layers hidden_size = text_config.hidden_size total_dim = num_layers * self.hidden_size_per_layer_input tp_size = self.config.tensor_model_parallel_size - # Pad aux vocab size to be TP-divisible, matching how `GPTModel` pads the main - # `padded_vocab_size` before feeding it into `VocabParallelEmbedding`. padded_vocab_size_per_layer = math.ceil(text_config.vocab_size_per_layer_input / tp_size) * tp_size - # Vocab-parallel embedding (shard on vocab dim). HF's `Gemma4TextScaledWordEmbedding` - # applies an `embed_scale = hidden_size_per_layer_input**0.5` factor on forward; - # we capture the scale as a sibling attribute so the weight shape stays 1:1 with HF. self.embed_tokens_per_layer = VocabParallelEmbedding( num_embeddings=padded_vocab_size_per_layer, embedding_dim=total_dim, @@ -313,9 +287,6 @@ def __init__(self, *args, **kwargs): ) self.embed_tokens_per_layer_scale = self.hidden_size_per_layer_input**0.5 self.per_layer_input_scale = 2.0**-0.5 - # Column-parallel projection: output dim `num_layers * hidden_size_per_layer_input` - # is split across TP ranks so each rank produces its own shard of the packed - # per-layer input tensor. self.per_layer_model_projection = build_module( TEColumnParallelLinear, hidden_size, @@ -356,8 +327,18 @@ def _set_inv_freq(self): self.config.rope_scaling = rope_scaling - def forward(self): - pass + def forward(self, *args, **kwargs): + extra_block_kwargs = kwargs.pop('extra_block_kwargs', None) or {} + llm_input_ids = extra_block_kwargs.pop('llm_input_ids', None) + if self.hidden_size_per_layer_input and self.pre_process: + per_layer_inputs = (self.embed_tokens_per_layer(llm_input_ids) * self.embed_tokens_per_layer_scale).reshape( + *llm_input_ids.shape, + self.text_config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + extra_block_kwargs['per_layer_inputs'] = per_layer_inputs + kwargs['extra_block_kwargs'] = extra_block_kwargs + return super().forward(*args, **kwargs) class Gemma4TransformerLayer(CustomTransformerLayer): @@ -368,22 +349,15 @@ def __init__(self, config, submodules, *args, **kwargs): hidden_size = self.config.hidden_size eps = self.config.layernorm_epsilon - # HF keeps an extra layernorm after self-attn / feedforward (before the residual add). - # mcore's TransformerLayer does not include these, so attach them here. self.post_attention_layernorm = build_module(TENorm, hidden_size=hidden_size, config=self.config, eps=eps) self.post_feedforward_layernorm = build_module(TENorm, hidden_size=hidden_size, config=self.config, eps=eps) - # HF: `self.register_buffer("layer_scalar", torch.ones(1))` self.register_buffer('layer_scalar', torch.ones(1)) - # HF: per-layer input projection branch, only when `hidden_size_per_layer_input` is set. self.hidden_size_per_layer_input = getattr(text_config, 'hidden_size_per_layer_input', None) if self.hidden_size_per_layer_input: from transformers.activations import ACT2FN self.act_fn = ACT2FN[text_config.hidden_activation] - # Megatron-style parallel linears (see attention.py L348-361 for `linear_proj`): - # `per_layer_input_gate` is column-parallel (output dim split across TP), then its - # output is consumed by the row-parallel `per_layer_projection` which gathers along TP. self.per_layer_input_gate = build_module( TEColumnParallelLinear, hidden_size, @@ -412,9 +386,6 @@ def __init__(self, config, submodules, *args, **kwargs): ) self.post_per_layer_input_norm = build_module(TENorm, hidden_size=hidden_size, config=self.config, eps=eps) - # HF: extra layernorms when the layer runs a MoE block in parallel with the dense MLP. - # Router / experts modules are gemma4-specific and intentionally skipped here; they can - # be wired by the bridge/forward override once their mcore counterparts are implemented. self.enable_moe_block = getattr(text_config, 'enable_moe_block', False) if self.enable_moe_block: self.post_feedforward_layernorm_1 = build_module( From 4ae40a254d081c430ea679f6b0e4f694b0accb01 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 11 May 2026 15:02:37 +0800 Subject: [PATCH 27/52] update --- src/mcore_bridge/model/gpt_model.py | 24 +++++++++++++----------- src/mcore_bridge/model/mm_gpts/gemma4.py | 11 ++++++++++- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 197d723..cb04b27 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -210,7 +210,7 @@ def _preprocess( if decoder_input is not None and self.training and torch.is_grad_enabled() and not decoder_input.requires_grad: # fix LoRA incompatibility with gradient checkpointing decoder_input = decoder_input.requires_grad_(True) - rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb( + rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb( decoder_input, position_ids, packed_seq_params=packed_seq_params) if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration') @@ -231,8 +231,7 @@ def _preprocess( if in_inference_mode and not has_config_logger_enabled(self.config): decoder_input = WrappedTensor(decoder_input) - return (decoder_input, rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, - sequence_len_offset) + return (decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset) def _set_inv_freq(self): self.attention_scaling = 1. @@ -274,13 +273,7 @@ def _get_rotary_pos_emb(self, decoder_input, position_ids, packed_seq_params, in rotary_seq_len, packed_seq=packed_seq, ) - decoder_rotary_pos_emb = rotary_pos_emb - packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' - if self.position_embedding_type == 'rope' and packed_seq and not self.config.apply_rope_fusion: - assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' - decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]] - - return rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin + return rotary_pos_emb, rotary_pos_cos, rotary_pos_sin # Code borrowed from NVIDIA/Megatron-LM def forward( @@ -311,7 +304,7 @@ def forward( inference_context = deprecate_inference_params(inference_context, inference_params) # There is a difference in whether rotary_pos_emb can be fused between the decoder and MTP. - decoder_input, rotary_pos_emb, decoder_rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = ( + decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = ( self._preprocess( input_ids=input_ids, position_ids=position_ids, @@ -319,6 +312,15 @@ def forward( inference_context=inference_context, packed_seq_params=packed_seq_params, )) + decoder_rotary_pos_emb = rotary_pos_emb + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' + if self.position_embedding_type == 'rope' and packed_seq and not self.config.apply_rope_fusion: + assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' + if isinstance(rotary_pos_emb, dict): + for k, v in rotary_pos_emb.items(): + decoder_rotary_pos_emb[k] = v[position_ids[0]] + else: + decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]] mtp_decoder_input = decoder_input if self.config.is_multimodal and self.config.mtp_num_layers and decoder_input is None: diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 3b92bbf..6e3fd49 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -4,6 +4,7 @@ import torch import torch.distributed as dist from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TENorm, TERowParallelLinear +from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec from megatron.core.tensor_parallel import VocabParallelEmbedding from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules @@ -266,7 +267,6 @@ class Gemma4TextGPTModel(GPTModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.pad_embedding = self.embedding.word_embeddings.weight text_config = self.config.hf_config.text_config self.text_config = text_config self.unique_layer_types = set(text_config.layer_types) @@ -308,6 +308,15 @@ def __init__(self, *args, **kwargs): eps=self.config.layernorm_epsilon, ) + def _get_rotary_pos_emb(self, decoder_input, position_ids, packed_seq_params, inference_context=None): + rotary_seq_len = RotaryEmbedding.get_rotary_seq_len(self, inference_context, self.decoder, decoder_input, + self.config, packed_seq_params) + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq) + full_rotary_pos_emb = self.full_rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq) + rotary_pos_emb = {'sliding_attention': rotary_pos_emb, 'full_attention': full_rotary_pos_emb} + return rotary_pos_emb, None, None + def _set_inv_freq(self): rope_scaling = self.config.rope_scaling self.config.rope_scaling = rope_scaling['sliding_attention'] From 3123407518dae5ef689b57c8d240b9906a6c1c0a Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 11 May 2026 15:35:04 +0800 Subject: [PATCH 28/52] update --- src/mcore_bridge/model/modules/__init__.py | 1 + src/mcore_bridge/model/modules/mtp_layer.py | 1 + .../model/modules/transformer_block.py | 386 ++++++++++++++++++ .../model/modules/transformer_layer.py | 1 + 4 files changed, 389 insertions(+) create mode 100644 src/mcore_bridge/model/modules/transformer_block.py diff --git a/src/mcore_bridge/model/modules/__init__.py b/src/mcore_bridge/model/modules/__init__.py index eff1bd6..885b05b 100644 --- a/src/mcore_bridge/model/modules/__init__.py +++ b/src/mcore_bridge/model/modules/__init__.py @@ -2,4 +2,5 @@ from .gated_delta_net import GatedDeltaNet from .gated_self_attention import GatedSelfAttention from .mtp_layer import MultiTokenPredictionLayer +from .transformer_block import CustomTransformerBlock from .transformer_layer import CustomTransformerLayer diff --git a/src/mcore_bridge/model/modules/mtp_layer.py b/src/mcore_bridge/model/modules/mtp_layer.py index 8be6aeb..5398b71 100644 --- a/src/mcore_bridge/model/modules/mtp_layer.py +++ b/src/mcore_bridge/model/modules/mtp_layer.py @@ -1,3 +1,4 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. import torch import transformer_engine from contextlib import nullcontext diff --git a/src/mcore_bridge/model/modules/transformer_block.py b/src/mcore_bridge/model/modules/transformer_block.py new file mode 100644 index 0000000..841c357 --- /dev/null +++ b/src/mcore_bridge/model/modules/transformer_block.py @@ -0,0 +1,386 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import torch +from contextlib import nullcontext +from megatron.core import tensor_parallel +from megatron.core.enums import Fp8Recipe +from megatron.core.extensions.transformer_engine import te_checkpoint +from megatron.core.fp4_utils import get_fp4_context +from megatron.core.fp8_utils import get_fp8_context +from megatron.core.inference.contexts import BaseInferenceContext +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_layer import get_transformer_layer_offset +from megatron.core.utils import WrappedTensor, deprecate_inference_params, get_pg_rank, make_viewless_tensor +from typing import List, Optional, Set, Union, cast + +try: + from megatron.core.typed_torch import apply_module +except ImportError: + apply_module = None + + +class CustomTransformerBlock(TransformerBlock): + + def _checkpointed_forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + context: torch.Tensor, + context_mask: torch.Tensor, + rotary_pos_emb: torch.Tensor, + attention_bias: torch.Tensor, + packed_seq_params: PackedSeqParams, + use_inner_quantization_context: bool, + padding_mask: Optional[torch.Tensor] = None, + extract_layer_indices: Optional[Set[int]] = None, + layer_offset: int = 0, + ): + """Forward method with activation checkpointing. + + Args: + extract_layer_indices (Set[int], optional): Global layer + indices (across all pipeline stages) from which to + extract features. + layer_offset (int): The global layer offset for the current + pipeline stage. Used to convert local layer indices to + global indices when checking extract_layer_indices. + + Returns: + If extract_layer_indices is empty: hidden_states tensor + If extract_layer_indices is non-empty: (hidden_states, intermediate_hidden_states) tuple + """ + if extract_layer_indices is None: + extract_layer_indices = set() + intermediate_hidden_states: List[torch.Tensor] = [] + + def custom(start: int, end: int): + + def custom_forward( + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + padding_mask=None, + ): + for index in range(start, end): + layer = self._get_layer(index) + + # Get appropriate inner quantization context + if use_inner_quantization_context: + if self.config.fp8: + inner_quantization_context = get_fp8_context(self.config, layer.layer_number - 1) + # TODO: check if fp4 is supported in this case + elif self.config.fp4: + inner_quantization_context = get_fp4_context(self.config, layer.layer_number - 1) + else: + inner_quantization_context = nullcontext() + else: + inner_quantization_context = nullcontext() + + with inner_quantization_context: + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + inference_context=None, + packed_seq_params=packed_seq_params, + padding_mask=padding_mask, + ) + return hidden_states, context + + return custom_forward + + def checkpoint_handler(forward_func): + """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" + # TODO: check if fp4 is supported in this case + if self.config.fp8 or self.config.fp4: + return te_checkpoint( + forward_func, + self.config.distribute_saved_activations, + tensor_parallel.random.get_cuda_rng_tracker, + self.pg_collection.tp, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + padding_mask, + ) + else: + return tensor_parallel.checkpoint( + forward_func, + self.config.distribute_saved_activations, + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb, + padding_mask, + ) + + if self.config.recompute_method == 'uniform': + # Uniformly divide the total number of Transformer layers and checkpoint + # the input activation of each divided chunk. + # A method to further reduce memory usage reducing checkpoints. + layer_idx = 0 + while layer_idx < self.num_layers_per_pipeline_rank: + chunk_end = min(layer_idx + self.config.recompute_num_layers, self.num_layers_per_pipeline_rank) + hidden_states, context = checkpoint_handler(custom(layer_idx, chunk_end)) + + # Feature extraction for uniform recompute: collect at end of each chunk + # Note: Only the last layer of each chunk can have features collected + for idx in range(layer_idx, chunk_end): + if (idx + layer_offset) in extract_layer_indices: + # For uniform recompute, we can only get features at chunk boundaries + # Limitation: for fine-grained extraction, use 'block' + if idx == chunk_end - 1: + intermediate_hidden_states.append(hidden_states) + + layer_idx += self.config.recompute_num_layers + + elif self.config.recompute_method == 'block': + # Checkpoint the input activation of only a set number of individual + # Transformer layers and skip the rest. + # A method fully use the device memory removing redundant re-computation. + recompute_skip_num_layers = 0 + for layer_idx in range(self.num_layers_per_pipeline_rank): + # Skip recomputation when input grad computation is not needed. + # Need to have at least one input tensor with gradient computation + # for re-enterant autograd engine. + # TODO: check if fp4 is supported in this case + if (self.config.fp8 or self.config.fp4) and not hidden_states.requires_grad: + recompute_skip_num_layers += 1 + if (layer_idx >= recompute_skip_num_layers + and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers): + hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1)) + else: + hidden_states, context = custom(layer_idx, layer_idx + 1)(hidden_states, attention_mask, context, + context_mask, rotary_pos_emb) + + # Feature extraction: collect hidden states at specified global layer indices + if (layer_idx + layer_offset) in extract_layer_indices: + intermediate_hidden_states.append(hidden_states) + else: + raise ValueError('Invalid activation recompute method.') + + # Return intermediate hidden states if feature extraction was requested + if len(extract_layer_indices) > 0: + return hidden_states, intermediate_hidden_states + + return hidden_states + + def forward( + self, + hidden_states: Union[torch.Tensor, WrappedTensor], + attention_mask: Optional[torch.Tensor], + context: Optional[torch.Tensor] = None, + context_mask: Optional[torch.Tensor] = None, + rotary_pos_emb: Optional[torch.Tensor] = None, + rotary_pos_cos: Optional[torch.Tensor] = None, + rotary_pos_sin: Optional[torch.Tensor] = None, + rotary_pos_cos_sin: Optional[torch.Tensor] = None, + attention_bias: Optional[torch.Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + extract_layer_indices: Optional[Set[int]] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + dynamic_inference_decode_only: Optional[bool] = None, + ): + """ + Perform the forward pass through the transformer block. + + This method handles the core computation of the transformer, including + self-attention, optional cross-attention, and feed-forward operations. + + Args: + hidden_states (Union[Tensor, WrappedTensor]): Input tensor of shape [s, b, h] + where s is the sequence length, b is the batch size, and h is the hidden size. + Can be passed as a WrappedTensor during inference to avoid an obsolete + reference in the calling function. + attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking + self-attention. + context (Tensor, optional): Context tensor for cross-attention. + context_mask (Tensor, optional): Mask for cross-attention context + rotary_pos_emb (Tensor, optional): Rotary positional embeddings. + rotary_pos_cos (Optional[Tensor]): Rotary embedding cosine. + rotary_pos_sin (Optional[Tensor]): Rotary embedding sine. + rotary_pos_cos_sin (Optional[Tensor]): Combined rotary embedding cosine and sine. + Currently used exclusively for inference with dynamic batching and flashinfer RoPE. + attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable + to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. + Used as an alternative to apply attention mask for TE cuDNN attention. + inference_context (BaseInferenceContext, optional): Parameters for inference-time + optimizations. + packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence + processing. + extract_layer_indices (Set[int], optional): A set of global + layer indices (0-based across all pipeline stages) from + which to extract intermediate hidden states. If + non-empty, the forward pass will collect hidden_states + after each specified layer. + dynamic_inference_decode_only: Optional[bool]: If true, indicates that the current + inference context is for decode-only. This args is only used to uniquely + identify decode and non-decode cuda graph runners in the cuda graph manager. + + Returns: + Union[Tensor, Tuple[Tensor, List[Tensor]]]: + - If extract_layer_indices is None or empty: Returns the output hidden states tensor + of shape [s, b, h]. + - If extract_layer_indices is non-empty: Returns a tuple + of (hidden_states, intermediate_hidden_states) where + intermediate_hidden_states is a list of tensors + corresponding to hidden states after each layer in + extract_layer_indices. + """ + + inference_context = deprecate_inference_params(inference_context, inference_params) + # Remove 'dynamic_inference_decode_only' from kwargs if present + # this is only used to uniquely identify decode and non-decode cuda graph + # runners in the cuda graph manager + + # Initialize feature collection (consistent with FastGen's Wan implementation) + if extract_layer_indices is None: + extract_layer_indices = set() + intermediate_hidden_states: List[torch.Tensor] = [] + + # Calculate the global layer offset for this pipeline stage + # This is needed to convert local layer indices to global indices for feature extraction + pp_group = self.pg_collection.pp if hasattr(self.pg_collection, 'pp') else None + layer_offset = get_transformer_layer_offset(self.config, self.vp_stage, get_pg_rank(pp_group)) + + # Delete the obsolete reference to the initial input tensor if necessary + if isinstance(hidden_states, WrappedTensor): + hidden_states = hidden_states.unwrap() + + if not self.pre_process: + # See set_input_tensor() + hidden_states = self.input_tensor + + # Viewless tensor. + # - We only need to create a viewless tensor in the case of micro batch + # size (mbs) == 1, since in this case, 'hidden_states.transpose()' + # above creates a view tensor, and '.contiguous()' is a pass-through. + # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating + # the need to make it viewless. + # + # However, we don't explicitly check mbs == 1 here because + # make_viewless_tensor() has negligible overhead when its input + # is already viewless. + # + # - For the 'else' case above, calling make_viewless_tensor() here is + # likely redundant, since p2p_communication.py (likely originator) + # already creates viewless tensors. That said, make_viewless_tensor() + # is called here to be future-proof and corner-case-proof. + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + if self.config.sequence_parallel: + rng_context = tensor_parallel.get_cuda_rng_tracker().fork() + else: + rng_context = nullcontext() + + # If fp8_recipe is delayed, wrap the entire pass with get_fp8_context(), + # otherwise do nothing extra at the outer level + # if we are using other fp8 recipes, then the context manager enter&exit are free + # we can wrap fp8_context within the for loop over layers, so that we can fine-grained + # control which layer will be fp8 or bf16 + # For FP4: NVFP4BlockScaling doesn't have delayed scaling, always uses inner context + if self.config.fp8: + use_outer_quantization_context = self.config.fp8_recipe == Fp8Recipe.delayed + use_inner_quantization_context = self.config.fp8_recipe != Fp8Recipe.delayed + outer_quantization_context = ( + get_fp8_context(self.config) if use_outer_quantization_context else nullcontext()) + elif self.config.fp4: + use_outer_quantization_context = False + use_inner_quantization_context = True + outer_quantization_context = nullcontext() + else: + # No quantization + use_outer_quantization_context = False + use_inner_quantization_context = False + outer_quantization_context = nullcontext() + + with rng_context, outer_quantization_context: + # Forward pass. + if self.config.recompute_granularity == 'full' and self.training: + checkpointed_result = self._checkpointed_forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + use_inner_quantization_context=use_inner_quantization_context, + padding_mask=padding_mask, + extract_layer_indices=extract_layer_indices, + layer_offset=layer_offset, + ) + # Handle return value from _checkpointed_forward + if len(extract_layer_indices) > 0: + # (hidden_states, intermediate_hidden_states) tuple + hidden_states, intermediate_hidden_states = checkpointed_result + else: + # No intermediate_hidden_states requested: just hidden_states + hidden_states = checkpointed_result + else: + for l_no, layer in enumerate(self.layers): + # Get appropriate inner quantization context + if use_inner_quantization_context: + if self.config.fp8: + inner_quantization_context = get_fp8_context(self.config, layer.layer_number - 1) + elif self.config.fp4: + inner_quantization_context = get_fp4_context(self.config, layer.layer_number - 1) + else: + inner_quantization_context = nullcontext() + else: + inner_quantization_context = nullcontext() + + with self.offload_context, inner_quantization_context: + hidden_states, context = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + rotary_pos_cos_sin=rotary_pos_cos_sin, + attention_bias=attention_bias, + inference_context=inference_context, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + padding_mask=padding_mask, + ) + + if (torch.is_grad_enabled() and self.config.cpu_offloading + and self.group_prefetch_offload_commit_async is not None): + hidden_states = self.group_prefetch_offload_commit_async(hidden_states) + + # Extract intermediate embeddings using global layer index + if (l_no + layer_offset) in extract_layer_indices: + intermediate_hidden_states.append(hidden_states) + + # Final layer norm. + if self.final_layernorm is not None: + hidden_states = apply_module(self.final_layernorm)(cast(torch.Tensor, hidden_states)) + # TENorm produces a "viewed" tensor. This will result in schedule.py's + # deallocate_output_tensor() throwing an error, so a viewless tensor is + # created to prevent this. + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + # If this TransformerBlock is empty, input and output hidden states will be the same node + # on the computational graph and will lead to unexpected errors in pipeline schedules. + if not self.pre_process and len(self.layers) == 0 and not self.final_layernorm: + hidden_states = hidden_states.clone() + + if len(extract_layer_indices) > 0: + return hidden_states, intermediate_hidden_states + + return hidden_states diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index b3bd54b..bf119a2 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -1,3 +1,4 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. import enum import inspect import torch From 03dc22168805808c73617f18136758c875fd1f5d Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 11 May 2026 15:59:41 +0800 Subject: [PATCH 29/52] update --- src/mcore_bridge/model/mm_gpts/qwen3_vl.py | 11 +------- src/mcore_bridge/model/register.py | 31 +++++++++++++++------- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/qwen3_vl.py b/src/mcore_bridge/model/mm_gpts/qwen3_vl.py index 92a90d8..7907bf4 100644 --- a/src/mcore_bridge/model/mm_gpts/qwen3_vl.py +++ b/src/mcore_bridge/model/mm_gpts/qwen3_vl.py @@ -427,16 +427,7 @@ def _get_inputs_embeds(self, inputs_embeds, inputs, visual, hf_config): class Qwen3VLLoader(ModelLoader): - - def _patch_transformer_block(self): - if hasattr(gpt_model, 'OriginTransformerBlock'): - return - gpt_model.OriginTransformerBlock = gpt_model.TransformerBlock - gpt_model.TransformerBlock = Qwen3VLTransformerBlock - - def __init__(self, config): - super().__init__(config) - self._patch_transformer_block() + transformer_block = Qwen3VLTransformerBlock register_model( diff --git a/src/mcore_bridge/model/register.py b/src/mcore_bridge/model/register.py index 15b37fe..57be2b6 100644 --- a/src/mcore_bridge/model/register.py +++ b/src/mcore_bridge/model/register.py @@ -1,9 +1,11 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import megatron.core +from contextlib import contextmanager from dataclasses import dataclass from megatron.core import mpu from megatron.core.enums import ModelType from megatron.core.extensions.transformer_engine import TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear +from megatron.core.models.gpt import gpt_model from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec, get_gpt_mtp_block_spec from packaging import version from torch import nn @@ -13,7 +15,7 @@ from mcore_bridge.config import ModelConfig from mcore_bridge.utils import get_logger -from .modules import CustomTransformerLayer, MultiTokenPredictionLayer +from .modules import CustomTransformerBlock, CustomTransformerLayer, MultiTokenPredictionLayer if TYPE_CHECKING: from .gpt_model import GPTModel @@ -66,6 +68,7 @@ def get_model_meta(mcore_model_type: str) -> ModelMeta: class ModelLoader: model_cls = None + transformer_block = CustomTransformerBlock def __init__(self, config: ModelConfig): from mcore_bridge.model import GPTModel, MultimodalGPTModel @@ -131,17 +134,27 @@ def build_model( mtp_block_spec = None if self.config.mtp_num_layers is not None: mtp_block_spec = self.get_mtp_block_spec(transformer_layer_spec, vp_stage=vp_stage) - model = self.model_cls( - config=self.config, - transformer_layer_spec=transformer_layer_spec, - pre_process=pre_process, - post_process=post_process, - mtp_block_spec=mtp_block_spec, - vp_stage=vp_stage, - ) + with self._patch_transformer_block(): + model = self.model_cls( + config=self.config, + transformer_layer_spec=transformer_layer_spec, + pre_process=pre_process, + post_process=post_process, + mtp_block_spec=mtp_block_spec, + vp_stage=vp_stage, + ) self._set_linear_is_expert(model) return model + @contextmanager + def _patch_transformer_block(self): + TransformerBlock = gpt_model.TransformerBlock + gpt_model.TransformerBlock = self.transformer_block + try: + yield + finally: + gpt_model.TransformerBlock = TransformerBlock + def _set_linear_is_expert(self, model): for n, module in model.named_modules(): if '.local_experts.' in n and isinstance(module, (TELinear, TELayerNormColumnParallelLinear)) or isinstance( From 41465ff29c130642b50b34aa201398d9dbabce3f Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 11 May 2026 16:21:04 +0800 Subject: [PATCH 30/52] update --- src/mcore_bridge/model/mm_gpts/qwen3_vl.py | 297 ++---------------- .../model/modules/transformer_block.py | 40 ++- 2 files changed, 51 insertions(+), 286 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/qwen3_vl.py b/src/mcore_bridge/model/mm_gpts/qwen3_vl.py index 7907bf4..e95741f 100644 --- a/src/mcore_bridge/model/mm_gpts/qwen3_vl.py +++ b/src/mcore_bridge/model/mm_gpts/qwen3_vl.py @@ -1,299 +1,38 @@ # Copyright (c) ModelScope Contributors. All rights reserved. import torch -from contextlib import nullcontext -from megatron.core import parallel_state, tensor_parallel -from megatron.core.enums import Fp8Recipe -from megatron.core.fp8_utils import get_fp8_context -from megatron.core.inference.contexts import BaseInferenceContext -from megatron.core.models.gpt import gpt_model -from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.utils import WrappedTensor, deprecate_inference_params, make_viewless_tensor -from typing import List, Optional, Union +from megatron.core import parallel_state from mcore_bridge.bridge import MultimodalGPTBridge from mcore_bridge.utils import split_cp_inputs from ..constant import ModelType +from ..modules import CustomTransformerBlock from ..register import ModelLoader, ModelMeta, register_model from .utils import HuggingFaceVit -te_checkpoint = None -try: - import transformer_engine.pytorch as te # pylint: disable=unused-import +class Qwen3VLTransformerBlock(CustomTransformerBlock): - HAVE_TE = True -except ImportError: - HAVE_TE = False + def _layer_forward(self, layer, hidden_states, **kwargs): + deepstack_visual_embeds = kwargs.pop('deepstack_visual_embeds', None) + visual_pos_masks = kwargs.pop('visual_pos_masks', None) + hidden_states, context = layer(hidden_states=hidden_states, **kwargs) + layer_number = layer.layer_number - 1 + if deepstack_visual_embeds is not None and layer_number in range(len(deepstack_visual_embeds)): + hidden_states = self._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[layer_number], + ) + return hidden_states, context -if HAVE_TE: - from megatron.core.extensions.transformer_engine import te_checkpoint - - -class Qwen3VLTransformerBlock(gpt_model.TransformerBlock): - # Code borrowed from NVIDIA/Megatron-LM - - def _checkpointed_forward( - self, - hidden_states: torch.Tensor, - attention_mask: torch.Tensor, - context: torch.Tensor, - context_mask: torch.Tensor, - rotary_pos_emb: torch.Tensor, - attention_bias: torch.Tensor, - packed_seq_params: PackedSeqParams, - use_inner_fp8_context: bool, - # args for deepstack - visual_pos_masks: Optional[torch.Tensor] = None, - deepstack_visual_embeds: Optional[List[torch.Tensor]] = None, - ): - """Forward method with activation checkpointing.""" - - def custom(start: int, end: int): - - def custom_forward(hidden_states, attention_mask, context, context_mask, rotary_pos_emb, visual_pos_masks, - deepstack_visual_embeds): - for index in range(start, end): - layer = self._get_layer(index) - inner_fp8_context = ( - get_fp8_context(self.config, layer.layer_number - - 1) if use_inner_fp8_context else nullcontext()) - with inner_fp8_context: - hidden_states, context = layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - attention_bias=attention_bias, - inference_context=None, - packed_seq_params=packed_seq_params, - ) - # add visual features to the hidden states of first several layers - layer_number = layer.layer_number - 1 - if deepstack_visual_embeds is not None and layer_number in range(len(deepstack_visual_embeds)): - hidden_states = self._deepstack_process( - hidden_states, - visual_pos_masks, - deepstack_visual_embeds[layer_number], - ) - return hidden_states, context - - return custom_forward - - def checkpoint_handler(forward_func): - """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" - if self.config.fp8: - return te_checkpoint( - forward_func, - self.config.distribute_saved_activations, - tensor_parallel.random.get_cuda_rng_tracker, - parallel_state.get_tensor_model_parallel_group(), - hidden_states, - attention_mask, - context, - context_mask, - rotary_pos_emb, - visual_pos_masks, - deepstack_visual_embeds, - ) - else: - return tensor_parallel.checkpoint( - forward_func, - self.config.distribute_saved_activations, - hidden_states, - attention_mask, - context, - context_mask, - rotary_pos_emb, - visual_pos_masks, - deepstack_visual_embeds, - ) - - if self.config.recompute_method == 'uniform': - # Uniformly divide the total number of Transformer layers and checkpoint - # the input activation of each divided chunk. - # A method to further reduce memory usage reducing checkpoints. - layer_idx = 0 - while layer_idx < self.num_layers_per_pipeline_rank: - hidden_states, context = checkpoint_handler( - custom(layer_idx, layer_idx + self.config.recompute_num_layers)) - - layer_idx += self.config.recompute_num_layers - - elif self.config.recompute_method == 'block': - # Checkpoint the input activation of only a set number of individual - # Transformer layers and skip the rest. - # A method fully use the device memory removing redundant re-computation. - recompute_skip_num_layers = 0 - for layer_idx in range(self.num_layers_per_pipeline_rank): - # Skip recomputation when input grad computation is not needed. - # Need to have at least one input tensor with gradient computation - # for re-enterant autograd engine. - if self.config.fp8 and not hidden_states.requires_grad: - recompute_skip_num_layers += 1 - if (layer_idx >= recompute_skip_num_layers - and layer_idx < self.config.recompute_num_layers + recompute_skip_num_layers): - hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1)) - else: - hidden_states, context = custom(layer_idx, layer_idx + 1)(hidden_states, attention_mask, context, - context_mask, rotary_pos_emb, - visual_pos_masks, deepstack_visual_embeds) - else: - raise ValueError('Invalid activation recompute method.') - - return hidden_states - - def forward( - self, - hidden_states: Union[torch.Tensor, WrappedTensor], - attention_mask: Optional[torch.Tensor], - context: Optional[torch.Tensor] = None, - context_mask: Optional[torch.Tensor] = None, - rotary_pos_emb: Optional[torch.Tensor] = None, - rotary_pos_cos: Optional[torch.Tensor] = None, - rotary_pos_sin: Optional[torch.Tensor] = None, - attention_bias: Optional[torch.Tensor] = None, - inference_context: Optional[BaseInferenceContext] = None, - packed_seq_params: Optional[PackedSeqParams] = None, - sequence_len_offset: Optional[torch.Tensor] = None, - *, - inference_params: Optional[BaseInferenceContext] = None, - # args for deepstack - visual_pos_masks: Optional[torch.Tensor] = None, - deepstack_visual_embeds: Optional[List[torch.Tensor]] = None, - ): - """ - Perform the forward pass through the transformer block. - This method handles the core computation of the transformer, including - self-attention, optional cross-attention, and feed-forward operations. - Args: - hidden_states (Union[Tensor, WrappedTensor]): Input tensor of shape [s, b, h] - where s is the sequence length, b is the batch size, and h is the hidden size. - Can be passed as a WrappedTensor during inference to avoid an obsolete - reference in the calling function. - attention_mask (Tensor): Boolean tensor of shape [1, 1, s, s] for masking - self-attention. - context (Tensor, optional): Context tensor for cross-attention. - context_mask (Tensor, optional): Mask for cross-attention context - rotary_pos_emb (Tensor, optional): Rotary positional embeddings. - attention_bias (Tensor): Bias tensor for Q * K.T of shape in shape broadcastable - to [b, num_head, sq, skv], e.g. [1, 1, sq, skv]. - Used as an alternative to apply attention mask for TE cuDNN attention. - inference_context (BaseInferenceContext, optional): Parameters for inference-time - optimizations. - packed_seq_params (PackedSeqParams, optional): Parameters for packed sequence - processing. - Returns: - Union[Tensor, Tuple[Tensor, Tensor]]: The output hidden states tensor of shape - [s, b, h], and optionally the updated context tensor if cross-attention is used. - """ + def forward(self, *args, **kwargs): + deepstack_visual_embeds = kwargs.get('deepstack_visual_embeds') if deepstack_visual_embeds is not None: assert len(deepstack_visual_embeds) <= len( self.layers), (f'len(deepstack_visual_embeds): {len(deepstack_visual_embeds)}, ' f'len(self.layers): {len(self.layers)}.') - inference_context = deprecate_inference_params(inference_context, inference_params) - - # Delete the obsolete reference to the initial input tensor if necessary - if isinstance(hidden_states, WrappedTensor): - hidden_states = hidden_states.unwrap() - - if not self.pre_process: - # See set_input_tensor() - hidden_states = self.input_tensor - - # Viewless tensor. - # - We only need to create a viewless tensor in the case of micro batch - # size (mbs) == 1, since in this case, 'hidden_states.transpose()' - # above creates a view tensor, and '.contiguous()' is a pass-through. - # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating - # the need to make it viewless. - # - # However, we don't explicitly check mbs == 1 here because - # make_viewless_tensor() has negligible overhead when its input - # is already viewless. - # - # - For the 'else' case above, calling make_viewless_tensor() here is - # likely redundant, since p2p_communication.py (likely originator) - # already creates viewless tensors. That said, make_viewless_tensor() - # is called here to be future-proof and corner-case-proof. - hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) - - if self.config.sequence_parallel: - rng_context = tensor_parallel.get_cuda_rng_tracker().fork() - else: - rng_context = nullcontext() - - # If fp8_recipe is delayed, wrap the entire pass with get_fp8_context(), - # otherwise do nothing extra at the outer level - # if we are using other fp8 recipes, then the context manager enter&exit are free - # we can wrap fp8_context within the for loop over layers, so that we can fine-grained - # control which layer will be fp8 or bf16 - use_outer_fp8_context = self.config.fp8 and self.config.fp8_recipe == Fp8Recipe.delayed - use_inner_fp8_context = self.config.fp8 and self.config.fp8_recipe != Fp8Recipe.delayed - outer_fp8_context = get_fp8_context(self.config) if use_outer_fp8_context else nullcontext() - - with rng_context, outer_fp8_context: - # Forward pass. - if self.config.recompute_granularity == 'full' and self.training: - hidden_states = self._checkpointed_forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - attention_bias=attention_bias, - packed_seq_params=packed_seq_params, - use_inner_fp8_context=use_inner_fp8_context, - visual_pos_masks=visual_pos_masks, - deepstack_visual_embeds=deepstack_visual_embeds, - ) - else: - for l_no, layer in enumerate(self.layers): - inner_fp8_context = ( - get_fp8_context(self.config, layer.layer_number - - 1) if use_inner_fp8_context else nullcontext()) - with self.offload_context, inner_fp8_context: - hidden_states, context = layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - context=context, - context_mask=context_mask, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - attention_bias=attention_bias, - inference_context=inference_context, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, - ) - # add visual features to the hidden states of first several layers - layer_number = layer.layer_number - 1 - if deepstack_visual_embeds is not None and layer_number in range(len(deepstack_visual_embeds)): - hidden_states = self._deepstack_process( - hidden_states, - visual_pos_masks, - deepstack_visual_embeds[layer_number], - ) - - if (torch.is_grad_enabled() and self.config.cpu_offloading - and self.group_prefetch_offload_commit_async is not None): - hidden_states = self.group_prefetch_offload_commit_async(hidden_states) - - # Final layer norm. - if self.final_layernorm is not None: - hidden_states = self.final_layernorm(hidden_states) - # TENorm produces a "viewed" tensor. This will result in schedule.py's - # deallocate_output_tensor() throwing an error, so a viewless tensor is - # created to prevent this. - hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) - - # If this TransformerBlock is empty, input and output hidden states will be the same node - # on the computational graph and will lead to unexpected errors in pipeline schedules. - if not self.pre_process and len(self.layers) == 0 and not self.final_layernorm: - hidden_states = hidden_states.clone() - - return hidden_states + return super().forward(*args, **kwargs) def _deepstack_process(self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, visual_embeds: torch.Tensor): diff --git a/src/mcore_bridge/model/modules/transformer_block.py b/src/mcore_bridge/model/modules/transformer_block.py index 841c357..23282ed 100644 --- a/src/mcore_bridge/model/modules/transformer_block.py +++ b/src/mcore_bridge/model/modules/transformer_block.py @@ -19,6 +19,7 @@ apply_module = None +# Code borrowed from NVIDIA/Megatron-LM class CustomTransformerBlock(TransformerBlock): def _checkpointed_forward( @@ -34,6 +35,7 @@ def _checkpointed_forward( padding_mask: Optional[torch.Tensor] = None, extract_layer_indices: Optional[Set[int]] = None, layer_offset: int = 0, + **kwargs, ): """Forward method with activation checkpointing. @@ -62,6 +64,7 @@ def custom_forward( context_mask, rotary_pos_emb, padding_mask=None, + **kwargs, ): for index in range(start, end): layer = self._get_layer(index) @@ -79,8 +82,9 @@ def custom_forward( inner_quantization_context = nullcontext() with inner_quantization_context: - hidden_states, context = layer( - hidden_states=hidden_states, + hidden_states, context = self._layer_forward( + layer, + hidden_states, attention_mask=attention_mask, context=context, context_mask=context_mask, @@ -89,6 +93,7 @@ def custom_forward( inference_context=None, packed_seq_params=packed_seq_params, padding_mask=padding_mask, + **kwargs, ) return hidden_states, context @@ -109,6 +114,7 @@ def checkpoint_handler(forward_func): context_mask, rotary_pos_emb, padding_mask, + **kwargs, ) else: return tensor_parallel.checkpoint( @@ -120,6 +126,7 @@ def checkpoint_handler(forward_func): context_mask, rotary_pos_emb, padding_mask, + **kwargs, ) if self.config.recompute_method == 'uniform': @@ -159,7 +166,7 @@ def checkpoint_handler(forward_func): hidden_states, context = checkpoint_handler(custom(layer_idx, layer_idx + 1)) else: hidden_states, context = custom(layer_idx, layer_idx + 1)(hidden_states, attention_mask, context, - context_mask, rotary_pos_emb) + context_mask, rotary_pos_emb, **kwargs) # Feature extraction: collect hidden states at specified global layer indices if (layer_idx + layer_offset) in extract_layer_indices: @@ -173,6 +180,19 @@ def checkpoint_handler(forward_func): return hidden_states + def _layer_forward(self, layer, hidden_states, **kwargs): + deepstack_visual_embeds = kwargs.pop('deepstack_visual_embeds', None) + visual_pos_masks = kwargs.pop('visual_pos_masks', None) + hidden_states, context = layer(hidden_states=hidden_states, **kwargs) + layer_number = layer.layer_number - 1 + if deepstack_visual_embeds is not None and layer_number in range(len(deepstack_visual_embeds)): + hidden_states = self._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[layer_number], + ) + return hidden_states, context + def forward( self, hidden_states: Union[torch.Tensor, WrappedTensor], @@ -192,6 +212,7 @@ def forward( *, inference_params: Optional[BaseInferenceContext] = None, dynamic_inference_decode_only: Optional[bool] = None, + **kwargs, ): """ Perform the forward pass through the transformer block. @@ -321,6 +342,7 @@ def forward( padding_mask=padding_mask, extract_layer_indices=extract_layer_indices, layer_offset=layer_offset, + **kwargs, ) # Handle return value from _checkpointed_forward if len(extract_layer_indices) > 0: @@ -343,8 +365,9 @@ def forward( inner_quantization_context = nullcontext() with self.offload_context, inner_quantization_context: - hidden_states, context = layer( - hidden_states=hidden_states, + hidden_states, context = self._layer_forward( + layer, + hidden_states, attention_mask=attention_mask, context=context, context_mask=context_mask, @@ -357,7 +380,7 @@ def forward( packed_seq_params=packed_seq_params, sequence_len_offset=sequence_len_offset, padding_mask=padding_mask, - ) + **kwargs) if (torch.is_grad_enabled() and self.config.cpu_offloading and self.group_prefetch_offload_commit_async is not None): @@ -369,7 +392,10 @@ def forward( # Final layer norm. if self.final_layernorm is not None: - hidden_states = apply_module(self.final_layernorm)(cast(torch.Tensor, hidden_states)) + if apply_module is None: + hidden_states = self.final_layernorm(hidden_states) + else: + hidden_states = apply_module(self.final_layernorm)(cast(torch.Tensor, hidden_states)) # TENorm produces a "viewed" tensor. This will result in schedule.py's # deallocate_output_tensor() throwing an error, so a viewless tensor is # created to prevent this. From 23cd67048eef6c1c3f6df01d7159856a6c7f803f Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 11 May 2026 17:06:25 +0800 Subject: [PATCH 31/52] fix --- src/mcore_bridge/model/mm_gpts/gemma4.py | 17 ++++++++++++++++- src/mcore_bridge/model/mm_gpts/qwen3_vl.py | 2 +- .../model/modules/transformer_block.py | 12 +----------- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 6e3fd49..e9b9ece 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -20,7 +20,7 @@ from ..constant import ModelType from ..gpt_model import GPTModel from ..mm_gpt_model import MultimodalGPTModel -from ..modules import CustomTransformerLayer +from ..modules import CustomTransformerBlock, CustomTransformerLayer from ..register import ModelLoader, ModelMeta, register_model from ..rope import get_rope_inv_freq from .utils import HuggingFaceVit @@ -404,13 +404,28 @@ def __init__(self, config, submodules, *args, **kwargs): self.pre_feedforward_layernorm_2 = build_module( TENorm, hidden_size=hidden_size, config=self.config, eps=eps) + def forward(self, *args, **kwargs): + per_layer_input = kwargs.pop('per_layer_input', None) + output, context = super().forward(*args, **kwargs) + return output, context + class Gemma4GPTModel(MultimodalGPTModel): language_model_cls = Gemma4TextGPTModel +class Gemma4TransformerBlock(CustomTransformerBlock): + + def _layer_forward(self, layer, hidden_states, **kwargs): + layer_number = layer.layer_number - 1 + per_layer_inputs = kwargs.pop('per_layer_inputs', None) + kwargs['per_layer_input'] = per_layer_inputs[:, :, layer_number] + return super()._layer_forward(layer, hidden_states, **kwargs) + + class Gemma4Loader(ModelLoader): model_cls = Gemma4GPTModel + transformer_block = Gemma4TransformerBlock def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): layer_specs = get_gpt_decoder_block_spec( diff --git a/src/mcore_bridge/model/mm_gpts/qwen3_vl.py b/src/mcore_bridge/model/mm_gpts/qwen3_vl.py index e95741f..26dfb85 100644 --- a/src/mcore_bridge/model/mm_gpts/qwen3_vl.py +++ b/src/mcore_bridge/model/mm_gpts/qwen3_vl.py @@ -16,7 +16,7 @@ class Qwen3VLTransformerBlock(CustomTransformerBlock): def _layer_forward(self, layer, hidden_states, **kwargs): deepstack_visual_embeds = kwargs.pop('deepstack_visual_embeds', None) visual_pos_masks = kwargs.pop('visual_pos_masks', None) - hidden_states, context = layer(hidden_states=hidden_states, **kwargs) + hidden_states, context = super()._layer_forward(layer, hidden_states, **kwargs) layer_number = layer.layer_number - 1 if deepstack_visual_embeds is not None and layer_number in range(len(deepstack_visual_embeds)): hidden_states = self._deepstack_process( diff --git a/src/mcore_bridge/model/modules/transformer_block.py b/src/mcore_bridge/model/modules/transformer_block.py index 23282ed..0f7f735 100644 --- a/src/mcore_bridge/model/modules/transformer_block.py +++ b/src/mcore_bridge/model/modules/transformer_block.py @@ -181,17 +181,7 @@ def checkpoint_handler(forward_func): return hidden_states def _layer_forward(self, layer, hidden_states, **kwargs): - deepstack_visual_embeds = kwargs.pop('deepstack_visual_embeds', None) - visual_pos_masks = kwargs.pop('visual_pos_masks', None) - hidden_states, context = layer(hidden_states=hidden_states, **kwargs) - layer_number = layer.layer_number - 1 - if deepstack_visual_embeds is not None and layer_number in range(len(deepstack_visual_embeds)): - hidden_states = self._deepstack_process( - hidden_states, - visual_pos_masks, - deepstack_visual_embeds[layer_number], - ) - return hidden_states, context + return layer(hidden_states=hidden_states, **kwargs) def forward( self, From 8460a3aa446c9e13729be447dc135104e26dc47a Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 16 May 2026 13:21:00 +0800 Subject: [PATCH 32/52] update --- src/mcore_bridge/model/mm_gpts/gemma4.py | 26 ++++++++++++++++++------ src/mcore_bridge/model/register.py | 4 ++-- src/mcore_bridge/model/rope.py | 7 ++++--- 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index e9b9ece..430a13d 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -20,7 +20,7 @@ from ..constant import ModelType from ..gpt_model import GPTModel from ..mm_gpt_model import MultimodalGPTModel -from ..modules import CustomTransformerBlock, CustomTransformerLayer +from ..modules import TransformerBlock, TransformerLayer from ..register import ModelLoader, ModelMeta, register_model from ..rope import get_rope_inv_freq from .utils import HuggingFaceVit @@ -326,10 +326,11 @@ def _set_inv_freq(self): # full self.full_rotary_pos_emb = copy.copy(self.rotary_pos_emb) self.config.rope_scaling = rope_scaling['full_attention'] - kwargs = {} + kwargs = {'layer_type': 'full_attention'} if self.config.rope_scaling['rope_type'] == 'proportional': kwargs['head_dim_key'] = 'global_head_dim' - new_inv_freq, attention_scaling = get_rope_inv_freq(self.config, **kwargs) + new_inv_freq, attention_scaling = get_rope_inv_freq( + self.config, text_config=self.config.hf_config.text_config, **kwargs) assert attention_scaling == 1, 'not support' self.full_rotary_pos_emb.inv_freq = new_inv_freq self.attention_scaling = attention_scaling @@ -340,17 +341,28 @@ def forward(self, *args, **kwargs): extra_block_kwargs = kwargs.pop('extra_block_kwargs', None) or {} llm_input_ids = extra_block_kwargs.pop('llm_input_ids', None) if self.hidden_size_per_layer_input and self.pre_process: + inputs_embeds = kwargs['decoder_input'] per_layer_inputs = (self.embed_tokens_per_layer(llm_input_ids) * self.embed_tokens_per_layer_scale).reshape( *llm_input_ids.shape, self.text_config.num_hidden_layers, self.hidden_size_per_layer_input, ) + per_layer_projection = self.per_layer_model_projection( + inputs_embeds)[0] * self.per_layer_model_projection_scale + per_layer_projection = per_layer_projection.reshape( + *inputs_embeds.shape[:-1], + self.text_config.num_hidden_layers, + self.hidden_size_per_layer_input, + ) + per_layer_projection = self.per_layer_projection_norm(per_layer_projection).transpose(0, 1) + per_layer_inputs = (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale + extra_block_kwargs['per_layer_inputs'] = per_layer_inputs kwargs['extra_block_kwargs'] = extra_block_kwargs return super().forward(*args, **kwargs) -class Gemma4TransformerLayer(CustomTransformerLayer): +class Gemma4TransformerLayer(TransformerLayer): def __init__(self, config, submodules, *args, **kwargs): super().__init__(config, submodules, *args, **kwargs) @@ -414,12 +426,14 @@ class Gemma4GPTModel(MultimodalGPTModel): language_model_cls = Gemma4TextGPTModel -class Gemma4TransformerBlock(CustomTransformerBlock): +class Gemma4TransformerBlock(TransformerBlock): def _layer_forward(self, layer, hidden_states, **kwargs): layer_number = layer.layer_number - 1 per_layer_inputs = kwargs.pop('per_layer_inputs', None) kwargs['per_layer_input'] = per_layer_inputs[:, :, layer_number] + layer_type = self.config.hf_config.text_config.layer_types[layer_number] + kwargs['rotary_pos_emb'] = kwargs['rotary_pos_emb'][layer_type] return super()._layer_forward(layer, hidden_states, **kwargs) @@ -435,7 +449,7 @@ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): layer_spec.submodules.mlp.module = Gemma4MLP return layer_specs - def _set_custom_layer(self, transformer_layer_spec): + def _set_transformer_layer(self, transformer_layer_spec): for layer_spec in transformer_layer_spec.layer_specs: layer_spec.module = Gemma4TransformerLayer diff --git a/src/mcore_bridge/model/register.py b/src/mcore_bridge/model/register.py index 84bfdb4..22ad26d 100644 --- a/src/mcore_bridge/model/register.py +++ b/src/mcore_bridge/model/register.py @@ -118,7 +118,7 @@ def _set_shared_expert_gate(self, transformer_layer_spec): if hasattr(layer_spec.submodules.mlp.submodules, 'shared_experts'): layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} - def _set_custom_layer(self, transformer_layer_spec): + def _set_transformer_layer(self, transformer_layer_spec): for layer_spec in transformer_layer_spec.layer_specs: layer_spec.module = TransformerLayer @@ -130,7 +130,7 @@ def build_model( ) -> Union['GPTModel', 'MultimodalGPTModel']: transformer_layer_spec = self.get_transformer_layer_spec(vp_stage=vp_stage) self._set_shared_expert_gate(transformer_layer_spec) - self._set_custom_layer(transformer_layer_spec) + self._set_transformer_layer(transformer_layer_spec) mtp_block_spec = None if self.config.mtp_num_layers is not None: mtp_block_spec = self.get_mtp_block_spec(transformer_layer_spec, vp_stage=vp_stage) diff --git a/src/mcore_bridge/model/rope.py b/src/mcore_bridge/model/rope.py index e7db3c3..a514f43 100644 --- a/src/mcore_bridge/model/rope.py +++ b/src/mcore_bridge/model/rope.py @@ -106,12 +106,13 @@ def _get_rope_type(rope_scaling: Optional[Dict[str, Any]]): return rope_type -def get_rope_inv_freq(config, seq_len=None, **kwargs): +def get_rope_inv_freq(config, seq_len=None, text_config=None, **kwargs): from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS ROPE_INIT_FUNCTIONS.update(EXTENDED_ROPE_INIT_FUNCTIONS) - dummy_config = _get_dummy_config(config) + if text_config is None: + text_config = _get_dummy_config(config) rope_init_fn = ROPE_INIT_FUNCTIONS[_get_rope_type(config.rope_scaling)] - inv_freq, attention_scaling = rope_init_fn(dummy_config, 'cpu', seq_len=seq_len, **kwargs) + inv_freq, attention_scaling = rope_init_fn(text_config, 'cpu', seq_len=seq_len, **kwargs) if attention_scaling is None: attention_scaling = 1. return inv_freq, attention_scaling From 23356d9a4a42cfe9c414e731530d7639965de98c Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 16 May 2026 22:46:54 +0800 Subject: [PATCH 33/52] update --- src/mcore_bridge/config/parser.py | 9 + src/mcore_bridge/model/gpt_model.py | 7 +- src/mcore_bridge/model/mm_gpts/gemma4.py | 209 +++++++++++++++++++++-- 3 files changed, 210 insertions(+), 15 deletions(-) diff --git a/src/mcore_bridge/config/parser.py b/src/mcore_bridge/config/parser.py index f6968a2..c9db851 100644 --- a/src/mcore_bridge/config/parser.py +++ b/src/mcore_bridge/config/parser.py @@ -1,4 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import torch.nn.functional as F +from functools import partial from transformers import PretrainedConfig from typing import Any, Dict @@ -155,6 +157,13 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]: res['rotary_interleaved'] = True elif hf_model_type in {'gemma4'}: res['qk_layernorm'] = True + res['window_size'] = f'{window_size},0' + window_attn_skip_freq = ','.join(['1' if lt == 'sliding_attention' else '0' for lt in layer_types]) + res['window_attn_skip_freq'] = f'[{window_attn_skip_freq}]' + res['softmax_scale'] = 1. + res['swiglu'] = False + res['gated_linear_unit'] = True + res['activation_func'] = partial(F.gelu, approximate='tanh') elif llm_model_type == 'gpt_oss': res['add_bias_linear'] = True res['bias_dropout_fusion'] = False diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 8a5856c..96c3c86 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -378,6 +378,9 @@ def forward( inference_context=inference_context, ) + def _forward_output_layer(self, hidden_states, *args, **kwargs): + return self.output_layer(hidden_states, *args, **kwargs)[0] + def _postprocess( self, hidden_states, @@ -444,7 +447,7 @@ def _postprocess( loss_mask = torch.ones_like(mtp_labels) for mtp_layer_number in range(self.config.mtp_unroll_steps): # output - mtp_logits, _ = self.output_layer( + mtp_logits = self._forward_output_layer( hidden_states_list[mtp_layer_number + 1], weight=output_weight, runtime_gather_output=runtime_gather_output, @@ -509,7 +512,7 @@ def _postprocess( if self.config.task_type == 'embedding': logits = F.normalize(hidden_states, p=2, dim=-1) else: - logits, _ = self.output_layer( + logits = self._forward_output_layer( hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output) if self.config.task_type == 'generative_reranker': logits = gather_from_tensor_model_parallel_region(logits) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 430a13d..67f2df0 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -3,16 +3,22 @@ import math import torch import torch.distributed as dist -from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TENorm, TERowParallelLinear +from megatron.core.extensions.transformer_engine import (SplitAlongDim, TEColumnParallelLinear, TENorm, + TERowParallelLinear) +from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from megatron.core.models.common.embeddings.yarn_rotary_pos_embedding import _yarn_get_concentration_factor_from_config from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec +from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.tensor_parallel import VocabParallelEmbedding from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP from megatron.core.transformer.spec_utils import build_module +from megatron.core.utils import make_viewless_tensor, nvtx_range_pop, nvtx_range_push +from torch import Tensor from transformers import AutoModel, PretrainedConfig -from typing import Optional +from typing import Optional, Tuple from mcore_bridge.bridge import MultimodalGPTBridge from mcore_bridge.config import ModelConfig @@ -133,7 +139,7 @@ def __init__( if not self.is_sliding and text_config.global_head_dim else text_config.head_dim) # Alternative attention (k == v) for global layers when `attention_k_eq_v` is set - self.use_alternative_attention = (getattr(text_config, 'attention_k_eq_v', False) and not self.is_sliding) + self.use_alternative_attention = (text_config.attention_k_eq_v and not self.is_sliding) num_key_value_heads = ( text_config.num_global_key_value_heads if self.use_alternative_attention else text_config.num_key_value_heads) @@ -187,6 +193,134 @@ def __init__( self.v_norm = ( Gemma4VNorm(self.head_dim, eps=self.config.layernorm_epsilon) if not self.is_kv_shared_layer else None) + def _forward_core_attention( + self, + query, + key, + value, + attention_mask, + attention_bias: Optional[torch.Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + ): + nvtx_range_push(suffix='core_attention') + attn_mask_type = self.attn_mask_type + if self.checkpoint_core_attention and self.training: + core_attn_out = self._checkpointed_attention_forward( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + else: + core_attn_out = self.core_attention( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': + # reshape to same output shape as unpacked case + # (t, np, hn) -> (t, b=1, h=np*hn) + # t is the pack size = sum (sq_i) + # note that batch is a dummy dimension in the packed case + core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) + nvtx_range_pop(suffix='core_attention') + return core_attn_out + + def _apply_rotary(self, query, key, rotary_pos_emb, packed_seq_params): + nvtx_range_push(suffix='rotary_pos_emb') + q_pos_emb, k_pos_emb = rotary_pos_emb + + if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': + if packed_seq_params.cu_seqlens_q_padded is not None: + cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded + else: + cu_seqlens_q = packed_seq_params.cu_seqlens_q + if packed_seq_params.cu_seqlens_kv_padded is not None: + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded + else: + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv + else: + cu_seqlens_q = cu_seqlens_kv = None + + if q_pos_emb is not None: + # TODO VIJAY: simplify + query = apply_rotary_pos_emb( + query, + q_pos_emb, + config=self.config, + cu_seqlens=cu_seqlens_q, + mscale=_yarn_get_concentration_factor_from_config(self.config), + cp_group=self.pg_collection.cp, + ) + if not self.is_kv_shared_layer and k_pos_emb is not None: + key = apply_rotary_pos_emb( + key, + k_pos_emb, + config=self.config, + cu_seqlens=cu_seqlens_kv, + mscale=_yarn_get_concentration_factor_from_config(self.config), + cp_group=self.pg_collection.cp, + ) + nvtx_range_pop(suffix='rotary_pos_emb') + return query, key + + def forward(self, hidden_states: Tensor, attention_mask: Tensor, **kwargs) -> Tuple[Tensor, Tensor]: + shared_kv_states = kwargs['shared_kv_states'] + rotary_pos_emb = kwargs.get('rotary_pos_emb') + packed_seq_params = kwargs.get('packed_seq_params') + attention_bias = kwargs.get('attention_bias') + mixed_qkv, _ = self.linear_qkv(hidden_states) + if self.is_kv_shared_layer: + query = mixed_qkv + key, value = shared_kv_states[self.layer_type] + else: + num_query_heads_per_group = (self.num_attention_heads_per_partition // self.num_query_groups_per_partition) + num_qkv_heads_per_group = num_query_heads_per_group + 2 + # If no output gate: [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] + # If have output gate: [sq, b, hp] --> [sq, b, ng, (2 * np/ng + 2) * hn] + new_tensor_shape = mixed_qkv.size()[:-1] + ( + self.num_query_groups_per_partition, + num_qkv_heads_per_group * self.hidden_size_per_attention_head, + ) + mixed_qkv = mixed_qkv.view(*new_tensor_shape) + + # If no output gate: [sq, b, ng, (np/ng + 2) * hn] + # --> [sq, b, ng, np/ng * hn], None, [sq, b, ng, hn], [sq, b, ng, hn] + split_arg_list = [ + num_query_heads_per_group * self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head, + ] + if SplitAlongDim is not None: + (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list) + else: + (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3) + key = self.k_layernorm(key) + value = self.v_norm(value) + # Query [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] + query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) + query = self.q_layernorm(query) + if isinstance(rotary_pos_emb, torch.Tensor): + rotary_pos_emb = (rotary_pos_emb, ) * 2 + + query, key = self._apply_rotary(query, key, rotary_pos_emb, packed_seq_params) + if self.store_full_length_kv: + shared_kv_states[self.layer_type] = key, value + core_attn_out = self._forward_core_attention(query, key, value, attention_mask, attention_bias, + packed_seq_params) + + nvtx_range_push(suffix='linear_proj') + output, bias = self.linear_proj(core_attn_out) + nvtx_range_pop(suffix='linear_proj') + return output, bias + class Gemma4MLP(MLP): @@ -214,7 +348,7 @@ def __init__( class Gemma4Bridge(MultimodalGPTBridge): hf_post_attention_layernorm = 'pre_feedforward_layernorm' - additional_dim0_keys = {'per_layer_input_gate', 'per_layer_model_projection'} + additional_dim0_keys = {'embed_tokens_per_layer', 'per_layer_input_gate', 'per_layer_model_projection'} additional_dim1_keys = {'per_layer_projection'} def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore): @@ -249,6 +383,7 @@ def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: i ]: self._set_state_dict( mg_layer, f'{key}.weight', hf_state_dict, f'{key}.weight', to_mcore, _check_mg_param=False) + self._set_state_dict(mg_layer, 'layer_scalar', hf_state_dict, 'layer_scalar', to_mcore) if to_mcore: hf_state_dict = {} else: @@ -270,8 +405,8 @@ def __init__(self, *args, **kwargs): text_config = self.config.hf_config.text_config self.text_config = text_config self.unique_layer_types = set(text_config.layer_types) - - self.hidden_size_per_layer_input = getattr(text_config, 'hidden_size_per_layer_input', None) + self.hidden_size_per_layer_input = text_config.hidden_size_per_layer_input + self.final_logit_softcapping = text_config.final_logit_softcapping if self.hidden_size_per_layer_input and self.pre_process: num_layers = text_config.num_hidden_layers hidden_size = text_config.hidden_size @@ -346,7 +481,7 @@ def forward(self, *args, **kwargs): *llm_input_ids.shape, self.text_config.num_hidden_layers, self.hidden_size_per_layer_input, - ) + ).transpose(0, 1) per_layer_projection = self.per_layer_model_projection( inputs_embeds)[0] * self.per_layer_model_projection_scale per_layer_projection = per_layer_projection.reshape( @@ -354,13 +489,22 @@ def forward(self, *args, **kwargs): self.text_config.num_hidden_layers, self.hidden_size_per_layer_input, ) - per_layer_projection = self.per_layer_projection_norm(per_layer_projection).transpose(0, 1) + per_layer_projection = self.per_layer_projection_norm(per_layer_projection) per_layer_inputs = (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale extra_block_kwargs['per_layer_inputs'] = per_layer_inputs + extra_block_kwargs['shared_kv_states'] = {} kwargs['extra_block_kwargs'] = extra_block_kwargs return super().forward(*args, **kwargs) + def _forward_output_layer(self, hidden_states, *args, **kwargs): + logits, _ = self.output_layer(hidden_states, *args, **kwargs) + if self.final_logit_softcapping is not None: + logits = logits / self.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.final_logit_softcapping + return logits + class Gemma4TransformerLayer(TransformerLayer): @@ -375,7 +519,7 @@ def __init__(self, config, submodules, *args, **kwargs): self.register_buffer('layer_scalar', torch.ones(1)) - self.hidden_size_per_layer_input = getattr(text_config, 'hidden_size_per_layer_input', None) + self.hidden_size_per_layer_input = text_config.hidden_size_per_layer_input if self.hidden_size_per_layer_input: from transformers.activations import ACT2FN self.act_fn = ACT2FN[text_config.hidden_activation] @@ -407,7 +551,7 @@ def __init__(self, config, submodules, *args, **kwargs): ) self.post_per_layer_input_norm = build_module(TENorm, hidden_size=hidden_size, config=self.config, eps=eps) - self.enable_moe_block = getattr(text_config, 'enable_moe_block', False) + self.enable_moe_block = text_config.enable_moe_block if self.enable_moe_block: self.post_feedforward_layernorm_1 = build_module( TENorm, hidden_size=hidden_size, config=self.config, eps=eps) @@ -416,10 +560,49 @@ def __init__(self, config, submodules, *args, **kwargs): self.pre_feedforward_layernorm_2 = build_module( TENorm, hidden_size=hidden_size, config=self.config, eps=eps) - def forward(self, *args, **kwargs): + def _forward_attention(self, hidden_states: Tensor, **kwargs): + context = kwargs.pop('context', None) + residual = hidden_states + input_layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + nvtx_range_push(suffix='self_attention') + attention_output, bias = self.self_attention(input_layernorm_output, **kwargs) + nvtx_range_pop(suffix='self_attention') + attention_output = self.post_attention_layernorm(attention_output) + + hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)( + (attention_output, bias), residual, self.hidden_dropout) + return hidden_states, context + + def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None): + # Residual connection. + residual = hidden_states + + # Optional Layer norm post the cross-attention. + pre_mlp_layernorm_output = self._forward_pre_mlp_layernorm(hidden_states) + mlp_output, bias = self.mlp(pre_mlp_layernorm_output, padding_mask=padding_mask) + if self.enable_moe_block: + pass + mlp_output = self.post_feedforward_layernorm(mlp_output) + hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)((mlp_output, bias), residual, + self.hidden_dropout) + output = make_viewless_tensor(inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True) + return output + + def forward(self, hidden_states, *args, **kwargs): per_layer_input = kwargs.pop('per_layer_input', None) - output, context = super().forward(*args, **kwargs) - return output, context + hidden_states, context = super().forward(hidden_states, *args, **kwargs) + if self.hidden_size_per_layer_input: + residual = hidden_states + hidden_states, _ = self.per_layer_input_gate(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = hidden_states * per_layer_input + hidden_states, _ = self.per_layer_projection(hidden_states) + hidden_states = self.post_per_layer_input_norm(hidden_states) + hidden_states = residual + hidden_states + + hidden_states *= self.layer_scalar + return hidden_states, context class Gemma4GPTModel(MultimodalGPTModel): From be8f320569eab5fb514ee400a195e0b95c95414e Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 18 May 2026 10:29:10 +0800 Subject: [PATCH 34/52] update --- src/mcore_bridge/model/mm_gpts/gemma4.py | 53 ++++++++++++------------ 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 67f2df0..2cdcba6 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -10,7 +10,8 @@ from megatron.core.models.common.embeddings.yarn_rotary_pos_embedding import _yarn_get_concentration_factor_from_config from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec from megatron.core.packed_seq_params import PackedSeqParams -from megatron.core.tensor_parallel import VocabParallelEmbedding +from megatron.core.parallel_state import get_tensor_model_parallel_rank +from megatron.core.tensor_parallel import VocabParallelEmbedding, all_gather_last_dim_from_tensor_parallel_region from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP @@ -147,7 +148,7 @@ def __init__( # Shared KV across the trailing layers num_kv_shared_layers = getattr(text_config, 'num_kv_shared_layers', 0) - first_kv_shared_layer_idx = text_config.num_hidden_layers - num_kv_shared_layers + first_kv_shared_layer_idx = config.num_layers - num_kv_shared_layers self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 prev_layers = text_config.layer_types[:first_kv_shared_layer_idx] if self.is_kv_shared_layer: @@ -277,20 +278,24 @@ def forward(self, hidden_states: Tensor, attention_mask: Tensor, **kwargs) -> Tu packed_seq_params = kwargs.get('packed_seq_params') attention_bias = kwargs.get('attention_bias') mixed_qkv, _ = self.linear_qkv(hidden_states) + if getattr(self, 'world_size', None) is not None and self.config.num_query_groups < self.world_size: + mixed_qkv = all_gather_last_dim_from_tensor_parallel_region(mixed_qkv) + idx = get_tensor_model_parallel_rank() // (self.world_size // self.config.num_query_groups) + size = mixed_qkv.size()[-1] // self.config.num_query_groups + mixed_qkv = mixed_qkv[:, :, idx * size:(idx + 1) * size] + if self.is_kv_shared_layer: query = mixed_qkv key, value = shared_kv_states[self.layer_type] else: num_query_heads_per_group = (self.num_attention_heads_per_partition // self.num_query_groups_per_partition) - num_qkv_heads_per_group = num_query_heads_per_group + 2 # If no output gate: [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] # If have output gate: [sq, b, hp] --> [sq, b, ng, (2 * np/ng + 2) * hn] new_tensor_shape = mixed_qkv.size()[:-1] + ( self.num_query_groups_per_partition, - num_qkv_heads_per_group * self.hidden_size_per_attention_head, + (num_query_heads_per_group + 2) * self.hidden_size_per_attention_head, ) mixed_qkv = mixed_qkv.view(*new_tensor_shape) - # If no output gate: [sq, b, ng, (np/ng + 2) * hn] # --> [sq, b, ng, np/ng * hn], None, [sq, b, ng, hn], [sq, b, ng, hn] split_arg_list = [ @@ -306,6 +311,10 @@ def forward(self, hidden_states: Tensor, attention_mask: Tensor, **kwargs) -> Tu value = self.v_norm(value) # Query [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) + if getattr(self, 'world_size', None) is not None and self.config.num_query_groups < self.world_size: + idx = get_tensor_model_parallel_rank() % (self.world_size // self.config.num_query_groups) + size = query.shape[2] // (self.world_size // self.config.num_query_groups) + query = query[:, :, idx * size:(idx + 1) * size, :] query = self.q_layernorm(query) if isinstance(rotary_pos_emb, torch.Tensor): rotary_pos_emb = (rotary_pos_emb, ) * 2 @@ -335,7 +344,7 @@ def __init__( self.layer_number = layer_number text_config = config.hf_config.text_config self.enable_moe_block = text_config.enable_moe_block - first_kv_shared_layer_idx = text_config.num_hidden_layers - text_config.num_kv_shared_layers + first_kv_shared_layer_idx = config.num_layers - text_config.num_kv_shared_layers is_kv_shared_layer = layer_number > first_kv_shared_layer_idx > 0 use_double_wide_mlp = text_config.use_double_wide_mlp and is_kv_shared_layer ffn_hidden_size = config.ffn_hidden_size @@ -408,13 +417,9 @@ def __init__(self, *args, **kwargs): self.hidden_size_per_layer_input = text_config.hidden_size_per_layer_input self.final_logit_softcapping = text_config.final_logit_softcapping if self.hidden_size_per_layer_input and self.pre_process: - num_layers = text_config.num_hidden_layers - hidden_size = text_config.hidden_size - total_dim = num_layers * self.hidden_size_per_layer_input - tp_size = self.config.tensor_model_parallel_size - padded_vocab_size_per_layer = math.ceil(text_config.vocab_size_per_layer_input / tp_size) * tp_size + total_dim = self.config.num_layers * self.hidden_size_per_layer_input self.embed_tokens_per_layer = VocabParallelEmbedding( - num_embeddings=padded_vocab_size_per_layer, + num_embeddings=self.vocab_size, embedding_dim=total_dim, init_method=self.config.init_method, config=self.config, @@ -424,18 +429,18 @@ def __init__(self, *args, **kwargs): self.per_layer_input_scale = 2.0**-0.5 self.per_layer_model_projection = build_module( TEColumnParallelLinear, - hidden_size, + self.config.hidden_size, total_dim, config=self.config, init_method=self.config.init_method, gather_output=False, bias=False, - skip_bias_add=True, + skip_bias_add=False, is_expert=False, tp_comm_buffer_name='per_layer_model_projection', tp_group=self.pg_collection.tp, ) - self.per_layer_model_projection_scale = hidden_size**-0.5 + self.per_layer_model_projection_scale = self.config.hidden_size**-0.5 self.per_layer_projection_norm = build_module( TENorm, hidden_size=self.hidden_size_per_layer_input, @@ -477,21 +482,15 @@ def forward(self, *args, **kwargs): llm_input_ids = extra_block_kwargs.pop('llm_input_ids', None) if self.hidden_size_per_layer_input and self.pre_process: inputs_embeds = kwargs['decoder_input'] - per_layer_inputs = (self.embed_tokens_per_layer(llm_input_ids) * self.embed_tokens_per_layer_scale).reshape( - *llm_input_ids.shape, - self.text_config.num_hidden_layers, - self.hidden_size_per_layer_input, - ).transpose(0, 1) + per_layer_inputs = self.embed_tokens_per_layer(llm_input_ids) * self.embed_tokens_per_layer_scale + per_layer_inputs = per_layer_inputs.reshape(*per_layer_inputs.shape[:-1], self.config.num_layers, + -1).transpose(0, 1) per_layer_projection = self.per_layer_model_projection( inputs_embeds)[0] * self.per_layer_model_projection_scale - per_layer_projection = per_layer_projection.reshape( - *inputs_embeds.shape[:-1], - self.text_config.num_hidden_layers, - self.hidden_size_per_layer_input, - ) + per_layer_projection = per_layer_projection.reshape(*per_layer_projection.shape[:-1], + self.config.num_layers, -1) per_layer_projection = self.per_layer_projection_norm(per_layer_projection) per_layer_inputs = (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale - extra_block_kwargs['per_layer_inputs'] = per_layer_inputs extra_block_kwargs['shared_kv_states'] = {} kwargs['extra_block_kwargs'] = extra_block_kwargs @@ -531,7 +530,7 @@ def __init__(self, config, submodules, *args, **kwargs): init_method=self.config.init_method, gather_output=False, bias=False, - skip_bias_add=True, + skip_bias_add=False, is_expert=False, tp_comm_buffer_name='per_layer_input_gate', tp_group=self.pg_collection.tp, From bdb29e055181e3e6bfacf9e92ab5c70887672427 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 18 May 2026 11:21:33 +0800 Subject: [PATCH 35/52] fix pp --- src/mcore_bridge/model/mm_gpts/gemma4.py | 52 ++++++++++++++++-------- 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 2cdcba6..0b232a1 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -12,6 +12,8 @@ from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.parallel_state import get_tensor_model_parallel_rank from megatron.core.tensor_parallel import VocabParallelEmbedding, all_gather_last_dim_from_tensor_parallel_region +from megatron.core.tensor_parallel.mappings import (gather_from_tensor_model_parallel_region, + scatter_to_tensor_model_parallel_region) from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP @@ -368,10 +370,11 @@ def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore): def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool): is_kv_shared_layer = False if mg_attn is None else mg_attn.is_kv_shared_layer - is_kv_shared_layer = torch.tensor([is_kv_shared_layer], dtype=torch.bool, device='cuda') - if self.pp_size > 1: - dist.all_reduce(is_kv_shared_layer, group=self.pp_group, op=dist.ReduceOp.MAX) - is_kv_shared_layer = is_kv_shared_layer.item() + if not to_mcore: + is_kv_shared_layer = torch.tensor([is_kv_shared_layer], dtype=torch.bool, device='cuda') + if self.pp_size > 1: + dist.all_reduce(is_kv_shared_layer, group=self.pp_group, op=dist.ReduceOp.MAX) + is_kv_shared_layer = is_kv_shared_layer.item() if is_kv_shared_layer: self._set_state_dict(mg_attn, 'linear_qkv.weight', hf_state_dict, 'q_proj.weight', to_mcore) return hf_state_dict @@ -480,21 +483,38 @@ def _set_inv_freq(self): def forward(self, *args, **kwargs): extra_block_kwargs = kwargs.pop('extra_block_kwargs', None) or {} llm_input_ids = extra_block_kwargs.pop('llm_input_ids', None) - if self.hidden_size_per_layer_input and self.pre_process: - inputs_embeds = kwargs['decoder_input'] - per_layer_inputs = self.embed_tokens_per_layer(llm_input_ids) * self.embed_tokens_per_layer_scale - per_layer_inputs = per_layer_inputs.reshape(*per_layer_inputs.shape[:-1], self.config.num_layers, - -1).transpose(0, 1) - per_layer_projection = self.per_layer_model_projection( - inputs_embeds)[0] * self.per_layer_model_projection_scale - per_layer_projection = per_layer_projection.reshape(*per_layer_projection.shape[:-1], - self.config.num_layers, -1) - per_layer_projection = self.per_layer_projection_norm(per_layer_projection) - per_layer_inputs = (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale + decoder_input = kwargs.get('decoder_input') + if self.hidden_size_per_layer_input: + if decoder_input is None: + # PP + input_tensor = self.get_input_tensor() + per_layer_inputs_dim = self.hidden_size_per_layer_input * self.config.num_layers + input_tensor, per_layer_inputs = input_tensor.split( + [input_tensor.shape[-1] - per_layer_inputs_dim, per_layer_inputs_dim], dim=-1) + self.set_input_tensor(input_tensor) + per_layer_inputs = per_layer_inputs.view(*per_layer_inputs.shape[:2], self.config.num_layers, + self.hidden_size_per_layer_input) + else: + inputs_embeds = decoder_input + per_layer_inputs = self.embed_tokens_per_layer(llm_input_ids) * self.embed_tokens_per_layer_scale + per_layer_inputs = per_layer_inputs.reshape(*per_layer_inputs.shape[:-1], self.config.num_layers, + -1).transpose(0, 1) + per_layer_projection = self.per_layer_model_projection( + inputs_embeds)[0] * self.per_layer_model_projection_scale + per_layer_projection = gather_from_tensor_model_parallel_region(per_layer_projection) + per_layer_projection = per_layer_projection.reshape(*per_layer_projection.shape[:-1], + self.config.num_layers, -1) + per_layer_projection = self.per_layer_projection_norm(per_layer_projection) + per_layer_inputs = (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale + per_layer_inputs = scatter_to_tensor_model_parallel_region(per_layer_inputs) extra_block_kwargs['per_layer_inputs'] = per_layer_inputs extra_block_kwargs['shared_kv_states'] = {} kwargs['extra_block_kwargs'] = extra_block_kwargs - return super().forward(*args, **kwargs) + hidden_states = super().forward(*args, **kwargs) + if not self.post_process: + per_layer_inputs = per_layer_inputs.view(*per_layer_inputs.shape[:2], -1) + hidden_states = torch.concat([hidden_states, per_layer_inputs], dim=-1) + return hidden_states def _forward_output_layer(self, hidden_states, *args, **kwargs): logits, _ = self.output_layer(hidden_states, *args, **kwargs) From d097e68b37636afed79949f5eb09d30224d73b1c Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 18 May 2026 14:15:50 +0800 Subject: [PATCH 36/52] update --- src/mcore_bridge/model/mm_gpts/gemma4.py | 54 +++++++++++++++++++----- 1 file changed, 43 insertions(+), 11 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 0b232a1..8db0d15 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -486,14 +486,7 @@ def forward(self, *args, **kwargs): decoder_input = kwargs.get('decoder_input') if self.hidden_size_per_layer_input: if decoder_input is None: - # PP - input_tensor = self.get_input_tensor() - per_layer_inputs_dim = self.hidden_size_per_layer_input * self.config.num_layers - input_tensor, per_layer_inputs = input_tensor.split( - [input_tensor.shape[-1] - per_layer_inputs_dim, per_layer_inputs_dim], dim=-1) - self.set_input_tensor(input_tensor) - per_layer_inputs = per_layer_inputs.view(*per_layer_inputs.shape[:2], self.config.num_layers, - self.hidden_size_per_layer_input) + per_layer_inputs, shared_kv_states = self.unpack_pp_input() else: inputs_embeds = decoder_input per_layer_inputs = self.embed_tokens_per_layer(llm_input_ids) * self.embed_tokens_per_layer_scale @@ -507,15 +500,54 @@ def forward(self, *args, **kwargs): per_layer_projection = self.per_layer_projection_norm(per_layer_projection) per_layer_inputs = (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale per_layer_inputs = scatter_to_tensor_model_parallel_region(per_layer_inputs) + shared_kv_states = {} extra_block_kwargs['per_layer_inputs'] = per_layer_inputs - extra_block_kwargs['shared_kv_states'] = {} + extra_block_kwargs['shared_kv_states'] = shared_kv_states kwargs['extra_block_kwargs'] = extra_block_kwargs hidden_states = super().forward(*args, **kwargs) if not self.post_process: - per_layer_inputs = per_layer_inputs.view(*per_layer_inputs.shape[:2], -1) - hidden_states = torch.concat([hidden_states, per_layer_inputs], dim=-1) + hidden_states = self._pack_pp_output(hidden_states, per_layer_inputs, shared_kv_states) return hidden_states + def _pack_pp_output(self, hidden_states, per_layer_inputs, shared_kv_states): + per_layer_inputs = per_layer_inputs.view(*per_layer_inputs.shape[:2], -1) + hidden_states = torch.concat([hidden_states, per_layer_inputs], dim=-1) + flag = per_layer_inputs.new_zeros(*per_layer_inputs.shape[:2], 1) + if 'sliding_attention' in shared_kv_states: + flag[0] = 1 + sliding_states = torch.concat(shared_kv_states['sliding_attention'], -1) + sliding_states = sliding_states.view(*sliding_states.shape[:2], -1) + hidden_states = torch.concat([hidden_states, sliding_states], dim=-1) + if 'full_attention' in shared_kv_states: + flag[1] = 1 + full_states = torch.concat(shared_kv_states['full_attention'], -1) + full_states = full_states.view(*full_states.shape[:2], -1) + hidden_states = torch.concat([hidden_states, full_states], dim=-1) + hidden_states = torch.concat([hidden_states, flag], dim=-1) + return hidden_states + + def unpack_pp_input(self): + shared_kv_states = {} + input_tensor = self.get_input_tensor() + input_tensor, flag = input_tensor.split([input_tensor.shape[-1] - 1, 1], dim=-1) + per_layer_inputs_dim = self.hidden_size_per_layer_input * self.config.num_layers + if flag[1] == 1: + full_head_dim = self.text_config.global_head_dim + input_tensor, full_states = input_tensor.split( + [input_tensor.shape[-1] - full_head_dim * 2, full_head_dim * 2], dim=-1) + shared_kv_states['full_attention'] = full_states[:, :, None].split([full_head_dim, full_head_dim], -1) + if flag[0] == 1: + input_tensor, sliding_states = input_tensor.split( + [input_tensor.shape[-1] - self.config.kv_channels * 2, self.config.kv_channels * 2], dim=-1) + shared_kv_states['sliding_attention'] = sliding_states[:, :, None].split( + [self.config.kv_channels, self.config.kv_channels], -1) + input_tensor, per_layer_inputs = input_tensor.split( + [input_tensor.shape[-1] - per_layer_inputs_dim, per_layer_inputs_dim], dim=-1) + self.set_input_tensor(input_tensor) + per_layer_inputs = per_layer_inputs.view(*per_layer_inputs.shape[:2], self.config.num_layers, + self.hidden_size_per_layer_input) + return per_layer_inputs, shared_kv_states + def _forward_output_layer(self, hidden_states, *args, **kwargs): logits, _ = self.output_layer(hidden_states, *args, **kwargs) if self.final_logit_softcapping is not None: From 5c29f779aaeaff2a2a7f6ae11d554dce7f964f4c Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 18 May 2026 16:44:03 +0800 Subject: [PATCH 37/52] fix --- src/mcore_bridge/bridge/gpt_bridge.py | 41 +++++++----- src/mcore_bridge/model/gpts/bailing_moe.py | 2 +- src/mcore_bridge/model/mm_gpts/gemma4.py | 74 +++++++++++++--------- 3 files changed, 72 insertions(+), 45 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 4b22bec..ebf3bc9 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -346,7 +346,7 @@ def _all_gather_tp(self, tensor, tp_dim, is_expert): del output return tensor - def _broadcast_ep_pp(self, tensor, is_expert): + def _broadcast_ep_pp(self, tensor, is_expert, is_scalar: bool = False): pp_group = self.ep_pp_group if is_expert else self.pp_group pp_size = self.ep_pp_size if is_expert else self.pp_size pp_rank = self.ep_pp_rank if is_expert else self.pp_rank @@ -360,7 +360,8 @@ def _broadcast_ep_pp(self, tensor, is_expert): dtype_mapping_r = {v: k for k, v in dtype_mapping.items()} if tensor is None: dist.broadcast(meta_data, src=src_rank, group=pp_group) - assert meta_data[0].item() > 0, f'meta_data: {meta_data}' + if not is_scalar: + assert meta_data[0].item() > 0, f'meta_data: {meta_data}' shape = meta_data[1:1 + meta_data[0]].tolist() dtype = dtype_mapping_r[meta_data[-1].item()] tensor = torch.empty(shape, device='cuda', dtype=dtype) @@ -383,7 +384,10 @@ def _get_weight( # tp/etp mg_scale_inv = None tensor = mg_weight - if tensor is not None: + is_scalar = False + if isinstance(tensor, torch.Tensor) and tensor.ndim == 0: + is_scalar = True + if tensor is not None and not is_scalar: if not isinstance(tensor, (list, tuple)): tensor = [tensor] if self._is_fp8_param(tensor[0]): @@ -392,8 +396,7 @@ def _get_weight( for t in tensor ] tensor = [t._rowwise_data for t in tensor] - del mg_weight - if tensor is not None: + del mg_weight assert isinstance(tensor, (list, tuple)), f'mg_key: {mg_key}' tensor = torch.concat(tensor, dim=0) if mg_scale_inv is not None: @@ -407,7 +410,7 @@ def _get_weight( mg_scale_inv = mg_scale_inv.view(num_local_experts * 2, -1, mg_scale_inv.shape[-1]) tensor = self._all_gather_tp(tensor, tp_dim, is_expert) - tensor = self._broadcast_ep_pp(tensor, is_expert) + tensor = self._broadcast_ep_pp(tensor, is_expert, is_scalar=is_scalar) if tensor.dtype == torch.uint8: mg_scale_inv = self._all_gather_tp(mg_scale_inv, tp_dim, is_expert) mg_scale_inv = self._broadcast_ep_pp(mg_scale_inv, is_expert) @@ -443,8 +446,7 @@ def _set_state_dict(self, to_mcore: bool, *, offset: float = 0, - is_expert: bool = False, - _check_mg_param: bool = True): + is_expert: bool = False): if '.' in mg_key: module_key, param_key = mg_key.rsplit('.', 1) else: @@ -493,10 +495,7 @@ def _set_state_dict(self, mg_param = deep_getattr(sub_module, param_key) if to_mcore: if mg_param is None: - if _check_mg_param: - raise ValueError(f'mg_module: {mg_module}, mg_key: {mg_key}') - else: - return + raise ValueError(f'mg_module: {mg_module}, mg_key: {mg_key}') hf_weight = hf_state_dict[hf_key].load() if module_key in { 'embedding.word_embeddings', 'output_layer' @@ -534,7 +533,16 @@ def _filter_prefix(state_dict, prefix: str): return state_dict return {k: v for k, v in state_dict.items() if k.startswith(prefix)} - def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool): + def _reduce_tensor_pp_group(self, tensor, to_mcore, dtype=torch.long, op=dist.ReduceOp.MAX): + if to_mcore: + return tensor + tensor = torch.tensor([tensor], dtype=dtype, device='cuda') + if self.pp_size > 1: + dist.all_reduce(tensor, group=self.pp_group, op=op) + tensor = tensor.item() + return tensor + + def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): config = self.config num_query_groups = ( config.num_query_groups if config.num_query_groups is not None else config.num_attention_heads) @@ -577,10 +585,13 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool): self._set_weight( mg_attn.linear_qkv.weight, linear_qkv_weight, 'linear_qkv.weight', hf_scale_inv=qkv_scale_inv) else: - q_dim = self.config.kv_channels * self.config.num_attention_heads // self.config.num_query_groups + kv_channels = kwargs.get('kv_channels') + if kv_channels is None: + kv_channels = self.config.kv_channels + q_dim = kv_channels * self.config.num_attention_heads // self.config.num_query_groups if self.config.attention_output_gate: q_dim *= 2 - kv_dim = self.config.kv_channels + kv_dim = kv_channels q_block = q_dim // self.fp8_block_size kv_block = kv_dim // self.fp8_block_size is_lora = False if mg_attn is None else isinstance(mg_attn.linear_qkv, diff --git a/src/mcore_bridge/model/gpts/bailing_moe.py b/src/mcore_bridge/model/gpts/bailing_moe.py index f3c383a..7e2a70d 100644 --- a/src/mcore_bridge/model/gpts/bailing_moe.py +++ b/src/mcore_bridge/model/gpts/bailing_moe.py @@ -71,7 +71,7 @@ class BailingMoeBridge(GPTBridge): hf_expert_bias_key = 'gate.expert_bias' hf_o_proj_key = 'dense' - def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool): + def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): self._set_state_dict(mg_attn, 'linear_qkv.weight', hf_state_dict, 'query_key_value.weight', to_mcore) assert not self.config.add_bias_linear return hf_state_dict diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 8db0d15..602de19 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -363,23 +363,28 @@ class Gemma4Bridge(MultimodalGPTBridge): additional_dim1_keys = {'per_layer_projection'} def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore): - self._set_state_dict( - mg_attn, 'q_layernorm.weight', hf_state_dict, self.hf_q_norm_key, to_mcore, _check_mg_param=False) - self._set_state_dict( - mg_attn, 'k_layernorm.weight', hf_state_dict, self.hf_k_norm_key, to_mcore, _check_mg_param=False) + is_kv_shared_layer = self._get_is_kv_shared_layer(mg_attn, to_mcore) + self._set_state_dict(mg_attn, 'q_layernorm.weight', hf_state_dict, self.hf_q_norm_key, to_mcore) + if not is_kv_shared_layer: + self._set_state_dict(mg_attn, 'k_layernorm.weight', hf_state_dict, self.hf_k_norm_key, to_mcore) - def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool): + def _get_is_kv_shared_layer(self, mg_attn, to_mcore): is_kv_shared_layer = False if mg_attn is None else mg_attn.is_kv_shared_layer - if not to_mcore: - is_kv_shared_layer = torch.tensor([is_kv_shared_layer], dtype=torch.bool, device='cuda') - if self.pp_size > 1: - dist.all_reduce(is_kv_shared_layer, group=self.pp_group, op=dist.ReduceOp.MAX) - is_kv_shared_layer = is_kv_shared_layer.item() + return self._reduce_tensor_pp_group(is_kv_shared_layer, to_mcore) + + def _get_head_dim(self, mg_attn, to_mcore): + head_dim = 0 if mg_attn is None else mg_attn.head_dim + return self._reduce_tensor_pp_group(head_dim, to_mcore) + + def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): + is_kv_shared_layer = self._get_is_kv_shared_layer(mg_attn, to_mcore) + kwargs['kv_channels'] = self._get_head_dim(mg_attn, to_mcore) + assert kwargs['kv_channels'] > 0 if is_kv_shared_layer: self._set_state_dict(mg_attn, 'linear_qkv.weight', hf_state_dict, 'q_proj.weight', to_mcore) return hf_state_dict else: - return super()._set_qkv(mg_attn, hf_state_dict, to_mcore) + return super()._set_qkv(mg_attn, hf_state_dict, to_mcore, **kwargs) def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): hf_prefix = f'{hf_prefix}{layer_idx}.' @@ -393,8 +398,7 @@ def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: i 'post_attention_layernorm', 'post_feedforward_layernorm', 'per_layer_input_gate', 'per_layer_projection', 'post_per_layer_input_norm' ]: - self._set_state_dict( - mg_layer, f'{key}.weight', hf_state_dict, f'{key}.weight', to_mcore, _check_mg_param=False) + self._set_state_dict(mg_layer, f'{key}.weight', hf_state_dict, f'{key}.weight', to_mcore) self._set_state_dict(mg_layer, 'layer_scalar', hf_state_dict, 'layer_scalar', to_mcore) if to_mcore: hf_state_dict = {} @@ -414,6 +418,7 @@ class Gemma4TextGPTModel(GPTModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.num_query_groups_per_partition = self.decoder.layers[0].self_attention.num_query_groups_per_partition text_config = self.config.hf_config.text_config self.text_config = text_config self.unique_layer_types = set(text_config.layer_types) @@ -510,42 +515,53 @@ def forward(self, *args, **kwargs): return hidden_states def _pack_pp_output(self, hidden_states, per_layer_inputs, shared_kv_states): - per_layer_inputs = per_layer_inputs.view(*per_layer_inputs.shape[:2], -1) + per_layer_inputs = per_layer_inputs.view(*hidden_states.shape[:2], -1) hidden_states = torch.concat([hidden_states, per_layer_inputs], dim=-1) - flag = per_layer_inputs.new_zeros(*per_layer_inputs.shape[:2], 1) + flag = per_layer_inputs.new_zeros(*hidden_states.shape[:2], 1) if 'sliding_attention' in shared_kv_states: flag[0] = 1 sliding_states = torch.concat(shared_kv_states['sliding_attention'], -1) - sliding_states = sliding_states.view(*sliding_states.shape[:2], -1) + sliding_states = sliding_states.view(*hidden_states.shape[:2], -1) hidden_states = torch.concat([hidden_states, sliding_states], dim=-1) if 'full_attention' in shared_kv_states: flag[1] = 1 full_states = torch.concat(shared_kv_states['full_attention'], -1) - full_states = full_states.view(*full_states.shape[:2], -1) + full_states = full_states.view(*hidden_states.shape[:2], -1) hidden_states = torch.concat([hidden_states, full_states], dim=-1) hidden_states = torch.concat([hidden_states, flag], dim=-1) return hidden_states def unpack_pp_input(self): + tp_size = self.config.tensor_model_parallel_size shared_kv_states = {} input_tensor = self.get_input_tensor() + self.num_query_groups_per_partition + sequence_len = input_tensor.shape[0] * tp_size if self.config.sequence_parallel else input_tensor.shape[0] input_tensor, flag = input_tensor.split([input_tensor.shape[-1] - 1, 1], dim=-1) - per_layer_inputs_dim = self.hidden_size_per_layer_input * self.config.num_layers - if flag[1] == 1: - full_head_dim = self.text_config.global_head_dim - input_tensor, full_states = input_tensor.split( - [input_tensor.shape[-1] - full_head_dim * 2, full_head_dim * 2], dim=-1) - shared_kv_states['full_attention'] = full_states[:, :, None].split([full_head_dim, full_head_dim], -1) - if flag[0] == 1: + flag = flag.detach() + per_layer_inputs_shape = [ + sequence_len, input_tensor.shape[1], self.config.num_layers, self.hidden_size_per_layer_input // tp_size + ] + full_head_dim = self.text_config.global_head_dim + full_states_shape = per_layer_inputs_shape[:2] + [self.num_query_groups_per_partition, full_head_dim * 2] + sliding_states_shape = full_states_shape[:3] + [self.config.kv_channels * 2] + per_layer_inputs_dim = math.prod(per_layer_inputs_shape) // math.prod(input_tensor.shape[:2]) + full_states_dim = math.prod(full_states_shape) // math.prod(input_tensor.shape[:2]) + sliding_states_dim = math.prod(sliding_states_shape) // math.prod(input_tensor.shape[:2]) + if flag[1] != 0: + input_tensor, full_states = input_tensor.split([input_tensor.shape[-1] - full_states_dim, full_states_dim], + dim=-1) + full_states = full_states.reshape(*full_states_shape) + shared_kv_states['full_attention'] = full_states.chunk(2, -1) + if flag[0] != 0: input_tensor, sliding_states = input_tensor.split( - [input_tensor.shape[-1] - self.config.kv_channels * 2, self.config.kv_channels * 2], dim=-1) - shared_kv_states['sliding_attention'] = sliding_states[:, :, None].split( - [self.config.kv_channels, self.config.kv_channels], -1) + [input_tensor.shape[-1] - sliding_states_dim, sliding_states_dim], dim=-1) + sliding_states = sliding_states.reshape(*sliding_states_shape) + shared_kv_states['sliding_attention'] = sliding_states.chunk(2, -1) input_tensor, per_layer_inputs = input_tensor.split( [input_tensor.shape[-1] - per_layer_inputs_dim, per_layer_inputs_dim], dim=-1) self.set_input_tensor(input_tensor) - per_layer_inputs = per_layer_inputs.view(*per_layer_inputs.shape[:2], self.config.num_layers, - self.hidden_size_per_layer_input) + per_layer_inputs = per_layer_inputs.reshape(*per_layer_inputs_shape) return per_layer_inputs, shared_kv_states def _forward_output_layer(self, hidden_states, *args, **kwargs): From 82cd1ade282632252f299615dddd51a19b3a3d74 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 18 May 2026 17:53:42 +0800 Subject: [PATCH 38/52] update --- src/mcore_bridge/bridge/gpt_bridge.py | 76 +++++++-------- src/mcore_bridge/model/gpts/minimax_m2.py | 2 +- src/mcore_bridge/model/mm_gpts/gemma4.py | 107 +++++++++++----------- 3 files changed, 89 insertions(+), 96 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index ebf3bc9..9ba858e 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -346,7 +346,7 @@ def _all_gather_tp(self, tensor, tp_dim, is_expert): del output return tensor - def _broadcast_ep_pp(self, tensor, is_expert, is_scalar: bool = False): + def _broadcast_ep_pp(self, tensor, is_expert): pp_group = self.ep_pp_group if is_expert else self.pp_group pp_size = self.ep_pp_size if is_expert else self.pp_size pp_rank = self.ep_pp_rank if is_expert else self.pp_rank @@ -360,8 +360,6 @@ def _broadcast_ep_pp(self, tensor, is_expert, is_scalar: bool = False): dtype_mapping_r = {v: k for k, v in dtype_mapping.items()} if tensor is None: dist.broadcast(meta_data, src=src_rank, group=pp_group) - if not is_scalar: - assert meta_data[0].item() > 0, f'meta_data: {meta_data}' shape = meta_data[1:1 + meta_data[0]].tolist() dtype = dtype_mapping_r[meta_data[-1].item()] tensor = torch.empty(shape, device='cuda', dtype=dtype) @@ -384,9 +382,7 @@ def _get_weight( # tp/etp mg_scale_inv = None tensor = mg_weight - is_scalar = False - if isinstance(tensor, torch.Tensor) and tensor.ndim == 0: - is_scalar = True + is_scalar = isinstance(tensor, torch.Tensor) and tensor.ndim == 0 if tensor is not None and not is_scalar: if not isinstance(tensor, (list, tuple)): tensor = [tensor] @@ -410,7 +406,7 @@ def _get_weight( mg_scale_inv = mg_scale_inv.view(num_local_experts * 2, -1, mg_scale_inv.shape[-1]) tensor = self._all_gather_tp(tensor, tp_dim, is_expert) - tensor = self._broadcast_ep_pp(tensor, is_expert, is_scalar=is_scalar) + tensor = self._broadcast_ep_pp(tensor, is_expert) if tensor.dtype == torch.uint8: mg_scale_inv = self._all_gather_tp(mg_scale_inv, tp_dim, is_expert) mg_scale_inv = self._broadcast_ep_pp(mg_scale_inv, is_expert) @@ -547,17 +543,18 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): num_query_groups = ( config.num_query_groups if config.num_query_groups is not None else config.num_attention_heads) hidden_size_block = config.hidden_size // self.fp8_block_size + attention_k_eq_v = kwargs.get('attention_k_eq_v', False) + kv_proj_list = ['k_proj'] if attention_k_eq_v else ['k_proj', 'v_proj'] if to_mcore: if isinstance(mg_attn.linear_qkv, LoraParallelLinear): lora_A = hf_state_dict['q_proj.lora_A.weight'].load() - assert (lora_A == hf_state_dict['k_proj.lora_A.weight'].load()).all() and ( - lora_A == hf_state_dict['v_proj.lora_A.weight'].load() - ).all(), 'Need to ensure QKV\'s lora_A are consistent' + assert all((lora_A == hf_state_dict[f'{k}.lora_A.weight'].load()).all() + for k in kv_proj_list), 'Need to ensure QKV\'s lora_A are consistent' q_lora_B = hf_state_dict['q_proj.lora_B.weight'].load() lora_B = torch.cat([ q_lora_B.reshape((num_query_groups, -1, q_lora_B.shape[-1])), - hf_state_dict['k_proj.lora_B.weight'].load().reshape((num_query_groups, -1, q_lora_B.shape[-1])), - hf_state_dict['v_proj.lora_B.weight'].load().reshape((num_query_groups, -1, q_lora_B.shape[-1])), + *(hf_state_dict[f'{k}.lora_B.weight'].load().reshape((num_query_groups, -1, q_lora_B.shape[-1])) + for k in kv_proj_list), ], dim=1).reshape((-1, q_lora_B.shape[-1])) self._set_weight(mg_attn.linear_qkv.lora_A[self._adapter_name].weight, lora_A, @@ -566,20 +563,15 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): 'linear_qkv.lora_B.weight') elif not self._peft_format: linear_qkv_weight = torch.cat([ - hf_state_dict['q_proj.weight'].load().reshape((num_query_groups, -1, config.hidden_size)), - hf_state_dict['k_proj.weight'].load().reshape((num_query_groups, -1, config.hidden_size)), - hf_state_dict['v_proj.weight'].load().reshape((num_query_groups, -1, config.hidden_size)), + hf_state_dict[f'{k}.weight'].load().reshape((num_query_groups, -1, config.hidden_size)) + for k in ['q_proj'] + kv_proj_list ], dim=1).reshape((-1, config.hidden_size)) qkv_scale_inv = None if 'q_proj.weight_scale_inv' in hf_state_dict: qkv_scale_inv = torch.cat([ - hf_state_dict['q_proj.weight_scale_inv'].load().reshape( - (num_query_groups, -1, hidden_size_block)), - hf_state_dict['k_proj.weight_scale_inv'].load().reshape( - (num_query_groups, -1, hidden_size_block)), - hf_state_dict['v_proj.weight_scale_inv'].load().reshape( - (num_query_groups, -1, hidden_size_block)), + hf_state_dict[f'{k}.weight_scale_inv'].load().reshape((num_query_groups, -1, hidden_size_block)) + for k in ['q_proj'] + kv_proj_list ], dim=1).reshape((-1, hidden_size_block)) self._set_weight( @@ -607,15 +599,16 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): None if mg_attn is None else mg_attn.linear_qkv.lora_B[self._adapter_name].weight.data, f'linear_qkv.lora_B.{self._adapter_name}.weight') if lora_A is not None: - self._peft_target_modules.update({'q_proj', 'k_proj', 'v_proj'}) - for key in ['q_proj', 'k_proj', 'v_proj']: + self._peft_target_modules.update({'q_proj'} | set(kv_proj_list)) + for key in ['q_proj'] + kv_proj_list: hf_state_dict[f'{key}.lora_A.weight'] = lora_A.clone() lora_B = lora_B.reshape((num_query_groups, -1, lora_B.shape[-1])) hf_state_dict['q_proj.lora_B.weight'] = lora_B[:, :q_dim, :].reshape(-1, lora_B.shape[-1]).clone() - hf_state_dict['k_proj.lora_B.weight'] = lora_B[:, - q_dim:-kv_dim, :].reshape(-1, - lora_B.shape[-1]).clone() - hf_state_dict['v_proj.lora_B.weight'] = lora_B[:, -kv_dim:, :].reshape(-1, lora_B.shape[-1]).clone() + hf_state_dict['k_proj.lora_B.weight'] = lora_B[:, q_dim:q_dim + kv_dim:, :].reshape( + -1, lora_B.shape[-1]).clone() + if not attention_k_eq_v: + hf_state_dict['v_proj.lora_B.weight'] = lora_B[:, q_dim + kv_dim:, :].reshape( + -1, lora_B.shape[-1]).clone() elif not self._peft_format: mg_attn_weight, scale_inv = self._get_weight( None if mg_attn is None else mg_attn.linear_qkv.weight.data, 'linear_qkv.weight') @@ -623,27 +616,27 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): mg_attn_weight = mg_attn_weight.reshape((num_query_groups, -1, config.hidden_size)) hf_state_dict['q_proj.weight'] = mg_attn_weight[:, :q_dim, :].reshape(-1, config.hidden_size).clone() - hf_state_dict['k_proj.weight'] = mg_attn_weight[:, q_dim:-kv_dim, :].reshape( + hf_state_dict['k_proj.weight'] = mg_attn_weight[:, q_dim:q_dim + kv_dim, :].reshape( -1, config.hidden_size).clone() - hf_state_dict['v_proj.weight'] = mg_attn_weight[:, -kv_dim:, :].reshape(-1, - config.hidden_size).clone() + if not attention_k_eq_v: + hf_state_dict['v_proj.weight'] = mg_attn_weight[:, q_dim + kv_dim:, :].reshape( + -1, config.hidden_size).clone() if scale_inv is not None: scale_inv = scale_inv.reshape((num_query_groups, -1, hidden_size_block)) hf_state_dict['q_proj.weight_scale_inv'] = scale_inv[:, :q_block, :].reshape( -1, hidden_size_block).clone() - hf_state_dict['k_proj.weight_scale_inv'] = scale_inv[:, q_block:-kv_block, :].reshape( - -1, hidden_size_block).clone() - hf_state_dict['v_proj.weight_scale_inv'] = scale_inv[:, -kv_block:, :].reshape( + hf_state_dict['k_proj.weight_scale_inv'] = scale_inv[:, q_block:q_block + kv_block:, :].reshape( -1, hidden_size_block).clone() + if not attention_k_eq_v: + dict['v_proj.weight_scale_inv'] = scale_inv[:, q_block + kv_block:, :].reshape( + -1, hidden_size_block).clone() del mg_attn_weight # Copy bias if (config.add_bias_linear or config.add_qkv_bias) and not self._peft_format: if to_mcore: linear_qkv_bias = torch.cat([ - hf_state_dict['q_proj.bias'].load().reshape((num_query_groups, -1)), - hf_state_dict['k_proj.bias'].load().reshape((num_query_groups, -1)), - hf_state_dict['v_proj.bias'].load().reshape((num_query_groups, -1)), + hf_state_dict[f'{k}.bias'].load().reshape((num_query_groups, -1)) for k in ['q_proj'] + kv_proj_list ], dim=1).reshape(-1) self._set_weight(mg_attn.linear_qkv.bias, linear_qkv_bias, 'linear_qkv.bias') @@ -653,8 +646,9 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): if mg_attn_bias is not None: mg_attn_bias = mg_attn_bias.reshape((num_query_groups, -1)) hf_state_dict['q_proj.bias'] = mg_attn_bias[:, :q_dim].reshape(-1).clone() - hf_state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:-kv_dim].reshape(-1).clone() - hf_state_dict['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1).clone() + hf_state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:q_dim + kv_dim:].reshape(-1).clone() + if not attention_k_eq_v: + hf_state_dict['v_proj.bias'] = mg_attn_bias[:, q_dim + kv_dim:].reshape(-1).clone() return hf_state_dict def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): @@ -663,21 +657,21 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int else: hf_state_dict = {} config = self.config - hf_state_dict.update(self._set_qkv(mg_attn, hf_state_dict, to_mcore)) + hf_state_dict.update(self._set_qkv(mg_attn, hf_state_dict, to_mcore, layer_idx=layer_idx)) self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, f'{self.hf_o_proj_key}.weight', to_mcore) if config.add_bias_linear: self._set_state_dict(mg_attn, 'linear_proj.bias', hf_state_dict, f'{self.hf_o_proj_key}.bias', to_mcore) if getattr(config, 'softmax_type', 'vanilla') == 'learnable': self._set_state_dict(mg_attn, 'core_attention.softmax_offset', hf_state_dict, 'sinks', to_mcore) if config.qk_layernorm: - self._set_qk_layernorm(mg_attn, hf_state_dict, to_mcore) + self._set_qk_layernorm(mg_attn, hf_state_dict, to_mcore, layer_idx=layer_idx) if to_mcore: hf_state_dict = {} else: hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict - def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore): + def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore, **kwargs): self._set_state_dict(mg_attn, 'q_layernorm.weight', hf_state_dict, self.hf_q_norm_key, to_mcore) self._set_state_dict(mg_attn, 'k_layernorm.weight', hf_state_dict, self.hf_k_norm_key, to_mcore) diff --git a/src/mcore_bridge/model/gpts/minimax_m2.py b/src/mcore_bridge/model/gpts/minimax_m2.py index 42aa4c6..853e3a5 100644 --- a/src/mcore_bridge/model/gpts/minimax_m2.py +++ b/src/mcore_bridge/model/gpts/minimax_m2.py @@ -74,7 +74,7 @@ class MinimaxM2Bridge(GPTBridge): hf_mlp_prefix = 'block_sparse_moe' hf_expert_bias_key = 'e_score_correction_bias' - def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore): + def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore, **kwargs): self._set_state_dict(mg_attn, 'q_norm.weight', hf_state_dict, 'q_norm.weight', to_mcore) self._set_state_dict(mg_attn, 'k_norm.weight', hf_state_dict, 'k_norm.weight', to_mcore) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 602de19..e5c2f35 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -134,35 +134,23 @@ def __init__( # Layer type / sliding attention self.layer_type = text_config.layer_types[layer_idx] self.is_sliding = self.layer_type == 'sliding_attention' - self.sliding_window = text_config.sliding_window if self.is_sliding else None # Head dim: global layers may use a different head dim than sliding ones self.head_dim = ( text_config.global_head_dim if not self.is_sliding and text_config.global_head_dim else text_config.head_dim) - # Alternative attention (k == v) for global layers when `attention_k_eq_v` is set self.use_alternative_attention = (text_config.attention_k_eq_v and not self.is_sliding) num_key_value_heads = ( text_config.num_global_key_value_heads if self.use_alternative_attention else text_config.num_key_value_heads) - self.num_key_value_groups = text_config.num_attention_heads // num_key_value_heads - # Shared KV across the trailing layers - num_kv_shared_layers = getattr(text_config, 'num_kv_shared_layers', 0) - first_kv_shared_layer_idx = config.num_layers - num_kv_shared_layers + self.num_kv_shared_layers = getattr(text_config, 'num_kv_shared_layers', 0) + first_kv_shared_layer_idx = config.num_layers - self.num_kv_shared_layers self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 prev_layers = text_config.layer_types[:first_kv_shared_layer_idx] - if self.is_kv_shared_layer: - # For shared layers, reuse KV from the last non-shared layer of the same type - self.kv_shared_layer_index = (len(prev_layers) - 1 - prev_layers[::-1].index(self.layer_type)) - self.store_full_length_kv = False - else: - self.kv_shared_layer_index = None - # Non-shared layers that are the last of their type in `prev_layers` must keep full KV - self.store_full_length_kv = ( - self.layer_type in prev_layers - and layer_idx == len(prev_layers) - 1 - prev_layers[::-1].index(self.layer_type)) + self.store_full_length_kv = not self.is_kv_shared_layer and layer_idx == len( + prev_layers) - 1 - prev_layers[::-1].index(text_config.layer_types[layer_idx]) orig_kv_channels = config.kv_channels orig_num_query_groups = config.num_query_groups @@ -178,11 +166,13 @@ def __init__( config.num_query_groups = orig_num_query_groups submodules.k_layernorm = orig_k_layernorm - if self.is_kv_shared_layer: - self.linear_qkv_out_dim = self.query_projection_size + if self.is_kv_shared_layer or self.use_alternative_attention: + linear_qkv_dim = self.query_projection_size + if not self.is_kv_shared_layer: + linear_qkv_dim += self.kv_projection_size self.linear_qkv = submodules.linear_qkv( self.config.hidden_size, - self.linear_qkv_out_dim, + linear_qkv_dim, config=self.config, init_method=self.config.init_method, gather_output=False, @@ -290,27 +280,21 @@ def forward(self, hidden_states: Tensor, attention_mask: Tensor, **kwargs) -> Tu query = mixed_qkv key, value = shared_kv_states[self.layer_type] else: + kv_heads_per_group = 1 if self.use_alternative_attention else 2 num_query_heads_per_group = (self.num_attention_heads_per_partition // self.num_query_groups_per_partition) - # If no output gate: [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] - # If have output gate: [sq, b, hp] --> [sq, b, ng, (2 * np/ng + 2) * hn] new_tensor_shape = mixed_qkv.size()[:-1] + ( self.num_query_groups_per_partition, - (num_query_heads_per_group + 2) * self.hidden_size_per_attention_head, + (num_query_heads_per_group + kv_heads_per_group) * self.hidden_size_per_attention_head, ) mixed_qkv = mixed_qkv.view(*new_tensor_shape) - # If no output gate: [sq, b, ng, (np/ng + 2) * hn] - # --> [sq, b, ng, np/ng * hn], None, [sq, b, ng, hn], [sq, b, ng, hn] - split_arg_list = [ - num_query_heads_per_group * self.hidden_size_per_attention_head, - self.hidden_size_per_attention_head, - self.hidden_size_per_attention_head, - ] + split_arg_list = [num_query_heads_per_group * self.hidden_size_per_attention_head + ] + [self.hidden_size_per_attention_head] * kv_heads_per_group if SplitAlongDim is not None: - (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list) + (query, key, value) = SplitAlongDim(mixed_qkv, len(split_arg_list), split_arg_list) else: (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3) - key = self.k_layernorm(key) - value = self.v_norm(value) + key = self.k_layernorm(key) + value = self.v_norm(value) # Query [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) if getattr(self, 'world_size', None) is not None and self.config.num_query_groups < self.world_size: @@ -362,28 +346,38 @@ class Gemma4Bridge(MultimodalGPTBridge): additional_dim0_keys = {'embed_tokens_per_layer', 'per_layer_input_gate', 'per_layer_model_projection'} additional_dim1_keys = {'per_layer_projection'} - def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore): - is_kv_shared_layer = self._get_is_kv_shared_layer(mg_attn, to_mcore) + def __init__(self, config: ModelConfig): + super().__init__(config) + self.text_config = config.hf_config.text_config + self.hidden_size_per_layer_input = self.text_config.hidden_size_per_layer_input + + def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore, **kwargs): + layer_idx = kwargs['layer_idx'] + is_kv_shared_layer = self._get_is_kv_shared_layer(layer_idx) self._set_state_dict(mg_attn, 'q_layernorm.weight', hf_state_dict, self.hf_q_norm_key, to_mcore) if not is_kv_shared_layer: self._set_state_dict(mg_attn, 'k_layernorm.weight', hf_state_dict, self.hf_k_norm_key, to_mcore) - def _get_is_kv_shared_layer(self, mg_attn, to_mcore): - is_kv_shared_layer = False if mg_attn is None else mg_attn.is_kv_shared_layer - return self._reduce_tensor_pp_group(is_kv_shared_layer, to_mcore) - - def _get_head_dim(self, mg_attn, to_mcore): - head_dim = 0 if mg_attn is None else mg_attn.head_dim - return self._reduce_tensor_pp_group(head_dim, to_mcore) + def _get_is_kv_shared_layer(self, layer_idx): + text_config = self.text_config + num_kv_shared_layers = getattr(text_config, 'num_kv_shared_layers', 0) + first_kv_shared_layer_idx = self.config.num_layers - num_kv_shared_layers + is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 + return is_kv_shared_layer def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): - is_kv_shared_layer = self._get_is_kv_shared_layer(mg_attn, to_mcore) - kwargs['kv_channels'] = self._get_head_dim(mg_attn, to_mcore) - assert kwargs['kv_channels'] > 0 + text_config = self.text_config + layer_idx = kwargs['layer_idx'] + is_sliding = text_config.layer_types[layer_idx] == 'sliding_attention' + head_dim = ( + text_config.global_head_dim if not is_sliding and text_config.global_head_dim else text_config.head_dim) + is_kv_shared_layer = self._get_is_kv_shared_layer(layer_idx) if is_kv_shared_layer: self._set_state_dict(mg_attn, 'linear_qkv.weight', hf_state_dict, 'q_proj.weight', to_mcore) return hf_state_dict else: + kwargs['kv_channels'] = head_dim + kwargs['attention_k_eq_v'] = text_config.attention_k_eq_v and not is_sliding return super()._set_qkv(mg_attn, hf_state_dict, to_mcore, **kwargs) def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): @@ -394,11 +388,11 @@ def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: i hf_state_dict = {} hf_state_dict.update(self._set_layer_attn(mg_layer, hf_state_dict, layer_idx, to_mcore)) hf_state_dict.update(self._set_layer_mlp(mg_layer, hf_state_dict, layer_idx, to_mcore)) - for key in [ - 'post_attention_layernorm', 'post_feedforward_layernorm', 'per_layer_input_gate', - 'per_layer_projection', 'post_per_layer_input_norm' - ]: + for key in ['post_attention_layernorm', 'post_feedforward_layernorm']: self._set_state_dict(mg_layer, f'{key}.weight', hf_state_dict, f'{key}.weight', to_mcore) + if self.hidden_size_per_layer_input: + for key in ['per_layer_input_gate', 'per_layer_projection', 'post_per_layer_input_norm']: + self._set_state_dict(mg_layer, f'{key}.weight', hf_state_dict, f'{key}.weight', to_mcore) self._set_state_dict(mg_layer, 'layer_scalar', hf_state_dict, 'layer_scalar', to_mcore) if to_mcore: hf_state_dict = {} @@ -409,9 +403,10 @@ def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: i def _set_word_embeddings(self, mg_model, hf_state_dict, to_mcore): lm_model = getattr(mg_model, 'language_model') if self.is_multimodal else mg_model self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, self.hf_embed_key, to_mcore) - for key in ['embed_tokens_per_layer', 'per_layer_model_projection', 'per_layer_projection_norm']: - self._set_state_dict(lm_model, f'{key}.weight', hf_state_dict, f'model.language_model.{key}.weight', - to_mcore) + if self.hidden_size_per_layer_input: + for key in ['embed_tokens_per_layer', 'per_layer_model_projection', 'per_layer_projection_norm']: + self._set_state_dict(lm_model, f'{key}.weight', hf_state_dict, f'model.language_model.{key}.weight', + to_mcore) class Gemma4TextGPTModel(GPTModel): @@ -421,6 +416,7 @@ def __init__(self, *args, **kwargs): self.num_query_groups_per_partition = self.decoder.layers[0].self_attention.num_query_groups_per_partition text_config = self.config.hf_config.text_config self.text_config = text_config + self.num_kv_shared_layers = getattr(text_config, 'num_kv_shared_layers', 0) self.unique_layer_types = set(text_config.layer_types) self.hidden_size_per_layer_input = text_config.hidden_size_per_layer_input self.final_logit_softcapping = text_config.final_logit_softcapping @@ -490,6 +486,7 @@ def forward(self, *args, **kwargs): llm_input_ids = extra_block_kwargs.pop('llm_input_ids', None) decoder_input = kwargs.get('decoder_input') if self.hidden_size_per_layer_input: + assert self.num_kv_shared_layers > 0, 'not support' if decoder_input is None: per_layer_inputs, shared_kv_states = self.unpack_pp_input() else: @@ -506,11 +503,13 @@ def forward(self, *args, **kwargs): per_layer_inputs = (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale per_layer_inputs = scatter_to_tensor_model_parallel_region(per_layer_inputs) shared_kv_states = {} - extra_block_kwargs['per_layer_inputs'] = per_layer_inputs - extra_block_kwargs['shared_kv_states'] = shared_kv_states + extra_block_kwargs['per_layer_inputs'] = per_layer_inputs + extra_block_kwargs['shared_kv_states'] = shared_kv_states + else: + assert self.num_kv_shared_layers == 0, 'not support' kwargs['extra_block_kwargs'] = extra_block_kwargs hidden_states = super().forward(*args, **kwargs) - if not self.post_process: + if self.hidden_size_per_layer_input and not self.post_process: hidden_states = self._pack_pp_output(hidden_states, per_layer_inputs, shared_kv_states) return hidden_states From 472f2f7de44163b84544406666fd3191eef7a137 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 18 May 2026 17:55:16 +0800 Subject: [PATCH 39/52] fix --- src/mcore_bridge/bridge/gpt_bridge.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 9ba858e..ee8d36b 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -607,8 +607,9 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): hf_state_dict['k_proj.lora_B.weight'] = lora_B[:, q_dim:q_dim + kv_dim:, :].reshape( -1, lora_B.shape[-1]).clone() if not attention_k_eq_v: - hf_state_dict['v_proj.lora_B.weight'] = lora_B[:, q_dim + kv_dim:, :].reshape( - -1, lora_B.shape[-1]).clone() + hf_state_dict['v_proj.lora_B.weight'] = lora_B[:, + -kv_dim:, :].reshape(-1, + lora_B.shape[-1]).clone() elif not self._peft_format: mg_attn_weight, scale_inv = self._get_weight( None if mg_attn is None else mg_attn.linear_qkv.weight.data, 'linear_qkv.weight') @@ -619,7 +620,7 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): hf_state_dict['k_proj.weight'] = mg_attn_weight[:, q_dim:q_dim + kv_dim, :].reshape( -1, config.hidden_size).clone() if not attention_k_eq_v: - hf_state_dict['v_proj.weight'] = mg_attn_weight[:, q_dim + kv_dim:, :].reshape( + hf_state_dict['v_proj.weight'] = mg_attn_weight[:, -kv_dim:, :].reshape( -1, config.hidden_size).clone() if scale_inv is not None: scale_inv = scale_inv.reshape((num_query_groups, -1, hidden_size_block)) @@ -628,8 +629,9 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): hf_state_dict['k_proj.weight_scale_inv'] = scale_inv[:, q_block:q_block + kv_block:, :].reshape( -1, hidden_size_block).clone() if not attention_k_eq_v: - dict['v_proj.weight_scale_inv'] = scale_inv[:, q_block + kv_block:, :].reshape( - -1, hidden_size_block).clone() + dict['v_proj.weight_scale_inv'] = scale_inv[:, + -kv_block:, :].reshape(-1, + hidden_size_block).clone() del mg_attn_weight # Copy bias @@ -648,7 +650,7 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): hf_state_dict['q_proj.bias'] = mg_attn_bias[:, :q_dim].reshape(-1).clone() hf_state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:q_dim + kv_dim:].reshape(-1).clone() if not attention_k_eq_v: - hf_state_dict['v_proj.bias'] = mg_attn_bias[:, q_dim + kv_dim:].reshape(-1).clone() + hf_state_dict['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1).clone() return hf_state_dict def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): From 13dc948270602fd47c5c57f07d31939fa8602e42 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 18 May 2026 18:06:54 +0800 Subject: [PATCH 40/52] fix --- src/mcore_bridge/model/mm_gpts/gemma4.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index e5c2f35..c0fface 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -290,9 +290,14 @@ def forward(self, hidden_states: Tensor, attention_mask: Tensor, **kwargs) -> Tu split_arg_list = [num_query_heads_per_group * self.hidden_size_per_attention_head ] + [self.hidden_size_per_attention_head] * kv_heads_per_group if SplitAlongDim is not None: - (query, key, value) = SplitAlongDim(mixed_qkv, len(split_arg_list), split_arg_list) + qkv = SplitAlongDim(mixed_qkv, len(split_arg_list), split_arg_list) else: - (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3) + qkv = torch.split(mixed_qkv, split_arg_list, dim=3) + if self.use_alternative_attention: + query, key = qkv + value = key + else: + query, key, value = qkv key = self.k_layernorm(key) value = self.v_norm(value) # Query [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] @@ -485,6 +490,7 @@ def forward(self, *args, **kwargs): extra_block_kwargs = kwargs.pop('extra_block_kwargs', None) or {} llm_input_ids = extra_block_kwargs.pop('llm_input_ids', None) decoder_input = kwargs.get('decoder_input') + shared_kv_states = {} if self.hidden_size_per_layer_input: assert self.num_kv_shared_layers > 0, 'not support' if decoder_input is None: @@ -502,11 +508,10 @@ def forward(self, *args, **kwargs): per_layer_projection = self.per_layer_projection_norm(per_layer_projection) per_layer_inputs = (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale per_layer_inputs = scatter_to_tensor_model_parallel_region(per_layer_inputs) - shared_kv_states = {} extra_block_kwargs['per_layer_inputs'] = per_layer_inputs - extra_block_kwargs['shared_kv_states'] = shared_kv_states else: assert self.num_kv_shared_layers == 0, 'not support' + extra_block_kwargs['shared_kv_states'] = shared_kv_states kwargs['extra_block_kwargs'] = extra_block_kwargs hidden_states = super().forward(*args, **kwargs) if self.hidden_size_per_layer_input and not self.post_process: @@ -680,7 +685,8 @@ class Gemma4TransformerBlock(TransformerBlock): def _layer_forward(self, layer, hidden_states, **kwargs): layer_number = layer.layer_number - 1 per_layer_inputs = kwargs.pop('per_layer_inputs', None) - kwargs['per_layer_input'] = per_layer_inputs[:, :, layer_number] + if per_layer_inputs is not None: + kwargs['per_layer_input'] = per_layer_inputs[:, :, layer_number] layer_type = self.config.hf_config.text_config.layer_types[layer_number] kwargs['rotary_pos_emb'] = kwargs['rotary_pos_emb'][layer_type] return super()._layer_forward(layer, hidden_states, **kwargs) From 3d3e90e5b24c3232a5ead06e204b01de3b301fbc Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 18 May 2026 18:07:29 +0800 Subject: [PATCH 41/52] fix --- src/mcore_bridge/bridge/gpt_bridge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index ee8d36b..6ecdcaa 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -629,7 +629,7 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): hf_state_dict['k_proj.weight_scale_inv'] = scale_inv[:, q_block:q_block + kv_block:, :].reshape( -1, hidden_size_block).clone() if not attention_k_eq_v: - dict['v_proj.weight_scale_inv'] = scale_inv[:, + hf_state_dict['v_proj.weight_scale_inv'] = scale_inv[:, -kv_block:, :].reshape(-1, hidden_size_block).clone() del mg_attn_weight From ed370d90f95b77c4345228cb7fb7edff6a487dbb Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 18 May 2026 22:03:32 +0800 Subject: [PATCH 42/52] update --- src/mcore_bridge/bridge/gpt_bridge.py | 23 +++++++++++---------- src/mcore_bridge/config/parser.py | 10 +++++---- src/mcore_bridge/model/gpt_model.py | 1 + src/mcore_bridge/model/mm_gpt_model.py | 2 ++ src/mcore_bridge/model/mm_gpts/gemma4.py | 26 ++++++++++++++++-------- 5 files changed, 38 insertions(+), 24 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 6ecdcaa..57477d7 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -540,11 +540,16 @@ def _reduce_tensor_pp_group(self, tensor, to_mcore, dtype=torch.long, op=dist.Re def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): config = self.config - num_query_groups = ( - config.num_query_groups if config.num_query_groups is not None else config.num_attention_heads) + num_query_groups = kwargs.get('num_query_groups') + if num_query_groups is None: + num_query_groups = ( + config.num_query_groups if config.num_query_groups is not None else config.num_attention_heads) hidden_size_block = config.hidden_size // self.fp8_block_size attention_k_eq_v = kwargs.get('attention_k_eq_v', False) kv_proj_list = ['k_proj'] if attention_k_eq_v else ['k_proj', 'v_proj'] + kv_channels = kwargs.get('kv_channels') + if kv_channels is None: + kv_channels = self.config.kv_channels if to_mcore: if isinstance(mg_attn.linear_qkv, LoraParallelLinear): lora_A = hf_state_dict['q_proj.lora_A.weight'].load() @@ -577,10 +582,7 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): self._set_weight( mg_attn.linear_qkv.weight, linear_qkv_weight, 'linear_qkv.weight', hf_scale_inv=qkv_scale_inv) else: - kv_channels = kwargs.get('kv_channels') - if kv_channels is None: - kv_channels = self.config.kv_channels - q_dim = kv_channels * self.config.num_attention_heads // self.config.num_query_groups + q_dim = kv_channels * self.config.num_attention_heads // num_query_groups if self.config.attention_output_gate: q_dim *= 2 kv_dim = kv_channels @@ -604,7 +606,7 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): hf_state_dict[f'{key}.lora_A.weight'] = lora_A.clone() lora_B = lora_B.reshape((num_query_groups, -1, lora_B.shape[-1])) hf_state_dict['q_proj.lora_B.weight'] = lora_B[:, :q_dim, :].reshape(-1, lora_B.shape[-1]).clone() - hf_state_dict['k_proj.lora_B.weight'] = lora_B[:, q_dim:q_dim + kv_dim:, :].reshape( + hf_state_dict['k_proj.lora_B.weight'] = lora_B[:, q_dim:q_dim + kv_dim, :].reshape( -1, lora_B.shape[-1]).clone() if not attention_k_eq_v: hf_state_dict['v_proj.lora_B.weight'] = lora_B[:, @@ -629,9 +631,8 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): hf_state_dict['k_proj.weight_scale_inv'] = scale_inv[:, q_block:q_block + kv_block:, :].reshape( -1, hidden_size_block).clone() if not attention_k_eq_v: - hf_state_dict['v_proj.weight_scale_inv'] = scale_inv[:, - -kv_block:, :].reshape(-1, - hidden_size_block).clone() + hf_state_dict['v_proj.weight_scale_inv'] = scale_inv[:, -kv_block:, :].reshape( + -1, hidden_size_block).clone() del mg_attn_weight # Copy bias @@ -648,7 +649,7 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): if mg_attn_bias is not None: mg_attn_bias = mg_attn_bias.reshape((num_query_groups, -1)) hf_state_dict['q_proj.bias'] = mg_attn_bias[:, :q_dim].reshape(-1).clone() - hf_state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:q_dim + kv_dim:].reshape(-1).clone() + hf_state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:q_dim + kv_dim].reshape(-1).clone() if not attention_k_eq_v: hf_state_dict['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1).clone() return hf_state_dict diff --git a/src/mcore_bridge/config/parser.py b/src/mcore_bridge/config/parser.py index c9db851..f00826d 100644 --- a/src/mcore_bridge/config/parser.py +++ b/src/mcore_bridge/config/parser.py @@ -157,9 +157,11 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]: res['rotary_interleaved'] = True elif hf_model_type in {'gemma4'}: res['qk_layernorm'] = True - res['window_size'] = f'{window_size},0' - window_attn_skip_freq = ','.join(['1' if lt == 'sliding_attention' else '0' for lt in layer_types]) - res['window_attn_skip_freq'] = f'[{window_attn_skip_freq}]' + # If set to "vision", pass attention_mask manually. + if hf_config.text_config.use_bidirectional_attention is None: + res['window_size'] = f'{window_size - 1},0' + window_attn_skip_freq = ','.join(['1' if lt == 'sliding_attention' else '0' for lt in layer_types]) + res['window_attn_skip_freq'] = f'[{window_attn_skip_freq}]' res['softmax_scale'] = 1. res['swiglu'] = False res['gated_linear_unit'] = True @@ -172,7 +174,7 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]: res['quick_geglu'] = True res['activation_func_clamp_value'] = 7 res['glu_linear_offset'] = 1 - res['window_size'] = f'{window_size},0' + res['window_size'] = f'{window_size - 1},0' if layer_types is None: res['window_attn_skip_freq'] = '2' else: diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 778e2c3..f744ba7 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -55,6 +55,7 @@ def sharded_state_dict( class GPTModel(McoreGPTModel): config: ModelConfig + extra_forward_keys = [] def __init__( self, diff --git a/src/mcore_bridge/model/mm_gpt_model.py b/src/mcore_bridge/model/mm_gpt_model.py index 2a5f659..2a80c04 100644 --- a/src/mcore_bridge/model/mm_gpt_model.py +++ b/src/mcore_bridge/model/mm_gpt_model.py @@ -79,6 +79,7 @@ def forward( packed_seq_params: PackedSeqParams = None, **kwargs, ) -> torch.Tensor: + extra_kwargs = {k: kwargs[k] for k in self.language_model.extra_forward_keys} if decoder_input is not None: pass elif self.pre_process: @@ -90,6 +91,7 @@ def forward( # decoder will get hidden_states from encoder.input_tensor decoder_input = None kwargs = {} + kwargs.update(extra_kwargs) return self.language_model( input_ids=input_ids, position_ids=position_ids, diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index c0fface..5b2b807 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -141,7 +141,7 @@ def __init__( if not self.is_sliding and text_config.global_head_dim else text_config.head_dim) self.use_alternative_attention = (text_config.attention_k_eq_v and not self.is_sliding) - num_key_value_heads = ( + self.num_key_value_heads = ( text_config.num_global_key_value_heads if self.use_alternative_attention else text_config.num_key_value_heads) # Shared KV across the trailing layers @@ -156,7 +156,7 @@ def __init__( orig_num_query_groups = config.num_query_groups orig_k_layernorm = submodules.k_layernorm config.kv_channels = self.head_dim - config.num_query_groups = num_key_value_heads + config.num_query_groups = self.num_key_value_heads if self.is_kv_shared_layer: submodules.k_layernorm = IdentityOp try: @@ -382,7 +382,11 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): return hf_state_dict else: kwargs['kv_channels'] = head_dim - kwargs['attention_k_eq_v'] = text_config.attention_k_eq_v and not is_sliding + use_alternative_attention = text_config.attention_k_eq_v and not is_sliding + kwargs['attention_k_eq_v'] = use_alternative_attention + kwargs['num_query_groups'] = ( + text_config.num_global_key_value_heads + if use_alternative_attention else text_config.num_key_value_heads) return super()._set_qkv(mg_attn, hf_state_dict, to_mcore, **kwargs) def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): @@ -415,6 +419,7 @@ def _set_word_embeddings(self, mg_model, hf_state_dict, to_mcore): class Gemma4TextGPTModel(GPTModel): + extra_forward_keys = ['mm_token_type_ids'] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -490,6 +495,7 @@ def forward(self, *args, **kwargs): extra_block_kwargs = kwargs.pop('extra_block_kwargs', None) or {} llm_input_ids = extra_block_kwargs.pop('llm_input_ids', None) decoder_input = kwargs.get('decoder_input') + mm_token_type_ids = extra_block_kwargs.pop('mm_token_type_ids', None) shared_kv_states = {} if self.hidden_size_per_layer_input: assert self.num_kv_shared_layers > 0, 'not support' @@ -513,6 +519,8 @@ def forward(self, *args, **kwargs): assert self.num_kv_shared_layers == 0, 'not support' extra_block_kwargs['shared_kv_states'] = shared_kv_states kwargs['extra_block_kwargs'] = extra_block_kwargs + if self.text_config.use_bidirectional_attention == 'vision': + pass hidden_states = super().forward(*args, **kwargs) if self.hidden_size_per_layer_input and not self.post_process: hidden_states = self._pack_pp_output(hidden_states, per_layer_inputs, shared_kv_states) @@ -521,14 +529,14 @@ def forward(self, *args, **kwargs): def _pack_pp_output(self, hidden_states, per_layer_inputs, shared_kv_states): per_layer_inputs = per_layer_inputs.view(*hidden_states.shape[:2], -1) hidden_states = torch.concat([hidden_states, per_layer_inputs], dim=-1) - flag = per_layer_inputs.new_zeros(*hidden_states.shape[:2], 1) + flag = per_layer_inputs.new_zeros(*hidden_states.shape[:2], 2) if 'sliding_attention' in shared_kv_states: - flag[0] = 1 + flag[0, 0, 0] = 1 sliding_states = torch.concat(shared_kv_states['sliding_attention'], -1) sliding_states = sliding_states.view(*hidden_states.shape[:2], -1) hidden_states = torch.concat([hidden_states, sliding_states], dim=-1) if 'full_attention' in shared_kv_states: - flag[1] = 1 + flag[0, 0, 1] = 1 full_states = torch.concat(shared_kv_states['full_attention'], -1) full_states = full_states.view(*hidden_states.shape[:2], -1) hidden_states = torch.concat([hidden_states, full_states], dim=-1) @@ -541,7 +549,7 @@ def unpack_pp_input(self): input_tensor = self.get_input_tensor() self.num_query_groups_per_partition sequence_len = input_tensor.shape[0] * tp_size if self.config.sequence_parallel else input_tensor.shape[0] - input_tensor, flag = input_tensor.split([input_tensor.shape[-1] - 1, 1], dim=-1) + input_tensor, flag = input_tensor.split([input_tensor.shape[-1] - 2, 2], dim=-1) flag = flag.detach() per_layer_inputs_shape = [ sequence_len, input_tensor.shape[1], self.config.num_layers, self.hidden_size_per_layer_input // tp_size @@ -552,12 +560,12 @@ def unpack_pp_input(self): per_layer_inputs_dim = math.prod(per_layer_inputs_shape) // math.prod(input_tensor.shape[:2]) full_states_dim = math.prod(full_states_shape) // math.prod(input_tensor.shape[:2]) sliding_states_dim = math.prod(sliding_states_shape) // math.prod(input_tensor.shape[:2]) - if flag[1] != 0: + if flag[0, 0, 0].item() != 0: input_tensor, full_states = input_tensor.split([input_tensor.shape[-1] - full_states_dim, full_states_dim], dim=-1) full_states = full_states.reshape(*full_states_shape) shared_kv_states['full_attention'] = full_states.chunk(2, -1) - if flag[0] != 0: + if flag[0, 0, 1].item() != 0: input_tensor, sliding_states = input_tensor.split( [input_tensor.shape[-1] - sliding_states_dim, sliding_states_dim], dim=-1) sliding_states = sliding_states.reshape(*sliding_states_shape) From b3dfdb945ac5e501e8d15f04b72c0920fd4ccee1 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 18 May 2026 22:37:08 +0800 Subject: [PATCH 43/52] update --- README.md | 1 + README_zh.md | 1 + src/mcore_bridge/model/mm_gpts/gemma4.py | 25 +++++++++++++++++++++++- 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 1473a16..bf9423a 100644 --- a/README.md +++ b/README.md @@ -145,6 +145,7 @@ The following is the list of models supported by MCore-Bridge: | Series | model_type | | -------- | ------------------------------------------------------------ | | Qwen | qwen2_vl, qwen2_5_vl, qwen2_5_omni
qwen3_vl, qwen3_vl_moe, qwen3_omni_moe, qwen3_asr
qwen3_5, qwen3_5_moe | +| Gemma | gemma4 | | GLM | glm4v, glm4v_moe | | Kimi | kimi_vl | | InternVL | internvl_chat, internvl | diff --git a/README_zh.md b/README_zh.md index 5f170e7..b7ee63d 100644 --- a/README_zh.md +++ b/README_zh.md @@ -142,6 +142,7 @@ uv pip install -e . --torch-backend=auto | 系列 | model_type | | -------- | ------------------------------------------------------------ | | Qwen | qwen2_vl, qwen2_5_vl, qwen2_5_omni
qwen3_vl, qwen3_vl_moe, qwen3_omni_moe, qwen3_asr
qwen3_5, qwen3_5_moe | +| Gemma | gemma4 | | GLM | glm4v, glm4v_moe | | Kimi | kimi_vl | | InternVL | internvl_chat, internvl | diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 5b2b807..9a52635 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -519,13 +519,35 @@ def forward(self, *args, **kwargs): assert self.num_kv_shared_layers == 0, 'not support' extra_block_kwargs['shared_kv_states'] = shared_kv_states kwargs['extra_block_kwargs'] = extra_block_kwargs + attention_mask = kwargs.get('attention_mask') + kwargs['attention_mask'] = {'sliding_attention': attention_mask, 'full_attention': attention_mask} if self.text_config.use_bidirectional_attention == 'vision': - pass + kwargs['attention_mask']['sliding_attention'] = self._create_sliding_attention_mask( + kwargs['attention_mask'], mm_token_type_ids) hidden_states = super().forward(*args, **kwargs) if self.hidden_size_per_layer_input and not self.post_process: hidden_states = self._pack_pp_output(hidden_states, per_layer_inputs, shared_kv_states) return hidden_states + def _create_sliding_attention_mask(self, attention_mask, mm_token_type_ids): + window_size = self.text_config.sliding_window - 1 + seq_len = attention_mask.shape[-1] + + window_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device='cuda') + window_mask = ~torch.triu(window_mask, diagonal=-window_size) + + is_vision = mm_token_type_ids > 0 + is_prev_vision = torch.roll(is_vision, shifts=1, dims=-1) + is_prev_vision[:, 0] = False + vision_group_ids = torch.cumsum((is_vision & ~is_prev_vision).int(), dim=1) - 1 + vision_group_ids = torch.where(is_vision, vision_group_ids, + torch.full_like(vision_group_ids, -1)) + + q_group = vision_group_ids.unsqueeze(1).unsqueeze(-1) + k_group = vision_group_ids.unsqueeze(1).unsqueeze(-2) + same_vision_group = (q_group == k_group) & (q_group >= 0) & (k_group >= 0) + return attention_mask | window_mask | ~same_vision_group + def _pack_pp_output(self, hidden_states, per_layer_inputs, shared_kv_states): per_layer_inputs = per_layer_inputs.view(*hidden_states.shape[:2], -1) hidden_states = torch.concat([hidden_states, per_layer_inputs], dim=-1) @@ -697,6 +719,7 @@ def _layer_forward(self, layer, hidden_states, **kwargs): kwargs['per_layer_input'] = per_layer_inputs[:, :, layer_number] layer_type = self.config.hf_config.text_config.layer_types[layer_number] kwargs['rotary_pos_emb'] = kwargs['rotary_pos_emb'][layer_type] + kwargs['attention_mask'] = kwargs['attention_mask'][layer_type] return super()._layer_forward(layer, hidden_states, **kwargs) From 8744cc4b684ed2548f15020af64ec0a82f3ddf12 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 19 May 2026 02:54:59 +0800 Subject: [PATCH 44/52] update --- src/mcore_bridge/model/gpt_model.py | 5 ++++- src/mcore_bridge/model/mm_gpts/gemma4.py | 7 +++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index f744ba7..84bf56d 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -334,9 +334,12 @@ def forward( input_tensor, mtp_decoder_input = input_tensor.chunk(2, dim=0) self.set_input_tensor(input_tensor) kwargs = {} + full_attention_mask = attention_mask + if isinstance(full_attention_mask, dict): + full_attention_mask = full_attention_mask['full_attention'] if mcore_016 and attention_mask is not None: assert packed_seq_params is None - padding_mask = ~((~attention_mask).sum(dim=(1, 2)) > 0) + padding_mask = ~((~full_attention_mask).sum(dim=(1, 2)) > 0) if self.config.context_parallel_size > 1: padding_mask = split_cp_inputs(padding_mask, None, 1) tp_size = self.config.tensor_model_parallel_size diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 9a52635..304b0e7 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -523,7 +523,7 @@ def forward(self, *args, **kwargs): kwargs['attention_mask'] = {'sliding_attention': attention_mask, 'full_attention': attention_mask} if self.text_config.use_bidirectional_attention == 'vision': kwargs['attention_mask']['sliding_attention'] = self._create_sliding_attention_mask( - kwargs['attention_mask'], mm_token_type_ids) + attention_mask, mm_token_type_ids) hidden_states = super().forward(*args, **kwargs) if self.hidden_size_per_layer_input and not self.post_process: hidden_states = self._pack_pp_output(hidden_states, per_layer_inputs, shared_kv_states) @@ -540,13 +540,12 @@ def _create_sliding_attention_mask(self, attention_mask, mm_token_type_ids): is_prev_vision = torch.roll(is_vision, shifts=1, dims=-1) is_prev_vision[:, 0] = False vision_group_ids = torch.cumsum((is_vision & ~is_prev_vision).int(), dim=1) - 1 - vision_group_ids = torch.where(is_vision, vision_group_ids, - torch.full_like(vision_group_ids, -1)) + vision_group_ids = torch.where(is_vision, vision_group_ids, torch.full_like(vision_group_ids, -1)) q_group = vision_group_ids.unsqueeze(1).unsqueeze(-1) k_group = vision_group_ids.unsqueeze(1).unsqueeze(-2) same_vision_group = (q_group == k_group) & (q_group >= 0) & (k_group >= 0) - return attention_mask | window_mask | ~same_vision_group + return (attention_mask | window_mask) & ~same_vision_group def _pack_pp_output(self, hidden_states, per_layer_inputs, shared_kv_states): per_layer_inputs = per_layer_inputs.view(*hidden_states.shape[:2], -1) From 3c3c4b892c0cba51cd1a4609329610a3c224f434 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 19 May 2026 11:28:12 +0800 Subject: [PATCH 45/52] update --- src/mcore_bridge/model/mm_gpts/gemma4.py | 53 ++++++++++++++++++++---- 1 file changed, 46 insertions(+), 7 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 304b0e7..c51fd00 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -2,19 +2,20 @@ import copy import math import torch -import torch.distributed as dist from megatron.core.extensions.transformer_engine import (SplitAlongDim, TEColumnParallelLinear, TENorm, TERowParallelLinear) from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.common.embeddings.yarn_rotary_pos_embedding import _yarn_get_concentration_factor_from_config from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec +from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.parallel_state import get_tensor_model_parallel_rank from megatron.core.tensor_parallel import VocabParallelEmbedding, all_gather_last_dim_from_tensor_parallel_region from megatron.core.tensor_parallel.mappings import (gather_from_tensor_model_parallel_region, scatter_to_tensor_model_parallel_region) from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP from megatron.core.transformer.spec_utils import build_module @@ -157,6 +158,8 @@ def __init__( orig_k_layernorm = submodules.k_layernorm config.kv_channels = self.head_dim config.num_query_groups = self.num_key_value_heads + if self.is_sliding and config.window_size is None: + kwargs['attn_mask_type'] = AttnMaskType.arbitrary if self.is_kv_shared_layer: submodules.k_layernorm = IdentityOp try: @@ -334,7 +337,6 @@ def __init__( ): self.layer_number = layer_number text_config = config.hf_config.text_config - self.enable_moe_block = text_config.enable_moe_block first_kv_shared_layer_idx = config.num_layers - text_config.num_kv_shared_layers is_kv_shared_layer = layer_number > first_kv_shared_layer_idx > 0 use_double_wide_mlp = text_config.use_double_wide_mlp and is_kv_shared_layer @@ -345,6 +347,13 @@ def __init__( finally: config.ffn_hidden_size = ffn_hidden_size +class Gemma4MoELayer(MoELayer): + def __init__(self, config, *args, **kwargs): + super().__init__(config, *args, **kwargs) + + def forward(self, hidden_states: torch.Tensor): + return super().forward(hidden_states) + class Gemma4Bridge(MultimodalGPTBridge): hf_post_attention_layernorm = 'pre_feedforward_layernorm' @@ -389,6 +398,19 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): if use_alternative_attention else text_config.num_key_value_heads) return super()._set_qkv(mg_attn, hf_state_dict, to_mcore, **kwargs) + def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool, is_mtp: bool = False): + mg_mlp = None if mg_layer is None else mg_layer.mlp + hf_state_dict.update( + self._set_mlp_state(mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore)) + self._set_state_dict(mg_layer, 'mlp.linear_fc1.layer_norm_weight', hf_state_dict, + f'{self.hf_post_attention_layernorm}.weight', to_mcore) + if self.text_config.enable_moe_block: + mg_experts = None if mg_layer is None else mg_layer.experts_mlp + hf_state_dict.update( + self._set_moe_state( + mg_experts, hf_state_dict, '', layer_idx, to_mcore, is_mtp=is_mtp)) + return hf_state_dict + def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): hf_prefix = f'{hf_prefix}{layer_idx}.' if to_mcore: @@ -403,6 +425,9 @@ def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: i for key in ['per_layer_input_gate', 'per_layer_projection', 'post_per_layer_input_norm']: self._set_state_dict(mg_layer, f'{key}.weight', hf_state_dict, f'{key}.weight', to_mcore) self._set_state_dict(mg_layer, 'layer_scalar', hf_state_dict, 'layer_scalar', to_mcore) + if self.text_config.enable_moe_block: + for key in ['post_feedforward_layernorm_1', 'post_feedforward_layernorm_2']: + self._set_state_dict(mg_layer, f'{key}.weight', hf_state_dict, f'{key}.weight', to_mcore) if to_mcore: hf_state_dict = {} else: @@ -568,7 +593,6 @@ def unpack_pp_input(self): tp_size = self.config.tensor_model_parallel_size shared_kv_states = {} input_tensor = self.get_input_tensor() - self.num_query_groups_per_partition sequence_len = input_tensor.shape[0] * tp_size if self.config.sequence_parallel else input_tensor.shape[0] input_tensor, flag = input_tensor.split([input_tensor.shape[-1] - 2, 2], dim=-1) flag = flag.detach() @@ -611,6 +635,9 @@ class Gemma4TransformerLayer(TransformerLayer): def __init__(self, config, submodules, *args, **kwargs): super().__init__(config, submodules, *args, **kwargs) text_config = config.hf_config.text_config + self.enable_moe_block = text_config.enable_moe_block + if self.enable_moe_block: + self.experts_mlp = self._build_mlp(submodules.experts_mlp) hidden_size = self.config.hidden_size eps = self.config.layernorm_epsilon @@ -651,14 +678,11 @@ def __init__(self, config, submodules, *args, **kwargs): ) self.post_per_layer_input_norm = build_module(TENorm, hidden_size=hidden_size, config=self.config, eps=eps) - self.enable_moe_block = text_config.enable_moe_block if self.enable_moe_block: self.post_feedforward_layernorm_1 = build_module( TENorm, hidden_size=hidden_size, config=self.config, eps=eps) self.post_feedforward_layernorm_2 = build_module( TENorm, hidden_size=hidden_size, config=self.config, eps=eps) - self.pre_feedforward_layernorm_2 = build_module( - TENorm, hidden_size=hidden_size, config=self.config, eps=eps) def _forward_attention(self, hidden_states: Tensor, **kwargs): context = kwargs.pop('context', None) @@ -682,7 +706,13 @@ def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None) pre_mlp_layernorm_output = self._forward_pre_mlp_layernorm(hidden_states) mlp_output, bias = self.mlp(pre_mlp_layernorm_output, padding_mask=padding_mask) if self.enable_moe_block: - pass + mlp_output_1 = self.post_feedforward_layernorm_1(mlp_output) + mlp_output_2, bias = self.experts_mlp(residual, padding_mask=padding_mask) + mlp_output_2 = self.post_feedforward_layernorm_2(mlp_output_2) + + # Combine mlp and moe outputs + mlp_output = mlp_output_1 + mlp_output_2 + mlp_output = self.post_feedforward_layernorm(mlp_output) hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)((mlp_output, bias), residual, self.hidden_dropout) @@ -727,11 +757,20 @@ class Gemma4Loader(ModelLoader): transformer_block = Gemma4TransformerBlock def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): + num_moe_experts = self.config.num_moe_experts + self.config.num_moe_experts = None layer_specs = get_gpt_decoder_block_spec( self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage) for layer_spec in layer_specs.layer_specs: layer_spec.submodules.self_attention.module = Gemma4SelfAttention layer_spec.submodules.mlp.module = Gemma4MLP + if num_moe_experts is not None: + self.config.num_moe_experts = num_moe_experts + moe_layer_specs = get_gpt_decoder_block_spec( + self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage) + for layer_spec, moe_layer_spec in zip(layer_specs.layer_specs, moe_layer_specs.layer_specs): + layer_spec.submodules.experts_mlp = moe_layer_spec.submodules.mlp + layer_spec.submodules.experts_mlp.module = Gemma4MoELayer return layer_specs def _set_transformer_layer(self, transformer_layer_spec): From e10c6e92ca488c7bb4b6d4b12a1cdc091499f093 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 19 May 2026 14:04:14 +0800 Subject: [PATCH 46/52] update --- src/mcore_bridge/bridge/gpt_bridge.py | 11 ++-- src/mcore_bridge/config/parser.py | 2 +- src/mcore_bridge/model/mm_gpts/gemma4.py | 52 ++++++++++++++----- .../model/modules/transformer_layer.py | 4 +- 4 files changed, 48 insertions(+), 21 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 57477d7..23c0e50 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -678,6 +678,12 @@ def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore, **kwargs): self._set_state_dict(mg_attn, 'q_layernorm.weight', hf_state_dict, self.hf_q_norm_key, to_mcore) self._set_state_dict(mg_attn, 'k_layernorm.weight', hf_state_dict, self.hf_k_norm_key, to_mcore) + def _set_router(self, mg_mlp, hf_state_dict, to_mcore): + hf_gate_key = self.hf_gate_key + if self.llm_model_type == 'gpt_oss': + hf_gate_key = 'router.weight' + self._set_state_dict(mg_mlp, 'router.weight', hf_state_dict, hf_gate_key, to_mcore) + def _set_moe_state( self, mg_mlp, @@ -692,10 +698,7 @@ def _set_moe_state( else: hf_state_dict = {} config = self.config - hf_gate_key = self.hf_gate_key - if self.llm_model_type == 'gpt_oss': - hf_gate_key = 'router.weight' - self._set_state_dict(mg_mlp, 'router.weight', hf_state_dict, hf_gate_key, to_mcore) + self._set_router(mg_mlp, hf_state_dict, to_mcore) if config.add_bias_linear: self._set_state_dict(mg_mlp, 'router.bias', hf_state_dict, hf_gate_key.replace('weight', 'bias'), to_mcore) if config.moe_router_enable_expert_bias: diff --git a/src/mcore_bridge/config/parser.py b/src/mcore_bridge/config/parser.py index f00826d..842c1bb 100644 --- a/src/mcore_bridge/config/parser.py +++ b/src/mcore_bridge/config/parser.py @@ -27,7 +27,7 @@ # moe 'moe_ffn_hidden_size': ['moe_intermediate_size'], 'moe_shared_expert_intermediate_size': ['shared_expert_intermediate_size'], - 'moe_router_topk': ['num_experts_per_tok', 'moe_topk', 'moe_k'], + 'moe_router_topk': ['num_experts_per_tok', 'moe_topk', 'moe_k', 'top_k_experts'], 'moe_router_num_groups': ['n_group'], 'moe_router_group_topk': ['topk_group'], 'num_moe_experts': ['num_experts', 'n_routed_experts', 'moe_num_experts', 'num_local_experts'], diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index c51fd00..a329483 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -8,7 +8,6 @@ from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.common.embeddings.yarn_rotary_pos_embedding import _yarn_get_concentration_factor_from_config from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec -from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.parallel_state import get_tensor_model_parallel_rank from megatron.core.tensor_parallel import VocabParallelEmbedding, all_gather_last_dim_from_tensor_parallel_region @@ -18,10 +17,12 @@ from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP +from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.spec_utils import build_module from megatron.core.utils import make_viewless_tensor, nvtx_range_pop, nvtx_range_push -from torch import Tensor +from torch import Tensor, nn from transformers import AutoModel, PretrainedConfig +from transformers.utils.versions import require_version from typing import Optional, Tuple from mcore_bridge.bridge import MultimodalGPTBridge @@ -36,7 +37,7 @@ from .utils import HuggingFaceVit -class Gemma4VNorm(torch.nn.Module): +class Gemma4RMSNormNoScale(torch.nn.Module): """RMSNorm without learnable scale, mirroring HF `Gemma4RMSNorm(with_scale=False)`.""" def __init__(self, dim: int, eps: float = 1e-6): @@ -186,8 +187,8 @@ def __init__( tp_group=self.pg_collection.tp, ) - self.v_norm = ( - Gemma4VNorm(self.head_dim, eps=self.config.layernorm_epsilon) if not self.is_kv_shared_layer else None) + if not self.is_kv_shared_layer: + self.v_norm = Gemma4RMSNormNoScale(self.head_dim, eps=self.config.layernorm_epsilon) def _forward_core_attention( self, @@ -347,12 +348,31 @@ def __init__( finally: config.ffn_hidden_size = ffn_hidden_size + class Gemma4MoELayer(MoELayer): + def __init__(self, config, *args, **kwargs): + require_version('megatron-core>=0.16.0.dev', 'Gemma4MoELayer requires megatron-core>=0.16.0') super().__init__(config, *args, **kwargs) - - def forward(self, hidden_states: torch.Tensor): - return super().forward(hidden_states) + self.pre_feedforward_layernorm_2 = build_module( + TENorm, hidden_size=config.hidden_size, config=config, eps=config.layernorm_epsilon) + self.norm = Gemma4RMSNormNoScale(config.hidden_size, eps=self.config.layernorm_epsilon) + self.scalar_root_size = config.hidden_size**-0.5 + self.scale = nn.Parameter(torch.ones(config.hidden_size)) + self.per_expert_scale = nn.Parameter(torch.ones(config.num_moe_experts)) + self.scale.sequence_parallel = config.sequence_parallel + self.per_expert_scale.sequence_parallel = config.sequence_parallel + + def route(self, hidden_states: torch.Tensor, *args, **kwargs): + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states * self.scale * self.scalar_root_size + probs, routing_map = super().route(hidden_states) + probs = probs * self.per_expert_scale + return probs, routing_map + + def preprocess(self, hidden_states: torch.Tensor, *args, **kwargs): + hidden_states = self.pre_feedforward_layernorm_2(hidden_states) + return super().preprocess(hidden_states, *args, **kwargs) class Gemma4Bridge(MultimodalGPTBridge): @@ -398,17 +418,19 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): if use_alternative_attention else text_config.num_key_value_heads) return super()._set_qkv(mg_attn, hf_state_dict, to_mcore, **kwargs) + def _set_router(self, mg_mlp, hf_state_dict, to_mcore): + self._set_state_dict(mg_mlp, 'router.weight', hf_state_dict, 'router.proj.weight', to_mcore) + for key in ['per_expert_scale', 'scale']: + self._set_state_dict(mg_mlp, key, hf_state_dict, f'router.{key}', to_mcore) + def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool, is_mtp: bool = False): mg_mlp = None if mg_layer is None else mg_layer.mlp - hf_state_dict.update( - self._set_mlp_state(mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore)) + hf_state_dict.update(self._set_mlp_state(mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore)) self._set_state_dict(mg_layer, 'mlp.linear_fc1.layer_norm_weight', hf_state_dict, - f'{self.hf_post_attention_layernorm}.weight', to_mcore) + f'{self.hf_post_attention_layernorm}.weight', to_mcore) if self.text_config.enable_moe_block: mg_experts = None if mg_layer is None else mg_layer.experts_mlp - hf_state_dict.update( - self._set_moe_state( - mg_experts, hf_state_dict, '', layer_idx, to_mcore, is_mtp=is_mtp)) + hf_state_dict.update(self._set_moe_state(mg_experts, hf_state_dict, '', layer_idx, to_mcore, is_mtp=is_mtp)) return hf_state_dict def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): @@ -428,6 +450,8 @@ def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: i if self.text_config.enable_moe_block: for key in ['post_feedforward_layernorm_1', 'post_feedforward_layernorm_2']: self._set_state_dict(mg_layer, f'{key}.weight', hf_state_dict, f'{key}.weight', to_mcore) + self._set_state_dict(mg_layer, 'experts_mlp.pre_feedforward_layernorm_2.weight', hf_state_dict, + 'pre_feedforward_layernorm_2.weight', to_mcore) if to_mcore: hf_state_dict = {} else: diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index 2520bd8..01e5709 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -216,14 +216,14 @@ def _build_mlp(self, mlp_spec): additional_mlp_kwargs = {} # import here to avoid circular import from mcore_bridge.model.gpts.glm4 import Glm4MLP - from mcore_bridge.model.mm_gpts.gemma4 import Gemma4MLP + from mcore_bridge.model.mm_gpts.gemma4 import Gemma4MLP, Gemma4MoELayer # MLP expects tp_group but MoELayer expects pg_collection to be passed in. # We can change MLP to accept pg_collection but it makes the logic implicit # The conditional below is to make the logic explicit # if smlp_spec is not a ModuleSpec,we dont have to handle passing additional kwargs if isinstance(mlp_spec, ModuleSpec): - if mlp_spec.module in (MoELayer, TEGroupedMLP, SequentialMLP): + if mlp_spec.module in (MoELayer, Gemma4MoELayer, TEGroupedMLP, SequentialMLP): additional_mlp_kwargs['pg_collection'] = pg_collection # Pass is_mtp_layer flag to MoELayer to distinguish MTP MoE layers. if mlp_spec.module == MoELayer and 'is_mtp_layer' in inspect.signature(MoELayer).parameters: From a4ebaf3dfb0eb1ca3b1e8020899d7b99edb73e0d Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 19 May 2026 14:19:39 +0800 Subject: [PATCH 47/52] fix --- src/mcore_bridge/bridge/gpt_bridge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 23c0e50..1f6a907 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -751,7 +751,7 @@ def _get_hf_experts_attr(self, is_mtp: bool = False): 'glm4_moe_lite', 'minimax_m2', 'olmoe', 'qwen3_next', 'glm_moe_dsa', 'deepseek_v32' }: return False, False - elif self.model_type in {'qwen3_vl_moe', 'llama4'} or self.llm_model_type in {'gpt_oss'}: + elif self.model_type in {'qwen3_vl_moe', 'llama4', 'gemma4'} or self.llm_model_type in {'gpt_oss'}: return True, True else: # default From 22de4b4806fe3057a6566d992c182ebbdb81a0e3 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 19 May 2026 14:54:39 +0800 Subject: [PATCH 48/52] fix --- src/mcore_bridge/bridge/gpt_bridge.py | 17 +++++++---------- src/mcore_bridge/model/mm_gpts/gemma4.py | 2 +- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 1f6a907..1d95f31 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -490,8 +490,7 @@ def _set_state_dict(self, else: mg_param = deep_getattr(sub_module, param_key) if to_mcore: - if mg_param is None: - raise ValueError(f'mg_module: {mg_module}, mg_key: {mg_key}') + assert mg_param is not None, f'mg_module: {mg_module}, mg_key: {mg_key}' hf_weight = hf_state_dict[hf_key].load() if module_key in { 'embedding.word_embeddings', 'output_layer' @@ -683,6 +682,10 @@ def _set_router(self, mg_mlp, hf_state_dict, to_mcore): if self.llm_model_type == 'gpt_oss': hf_gate_key = 'router.weight' self._set_state_dict(mg_mlp, 'router.weight', hf_state_dict, hf_gate_key, to_mcore) + if self.config.add_bias_linear: + self._set_state_dict(mg_mlp, 'router.bias', hf_state_dict, hf_gate_key.replace('weight', 'bias'), to_mcore) + if self.config.moe_router_enable_expert_bias: + self._set_state_dict(mg_mlp, 'router.expert_bias', hf_state_dict, self.hf_expert_bias_key, to_mcore) def _set_moe_state( self, @@ -697,14 +700,8 @@ def _set_moe_state( hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) else: hf_state_dict = {} - config = self.config self._set_router(mg_mlp, hf_state_dict, to_mcore) - if config.add_bias_linear: - self._set_state_dict(mg_mlp, 'router.bias', hf_state_dict, hf_gate_key.replace('weight', 'bias'), to_mcore) - if config.moe_router_enable_expert_bias: - self._set_state_dict(mg_mlp, 'router.expert_bias', hf_state_dict, self.hf_expert_bias_key, to_mcore) - - if config.moe_shared_expert_intermediate_size: + if self.config.moe_shared_expert_intermediate_size: hf_shared_expert_key = self.hf_shared_expert_key if hf_shared_expert_key is None: if 'qwen' in self.llm_model_type or self.model_type == 'llama4': @@ -714,7 +711,7 @@ def _set_moe_state( hf_state_dict.update( self._set_mlp_state(None if mg_mlp is None else mg_mlp.shared_experts, hf_state_dict, f'{hf_shared_expert_key}.', layer_idx, to_mcore)) - if config.moe_shared_expert_gate: + if self.config.moe_shared_expert_gate: self._set_state_dict(mg_mlp, 'shared_experts.gate_weight', hf_state_dict, 'shared_expert_gate.weight', to_mcore) for ep_rank in range(self.ep_size): diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index a329483..7d31b28 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -75,7 +75,7 @@ def prepare_model(self, hf_config: PretrainedConfig): def get_inputs_embeds(self, inputs_embeds, **kwargs): input_ids = kwargs.get('input_ids') - inputs_embeds *= self.embed_scale.to(inputs_embeds.dtype) + inputs_embeds = inputs_embeds * self.embed_scale.to(inputs_embeds.dtype) hf_config = self.hf_config input_ids = kwargs.get('input_ids') From ed138dcc3bacddfef57dddfd233e348e0cb3c6f5 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 19 May 2026 15:54:47 +0800 Subject: [PATCH 49/52] update --- .../model/modules/transformer_block.py | 74 ++++++++++++++----- 1 file changed, 57 insertions(+), 17 deletions(-) diff --git a/src/mcore_bridge/model/modules/transformer_block.py b/src/mcore_bridge/model/modules/transformer_block.py index 073f671..5d36431 100644 --- a/src/mcore_bridge/model/modules/transformer_block.py +++ b/src/mcore_bridge/model/modules/transformer_block.py @@ -19,6 +19,51 @@ apply_module = None +class _TensorIdx: + """Sentinel that marks a position in the flatten schema as a tensor index.""" + __slots__ = ('idx', ) + + def __init__(self, idx): + self.idx = idx + + +def _checkpoint_flatten(obj, tensors): + """Recursively flatten a nested structure (dict/tuple/list/Tensor) into a schema. + + Tensors are appended to `tensors` and replaced in the schema by a _TensorIdx sentinel. + Non-tensor leaves (int, bool, None, str, ...) are stored as-is. + The schema mirrors the original structure and is captured in the checkpoint closure. + """ + if torch.is_tensor(obj): + idx = len(tensors) + tensors.append(obj) + return _TensorIdx(idx) + elif isinstance(obj, dict): + # inplace (gemma4 shared_kv_states) + for k, v in obj.items(): + obj[k] = _checkpoint_flatten(v, tensors) + return obj + elif isinstance(obj, (tuple, list)): + return type(obj)(_checkpoint_flatten(v, tensors) for v in obj) + else: + return obj # non-tensor leaf: stored directly in schema + + +def _checkpoint_unflatten(schema, tensors): + """Reconstruct the original structure from a schema and a flat tensors list.""" + if isinstance(schema, _TensorIdx): + return tensors[schema.idx] + elif isinstance(schema, dict): + # inplace (gemma4 shared_kv_states) + for k, v in schema.items(): + schema[k] = _checkpoint_unflatten(v, tensors) + return schema + elif isinstance(schema, (tuple, list)): + return type(schema)(_checkpoint_unflatten(v, tensors) for v in schema) + else: + return schema # non-tensor leaf + + # Code borrowed from NVIDIA/Megatron-LM class TransformerBlock(McoreTransformerBlock): @@ -99,26 +144,25 @@ def custom_forward( return custom_forward - # `tensor_parallel.checkpoint` / `te_checkpoint` only forward *args to the - # wrapped function (torch.utils.checkpoint limitation). Convert kwargs to - # positional args by capturing the keys in closure so tensor kwargs (e.g. - # qwen3-vl's visual_pos_masks / deepstack_visual_embeds) can flow through - # activation recompute and remain in the autograd graph. + # Variables that don't require gradients can be captured via closure. + _ckpt_attention_mask = attention_mask + _ckpt_rotary_pos_emb = rotary_pos_emb extra_kwargs_keys = tuple(kwargs.keys()) - extra_kwargs_values = tuple(kwargs.values()) + _extra_flat_tensors = [] + _extra_schemas = [_checkpoint_flatten(v, _extra_flat_tensors) for v in kwargs.values()] def checkpoint_handler(forward_func): """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" - def wrapped_forward(hidden_states, attention_mask, context, context_mask, rotary_pos_emb, padding_mask, - *extra_args): - extra_kwargs = dict(zip(extra_kwargs_keys, extra_args)) + def wrapped_forward(hidden_states, context, context_mask, padding_mask, *extra_flat): + rebuilt = [_checkpoint_unflatten(s, extra_flat) for s in _extra_schemas] + extra_kwargs = dict(zip(extra_kwargs_keys, rebuilt)) return forward_func( hidden_states, - attention_mask, + _ckpt_attention_mask, context, context_mask, - rotary_pos_emb, + _ckpt_rotary_pos_emb, padding_mask, **extra_kwargs, ) @@ -131,24 +175,20 @@ def wrapped_forward(hidden_states, attention_mask, context, context_mask, rotary tensor_parallel.random.get_cuda_rng_tracker, self.pg_collection.tp, hidden_states, - attention_mask, context, context_mask, - rotary_pos_emb, padding_mask, - *extra_kwargs_values, + *_extra_flat_tensors, ) else: return tensor_parallel.checkpoint( wrapped_forward, self.config.distribute_saved_activations, hidden_states, - attention_mask, context, context_mask, - rotary_pos_emb, padding_mask, - *extra_kwargs_values, + *_extra_flat_tensors, ) if self.config.recompute_method == 'uniform': From 2fe48e6e77b273847f0481e2f901b17b57ea3e13 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 19 May 2026 16:25:57 +0800 Subject: [PATCH 50/52] fix --- src/mcore_bridge/model/modules/transformer_layer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index 01e5709..6d500e0 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -134,8 +134,6 @@ def __init__( # [Module 8: MLP block] self.mlp = self._build_mlp(submodules.mlp) - if hasattr(self.mlp, 'set_layer_number'): - self.mlp.set_layer_number(self.layer_number) # [Module 9: BiasDropoutFusion] self.mlp_bda = build_module(submodules.mlp_bda) self.is_moe_layer = isinstance(self.mlp, MoELayer) @@ -238,7 +236,10 @@ def _build_mlp(self, mlp_spec): additional_mlp_kwargs['tp_group'] = pg_collection.tp else: logger.warning_once(f'Unknown MLP type: {mlp_spec.module}. Using default kwargs.') - return build_module(mlp_spec, config=self.config, **additional_mlp_kwargs) + mlp = build_module(mlp_spec, config=self.config, **additional_mlp_kwargs) + if hasattr(mlp, 'set_layer_number'): + mlp.set_layer_number(self.layer_number) + return mlp def forward(self, *args, **kwargs): """ From 811ee6fc480d983fd192f333cb1f23180c0c03ca Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 19 May 2026 16:32:13 +0800 Subject: [PATCH 51/52] fix --- src/mcore_bridge/model/mm_gpts/gemma4.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index 7d31b28..d5ee19c 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -103,15 +103,16 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): if pixel_values_videos is not None: with self.patch_hf_config(): - video_features = self.get_video_features( - pixel_values_videos, video_position_ids, return_dict=True).pooler_output + video_features = self.model_cls.get_video_features( + self, pixel_values_videos, video_position_ids, return_dict=True).pooler_output video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) video_mask_e = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) inputs_embeds = inputs_embeds.masked_scatter(video_mask_e, video_features) if (input_features is not None and input_features_mask is not None and self.audio_tower is not None): with self.patch_hf_config(): - audio_output = self.get_audio_features(input_features, input_features_mask, return_dict=True) + audio_output = self.model_cls.get_audio_features( + self, input_features, input_features_mask, return_dict=True) audio_features = audio_output.pooler_output audio_features = audio_features[audio_output.attention_mask] audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) @@ -582,7 +583,7 @@ def _create_sliding_attention_mask(self, attention_mask, mm_token_type_ids): window_size = self.text_config.sliding_window - 1 seq_len = attention_mask.shape[-1] - window_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device='cuda') + window_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=attention_mask.device) window_mask = ~torch.triu(window_mask, diagonal=-window_size) is_vision = mm_token_type_ids > 0 From c288d024a8d8a6ff9f42f64d5ff97193dcea1226 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 19 May 2026 16:43:10 +0800 Subject: [PATCH 52/52] fix --- src/mcore_bridge/model/mm_gpts/gemma4.py | 27 ++++++++++++------------ 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py index d5ee19c..2c68a73 100644 --- a/src/mcore_bridge/model/mm_gpts/gemma4.py +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -78,7 +78,6 @@ def get_inputs_embeds(self, inputs_embeds, **kwargs): inputs_embeds = inputs_embeds * self.embed_scale.to(inputs_embeds.dtype) hf_config = self.hf_config - input_ids = kwargs.get('input_ids') pixel_values = kwargs.get('pixel_values') pixel_values_videos = kwargs.get('pixel_values_videos') input_features = kwargs.get('input_features') @@ -586,16 +585,18 @@ def _create_sliding_attention_mask(self, attention_mask, mm_token_type_ids): window_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=attention_mask.device) window_mask = ~torch.triu(window_mask, diagonal=-window_size) - is_vision = mm_token_type_ids > 0 - is_prev_vision = torch.roll(is_vision, shifts=1, dims=-1) - is_prev_vision[:, 0] = False - vision_group_ids = torch.cumsum((is_vision & ~is_prev_vision).int(), dim=1) - 1 - vision_group_ids = torch.where(is_vision, vision_group_ids, torch.full_like(vision_group_ids, -1)) - - q_group = vision_group_ids.unsqueeze(1).unsqueeze(-1) - k_group = vision_group_ids.unsqueeze(1).unsqueeze(-2) - same_vision_group = (q_group == k_group) & (q_group >= 0) & (k_group >= 0) - return (attention_mask | window_mask) & ~same_vision_group + attention_mask = attention_mask | window_mask + if mm_token_type_ids is not None: + is_vision = mm_token_type_ids > 0 + is_prev_vision = torch.roll(is_vision, shifts=1, dims=-1) + is_prev_vision[:, 0] = False + vision_group_ids = torch.cumsum((is_vision & ~is_prev_vision).int(), dim=1) - 1 + vision_group_ids = torch.where(is_vision, vision_group_ids, torch.full_like(vision_group_ids, -1)) + q_group = vision_group_ids.unsqueeze(1).unsqueeze(-1) + k_group = vision_group_ids.unsqueeze(1).unsqueeze(-2) + same_vision_group = (q_group == k_group) & (q_group >= 0) & (k_group >= 0) + attention_mask = attention_mask & ~same_vision_group + return attention_mask def _pack_pp_output(self, hidden_states, per_layer_inputs, shared_kv_states): per_layer_inputs = per_layer_inputs.view(*hidden_states.shape[:2], -1) @@ -630,12 +631,12 @@ def unpack_pp_input(self): per_layer_inputs_dim = math.prod(per_layer_inputs_shape) // math.prod(input_tensor.shape[:2]) full_states_dim = math.prod(full_states_shape) // math.prod(input_tensor.shape[:2]) sliding_states_dim = math.prod(sliding_states_shape) // math.prod(input_tensor.shape[:2]) - if flag[0, 0, 0].item() != 0: + if flag[0, 0, 1].item() != 0: input_tensor, full_states = input_tensor.split([input_tensor.shape[-1] - full_states_dim, full_states_dim], dim=-1) full_states = full_states.reshape(*full_states_shape) shared_kv_states['full_attention'] = full_states.chunk(2, -1) - if flag[0, 0, 1].item() != 0: + if flag[0, 0, 0].item() != 0: input_tensor, sliding_states = input_tensor.split( [input_tensor.shape[-1] - sliding_states_dim, sliding_states_dim], dim=-1) sliding_states = sliding_states.reshape(*sliding_states_shape)