kashif HF Staff commited on
Commit
c555c2f
·
verified ·
1 Parent(s): 7729892

fix: align _init_weights with Qwen2Moe using nn.init API

Browse files

Use @torch .no_grad() decorator, call super()._init_weights(), and only init MoE gate weights (nn.Linear/nn.Embedding handled by PreTrainedModel base class in transformers v5).

Files changed (1) hide show
  1. modeling_llada2_moe.py +3 -8
modeling_llada2_moe.py CHANGED
@@ -686,17 +686,12 @@ class LLaDA2MoePreTrainedModel(PreTrainedModel):
686
  _supports_flex_attn = True
687
  _supports_cache_class = True
688
 
 
689
  def _init_weights(self, module):
690
  super()._init_weights(module)
691
  std = self.config.initializer_range
692
- if isinstance(module, nn.Linear):
693
- module.weight.data.normal_(mean=0.0, std=std)
694
- if module.bias is not None:
695
- module.bias.data.zero_()
696
- elif isinstance(module, nn.Embedding):
697
- module.weight.data.normal_(mean=0.0, std=std)
698
- if module.padding_idx is not None:
699
- module.weight.data[module.padding_idx].zero_()
700
 
701
 
702
  LLADA2MOE_INPUTS_DOCSTRING = r"""
 
686
  _supports_flex_attn = True
687
  _supports_cache_class = True
688
 
689
+ @torch.no_grad()
690
  def _init_weights(self, module):
691
  super()._init_weights(module)
692
  std = self.config.initializer_range
693
+ if isinstance(module, LLaDA2MoeGate):
694
+ nn.init.normal_(module.weight, mean=0.0, std=std)
 
 
 
 
 
 
695
 
696
 
697
  LLADA2MOE_INPUTS_DOCSTRING = r"""