diff --git a/README.md b/README.md
index 1473a16..bf9423a 100644
--- a/README.md
+++ b/README.md
@@ -145,6 +145,7 @@ The following is the list of models supported by MCore-Bridge:
| Series | model_type |
| -------- | ------------------------------------------------------------ |
| Qwen | qwen2_vl, qwen2_5_vl, qwen2_5_omni
qwen3_vl, qwen3_vl_moe, qwen3_omni_moe, qwen3_asr
qwen3_5, qwen3_5_moe |
+| Gemma | gemma4 |
| GLM | glm4v, glm4v_moe |
| Kimi | kimi_vl |
| InternVL | internvl_chat, internvl |
diff --git a/README_zh.md b/README_zh.md
index 5f170e7..b7ee63d 100644
--- a/README_zh.md
+++ b/README_zh.md
@@ -142,6 +142,7 @@ uv pip install -e . --torch-backend=auto
| 系列 | model_type |
| -------- | ------------------------------------------------------------ |
| Qwen | qwen2_vl, qwen2_5_vl, qwen2_5_omni
qwen3_vl, qwen3_vl_moe, qwen3_omni_moe, qwen3_asr
qwen3_5, qwen3_5_moe |
+| Gemma | gemma4 |
| GLM | glm4v, glm4v_moe |
| Kimi | kimi_vl |
| InternVL | internvl_chat, internvl |
diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py
index 0eceb23..1d95f31 100644
--- a/src/mcore_bridge/bridge/gpt_bridge.py
+++ b/src/mcore_bridge/bridge/gpt_bridge.py
@@ -360,7 +360,6 @@ def _broadcast_ep_pp(self, tensor, is_expert):
dtype_mapping_r = {v: k for k, v in dtype_mapping.items()}
if tensor is None:
dist.broadcast(meta_data, src=src_rank, group=pp_group)
- assert meta_data[0].item() > 0, f'meta_data: {meta_data}'
shape = meta_data[1:1 + meta_data[0]].tolist()
dtype = dtype_mapping_r[meta_data[-1].item()]
tensor = torch.empty(shape, device='cuda', dtype=dtype)
@@ -383,7 +382,8 @@ def _get_weight(
# tp/etp
mg_scale_inv = None
tensor = mg_weight
- if tensor is not None:
+ is_scalar = isinstance(tensor, torch.Tensor) and tensor.ndim == 0
+ if tensor is not None and not is_scalar:
if not isinstance(tensor, (list, tuple)):
tensor = [tensor]
if self._is_fp8_param(tensor[0]):
@@ -392,8 +392,7 @@ def _get_weight(
for t in tensor
]
tensor = [t._rowwise_data for t in tensor]
- del mg_weight
- if tensor is not None:
+ del mg_weight
assert isinstance(tensor, (list, tuple)), f'mg_key: {mg_key}'
tensor = torch.concat(tensor, dim=0)
if mg_scale_inv is not None:
@@ -529,22 +528,37 @@ def _filter_prefix(state_dict, prefix: str):
return state_dict
return {k: v for k, v in state_dict.items() if k.startswith(prefix)}
- def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool):
+ def _reduce_tensor_pp_group(self, tensor, to_mcore, dtype=torch.long, op=dist.ReduceOp.MAX):
+ if to_mcore:
+ return tensor
+ tensor = torch.tensor([tensor], dtype=dtype, device='cuda')
+ if self.pp_size > 1:
+ dist.all_reduce(tensor, group=self.pp_group, op=op)
+ tensor = tensor.item()
+ return tensor
+
+ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs):
config = self.config
- num_query_groups = (
- config.num_query_groups if config.num_query_groups is not None else config.num_attention_heads)
+ num_query_groups = kwargs.get('num_query_groups')
+ if num_query_groups is None:
+ num_query_groups = (
+ config.num_query_groups if config.num_query_groups is not None else config.num_attention_heads)
hidden_size_block = config.hidden_size // self.fp8_block_size
+ attention_k_eq_v = kwargs.get('attention_k_eq_v', False)
+ kv_proj_list = ['k_proj'] if attention_k_eq_v else ['k_proj', 'v_proj']
+ kv_channels = kwargs.get('kv_channels')
+ if kv_channels is None:
+ kv_channels = self.config.kv_channels
if to_mcore:
if isinstance(mg_attn.linear_qkv, LoraParallelLinear):
lora_A = hf_state_dict['q_proj.lora_A.weight'].load()
- assert (lora_A == hf_state_dict['k_proj.lora_A.weight'].load()).all() and (
- lora_A == hf_state_dict['v_proj.lora_A.weight'].load()
- ).all(), 'Need to ensure QKV\'s lora_A are consistent'
+ assert all((lora_A == hf_state_dict[f'{k}.lora_A.weight'].load()).all()
+ for k in kv_proj_list), 'Need to ensure QKV\'s lora_A are consistent'
q_lora_B = hf_state_dict['q_proj.lora_B.weight'].load()
lora_B = torch.cat([
q_lora_B.reshape((num_query_groups, -1, q_lora_B.shape[-1])),
- hf_state_dict['k_proj.lora_B.weight'].load().reshape((num_query_groups, -1, q_lora_B.shape[-1])),
- hf_state_dict['v_proj.lora_B.weight'].load().reshape((num_query_groups, -1, q_lora_B.shape[-1])),
+ *(hf_state_dict[f'{k}.lora_B.weight'].load().reshape((num_query_groups, -1, q_lora_B.shape[-1]))
+ for k in kv_proj_list),
],
dim=1).reshape((-1, q_lora_B.shape[-1]))
self._set_weight(mg_attn.linear_qkv.lora_A[self._adapter_name].weight, lora_A,
@@ -553,29 +567,24 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool):
'linear_qkv.lora_B.weight')
elif not self._peft_format:
linear_qkv_weight = torch.cat([
- hf_state_dict['q_proj.weight'].load().reshape((num_query_groups, -1, config.hidden_size)),
- hf_state_dict['k_proj.weight'].load().reshape((num_query_groups, -1, config.hidden_size)),
- hf_state_dict['v_proj.weight'].load().reshape((num_query_groups, -1, config.hidden_size)),
+ hf_state_dict[f'{k}.weight'].load().reshape((num_query_groups, -1, config.hidden_size))
+ for k in ['q_proj'] + kv_proj_list
],
dim=1).reshape((-1, config.hidden_size))
qkv_scale_inv = None
if 'q_proj.weight_scale_inv' in hf_state_dict:
qkv_scale_inv = torch.cat([
- hf_state_dict['q_proj.weight_scale_inv'].load().reshape(
- (num_query_groups, -1, hidden_size_block)),
- hf_state_dict['k_proj.weight_scale_inv'].load().reshape(
- (num_query_groups, -1, hidden_size_block)),
- hf_state_dict['v_proj.weight_scale_inv'].load().reshape(
- (num_query_groups, -1, hidden_size_block)),
+ hf_state_dict[f'{k}.weight_scale_inv'].load().reshape((num_query_groups, -1, hidden_size_block))
+ for k in ['q_proj'] + kv_proj_list
],
dim=1).reshape((-1, hidden_size_block))
self._set_weight(
mg_attn.linear_qkv.weight, linear_qkv_weight, 'linear_qkv.weight', hf_scale_inv=qkv_scale_inv)
else:
- q_dim = self.config.kv_channels * self.config.num_attention_heads // self.config.num_query_groups
+ q_dim = kv_channels * self.config.num_attention_heads // num_query_groups
if self.config.attention_output_gate:
q_dim *= 2
- kv_dim = self.config.kv_channels
+ kv_dim = kv_channels
q_block = q_dim // self.fp8_block_size
kv_block = kv_dim // self.fp8_block_size
is_lora = False if mg_attn is None else isinstance(mg_attn.linear_qkv,
@@ -591,15 +600,17 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool):
None if mg_attn is None else mg_attn.linear_qkv.lora_B[self._adapter_name].weight.data,
f'linear_qkv.lora_B.{self._adapter_name}.weight')
if lora_A is not None:
- self._peft_target_modules.update({'q_proj', 'k_proj', 'v_proj'})
- for key in ['q_proj', 'k_proj', 'v_proj']:
+ self._peft_target_modules.update({'q_proj'} | set(kv_proj_list))
+ for key in ['q_proj'] + kv_proj_list:
hf_state_dict[f'{key}.lora_A.weight'] = lora_A.clone()
lora_B = lora_B.reshape((num_query_groups, -1, lora_B.shape[-1]))
hf_state_dict['q_proj.lora_B.weight'] = lora_B[:, :q_dim, :].reshape(-1, lora_B.shape[-1]).clone()
- hf_state_dict['k_proj.lora_B.weight'] = lora_B[:,
- q_dim:-kv_dim, :].reshape(-1,
- lora_B.shape[-1]).clone()
- hf_state_dict['v_proj.lora_B.weight'] = lora_B[:, -kv_dim:, :].reshape(-1, lora_B.shape[-1]).clone()
+ hf_state_dict['k_proj.lora_B.weight'] = lora_B[:, q_dim:q_dim + kv_dim, :].reshape(
+ -1, lora_B.shape[-1]).clone()
+ if not attention_k_eq_v:
+ hf_state_dict['v_proj.lora_B.weight'] = lora_B[:,
+ -kv_dim:, :].reshape(-1,
+ lora_B.shape[-1]).clone()
elif not self._peft_format:
mg_attn_weight, scale_inv = self._get_weight(
None if mg_attn is None else mg_attn.linear_qkv.weight.data, 'linear_qkv.weight')
@@ -607,27 +618,27 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool):
mg_attn_weight = mg_attn_weight.reshape((num_query_groups, -1, config.hidden_size))
hf_state_dict['q_proj.weight'] = mg_attn_weight[:, :q_dim, :].reshape(-1,
config.hidden_size).clone()
- hf_state_dict['k_proj.weight'] = mg_attn_weight[:, q_dim:-kv_dim, :].reshape(
+ hf_state_dict['k_proj.weight'] = mg_attn_weight[:, q_dim:q_dim + kv_dim, :].reshape(
-1, config.hidden_size).clone()
- hf_state_dict['v_proj.weight'] = mg_attn_weight[:, -kv_dim:, :].reshape(-1,
- config.hidden_size).clone()
+ if not attention_k_eq_v:
+ hf_state_dict['v_proj.weight'] = mg_attn_weight[:, -kv_dim:, :].reshape(
+ -1, config.hidden_size).clone()
if scale_inv is not None:
scale_inv = scale_inv.reshape((num_query_groups, -1, hidden_size_block))
hf_state_dict['q_proj.weight_scale_inv'] = scale_inv[:, :q_block, :].reshape(
-1, hidden_size_block).clone()
- hf_state_dict['k_proj.weight_scale_inv'] = scale_inv[:, q_block:-kv_block, :].reshape(
- -1, hidden_size_block).clone()
- hf_state_dict['v_proj.weight_scale_inv'] = scale_inv[:, -kv_block:, :].reshape(
+ hf_state_dict['k_proj.weight_scale_inv'] = scale_inv[:, q_block:q_block + kv_block:, :].reshape(
-1, hidden_size_block).clone()
+ if not attention_k_eq_v:
+ hf_state_dict['v_proj.weight_scale_inv'] = scale_inv[:, -kv_block:, :].reshape(
+ -1, hidden_size_block).clone()
del mg_attn_weight
# Copy bias
if (config.add_bias_linear or config.add_qkv_bias) and not self._peft_format:
if to_mcore:
linear_qkv_bias = torch.cat([
- hf_state_dict['q_proj.bias'].load().reshape((num_query_groups, -1)),
- hf_state_dict['k_proj.bias'].load().reshape((num_query_groups, -1)),
- hf_state_dict['v_proj.bias'].load().reshape((num_query_groups, -1)),
+ hf_state_dict[f'{k}.bias'].load().reshape((num_query_groups, -1)) for k in ['q_proj'] + kv_proj_list
],
dim=1).reshape(-1)
self._set_weight(mg_attn.linear_qkv.bias, linear_qkv_bias, 'linear_qkv.bias')
@@ -637,8 +648,9 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool):
if mg_attn_bias is not None:
mg_attn_bias = mg_attn_bias.reshape((num_query_groups, -1))
hf_state_dict['q_proj.bias'] = mg_attn_bias[:, :q_dim].reshape(-1).clone()
- hf_state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:-kv_dim].reshape(-1).clone()
- hf_state_dict['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1).clone()
+ hf_state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:q_dim + kv_dim].reshape(-1).clone()
+ if not attention_k_eq_v:
+ hf_state_dict['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1).clone()
return hf_state_dict
def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool):
@@ -647,24 +659,34 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int
else:
hf_state_dict = {}
config = self.config
- hf_state_dict.update(self._set_qkv(mg_attn, hf_state_dict, to_mcore))
+ hf_state_dict.update(self._set_qkv(mg_attn, hf_state_dict, to_mcore, layer_idx=layer_idx))
self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, f'{self.hf_o_proj_key}.weight', to_mcore)
if config.add_bias_linear:
self._set_state_dict(mg_attn, 'linear_proj.bias', hf_state_dict, f'{self.hf_o_proj_key}.bias', to_mcore)
if getattr(config, 'softmax_type', 'vanilla') == 'learnable':
self._set_state_dict(mg_attn, 'core_attention.softmax_offset', hf_state_dict, 'sinks', to_mcore)
if config.qk_layernorm:
- self._set_qk_layernorm(mg_attn, hf_state_dict, to_mcore)
+ self._set_qk_layernorm(mg_attn, hf_state_dict, to_mcore, layer_idx=layer_idx)
if to_mcore:
hf_state_dict = {}
else:
hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix)
return hf_state_dict
- def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore):
+ def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore, **kwargs):
self._set_state_dict(mg_attn, 'q_layernorm.weight', hf_state_dict, self.hf_q_norm_key, to_mcore)
self._set_state_dict(mg_attn, 'k_layernorm.weight', hf_state_dict, self.hf_k_norm_key, to_mcore)
+ def _set_router(self, mg_mlp, hf_state_dict, to_mcore):
+ hf_gate_key = self.hf_gate_key
+ if self.llm_model_type == 'gpt_oss':
+ hf_gate_key = 'router.weight'
+ self._set_state_dict(mg_mlp, 'router.weight', hf_state_dict, hf_gate_key, to_mcore)
+ if self.config.add_bias_linear:
+ self._set_state_dict(mg_mlp, 'router.bias', hf_state_dict, hf_gate_key.replace('weight', 'bias'), to_mcore)
+ if self.config.moe_router_enable_expert_bias:
+ self._set_state_dict(mg_mlp, 'router.expert_bias', hf_state_dict, self.hf_expert_bias_key, to_mcore)
+
def _set_moe_state(
self,
mg_mlp,
@@ -678,17 +700,8 @@ def _set_moe_state(
hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
else:
hf_state_dict = {}
- config = self.config
- hf_gate_key = self.hf_gate_key
- if self.llm_model_type == 'gpt_oss':
- hf_gate_key = 'router.weight'
- self._set_state_dict(mg_mlp, 'router.weight', hf_state_dict, hf_gate_key, to_mcore)
- if config.add_bias_linear:
- self._set_state_dict(mg_mlp, 'router.bias', hf_state_dict, hf_gate_key.replace('weight', 'bias'), to_mcore)
- if config.moe_router_enable_expert_bias:
- self._set_state_dict(mg_mlp, 'router.expert_bias', hf_state_dict, self.hf_expert_bias_key, to_mcore)
-
- if config.moe_shared_expert_intermediate_size:
+ self._set_router(mg_mlp, hf_state_dict, to_mcore)
+ if self.config.moe_shared_expert_intermediate_size:
hf_shared_expert_key = self.hf_shared_expert_key
if hf_shared_expert_key is None:
if 'qwen' in self.llm_model_type or self.model_type == 'llama4':
@@ -698,7 +711,7 @@ def _set_moe_state(
hf_state_dict.update(
self._set_mlp_state(None if mg_mlp is None else mg_mlp.shared_experts, hf_state_dict,
f'{hf_shared_expert_key}.', layer_idx, to_mcore))
- if config.moe_shared_expert_gate:
+ if self.config.moe_shared_expert_gate:
self._set_state_dict(mg_mlp, 'shared_experts.gate_weight', hf_state_dict, 'shared_expert_gate.weight',
to_mcore)
for ep_rank in range(self.ep_size):
@@ -735,7 +748,7 @@ def _get_hf_experts_attr(self, is_mtp: bool = False):
'glm4_moe_lite', 'minimax_m2', 'olmoe', 'qwen3_next', 'glm_moe_dsa', 'deepseek_v32'
}:
return False, False
- elif self.model_type in {'qwen3_vl_moe', 'llama4'} or self.llm_model_type in {'gpt_oss'}:
+ elif self.model_type in {'qwen3_vl_moe', 'llama4', 'gemma4'} or self.llm_model_type in {'gpt_oss'}:
return True, True
else:
# default
@@ -1618,13 +1631,16 @@ def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: i
hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix)
return hf_state_dict
+ def _set_word_embeddings(self, mg_model, hf_state_dict, to_mcore):
+ lm_model = getattr(mg_model, 'language_model') if self.is_multimodal else mg_model
+ self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, self.hf_embed_key, to_mcore)
+
def _convert_pre_process(self, mg_model, hf_state_dict, hf_prefix: str, to_mcore):
if to_mcore:
hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
else:
hf_state_dict = {}
- lm_model = getattr(mg_model, 'language_model') if self.is_multimodal else mg_model
- self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, self.hf_embed_key, to_mcore)
+ self._set_word_embeddings(mg_model, hf_state_dict, to_mcore)
if self.is_multimodal:
for prefix, mg_prefix in self.module_mapping.items():
mg_module = deep_getattr(mg_model, f'visual.{mg_prefix}')
diff --git a/src/mcore_bridge/config/parser.py b/src/mcore_bridge/config/parser.py
index 869f876..842c1bb 100644
--- a/src/mcore_bridge/config/parser.py
+++ b/src/mcore_bridge/config/parser.py
@@ -1,4 +1,6 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
+import torch.nn.functional as F
+from functools import partial
from transformers import PretrainedConfig
from typing import Any, Dict
@@ -25,7 +27,7 @@
# moe
'moe_ffn_hidden_size': ['moe_intermediate_size'],
'moe_shared_expert_intermediate_size': ['shared_expert_intermediate_size'],
- 'moe_router_topk': ['num_experts_per_tok', 'moe_topk', 'moe_k'],
+ 'moe_router_topk': ['num_experts_per_tok', 'moe_topk', 'moe_k', 'top_k_experts'],
'moe_router_num_groups': ['n_group'],
'moe_router_group_topk': ['topk_group'],
'num_moe_experts': ['num_experts', 'n_routed_experts', 'moe_num_experts', 'num_local_experts'],
@@ -153,6 +155,17 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]:
n_shared_experts = res.pop('n_shared_experts')
elif llm_model_type in {'ernie4_5', 'ernie4_5_moe', 'glm4'}:
res['rotary_interleaved'] = True
+ elif hf_model_type in {'gemma4'}:
+ res['qk_layernorm'] = True
+ # If set to "vision", pass attention_mask manually.
+ if hf_config.text_config.use_bidirectional_attention is None:
+ res['window_size'] = f'{window_size - 1},0'
+ window_attn_skip_freq = ','.join(['1' if lt == 'sliding_attention' else '0' for lt in layer_types])
+ res['window_attn_skip_freq'] = f'[{window_attn_skip_freq}]'
+ res['softmax_scale'] = 1.
+ res['swiglu'] = False
+ res['gated_linear_unit'] = True
+ res['activation_func'] = partial(F.gelu, approximate='tanh')
elif llm_model_type == 'gpt_oss':
res['add_bias_linear'] = True
res['bias_dropout_fusion'] = False
@@ -161,7 +174,7 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]:
res['quick_geglu'] = True
res['activation_func_clamp_value'] = 7
res['glu_linear_offset'] = 1
- res['window_size'] = f'{window_size},0'
+ res['window_size'] = f'{window_size - 1},0'
if layer_types is None:
res['window_attn_skip_freq'] = '2'
else:
diff --git a/src/mcore_bridge/model/constant.py b/src/mcore_bridge/model/constant.py
index 58b0ee3..9b8dc1b 100644
--- a/src/mcore_bridge/model/constant.py
+++ b/src/mcore_bridge/model/constant.py
@@ -29,6 +29,7 @@ class MLLMModelType:
glm4v_moe = 'glm4v_moe'
kimi_vl = 'kimi_vl'
llama4 = 'llama4'
+ gemma4 = 'gemma4'
kimi_k25 = 'kimi_k25'
diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py
index 02f8bb0..84bf56d 100644
--- a/src/mcore_bridge/model/gpt_model.py
+++ b/src/mcore_bridge/model/gpt_model.py
@@ -55,6 +55,7 @@ def sharded_state_dict(
class GPTModel(McoreGPTModel):
config: ModelConfig
+ extra_forward_keys = []
def __init__(
self,
@@ -105,9 +106,7 @@ def __init__(
for i in range(len(self.decoder.layers)):
if hasattr(self.decoder.layers[i].self_attention, 'rotary_pos_emb'):
del self.decoder.layers[i].self_attention.rotary_pos_emb
- self.attention_scaling = 1.
- new_inv_freq, self.attention_scaling = get_rope_inv_freq(config)
- self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device)
+ self._set_inv_freq()
if self.config.task_type == 'seq_cls' and self.post_process:
self.output_layer = OutputLayerLinear(
config.hidden_size,
@@ -216,7 +215,35 @@ def _preprocess(
if decoder_input is not None and self.training and torch.is_grad_enabled() and not decoder_input.requires_grad:
# fix LoRA incompatibility with gradient checkpointing
decoder_input = decoder_input.requires_grad_(True)
+ rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb(
+ decoder_input, position_ids, packed_seq_params=packed_seq_params)
+ if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration')
+ or self.config.flash_decode) and rotary_pos_cos is not None
+ and inference_context.is_static_batching()):
+ current_batch_size = input_ids.shape[0]
+ sequence_len_offset = torch.tensor(
+ [inference_context.sequence_len_offset] * current_batch_size,
+ dtype=torch.int32,
+ device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
+ )
+ else:
+ sequence_len_offset = None
+
+ # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
+ # reference held by this caller function, enabling early garbage collection for
+ # inference. Skip wrapping if decoder_input is logged after decoder completion.
+ if in_inference_mode and not has_config_logger_enabled(self.config):
+ decoder_input = WrappedTensor(decoder_input)
+
+ return (decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset)
+
+ def _set_inv_freq(self):
+ self.attention_scaling = 1.
+ new_inv_freq, self.attention_scaling = get_rope_inv_freq(self.config)
+ self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device)
+
+ def _get_rotary_pos_emb(self, decoder_input, position_ids, packed_seq_params, inference_context=None):
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
rotary_pos_cos = None
@@ -251,26 +278,7 @@ def _preprocess(
rotary_seq_len,
packed_seq=packed_seq,
)
-
- if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration')
- or self.config.flash_decode) and rotary_pos_cos is not None
- and inference_context.is_static_batching()):
- current_batch_size = input_ids.shape[0]
- sequence_len_offset = torch.tensor(
- [inference_context.sequence_len_offset] * current_batch_size,
- dtype=torch.int32,
- device=rotary_pos_cos.device, # Co-locate this with the rotary tensors
- )
- else:
- sequence_len_offset = None
-
- # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the
- # reference held by this caller function, enabling early garbage collection for
- # inference. Skip wrapping if decoder_input is logged after decoder completion.
- if in_inference_mode and not has_config_logger_enabled(self.config):
- decoder_input = WrappedTensor(decoder_input)
-
- return decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset
+ return rotary_pos_emb, rotary_pos_cos, rotary_pos_sin
# Code borrowed from NVIDIA/Megatron-LM
def forward(
@@ -314,7 +322,11 @@ def forward(
packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd'
if self.position_embedding_type == 'rope' and packed_seq and not self.config.apply_rope_fusion:
assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}'
- decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]]
+ if isinstance(rotary_pos_emb, dict):
+ for k, v in rotary_pos_emb.items():
+ decoder_rotary_pos_emb[k] = v[position_ids[0]]
+ else:
+ decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]]
mtp_decoder_input = decoder_input
if self.config.is_multimodal and self.config.mtp_num_layers and decoder_input is None:
@@ -322,9 +334,12 @@ def forward(
input_tensor, mtp_decoder_input = input_tensor.chunk(2, dim=0)
self.set_input_tensor(input_tensor)
kwargs = {}
+ full_attention_mask = attention_mask
+ if isinstance(full_attention_mask, dict):
+ full_attention_mask = full_attention_mask['full_attention']
if mcore_016 and attention_mask is not None:
assert packed_seq_params is None
- padding_mask = ~((~attention_mask).sum(dim=(1, 2)) > 0)
+ padding_mask = ~((~full_attention_mask).sum(dim=(1, 2)) > 0)
if self.config.context_parallel_size > 1:
padding_mask = split_cp_inputs(padding_mask, None, 1)
tp_size = self.config.tensor_model_parallel_size
@@ -366,6 +381,9 @@ def forward(
inference_context=inference_context,
)
+ def _forward_output_layer(self, hidden_states, *args, **kwargs):
+ return self.output_layer(hidden_states, *args, **kwargs)[0]
+
def _postprocess(
self,
hidden_states,
@@ -432,7 +450,7 @@ def _postprocess(
loss_mask = torch.ones_like(mtp_labels)
for mtp_layer_number in range(self.config.mtp_unroll_steps):
# output
- mtp_logits, _ = self.output_layer(
+ mtp_logits = self._forward_output_layer(
hidden_states_list[mtp_layer_number + 1],
weight=output_weight,
runtime_gather_output=runtime_gather_output,
@@ -493,7 +511,7 @@ def _postprocess(
if self.config.task_type == 'embedding':
logits = F.normalize(hidden_states, p=2, dim=-1)
else:
- logits, _ = self.output_layer(
+ logits = self._forward_output_layer(
hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output)
if self.config.task_type == 'generative_reranker':
logits = gather_from_tensor_model_parallel_region(logits)
diff --git a/src/mcore_bridge/model/gpts/bailing_moe.py b/src/mcore_bridge/model/gpts/bailing_moe.py
index f3c383a..7e2a70d 100644
--- a/src/mcore_bridge/model/gpts/bailing_moe.py
+++ b/src/mcore_bridge/model/gpts/bailing_moe.py
@@ -71,7 +71,7 @@ class BailingMoeBridge(GPTBridge):
hf_expert_bias_key = 'gate.expert_bias'
hf_o_proj_key = 'dense'
- def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool):
+ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs):
self._set_state_dict(mg_attn, 'linear_qkv.weight', hf_state_dict, 'query_key_value.weight', to_mcore)
assert not self.config.add_bias_linear
return hf_state_dict
diff --git a/src/mcore_bridge/model/gpts/minimax_m2.py b/src/mcore_bridge/model/gpts/minimax_m2.py
index 42aa4c6..853e3a5 100644
--- a/src/mcore_bridge/model/gpts/minimax_m2.py
+++ b/src/mcore_bridge/model/gpts/minimax_m2.py
@@ -74,7 +74,7 @@ class MinimaxM2Bridge(GPTBridge):
hf_mlp_prefix = 'block_sparse_moe'
hf_expert_bias_key = 'e_score_correction_bias'
- def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore):
+ def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore, **kwargs):
self._set_state_dict(mg_attn, 'q_norm.weight', hf_state_dict, 'q_norm.weight', to_mcore)
self._set_state_dict(mg_attn, 'k_norm.weight', hf_state_dict, 'k_norm.weight', to_mcore)
diff --git a/src/mcore_bridge/model/mm_gpt_model.py b/src/mcore_bridge/model/mm_gpt_model.py
index 2a5f659..2a80c04 100644
--- a/src/mcore_bridge/model/mm_gpt_model.py
+++ b/src/mcore_bridge/model/mm_gpt_model.py
@@ -79,6 +79,7 @@ def forward(
packed_seq_params: PackedSeqParams = None,
**kwargs,
) -> torch.Tensor:
+ extra_kwargs = {k: kwargs[k] for k in self.language_model.extra_forward_keys}
if decoder_input is not None:
pass
elif self.pre_process:
@@ -90,6 +91,7 @@ def forward(
# decoder will get hidden_states from encoder.input_tensor
decoder_input = None
kwargs = {}
+ kwargs.update(extra_kwargs)
return self.language_model(
input_ids=input_ids,
position_ids=position_ids,
diff --git a/src/mcore_bridge/model/mm_gpts/__init__.py b/src/mcore_bridge/model/mm_gpts/__init__.py
index 9009edb..b862ec6 100644
--- a/src/mcore_bridge/model/mm_gpts/__init__.py
+++ b/src/mcore_bridge/model/mm_gpts/__init__.py
@@ -1,2 +1,2 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
-from . import glm, internvl, kimi_vl, llama4, llava, qwen, qwen3_5, qwen3_5_gdn, qwen3_asr, qwen3_omni, qwen3_vl
+from . import gemma4, glm, internvl, kimi_vl, llama4, llava, qwen, qwen3_5, qwen3_5_gdn, qwen3_asr, qwen3_omni, qwen3_vl
diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py
new file mode 100644
index 0000000..2c68a73
--- /dev/null
+++ b/src/mcore_bridge/model/mm_gpts/gemma4.py
@@ -0,0 +1,814 @@
+# Copyright (c) ModelScope Contributors. All rights reserved.
+import copy
+import math
+import torch
+from megatron.core.extensions.transformer_engine import (SplitAlongDim, TEColumnParallelLinear, TENorm,
+ TERowParallelLinear)
+from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb
+from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
+from megatron.core.models.common.embeddings.yarn_rotary_pos_embedding import _yarn_get_concentration_factor_from_config
+from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec
+from megatron.core.packed_seq_params import PackedSeqParams
+from megatron.core.parallel_state import get_tensor_model_parallel_rank
+from megatron.core.tensor_parallel import VocabParallelEmbedding, all_gather_last_dim_from_tensor_parallel_region
+from megatron.core.tensor_parallel.mappings import (gather_from_tensor_model_parallel_region,
+ scatter_to_tensor_model_parallel_region)
+from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
+from megatron.core.transformer.enums import AttnMaskType
+from megatron.core.transformer.identity_op import IdentityOp
+from megatron.core.transformer.mlp import MLP
+from megatron.core.transformer.moe.moe_layer import MoELayer
+from megatron.core.transformer.spec_utils import build_module
+from megatron.core.utils import make_viewless_tensor, nvtx_range_pop, nvtx_range_push
+from torch import Tensor, nn
+from transformers import AutoModel, PretrainedConfig
+from transformers.utils.versions import require_version
+from typing import Optional, Tuple
+
+from mcore_bridge.bridge import MultimodalGPTBridge
+from mcore_bridge.config import ModelConfig
+
+from ..constant import ModelType
+from ..gpt_model import GPTModel
+from ..mm_gpt_model import MultimodalGPTModel
+from ..modules import TransformerBlock, TransformerLayer
+from ..register import ModelLoader, ModelMeta, register_model
+from ..rope import get_rope_inv_freq
+from .utils import HuggingFaceVit
+
+
+class Gemma4RMSNormNoScale(torch.nn.Module):
+ """RMSNorm without learnable scale, mirroring HF `Gemma4RMSNorm(with_scale=False)`."""
+
+ def __init__(self, dim: int, eps: float = 1e-6):
+ super().__init__()
+ self.eps = eps
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ orig_dtype = x.dtype
+ x = x.to(torch.float32)
+ variance = x.pow(2).mean(-1, keepdim=True)
+ return (x * torch.rsqrt(variance + self.eps)).to(orig_dtype)
+
+
+class Gemma4Vit(HuggingFaceVit):
+ module_mapping = {
+ 'model.vision_tower': 'vision_tower',
+ 'model.embed_vision': 'embed_vision',
+ 'model.audio_tower': 'audio_tower',
+ 'model.embed_audio': 'embed_audio',
+ }
+ _vision_tower = ['vision_tower', 'audio_tower']
+ _aligner = ['embed_vision', 'embed_audio']
+
+ def prepare_model(self, hf_config: PretrainedConfig):
+ from transformers.models.gemma4.modeling_gemma4 import Gemma4Model, Gemma4MultimodalEmbedder
+ self.vision_tower = AutoModel.from_config(hf_config.vision_config)
+ dtype = self.vision_tower.dtype
+ self.audio_tower = AutoModel.from_config(hf_config.audio_config) if hf_config.audio_config is not None else None
+ self.embed_vision = Gemma4MultimodalEmbedder(hf_config.vision_config, hf_config.text_config).to(dtype)
+ self.embed_audio = (
+ Gemma4MultimodalEmbedder(hf_config.audio_config, hf_config.text_config).to(dtype)
+ if hf_config.audio_config is not None else None)
+ self.register_buffer('embed_scale', torch.tensor(hf_config.hidden_size**0.5).to(dtype), persistent=False)
+ self.model_cls = Gemma4Model
+
+ def get_inputs_embeds(self, inputs_embeds, **kwargs):
+ input_ids = kwargs.get('input_ids')
+ inputs_embeds = inputs_embeds * self.embed_scale.to(inputs_embeds.dtype)
+
+ hf_config = self.hf_config
+ pixel_values = kwargs.get('pixel_values')
+ pixel_values_videos = kwargs.get('pixel_values_videos')
+ input_features = kwargs.get('input_features')
+ input_features_mask = kwargs.get('input_features_mask')
+ image_position_ids = kwargs.get('image_position_ids')
+ video_position_ids = kwargs.get('video_position_ids')
+
+ image_mask = input_ids == hf_config.image_token_id
+ video_mask = input_ids == hf_config.video_token_id
+ audio_mask = input_ids == hf_config.audio_token_id
+ multimodal_mask = image_mask | video_mask | audio_mask
+ llm_input_ids = input_ids.clone()
+ llm_input_ids[multimodal_mask] = hf_config.text_config.pad_token_id
+
+ if pixel_values is not None:
+ with self.patch_hf_config():
+ image_features = self.model_cls.get_image_features(
+ self, pixel_values, image_position_ids, return_dict=True).pooler_output
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ image_mask_e = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask_e, image_features)
+
+ if pixel_values_videos is not None:
+ with self.patch_hf_config():
+ video_features = self.model_cls.get_video_features(
+ self, pixel_values_videos, video_position_ids, return_dict=True).pooler_output
+ video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ video_mask_e = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask_e, video_features)
+
+ if (input_features is not None and input_features_mask is not None and self.audio_tower is not None):
+ with self.patch_hf_config():
+ audio_output = self.model_cls.get_audio_features(
+ self, input_features, input_features_mask, return_dict=True)
+ audio_features = audio_output.pooler_output
+ audio_features = audio_features[audio_output.attention_mask]
+ audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
+ audio_mask_e = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
+ inputs_embeds = inputs_embeds.masked_scatter(audio_mask_e, audio_features)
+ return {'inputs_embeds': inputs_embeds, 'llm_input_ids': llm_input_ids}
+
+
+class Gemma4SelfAttention(SelfAttention):
+
+ def __init__(
+ self,
+ config: ModelConfig,
+ submodules: SelfAttentionSubmodules,
+ layer_number: int,
+ *args,
+ **kwargs,
+ ):
+ text_config = config.hf_config.text_config
+ layer_idx = layer_number - 1
+
+ # Layer type / sliding attention
+ self.layer_type = text_config.layer_types[layer_idx]
+ self.is_sliding = self.layer_type == 'sliding_attention'
+
+ # Head dim: global layers may use a different head dim than sliding ones
+ self.head_dim = (
+ text_config.global_head_dim
+ if not self.is_sliding and text_config.global_head_dim else text_config.head_dim)
+
+ self.use_alternative_attention = (text_config.attention_k_eq_v and not self.is_sliding)
+ self.num_key_value_heads = (
+ text_config.num_global_key_value_heads
+ if self.use_alternative_attention else text_config.num_key_value_heads)
+ # Shared KV across the trailing layers
+ self.num_kv_shared_layers = getattr(text_config, 'num_kv_shared_layers', 0)
+ first_kv_shared_layer_idx = config.num_layers - self.num_kv_shared_layers
+ self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
+ prev_layers = text_config.layer_types[:first_kv_shared_layer_idx]
+ self.store_full_length_kv = not self.is_kv_shared_layer and layer_idx == len(
+ prev_layers) - 1 - prev_layers[::-1].index(text_config.layer_types[layer_idx])
+
+ orig_kv_channels = config.kv_channels
+ orig_num_query_groups = config.num_query_groups
+ orig_k_layernorm = submodules.k_layernorm
+ config.kv_channels = self.head_dim
+ config.num_query_groups = self.num_key_value_heads
+ if self.is_sliding and config.window_size is None:
+ kwargs['attn_mask_type'] = AttnMaskType.arbitrary
+ if self.is_kv_shared_layer:
+ submodules.k_layernorm = IdentityOp
+ try:
+ super().__init__(config, submodules, layer_number, *args, **kwargs)
+ finally:
+ config.kv_channels = orig_kv_channels
+ config.num_query_groups = orig_num_query_groups
+ submodules.k_layernorm = orig_k_layernorm
+
+ if self.is_kv_shared_layer or self.use_alternative_attention:
+ linear_qkv_dim = self.query_projection_size
+ if not self.is_kv_shared_layer:
+ linear_qkv_dim += self.kv_projection_size
+ self.linear_qkv = submodules.linear_qkv(
+ self.config.hidden_size,
+ linear_qkv_dim,
+ config=self.config,
+ init_method=self.config.init_method,
+ gather_output=False,
+ bias=self.config.add_bias_linear or self.config.add_qkv_bias,
+ skip_bias_add=False,
+ is_expert=False,
+ tp_comm_buffer_name='qkv',
+ tp_group=self.pg_collection.tp,
+ )
+
+ if not self.is_kv_shared_layer:
+ self.v_norm = Gemma4RMSNormNoScale(self.head_dim, eps=self.config.layernorm_epsilon)
+
+ def _forward_core_attention(
+ self,
+ query,
+ key,
+ value,
+ attention_mask,
+ attention_bias: Optional[torch.Tensor] = None,
+ packed_seq_params: Optional[PackedSeqParams] = None,
+ ):
+ nvtx_range_push(suffix='core_attention')
+ attn_mask_type = self.attn_mask_type
+ if self.checkpoint_core_attention and self.training:
+ core_attn_out = self._checkpointed_attention_forward(
+ query,
+ key,
+ value,
+ attention_mask,
+ attn_mask_type=attn_mask_type,
+ attention_bias=attention_bias,
+ packed_seq_params=packed_seq_params,
+ )
+ else:
+ core_attn_out = self.core_attention(
+ query,
+ key,
+ value,
+ attention_mask,
+ attn_mask_type=attn_mask_type,
+ attention_bias=attention_bias,
+ packed_seq_params=packed_seq_params,
+ )
+ if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd':
+ # reshape to same output shape as unpacked case
+ # (t, np, hn) -> (t, b=1, h=np*hn)
+ # t is the pack size = sum (sq_i)
+ # note that batch is a dummy dimension in the packed case
+ core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1)
+ nvtx_range_pop(suffix='core_attention')
+ return core_attn_out
+
+ def _apply_rotary(self, query, key, rotary_pos_emb, packed_seq_params):
+ nvtx_range_push(suffix='rotary_pos_emb')
+ q_pos_emb, k_pos_emb = rotary_pos_emb
+
+ if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd':
+ if packed_seq_params.cu_seqlens_q_padded is not None:
+ cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded
+ else:
+ cu_seqlens_q = packed_seq_params.cu_seqlens_q
+ if packed_seq_params.cu_seqlens_kv_padded is not None:
+ cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded
+ else:
+ cu_seqlens_kv = packed_seq_params.cu_seqlens_kv
+ else:
+ cu_seqlens_q = cu_seqlens_kv = None
+
+ if q_pos_emb is not None:
+ # TODO VIJAY: simplify
+ query = apply_rotary_pos_emb(
+ query,
+ q_pos_emb,
+ config=self.config,
+ cu_seqlens=cu_seqlens_q,
+ mscale=_yarn_get_concentration_factor_from_config(self.config),
+ cp_group=self.pg_collection.cp,
+ )
+ if not self.is_kv_shared_layer and k_pos_emb is not None:
+ key = apply_rotary_pos_emb(
+ key,
+ k_pos_emb,
+ config=self.config,
+ cu_seqlens=cu_seqlens_kv,
+ mscale=_yarn_get_concentration_factor_from_config(self.config),
+ cp_group=self.pg_collection.cp,
+ )
+ nvtx_range_pop(suffix='rotary_pos_emb')
+ return query, key
+
+ def forward(self, hidden_states: Tensor, attention_mask: Tensor, **kwargs) -> Tuple[Tensor, Tensor]:
+ shared_kv_states = kwargs['shared_kv_states']
+ rotary_pos_emb = kwargs.get('rotary_pos_emb')
+ packed_seq_params = kwargs.get('packed_seq_params')
+ attention_bias = kwargs.get('attention_bias')
+ mixed_qkv, _ = self.linear_qkv(hidden_states)
+ if getattr(self, 'world_size', None) is not None and self.config.num_query_groups < self.world_size:
+ mixed_qkv = all_gather_last_dim_from_tensor_parallel_region(mixed_qkv)
+ idx = get_tensor_model_parallel_rank() // (self.world_size // self.config.num_query_groups)
+ size = mixed_qkv.size()[-1] // self.config.num_query_groups
+ mixed_qkv = mixed_qkv[:, :, idx * size:(idx + 1) * size]
+
+ if self.is_kv_shared_layer:
+ query = mixed_qkv
+ key, value = shared_kv_states[self.layer_type]
+ else:
+ kv_heads_per_group = 1 if self.use_alternative_attention else 2
+ num_query_heads_per_group = (self.num_attention_heads_per_partition // self.num_query_groups_per_partition)
+ new_tensor_shape = mixed_qkv.size()[:-1] + (
+ self.num_query_groups_per_partition,
+ (num_query_heads_per_group + kv_heads_per_group) * self.hidden_size_per_attention_head,
+ )
+ mixed_qkv = mixed_qkv.view(*new_tensor_shape)
+ split_arg_list = [num_query_heads_per_group * self.hidden_size_per_attention_head
+ ] + [self.hidden_size_per_attention_head] * kv_heads_per_group
+ if SplitAlongDim is not None:
+ qkv = SplitAlongDim(mixed_qkv, len(split_arg_list), split_arg_list)
+ else:
+ qkv = torch.split(mixed_qkv, split_arg_list, dim=3)
+ if self.use_alternative_attention:
+ query, key = qkv
+ value = key
+ else:
+ query, key, value = qkv
+ key = self.k_layernorm(key)
+ value = self.v_norm(value)
+ # Query [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
+ query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)
+ if getattr(self, 'world_size', None) is not None and self.config.num_query_groups < self.world_size:
+ idx = get_tensor_model_parallel_rank() % (self.world_size // self.config.num_query_groups)
+ size = query.shape[2] // (self.world_size // self.config.num_query_groups)
+ query = query[:, :, idx * size:(idx + 1) * size, :]
+ query = self.q_layernorm(query)
+ if isinstance(rotary_pos_emb, torch.Tensor):
+ rotary_pos_emb = (rotary_pos_emb, ) * 2
+
+ query, key = self._apply_rotary(query, key, rotary_pos_emb, packed_seq_params)
+ if self.store_full_length_kv:
+ shared_kv_states[self.layer_type] = key, value
+ core_attn_out = self._forward_core_attention(query, key, value, attention_mask, attention_bias,
+ packed_seq_params)
+
+ nvtx_range_push(suffix='linear_proj')
+ output, bias = self.linear_proj(core_attn_out)
+ nvtx_range_pop(suffix='linear_proj')
+ return output, bias
+
+
+class Gemma4MLP(MLP):
+
+ def __init__(
+ self,
+ config: ModelConfig,
+ submodules: SelfAttentionSubmodules,
+ layer_number: int,
+ *args,
+ **kwargs,
+ ):
+ self.layer_number = layer_number
+ text_config = config.hf_config.text_config
+ first_kv_shared_layer_idx = config.num_layers - text_config.num_kv_shared_layers
+ is_kv_shared_layer = layer_number > first_kv_shared_layer_idx > 0
+ use_double_wide_mlp = text_config.use_double_wide_mlp and is_kv_shared_layer
+ ffn_hidden_size = config.ffn_hidden_size
+ config.ffn_hidden_size = config.ffn_hidden_size * (2 if use_double_wide_mlp else 1)
+ try:
+ super().__init__(config, submodules, *args, **kwargs)
+ finally:
+ config.ffn_hidden_size = ffn_hidden_size
+
+
+class Gemma4MoELayer(MoELayer):
+
+ def __init__(self, config, *args, **kwargs):
+ require_version('megatron-core>=0.16.0.dev', 'Gemma4MoELayer requires megatron-core>=0.16.0')
+ super().__init__(config, *args, **kwargs)
+ self.pre_feedforward_layernorm_2 = build_module(
+ TENorm, hidden_size=config.hidden_size, config=config, eps=config.layernorm_epsilon)
+ self.norm = Gemma4RMSNormNoScale(config.hidden_size, eps=self.config.layernorm_epsilon)
+ self.scalar_root_size = config.hidden_size**-0.5
+ self.scale = nn.Parameter(torch.ones(config.hidden_size))
+ self.per_expert_scale = nn.Parameter(torch.ones(config.num_moe_experts))
+ self.scale.sequence_parallel = config.sequence_parallel
+ self.per_expert_scale.sequence_parallel = config.sequence_parallel
+
+ def route(self, hidden_states: torch.Tensor, *args, **kwargs):
+ hidden_states = self.norm(hidden_states)
+ hidden_states = hidden_states * self.scale * self.scalar_root_size
+ probs, routing_map = super().route(hidden_states)
+ probs = probs * self.per_expert_scale
+ return probs, routing_map
+
+ def preprocess(self, hidden_states: torch.Tensor, *args, **kwargs):
+ hidden_states = self.pre_feedforward_layernorm_2(hidden_states)
+ return super().preprocess(hidden_states, *args, **kwargs)
+
+
+class Gemma4Bridge(MultimodalGPTBridge):
+ hf_post_attention_layernorm = 'pre_feedforward_layernorm'
+ additional_dim0_keys = {'embed_tokens_per_layer', 'per_layer_input_gate', 'per_layer_model_projection'}
+ additional_dim1_keys = {'per_layer_projection'}
+
+ def __init__(self, config: ModelConfig):
+ super().__init__(config)
+ self.text_config = config.hf_config.text_config
+ self.hidden_size_per_layer_input = self.text_config.hidden_size_per_layer_input
+
+ def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore, **kwargs):
+ layer_idx = kwargs['layer_idx']
+ is_kv_shared_layer = self._get_is_kv_shared_layer(layer_idx)
+ self._set_state_dict(mg_attn, 'q_layernorm.weight', hf_state_dict, self.hf_q_norm_key, to_mcore)
+ if not is_kv_shared_layer:
+ self._set_state_dict(mg_attn, 'k_layernorm.weight', hf_state_dict, self.hf_k_norm_key, to_mcore)
+
+ def _get_is_kv_shared_layer(self, layer_idx):
+ text_config = self.text_config
+ num_kv_shared_layers = getattr(text_config, 'num_kv_shared_layers', 0)
+ first_kv_shared_layer_idx = self.config.num_layers - num_kv_shared_layers
+ is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0
+ return is_kv_shared_layer
+
+ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs):
+ text_config = self.text_config
+ layer_idx = kwargs['layer_idx']
+ is_sliding = text_config.layer_types[layer_idx] == 'sliding_attention'
+ head_dim = (
+ text_config.global_head_dim if not is_sliding and text_config.global_head_dim else text_config.head_dim)
+ is_kv_shared_layer = self._get_is_kv_shared_layer(layer_idx)
+ if is_kv_shared_layer:
+ self._set_state_dict(mg_attn, 'linear_qkv.weight', hf_state_dict, 'q_proj.weight', to_mcore)
+ return hf_state_dict
+ else:
+ kwargs['kv_channels'] = head_dim
+ use_alternative_attention = text_config.attention_k_eq_v and not is_sliding
+ kwargs['attention_k_eq_v'] = use_alternative_attention
+ kwargs['num_query_groups'] = (
+ text_config.num_global_key_value_heads
+ if use_alternative_attention else text_config.num_key_value_heads)
+ return super()._set_qkv(mg_attn, hf_state_dict, to_mcore, **kwargs)
+
+ def _set_router(self, mg_mlp, hf_state_dict, to_mcore):
+ self._set_state_dict(mg_mlp, 'router.weight', hf_state_dict, 'router.proj.weight', to_mcore)
+ for key in ['per_expert_scale', 'scale']:
+ self._set_state_dict(mg_mlp, key, hf_state_dict, f'router.{key}', to_mcore)
+
+ def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool, is_mtp: bool = False):
+ mg_mlp = None if mg_layer is None else mg_layer.mlp
+ hf_state_dict.update(self._set_mlp_state(mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore))
+ self._set_state_dict(mg_layer, 'mlp.linear_fc1.layer_norm_weight', hf_state_dict,
+ f'{self.hf_post_attention_layernorm}.weight', to_mcore)
+ if self.text_config.enable_moe_block:
+ mg_experts = None if mg_layer is None else mg_layer.experts_mlp
+ hf_state_dict.update(self._set_moe_state(mg_experts, hf_state_dict, '', layer_idx, to_mcore, is_mtp=is_mtp))
+ return hf_state_dict
+
+ def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool):
+ hf_prefix = f'{hf_prefix}{layer_idx}.'
+ if to_mcore:
+ hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix)
+ else:
+ hf_state_dict = {}
+ hf_state_dict.update(self._set_layer_attn(mg_layer, hf_state_dict, layer_idx, to_mcore))
+ hf_state_dict.update(self._set_layer_mlp(mg_layer, hf_state_dict, layer_idx, to_mcore))
+ for key in ['post_attention_layernorm', 'post_feedforward_layernorm']:
+ self._set_state_dict(mg_layer, f'{key}.weight', hf_state_dict, f'{key}.weight', to_mcore)
+ if self.hidden_size_per_layer_input:
+ for key in ['per_layer_input_gate', 'per_layer_projection', 'post_per_layer_input_norm']:
+ self._set_state_dict(mg_layer, f'{key}.weight', hf_state_dict, f'{key}.weight', to_mcore)
+ self._set_state_dict(mg_layer, 'layer_scalar', hf_state_dict, 'layer_scalar', to_mcore)
+ if self.text_config.enable_moe_block:
+ for key in ['post_feedforward_layernorm_1', 'post_feedforward_layernorm_2']:
+ self._set_state_dict(mg_layer, f'{key}.weight', hf_state_dict, f'{key}.weight', to_mcore)
+ self._set_state_dict(mg_layer, 'experts_mlp.pre_feedforward_layernorm_2.weight', hf_state_dict,
+ 'pre_feedforward_layernorm_2.weight', to_mcore)
+ if to_mcore:
+ hf_state_dict = {}
+ else:
+ hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix)
+ return hf_state_dict
+
+ def _set_word_embeddings(self, mg_model, hf_state_dict, to_mcore):
+ lm_model = getattr(mg_model, 'language_model') if self.is_multimodal else mg_model
+ self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, self.hf_embed_key, to_mcore)
+ if self.hidden_size_per_layer_input:
+ for key in ['embed_tokens_per_layer', 'per_layer_model_projection', 'per_layer_projection_norm']:
+ self._set_state_dict(lm_model, f'{key}.weight', hf_state_dict, f'model.language_model.{key}.weight',
+ to_mcore)
+
+
+class Gemma4TextGPTModel(GPTModel):
+ extra_forward_keys = ['mm_token_type_ids']
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.num_query_groups_per_partition = self.decoder.layers[0].self_attention.num_query_groups_per_partition
+ text_config = self.config.hf_config.text_config
+ self.text_config = text_config
+ self.num_kv_shared_layers = getattr(text_config, 'num_kv_shared_layers', 0)
+ self.unique_layer_types = set(text_config.layer_types)
+ self.hidden_size_per_layer_input = text_config.hidden_size_per_layer_input
+ self.final_logit_softcapping = text_config.final_logit_softcapping
+ if self.hidden_size_per_layer_input and self.pre_process:
+ total_dim = self.config.num_layers * self.hidden_size_per_layer_input
+ self.embed_tokens_per_layer = VocabParallelEmbedding(
+ num_embeddings=self.vocab_size,
+ embedding_dim=total_dim,
+ init_method=self.config.init_method,
+ config=self.config,
+ tp_group=self.pg_collection.tp,
+ )
+ self.embed_tokens_per_layer_scale = self.hidden_size_per_layer_input**0.5
+ self.per_layer_input_scale = 2.0**-0.5
+ self.per_layer_model_projection = build_module(
+ TEColumnParallelLinear,
+ self.config.hidden_size,
+ total_dim,
+ config=self.config,
+ init_method=self.config.init_method,
+ gather_output=False,
+ bias=False,
+ skip_bias_add=False,
+ is_expert=False,
+ tp_comm_buffer_name='per_layer_model_projection',
+ tp_group=self.pg_collection.tp,
+ )
+ self.per_layer_model_projection_scale = self.config.hidden_size**-0.5
+ self.per_layer_projection_norm = build_module(
+ TENorm,
+ hidden_size=self.hidden_size_per_layer_input,
+ config=self.config,
+ eps=self.config.layernorm_epsilon,
+ )
+
+ def _get_rotary_pos_emb(self, decoder_input, position_ids, packed_seq_params, inference_context=None):
+ rotary_seq_len = RotaryEmbedding.get_rotary_seq_len(self, inference_context, self.decoder, decoder_input,
+ self.config, packed_seq_params)
+ packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd'
+ rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq)
+ full_rotary_pos_emb = self.full_rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq)
+ rotary_pos_emb = {'sliding_attention': rotary_pos_emb, 'full_attention': full_rotary_pos_emb}
+ return rotary_pos_emb, None, None
+
+ def _set_inv_freq(self):
+ rope_scaling = self.config.rope_scaling
+ self.config.rope_scaling = rope_scaling['sliding_attention']
+ new_inv_freq, attention_scaling = get_rope_inv_freq(self.config)
+ assert attention_scaling == 1, 'not support'
+ self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device)
+ # full
+ self.full_rotary_pos_emb = copy.copy(self.rotary_pos_emb)
+ self.config.rope_scaling = rope_scaling['full_attention']
+ kwargs = {'layer_type': 'full_attention'}
+ if self.config.rope_scaling['rope_type'] == 'proportional':
+ kwargs['head_dim_key'] = 'global_head_dim'
+ new_inv_freq, attention_scaling = get_rope_inv_freq(
+ self.config, text_config=self.config.hf_config.text_config, **kwargs)
+ assert attention_scaling == 1, 'not support'
+ self.full_rotary_pos_emb.inv_freq = new_inv_freq
+ self.attention_scaling = attention_scaling
+
+ self.config.rope_scaling = rope_scaling
+
+ def forward(self, *args, **kwargs):
+ extra_block_kwargs = kwargs.pop('extra_block_kwargs', None) or {}
+ llm_input_ids = extra_block_kwargs.pop('llm_input_ids', None)
+ decoder_input = kwargs.get('decoder_input')
+ mm_token_type_ids = extra_block_kwargs.pop('mm_token_type_ids', None)
+ shared_kv_states = {}
+ if self.hidden_size_per_layer_input:
+ assert self.num_kv_shared_layers > 0, 'not support'
+ if decoder_input is None:
+ per_layer_inputs, shared_kv_states = self.unpack_pp_input()
+ else:
+ inputs_embeds = decoder_input
+ per_layer_inputs = self.embed_tokens_per_layer(llm_input_ids) * self.embed_tokens_per_layer_scale
+ per_layer_inputs = per_layer_inputs.reshape(*per_layer_inputs.shape[:-1], self.config.num_layers,
+ -1).transpose(0, 1)
+ per_layer_projection = self.per_layer_model_projection(
+ inputs_embeds)[0] * self.per_layer_model_projection_scale
+ per_layer_projection = gather_from_tensor_model_parallel_region(per_layer_projection)
+ per_layer_projection = per_layer_projection.reshape(*per_layer_projection.shape[:-1],
+ self.config.num_layers, -1)
+ per_layer_projection = self.per_layer_projection_norm(per_layer_projection)
+ per_layer_inputs = (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale
+ per_layer_inputs = scatter_to_tensor_model_parallel_region(per_layer_inputs)
+ extra_block_kwargs['per_layer_inputs'] = per_layer_inputs
+ else:
+ assert self.num_kv_shared_layers == 0, 'not support'
+ extra_block_kwargs['shared_kv_states'] = shared_kv_states
+ kwargs['extra_block_kwargs'] = extra_block_kwargs
+ attention_mask = kwargs.get('attention_mask')
+ kwargs['attention_mask'] = {'sliding_attention': attention_mask, 'full_attention': attention_mask}
+ if self.text_config.use_bidirectional_attention == 'vision':
+ kwargs['attention_mask']['sliding_attention'] = self._create_sliding_attention_mask(
+ attention_mask, mm_token_type_ids)
+ hidden_states = super().forward(*args, **kwargs)
+ if self.hidden_size_per_layer_input and not self.post_process:
+ hidden_states = self._pack_pp_output(hidden_states, per_layer_inputs, shared_kv_states)
+ return hidden_states
+
+ def _create_sliding_attention_mask(self, attention_mask, mm_token_type_ids):
+ window_size = self.text_config.sliding_window - 1
+ seq_len = attention_mask.shape[-1]
+
+ window_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=attention_mask.device)
+ window_mask = ~torch.triu(window_mask, diagonal=-window_size)
+
+ attention_mask = attention_mask | window_mask
+ if mm_token_type_ids is not None:
+ is_vision = mm_token_type_ids > 0
+ is_prev_vision = torch.roll(is_vision, shifts=1, dims=-1)
+ is_prev_vision[:, 0] = False
+ vision_group_ids = torch.cumsum((is_vision & ~is_prev_vision).int(), dim=1) - 1
+ vision_group_ids = torch.where(is_vision, vision_group_ids, torch.full_like(vision_group_ids, -1))
+ q_group = vision_group_ids.unsqueeze(1).unsqueeze(-1)
+ k_group = vision_group_ids.unsqueeze(1).unsqueeze(-2)
+ same_vision_group = (q_group == k_group) & (q_group >= 0) & (k_group >= 0)
+ attention_mask = attention_mask & ~same_vision_group
+ return attention_mask
+
+ def _pack_pp_output(self, hidden_states, per_layer_inputs, shared_kv_states):
+ per_layer_inputs = per_layer_inputs.view(*hidden_states.shape[:2], -1)
+ hidden_states = torch.concat([hidden_states, per_layer_inputs], dim=-1)
+ flag = per_layer_inputs.new_zeros(*hidden_states.shape[:2], 2)
+ if 'sliding_attention' in shared_kv_states:
+ flag[0, 0, 0] = 1
+ sliding_states = torch.concat(shared_kv_states['sliding_attention'], -1)
+ sliding_states = sliding_states.view(*hidden_states.shape[:2], -1)
+ hidden_states = torch.concat([hidden_states, sliding_states], dim=-1)
+ if 'full_attention' in shared_kv_states:
+ flag[0, 0, 1] = 1
+ full_states = torch.concat(shared_kv_states['full_attention'], -1)
+ full_states = full_states.view(*hidden_states.shape[:2], -1)
+ hidden_states = torch.concat([hidden_states, full_states], dim=-1)
+ hidden_states = torch.concat([hidden_states, flag], dim=-1)
+ return hidden_states
+
+ def unpack_pp_input(self):
+ tp_size = self.config.tensor_model_parallel_size
+ shared_kv_states = {}
+ input_tensor = self.get_input_tensor()
+ sequence_len = input_tensor.shape[0] * tp_size if self.config.sequence_parallel else input_tensor.shape[0]
+ input_tensor, flag = input_tensor.split([input_tensor.shape[-1] - 2, 2], dim=-1)
+ flag = flag.detach()
+ per_layer_inputs_shape = [
+ sequence_len, input_tensor.shape[1], self.config.num_layers, self.hidden_size_per_layer_input // tp_size
+ ]
+ full_head_dim = self.text_config.global_head_dim
+ full_states_shape = per_layer_inputs_shape[:2] + [self.num_query_groups_per_partition, full_head_dim * 2]
+ sliding_states_shape = full_states_shape[:3] + [self.config.kv_channels * 2]
+ per_layer_inputs_dim = math.prod(per_layer_inputs_shape) // math.prod(input_tensor.shape[:2])
+ full_states_dim = math.prod(full_states_shape) // math.prod(input_tensor.shape[:2])
+ sliding_states_dim = math.prod(sliding_states_shape) // math.prod(input_tensor.shape[:2])
+ if flag[0, 0, 1].item() != 0:
+ input_tensor, full_states = input_tensor.split([input_tensor.shape[-1] - full_states_dim, full_states_dim],
+ dim=-1)
+ full_states = full_states.reshape(*full_states_shape)
+ shared_kv_states['full_attention'] = full_states.chunk(2, -1)
+ if flag[0, 0, 0].item() != 0:
+ input_tensor, sliding_states = input_tensor.split(
+ [input_tensor.shape[-1] - sliding_states_dim, sliding_states_dim], dim=-1)
+ sliding_states = sliding_states.reshape(*sliding_states_shape)
+ shared_kv_states['sliding_attention'] = sliding_states.chunk(2, -1)
+ input_tensor, per_layer_inputs = input_tensor.split(
+ [input_tensor.shape[-1] - per_layer_inputs_dim, per_layer_inputs_dim], dim=-1)
+ self.set_input_tensor(input_tensor)
+ per_layer_inputs = per_layer_inputs.reshape(*per_layer_inputs_shape)
+ return per_layer_inputs, shared_kv_states
+
+ def _forward_output_layer(self, hidden_states, *args, **kwargs):
+ logits, _ = self.output_layer(hidden_states, *args, **kwargs)
+ if self.final_logit_softcapping is not None:
+ logits = logits / self.final_logit_softcapping
+ logits = torch.tanh(logits)
+ logits = logits * self.final_logit_softcapping
+ return logits
+
+
+class Gemma4TransformerLayer(TransformerLayer):
+
+ def __init__(self, config, submodules, *args, **kwargs):
+ super().__init__(config, submodules, *args, **kwargs)
+ text_config = config.hf_config.text_config
+ self.enable_moe_block = text_config.enable_moe_block
+ if self.enable_moe_block:
+ self.experts_mlp = self._build_mlp(submodules.experts_mlp)
+ hidden_size = self.config.hidden_size
+ eps = self.config.layernorm_epsilon
+
+ self.post_attention_layernorm = build_module(TENorm, hidden_size=hidden_size, config=self.config, eps=eps)
+ self.post_feedforward_layernorm = build_module(TENorm, hidden_size=hidden_size, config=self.config, eps=eps)
+
+ self.register_buffer('layer_scalar', torch.ones(1))
+
+ self.hidden_size_per_layer_input = text_config.hidden_size_per_layer_input
+ if self.hidden_size_per_layer_input:
+ from transformers.activations import ACT2FN
+ self.act_fn = ACT2FN[text_config.hidden_activation]
+ self.per_layer_input_gate = build_module(
+ TEColumnParallelLinear,
+ hidden_size,
+ self.hidden_size_per_layer_input,
+ config=self.config,
+ init_method=self.config.init_method,
+ gather_output=False,
+ bias=False,
+ skip_bias_add=False,
+ is_expert=False,
+ tp_comm_buffer_name='per_layer_input_gate',
+ tp_group=self.pg_collection.tp,
+ )
+ self.per_layer_projection = build_module(
+ TERowParallelLinear,
+ self.hidden_size_per_layer_input,
+ hidden_size,
+ config=self.config,
+ init_method=self.config.output_layer_init_method,
+ bias=False,
+ input_is_parallel=True,
+ skip_bias_add=True,
+ is_expert=False,
+ tp_comm_buffer_name='per_layer_projection',
+ tp_group=self.pg_collection.tp,
+ )
+ self.post_per_layer_input_norm = build_module(TENorm, hidden_size=hidden_size, config=self.config, eps=eps)
+
+ if self.enable_moe_block:
+ self.post_feedforward_layernorm_1 = build_module(
+ TENorm, hidden_size=hidden_size, config=self.config, eps=eps)
+ self.post_feedforward_layernorm_2 = build_module(
+ TENorm, hidden_size=hidden_size, config=self.config, eps=eps)
+
+ def _forward_attention(self, hidden_states: Tensor, **kwargs):
+ context = kwargs.pop('context', None)
+ residual = hidden_states
+ input_layernorm_output = self.input_layernorm(hidden_states)
+ # Self attention.
+ nvtx_range_push(suffix='self_attention')
+ attention_output, bias = self.self_attention(input_layernorm_output, **kwargs)
+ nvtx_range_pop(suffix='self_attention')
+ attention_output = self.post_attention_layernorm(attention_output)
+
+ hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)(
+ (attention_output, bias), residual, self.hidden_dropout)
+ return hidden_states, context
+
+ def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None):
+ # Residual connection.
+ residual = hidden_states
+
+ # Optional Layer norm post the cross-attention.
+ pre_mlp_layernorm_output = self._forward_pre_mlp_layernorm(hidden_states)
+ mlp_output, bias = self.mlp(pre_mlp_layernorm_output, padding_mask=padding_mask)
+ if self.enable_moe_block:
+ mlp_output_1 = self.post_feedforward_layernorm_1(mlp_output)
+ mlp_output_2, bias = self.experts_mlp(residual, padding_mask=padding_mask)
+ mlp_output_2 = self.post_feedforward_layernorm_2(mlp_output_2)
+
+ # Combine mlp and moe outputs
+ mlp_output = mlp_output_1 + mlp_output_2
+
+ mlp_output = self.post_feedforward_layernorm(mlp_output)
+ hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)((mlp_output, bias), residual,
+ self.hidden_dropout)
+ output = make_viewless_tensor(inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True)
+ return output
+
+ def forward(self, hidden_states, *args, **kwargs):
+ per_layer_input = kwargs.pop('per_layer_input', None)
+ hidden_states, context = super().forward(hidden_states, *args, **kwargs)
+ if self.hidden_size_per_layer_input:
+ residual = hidden_states
+ hidden_states, _ = self.per_layer_input_gate(hidden_states)
+ hidden_states = self.act_fn(hidden_states)
+ hidden_states = hidden_states * per_layer_input
+ hidden_states, _ = self.per_layer_projection(hidden_states)
+ hidden_states = self.post_per_layer_input_norm(hidden_states)
+ hidden_states = residual + hidden_states
+
+ hidden_states *= self.layer_scalar
+ return hidden_states, context
+
+
+class Gemma4GPTModel(MultimodalGPTModel):
+ language_model_cls = Gemma4TextGPTModel
+
+
+class Gemma4TransformerBlock(TransformerBlock):
+
+ def _layer_forward(self, layer, hidden_states, **kwargs):
+ layer_number = layer.layer_number - 1
+ per_layer_inputs = kwargs.pop('per_layer_inputs', None)
+ if per_layer_inputs is not None:
+ kwargs['per_layer_input'] = per_layer_inputs[:, :, layer_number]
+ layer_type = self.config.hf_config.text_config.layer_types[layer_number]
+ kwargs['rotary_pos_emb'] = kwargs['rotary_pos_emb'][layer_type]
+ kwargs['attention_mask'] = kwargs['attention_mask'][layer_type]
+ return super()._layer_forward(layer, hidden_states, **kwargs)
+
+
+class Gemma4Loader(ModelLoader):
+ model_cls = Gemma4GPTModel
+ transformer_block = Gemma4TransformerBlock
+
+ def get_transformer_layer_spec(self, vp_stage: Optional[int] = None):
+ num_moe_experts = self.config.num_moe_experts
+ self.config.num_moe_experts = None
+ layer_specs = get_gpt_decoder_block_spec(
+ self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage)
+ for layer_spec in layer_specs.layer_specs:
+ layer_spec.submodules.self_attention.module = Gemma4SelfAttention
+ layer_spec.submodules.mlp.module = Gemma4MLP
+ if num_moe_experts is not None:
+ self.config.num_moe_experts = num_moe_experts
+ moe_layer_specs = get_gpt_decoder_block_spec(
+ self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage)
+ for layer_spec, moe_layer_spec in zip(layer_specs.layer_specs, moe_layer_specs.layer_specs):
+ layer_spec.submodules.experts_mlp = moe_layer_spec.submodules.mlp
+ layer_spec.submodules.experts_mlp.module = Gemma4MoELayer
+ return layer_specs
+
+ def _set_transformer_layer(self, transformer_layer_spec):
+ for layer_spec in transformer_layer_spec.layer_specs:
+ layer_spec.module = Gemma4TransformerLayer
+
+
+register_model(
+ ModelMeta(
+ ModelType.gemma4,
+ ['gemma4'],
+ bridge_cls=Gemma4Bridge,
+ visual_cls=Gemma4Vit,
+ loader=Gemma4Loader,
+ ))
diff --git a/src/mcore_bridge/model/modules/transformer_block.py b/src/mcore_bridge/model/modules/transformer_block.py
index 073f671..5d36431 100644
--- a/src/mcore_bridge/model/modules/transformer_block.py
+++ b/src/mcore_bridge/model/modules/transformer_block.py
@@ -19,6 +19,51 @@
apply_module = None
+class _TensorIdx:
+ """Sentinel that marks a position in the flatten schema as a tensor index."""
+ __slots__ = ('idx', )
+
+ def __init__(self, idx):
+ self.idx = idx
+
+
+def _checkpoint_flatten(obj, tensors):
+ """Recursively flatten a nested structure (dict/tuple/list/Tensor) into a schema.
+
+ Tensors are appended to `tensors` and replaced in the schema by a _TensorIdx sentinel.
+ Non-tensor leaves (int, bool, None, str, ...) are stored as-is.
+ The schema mirrors the original structure and is captured in the checkpoint closure.
+ """
+ if torch.is_tensor(obj):
+ idx = len(tensors)
+ tensors.append(obj)
+ return _TensorIdx(idx)
+ elif isinstance(obj, dict):
+ # inplace (gemma4 shared_kv_states)
+ for k, v in obj.items():
+ obj[k] = _checkpoint_flatten(v, tensors)
+ return obj
+ elif isinstance(obj, (tuple, list)):
+ return type(obj)(_checkpoint_flatten(v, tensors) for v in obj)
+ else:
+ return obj # non-tensor leaf: stored directly in schema
+
+
+def _checkpoint_unflatten(schema, tensors):
+ """Reconstruct the original structure from a schema and a flat tensors list."""
+ if isinstance(schema, _TensorIdx):
+ return tensors[schema.idx]
+ elif isinstance(schema, dict):
+ # inplace (gemma4 shared_kv_states)
+ for k, v in schema.items():
+ schema[k] = _checkpoint_unflatten(v, tensors)
+ return schema
+ elif isinstance(schema, (tuple, list)):
+ return type(schema)(_checkpoint_unflatten(v, tensors) for v in schema)
+ else:
+ return schema # non-tensor leaf
+
+
# Code borrowed from NVIDIA/Megatron-LM
class TransformerBlock(McoreTransformerBlock):
@@ -99,26 +144,25 @@ def custom_forward(
return custom_forward
- # `tensor_parallel.checkpoint` / `te_checkpoint` only forward *args to the
- # wrapped function (torch.utils.checkpoint limitation). Convert kwargs to
- # positional args by capturing the keys in closure so tensor kwargs (e.g.
- # qwen3-vl's visual_pos_masks / deepstack_visual_embeds) can flow through
- # activation recompute and remain in the autograd graph.
+ # Variables that don't require gradients can be captured via closure.
+ _ckpt_attention_mask = attention_mask
+ _ckpt_rotary_pos_emb = rotary_pos_emb
extra_kwargs_keys = tuple(kwargs.keys())
- extra_kwargs_values = tuple(kwargs.values())
+ _extra_flat_tensors = []
+ _extra_schemas = [_checkpoint_flatten(v, _extra_flat_tensors) for v in kwargs.values()]
def checkpoint_handler(forward_func):
"""Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`"""
- def wrapped_forward(hidden_states, attention_mask, context, context_mask, rotary_pos_emb, padding_mask,
- *extra_args):
- extra_kwargs = dict(zip(extra_kwargs_keys, extra_args))
+ def wrapped_forward(hidden_states, context, context_mask, padding_mask, *extra_flat):
+ rebuilt = [_checkpoint_unflatten(s, extra_flat) for s in _extra_schemas]
+ extra_kwargs = dict(zip(extra_kwargs_keys, rebuilt))
return forward_func(
hidden_states,
- attention_mask,
+ _ckpt_attention_mask,
context,
context_mask,
- rotary_pos_emb,
+ _ckpt_rotary_pos_emb,
padding_mask,
**extra_kwargs,
)
@@ -131,24 +175,20 @@ def wrapped_forward(hidden_states, attention_mask, context, context_mask, rotary
tensor_parallel.random.get_cuda_rng_tracker,
self.pg_collection.tp,
hidden_states,
- attention_mask,
context,
context_mask,
- rotary_pos_emb,
padding_mask,
- *extra_kwargs_values,
+ *_extra_flat_tensors,
)
else:
return tensor_parallel.checkpoint(
wrapped_forward,
self.config.distribute_saved_activations,
hidden_states,
- attention_mask,
context,
context_mask,
- rotary_pos_emb,
padding_mask,
- *extra_kwargs_values,
+ *_extra_flat_tensors,
)
if self.config.recompute_method == 'uniform':
diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py
index 21b8cfa..6d500e0 100644
--- a/src/mcore_bridge/model/modules/transformer_layer.py
+++ b/src/mcore_bridge/model/modules/transformer_layer.py
@@ -134,8 +134,6 @@ def __init__(
# [Module 8: MLP block]
self.mlp = self._build_mlp(submodules.mlp)
- if hasattr(self.mlp, 'set_layer_number'):
- self.mlp.set_layer_number(self.layer_number)
# [Module 9: BiasDropoutFusion]
self.mlp_bda = build_module(submodules.mlp_bda)
self.is_moe_layer = isinstance(self.mlp, MoELayer)
@@ -216,13 +214,14 @@ def _build_mlp(self, mlp_spec):
additional_mlp_kwargs = {}
# import here to avoid circular import
from mcore_bridge.model.gpts.glm4 import Glm4MLP
+ from mcore_bridge.model.mm_gpts.gemma4 import Gemma4MLP, Gemma4MoELayer
# MLP expects tp_group but MoELayer expects pg_collection to be passed in.
# We can change MLP to accept pg_collection but it makes the logic implicit
# The conditional below is to make the logic explicit
# if smlp_spec is not a ModuleSpec,we dont have to handle passing additional kwargs
if isinstance(mlp_spec, ModuleSpec):
- if mlp_spec.module in (MoELayer, TEGroupedMLP, SequentialMLP):
+ if mlp_spec.module in (MoELayer, Gemma4MoELayer, TEGroupedMLP, SequentialMLP):
additional_mlp_kwargs['pg_collection'] = pg_collection
# Pass is_mtp_layer flag to MoELayer to distinguish MTP MoE layers.
if mlp_spec.module == MoELayer and 'is_mtp_layer' in inspect.signature(MoELayer).parameters:
@@ -230,12 +229,17 @@ def _build_mlp(self, mlp_spec):
elif mlp_spec.module in (MLP, Glm4MLP):
assert hasattr(pg_collection, 'tp'), 'TP process group is required for MLP in TransformerLayer'
additional_mlp_kwargs['tp_group'] = pg_collection.tp
+ elif mlp_spec.module == Gemma4MLP:
+ additional_mlp_kwargs['layer_number'] = self.layer_number
elif TEFusedMLP is not None and mlp_spec.module == TEFusedMLP:
assert hasattr(pg_collection, 'tp'), 'TP process group is required for TEFusedMLP in TransformerLayer'
additional_mlp_kwargs['tp_group'] = pg_collection.tp
else:
logger.warning_once(f'Unknown MLP type: {mlp_spec.module}. Using default kwargs.')
- return build_module(mlp_spec, config=self.config, **additional_mlp_kwargs)
+ mlp = build_module(mlp_spec, config=self.config, **additional_mlp_kwargs)
+ if hasattr(mlp, 'set_layer_number'):
+ mlp.set_layer_number(self.layer_number)
+ return mlp
def forward(self, *args, **kwargs):
"""
diff --git a/src/mcore_bridge/model/register.py b/src/mcore_bridge/model/register.py
index 84bfdb4..22ad26d 100644
--- a/src/mcore_bridge/model/register.py
+++ b/src/mcore_bridge/model/register.py
@@ -118,7 +118,7 @@ def _set_shared_expert_gate(self, transformer_layer_spec):
if hasattr(layer_spec.submodules.mlp.submodules, 'shared_experts'):
layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True}
- def _set_custom_layer(self, transformer_layer_spec):
+ def _set_transformer_layer(self, transformer_layer_spec):
for layer_spec in transformer_layer_spec.layer_specs:
layer_spec.module = TransformerLayer
@@ -130,7 +130,7 @@ def build_model(
) -> Union['GPTModel', 'MultimodalGPTModel']:
transformer_layer_spec = self.get_transformer_layer_spec(vp_stage=vp_stage)
self._set_shared_expert_gate(transformer_layer_spec)
- self._set_custom_layer(transformer_layer_spec)
+ self._set_transformer_layer(transformer_layer_spec)
mtp_block_spec = None
if self.config.mtp_num_layers is not None:
mtp_block_spec = self.get_mtp_block_spec(transformer_layer_spec, vp_stage=vp_stage)
diff --git a/src/mcore_bridge/model/rope.py b/src/mcore_bridge/model/rope.py
index e7db3c3..a514f43 100644
--- a/src/mcore_bridge/model/rope.py
+++ b/src/mcore_bridge/model/rope.py
@@ -106,12 +106,13 @@ def _get_rope_type(rope_scaling: Optional[Dict[str, Any]]):
return rope_type
-def get_rope_inv_freq(config, seq_len=None, **kwargs):
+def get_rope_inv_freq(config, seq_len=None, text_config=None, **kwargs):
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
ROPE_INIT_FUNCTIONS.update(EXTENDED_ROPE_INIT_FUNCTIONS)
- dummy_config = _get_dummy_config(config)
+ if text_config is None:
+ text_config = _get_dummy_config(config)
rope_init_fn = ROPE_INIT_FUNCTIONS[_get_rope_type(config.rope_scaling)]
- inv_freq, attention_scaling = rope_init_fn(dummy_config, 'cpu', seq_len=seq_len, **kwargs)
+ inv_freq, attention_scaling = rope_init_fn(text_config, 'cpu', seq_len=seq_len, **kwargs)
if attention_scaling is None:
attention_scaling = 1.
return inv_freq, attention_scaling
diff --git a/tests/test_mllm.py b/tests/test_mllm.py
index 1fe103e..5eb07d2 100644
--- a/tests/test_mllm.py
+++ b/tests/test_mllm.py
@@ -116,6 +116,10 @@ def test_qwen3_asr():
_test_model('Qwen/Qwen3-ASR-1.7B')
+def test_gemma4():
+ _test_model('google/gemma-4-E2B-it')
+
+
if __name__ == '__main__':
# test_qwen2_5_vl()
# test_qwen2_vl()
@@ -136,4 +140,5 @@ def test_qwen3_asr():
# test_llama4()
# test_qwen3_5()
# test_llava_onevision1_5()
- test_qwen3_asr()
+ # test_qwen3_asr()
+ test_gemma4()