From e0ec17a56e1b3eaa932f4e63c56c82eb4984ee19 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 20 May 2026 17:03:02 +0800 Subject: [PATCH 1/2] compat megatron dev --- src/mcore_bridge/config/parser.py | 1 - src/mcore_bridge/model/gpt_model.py | 4 +++- .../model/modules/transformer_layer.py | 20 ++++++++++++------- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/mcore_bridge/config/parser.py b/src/mcore_bridge/config/parser.py index 842c1bb..1bd4b91 100644 --- a/src/mcore_bridge/config/parser.py +++ b/src/mcore_bridge/config/parser.py @@ -139,7 +139,6 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]: if llm_model_type != 'deepseek': res['qk_layernorm'] = True res['moe_router_load_balancing_type'] = 'seq_aux_loss' - res.pop('num_query_groups', None) # https://github.com/NVIDIA/Megatron-LM/issues/1475 if llm_model_type == 'dots1': res['moe_router_score_function'] = 'sigmoid' elif llm_model_type == 'deepseek_v32': diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 84bf56d..2c14961 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -149,7 +149,7 @@ def _apply_rotary_pos_emb_bshd( t: torch.Tensor, freqs: torch.Tensor, rotary_interleaved: bool = False, - multi_latent_attention: bool = False, # not use + multi_latent_attention: Optional[bool] = None, mscale: float = 1.0, **kwargs, ) -> torch.Tensor: @@ -169,6 +169,8 @@ def _apply_rotary_pos_emb_bshd( # ideally t_pass is empty so rotary pos embedding is applied to all tensor t t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + if multi_latent_attention is None: + multi_latent_attention = self.config.multi_latent_attention if multi_latent_attention: x1 = t[..., 0::2] x2 = t[..., 1::2] diff --git a/src/mcore_bridge/model/modules/transformer_layer.py b/src/mcore_bridge/model/modules/transformer_layer.py index 6d500e0..ce742c4 100644 --- a/src/mcore_bridge/model/modules/transformer_layer.py +++ b/src/mcore_bridge/model/modules/transformer_layer.py @@ -191,13 +191,19 @@ def can_recompute_pre_mlp_layernorm_for_cudagraph(): if 'mlp' in self.config.recompute_modules: if not self.is_moe_layer: self.recompute_mlp = True - if hasattr(self.config, 'fine_grained_activation_offloading'): - self.offload_attn_norm = ( - self.config.fine_grained_activation_offloading and 'attn_norm' in self.config.offload_modules - and not isinstance(self.input_layernorm, IdentityOp)) - self.offload_mlp_norm = ( - self.config.fine_grained_activation_offloading and 'mlp_norm' in self.config.offload_modules - and not isinstance(self.pre_mlp_layernorm, IdentityOp)) + if hasattr(self, '_set_offload_modules'): + from megatron.core.transformer.transformer_layer import _get_offloading_interface + self._set_offload_modules() + self.off_interface = _get_offloading_interface() + self.mlp_norm_manager = None + else: + if hasattr(self.config, 'fine_grained_activation_offloading'): + self.offload_attn_norm = ( + self.config.fine_grained_activation_offloading and 'attn_norm' in self.config.offload_modules + and not isinstance(self.input_layernorm, IdentityOp)) + self.offload_mlp_norm = ( + self.config.fine_grained_activation_offloading and 'mlp_norm' in self.config.offload_modules + and not isinstance(self.pre_mlp_layernorm, IdentityOp)) # @jcasper how should we handle nvfuser? # Set bias+dropout+add fusion grad_enable execution handler. From ed6b565dba8a2cdd36cf5600e512522f865410ae Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 20 May 2026 17:05:49 +0800 Subject: [PATCH 2/2] lint pass --- src/mcore_bridge/model/gpt_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 2c14961..04abb4d 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -170,7 +170,7 @@ def _apply_rotary_pos_emb_bshd( # ideally t_pass is empty so rotary pos embedding is applied to all tensor t t, t_pass = t[..., :rot_dim], t[..., rot_dim:] if multi_latent_attention is None: - multi_latent_attention = self.config.multi_latent_attention + multi_latent_attention = self.config.multi_latent_attention if multi_latent_attention: x1 = t[..., 0::2] x2 = t[..., 1::2]