Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
4e2df75
update
Jintao-Huang Apr 29, 2026
c388954
update
Jintao-Huang Apr 29, 2026
76af2bc
update
Jintao-Huang Apr 29, 2026
54e3343
update
Jintao-Huang Apr 29, 2026
25a45bd
update
Jintao-Huang Apr 30, 2026
3210617
update
Jintao-Huang Apr 30, 2026
e1355dd
Merge branch 'main' into support_gemma4
Jintao-Huang May 3, 2026
5b4e118
update
Jintao-Huang May 4, 2026
d1d2246
update
Jintao-Huang May 4, 2026
24cd697
Merge branch 'main' into support_gemma4
Jintao-Huang May 4, 2026
bba8144
merge
Jintao-Huang May 4, 2026
14b1644
fix
Jintao-Huang May 5, 2026
196a58f
update
Jintao-Huang May 5, 2026
68e33a7
update
Jintao-Huang May 5, 2026
9736c3e
Merge branch 'main' into support_gemma4
Jintao-Huang May 5, 2026
44ddaec
update
Jintao-Huang May 5, 2026
b3cc043
Merge branch 'main' into support_gemma4
Jintao-Huang May 5, 2026
7563bc4
Merge branch 'main' into support_gemma4
Jintao-Huang May 8, 2026
4a74289
Merge branch 'main' into support_gemma4
Jintao-Huang May 8, 2026
0de0ebb
update
Jintao-Huang May 9, 2026
2a81bf0
update
Jintao-Huang May 9, 2026
7e05d3d
update
Jintao-Huang May 9, 2026
8da05df
fix
Jintao-Huang May 9, 2026
d1eff8a
fix
Jintao-Huang May 9, 2026
e545c4f
Merge remote-tracking branch 'refs/remotes/origin/support_gemma4' int…
Jintao-Huang May 9, 2026
0c22e68
fix
Jintao-Huang May 9, 2026
63511bd
Merge remote-tracking branch 'refs/remotes/origin/support_gemma4' int…
Jintao-Huang May 9, 2026
fa5360b
fix
Jintao-Huang May 9, 2026
d25db28
update
Jintao-Huang May 9, 2026
bfbcbc4
update
Jintao-Huang May 9, 2026
2300825
fix
Jintao-Huang May 9, 2026
cda31a5
update
Jintao-Huang May 9, 2026
7e6fb75
update
Jintao-Huang May 9, 2026
e1d0851
update
Jintao-Huang May 9, 2026
0178948
Merge branch 'main' into support_gemma4
Jintao-Huang May 9, 2026
e3cbe5d
update
Jintao-Huang May 9, 2026
4ae40a2
update
Jintao-Huang May 11, 2026
3123407
update
Jintao-Huang May 11, 2026
03dc221
update
Jintao-Huang May 11, 2026
41465ff
update
Jintao-Huang May 11, 2026
23cd670
fix
Jintao-Huang May 11, 2026
9bc5596
Merge branch 'main' into support_gemma4
Jintao-Huang May 11, 2026
4d2dc1c
Merge branch 'main' into support_gemma4
Jintao-Huang May 15, 2026
6ad3a43
Merge branch 'main' into support_gemma4
Jintao-Huang May 15, 2026
8460a3a
update
Jintao-Huang May 16, 2026
23356d9
update
Jintao-Huang May 16, 2026
a3406b2
Merge branch 'main' into support_gemma4
Jintao-Huang May 17, 2026
be8f320
update
Jintao-Huang May 18, 2026
bdb29e0
fix pp
Jintao-Huang May 18, 2026
d097e68
update
Jintao-Huang May 18, 2026
5c29f77
fix
Jintao-Huang May 18, 2026
82cd1ad
update
Jintao-Huang May 18, 2026
472f2f7
fix
Jintao-Huang May 18, 2026
13dc948
fix
Jintao-Huang May 18, 2026
3d3e90e
fix
Jintao-Huang May 18, 2026
ed370d9
update
Jintao-Huang May 18, 2026
b3dfdb9
update
Jintao-Huang May 18, 2026
8744cc4
update
Jintao-Huang May 18, 2026
3c3c4b8
update
Jintao-Huang May 19, 2026
e10c6e9
update
Jintao-Huang May 19, 2026
a4ebaf3
fix
Jintao-Huang May 19, 2026
22de4b4
fix
Jintao-Huang May 19, 2026
ed138dc
update
Jintao-Huang May 19, 2026
2fe48e6
fix
Jintao-Huang May 19, 2026
811ee6f
fix
Jintao-Huang May 19, 2026
c288d02
fix
Jintao-Huang May 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ The following is the list of models supported by MCore-Bridge:
| Series | model_type |
| -------- | ------------------------------------------------------------ |
| Qwen | qwen2_vl, qwen2_5_vl, qwen2_5_omni<br />qwen3_vl, qwen3_vl_moe, qwen3_omni_moe, qwen3_asr<br />qwen3_5, qwen3_5_moe |
| Gemma | gemma4 |
| GLM | glm4v, glm4v_moe |
| Kimi | kimi_vl |
| InternVL | internvl_chat, internvl |
Expand Down
1 change: 1 addition & 0 deletions README_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ uv pip install -e . --torch-backend=auto
| 系列 | model_type |
| -------- | ------------------------------------------------------------ |
| Qwen | qwen2_vl, qwen2_5_vl, qwen2_5_omni<br />qwen3_vl, qwen3_vl_moe, qwen3_omni_moe, qwen3_asr<br />qwen3_5, qwen3_5_moe |
| Gemma | gemma4 |
| GLM | glm4v, glm4v_moe |
| Kimi | kimi_vl |
| InternVL | internvl_chat, internvl |
Expand Down
132 changes: 74 additions & 58 deletions src/mcore_bridge/bridge/gpt_bridge.py

Large diffs are not rendered by default.

17 changes: 15 additions & 2 deletions src/mcore_bridge/config/parser.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import torch.nn.functional as F
from functools import partial
from transformers import PretrainedConfig
from typing import Any, Dict

Expand All @@ -25,7 +27,7 @@
# moe
'moe_ffn_hidden_size': ['moe_intermediate_size'],
'moe_shared_expert_intermediate_size': ['shared_expert_intermediate_size'],
'moe_router_topk': ['num_experts_per_tok', 'moe_topk', 'moe_k'],
'moe_router_topk': ['num_experts_per_tok', 'moe_topk', 'moe_k', 'top_k_experts'],
'moe_router_num_groups': ['n_group'],
'moe_router_group_topk': ['topk_group'],
'num_moe_experts': ['num_experts', 'n_routed_experts', 'moe_num_experts', 'num_local_experts'],
Expand Down Expand Up @@ -153,6 +155,17 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]:
n_shared_experts = res.pop('n_shared_experts')
elif llm_model_type in {'ernie4_5', 'ernie4_5_moe', 'glm4'}:
res['rotary_interleaved'] = True
elif hf_model_type in {'gemma4'}:
res['qk_layernorm'] = True
# If set to "vision", pass attention_mask manually.
if hf_config.text_config.use_bidirectional_attention is None:
res['window_size'] = f'{window_size - 1},0'
Comment thread
Jintao-Huang marked this conversation as resolved.
window_attn_skip_freq = ','.join(['1' if lt == 'sliding_attention' else '0' for lt in layer_types])
res['window_attn_skip_freq'] = f'[{window_attn_skip_freq}]'
res['softmax_scale'] = 1.
res['swiglu'] = False
res['gated_linear_unit'] = True
res['activation_func'] = partial(F.gelu, approximate='tanh')
elif llm_model_type == 'gpt_oss':
res['add_bias_linear'] = True
res['bias_dropout_fusion'] = False
Expand All @@ -161,7 +174,7 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]:
res['quick_geglu'] = True
res['activation_func_clamp_value'] = 7
res['glu_linear_offset'] = 1
res['window_size'] = f'{window_size},0'
res['window_size'] = f'{window_size - 1},0'
if layer_types is None:
res['window_attn_skip_freq'] = '2'
else:
Expand Down
1 change: 1 addition & 0 deletions src/mcore_bridge/model/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class MLLMModelType:
glm4v_moe = 'glm4v_moe'
kimi_vl = 'kimi_vl'
llama4 = 'llama4'
gemma4 = 'gemma4'

kimi_k25 = 'kimi_k25'

Expand Down
72 changes: 45 additions & 27 deletions src/mcore_bridge/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def sharded_state_dict(

class GPTModel(McoreGPTModel):
config: ModelConfig
extra_forward_keys = []

def __init__(
self,
Expand Down Expand Up @@ -105,9 +106,7 @@ def __init__(
for i in range(len(self.decoder.layers)):
if hasattr(self.decoder.layers[i].self_attention, 'rotary_pos_emb'):
del self.decoder.layers[i].self_attention.rotary_pos_emb
self.attention_scaling = 1.
new_inv_freq, self.attention_scaling = get_rope_inv_freq(config)
self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device)
self._set_inv_freq()
if self.config.task_type == 'seq_cls' and self.post_process:
self.output_layer = OutputLayerLinear(
config.hidden_size,
Expand Down Expand Up @@ -216,7 +215,35 @@ def _preprocess(
if decoder_input is not None and self.training and torch.is_grad_enabled() and not decoder_input.requires_grad:
# fix LoRA incompatibility with gradient checkpointing
decoder_input = decoder_input.requires_grad_(True)
rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb(
decoder_input, position_ids, packed_seq_params=packed_seq_params)

if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration')
or self.config.flash_decode) and rotary_pos_cos is not None
and inference_context.is_static_batching()):
current_batch_size = input_ids.shape[0]
sequence_len_offset = torch.tensor(
[inference_context.sequence_len_offset] * current_batch_size,
dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
)
else:
sequence_len_offset = None

# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
# reference held by this caller function, enabling early garbage collection for
# inference. Skip wrapping if decoder_input is logged after decoder completion.
if in_inference_mode and not has_config_logger_enabled(self.config):
decoder_input = WrappedTensor(decoder_input)

return (decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset)

def _set_inv_freq(self):
self.attention_scaling = 1.
new_inv_freq, self.attention_scaling = get_rope_inv_freq(self.config)
self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device)

def _get_rotary_pos_emb(self, decoder_input, position_ids, packed_seq_params, inference_context=None):
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
rotary_pos_cos = None
Expand Down Expand Up @@ -251,26 +278,7 @@ def _preprocess(
rotary_seq_len,
packed_seq=packed_seq,
)

if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration')
or self.config.flash_decode) and rotary_pos_cos is not None
and inference_context.is_static_batching()):
current_batch_size = input_ids.shape[0]
sequence_len_offset = torch.tensor(
[inference_context.sequence_len_offset] * current_batch_size,
dtype=torch.int32,
device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
)
else:
sequence_len_offset = None

# Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
# reference held by this caller function, enabling early garbage collection for
# inference. Skip wrapping if decoder_input is logged after decoder completion.
if in_inference_mode and not has_config_logger_enabled(self.config):
decoder_input = WrappedTensor(decoder_input)

return decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset
return rotary_pos_emb, rotary_pos_cos, rotary_pos_sin

# Code borrowed from NVIDIA/Megatron-LM
def forward(
Expand Down Expand Up @@ -314,17 +322,24 @@ def forward(
packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd'
if self.position_embedding_type == 'rope' and packed_seq and not self.config.apply_rope_fusion:
assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}'
decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]]
if isinstance(rotary_pos_emb, dict):
for k, v in rotary_pos_emb.items():
decoder_rotary_pos_emb[k] = v[position_ids[0]]
else:
decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]]

mtp_decoder_input = decoder_input
if self.config.is_multimodal and self.config.mtp_num_layers and decoder_input is None:
input_tensor = self.get_input_tensor()
input_tensor, mtp_decoder_input = input_tensor.chunk(2, dim=0)
self.set_input_tensor(input_tensor)
kwargs = {}
full_attention_mask = attention_mask
if isinstance(full_attention_mask, dict):
full_attention_mask = full_attention_mask['full_attention']
if mcore_016 and attention_mask is not None:
assert packed_seq_params is None
padding_mask = ~((~attention_mask).sum(dim=(1, 2)) > 0)
padding_mask = ~((~full_attention_mask).sum(dim=(1, 2)) > 0)
if self.config.context_parallel_size > 1:
padding_mask = split_cp_inputs(padding_mask, None, 1)
tp_size = self.config.tensor_model_parallel_size
Expand Down Expand Up @@ -366,6 +381,9 @@ def forward(
inference_context=inference_context,
)

def _forward_output_layer(self, hidden_states, *args, **kwargs):
return self.output_layer(hidden_states, *args, **kwargs)[0]

def _postprocess(
self,
hidden_states,
Expand Down Expand Up @@ -432,7 +450,7 @@ def _postprocess(
loss_mask = torch.ones_like(mtp_labels)
for mtp_layer_number in range(self.config.mtp_unroll_steps):
# output
mtp_logits, _ = self.output_layer(
mtp_logits = self._forward_output_layer(
hidden_states_list[mtp_layer_number + 1],
weight=output_weight,
runtime_gather_output=runtime_gather_output,
Expand Down Expand Up @@ -493,7 +511,7 @@ def _postprocess(
if self.config.task_type == 'embedding':
logits = F.normalize(hidden_states, p=2, dim=-1)
else:
logits, _ = self.output_layer(
logits = self._forward_output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output)
if self.config.task_type == 'generative_reranker':
logits = gather_from_tensor_model_parallel_region(logits)
Expand Down
2 changes: 1 addition & 1 deletion src/mcore_bridge/model/gpts/bailing_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class BailingMoeBridge(GPTBridge):
hf_expert_bias_key = 'gate.expert_bias'
hf_o_proj_key = 'dense'

def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool):
def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs):
self._set_state_dict(mg_attn, 'linear_qkv.weight', hf_state_dict, 'query_key_value.weight', to_mcore)
assert not self.config.add_bias_linear
return hf_state_dict
Expand Down
2 changes: 1 addition & 1 deletion src/mcore_bridge/model/gpts/minimax_m2.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class MinimaxM2Bridge(GPTBridge):
hf_mlp_prefix = 'block_sparse_moe'
hf_expert_bias_key = 'e_score_correction_bias'

def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore):
def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore, **kwargs):
self._set_state_dict(mg_attn, 'q_norm.weight', hf_state_dict, 'q_norm.weight', to_mcore)
self._set_state_dict(mg_attn, 'k_norm.weight', hf_state_dict, 'k_norm.weight', to_mcore)

Expand Down
2 changes: 2 additions & 0 deletions src/mcore_bridge/model/mm_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def forward(
packed_seq_params: PackedSeqParams = None,
**kwargs,
) -> torch.Tensor:
extra_kwargs = {k: kwargs[k] for k in self.language_model.extra_forward_keys}
if decoder_input is not None:
pass
elif self.pre_process:
Expand All @@ -90,6 +91,7 @@ def forward(
# decoder will get hidden_states from encoder.input_tensor
decoder_input = None
kwargs = {}
kwargs.update(extra_kwargs)
return self.language_model(
input_ids=input_ids,
position_ids=position_ids,
Expand Down
2 changes: 1 addition & 1 deletion src/mcore_bridge/model/mm_gpts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from . import glm, internvl, kimi_vl, llama4, llava, qwen, qwen3_5, qwen3_5_gdn, qwen3_asr, qwen3_omni, qwen3_vl
from . import gemma4, glm, internvl, kimi_vl, llama4, llava, qwen, qwen3_5, qwen3_5_gdn, qwen3_asr, qwen3_omni, qwen3_vl
Loading
Loading