Source code for cvnets.modules.mobilevit_block

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import math
from typing import Dict, Optional, Sequence, Tuple, Union

import numpy as np
import torch
from torch import Tensor, nn
from torch.nn import functional as F

from cvnets.layers import ConvLayer2d, get_normalization_layer
from cvnets.modules.base_module import BaseModule
from cvnets.modules.transformer import LinearAttnFFN, TransformerEncoder


[docs]class MobileViTBlock(BaseModule): """ This class defines the `MobileViT block <https://arxiv.org/abs/2110.02178?context=cs.LG>`_ Args: opts: command line arguments in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)` transformer_dim (int): Input dimension to the transformer unit ffn_dim (int): Dimension of the FFN block n_transformer_blocks (Optional[int]): Number of transformer blocks. Default: 2 head_dim (Optional[int]): Head dimension in the multi-head attention. Default: 32 attn_dropout (Optional[float]): Dropout in multi-head attention. Default: 0.0 dropout (Optional[float]): Dropout rate. Default: 0.0 ffn_dropout (Optional[float]): Dropout between FFN layers in transformer. Default: 0.0 patch_h (Optional[int]): Patch height for unfolding operation. Default: 8 patch_w (Optional[int]): Patch width for unfolding operation. Default: 8 transformer_norm_layer (Optional[str]): Normalization layer in the transformer block. Default: layer_norm conv_ksize (Optional[int]): Kernel size to learn local representations in MobileViT block. Default: 3 dilation (Optional[int]): Dilation rate in convolutions. Default: 1 no_fusion (Optional[bool]): Do not combine the input and output feature maps. Default: False """
[docs] def __init__( self, opts, in_channels: int, transformer_dim: int, ffn_dim: int, n_transformer_blocks: Optional[int] = 2, head_dim: Optional[int] = 32, attn_dropout: Optional[float] = 0.0, dropout: Optional[int] = 0.0, ffn_dropout: Optional[int] = 0.0, patch_h: Optional[int] = 8, patch_w: Optional[int] = 8, transformer_norm_layer: Optional[str] = "layer_norm", conv_ksize: Optional[int] = 3, dilation: Optional[int] = 1, no_fusion: Optional[bool] = False, *args, **kwargs ) -> None: conv_3x3_in = ConvLayer2d( opts=opts, in_channels=in_channels, out_channels=in_channels, kernel_size=conv_ksize, stride=1, use_norm=True, use_act=True, dilation=dilation, ) conv_1x1_in = ConvLayer2d( opts=opts, in_channels=in_channels, out_channels=transformer_dim, kernel_size=1, stride=1, use_norm=False, use_act=False, ) conv_1x1_out = ConvLayer2d( opts=opts, in_channels=transformer_dim, out_channels=in_channels, kernel_size=1, stride=1, use_norm=True, use_act=True, ) conv_3x3_out = None if not no_fusion: conv_3x3_out = ConvLayer2d( opts=opts, in_channels=2 * in_channels, out_channels=in_channels, kernel_size=conv_ksize, stride=1, use_norm=True, use_act=True, ) super().__init__() self.local_rep = nn.Sequential() self.local_rep.add_module(name="conv_3x3", module=conv_3x3_in) self.local_rep.add_module(name="conv_1x1", module=conv_1x1_in) assert transformer_dim % head_dim == 0 num_heads = transformer_dim // head_dim global_rep = [ TransformerEncoder( opts=opts, embed_dim=transformer_dim, ffn_latent_dim=ffn_dim, num_heads=num_heads, attn_dropout=attn_dropout, dropout=dropout, ffn_dropout=ffn_dropout, transformer_norm_layer=transformer_norm_layer, ) for _ in range(n_transformer_blocks) ] global_rep.append( get_normalization_layer( opts=opts, norm_type=transformer_norm_layer, num_features=transformer_dim, ) ) self.global_rep = nn.Sequential(*global_rep) self.conv_proj = conv_1x1_out self.fusion = conv_3x3_out self.patch_h = patch_h self.patch_w = patch_w self.patch_area = self.patch_w * self.patch_h self.cnn_in_dim = in_channels self.cnn_out_dim = transformer_dim self.n_heads = num_heads self.ffn_dim = ffn_dim self.dropout = dropout self.attn_dropout = attn_dropout self.ffn_dropout = ffn_dropout self.dilation = dilation self.n_blocks = n_transformer_blocks self.conv_ksize = conv_ksize
def __repr__(self) -> str: repr_str = "{}(".format(self.__class__.__name__) repr_str += "\n\t Local representations" if isinstance(self.local_rep, nn.Sequential): for m in self.local_rep: repr_str += "\n\t\t {}".format(m) else: repr_str += "\n\t\t {}".format(self.local_rep) repr_str += "\n\t Global representations with patch size of {}x{}".format( self.patch_h, self.patch_w ) if isinstance(self.global_rep, nn.Sequential): for m in self.global_rep: repr_str += "\n\t\t {}".format(m) else: repr_str += "\n\t\t {}".format(self.global_rep) if isinstance(self.conv_proj, nn.Sequential): for m in self.conv_proj: repr_str += "\n\t\t {}".format(m) else: repr_str += "\n\t\t {}".format(self.conv_proj) if self.fusion is not None: repr_str += "\n\t Feature fusion" if isinstance(self.fusion, nn.Sequential): for m in self.fusion: repr_str += "\n\t\t {}".format(m) else: repr_str += "\n\t\t {}".format(self.fusion) repr_str += "\n)" return repr_str
[docs] def unfolding(self, feature_map: Tensor) -> Tuple[Tensor, Dict]: patch_w, patch_h = self.patch_w, self.patch_h patch_area = int(patch_w * patch_h) batch_size, in_channels, orig_h, orig_w = feature_map.shape new_h = int(math.ceil(orig_h / self.patch_h) * self.patch_h) new_w = int(math.ceil(orig_w / self.patch_w) * self.patch_w) interpolate = False if new_w != orig_w or new_h != orig_h: # Note: Padding can be done, but then it needs to be handled in attention function. feature_map = F.interpolate( feature_map, size=(new_h, new_w), mode="bilinear", align_corners=False ) interpolate = True # number of patches along width and height num_patch_w = new_w // patch_w # n_w num_patch_h = new_h // patch_h # n_h num_patches = num_patch_h * num_patch_w # N # [B, C, H, W] --> [B * C * n_h, p_h, n_w, p_w] reshaped_fm = feature_map.reshape( batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w ) # [B * C * n_h, p_h, n_w, p_w] --> [B * C * n_h, n_w, p_h, p_w] transposed_fm = reshaped_fm.transpose(1, 2) # [B * C * n_h, n_w, p_h, p_w] --> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w reshaped_fm = transposed_fm.reshape( batch_size, in_channels, num_patches, patch_area ) # [B, C, N, P] --> [B, P, N, C] transposed_fm = reshaped_fm.transpose(1, 3) # [B, P, N, C] --> [BP, N, C] patches = transposed_fm.reshape(batch_size * patch_area, num_patches, -1) info_dict = { "orig_size": (orig_h, orig_w), "batch_size": batch_size, "interpolate": interpolate, "total_patches": num_patches, "num_patches_w": num_patch_w, "num_patches_h": num_patch_h, } return patches, info_dict
[docs] def folding(self, patches: Tensor, info_dict: Dict) -> Tensor: n_dim = patches.dim() assert n_dim == 3, "Tensor should be of shape BPxNxC. Got: {}".format( patches.shape ) # [BP, N, C] --> [B, P, N, C] patches = patches.contiguous().view( info_dict["batch_size"], self.patch_area, info_dict["total_patches"], -1 ) batch_size, pixels, num_patches, channels = patches.size() num_patch_h = info_dict["num_patches_h"] num_patch_w = info_dict["num_patches_w"] # [B, P, N, C] --> [B, C, N, P] patches = patches.transpose(1, 3) # [B, C, N, P] --> [B*C*n_h, n_w, p_h, p_w] feature_map = patches.reshape( batch_size * channels * num_patch_h, num_patch_w, self.patch_h, self.patch_w ) # [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w] feature_map = feature_map.transpose(1, 2) # [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W] feature_map = feature_map.reshape( batch_size, channels, num_patch_h * self.patch_h, num_patch_w * self.patch_w ) if info_dict["interpolate"]: feature_map = F.interpolate( feature_map, size=info_dict["orig_size"], mode="bilinear", align_corners=False, ) return feature_map
[docs] def forward_spatial(self, x: Tensor) -> Tensor: res = x fm = self.local_rep(x) # convert feature map to patches patches, info_dict = self.unfolding(fm) # learn global representations for transformer_layer in self.global_rep: patches = transformer_layer(patches) # [B x Patch x Patches x C] --> [B x C x Patches x Patch] fm = self.folding(patches=patches, info_dict=info_dict) fm = self.conv_proj(fm) if self.fusion is not None: fm = self.fusion(torch.cat((res, fm), dim=1)) return fm
[docs] def forward_temporal( self, x: Tensor, x_prev: Optional[Tensor] = None ) -> Union[Tensor, Tuple[Tensor, Tensor]]: res = x fm = self.local_rep(x) # convert feature map to patches patches, info_dict = self.unfolding(fm) # learn global representations for global_layer in self.global_rep: if isinstance(global_layer, TransformerEncoder): patches = global_layer(x=patches, x_prev=x_prev) else: patches = global_layer(patches) # [B x Patch x Patches x C] --> [B x C x Patches x Patch] fm = self.folding(patches=patches, info_dict=info_dict) fm = self.conv_proj(fm) if self.fusion is not None: fm = self.fusion(torch.cat((res, fm), dim=1)) return fm, patches
[docs] def forward( self, x: Union[Tensor, Tuple[Tensor]], *args, **kwargs ) -> Union[Tensor, Tuple[Tensor, Tensor]]: if isinstance(x, Tuple) and len(x) == 2: # for spatio-temporal MobileViT return self.forward_temporal(x=x[0], x_prev=x[1]) elif isinstance(x, Tensor): # For image data return self.forward_spatial(x) else: raise NotImplementedError
[docs]class MobileViTBlockv2(BaseModule): """ This class defines the `MobileViTv2 <https://arxiv.org/abs/2206.02680>`_ block Args: opts: command line arguments in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)` attn_unit_dim (int): Input dimension to the attention unit ffn_multiplier (int): Expand the input dimensions by this factor in FFN. Default is 2. n_attn_blocks (Optional[int]): Number of attention units. Default: 2 attn_dropout (Optional[float]): Dropout in multi-head attention. Default: 0.0 dropout (Optional[float]): Dropout rate. Default: 0.0 ffn_dropout (Optional[float]): Dropout between FFN layers in transformer. Default: 0.0 patch_h (Optional[int]): Patch height for unfolding operation. Default: 8 patch_w (Optional[int]): Patch width for unfolding operation. Default: 8 conv_ksize (Optional[int]): Kernel size to learn local representations in MobileViT block. Default: 3 dilation (Optional[int]): Dilation rate in convolutions. Default: 1 attn_norm_layer (Optional[str]): Normalization layer in the attention block. Default: layer_norm_2d """
[docs] def __init__( self, opts, in_channels: int, attn_unit_dim: int, ffn_multiplier: Optional[Union[Sequence[Union[int, float]], int, float]] = 2.0, n_attn_blocks: Optional[int] = 2, attn_dropout: Optional[float] = 0.0, dropout: Optional[float] = 0.0, ffn_dropout: Optional[float] = 0.0, patch_h: Optional[int] = 8, patch_w: Optional[int] = 8, conv_ksize: Optional[int] = 3, dilation: Optional[int] = 1, attn_norm_layer: Optional[str] = "layer_norm_2d", *args, **kwargs ) -> None: cnn_out_dim = attn_unit_dim conv_3x3_in = ConvLayer2d( opts=opts, in_channels=in_channels, out_channels=in_channels, kernel_size=conv_ksize, stride=1, use_norm=True, use_act=True, dilation=dilation, groups=in_channels, ) conv_1x1_in = ConvLayer2d( opts=opts, in_channels=in_channels, out_channels=cnn_out_dim, kernel_size=1, stride=1, use_norm=False, use_act=False, ) super(MobileViTBlockv2, self).__init__() self.local_rep = nn.Sequential(conv_3x3_in, conv_1x1_in) self.global_rep, attn_unit_dim = self._build_attn_layer( opts=opts, d_model=attn_unit_dim, ffn_mult=ffn_multiplier, n_layers=n_attn_blocks, attn_dropout=attn_dropout, dropout=dropout, ffn_dropout=ffn_dropout, attn_norm_layer=attn_norm_layer, ) self.conv_proj = ConvLayer2d( opts=opts, in_channels=cnn_out_dim, out_channels=in_channels, kernel_size=1, stride=1, use_norm=True, use_act=False, ) self.patch_h = patch_h self.patch_w = patch_w self.patch_area = self.patch_w * self.patch_h self.cnn_in_dim = in_channels self.cnn_out_dim = cnn_out_dim self.transformer_in_dim = attn_unit_dim self.dropout = dropout self.attn_dropout = attn_dropout self.ffn_dropout = ffn_dropout self.n_blocks = n_attn_blocks self.conv_ksize = conv_ksize self.enable_coreml_compatible_fn = getattr( opts, "common.enable_coreml_compatible_module", False ) if self.enable_coreml_compatible_fn: # we set persistent to false so that these weights are not part of model's state_dict self.register_buffer( name="unfolding_weights", tensor=self._compute_unfolding_weights(), persistent=False, )
def _compute_unfolding_weights(self) -> Tensor: # [P_h * P_w, P_h * P_w] weights = torch.eye(self.patch_h * self.patch_w, dtype=torch.float) # [P_h * P_w, P_h * P_w] --> [P_h * P_w, 1, P_h, P_w] weights = weights.reshape( (self.patch_h * self.patch_w, 1, self.patch_h, self.patch_w) ) # [P_h * P_w, 1, P_h, P_w] --> [P_h * P_w * C, 1, P_h, P_w] weights = weights.repeat(self.cnn_out_dim, 1, 1, 1) return weights def _build_attn_layer( self, opts, d_model: int, ffn_mult: Union[Sequence, int, float], n_layers: int, attn_dropout: float, dropout: float, ffn_dropout: float, attn_norm_layer: str, *args, **kwargs ) -> Tuple[nn.Module, int]: if isinstance(ffn_mult, Sequence) and len(ffn_mult) == 2: ffn_dims = ( np.linspace(ffn_mult[0], ffn_mult[1], n_layers, dtype=float) * d_model ) elif isinstance(ffn_mult, Sequence) and len(ffn_mult) == 1: ffn_dims = [ffn_mult[0] * d_model] * n_layers elif isinstance(ffn_mult, (int, float)): ffn_dims = [ffn_mult * d_model] * n_layers else: raise NotImplementedError # ensure that dims are multiple of 16 ffn_dims = [int((d // 16) * 16) for d in ffn_dims] global_rep = [ LinearAttnFFN( opts=opts, embed_dim=d_model, ffn_latent_dim=ffn_dims[block_idx], attn_dropout=attn_dropout, dropout=dropout, ffn_dropout=ffn_dropout, norm_layer=attn_norm_layer, ) for block_idx in range(n_layers) ] global_rep.append( get_normalization_layer( opts=opts, norm_type=attn_norm_layer, num_features=d_model ) ) return nn.Sequential(*global_rep), d_model def __repr__(self) -> str: repr_str = "{}(".format(self.__class__.__name__) repr_str += "\n\t Local representations" if isinstance(self.local_rep, nn.Sequential): for m in self.local_rep: repr_str += "\n\t\t {}".format(m) else: repr_str += "\n\t\t {}".format(self.local_rep) repr_str += "\n\t Global representations with patch size of {}x{}".format( self.patch_h, self.patch_w, ) if isinstance(self.global_rep, nn.Sequential): for m in self.global_rep: repr_str += "\n\t\t {}".format(m) else: repr_str += "\n\t\t {}".format(self.global_rep) if isinstance(self.conv_proj, nn.Sequential): for m in self.conv_proj: repr_str += "\n\t\t {}".format(m) else: repr_str += "\n\t\t {}".format(self.conv_proj) repr_str += "\n)" return repr_str
[docs] def unfolding_pytorch(self, feature_map: Tensor) -> Tuple[Tensor, Tuple[int, int]]: batch_size, in_channels, img_h, img_w = feature_map.shape # [B, C, H, W] --> [B, C, P, N] patches = F.unfold( feature_map, kernel_size=(self.patch_h, self.patch_w), stride=(self.patch_h, self.patch_w), ) patches = patches.reshape( batch_size, in_channels, self.patch_h * self.patch_w, -1 ) return patches, (img_h, img_w)
[docs] def folding_pytorch(self, patches: Tensor, output_size: Tuple[int, int]) -> Tensor: batch_size, in_dim, patch_size, n_patches = patches.shape # [B, C, P, N] patches = patches.reshape(batch_size, in_dim * patch_size, n_patches) feature_map = F.fold( patches, output_size=output_size, kernel_size=(self.patch_h, self.patch_w), stride=(self.patch_h, self.patch_w), ) return feature_map
[docs] def unfolding_coreml(self, feature_map: Tensor) -> Tuple[Tensor, Tuple[int, int]]: # im2col is not implemented in Coreml, so here we hack its implementation using conv2d # we compute the weights # [B, C, H, W] --> [B, C, P, N] batch_size, in_channels, img_h, img_w = feature_map.shape # patches = F.conv2d( feature_map, self.unfolding_weights, bias=None, stride=(self.patch_h, self.patch_w), padding=0, dilation=1, groups=in_channels, ) patches = patches.reshape( batch_size, in_channels, self.patch_h * self.patch_w, -1 ) return patches, (img_h, img_w)
[docs] def folding_coreml(self, patches: Tensor, output_size: Tuple[int, int]) -> Tensor: # col2im is not supported on coreml, so tracing fails # We hack folding function via pixel_shuffle to enable coreml tracing batch_size, in_dim, patch_size, n_patches = patches.shape n_patches_h = output_size[0] // self.patch_h n_patches_w = output_size[1] // self.patch_w feature_map = patches.reshape( batch_size, in_dim * self.patch_h * self.patch_w, n_patches_h, n_patches_w ) assert ( self.patch_h == self.patch_w ), "For Coreml, we need patch_h and patch_w are the same" feature_map = F.pixel_shuffle(feature_map, upscale_factor=self.patch_h) return feature_map
[docs] def resize_input_if_needed(self, x): batch_size, in_channels, orig_h, orig_w = x.shape if orig_h % self.patch_h != 0 or orig_w % self.patch_w != 0: new_h = int(math.ceil(orig_h / self.patch_h) * self.patch_h) new_w = int(math.ceil(orig_w / self.patch_w) * self.patch_w) x = F.interpolate( x, size=(new_h, new_w), mode="bilinear", align_corners=True ) return x
[docs] def forward_spatial(self, x: Tensor, *args, **kwargs) -> Tensor: x = self.resize_input_if_needed(x) fm = self.local_rep(x) # convert feature map to patches if self.enable_coreml_compatible_fn: patches, output_size = self.unfolding_coreml(fm) else: patches, output_size = self.unfolding_pytorch(fm) # learn global representations on all patches patches = self.global_rep(patches) # [B x Patch x Patches x C] --> [B x C x Patches x Patch] if self.enable_coreml_compatible_fn: fm = self.folding_coreml(patches=patches, output_size=output_size) else: fm = self.folding_pytorch(patches=patches, output_size=output_size) fm = self.conv_proj(fm) return fm
[docs] def forward_temporal( self, x: Tensor, x_prev: Tensor, *args, **kwargs ) -> Union[Tensor, Tuple[Tensor, Tensor]]: x = self.resize_input_if_needed(x) fm = self.local_rep(x) # convert feature map to patches if self.enable_coreml_compatible_fn: patches, output_size = self.unfolding_coreml(fm) else: patches, output_size = self.unfolding_pytorch(fm) # learn global representations for global_layer in self.global_rep: if isinstance(global_layer, LinearAttnFFN): patches = global_layer(x=patches, x_prev=x_prev) else: patches = global_layer(patches) # [B x Patch x Patches x C] --> [B x C x Patches x Patch] if self.enable_coreml_compatible_fn: fm = self.folding_coreml(patches=patches, output_size=output_size) else: fm = self.folding_pytorch(patches=patches, output_size=output_size) fm = self.conv_proj(fm) return fm, patches
[docs] def forward( self, x: Union[Tensor, Tuple[Tensor]], *args, **kwargs ) -> Union[Tensor, Tuple[Tensor, Tensor]]: if isinstance(x, Tuple) and len(x) == 2: # for spatio-temporal data (e.g., videos) return self.forward_temporal(x=x[0], x_prev=x[1]) elif isinstance(x, Tensor): # for image data return self.forward_spatial(x) else: raise NotImplementedError