| from typing import TypedDict |
| from torch import nn |
|
|
|
|
| class TransformerLayerCFG(TypedDict): |
| d_model : int |
| nhead : int |
| batch_first : bool |
| norm_first : bool |
| bias : bool |
| dim_feedforward : int |
| dropout : float |
| layer_norm_eps : float |
|
|
| @classmethod |
| def create(cls, |
| d_model : int = 768, |
| nhead : int = 12, |
| batch_first : bool = True, |
| norm_first : bool = False, |
| bias : bool = True, |
| mlp_ratio : float = 4.0, |
| dropout : float = 0.0, |
| layer_norm_eps : float = 1e-6) -> 'TransformerLayerCFG': |
| return TransformerLayerCFG(d_model = d_model, |
| nhead = nhead, |
| batch_first = batch_first, |
| norm_first = norm_first, |
| bias = bias, |
| dim_feedforward = int(d_model * mlp_ratio), |
| dropout = dropout, |
| layer_norm_eps = layer_norm_eps) |
|
|
|
|
| |
| class TransformerEncoderCFG(TypedDict): |
| num_layers : int |
| enable_nested_tensor: bool |
| mask_check: bool |
|
|
| @classmethod |
| def create(cls, |
| num_layers : int = 12, |
| enable_nested_tensor: bool = False, |
| mask_check: bool = True) -> 'TransformerEncoderCFG': |
| return TransformerEncoderCFG(num_layers=num_layers, |
| enable_nested_tensor = enable_nested_tensor, |
| mask_check = mask_check) |