[model] Support bailing v2 5#85
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for the bailing_hybrid model, including its configuration mapping and a specialized loader that handles hybrid attention layers. Review feedback highlights the need for safer and more efficient logic when retrieving transformer layer specifications, specifically recommending a try...finally block to ensure configuration state is restored. Additionally, it was suggested to remove redundant method overrides in the LinearAttention class and clean up several unused imports.
| layer_specs = super().get_transformer_layer_spec(vp_stage=vp_stage) | ||
| multi_latent_attention = self.config.multi_latent_attention | ||
| self.config.multi_latent_attention = False | ||
| linear_layer_specs = super().get_transformer_layer_spec(vp_stage=vp_stage) | ||
| self.config.multi_latent_attention = multi_latent_attention |
There was a problem hiding this comment.
The current implementation for getting linear_layer_specs by temporarily modifying self.config.multi_latent_attention has a couple of issues:
- Safety: If
super().get_transformer_layer_spec()raises an exception,self.config.multi_latent_attentionwill not be restored to its original value. This could lead to unexpected behavior in subsequent operations. Using atry...finallyblock is recommended for safety. - Efficiency:
super().get_transformer_layer_spec()is called twice. Ifself.config.multi_latent_attentionisFalseto begin with, both calls are identical, which is redundant and inefficient.
Consider refactoring this logic to be safer and more efficient.
| layer_specs = super().get_transformer_layer_spec(vp_stage=vp_stage) | |
| multi_latent_attention = self.config.multi_latent_attention | |
| self.config.multi_latent_attention = False | |
| linear_layer_specs = super().get_transformer_layer_spec(vp_stage=vp_stage) | |
| self.config.multi_latent_attention = multi_latent_attention | |
| multi_latent_attention = self.config.multi_latent_attention | |
| if multi_latent_attention: | |
| layer_specs = super().get_transformer_layer_spec(vp_stage=vp_stage) | |
| try: | |
| self.config.multi_latent_attention = False | |
| linear_layer_specs = super().get_transformer_layer_spec(vp_stage=vp_stage) | |
| finally: | |
| self.config.multi_latent_attention = multi_latent_attention | |
| else: | |
| layer_specs = super().get_transformer_layer_spec(vp_stage=vp_stage) | |
| linear_layer_specs = layer_specs |
| class LinearAttention(SelfAttention): | ||
| def __init__( | ||
| self, | ||
| config: TransformerConfig, | ||
| *args, **kwargs, | ||
| ): | ||
| super().__init__(config, *args, **kwargs) | ||
|
|
||
| def forward( | ||
| self, | ||
| hidden_states: Tensor, | ||
| attention_mask: Tensor, | ||
| **kwargs, | ||
| ) -> Tuple[Tensor, Tensor]: | ||
| return super().forward(hidden_states, attention_mask, **kwargs) |
There was a problem hiding this comment.
The __init__ and forward methods in the LinearAttention class are redundant as they just call the superclass methods with the same arguments. You can remove them for cleaner and more concise code.
After this change, TransformerConfig and Tuple will become unused imports and should also be removed, along with other unused imports in this file (BaseInferenceContext, PackedSeqParams, Union, and SelfAttentionSubmodules).
class LinearAttention(SelfAttention):
pass
No description provided.