--- library_name: transformers tags: [] --- In transformers v5, Mixtral uses a different architecture with fused layers compared to pre-v5. This LoRA checkpoint is based on the old architecture and we want to ensure that it can still be correctly loaded post-v5. Old architecture: ``` MixtralForCausalLM( (model): MixtralModel( (embed_tokens): Embedding(32000, 1024) (layers): ModuleList( (0-1): 2 x MixtralDecoderLayer( (self_attn): MixtralAttention( (q_proj): Linear(in_features=1024, out_features=1024, bias=False) (k_proj): Linear(in_features=1024, out_features=256, bias=False) (v_proj): Linear(in_features=1024, out_features=256, bias=False) (o_proj): Linear(in_features=1024, out_features=1024, bias=False) ) (block_sparse_moe): MixtralSparseMoeBlock( (gate): Linear(in_features=1024, out_features=8, bias=False) (experts): ModuleList( (0-7): 8 x MixtralBlockSparseTop2MLP( (w1): Linear(in_features=1024, out_features=3584, bias=False) (w2): Linear(in_features=3584, out_features=1024, bias=False) (w3): Linear(in_features=1024, out_features=3584, bias=False) (act_fn): SiLUActivation() ) ) ) (input_layernorm): MixtralRMSNorm((1024,), eps=1e-05) (post_attention_layernorm): MixtralRMSNorm((1024,), eps=1e-05) ) ) (norm): MixtralRMSNorm((1024,), eps=1e-05) (rotary_emb): MixtralRotaryEmbedding() ) (lm_head): Linear(in_features=1024, out_features=32000, bias=False) ) ``` New architecture: ``` MixtralForCausalLM( (model): MixtralModel( (embed_tokens): Embedding(32000, 1024) (layers): ModuleList( (0-1): 2 x MixtralDecoderLayer( (self_attn): MixtralAttention( (q_proj): Linear(in_features=1024, out_features=1024, bias=False) (k_proj): Linear(in_features=1024, out_features=256, bias=False) (v_proj): Linear(in_features=1024, out_features=256, bias=False) (o_proj): Linear(in_features=1024, out_features=1024, bias=False) ) (mlp): MixtralSparseMoeBlock( (gate): MixtralTopKRouter() (experts): MixtralExperts( (act_fn): SiLUActivation() ) ) (input_layernorm): MixtralRMSNorm((1024,), eps=1e-05) (post_attention_layernorm): MixtralRMSNorm((1024,), eps=1e-05) ) ) (norm): MixtralRMSNorm((1024,), eps=1e-05) (rotary_emb): MixtralRotaryEmbedding() ) (lm_head): Linear(in_features=1024, out_features=32000, bias=False) ) ``` The weight conversion spec is: ```python "mixtral": [ WeightRenaming(".block_sparse_moe.gate", ".mlp.gate"), WeightConverter( source_patterns=[ "block_sparse_moe.experts.*.w1.weight", "block_sparse_moe.experts.*.w3.weight", ], # you give me a list of 2 keys, I collect a list of a list of tensors target_patterns="mlp.experts.gate_up_proj", # target key gets the list of two tensors operations=[ MergeModulelist( dim=0 ), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors Concatenate(dim=1), # each process has 2 tensors, gate and up, we concat them into gate_up ], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first ), WeightConverter( source_patterns=[ "block_sparse_moe.experts.*.w2.weight", ], target_patterns="mlp.experts.down_proj", # target key gets the list of two tensors operations=[ MergeModulelist( dim=0 ), # each process has two lists of tensors, we cat each list. -> we end up with 2 tensors ], # we want the loading to add this shard operation here. Though we can't shard after concats and merge, needs to be first ), ], ```