#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import math
from typing import Dict, Optional, Sequence, Tuple, Union
import numpy as np
import torch
from torch import Tensor, nn
from torch.nn import functional as F
from cvnets.layers import ConvLayer2d, get_normalization_layer
from cvnets.modules.base_module import BaseModule
from cvnets.modules.transformer import LinearAttnFFN, TransformerEncoder
[docs]class MobileViTBlock(BaseModule):
"""
This class defines the `MobileViT block <https://arxiv.org/abs/2110.02178?context=cs.LG>`_
Args:
opts: command line arguments
in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)`
transformer_dim (int): Input dimension to the transformer unit
ffn_dim (int): Dimension of the FFN block
n_transformer_blocks (Optional[int]): Number of transformer blocks. Default: 2
head_dim (Optional[int]): Head dimension in the multi-head attention. Default: 32
attn_dropout (Optional[float]): Dropout in multi-head attention. Default: 0.0
dropout (Optional[float]): Dropout rate. Default: 0.0
ffn_dropout (Optional[float]): Dropout between FFN layers in transformer. Default: 0.0
patch_h (Optional[int]): Patch height for unfolding operation. Default: 8
patch_w (Optional[int]): Patch width for unfolding operation. Default: 8
transformer_norm_layer (Optional[str]): Normalization layer in the transformer block. Default: layer_norm
conv_ksize (Optional[int]): Kernel size to learn local representations in MobileViT block. Default: 3
dilation (Optional[int]): Dilation rate in convolutions. Default: 1
no_fusion (Optional[bool]): Do not combine the input and output feature maps. Default: False
"""
[docs] def __init__(
self,
opts,
in_channels: int,
transformer_dim: int,
ffn_dim: int,
n_transformer_blocks: Optional[int] = 2,
head_dim: Optional[int] = 32,
attn_dropout: Optional[float] = 0.0,
dropout: Optional[int] = 0.0,
ffn_dropout: Optional[int] = 0.0,
patch_h: Optional[int] = 8,
patch_w: Optional[int] = 8,
transformer_norm_layer: Optional[str] = "layer_norm",
conv_ksize: Optional[int] = 3,
dilation: Optional[int] = 1,
no_fusion: Optional[bool] = False,
*args,
**kwargs
) -> None:
conv_3x3_in = ConvLayer2d(
opts=opts,
in_channels=in_channels,
out_channels=in_channels,
kernel_size=conv_ksize,
stride=1,
use_norm=True,
use_act=True,
dilation=dilation,
)
conv_1x1_in = ConvLayer2d(
opts=opts,
in_channels=in_channels,
out_channels=transformer_dim,
kernel_size=1,
stride=1,
use_norm=False,
use_act=False,
)
conv_1x1_out = ConvLayer2d(
opts=opts,
in_channels=transformer_dim,
out_channels=in_channels,
kernel_size=1,
stride=1,
use_norm=True,
use_act=True,
)
conv_3x3_out = None
if not no_fusion:
conv_3x3_out = ConvLayer2d(
opts=opts,
in_channels=2 * in_channels,
out_channels=in_channels,
kernel_size=conv_ksize,
stride=1,
use_norm=True,
use_act=True,
)
super().__init__()
self.local_rep = nn.Sequential()
self.local_rep.add_module(name="conv_3x3", module=conv_3x3_in)
self.local_rep.add_module(name="conv_1x1", module=conv_1x1_in)
assert transformer_dim % head_dim == 0
num_heads = transformer_dim // head_dim
global_rep = [
TransformerEncoder(
opts=opts,
embed_dim=transformer_dim,
ffn_latent_dim=ffn_dim,
num_heads=num_heads,
attn_dropout=attn_dropout,
dropout=dropout,
ffn_dropout=ffn_dropout,
transformer_norm_layer=transformer_norm_layer,
)
for _ in range(n_transformer_blocks)
]
global_rep.append(
get_normalization_layer(
opts=opts,
norm_type=transformer_norm_layer,
num_features=transformer_dim,
)
)
self.global_rep = nn.Sequential(*global_rep)
self.conv_proj = conv_1x1_out
self.fusion = conv_3x3_out
self.patch_h = patch_h
self.patch_w = patch_w
self.patch_area = self.patch_w * self.patch_h
self.cnn_in_dim = in_channels
self.cnn_out_dim = transformer_dim
self.n_heads = num_heads
self.ffn_dim = ffn_dim
self.dropout = dropout
self.attn_dropout = attn_dropout
self.ffn_dropout = ffn_dropout
self.dilation = dilation
self.n_blocks = n_transformer_blocks
self.conv_ksize = conv_ksize
def __repr__(self) -> str:
repr_str = "{}(".format(self.__class__.__name__)
repr_str += "\n\t Local representations"
if isinstance(self.local_rep, nn.Sequential):
for m in self.local_rep:
repr_str += "\n\t\t {}".format(m)
else:
repr_str += "\n\t\t {}".format(self.local_rep)
repr_str += "\n\t Global representations with patch size of {}x{}".format(
self.patch_h, self.patch_w
)
if isinstance(self.global_rep, nn.Sequential):
for m in self.global_rep:
repr_str += "\n\t\t {}".format(m)
else:
repr_str += "\n\t\t {}".format(self.global_rep)
if isinstance(self.conv_proj, nn.Sequential):
for m in self.conv_proj:
repr_str += "\n\t\t {}".format(m)
else:
repr_str += "\n\t\t {}".format(self.conv_proj)
if self.fusion is not None:
repr_str += "\n\t Feature fusion"
if isinstance(self.fusion, nn.Sequential):
for m in self.fusion:
repr_str += "\n\t\t {}".format(m)
else:
repr_str += "\n\t\t {}".format(self.fusion)
repr_str += "\n)"
return repr_str
[docs] def unfolding(self, feature_map: Tensor) -> Tuple[Tensor, Dict]:
patch_w, patch_h = self.patch_w, self.patch_h
patch_area = int(patch_w * patch_h)
batch_size, in_channels, orig_h, orig_w = feature_map.shape
new_h = int(math.ceil(orig_h / self.patch_h) * self.patch_h)
new_w = int(math.ceil(orig_w / self.patch_w) * self.patch_w)
interpolate = False
if new_w != orig_w or new_h != orig_h:
# Note: Padding can be done, but then it needs to be handled in attention function.
feature_map = F.interpolate(
feature_map, size=(new_h, new_w), mode="bilinear", align_corners=False
)
interpolate = True
# number of patches along width and height
num_patch_w = new_w // patch_w # n_w
num_patch_h = new_h // patch_h # n_h
num_patches = num_patch_h * num_patch_w # N
# [B, C, H, W] --> [B * C * n_h, p_h, n_w, p_w]
reshaped_fm = feature_map.reshape(
batch_size * in_channels * num_patch_h, patch_h, num_patch_w, patch_w
)
# [B * C * n_h, p_h, n_w, p_w] --> [B * C * n_h, n_w, p_h, p_w]
transposed_fm = reshaped_fm.transpose(1, 2)
# [B * C * n_h, n_w, p_h, p_w] --> [B, C, N, P] where P = p_h * p_w and N = n_h * n_w
reshaped_fm = transposed_fm.reshape(
batch_size, in_channels, num_patches, patch_area
)
# [B, C, N, P] --> [B, P, N, C]
transposed_fm = reshaped_fm.transpose(1, 3)
# [B, P, N, C] --> [BP, N, C]
patches = transposed_fm.reshape(batch_size * patch_area, num_patches, -1)
info_dict = {
"orig_size": (orig_h, orig_w),
"batch_size": batch_size,
"interpolate": interpolate,
"total_patches": num_patches,
"num_patches_w": num_patch_w,
"num_patches_h": num_patch_h,
}
return patches, info_dict
[docs] def folding(self, patches: Tensor, info_dict: Dict) -> Tensor:
n_dim = patches.dim()
assert n_dim == 3, "Tensor should be of shape BPxNxC. Got: {}".format(
patches.shape
)
# [BP, N, C] --> [B, P, N, C]
patches = patches.contiguous().view(
info_dict["batch_size"], self.patch_area, info_dict["total_patches"], -1
)
batch_size, pixels, num_patches, channels = patches.size()
num_patch_h = info_dict["num_patches_h"]
num_patch_w = info_dict["num_patches_w"]
# [B, P, N, C] --> [B, C, N, P]
patches = patches.transpose(1, 3)
# [B, C, N, P] --> [B*C*n_h, n_w, p_h, p_w]
feature_map = patches.reshape(
batch_size * channels * num_patch_h, num_patch_w, self.patch_h, self.patch_w
)
# [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w]
feature_map = feature_map.transpose(1, 2)
# [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W]
feature_map = feature_map.reshape(
batch_size, channels, num_patch_h * self.patch_h, num_patch_w * self.patch_w
)
if info_dict["interpolate"]:
feature_map = F.interpolate(
feature_map,
size=info_dict["orig_size"],
mode="bilinear",
align_corners=False,
)
return feature_map
[docs] def forward_spatial(self, x: Tensor) -> Tensor:
res = x
fm = self.local_rep(x)
# convert feature map to patches
patches, info_dict = self.unfolding(fm)
# learn global representations
for transformer_layer in self.global_rep:
patches = transformer_layer(patches)
# [B x Patch x Patches x C] --> [B x C x Patches x Patch]
fm = self.folding(patches=patches, info_dict=info_dict)
fm = self.conv_proj(fm)
if self.fusion is not None:
fm = self.fusion(torch.cat((res, fm), dim=1))
return fm
[docs] def forward_temporal(
self, x: Tensor, x_prev: Optional[Tensor] = None
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
res = x
fm = self.local_rep(x)
# convert feature map to patches
patches, info_dict = self.unfolding(fm)
# learn global representations
for global_layer in self.global_rep:
if isinstance(global_layer, TransformerEncoder):
patches = global_layer(x=patches, x_prev=x_prev)
else:
patches = global_layer(patches)
# [B x Patch x Patches x C] --> [B x C x Patches x Patch]
fm = self.folding(patches=patches, info_dict=info_dict)
fm = self.conv_proj(fm)
if self.fusion is not None:
fm = self.fusion(torch.cat((res, fm), dim=1))
return fm, patches
[docs] def forward(
self, x: Union[Tensor, Tuple[Tensor]], *args, **kwargs
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
if isinstance(x, Tuple) and len(x) == 2:
# for spatio-temporal MobileViT
return self.forward_temporal(x=x[0], x_prev=x[1])
elif isinstance(x, Tensor):
# For image data
return self.forward_spatial(x)
else:
raise NotImplementedError
[docs]class MobileViTBlockv2(BaseModule):
"""
This class defines the `MobileViTv2 <https://arxiv.org/abs/2206.02680>`_ block
Args:
opts: command line arguments
in_channels (int): :math:`C_{in}` from an expected input of size :math:`(N, C_{in}, H, W)`
attn_unit_dim (int): Input dimension to the attention unit
ffn_multiplier (int): Expand the input dimensions by this factor in FFN. Default is 2.
n_attn_blocks (Optional[int]): Number of attention units. Default: 2
attn_dropout (Optional[float]): Dropout in multi-head attention. Default: 0.0
dropout (Optional[float]): Dropout rate. Default: 0.0
ffn_dropout (Optional[float]): Dropout between FFN layers in transformer. Default: 0.0
patch_h (Optional[int]): Patch height for unfolding operation. Default: 8
patch_w (Optional[int]): Patch width for unfolding operation. Default: 8
conv_ksize (Optional[int]): Kernel size to learn local representations in MobileViT block. Default: 3
dilation (Optional[int]): Dilation rate in convolutions. Default: 1
attn_norm_layer (Optional[str]): Normalization layer in the attention block. Default: layer_norm_2d
"""
[docs] def __init__(
self,
opts,
in_channels: int,
attn_unit_dim: int,
ffn_multiplier: Optional[Union[Sequence[Union[int, float]], int, float]] = 2.0,
n_attn_blocks: Optional[int] = 2,
attn_dropout: Optional[float] = 0.0,
dropout: Optional[float] = 0.0,
ffn_dropout: Optional[float] = 0.0,
patch_h: Optional[int] = 8,
patch_w: Optional[int] = 8,
conv_ksize: Optional[int] = 3,
dilation: Optional[int] = 1,
attn_norm_layer: Optional[str] = "layer_norm_2d",
*args,
**kwargs
) -> None:
cnn_out_dim = attn_unit_dim
conv_3x3_in = ConvLayer2d(
opts=opts,
in_channels=in_channels,
out_channels=in_channels,
kernel_size=conv_ksize,
stride=1,
use_norm=True,
use_act=True,
dilation=dilation,
groups=in_channels,
)
conv_1x1_in = ConvLayer2d(
opts=opts,
in_channels=in_channels,
out_channels=cnn_out_dim,
kernel_size=1,
stride=1,
use_norm=False,
use_act=False,
)
super(MobileViTBlockv2, self).__init__()
self.local_rep = nn.Sequential(conv_3x3_in, conv_1x1_in)
self.global_rep, attn_unit_dim = self._build_attn_layer(
opts=opts,
d_model=attn_unit_dim,
ffn_mult=ffn_multiplier,
n_layers=n_attn_blocks,
attn_dropout=attn_dropout,
dropout=dropout,
ffn_dropout=ffn_dropout,
attn_norm_layer=attn_norm_layer,
)
self.conv_proj = ConvLayer2d(
opts=opts,
in_channels=cnn_out_dim,
out_channels=in_channels,
kernel_size=1,
stride=1,
use_norm=True,
use_act=False,
)
self.patch_h = patch_h
self.patch_w = patch_w
self.patch_area = self.patch_w * self.patch_h
self.cnn_in_dim = in_channels
self.cnn_out_dim = cnn_out_dim
self.transformer_in_dim = attn_unit_dim
self.dropout = dropout
self.attn_dropout = attn_dropout
self.ffn_dropout = ffn_dropout
self.n_blocks = n_attn_blocks
self.conv_ksize = conv_ksize
self.enable_coreml_compatible_fn = getattr(
opts, "common.enable_coreml_compatible_module", False
)
if self.enable_coreml_compatible_fn:
# we set persistent to false so that these weights are not part of model's state_dict
self.register_buffer(
name="unfolding_weights",
tensor=self._compute_unfolding_weights(),
persistent=False,
)
def _compute_unfolding_weights(self) -> Tensor:
# [P_h * P_w, P_h * P_w]
weights = torch.eye(self.patch_h * self.patch_w, dtype=torch.float)
# [P_h * P_w, P_h * P_w] --> [P_h * P_w, 1, P_h, P_w]
weights = weights.reshape(
(self.patch_h * self.patch_w, 1, self.patch_h, self.patch_w)
)
# [P_h * P_w, 1, P_h, P_w] --> [P_h * P_w * C, 1, P_h, P_w]
weights = weights.repeat(self.cnn_out_dim, 1, 1, 1)
return weights
def _build_attn_layer(
self,
opts,
d_model: int,
ffn_mult: Union[Sequence, int, float],
n_layers: int,
attn_dropout: float,
dropout: float,
ffn_dropout: float,
attn_norm_layer: str,
*args,
**kwargs
) -> Tuple[nn.Module, int]:
if isinstance(ffn_mult, Sequence) and len(ffn_mult) == 2:
ffn_dims = (
np.linspace(ffn_mult[0], ffn_mult[1], n_layers, dtype=float) * d_model
)
elif isinstance(ffn_mult, Sequence) and len(ffn_mult) == 1:
ffn_dims = [ffn_mult[0] * d_model] * n_layers
elif isinstance(ffn_mult, (int, float)):
ffn_dims = [ffn_mult * d_model] * n_layers
else:
raise NotImplementedError
# ensure that dims are multiple of 16
ffn_dims = [int((d // 16) * 16) for d in ffn_dims]
global_rep = [
LinearAttnFFN(
opts=opts,
embed_dim=d_model,
ffn_latent_dim=ffn_dims[block_idx],
attn_dropout=attn_dropout,
dropout=dropout,
ffn_dropout=ffn_dropout,
norm_layer=attn_norm_layer,
)
for block_idx in range(n_layers)
]
global_rep.append(
get_normalization_layer(
opts=opts, norm_type=attn_norm_layer, num_features=d_model
)
)
return nn.Sequential(*global_rep), d_model
def __repr__(self) -> str:
repr_str = "{}(".format(self.__class__.__name__)
repr_str += "\n\t Local representations"
if isinstance(self.local_rep, nn.Sequential):
for m in self.local_rep:
repr_str += "\n\t\t {}".format(m)
else:
repr_str += "\n\t\t {}".format(self.local_rep)
repr_str += "\n\t Global representations with patch size of {}x{}".format(
self.patch_h,
self.patch_w,
)
if isinstance(self.global_rep, nn.Sequential):
for m in self.global_rep:
repr_str += "\n\t\t {}".format(m)
else:
repr_str += "\n\t\t {}".format(self.global_rep)
if isinstance(self.conv_proj, nn.Sequential):
for m in self.conv_proj:
repr_str += "\n\t\t {}".format(m)
else:
repr_str += "\n\t\t {}".format(self.conv_proj)
repr_str += "\n)"
return repr_str
[docs] def unfolding_pytorch(self, feature_map: Tensor) -> Tuple[Tensor, Tuple[int, int]]:
batch_size, in_channels, img_h, img_w = feature_map.shape
# [B, C, H, W] --> [B, C, P, N]
patches = F.unfold(
feature_map,
kernel_size=(self.patch_h, self.patch_w),
stride=(self.patch_h, self.patch_w),
)
patches = patches.reshape(
batch_size, in_channels, self.patch_h * self.patch_w, -1
)
return patches, (img_h, img_w)
[docs] def folding_pytorch(self, patches: Tensor, output_size: Tuple[int, int]) -> Tensor:
batch_size, in_dim, patch_size, n_patches = patches.shape
# [B, C, P, N]
patches = patches.reshape(batch_size, in_dim * patch_size, n_patches)
feature_map = F.fold(
patches,
output_size=output_size,
kernel_size=(self.patch_h, self.patch_w),
stride=(self.patch_h, self.patch_w),
)
return feature_map
[docs] def unfolding_coreml(self, feature_map: Tensor) -> Tuple[Tensor, Tuple[int, int]]:
# im2col is not implemented in Coreml, so here we hack its implementation using conv2d
# we compute the weights
# [B, C, H, W] --> [B, C, P, N]
batch_size, in_channels, img_h, img_w = feature_map.shape
#
patches = F.conv2d(
feature_map,
self.unfolding_weights,
bias=None,
stride=(self.patch_h, self.patch_w),
padding=0,
dilation=1,
groups=in_channels,
)
patches = patches.reshape(
batch_size, in_channels, self.patch_h * self.patch_w, -1
)
return patches, (img_h, img_w)
[docs] def folding_coreml(self, patches: Tensor, output_size: Tuple[int, int]) -> Tensor:
# col2im is not supported on coreml, so tracing fails
# We hack folding function via pixel_shuffle to enable coreml tracing
batch_size, in_dim, patch_size, n_patches = patches.shape
n_patches_h = output_size[0] // self.patch_h
n_patches_w = output_size[1] // self.patch_w
feature_map = patches.reshape(
batch_size, in_dim * self.patch_h * self.patch_w, n_patches_h, n_patches_w
)
assert (
self.patch_h == self.patch_w
), "For Coreml, we need patch_h and patch_w are the same"
feature_map = F.pixel_shuffle(feature_map, upscale_factor=self.patch_h)
return feature_map
[docs] def forward_spatial(self, x: Tensor, *args, **kwargs) -> Tensor:
x = self.resize_input_if_needed(x)
fm = self.local_rep(x)
# convert feature map to patches
if self.enable_coreml_compatible_fn:
patches, output_size = self.unfolding_coreml(fm)
else:
patches, output_size = self.unfolding_pytorch(fm)
# learn global representations on all patches
patches = self.global_rep(patches)
# [B x Patch x Patches x C] --> [B x C x Patches x Patch]
if self.enable_coreml_compatible_fn:
fm = self.folding_coreml(patches=patches, output_size=output_size)
else:
fm = self.folding_pytorch(patches=patches, output_size=output_size)
fm = self.conv_proj(fm)
return fm
[docs] def forward_temporal(
self, x: Tensor, x_prev: Tensor, *args, **kwargs
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
x = self.resize_input_if_needed(x)
fm = self.local_rep(x)
# convert feature map to patches
if self.enable_coreml_compatible_fn:
patches, output_size = self.unfolding_coreml(fm)
else:
patches, output_size = self.unfolding_pytorch(fm)
# learn global representations
for global_layer in self.global_rep:
if isinstance(global_layer, LinearAttnFFN):
patches = global_layer(x=patches, x_prev=x_prev)
else:
patches = global_layer(patches)
# [B x Patch x Patches x C] --> [B x C x Patches x Patch]
if self.enable_coreml_compatible_fn:
fm = self.folding_coreml(patches=patches, output_size=output_size)
else:
fm = self.folding_pytorch(patches=patches, output_size=output_size)
fm = self.conv_proj(fm)
return fm, patches
[docs] def forward(
self, x: Union[Tensor, Tuple[Tensor]], *args, **kwargs
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
if isinstance(x, Tuple) and len(x) == 2:
# for spatio-temporal data (e.g., videos)
return self.forward_temporal(x=x[0], x_prev=x[1])
elif isinstance(x, Tensor):
# for image data
return self.forward_spatial(x)
else:
raise NotImplementedError