Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/mcore_bridge/config/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]:
if llm_model_type != 'deepseek':
res['qk_layernorm'] = True
res['moe_router_load_balancing_type'] = 'seq_aux_loss'
res.pop('num_query_groups', None) # https://github.com/NVIDIA/Megatron-LM/issues/1475
if llm_model_type == 'dots1':
res['moe_router_score_function'] = 'sigmoid'
elif llm_model_type == 'deepseek_v32':
Expand Down
4 changes: 3 additions & 1 deletion src/mcore_bridge/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _apply_rotary_pos_emb_bshd(
t: torch.Tensor,
freqs: torch.Tensor,
rotary_interleaved: bool = False,
multi_latent_attention: bool = False, # not use
multi_latent_attention: Optional[bool] = None,
mscale: float = 1.0,
**kwargs,
) -> torch.Tensor:
Expand All @@ -169,6 +169,8 @@ def _apply_rotary_pos_emb_bshd(

# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
if multi_latent_attention is None:
multi_latent_attention = self.config.multi_latent_attention
Comment on lines +172 to +173
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The monkey-patched _apply_rotary_pos_emb_bshd function captures the self instance of the first GPTModel that triggers the patch. Since the patch is only applied once (due to the check at line 144), all subsequent GPTModel instances will use the multi_latent_attention configuration from the first instance, regardless of their own configuration. This will cause incorrect behavior in multi-model environments (e.g., different models in the same process or complex pipeline parallel setups).

if multi_latent_attention:
x1 = t[..., 0::2]
x2 = t[..., 1::2]
Expand Down
20 changes: 13 additions & 7 deletions src/mcore_bridge/model/modules/transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,19 @@ def can_recompute_pre_mlp_layernorm_for_cudagraph():
if 'mlp' in self.config.recompute_modules:
if not self.is_moe_layer:
self.recompute_mlp = True
if hasattr(self.config, 'fine_grained_activation_offloading'):
self.offload_attn_norm = (
self.config.fine_grained_activation_offloading and 'attn_norm' in self.config.offload_modules
and not isinstance(self.input_layernorm, IdentityOp))
self.offload_mlp_norm = (
self.config.fine_grained_activation_offloading and 'mlp_norm' in self.config.offload_modules
and not isinstance(self.pre_mlp_layernorm, IdentityOp))
if hasattr(self, '_set_offload_modules'):
from megatron.core.transformer.transformer_layer import _get_offloading_interface
self._set_offload_modules()
self.off_interface = _get_offloading_interface()
self.mlp_norm_manager = None
else:
if hasattr(self.config, 'fine_grained_activation_offloading'):
self.offload_attn_norm = (
self.config.fine_grained_activation_offloading and 'attn_norm' in self.config.offload_modules
and not isinstance(self.input_layernorm, IdentityOp))
self.offload_mlp_norm = (
self.config.fine_grained_activation_offloading and 'mlp_norm' in self.config.offload_modules
and not isinstance(self.pre_mlp_layernorm, IdentityOp))

# @jcasper how should we handle nvfuser?
# Set bias+dropout+add fusion grad_enable execution handler.
Expand Down
Loading