fix: align _init_weights with Qwen2Moe using nn.init API
Browse filesUse @torch .no_grad() decorator, call super()._init_weights(), and only init MoE gate weights (nn.Linear/nn.Embedding handled by PreTrainedModel base class in transformers v5).
- modeling_llada2_moe.py +3 -8
modeling_llada2_moe.py
CHANGED
|
@@ -686,17 +686,12 @@ class LLaDA2MoePreTrainedModel(PreTrainedModel):
|
|
| 686 |
_supports_flex_attn = True
|
| 687 |
_supports_cache_class = True
|
| 688 |
|
|
|
|
| 689 |
def _init_weights(self, module):
|
| 690 |
super()._init_weights(module)
|
| 691 |
std = self.config.initializer_range
|
| 692 |
-
if isinstance(module,
|
| 693 |
-
|
| 694 |
-
if module.bias is not None:
|
| 695 |
-
module.bias.data.zero_()
|
| 696 |
-
elif isinstance(module, nn.Embedding):
|
| 697 |
-
module.weight.data.normal_(mean=0.0, std=std)
|
| 698 |
-
if module.padding_idx is not None:
|
| 699 |
-
module.weight.data[module.padding_idx].zero_()
|
| 700 |
|
| 701 |
|
| 702 |
LLADA2MOE_INPUTS_DOCSTRING = r"""
|
|
|
|
| 686 |
_supports_flex_attn = True
|
| 687 |
_supports_cache_class = True
|
| 688 |
|
| 689 |
+
@torch.no_grad()
|
| 690 |
def _init_weights(self, module):
|
| 691 |
super()._init_weights(module)
|
| 692 |
std = self.config.initializer_range
|
| 693 |
+
if isinstance(module, LLaDA2MoeGate):
|
| 694 |
+
nn.init.normal_(module.weight, mean=0.0, std=std)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 695 |
|
| 696 |
|
| 697 |
LLADA2MOE_INPUTS_DOCSTRING = r"""
|