from transformers import PretrainedConfig from torch import nn from .types import TransformerLayerCFG, TransformerEncoderCFG class WavJEPAConfig(PretrainedConfig): model_type = "wavjepa-base" model_size = "base" in_channels: int = 1 def __init__( self, extractor_layers_spec: str = "[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)]", extractor_dropout : float = 0.0, extractor_mode : str = "default", extractor_conv_bias : bool = False, extractor_depthwise: bool = False, encoder_d_model: int = 768, encoder_nhead : int = 12, encoder_batch_first = True, encoder_norm_first = False, encoder_bias = True, encoder_mlp_ratio = 4.0, encoder_dropout = 0.0, encoder_num_layers: int = 12, encoder_enable_nested_tensor = False, encoder_mask_check = True, decoder_d_model: int = 384, decoder_nhead : int = 12, decoder_batch_first = True, decoder_norm_first = False, decoder_bias = True, decoder_mlp_ratio = 4.0, decoder_dropout = 0.0, decoder_num_layers: int = 12, decoder_enable_nested_tensor = False, decoder_mask_check = True, **kwargs ): self.encoder_cfg = TransformerEncoderCFG.create( num_layers = encoder_num_layers, enable_nested_tensor = encoder_enable_nested_tensor, mask_check = encoder_mask_check, ) self.decoder_cfg = TransformerEncoderCFG.create( num_layers = decoder_num_layers, enable_nested_tensor = decoder_enable_nested_tensor, mask_check = decoder_mask_check, ) self.encoder_layers_cfg = TransformerLayerCFG.create( d_model = encoder_d_model, nhead = encoder_nhead, batch_first = encoder_batch_first, norm_first = encoder_norm_first, bias = encoder_bias, mlp_ratio = encoder_mlp_ratio, dropout = encoder_dropout, layer_norm_eps = 1e-6 ) self.decoder_layers_cfg = TransformerLayerCFG.create( d_model = decoder_d_model, nhead = decoder_nhead, batch_first = decoder_batch_first, norm_first = decoder_norm_first, bias = decoder_bias, mlp_ratio = decoder_mlp_ratio, dropout = decoder_dropout, layer_norm_eps = 1e-6 ) self.extractor_config = dict( conv_layers_spec = extractor_layers_spec, in_channels = self.in_channels, dropout = extractor_dropout, mode = extractor_mode, conv_bias = extractor_conv_bias, depthwise = extractor_depthwise) super().__init__(**kwargs)