Source code for cvnets.modules.fastvit

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import argparse
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn

from cvnets.layers import ConvLayer2d, MultiHeadAttention, StochasticDepth
from cvnets.layers.normalization.batch_norm import BatchNorm2d
from cvnets.modules import BaseModule
from cvnets.modules.mobileone_block import MobileOneBlock, RepLKBlock


[docs]def convolutional_stem( opts: argparse.Namespace, in_channels: int, out_channels: int ) -> nn.Sequential: """ Build convolutional stem with MobileOne blocks. Args: opts: Command line arguments. in_channels: Number of input channels. out_channels: Number of output channels. Returns: nn.Sequential object with stem elements. """ inference_mode = getattr(opts, "model.classification.fastvit.inference_mode") return nn.Sequential( MobileOneBlock( opts, in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, groups=1, inference_mode=inference_mode, use_se=False, num_conv_branches=1, ), MobileOneBlock( opts, in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=2, padding=1, groups=out_channels, inference_mode=inference_mode, use_se=False, num_conv_branches=1, ), MobileOneBlock( opts, in_channels=out_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0, groups=1, inference_mode=inference_mode, use_se=False, num_conv_branches=1, ), )
[docs]class PatchEmbed(BaseModule): """ Convolutional Patch embedding layer. Args: opts: Command line arguments. patch_size: Patch size for embedding computation. stride: Stride for convolutional embedding layer. in_channels: Number of channels of input tensor. embed_dim: Number of embedding dimensions. """
[docs] def __init__( self, opts: argparse.Namespace, patch_size: int, stride: int, in_channels: int, embed_dim: int, ): super().__init__() inference_mode = getattr(opts, "model.classification.fastvit.inference_mode") block = list() block.append( RepLKBlock( opts, in_channels=in_channels, out_channels=embed_dim, kernel_size=patch_size, stride=stride, groups=in_channels, small_kernel_size=3, inference_mode=inference_mode, ) ) block.append( MobileOneBlock( opts, in_channels=embed_dim, out_channels=embed_dim, kernel_size=1, stride=1, padding=0, groups=1, inference_mode=inference_mode, use_se=False, num_conv_branches=1, ) ) self.proj = nn.Sequential(*block)
[docs] def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Forward pass Args: x: Input tensor of shape :math:`(B, C, H, W)`. Returns: torch.Tensor of shape :math:`(B, C, H//s, W//s)`, where `s` is the stride provide while instantiating the layer. """ x = self.proj(x) return x
[docs]class RepMixer(BaseModule): """ Reparameterizable token mixer For more details, please refer to our paper: `FastViT: A Fast Hybrid Vision Transformer using Structural Reparameterization` Args: opts: Command line arguments. dim: Input feature map dimension. :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, H, W)`. kernel_size: Kernel size for spatial mixing. Default: 3 use_layer_scale: If True, learnable layer scale is used. Default: ``True`` layer_scale_init_value: Initial value for layer scale. Default: 1e-5 inference_mode: If True, instantiates model in inference mode. Default: ``False`` """
[docs] def __init__( self, opts: argparse.Namespace, dim: int, kernel_size: int = 3, use_layer_scale: bool = True, layer_scale_init_value: float = 1e-5, inference_mode: bool = False, ): super().__init__() self.dim = dim self.kernel_size = kernel_size if inference_mode: self.reparam_conv = nn.Conv2d( in_channels=self.dim, out_channels=self.dim, kernel_size=self.kernel_size, stride=1, padding=self.kernel_size // 2, groups=self.dim, bias=True, ) else: self.norm = MobileOneBlock( opts, dim, dim, kernel_size, padding=kernel_size // 2, groups=dim, use_act=False, use_scale_branch=False, num_conv_branches=0, ) self.mixer = MobileOneBlock( opts, dim, dim, kernel_size, padding=kernel_size // 2, groups=dim, use_act=False, ) self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale = nn.Parameter( layer_scale_init_value * torch.ones((dim, 1, 1)), )
[docs] def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Forward pass implements inference logic for module before and after reparameterization. Args: x: Input tensor of shape :math:`(B, C, H, W)`. Returns: torch.Tensor of shape :math:`(B, C, H, W)`. """ if hasattr(self, "reparam_conv"): x = self.reparam_conv(x) return x else: if self.use_layer_scale: x = x + self.layer_scale * (self.mixer(x) - self.norm(x)) else: x = x + self.mixer(x) - self.norm(x) return x
[docs] def reparameterize(self) -> None: """ Reparameterize mixer and norm into a single convolutional layer for efficient inference. """ self.mixer.reparameterize() self.norm.reparameterize() if self.use_layer_scale: w = self.mixer.id_tensor + self.layer_scale.unsqueeze(-1) * ( self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight ) b = torch.squeeze(self.layer_scale) * ( self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias ) else: w = ( self.mixer.id_tensor + self.mixer.reparam_conv.weight - self.norm.reparam_conv.weight ) b = self.mixer.reparam_conv.bias - self.norm.reparam_conv.bias self.reparam_conv = nn.Conv2d( in_channels=self.dim, out_channels=self.dim, kernel_size=self.kernel_size, stride=1, padding=self.kernel_size // 2, groups=self.dim, bias=True, ) self.reparam_conv.weight.data = w self.reparam_conv.bias.data = b for para in self.parameters(): para.detach_() self.__delattr__("mixer") self.__delattr__("norm")
[docs]class ConvFFN(BaseModule): """ Convolutional FFN Module. Args: opts: Command line arguments. in_channels: Number of input channels. hidden_channels: Number of channels after expansion. Default: None out_channels: Number of output channels. Default: None drop: Dropout rate. Default: ``0.0``. """
[docs] def __init__( self, opts: argparse.Namespace, in_channels: int, hidden_channels: Optional[int] = None, out_channels: Optional[int] = None, drop: float = 0.0, ): super().__init__() out_channels = out_channels or in_channels hidden_channels = hidden_channels or in_channels self.conv = ConvLayer2d( opts, in_channels=in_channels, out_channels=in_channels, kernel_size=7, padding=3, groups=in_channels, bias=False, use_act=False, ) self.fc1 = ConvLayer2d( opts, in_channels, hidden_channels, kernel_size=1, use_norm=False, bias=True ) self.fc2 = ConvLayer2d( opts, hidden_channels, out_channels, kernel_size=1, use_norm=False, use_act=False, bias=True, ) self.drop = nn.Dropout(drop)
[docs] def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Forward pass Args: x: Input tensor of shape :math:`(B, C, H, W)`. Returns: torch.Tensor of shape :math:`(B, C, H, W)`. """ x = self.conv(x) x = self.fc1(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x
[docs]class RepMixerBlock(BaseModule): """ Implementation of Metaformer block with RepMixer as token mixer. For more details on Metaformer structure, please refer to: `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_ Args: opts: Command line arguments. dim: Number of embedding dimensions. kernel_size: Kernel size for repmixer. Default: 3 mlp_ratio: MLP expansion ratio. Default: 4.0 drop: Dropout rate. Default: 0.0 drop_path: Drop path rate. Default: 0.0 use_layer_scale: Flag to turn on layer scale. Default: ``True`` layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 inference_mode: Flag to instantiate block in inference mode. Default: ``False`` """
[docs] def __init__( self, opts: argparse.Namespace, dim: int, kernel_size: int = 3, mlp_ratio: float = 4.0, drop: float = 0.0, drop_path: float = 0.0, use_layer_scale: bool = True, layer_scale_init_value: float = 1e-5, inference_mode: bool = False, ): super().__init__() self.token_mixer = RepMixer( opts, dim=dim, kernel_size=kernel_size, use_layer_scale=use_layer_scale, layer_scale_init_value=layer_scale_init_value, inference_mode=inference_mode, ) assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format( mlp_ratio ) hidden_dim = int(dim * mlp_ratio) self.convffn = ConvFFN( opts, in_channels=dim, hidden_channels=hidden_dim, drop=drop ) self.drop_path = ( StochasticDepth(drop_path, mode="row") if drop_path > 0.0 else nn.Identity() ) self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale = nn.Parameter( layer_scale_init_value * torch.ones((dim, 1, 1)), )
[docs] def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Forward pass Args: x: Input tensor of shape :math:`(B, C, H, W)`. Returns: torch.Tensor of shape :math:`(B, C, H, W)`. """ if self.use_layer_scale: x = self.token_mixer(x) x = x + self.drop_path(self.layer_scale * self.convffn(x)) else: x = self.token_mixer(x) x = x + self.drop_path(self.convffn(x)) return x
[docs]class AttentionBlock(BaseModule): """ Implementation of metaformer block with MHSA as token mixer. For more details on Metaformer structure, please refer to: `MetaFormer Is Actually What You Need for Vision <https://arxiv.org/pdf/2111.11418.pdf>`_ Args: opts: Command line arguments. dim: Number of embedding dimensions. mlp_ratio: MLP expansion ratio. Default: 4.0 drop: Dropout rate. Default: 0.0 drop_path: Drop path rate. Default: 0.0 use_layer_scale: Flag to turn on layer scale. Default: ``True`` layer_scale_init_value: Layer scale value at initialization. Default: 1e-5 """
[docs] def __init__( self, opts: argparse.Namespace, dim: int, mlp_ratio: float = 4.0, drop: float = 0.0, drop_path: float = 0.0, use_layer_scale: bool = True, layer_scale_init_value: float = 1e-5, ): super().__init__() self.norm = BatchNorm2d(num_features=dim) self.head_dim = 32 num_heads = dim // self.head_dim self.token_mixer = MultiHeadAttention( embed_dim=dim, num_heads=num_heads, bias=False ) assert mlp_ratio > 0, "MLP ratio should be greater than 0, found: {}".format( mlp_ratio ) hidden_dim = int(dim * mlp_ratio) self.convffn = ConvFFN( opts, in_channels=dim, hidden_channels=hidden_dim, drop=drop ) self.drop_path = ( StochasticDepth(drop_path, mode="row") if drop_path > 0.0 else nn.Identity() ) self.use_layer_scale = use_layer_scale if use_layer_scale: self.layer_scale_1 = nn.Parameter( layer_scale_init_value * torch.ones((dim, 1, 1)), ) self.layer_scale_2 = nn.Parameter( layer_scale_init_value * torch.ones((dim, 1, 1)), )
def _apply_mhsa(self, x: torch.Tensor) -> torch.Tensor: """ Perform appropriate reshaping before and after MHSA block. Args: x: Input tensor of shape :math:`(B, C, H, W)`. Returns: torch.Tensor of shape :math:`(B, C, H, W)`. """ x_norm = self.norm(x) B, C, H, W = x_norm.shape x_norm_reshaped = torch.flatten(x_norm, start_dim=2).transpose( -2, -1 ) # (B, N, C), where N = H * W out = self.token_mixer(x_norm_reshaped) out = out.transpose(-2, -1).reshape(B, C, H, W) return out
[docs] def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Forward pass Args: x: Input tensor of shape :math:`(B, C, H, W)`. Returns: torch.Tensor output from the attention block. """ if self.use_layer_scale: x = x + self.drop_path(self.layer_scale_1 * self._apply_mhsa(x)) x = x + self.drop_path(self.layer_scale_2 * self.convffn(x)) else: x = x + self.drop_path(self._apply_mhsa(x)) x = x + self.drop_path(self.convffn(x)) return x
[docs]class RepCPE(BaseModule): """ Implementation of reparameterizable conditional positional encoding. For more details refer to paper: `Conditional Positional Encodings for Vision Transformers <https://arxiv.org/pdf/2102.10882.pdf>`_ Args: opts: Command line arguments. in_channels: Number of input channels. embed_dim: Number of embedding dimensions. Default: 768 spatial_shape: Spatial shape of kernel for positional encoding. Default: (7, 7) inference_mode: Flag to instantiate block in inference mode. Default: ``False`` """
[docs] def __init__( self, opts: argparse.Namespace, in_channels: int, embed_dim: int = 768, spatial_shape: Union[int, Tuple[int, int]] = (7, 7), inference_mode: bool = False, ): super(RepCPE, self).__init__() if isinstance(spatial_shape, int): spatial_shape = tuple([spatial_shape] * 2) assert isinstance(spatial_shape, Tuple), ( f'"spatial_shape" must by a sequence or int, ' f"get {type(spatial_shape)} instead." ) assert len(spatial_shape) == 2, ( f'Length of "spatial_shape" should be 2, ' f"got {len(spatial_shape)} instead." ) self.spatial_shape = spatial_shape self.embed_dim = embed_dim self.in_channels = in_channels self.groups = embed_dim if inference_mode: self.reparam_conv = nn.Conv2d( in_channels=self.in_channels, out_channels=self.embed_dim, kernel_size=self.spatial_shape, stride=1, padding=int(self.spatial_shape[0] // 2), groups=self.embed_dim, bias=True, ) else: self.pe = ConvLayer2d( opts, in_channels=in_channels, out_channels=embed_dim, kernel_size=spatial_shape, stride=1, padding=int(spatial_shape[0] // 2), use_norm=False, use_act=False, bias=True, groups=embed_dim, )
[docs] def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Forward pass implements inference logic for module before and after reparameterization. Args: x: Input tensor of shape :math:`(B, C, H, W)`. Returns: torch.Tensor of shape :math:`(B, C, H, W)`. """ if hasattr(self, "reparam_conv"): x = self.reparam_conv(x) return x else: x = self.pe(x) + x return x
[docs] def reparameterize(self) -> None: """Reparameterize linear branches.""" # Build equivalent Id tensor input_dim = self.in_channels // self.groups kernel_value = torch.zeros( ( self.in_channels, input_dim, self.spatial_shape[0], self.spatial_shape[1], ), dtype=self.pe.block.conv.weight.dtype, device=self.pe.block.conv.weight.device, ) for i in range(self.in_channels): kernel_value[ i, i % input_dim, self.spatial_shape[0] // 2, self.spatial_shape[1] // 2, ] = 1 id_tensor = kernel_value # Reparameterize Id tensor and conv w_final = id_tensor + self.pe.block.conv.weight b_final = self.pe.block.conv.bias # Introduce reparam conv self.reparam_conv = nn.Conv2d( in_channels=self.in_channels, out_channels=self.embed_dim, kernel_size=self.spatial_shape, stride=1, padding=int(self.spatial_shape[0] // 2), groups=self.embed_dim, bias=True, ) self.reparam_conv.weight.data = w_final self.reparam_conv.bias.data = b_final for para in self.parameters(): para.detach_() self.__delattr__("pe")