""" modeling_axion.py — Config + Model para Axion/DeepSeek-Nano Carregar com: AutoModelForCausalLM.from_pretrained(..., trust_remote_code=True) """ from __future__ import annotations from typing import Optional import torch import torch.nn.functional as F from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import CausalLMOutputWithPast from model import RMSNorm, DeepSeekBlock # ─── Config ────────────────────────────────────────────────────────────────── class AxionConfig(PretrainedConfig): model_type = "deepseek_nano" def __init__( self, vocab_size=1024, d_model=64, n_layers=4, n_heads=4, d_head=16, kv_lora_rank=8, q_lora_rank=16, rope_theta=10000.0, rope_scaling=None, n_shared_experts=1, n_routed_experts=4, n_active_experts=2, d_ff=64, moe_aux_loss_coef=0.0, expert_bias_init=0.0, max_seq_len=512, dropout=0.0, norm_eps=1e-6, tie_embeddings=True, pad_token_id=0, bos_token_id=1, eos_token_id=2, **kwargs, ): super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) self.vocab_size=vocab_size; self.d_model=d_model; self.n_layers=n_layers self.n_heads=n_heads; self.d_head=d_head; self.kv_lora_rank=kv_lora_rank self.q_lora_rank=q_lora_rank; self.rope_theta=rope_theta self.rope_scaling=rope_scaling; self.n_shared_experts=n_shared_experts self.n_routed_experts=n_routed_experts; self.n_active_experts=n_active_experts self.d_ff=d_ff; self.moe_aux_loss_coef=moe_aux_loss_coef self.expert_bias_init=expert_bias_init; self.max_seq_len=max_seq_len self.dropout=dropout; self.norm_eps=norm_eps; self.tie_embeddings=tie_embeddings self.tokenizer=None; self.tokenizer_class=None # Aliases que o HuggingFace acessa internamente self.num_hidden_layers = n_layers self.hidden_size = d_model self.num_attention_heads = n_heads # ─── Model ─────────────────────────────────────────────────────────────────── class DeepSeekNanoForCausalLM(PreTrainedModel): config_class = AxionConfig supports_gradient_checkpointing = False def __init__(self, config: AxionConfig): super().__init__(config) cfg = config.to_dict() self.embed = torch.nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_token_id) self.blocks = torch.nn.ModuleList( [DeepSeekBlock(cfg) for _ in range(config.n_layers)] ) self.norm = RMSNorm(config.d_model, eps=config.norm_eps) self.post_init() def get_input_embeddings(self): return self.embed def set_input_embeddings(self, v): self.embed = v def get_output_embeddings(self): return self.embed def forward(self, input_ids, attention_mask=None, past_key_values=None, labels=None, use_cache=False, **kwargs): x = self.embed(input_ids) new_caches = [] if use_cache else None # Compatibilidade com DynamicCache (Transformers >= 4.36) # Converte para lista simples que o nosso MLA entende if past_key_values is not None and not isinstance(past_key_values, list): try: past_key_values = [ (past_key_values.key_cache[i], past_key_values.value_cache[i]) if i < len(past_key_values.key_cache) else None for i in range(len(self.blocks)) ] except Exception: past_key_values = None for i, block in enumerate(self.blocks): cache = past_key_values[i] if past_key_values else None x, nc = block(x, kv_cache=cache, use_cache=use_cache) if use_cache: new_caches.append(nc) logits = F.linear(self.norm(x), self.embed.weight) loss = None if labels is not None: loss = F.cross_entropy(logits.view(-1, self.config.vocab_size), labels.view(-1), ignore_index=self.config.pad_token_id) return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=new_caches) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): # Se tem cache (qualquer tipo), usa só o último token if past_key_values is not None: input_ids = input_ids[:, -1:] return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": True}