Source code for cvnets.layers.normalization.layer_norm

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

from typing import List, Optional, Union

import torch
from torch import Size, Tensor, nn

from cvnets.layers.normalization import register_norm_fn


[docs]@register_norm_fn(name="layer_norm") class LayerNorm(nn.LayerNorm): r""" Applies `Layer Normalization <https://arxiv.org/abs/1607.06450>`_ over a input tensor Args: normalized_shape (int or list or torch.Size): input shape from an expected input of size .. math:: [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1] \times \ldots \times \text{normalized\_shape}[-1]] If a single integer is used, it is treated as a singleton list, and this module will normalize over the last dimension which is expected to be of that specific size. eps (Optional, float): Value added to the denominator for numerical stability. Default: 1e-5 elementwise_affine (bool): If ``True``, use learnable affine parameters. Default: ``True`` Shape: - Input: :math:`(N, *)` where :math:`N` is the batch size - Output: same shape as the input """
[docs] def __init__( self, normalized_shape: Union[int, List[int], Size], eps: Optional[float] = 1e-5, elementwise_affine: Optional[bool] = True, *args, **kwargs ): super().__init__( normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, )
[docs] def forward(self, x: Tensor) -> Tensor: n_dim = x.ndim if x.shape[1] == self.normalized_shape[0] and n_dim > 2: # channel-first format s, u = torch.std_mean(x, dim=1, keepdim=True, unbiased=False) x = (x - u) / (s + self.eps) if self.weight is not None: # Using fused operation for performing affine transformation: x = (x * weight) + bias n_dim = x.ndim - 2 new_shape = [1, self.normalized_shape[0]] + [1] * n_dim x = torch.addcmul( input=self.bias.reshape(*[new_shape]), value=1.0, tensor1=x, tensor2=self.weight.reshape(*[new_shape]), ) return x elif x.shape[-1] == self.normalized_shape[0]: # channel-last format return super().forward(x) else: raise NotImplementedError( "LayerNorm is supported for channel-first and channel-last format only" )
[docs]@register_norm_fn(name="layer_norm_2d") @register_norm_fn(name="layer_norm_nchw") class LayerNorm2D_NCHW(nn.GroupNorm): """ Applies `Layer Normalization <https://arxiv.org/abs/1607.06450>`_ over a 4D input tensor Args: num_features (int): :math:`C` from an expected input of size :math:`(N, C, H, W)` eps (Optional, float): Value added to the denominator for numerical stability. Default: 1e-5 elementwise_affine (bool): If ``True``, use learnable affine parameters. Default: ``True`` Shape: - Input: :math:`(N, C, H, W)` where :math:`N` is the batch size, :math:`C` is the number of input channels, :math:`H` is the input height, and :math:`W` is the input width - Output: same shape as the input """
[docs] def __init__( self, num_features: int, eps: Optional[float] = 1e-5, elementwise_affine: Optional[bool] = True, *args, **kwargs ) -> None: super().__init__( num_channels=num_features, eps=eps, affine=elementwise_affine, num_groups=1 ) self.num_channels = num_features
def __repr__(self): return "{}(num_channels={}, eps={}, affine={})".format( self.__class__.__name__, self.num_channels, self.eps, self.affine )
[docs]@register_norm_fn(name="layer_norm_fp32") class LayerNormFP32(LayerNorm): """ Applies `Layer Normalization <https://arxiv.org/abs/1607.06450>`_ over a input tensor with FP32 precision """
[docs] def __init__( self, normalized_shape: Union[int, List[int], Size], eps: Optional[float] = 1e-5, elementwise_affine: Optional[bool] = True, *args, **kwargs ): super().__init__( normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, *args, **kwargs )
[docs] def forward(self, x: Tensor) -> Tensor: # Convert input from dtype X to FP32 and perform normalization operation. # This may help with underflow/overflow issues that we typically see with normalization layers inp_dtype = x.dtype return super().forward(x.to(torch.float32)).to(inp_dtype)