| import math |
| from typing import Optional , Union |
|
|
| from transformers import PretrainedConfig |
| class MambaConfig(PretrainedConfig): |
| model_type = "mamba" |
| def __init__( |
| self, |
| vocab_size=50277, |
| d_state=16, |
| d_model=2560, |
| d_conv=4, |
| expand=2, |
| conv_bias=True, |
| bias=False, |
| n_layer=64, |
| dt_rank: Union[int, str] = "auto", |
| pad_vocab_size_multiple=8, |
| initializer_range=0.02, |
| **kwargs, |
| ): |
| self.vocab_size = vocab_size |
| self.n_layer= n_layer |
| self.conv_bias = conv_bias |
| self.expand = expand |
| self.pad_vocab_size_multiple = pad_vocab_size_multiple |
| self.d_conv = d_conv |
| self.d_model = d_model |
| self.d_state = d_state |
| self.d_inner = int(self.expand * self.d_model) |
| self.dt_rank = dt_rank |
| self.initializer_range = initializer_range |
| self.bias = bias |
| |
| if self.dt_rank == 'auto': |
| self.dt_rank = math.ceil(self.d_model / 16) |
| |
| if self.vocab_size % self.pad_vocab_size_multiple != 0: |
| self.vocab_size += (self.pad_vocab_size_multiple |
| - self.vocab_size % self.pad_vocab_size_multiple) |
| super().__init__( |
| **kwargs, |
| ) |