diff --git a/README.md b/README.md index 1473a16..bf9423a 100644 --- a/README.md +++ b/README.md @@ -145,6 +145,7 @@ The following is the list of models supported by MCore-Bridge: | Series | model_type | | -------- | ------------------------------------------------------------ | | Qwen | qwen2_vl, qwen2_5_vl, qwen2_5_omni
qwen3_vl, qwen3_vl_moe, qwen3_omni_moe, qwen3_asr
qwen3_5, qwen3_5_moe | +| Gemma | gemma4 | | GLM | glm4v, glm4v_moe | | Kimi | kimi_vl | | InternVL | internvl_chat, internvl | diff --git a/README_zh.md b/README_zh.md index 5f170e7..b7ee63d 100644 --- a/README_zh.md +++ b/README_zh.md @@ -142,6 +142,7 @@ uv pip install -e . --torch-backend=auto | 系列 | model_type | | -------- | ------------------------------------------------------------ | | Qwen | qwen2_vl, qwen2_5_vl, qwen2_5_omni
qwen3_vl, qwen3_vl_moe, qwen3_omni_moe, qwen3_asr
qwen3_5, qwen3_5_moe | +| Gemma | gemma4 | | GLM | glm4v, glm4v_moe | | Kimi | kimi_vl | | InternVL | internvl_chat, internvl | diff --git a/src/mcore_bridge/bridge/gpt_bridge.py b/src/mcore_bridge/bridge/gpt_bridge.py index 0eceb23..1d95f31 100644 --- a/src/mcore_bridge/bridge/gpt_bridge.py +++ b/src/mcore_bridge/bridge/gpt_bridge.py @@ -360,7 +360,6 @@ def _broadcast_ep_pp(self, tensor, is_expert): dtype_mapping_r = {v: k for k, v in dtype_mapping.items()} if tensor is None: dist.broadcast(meta_data, src=src_rank, group=pp_group) - assert meta_data[0].item() > 0, f'meta_data: {meta_data}' shape = meta_data[1:1 + meta_data[0]].tolist() dtype = dtype_mapping_r[meta_data[-1].item()] tensor = torch.empty(shape, device='cuda', dtype=dtype) @@ -383,7 +382,8 @@ def _get_weight( # tp/etp mg_scale_inv = None tensor = mg_weight - if tensor is not None: + is_scalar = isinstance(tensor, torch.Tensor) and tensor.ndim == 0 + if tensor is not None and not is_scalar: if not isinstance(tensor, (list, tuple)): tensor = [tensor] if self._is_fp8_param(tensor[0]): @@ -392,8 +392,7 @@ def _get_weight( for t in tensor ] tensor = [t._rowwise_data for t in tensor] - del mg_weight - if tensor is not None: + del mg_weight assert isinstance(tensor, (list, tuple)), f'mg_key: {mg_key}' tensor = torch.concat(tensor, dim=0) if mg_scale_inv is not None: @@ -529,22 +528,37 @@ def _filter_prefix(state_dict, prefix: str): return state_dict return {k: v for k, v in state_dict.items() if k.startswith(prefix)} - def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool): + def _reduce_tensor_pp_group(self, tensor, to_mcore, dtype=torch.long, op=dist.ReduceOp.MAX): + if to_mcore: + return tensor + tensor = torch.tensor([tensor], dtype=dtype, device='cuda') + if self.pp_size > 1: + dist.all_reduce(tensor, group=self.pp_group, op=op) + tensor = tensor.item() + return tensor + + def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): config = self.config - num_query_groups = ( - config.num_query_groups if config.num_query_groups is not None else config.num_attention_heads) + num_query_groups = kwargs.get('num_query_groups') + if num_query_groups is None: + num_query_groups = ( + config.num_query_groups if config.num_query_groups is not None else config.num_attention_heads) hidden_size_block = config.hidden_size // self.fp8_block_size + attention_k_eq_v = kwargs.get('attention_k_eq_v', False) + kv_proj_list = ['k_proj'] if attention_k_eq_v else ['k_proj', 'v_proj'] + kv_channels = kwargs.get('kv_channels') + if kv_channels is None: + kv_channels = self.config.kv_channels if to_mcore: if isinstance(mg_attn.linear_qkv, LoraParallelLinear): lora_A = hf_state_dict['q_proj.lora_A.weight'].load() - assert (lora_A == hf_state_dict['k_proj.lora_A.weight'].load()).all() and ( - lora_A == hf_state_dict['v_proj.lora_A.weight'].load() - ).all(), 'Need to ensure QKV\'s lora_A are consistent' + assert all((lora_A == hf_state_dict[f'{k}.lora_A.weight'].load()).all() + for k in kv_proj_list), 'Need to ensure QKV\'s lora_A are consistent' q_lora_B = hf_state_dict['q_proj.lora_B.weight'].load() lora_B = torch.cat([ q_lora_B.reshape((num_query_groups, -1, q_lora_B.shape[-1])), - hf_state_dict['k_proj.lora_B.weight'].load().reshape((num_query_groups, -1, q_lora_B.shape[-1])), - hf_state_dict['v_proj.lora_B.weight'].load().reshape((num_query_groups, -1, q_lora_B.shape[-1])), + *(hf_state_dict[f'{k}.lora_B.weight'].load().reshape((num_query_groups, -1, q_lora_B.shape[-1])) + for k in kv_proj_list), ], dim=1).reshape((-1, q_lora_B.shape[-1])) self._set_weight(mg_attn.linear_qkv.lora_A[self._adapter_name].weight, lora_A, @@ -553,29 +567,24 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool): 'linear_qkv.lora_B.weight') elif not self._peft_format: linear_qkv_weight = torch.cat([ - hf_state_dict['q_proj.weight'].load().reshape((num_query_groups, -1, config.hidden_size)), - hf_state_dict['k_proj.weight'].load().reshape((num_query_groups, -1, config.hidden_size)), - hf_state_dict['v_proj.weight'].load().reshape((num_query_groups, -1, config.hidden_size)), + hf_state_dict[f'{k}.weight'].load().reshape((num_query_groups, -1, config.hidden_size)) + for k in ['q_proj'] + kv_proj_list ], dim=1).reshape((-1, config.hidden_size)) qkv_scale_inv = None if 'q_proj.weight_scale_inv' in hf_state_dict: qkv_scale_inv = torch.cat([ - hf_state_dict['q_proj.weight_scale_inv'].load().reshape( - (num_query_groups, -1, hidden_size_block)), - hf_state_dict['k_proj.weight_scale_inv'].load().reshape( - (num_query_groups, -1, hidden_size_block)), - hf_state_dict['v_proj.weight_scale_inv'].load().reshape( - (num_query_groups, -1, hidden_size_block)), + hf_state_dict[f'{k}.weight_scale_inv'].load().reshape((num_query_groups, -1, hidden_size_block)) + for k in ['q_proj'] + kv_proj_list ], dim=1).reshape((-1, hidden_size_block)) self._set_weight( mg_attn.linear_qkv.weight, linear_qkv_weight, 'linear_qkv.weight', hf_scale_inv=qkv_scale_inv) else: - q_dim = self.config.kv_channels * self.config.num_attention_heads // self.config.num_query_groups + q_dim = kv_channels * self.config.num_attention_heads // num_query_groups if self.config.attention_output_gate: q_dim *= 2 - kv_dim = self.config.kv_channels + kv_dim = kv_channels q_block = q_dim // self.fp8_block_size kv_block = kv_dim // self.fp8_block_size is_lora = False if mg_attn is None else isinstance(mg_attn.linear_qkv, @@ -591,15 +600,17 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool): None if mg_attn is None else mg_attn.linear_qkv.lora_B[self._adapter_name].weight.data, f'linear_qkv.lora_B.{self._adapter_name}.weight') if lora_A is not None: - self._peft_target_modules.update({'q_proj', 'k_proj', 'v_proj'}) - for key in ['q_proj', 'k_proj', 'v_proj']: + self._peft_target_modules.update({'q_proj'} | set(kv_proj_list)) + for key in ['q_proj'] + kv_proj_list: hf_state_dict[f'{key}.lora_A.weight'] = lora_A.clone() lora_B = lora_B.reshape((num_query_groups, -1, lora_B.shape[-1])) hf_state_dict['q_proj.lora_B.weight'] = lora_B[:, :q_dim, :].reshape(-1, lora_B.shape[-1]).clone() - hf_state_dict['k_proj.lora_B.weight'] = lora_B[:, - q_dim:-kv_dim, :].reshape(-1, - lora_B.shape[-1]).clone() - hf_state_dict['v_proj.lora_B.weight'] = lora_B[:, -kv_dim:, :].reshape(-1, lora_B.shape[-1]).clone() + hf_state_dict['k_proj.lora_B.weight'] = lora_B[:, q_dim:q_dim + kv_dim, :].reshape( + -1, lora_B.shape[-1]).clone() + if not attention_k_eq_v: + hf_state_dict['v_proj.lora_B.weight'] = lora_B[:, + -kv_dim:, :].reshape(-1, + lora_B.shape[-1]).clone() elif not self._peft_format: mg_attn_weight, scale_inv = self._get_weight( None if mg_attn is None else mg_attn.linear_qkv.weight.data, 'linear_qkv.weight') @@ -607,27 +618,27 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool): mg_attn_weight = mg_attn_weight.reshape((num_query_groups, -1, config.hidden_size)) hf_state_dict['q_proj.weight'] = mg_attn_weight[:, :q_dim, :].reshape(-1, config.hidden_size).clone() - hf_state_dict['k_proj.weight'] = mg_attn_weight[:, q_dim:-kv_dim, :].reshape( + hf_state_dict['k_proj.weight'] = mg_attn_weight[:, q_dim:q_dim + kv_dim, :].reshape( -1, config.hidden_size).clone() - hf_state_dict['v_proj.weight'] = mg_attn_weight[:, -kv_dim:, :].reshape(-1, - config.hidden_size).clone() + if not attention_k_eq_v: + hf_state_dict['v_proj.weight'] = mg_attn_weight[:, -kv_dim:, :].reshape( + -1, config.hidden_size).clone() if scale_inv is not None: scale_inv = scale_inv.reshape((num_query_groups, -1, hidden_size_block)) hf_state_dict['q_proj.weight_scale_inv'] = scale_inv[:, :q_block, :].reshape( -1, hidden_size_block).clone() - hf_state_dict['k_proj.weight_scale_inv'] = scale_inv[:, q_block:-kv_block, :].reshape( - -1, hidden_size_block).clone() - hf_state_dict['v_proj.weight_scale_inv'] = scale_inv[:, -kv_block:, :].reshape( + hf_state_dict['k_proj.weight_scale_inv'] = scale_inv[:, q_block:q_block + kv_block:, :].reshape( -1, hidden_size_block).clone() + if not attention_k_eq_v: + hf_state_dict['v_proj.weight_scale_inv'] = scale_inv[:, -kv_block:, :].reshape( + -1, hidden_size_block).clone() del mg_attn_weight # Copy bias if (config.add_bias_linear or config.add_qkv_bias) and not self._peft_format: if to_mcore: linear_qkv_bias = torch.cat([ - hf_state_dict['q_proj.bias'].load().reshape((num_query_groups, -1)), - hf_state_dict['k_proj.bias'].load().reshape((num_query_groups, -1)), - hf_state_dict['v_proj.bias'].load().reshape((num_query_groups, -1)), + hf_state_dict[f'{k}.bias'].load().reshape((num_query_groups, -1)) for k in ['q_proj'] + kv_proj_list ], dim=1).reshape(-1) self._set_weight(mg_attn.linear_qkv.bias, linear_qkv_bias, 'linear_qkv.bias') @@ -637,8 +648,9 @@ def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool): if mg_attn_bias is not None: mg_attn_bias = mg_attn_bias.reshape((num_query_groups, -1)) hf_state_dict['q_proj.bias'] = mg_attn_bias[:, :q_dim].reshape(-1).clone() - hf_state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:-kv_dim].reshape(-1).clone() - hf_state_dict['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1).clone() + hf_state_dict['k_proj.bias'] = mg_attn_bias[:, q_dim:q_dim + kv_dim].reshape(-1).clone() + if not attention_k_eq_v: + hf_state_dict['v_proj.bias'] = mg_attn_bias[:, -kv_dim:].reshape(-1).clone() return hf_state_dict def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): @@ -647,24 +659,34 @@ def _set_attn_state(self, mg_attn, hf_state_dict, hf_prefix: str, layer_idx: int else: hf_state_dict = {} config = self.config - hf_state_dict.update(self._set_qkv(mg_attn, hf_state_dict, to_mcore)) + hf_state_dict.update(self._set_qkv(mg_attn, hf_state_dict, to_mcore, layer_idx=layer_idx)) self._set_state_dict(mg_attn, 'linear_proj.weight', hf_state_dict, f'{self.hf_o_proj_key}.weight', to_mcore) if config.add_bias_linear: self._set_state_dict(mg_attn, 'linear_proj.bias', hf_state_dict, f'{self.hf_o_proj_key}.bias', to_mcore) if getattr(config, 'softmax_type', 'vanilla') == 'learnable': self._set_state_dict(mg_attn, 'core_attention.softmax_offset', hf_state_dict, 'sinks', to_mcore) if config.qk_layernorm: - self._set_qk_layernorm(mg_attn, hf_state_dict, to_mcore) + self._set_qk_layernorm(mg_attn, hf_state_dict, to_mcore, layer_idx=layer_idx) if to_mcore: hf_state_dict = {} else: hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict - def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore): + def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore, **kwargs): self._set_state_dict(mg_attn, 'q_layernorm.weight', hf_state_dict, self.hf_q_norm_key, to_mcore) self._set_state_dict(mg_attn, 'k_layernorm.weight', hf_state_dict, self.hf_k_norm_key, to_mcore) + def _set_router(self, mg_mlp, hf_state_dict, to_mcore): + hf_gate_key = self.hf_gate_key + if self.llm_model_type == 'gpt_oss': + hf_gate_key = 'router.weight' + self._set_state_dict(mg_mlp, 'router.weight', hf_state_dict, hf_gate_key, to_mcore) + if self.config.add_bias_linear: + self._set_state_dict(mg_mlp, 'router.bias', hf_state_dict, hf_gate_key.replace('weight', 'bias'), to_mcore) + if self.config.moe_router_enable_expert_bias: + self._set_state_dict(mg_mlp, 'router.expert_bias', hf_state_dict, self.hf_expert_bias_key, to_mcore) + def _set_moe_state( self, mg_mlp, @@ -678,17 +700,8 @@ def _set_moe_state( hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) else: hf_state_dict = {} - config = self.config - hf_gate_key = self.hf_gate_key - if self.llm_model_type == 'gpt_oss': - hf_gate_key = 'router.weight' - self._set_state_dict(mg_mlp, 'router.weight', hf_state_dict, hf_gate_key, to_mcore) - if config.add_bias_linear: - self._set_state_dict(mg_mlp, 'router.bias', hf_state_dict, hf_gate_key.replace('weight', 'bias'), to_mcore) - if config.moe_router_enable_expert_bias: - self._set_state_dict(mg_mlp, 'router.expert_bias', hf_state_dict, self.hf_expert_bias_key, to_mcore) - - if config.moe_shared_expert_intermediate_size: + self._set_router(mg_mlp, hf_state_dict, to_mcore) + if self.config.moe_shared_expert_intermediate_size: hf_shared_expert_key = self.hf_shared_expert_key if hf_shared_expert_key is None: if 'qwen' in self.llm_model_type or self.model_type == 'llama4': @@ -698,7 +711,7 @@ def _set_moe_state( hf_state_dict.update( self._set_mlp_state(None if mg_mlp is None else mg_mlp.shared_experts, hf_state_dict, f'{hf_shared_expert_key}.', layer_idx, to_mcore)) - if config.moe_shared_expert_gate: + if self.config.moe_shared_expert_gate: self._set_state_dict(mg_mlp, 'shared_experts.gate_weight', hf_state_dict, 'shared_expert_gate.weight', to_mcore) for ep_rank in range(self.ep_size): @@ -735,7 +748,7 @@ def _get_hf_experts_attr(self, is_mtp: bool = False): 'glm4_moe_lite', 'minimax_m2', 'olmoe', 'qwen3_next', 'glm_moe_dsa', 'deepseek_v32' }: return False, False - elif self.model_type in {'qwen3_vl_moe', 'llama4'} or self.llm_model_type in {'gpt_oss'}: + elif self.model_type in {'qwen3_vl_moe', 'llama4', 'gemma4'} or self.llm_model_type in {'gpt_oss'}: return True, True else: # default @@ -1618,13 +1631,16 @@ def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: i hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) return hf_state_dict + def _set_word_embeddings(self, mg_model, hf_state_dict, to_mcore): + lm_model = getattr(mg_model, 'language_model') if self.is_multimodal else mg_model + self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, self.hf_embed_key, to_mcore) + def _convert_pre_process(self, mg_model, hf_state_dict, hf_prefix: str, to_mcore): if to_mcore: hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) else: hf_state_dict = {} - lm_model = getattr(mg_model, 'language_model') if self.is_multimodal else mg_model - self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, self.hf_embed_key, to_mcore) + self._set_word_embeddings(mg_model, hf_state_dict, to_mcore) if self.is_multimodal: for prefix, mg_prefix in self.module_mapping.items(): mg_module = deep_getattr(mg_model, f'visual.{mg_prefix}') diff --git a/src/mcore_bridge/config/parser.py b/src/mcore_bridge/config/parser.py index 869f876..842c1bb 100644 --- a/src/mcore_bridge/config/parser.py +++ b/src/mcore_bridge/config/parser.py @@ -1,4 +1,6 @@ # Copyright (c) ModelScope Contributors. All rights reserved. +import torch.nn.functional as F +from functools import partial from transformers import PretrainedConfig from typing import Any, Dict @@ -25,7 +27,7 @@ # moe 'moe_ffn_hidden_size': ['moe_intermediate_size'], 'moe_shared_expert_intermediate_size': ['shared_expert_intermediate_size'], - 'moe_router_topk': ['num_experts_per_tok', 'moe_topk', 'moe_k'], + 'moe_router_topk': ['num_experts_per_tok', 'moe_topk', 'moe_k', 'top_k_experts'], 'moe_router_num_groups': ['n_group'], 'moe_router_group_topk': ['topk_group'], 'num_moe_experts': ['num_experts', 'n_routed_experts', 'moe_num_experts', 'num_local_experts'], @@ -153,6 +155,17 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]: n_shared_experts = res.pop('n_shared_experts') elif llm_model_type in {'ernie4_5', 'ernie4_5_moe', 'glm4'}: res['rotary_interleaved'] = True + elif hf_model_type in {'gemma4'}: + res['qk_layernorm'] = True + # If set to "vision", pass attention_mask manually. + if hf_config.text_config.use_bidirectional_attention is None: + res['window_size'] = f'{window_size - 1},0' + window_attn_skip_freq = ','.join(['1' if lt == 'sliding_attention' else '0' for lt in layer_types]) + res['window_attn_skip_freq'] = f'[{window_attn_skip_freq}]' + res['softmax_scale'] = 1. + res['swiglu'] = False + res['gated_linear_unit'] = True + res['activation_func'] = partial(F.gelu, approximate='tanh') elif llm_model_type == 'gpt_oss': res['add_bias_linear'] = True res['bias_dropout_fusion'] = False @@ -161,7 +174,7 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]: res['quick_geglu'] = True res['activation_func_clamp_value'] = 7 res['glu_linear_offset'] = 1 - res['window_size'] = f'{window_size},0' + res['window_size'] = f'{window_size - 1},0' if layer_types is None: res['window_attn_skip_freq'] = '2' else: diff --git a/src/mcore_bridge/model/constant.py b/src/mcore_bridge/model/constant.py index 58b0ee3..9b8dc1b 100644 --- a/src/mcore_bridge/model/constant.py +++ b/src/mcore_bridge/model/constant.py @@ -29,6 +29,7 @@ class MLLMModelType: glm4v_moe = 'glm4v_moe' kimi_vl = 'kimi_vl' llama4 = 'llama4' + gemma4 = 'gemma4' kimi_k25 = 'kimi_k25' diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 02f8bb0..84bf56d 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -55,6 +55,7 @@ def sharded_state_dict( class GPTModel(McoreGPTModel): config: ModelConfig + extra_forward_keys = [] def __init__( self, @@ -105,9 +106,7 @@ def __init__( for i in range(len(self.decoder.layers)): if hasattr(self.decoder.layers[i].self_attention, 'rotary_pos_emb'): del self.decoder.layers[i].self_attention.rotary_pos_emb - self.attention_scaling = 1. - new_inv_freq, self.attention_scaling = get_rope_inv_freq(config) - self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) + self._set_inv_freq() if self.config.task_type == 'seq_cls' and self.post_process: self.output_layer = OutputLayerLinear( config.hidden_size, @@ -216,7 +215,35 @@ def _preprocess( if decoder_input is not None and self.training and torch.is_grad_enabled() and not decoder_input.requires_grad: # fix LoRA incompatibility with gradient checkpointing decoder_input = decoder_input.requires_grad_(True) + rotary_pos_emb, rotary_pos_cos, rotary_pos_sin = self._get_rotary_pos_emb( + decoder_input, position_ids, packed_seq_params=packed_seq_params) + if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration') + or self.config.flash_decode) and rotary_pos_cos is not None + and inference_context.is_static_batching()): + current_batch_size = input_ids.shape[0] + sequence_len_offset = torch.tensor( + [inference_context.sequence_len_offset] * current_batch_size, + dtype=torch.int32, + device=rotary_pos_cos.device, # Co-locate this with the rotary tensors + ) + else: + sequence_len_offset = None + + # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the + # reference held by this caller function, enabling early garbage collection for + # inference. Skip wrapping if decoder_input is logged after decoder completion. + if in_inference_mode and not has_config_logger_enabled(self.config): + decoder_input = WrappedTensor(decoder_input) + + return (decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset) + + def _set_inv_freq(self): + self.attention_scaling = 1. + new_inv_freq, self.attention_scaling = get_rope_inv_freq(self.config) + self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) + + def _get_rotary_pos_emb(self, decoder_input, position_ids, packed_seq_params, inference_context=None): # Rotary positional embeddings (embedding is None for PP intermediate devices) rotary_pos_emb = None rotary_pos_cos = None @@ -251,26 +278,7 @@ def _preprocess( rotary_seq_len, packed_seq=packed_seq, ) - - if (in_inference_mode and ((self.config.enable_cuda_graph and self.config.cuda_graph_scope != 'full_iteration') - or self.config.flash_decode) and rotary_pos_cos is not None - and inference_context.is_static_batching()): - current_batch_size = input_ids.shape[0] - sequence_len_offset = torch.tensor( - [inference_context.sequence_len_offset] * current_batch_size, - dtype=torch.int32, - device=rotary_pos_cos.device, # Co-locate this with the rotary tensors - ) - else: - sequence_len_offset = None - - # Wrap decoder_input to allow the decoder (TransformerBlock) to delete the - # reference held by this caller function, enabling early garbage collection for - # inference. Skip wrapping if decoder_input is logged after decoder completion. - if in_inference_mode and not has_config_logger_enabled(self.config): - decoder_input = WrappedTensor(decoder_input) - - return decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset + return rotary_pos_emb, rotary_pos_cos, rotary_pos_sin # Code borrowed from NVIDIA/Megatron-LM def forward( @@ -314,7 +322,11 @@ def forward( packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' if self.position_embedding_type == 'rope' and packed_seq and not self.config.apply_rope_fusion: assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' - decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]] + if isinstance(rotary_pos_emb, dict): + for k, v in rotary_pos_emb.items(): + decoder_rotary_pos_emb[k] = v[position_ids[0]] + else: + decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]] mtp_decoder_input = decoder_input if self.config.is_multimodal and self.config.mtp_num_layers and decoder_input is None: @@ -322,9 +334,12 @@ def forward( input_tensor, mtp_decoder_input = input_tensor.chunk(2, dim=0) self.set_input_tensor(input_tensor) kwargs = {} + full_attention_mask = attention_mask + if isinstance(full_attention_mask, dict): + full_attention_mask = full_attention_mask['full_attention'] if mcore_016 and attention_mask is not None: assert packed_seq_params is None - padding_mask = ~((~attention_mask).sum(dim=(1, 2)) > 0) + padding_mask = ~((~full_attention_mask).sum(dim=(1, 2)) > 0) if self.config.context_parallel_size > 1: padding_mask = split_cp_inputs(padding_mask, None, 1) tp_size = self.config.tensor_model_parallel_size @@ -366,6 +381,9 @@ def forward( inference_context=inference_context, ) + def _forward_output_layer(self, hidden_states, *args, **kwargs): + return self.output_layer(hidden_states, *args, **kwargs)[0] + def _postprocess( self, hidden_states, @@ -432,7 +450,7 @@ def _postprocess( loss_mask = torch.ones_like(mtp_labels) for mtp_layer_number in range(self.config.mtp_unroll_steps): # output - mtp_logits, _ = self.output_layer( + mtp_logits = self._forward_output_layer( hidden_states_list[mtp_layer_number + 1], weight=output_weight, runtime_gather_output=runtime_gather_output, @@ -493,7 +511,7 @@ def _postprocess( if self.config.task_type == 'embedding': logits = F.normalize(hidden_states, p=2, dim=-1) else: - logits, _ = self.output_layer( + logits = self._forward_output_layer( hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output) if self.config.task_type == 'generative_reranker': logits = gather_from_tensor_model_parallel_region(logits) diff --git a/src/mcore_bridge/model/gpts/bailing_moe.py b/src/mcore_bridge/model/gpts/bailing_moe.py index f3c383a..7e2a70d 100644 --- a/src/mcore_bridge/model/gpts/bailing_moe.py +++ b/src/mcore_bridge/model/gpts/bailing_moe.py @@ -71,7 +71,7 @@ class BailingMoeBridge(GPTBridge): hf_expert_bias_key = 'gate.expert_bias' hf_o_proj_key = 'dense' - def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool): + def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): self._set_state_dict(mg_attn, 'linear_qkv.weight', hf_state_dict, 'query_key_value.weight', to_mcore) assert not self.config.add_bias_linear return hf_state_dict diff --git a/src/mcore_bridge/model/gpts/minimax_m2.py b/src/mcore_bridge/model/gpts/minimax_m2.py index 42aa4c6..853e3a5 100644 --- a/src/mcore_bridge/model/gpts/minimax_m2.py +++ b/src/mcore_bridge/model/gpts/minimax_m2.py @@ -74,7 +74,7 @@ class MinimaxM2Bridge(GPTBridge): hf_mlp_prefix = 'block_sparse_moe' hf_expert_bias_key = 'e_score_correction_bias' - def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore): + def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore, **kwargs): self._set_state_dict(mg_attn, 'q_norm.weight', hf_state_dict, 'q_norm.weight', to_mcore) self._set_state_dict(mg_attn, 'k_norm.weight', hf_state_dict, 'k_norm.weight', to_mcore) diff --git a/src/mcore_bridge/model/mm_gpt_model.py b/src/mcore_bridge/model/mm_gpt_model.py index 2a5f659..2a80c04 100644 --- a/src/mcore_bridge/model/mm_gpt_model.py +++ b/src/mcore_bridge/model/mm_gpt_model.py @@ -79,6 +79,7 @@ def forward( packed_seq_params: PackedSeqParams = None, **kwargs, ) -> torch.Tensor: + extra_kwargs = {k: kwargs[k] for k in self.language_model.extra_forward_keys} if decoder_input is not None: pass elif self.pre_process: @@ -90,6 +91,7 @@ def forward( # decoder will get hidden_states from encoder.input_tensor decoder_input = None kwargs = {} + kwargs.update(extra_kwargs) return self.language_model( input_ids=input_ids, position_ids=position_ids, diff --git a/src/mcore_bridge/model/mm_gpts/__init__.py b/src/mcore_bridge/model/mm_gpts/__init__.py index 9009edb..b862ec6 100644 --- a/src/mcore_bridge/model/mm_gpts/__init__.py +++ b/src/mcore_bridge/model/mm_gpts/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) ModelScope Contributors. All rights reserved. -from . import glm, internvl, kimi_vl, llama4, llava, qwen, qwen3_5, qwen3_5_gdn, qwen3_asr, qwen3_omni, qwen3_vl +from . import gemma4, glm, internvl, kimi_vl, llama4, llava, qwen, qwen3_5, qwen3_5_gdn, qwen3_asr, qwen3_omni, qwen3_vl diff --git a/src/mcore_bridge/model/mm_gpts/gemma4.py b/src/mcore_bridge/model/mm_gpts/gemma4.py new file mode 100644 index 0000000..2c68a73 --- /dev/null +++ b/src/mcore_bridge/model/mm_gpts/gemma4.py @@ -0,0 +1,814 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import copy +import math +import torch +from megatron.core.extensions.transformer_engine import (SplitAlongDim, TEColumnParallelLinear, TENorm, + TERowParallelLinear) +from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb +from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding +from megatron.core.models.common.embeddings.yarn_rotary_pos_embedding import _yarn_get_concentration_factor_from_config +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.parallel_state import get_tensor_model_parallel_rank +from megatron.core.tensor_parallel import VocabParallelEmbedding, all_gather_last_dim_from_tensor_parallel_region +from megatron.core.tensor_parallel.mappings import (gather_from_tensor_model_parallel_region, + scatter_to_tensor_model_parallel_region) +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP +from megatron.core.transformer.moe.moe_layer import MoELayer +from megatron.core.transformer.spec_utils import build_module +from megatron.core.utils import make_viewless_tensor, nvtx_range_pop, nvtx_range_push +from torch import Tensor, nn +from transformers import AutoModel, PretrainedConfig +from transformers.utils.versions import require_version +from typing import Optional, Tuple + +from mcore_bridge.bridge import MultimodalGPTBridge +from mcore_bridge.config import ModelConfig + +from ..constant import ModelType +from ..gpt_model import GPTModel +from ..mm_gpt_model import MultimodalGPTModel +from ..modules import TransformerBlock, TransformerLayer +from ..register import ModelLoader, ModelMeta, register_model +from ..rope import get_rope_inv_freq +from .utils import HuggingFaceVit + + +class Gemma4RMSNormNoScale(torch.nn.Module): + """RMSNorm without learnable scale, mirroring HF `Gemma4RMSNorm(with_scale=False)`.""" + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + orig_dtype = x.dtype + x = x.to(torch.float32) + variance = x.pow(2).mean(-1, keepdim=True) + return (x * torch.rsqrt(variance + self.eps)).to(orig_dtype) + + +class Gemma4Vit(HuggingFaceVit): + module_mapping = { + 'model.vision_tower': 'vision_tower', + 'model.embed_vision': 'embed_vision', + 'model.audio_tower': 'audio_tower', + 'model.embed_audio': 'embed_audio', + } + _vision_tower = ['vision_tower', 'audio_tower'] + _aligner = ['embed_vision', 'embed_audio'] + + def prepare_model(self, hf_config: PretrainedConfig): + from transformers.models.gemma4.modeling_gemma4 import Gemma4Model, Gemma4MultimodalEmbedder + self.vision_tower = AutoModel.from_config(hf_config.vision_config) + dtype = self.vision_tower.dtype + self.audio_tower = AutoModel.from_config(hf_config.audio_config) if hf_config.audio_config is not None else None + self.embed_vision = Gemma4MultimodalEmbedder(hf_config.vision_config, hf_config.text_config).to(dtype) + self.embed_audio = ( + Gemma4MultimodalEmbedder(hf_config.audio_config, hf_config.text_config).to(dtype) + if hf_config.audio_config is not None else None) + self.register_buffer('embed_scale', torch.tensor(hf_config.hidden_size**0.5).to(dtype), persistent=False) + self.model_cls = Gemma4Model + + def get_inputs_embeds(self, inputs_embeds, **kwargs): + input_ids = kwargs.get('input_ids') + inputs_embeds = inputs_embeds * self.embed_scale.to(inputs_embeds.dtype) + + hf_config = self.hf_config + pixel_values = kwargs.get('pixel_values') + pixel_values_videos = kwargs.get('pixel_values_videos') + input_features = kwargs.get('input_features') + input_features_mask = kwargs.get('input_features_mask') + image_position_ids = kwargs.get('image_position_ids') + video_position_ids = kwargs.get('video_position_ids') + + image_mask = input_ids == hf_config.image_token_id + video_mask = input_ids == hf_config.video_token_id + audio_mask = input_ids == hf_config.audio_token_id + multimodal_mask = image_mask | video_mask | audio_mask + llm_input_ids = input_ids.clone() + llm_input_ids[multimodal_mask] = hf_config.text_config.pad_token_id + + if pixel_values is not None: + with self.patch_hf_config(): + image_features = self.model_cls.get_image_features( + self, pixel_values, image_position_ids, return_dict=True).pooler_output + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + image_mask_e = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter(image_mask_e, image_features) + + if pixel_values_videos is not None: + with self.patch_hf_config(): + video_features = self.model_cls.get_video_features( + self, pixel_values_videos, video_position_ids, return_dict=True).pooler_output + video_features = video_features.to(inputs_embeds.device, inputs_embeds.dtype) + video_mask_e = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter(video_mask_e, video_features) + + if (input_features is not None and input_features_mask is not None and self.audio_tower is not None): + with self.patch_hf_config(): + audio_output = self.model_cls.get_audio_features( + self, input_features, input_features_mask, return_dict=True) + audio_features = audio_output.pooler_output + audio_features = audio_features[audio_output.attention_mask] + audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) + audio_mask_e = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + inputs_embeds = inputs_embeds.masked_scatter(audio_mask_e, audio_features) + return {'inputs_embeds': inputs_embeds, 'llm_input_ids': llm_input_ids} + + +class Gemma4SelfAttention(SelfAttention): + + def __init__( + self, + config: ModelConfig, + submodules: SelfAttentionSubmodules, + layer_number: int, + *args, + **kwargs, + ): + text_config = config.hf_config.text_config + layer_idx = layer_number - 1 + + # Layer type / sliding attention + self.layer_type = text_config.layer_types[layer_idx] + self.is_sliding = self.layer_type == 'sliding_attention' + + # Head dim: global layers may use a different head dim than sliding ones + self.head_dim = ( + text_config.global_head_dim + if not self.is_sliding and text_config.global_head_dim else text_config.head_dim) + + self.use_alternative_attention = (text_config.attention_k_eq_v and not self.is_sliding) + self.num_key_value_heads = ( + text_config.num_global_key_value_heads + if self.use_alternative_attention else text_config.num_key_value_heads) + # Shared KV across the trailing layers + self.num_kv_shared_layers = getattr(text_config, 'num_kv_shared_layers', 0) + first_kv_shared_layer_idx = config.num_layers - self.num_kv_shared_layers + self.is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 + prev_layers = text_config.layer_types[:first_kv_shared_layer_idx] + self.store_full_length_kv = not self.is_kv_shared_layer and layer_idx == len( + prev_layers) - 1 - prev_layers[::-1].index(text_config.layer_types[layer_idx]) + + orig_kv_channels = config.kv_channels + orig_num_query_groups = config.num_query_groups + orig_k_layernorm = submodules.k_layernorm + config.kv_channels = self.head_dim + config.num_query_groups = self.num_key_value_heads + if self.is_sliding and config.window_size is None: + kwargs['attn_mask_type'] = AttnMaskType.arbitrary + if self.is_kv_shared_layer: + submodules.k_layernorm = IdentityOp + try: + super().__init__(config, submodules, layer_number, *args, **kwargs) + finally: + config.kv_channels = orig_kv_channels + config.num_query_groups = orig_num_query_groups + submodules.k_layernorm = orig_k_layernorm + + if self.is_kv_shared_layer or self.use_alternative_attention: + linear_qkv_dim = self.query_projection_size + if not self.is_kv_shared_layer: + linear_qkv_dim += self.kv_projection_size + self.linear_qkv = submodules.linear_qkv( + self.config.hidden_size, + linear_qkv_dim, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=self.config.add_bias_linear or self.config.add_qkv_bias, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='qkv', + tp_group=self.pg_collection.tp, + ) + + if not self.is_kv_shared_layer: + self.v_norm = Gemma4RMSNormNoScale(self.head_dim, eps=self.config.layernorm_epsilon) + + def _forward_core_attention( + self, + query, + key, + value, + attention_mask, + attention_bias: Optional[torch.Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + ): + nvtx_range_push(suffix='core_attention') + attn_mask_type = self.attn_mask_type + if self.checkpoint_core_attention and self.training: + core_attn_out = self._checkpointed_attention_forward( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + else: + core_attn_out = self.core_attention( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': + # reshape to same output shape as unpacked case + # (t, np, hn) -> (t, b=1, h=np*hn) + # t is the pack size = sum (sq_i) + # note that batch is a dummy dimension in the packed case + core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) + nvtx_range_pop(suffix='core_attention') + return core_attn_out + + def _apply_rotary(self, query, key, rotary_pos_emb, packed_seq_params): + nvtx_range_push(suffix='rotary_pos_emb') + q_pos_emb, k_pos_emb = rotary_pos_emb + + if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': + if packed_seq_params.cu_seqlens_q_padded is not None: + cu_seqlens_q = packed_seq_params.cu_seqlens_q_padded + else: + cu_seqlens_q = packed_seq_params.cu_seqlens_q + if packed_seq_params.cu_seqlens_kv_padded is not None: + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv_padded + else: + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv + else: + cu_seqlens_q = cu_seqlens_kv = None + + if q_pos_emb is not None: + # TODO VIJAY: simplify + query = apply_rotary_pos_emb( + query, + q_pos_emb, + config=self.config, + cu_seqlens=cu_seqlens_q, + mscale=_yarn_get_concentration_factor_from_config(self.config), + cp_group=self.pg_collection.cp, + ) + if not self.is_kv_shared_layer and k_pos_emb is not None: + key = apply_rotary_pos_emb( + key, + k_pos_emb, + config=self.config, + cu_seqlens=cu_seqlens_kv, + mscale=_yarn_get_concentration_factor_from_config(self.config), + cp_group=self.pg_collection.cp, + ) + nvtx_range_pop(suffix='rotary_pos_emb') + return query, key + + def forward(self, hidden_states: Tensor, attention_mask: Tensor, **kwargs) -> Tuple[Tensor, Tensor]: + shared_kv_states = kwargs['shared_kv_states'] + rotary_pos_emb = kwargs.get('rotary_pos_emb') + packed_seq_params = kwargs.get('packed_seq_params') + attention_bias = kwargs.get('attention_bias') + mixed_qkv, _ = self.linear_qkv(hidden_states) + if getattr(self, 'world_size', None) is not None and self.config.num_query_groups < self.world_size: + mixed_qkv = all_gather_last_dim_from_tensor_parallel_region(mixed_qkv) + idx = get_tensor_model_parallel_rank() // (self.world_size // self.config.num_query_groups) + size = mixed_qkv.size()[-1] // self.config.num_query_groups + mixed_qkv = mixed_qkv[:, :, idx * size:(idx + 1) * size] + + if self.is_kv_shared_layer: + query = mixed_qkv + key, value = shared_kv_states[self.layer_type] + else: + kv_heads_per_group = 1 if self.use_alternative_attention else 2 + num_query_heads_per_group = (self.num_attention_heads_per_partition // self.num_query_groups_per_partition) + new_tensor_shape = mixed_qkv.size()[:-1] + ( + self.num_query_groups_per_partition, + (num_query_heads_per_group + kv_heads_per_group) * self.hidden_size_per_attention_head, + ) + mixed_qkv = mixed_qkv.view(*new_tensor_shape) + split_arg_list = [num_query_heads_per_group * self.hidden_size_per_attention_head + ] + [self.hidden_size_per_attention_head] * kv_heads_per_group + if SplitAlongDim is not None: + qkv = SplitAlongDim(mixed_qkv, len(split_arg_list), split_arg_list) + else: + qkv = torch.split(mixed_qkv, split_arg_list, dim=3) + if self.use_alternative_attention: + query, key = qkv + value = key + else: + query, key, value = qkv + key = self.k_layernorm(key) + value = self.v_norm(value) + # Query [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] + query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) + if getattr(self, 'world_size', None) is not None and self.config.num_query_groups < self.world_size: + idx = get_tensor_model_parallel_rank() % (self.world_size // self.config.num_query_groups) + size = query.shape[2] // (self.world_size // self.config.num_query_groups) + query = query[:, :, idx * size:(idx + 1) * size, :] + query = self.q_layernorm(query) + if isinstance(rotary_pos_emb, torch.Tensor): + rotary_pos_emb = (rotary_pos_emb, ) * 2 + + query, key = self._apply_rotary(query, key, rotary_pos_emb, packed_seq_params) + if self.store_full_length_kv: + shared_kv_states[self.layer_type] = key, value + core_attn_out = self._forward_core_attention(query, key, value, attention_mask, attention_bias, + packed_seq_params) + + nvtx_range_push(suffix='linear_proj') + output, bias = self.linear_proj(core_attn_out) + nvtx_range_pop(suffix='linear_proj') + return output, bias + + +class Gemma4MLP(MLP): + + def __init__( + self, + config: ModelConfig, + submodules: SelfAttentionSubmodules, + layer_number: int, + *args, + **kwargs, + ): + self.layer_number = layer_number + text_config = config.hf_config.text_config + first_kv_shared_layer_idx = config.num_layers - text_config.num_kv_shared_layers + is_kv_shared_layer = layer_number > first_kv_shared_layer_idx > 0 + use_double_wide_mlp = text_config.use_double_wide_mlp and is_kv_shared_layer + ffn_hidden_size = config.ffn_hidden_size + config.ffn_hidden_size = config.ffn_hidden_size * (2 if use_double_wide_mlp else 1) + try: + super().__init__(config, submodules, *args, **kwargs) + finally: + config.ffn_hidden_size = ffn_hidden_size + + +class Gemma4MoELayer(MoELayer): + + def __init__(self, config, *args, **kwargs): + require_version('megatron-core>=0.16.0.dev', 'Gemma4MoELayer requires megatron-core>=0.16.0') + super().__init__(config, *args, **kwargs) + self.pre_feedforward_layernorm_2 = build_module( + TENorm, hidden_size=config.hidden_size, config=config, eps=config.layernorm_epsilon) + self.norm = Gemma4RMSNormNoScale(config.hidden_size, eps=self.config.layernorm_epsilon) + self.scalar_root_size = config.hidden_size**-0.5 + self.scale = nn.Parameter(torch.ones(config.hidden_size)) + self.per_expert_scale = nn.Parameter(torch.ones(config.num_moe_experts)) + self.scale.sequence_parallel = config.sequence_parallel + self.per_expert_scale.sequence_parallel = config.sequence_parallel + + def route(self, hidden_states: torch.Tensor, *args, **kwargs): + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states * self.scale * self.scalar_root_size + probs, routing_map = super().route(hidden_states) + probs = probs * self.per_expert_scale + return probs, routing_map + + def preprocess(self, hidden_states: torch.Tensor, *args, **kwargs): + hidden_states = self.pre_feedforward_layernorm_2(hidden_states) + return super().preprocess(hidden_states, *args, **kwargs) + + +class Gemma4Bridge(MultimodalGPTBridge): + hf_post_attention_layernorm = 'pre_feedforward_layernorm' + additional_dim0_keys = {'embed_tokens_per_layer', 'per_layer_input_gate', 'per_layer_model_projection'} + additional_dim1_keys = {'per_layer_projection'} + + def __init__(self, config: ModelConfig): + super().__init__(config) + self.text_config = config.hf_config.text_config + self.hidden_size_per_layer_input = self.text_config.hidden_size_per_layer_input + + def _set_qk_layernorm(self, mg_attn, hf_state_dict, to_mcore, **kwargs): + layer_idx = kwargs['layer_idx'] + is_kv_shared_layer = self._get_is_kv_shared_layer(layer_idx) + self._set_state_dict(mg_attn, 'q_layernorm.weight', hf_state_dict, self.hf_q_norm_key, to_mcore) + if not is_kv_shared_layer: + self._set_state_dict(mg_attn, 'k_layernorm.weight', hf_state_dict, self.hf_k_norm_key, to_mcore) + + def _get_is_kv_shared_layer(self, layer_idx): + text_config = self.text_config + num_kv_shared_layers = getattr(text_config, 'num_kv_shared_layers', 0) + first_kv_shared_layer_idx = self.config.num_layers - num_kv_shared_layers + is_kv_shared_layer = layer_idx >= first_kv_shared_layer_idx > 0 + return is_kv_shared_layer + + def _set_qkv(self, mg_attn, hf_state_dict, to_mcore: bool, **kwargs): + text_config = self.text_config + layer_idx = kwargs['layer_idx'] + is_sliding = text_config.layer_types[layer_idx] == 'sliding_attention' + head_dim = ( + text_config.global_head_dim if not is_sliding and text_config.global_head_dim else text_config.head_dim) + is_kv_shared_layer = self._get_is_kv_shared_layer(layer_idx) + if is_kv_shared_layer: + self._set_state_dict(mg_attn, 'linear_qkv.weight', hf_state_dict, 'q_proj.weight', to_mcore) + return hf_state_dict + else: + kwargs['kv_channels'] = head_dim + use_alternative_attention = text_config.attention_k_eq_v and not is_sliding + kwargs['attention_k_eq_v'] = use_alternative_attention + kwargs['num_query_groups'] = ( + text_config.num_global_key_value_heads + if use_alternative_attention else text_config.num_key_value_heads) + return super()._set_qkv(mg_attn, hf_state_dict, to_mcore, **kwargs) + + def _set_router(self, mg_mlp, hf_state_dict, to_mcore): + self._set_state_dict(mg_mlp, 'router.weight', hf_state_dict, 'router.proj.weight', to_mcore) + for key in ['per_expert_scale', 'scale']: + self._set_state_dict(mg_mlp, key, hf_state_dict, f'router.{key}', to_mcore) + + def _set_layer_mlp(self, mg_layer, hf_state_dict, layer_idx: int, to_mcore: bool, is_mtp: bool = False): + mg_mlp = None if mg_layer is None else mg_layer.mlp + hf_state_dict.update(self._set_mlp_state(mg_mlp, hf_state_dict, f'{self.hf_mlp_prefix}.', layer_idx, to_mcore)) + self._set_state_dict(mg_layer, 'mlp.linear_fc1.layer_norm_weight', hf_state_dict, + f'{self.hf_post_attention_layernorm}.weight', to_mcore) + if self.text_config.enable_moe_block: + mg_experts = None if mg_layer is None else mg_layer.experts_mlp + hf_state_dict.update(self._set_moe_state(mg_experts, hf_state_dict, '', layer_idx, to_mcore, is_mtp=is_mtp)) + return hf_state_dict + + def _set_layer_state(self, mg_layer, hf_state_dict, hf_prefix: str, layer_idx: int, to_mcore: bool): + hf_prefix = f'{hf_prefix}{layer_idx}.' + if to_mcore: + hf_state_dict = self._remove_prefix(hf_state_dict, hf_prefix) + else: + hf_state_dict = {} + hf_state_dict.update(self._set_layer_attn(mg_layer, hf_state_dict, layer_idx, to_mcore)) + hf_state_dict.update(self._set_layer_mlp(mg_layer, hf_state_dict, layer_idx, to_mcore)) + for key in ['post_attention_layernorm', 'post_feedforward_layernorm']: + self._set_state_dict(mg_layer, f'{key}.weight', hf_state_dict, f'{key}.weight', to_mcore) + if self.hidden_size_per_layer_input: + for key in ['per_layer_input_gate', 'per_layer_projection', 'post_per_layer_input_norm']: + self._set_state_dict(mg_layer, f'{key}.weight', hf_state_dict, f'{key}.weight', to_mcore) + self._set_state_dict(mg_layer, 'layer_scalar', hf_state_dict, 'layer_scalar', to_mcore) + if self.text_config.enable_moe_block: + for key in ['post_feedforward_layernorm_1', 'post_feedforward_layernorm_2']: + self._set_state_dict(mg_layer, f'{key}.weight', hf_state_dict, f'{key}.weight', to_mcore) + self._set_state_dict(mg_layer, 'experts_mlp.pre_feedforward_layernorm_2.weight', hf_state_dict, + 'pre_feedforward_layernorm_2.weight', to_mcore) + if to_mcore: + hf_state_dict = {} + else: + hf_state_dict = self._add_prefix(hf_state_dict, hf_prefix) + return hf_state_dict + + def _set_word_embeddings(self, mg_model, hf_state_dict, to_mcore): + lm_model = getattr(mg_model, 'language_model') if self.is_multimodal else mg_model + self._set_state_dict(lm_model, 'embedding.word_embeddings.weight', hf_state_dict, self.hf_embed_key, to_mcore) + if self.hidden_size_per_layer_input: + for key in ['embed_tokens_per_layer', 'per_layer_model_projection', 'per_layer_projection_norm']: + self._set_state_dict(lm_model, f'{key}.weight', hf_state_dict, f'model.language_model.{key}.weight', + to_mcore) + + +class Gemma4TextGPTModel(GPTModel): + extra_forward_keys = ['mm_token_type_ids'] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.num_query_groups_per_partition = self.decoder.layers[0].self_attention.num_query_groups_per_partition + text_config = self.config.hf_config.text_config + self.text_config = text_config + self.num_kv_shared_layers = getattr(text_config, 'num_kv_shared_layers', 0) + self.unique_layer_types = set(text_config.layer_types) + self.hidden_size_per_layer_input = text_config.hidden_size_per_layer_input + self.final_logit_softcapping = text_config.final_logit_softcapping + if self.hidden_size_per_layer_input and self.pre_process: + total_dim = self.config.num_layers * self.hidden_size_per_layer_input + self.embed_tokens_per_layer = VocabParallelEmbedding( + num_embeddings=self.vocab_size, + embedding_dim=total_dim, + init_method=self.config.init_method, + config=self.config, + tp_group=self.pg_collection.tp, + ) + self.embed_tokens_per_layer_scale = self.hidden_size_per_layer_input**0.5 + self.per_layer_input_scale = 2.0**-0.5 + self.per_layer_model_projection = build_module( + TEColumnParallelLinear, + self.config.hidden_size, + total_dim, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=False, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='per_layer_model_projection', + tp_group=self.pg_collection.tp, + ) + self.per_layer_model_projection_scale = self.config.hidden_size**-0.5 + self.per_layer_projection_norm = build_module( + TENorm, + hidden_size=self.hidden_size_per_layer_input, + config=self.config, + eps=self.config.layernorm_epsilon, + ) + + def _get_rotary_pos_emb(self, decoder_input, position_ids, packed_seq_params, inference_context=None): + rotary_seq_len = RotaryEmbedding.get_rotary_seq_len(self, inference_context, self.decoder, decoder_input, + self.config, packed_seq_params) + packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq) + full_rotary_pos_emb = self.full_rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq) + rotary_pos_emb = {'sliding_attention': rotary_pos_emb, 'full_attention': full_rotary_pos_emb} + return rotary_pos_emb, None, None + + def _set_inv_freq(self): + rope_scaling = self.config.rope_scaling + self.config.rope_scaling = rope_scaling['sliding_attention'] + new_inv_freq, attention_scaling = get_rope_inv_freq(self.config) + assert attention_scaling == 1, 'not support' + self.rotary_pos_emb.inv_freq = new_inv_freq.to(self.rotary_pos_emb.inv_freq.device) + # full + self.full_rotary_pos_emb = copy.copy(self.rotary_pos_emb) + self.config.rope_scaling = rope_scaling['full_attention'] + kwargs = {'layer_type': 'full_attention'} + if self.config.rope_scaling['rope_type'] == 'proportional': + kwargs['head_dim_key'] = 'global_head_dim' + new_inv_freq, attention_scaling = get_rope_inv_freq( + self.config, text_config=self.config.hf_config.text_config, **kwargs) + assert attention_scaling == 1, 'not support' + self.full_rotary_pos_emb.inv_freq = new_inv_freq + self.attention_scaling = attention_scaling + + self.config.rope_scaling = rope_scaling + + def forward(self, *args, **kwargs): + extra_block_kwargs = kwargs.pop('extra_block_kwargs', None) or {} + llm_input_ids = extra_block_kwargs.pop('llm_input_ids', None) + decoder_input = kwargs.get('decoder_input') + mm_token_type_ids = extra_block_kwargs.pop('mm_token_type_ids', None) + shared_kv_states = {} + if self.hidden_size_per_layer_input: + assert self.num_kv_shared_layers > 0, 'not support' + if decoder_input is None: + per_layer_inputs, shared_kv_states = self.unpack_pp_input() + else: + inputs_embeds = decoder_input + per_layer_inputs = self.embed_tokens_per_layer(llm_input_ids) * self.embed_tokens_per_layer_scale + per_layer_inputs = per_layer_inputs.reshape(*per_layer_inputs.shape[:-1], self.config.num_layers, + -1).transpose(0, 1) + per_layer_projection = self.per_layer_model_projection( + inputs_embeds)[0] * self.per_layer_model_projection_scale + per_layer_projection = gather_from_tensor_model_parallel_region(per_layer_projection) + per_layer_projection = per_layer_projection.reshape(*per_layer_projection.shape[:-1], + self.config.num_layers, -1) + per_layer_projection = self.per_layer_projection_norm(per_layer_projection) + per_layer_inputs = (per_layer_projection + per_layer_inputs) * self.per_layer_input_scale + per_layer_inputs = scatter_to_tensor_model_parallel_region(per_layer_inputs) + extra_block_kwargs['per_layer_inputs'] = per_layer_inputs + else: + assert self.num_kv_shared_layers == 0, 'not support' + extra_block_kwargs['shared_kv_states'] = shared_kv_states + kwargs['extra_block_kwargs'] = extra_block_kwargs + attention_mask = kwargs.get('attention_mask') + kwargs['attention_mask'] = {'sliding_attention': attention_mask, 'full_attention': attention_mask} + if self.text_config.use_bidirectional_attention == 'vision': + kwargs['attention_mask']['sliding_attention'] = self._create_sliding_attention_mask( + attention_mask, mm_token_type_ids) + hidden_states = super().forward(*args, **kwargs) + if self.hidden_size_per_layer_input and not self.post_process: + hidden_states = self._pack_pp_output(hidden_states, per_layer_inputs, shared_kv_states) + return hidden_states + + def _create_sliding_attention_mask(self, attention_mask, mm_token_type_ids): + window_size = self.text_config.sliding_window - 1 + seq_len = attention_mask.shape[-1] + + window_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=attention_mask.device) + window_mask = ~torch.triu(window_mask, diagonal=-window_size) + + attention_mask = attention_mask | window_mask + if mm_token_type_ids is not None: + is_vision = mm_token_type_ids > 0 + is_prev_vision = torch.roll(is_vision, shifts=1, dims=-1) + is_prev_vision[:, 0] = False + vision_group_ids = torch.cumsum((is_vision & ~is_prev_vision).int(), dim=1) - 1 + vision_group_ids = torch.where(is_vision, vision_group_ids, torch.full_like(vision_group_ids, -1)) + q_group = vision_group_ids.unsqueeze(1).unsqueeze(-1) + k_group = vision_group_ids.unsqueeze(1).unsqueeze(-2) + same_vision_group = (q_group == k_group) & (q_group >= 0) & (k_group >= 0) + attention_mask = attention_mask & ~same_vision_group + return attention_mask + + def _pack_pp_output(self, hidden_states, per_layer_inputs, shared_kv_states): + per_layer_inputs = per_layer_inputs.view(*hidden_states.shape[:2], -1) + hidden_states = torch.concat([hidden_states, per_layer_inputs], dim=-1) + flag = per_layer_inputs.new_zeros(*hidden_states.shape[:2], 2) + if 'sliding_attention' in shared_kv_states: + flag[0, 0, 0] = 1 + sliding_states = torch.concat(shared_kv_states['sliding_attention'], -1) + sliding_states = sliding_states.view(*hidden_states.shape[:2], -1) + hidden_states = torch.concat([hidden_states, sliding_states], dim=-1) + if 'full_attention' in shared_kv_states: + flag[0, 0, 1] = 1 + full_states = torch.concat(shared_kv_states['full_attention'], -1) + full_states = full_states.view(*hidden_states.shape[:2], -1) + hidden_states = torch.concat([hidden_states, full_states], dim=-1) + hidden_states = torch.concat([hidden_states, flag], dim=-1) + return hidden_states + + def unpack_pp_input(self): + tp_size = self.config.tensor_model_parallel_size + shared_kv_states = {} + input_tensor = self.get_input_tensor() + sequence_len = input_tensor.shape[0] * tp_size if self.config.sequence_parallel else input_tensor.shape[0] + input_tensor, flag = input_tensor.split([input_tensor.shape[-1] - 2, 2], dim=-1) + flag = flag.detach() + per_layer_inputs_shape = [ + sequence_len, input_tensor.shape[1], self.config.num_layers, self.hidden_size_per_layer_input // tp_size + ] + full_head_dim = self.text_config.global_head_dim + full_states_shape = per_layer_inputs_shape[:2] + [self.num_query_groups_per_partition, full_head_dim * 2] + sliding_states_shape = full_states_shape[:3] + [self.config.kv_channels * 2] + per_layer_inputs_dim = math.prod(per_layer_inputs_shape) // math.prod(input_tensor.shape[:2]) + full_states_dim = math.prod(full_states_shape) // math.prod(input_tensor.shape[:2]) + sliding_states_dim = math.prod(sliding_states_shape) // math.prod(input_tensor.shape[:2]) + if flag[0, 0, 1].item() != 0: + input_tensor, full_states = input_tensor.split([input_tensor.shape[-1] - full_states_dim, full_states_dim], + dim=-1) + full_states = full_states.reshape(*full_states_shape) + shared_kv_states['full_attention'] = full_states.chunk(2, -1) + if flag[0, 0, 0].item() != 0: + input_tensor, sliding_states = input_tensor.split( + [input_tensor.shape[-1] - sliding_states_dim, sliding_states_dim], dim=-1) + sliding_states = sliding_states.reshape(*sliding_states_shape) + shared_kv_states['sliding_attention'] = sliding_states.chunk(2, -1) + input_tensor, per_layer_inputs = input_tensor.split( + [input_tensor.shape[-1] - per_layer_inputs_dim, per_layer_inputs_dim], dim=-1) + self.set_input_tensor(input_tensor) + per_layer_inputs = per_layer_inputs.reshape(*per_layer_inputs_shape) + return per_layer_inputs, shared_kv_states + + def _forward_output_layer(self, hidden_states, *args, **kwargs): + logits, _ = self.output_layer(hidden_states, *args, **kwargs) + if self.final_logit_softcapping is not None: + logits = logits / self.final_logit_softcapping + logits = torch.tanh(logits) + logits = logits * self.final_logit_softcapping + return logits + + +class Gemma4TransformerLayer(TransformerLayer): + + def __init__(self, config, submodules, *args, **kwargs): + super().__init__(config, submodules, *args, **kwargs) + text_config = config.hf_config.text_config + self.enable_moe_block = text_config.enable_moe_block + if self.enable_moe_block: + self.experts_mlp = self._build_mlp(submodules.experts_mlp) + hidden_size = self.config.hidden_size + eps = self.config.layernorm_epsilon + + self.post_attention_layernorm = build_module(TENorm, hidden_size=hidden_size, config=self.config, eps=eps) + self.post_feedforward_layernorm = build_module(TENorm, hidden_size=hidden_size, config=self.config, eps=eps) + + self.register_buffer('layer_scalar', torch.ones(1)) + + self.hidden_size_per_layer_input = text_config.hidden_size_per_layer_input + if self.hidden_size_per_layer_input: + from transformers.activations import ACT2FN + self.act_fn = ACT2FN[text_config.hidden_activation] + self.per_layer_input_gate = build_module( + TEColumnParallelLinear, + hidden_size, + self.hidden_size_per_layer_input, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=False, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name='per_layer_input_gate', + tp_group=self.pg_collection.tp, + ) + self.per_layer_projection = build_module( + TERowParallelLinear, + self.hidden_size_per_layer_input, + hidden_size, + config=self.config, + init_method=self.config.output_layer_init_method, + bias=False, + input_is_parallel=True, + skip_bias_add=True, + is_expert=False, + tp_comm_buffer_name='per_layer_projection', + tp_group=self.pg_collection.tp, + ) + self.post_per_layer_input_norm = build_module(TENorm, hidden_size=hidden_size, config=self.config, eps=eps) + + if self.enable_moe_block: + self.post_feedforward_layernorm_1 = build_module( + TENorm, hidden_size=hidden_size, config=self.config, eps=eps) + self.post_feedforward_layernorm_2 = build_module( + TENorm, hidden_size=hidden_size, config=self.config, eps=eps) + + def _forward_attention(self, hidden_states: Tensor, **kwargs): + context = kwargs.pop('context', None) + residual = hidden_states + input_layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + nvtx_range_push(suffix='self_attention') + attention_output, bias = self.self_attention(input_layernorm_output, **kwargs) + nvtx_range_pop(suffix='self_attention') + attention_output = self.post_attention_layernorm(attention_output) + + hidden_states = self.self_attn_bda(self.training, self.config.bias_dropout_fusion)( + (attention_output, bias), residual, self.hidden_dropout) + return hidden_states, context + + def _forward_mlp(self, hidden_states, inference_context=None, padding_mask=None): + # Residual connection. + residual = hidden_states + + # Optional Layer norm post the cross-attention. + pre_mlp_layernorm_output = self._forward_pre_mlp_layernorm(hidden_states) + mlp_output, bias = self.mlp(pre_mlp_layernorm_output, padding_mask=padding_mask) + if self.enable_moe_block: + mlp_output_1 = self.post_feedforward_layernorm_1(mlp_output) + mlp_output_2, bias = self.experts_mlp(residual, padding_mask=padding_mask) + mlp_output_2 = self.post_feedforward_layernorm_2(mlp_output_2) + + # Combine mlp and moe outputs + mlp_output = mlp_output_1 + mlp_output_2 + + mlp_output = self.post_feedforward_layernorm(mlp_output) + hidden_states = self.mlp_bda(self.training, self.config.bias_dropout_fusion)((mlp_output, bias), residual, + self.hidden_dropout) + output = make_viewless_tensor(inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True) + return output + + def forward(self, hidden_states, *args, **kwargs): + per_layer_input = kwargs.pop('per_layer_input', None) + hidden_states, context = super().forward(hidden_states, *args, **kwargs) + if self.hidden_size_per_layer_input: + residual = hidden_states + hidden_states, _ = self.per_layer_input_gate(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = hidden_states * per_layer_input + hidden_states, _ = self.per_layer_projection(hidden_states) + hidden_states = self.post_per_layer_input_norm(hidden_states) + hidden_states = residual + hidden_states + + hidden_states *= self.layer_scalar + return hidden_states, context + + +class Gemma4GPTModel(MultimodalGPTModel): + language_model_cls = Gemma4TextGPTModel + + +class Gemma4TransformerBlock(TransformerBlock): + + def _layer_forward(self, layer, hidden_states, **kwargs): + layer_number = layer.layer_number - 1 + per_layer_inputs = kwargs.pop('per_layer_inputs', None) + if per_layer_inputs is not None: + kwargs['per_layer_input'] = per_layer_inputs[:, :, layer_number] + layer_type = self.config.hf_config.text_config.layer_types[layer_number] + kwargs['rotary_pos_emb'] = kwargs['rotary_pos_emb'][layer_type] + kwargs['attention_mask'] = kwargs['attention_mask'][layer_type] + return super()._layer_forward(layer, hidden_states, **kwargs) + + +class Gemma4Loader(ModelLoader): + model_cls = Gemma4GPTModel + transformer_block = Gemma4TransformerBlock + + def get_transformer_layer_spec(self, vp_stage: Optional[int] = None): + num_moe_experts = self.config.num_moe_experts + self.config.num_moe_experts = None + layer_specs = get_gpt_decoder_block_spec( + self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage) + for layer_spec in layer_specs.layer_specs: + layer_spec.submodules.self_attention.module = Gemma4SelfAttention + layer_spec.submodules.mlp.module = Gemma4MLP + if num_moe_experts is not None: + self.config.num_moe_experts = num_moe_experts + moe_layer_specs = get_gpt_decoder_block_spec( + self.config, use_transformer_engine=True, normalization=self.config.normalization, vp_stage=vp_stage) + for layer_spec, moe_layer_spec in zip(layer_specs.layer_specs, moe_layer_specs.layer_specs): + layer_spec.submodules.experts_mlp = moe_layer_spec.submodules.mlp + layer_spec.submodules.experts_mlp.module = Gemma4MoELayer + return layer_specs + + def _set_transformer_layer(self, transformer_layer_spec): + for layer_spec in transformer_layer_spec.layer_specs: + layer_spec.module = Gemma4TransformerLayer + + +register_model( + ModelMeta( + ModelType.gemma4, + ['gemma4'], + bridge_cls=Gemma4Bridge, + visual_cls=Gemma4Vit, + loader=Gemma4Loader, + )) diff --git a/src/mcore_bridge/model/modules/transformer_block.py b/src/mcore_bridge/model/modules/transformer_block.py index 073f671..5d36431 100644 --- a/src/mcore_bridge/model/modules/transformer_block.py +++ b/src/mcore_bridge/model/modules/transformer_block.py @@ -19,6 +19,51 @@ apply_module = None +class _TensorIdx: + """Sentinel that marks a position in the flatten schema as a tensor index.""" + __slots__ = ('idx', ) + + def __init__(self, idx): + self.idx = idx + + +def _checkpoint_flatten(obj, tensors): + """Recursively flatten a nested structure (dict/tuple/list/Tensor) into a schema. + + Tensors are appended to `tensors` and replaced in the schema by a _TensorIdx sentinel. + Non-tensor leaves (int, bool, None, str, ...) are stored as-is. + The schema mirrors the original structure and is captured in the checkpoint closure. + """ + if torch.is_tensor(obj): + idx = len(tensors) + tensors.append(obj) + return _TensorIdx(idx) + elif isinstance(obj, dict): + # inplace (gemma4 shared_kv_states) + for k, v in obj.items(): + obj[k] = _checkpoint_flatten(v, tensors) + return obj + elif isinstance(obj, (tuple, list)): + return type(obj)(_checkpoint_flatten(v, tensors) for v in obj) + else: + return obj # non-tensor leaf: stored directly in schema + + +def _checkpoint_unflatten(schema, tensors): + """Reconstruct the original structure from a schema and a flat tensors list.""" + if isinstance(schema, _TensorIdx): + return tensors[schema.idx] + elif isinstance(schema, dict): + # inplace (gemma4 shared_kv_states) + for k, v in schema.items(): + schema[k] = _checkpoint_unflatten(v, tensors) + return schema + elif isinstance(schema, (tuple, list)): + return type(schema)(_checkpoint_unflatten(v, tensors) for v in schema) + else: + return schema # non-tensor leaf + + # Code borrowed from NVIDIA/Megatron-LM class TransformerBlock(McoreTransformerBlock): @@ -99,26 +144,25 @@ def custom_forward( return custom_forward - # `tensor_parallel.checkpoint` / `te_checkpoint` only forward *args to the - # wrapped function (torch.utils.checkpoint limitation). Convert kwargs to - # positional args by capturing the keys in closure so tensor kwargs (e.g. - # qwen3-vl's visual_pos_masks / deepstack_visual_embeds) can flow through - # activation recompute and remain in the autograd graph. + # Variables that don't require gradients can be captured via closure. + _ckpt_attention_mask = attention_mask + _ckpt_rotary_pos_emb = rotary_pos_emb extra_kwargs_keys = tuple(kwargs.keys()) - extra_kwargs_values = tuple(kwargs.values()) + _extra_flat_tensors = [] + _extra_schemas = [_checkpoint_flatten(v, _extra_flat_tensors) for v in kwargs.values()] def checkpoint_handler(forward_func): """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" - def wrapped_forward(hidden_states, attention_mask, context, context_mask, rotary_pos_emb, padding_mask, - *extra_args): - extra_kwargs = dict(zip(extra_kwargs_keys, extra_args)) + def wrapped_forward(hidden_states, context, context_mask, padding_mask, *extra_flat): + rebuilt = [_checkpoint_unflatten(s, extra_flat) for s in _extra_schemas] + extra_kwargs = dict(zip(extra_kwargs_keys, rebuilt)) return forward_func( hidden_states, - attention_mask, + _ckpt_attention_mask, context, context_mask, - rotary_pos_emb, + _ckpt_rotary_pos_emb, padding_mask, **extra_kwargs, ) @@ -131,24 +175,20 @@ def wrapped_forward(hidden_states, attention_mask, context, context_mask, rotary tensor_parallel.random.get_cuda_rng_tracker, self.pg_collection.tp, hidden_states, - attention_mask, context, context_mask, - rotary_pos_emb, padding_mask, - *extra_kwargs_values, + *_extra_flat_tensors, ) else: return tensor_parallel.checkpoint( wrapped_forward, self.config.distribute_saved_activations, hidden_states, - attention_mask, context, context_mask, - rotary_pos_emb, padding_mask, - *extra_kwargs_values, + *_extra_flat_tensors, ) if self.config.recompute_method == 'uniform': diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index 21b8cfa..6d500e0 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -134,8 +134,6 @@ def __init__( # [Module 8: MLP block] self.mlp = self._build_mlp(submodules.mlp) - if hasattr(self.mlp, 'set_layer_number'): - self.mlp.set_layer_number(self.layer_number) # [Module 9: BiasDropoutFusion] self.mlp_bda = build_module(submodules.mlp_bda) self.is_moe_layer = isinstance(self.mlp, MoELayer) @@ -216,13 +214,14 @@ def _build_mlp(self, mlp_spec): additional_mlp_kwargs = {} # import here to avoid circular import from mcore_bridge.model.gpts.glm4 import Glm4MLP + from mcore_bridge.model.mm_gpts.gemma4 import Gemma4MLP, Gemma4MoELayer # MLP expects tp_group but MoELayer expects pg_collection to be passed in. # We can change MLP to accept pg_collection but it makes the logic implicit # The conditional below is to make the logic explicit # if smlp_spec is not a ModuleSpec,we dont have to handle passing additional kwargs if isinstance(mlp_spec, ModuleSpec): - if mlp_spec.module in (MoELayer, TEGroupedMLP, SequentialMLP): + if mlp_spec.module in (MoELayer, Gemma4MoELayer, TEGroupedMLP, SequentialMLP): additional_mlp_kwargs['pg_collection'] = pg_collection # Pass is_mtp_layer flag to MoELayer to distinguish MTP MoE layers. if mlp_spec.module == MoELayer and 'is_mtp_layer' in inspect.signature(MoELayer).parameters: @@ -230,12 +229,17 @@ def _build_mlp(self, mlp_spec): elif mlp_spec.module in (MLP, Glm4MLP): assert hasattr(pg_collection, 'tp'), 'TP process group is required for MLP in TransformerLayer' additional_mlp_kwargs['tp_group'] = pg_collection.tp + elif mlp_spec.module == Gemma4MLP: + additional_mlp_kwargs['layer_number'] = self.layer_number elif TEFusedMLP is not None and mlp_spec.module == TEFusedMLP: assert hasattr(pg_collection, 'tp'), 'TP process group is required for TEFusedMLP in TransformerLayer' additional_mlp_kwargs['tp_group'] = pg_collection.tp else: logger.warning_once(f'Unknown MLP type: {mlp_spec.module}. Using default kwargs.') - return build_module(mlp_spec, config=self.config, **additional_mlp_kwargs) + mlp = build_module(mlp_spec, config=self.config, **additional_mlp_kwargs) + if hasattr(mlp, 'set_layer_number'): + mlp.set_layer_number(self.layer_number) + return mlp def forward(self, *args, **kwargs): """ diff --git a/src/mcore_bridge/model/register.py b/src/mcore_bridge/model/register.py index 84bfdb4..22ad26d 100644 --- a/src/mcore_bridge/model/register.py +++ b/src/mcore_bridge/model/register.py @@ -118,7 +118,7 @@ def _set_shared_expert_gate(self, transformer_layer_spec): if hasattr(layer_spec.submodules.mlp.submodules, 'shared_experts'): layer_spec.submodules.mlp.submodules.shared_experts.params = {'gate': True} - def _set_custom_layer(self, transformer_layer_spec): + def _set_transformer_layer(self, transformer_layer_spec): for layer_spec in transformer_layer_spec.layer_specs: layer_spec.module = TransformerLayer @@ -130,7 +130,7 @@ def build_model( ) -> Union['GPTModel', 'MultimodalGPTModel']: transformer_layer_spec = self.get_transformer_layer_spec(vp_stage=vp_stage) self._set_shared_expert_gate(transformer_layer_spec) - self._set_custom_layer(transformer_layer_spec) + self._set_transformer_layer(transformer_layer_spec) mtp_block_spec = None if self.config.mtp_num_layers is not None: mtp_block_spec = self.get_mtp_block_spec(transformer_layer_spec, vp_stage=vp_stage) diff --git a/src/mcore_bridge/model/rope.py b/src/mcore_bridge/model/rope.py index e7db3c3..a514f43 100644 --- a/src/mcore_bridge/model/rope.py +++ b/src/mcore_bridge/model/rope.py @@ -106,12 +106,13 @@ def _get_rope_type(rope_scaling: Optional[Dict[str, Any]]): return rope_type -def get_rope_inv_freq(config, seq_len=None, **kwargs): +def get_rope_inv_freq(config, seq_len=None, text_config=None, **kwargs): from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS ROPE_INIT_FUNCTIONS.update(EXTENDED_ROPE_INIT_FUNCTIONS) - dummy_config = _get_dummy_config(config) + if text_config is None: + text_config = _get_dummy_config(config) rope_init_fn = ROPE_INIT_FUNCTIONS[_get_rope_type(config.rope_scaling)] - inv_freq, attention_scaling = rope_init_fn(dummy_config, 'cpu', seq_len=seq_len, **kwargs) + inv_freq, attention_scaling = rope_init_fn(text_config, 'cpu', seq_len=seq_len, **kwargs) if attention_scaling is None: attention_scaling = 1. return inv_freq, attention_scaling diff --git a/tests/test_mllm.py b/tests/test_mllm.py index 1fe103e..5eb07d2 100644 --- a/tests/test_mllm.py +++ b/tests/test_mllm.py @@ -116,6 +116,10 @@ def test_qwen3_asr(): _test_model('Qwen/Qwen3-ASR-1.7B') +def test_gemma4(): + _test_model('google/gemma-4-E2B-it') + + if __name__ == '__main__': # test_qwen2_5_vl() # test_qwen2_vl() @@ -136,4 +140,5 @@ def test_qwen3_asr(): # test_llama4() # test_qwen3_5() # test_llava_onevision1_5() - test_qwen3_asr() + # test_qwen3_asr() + test_gemma4()