| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import copy |
| import warnings |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn.modules.batchnorm import _BatchNorm |
|
|
| __all__ = ["LayerNorm2d", "build_norm", "get_norm_name", "reset_bn", "remove_bn", "set_norm_eps"] |
|
|
|
|
| class LayerNorm2d(nn.LayerNorm): |
| rmsnorm = False |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| out = x if LayerNorm2d.rmsnorm else x - torch.mean(x, dim=1, keepdim=True) |
| out = out / torch.sqrt(torch.square(out).mean(dim=1, keepdim=True) + self.eps) |
| if self.elementwise_affine: |
| out = out * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1) |
| return out |
|
|
| def extra_repr(self) -> str: |
| return f"{self.normalized_shape}, eps={self.eps}, elementwise_affine={self.elementwise_affine}, rmsnorm={self.rmsnorm}" |
|
|
|
|
| |
| |
| REGISTERED_NORMALIZATION_DICT: dict[str, tuple[type, dict[str, any]]] = { |
| "bn2d": (nn.BatchNorm2d, {"num_features": None, "eps": 1e-5, "momentum": 0.1, "affine": True}), |
| "syncbn": (nn.SyncBatchNorm, {"num_features": None, "eps": 1e-5, "momentum": 0.1, "affine": True}), |
| "ln": (nn.LayerNorm, {"normalized_shape": None, "eps": 1e-5, "elementwise_affine": True}), |
| "ln2d": (LayerNorm2d, {"normalized_shape": None, "eps": 1e-5, "elementwise_affine": True}), |
| } |
|
|
|
|
| def build_norm(name="bn2d", num_features=None, affine=True, **kwargs) -> nn.Module or None: |
| if name in ["ln", "ln2d"]: |
| kwargs["normalized_shape"] = num_features |
| kwargs["elementwise_affine"] = affine |
| else: |
| kwargs["num_features"] = num_features |
| kwargs["affine"] = affine |
| if name in REGISTERED_NORMALIZATION_DICT: |
| norm_cls, default_args = copy.deepcopy(REGISTERED_NORMALIZATION_DICT[name]) |
| for key in default_args: |
| if key in kwargs: |
| default_args[key] = kwargs[key] |
| return norm_cls(**default_args) |
| elif name is None or name.lower() == "none": |
| return None |
| else: |
| raise ValueError("do not support: %s" % name) |
|
|
|
|
| def get_norm_name(norm: nn.Module or None) -> str or None: |
| if norm is None: |
| return None |
| module2name = {} |
| for key, config in REGISTERED_NORMALIZATION_DICT.items(): |
| module2name[config[0].__name__] = key |
| return module2name.get(type(norm).__name__, "unknown") |
|
|
|
|
| def reset_bn( |
| model: nn.Module, |
| data_loader: list, |
| sync=True, |
| progress_bar=False, |
| ) -> None: |
| import copy |
|
|
| import torch.nn.functional as F |
| from packages.apps.utils import AverageMeter, is_master, sync_tensor |
| from packages.models.utils import get_device, list_join |
| from tqdm import tqdm |
|
|
| bn_mean = {} |
| bn_var = {} |
|
|
| tmp_model = copy.deepcopy(model) |
| for name, m in tmp_model.named_modules(): |
| if isinstance(m, _BatchNorm): |
| bn_mean[name] = AverageMeter(is_distributed=False) |
| bn_var[name] = AverageMeter(is_distributed=False) |
|
|
| def new_forward(bn, mean_est, var_est): |
| def lambda_forward(x): |
| x = x.contiguous() |
| if sync: |
| batch_mean = x.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) |
| batch_mean = sync_tensor(batch_mean, reduce="cat") |
| batch_mean = torch.mean(batch_mean, dim=0, keepdim=True) |
|
|
| batch_var = (x - batch_mean) * (x - batch_mean) |
| batch_var = batch_var.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) |
| batch_var = sync_tensor(batch_var, reduce="cat") |
| batch_var = torch.mean(batch_var, dim=0, keepdim=True) |
| else: |
| batch_mean = x.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) |
| batch_var = (x - batch_mean) * (x - batch_mean) |
| batch_var = batch_var.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) |
|
|
| batch_mean = torch.squeeze(batch_mean) |
| batch_var = torch.squeeze(batch_var) |
|
|
| mean_est.update(batch_mean.data, x.size(0)) |
| var_est.update(batch_var.data, x.size(0)) |
|
|
| |
| _feature_dim = batch_mean.shape[0] |
| return F.batch_norm( |
| x, |
| batch_mean, |
| batch_var, |
| bn.weight[:_feature_dim], |
| bn.bias[:_feature_dim], |
| False, |
| 0.0, |
| bn.eps, |
| ) |
|
|
| return lambda_forward |
|
|
| m.forward = new_forward(m, bn_mean[name], bn_var[name]) |
|
|
| |
| if len(bn_mean) == 0: |
| return |
|
|
| tmp_model.eval() |
| with torch.inference_mode(): |
| with tqdm(total=len(data_loader), desc="reset bn", disable=not progress_bar or not is_master()) as t: |
| for images in data_loader: |
| images = images.to(get_device(tmp_model)) |
| tmp_model(images) |
| t.set_postfix( |
| { |
| "bs": images.size(0), |
| "res": list_join(images.shape[-2:], "x"), |
| } |
| ) |
| t.update() |
|
|
| for name, m in model.named_modules(): |
| if name in bn_mean and bn_mean[name].count > 0: |
| feature_dim = bn_mean[name].avg.size(0) |
| assert isinstance(m, _BatchNorm) |
| m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg) |
| m.running_var.data[:feature_dim].copy_(bn_var[name].avg) |
|
|
|
|
| def remove_bn(model: nn.Module) -> None: |
| for m in model.modules(): |
| if isinstance(m, _BatchNorm): |
| m.weight = m.bias = None |
| m.forward = lambda x: x |
|
|
|
|
| def set_norm_eps(model: nn.Module, eps: float or None = None, momentum: float or None = None) -> None: |
| for m in model.modules(): |
| if isinstance(m, (nn.GroupNorm, nn.LayerNorm, _BatchNorm)): |
| if eps is not None: |
| m.eps = eps |
| if momentum is not None: |
| m.momentum = momentum |
|
|
|
|
| class RMSNorm(torch.nn.Module): |
| def __init__(self, dim: int, scale_factor=1.0, eps: float = 1e-6): |
| """ |
| Initialize the RMSNorm normalization layer. |
| |
| Args: |
| dim (int): The dimension of the input tensor. |
| eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. |
| |
| Attributes: |
| eps (float): A small value added to the denominator for numerical stability. |
| weight (nn.Parameter): Learnable scaling parameter. |
| |
| """ |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim) * scale_factor) |
|
|
| def _norm(self, x): |
| """ |
| Apply the RMSNorm normalization to the input tensor. |
| |
| Args: |
| x (torch.Tensor): The input tensor. |
| |
| Returns: |
| torch.Tensor: The normalized tensor. |
| |
| """ |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
| def forward(self, x): |
| """ |
| Forward pass through the RMSNorm layer. |
| |
| Args: |
| x (torch.Tensor): The input tensor. |
| |
| Returns: |
| torch.Tensor: The output tensor after applying RMSNorm. |
| |
| """ |
| return (self.weight * self._norm(x.float())).type_as(x) |
|
|