#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
from typing import Optional
import torch
from torch import Tensor
from torch.nn import functional as F
from cvnets.layers.base_layer import BaseLayer
from cvnets.layers.conv_layer import ConvLayer2d
from cvnets.layers.dropout import Dropout
[docs]class LinearSelfAttention(BaseLayer):
"""
This layer applies a self-attention with linear complexity, as described in `MobileViTv2 <https://arxiv.org/abs/2206.02680>`_ paper.
This layer can be used for self- as well as cross-attention.
Args:
opts: command line arguments
embed_dim (int): :math:`C` from an expected input of size :math:`(N, C, H, W)`
attn_dropout (Optional[float]): Dropout value for context scores. Default: 0.0
bias (Optional[bool]): Use bias in learnable layers. Default: True
Shape:
- Input: :math:`(N, C, P, N)` where :math:`N` is the batch size, :math:`C` is the input channels,
:math:`P` is the number of pixels in the patch, and :math:`N` is the number of patches
- Output: same as the input
.. note::
For MobileViTv2, we unfold the feature map [B, C, H, W] into [B, C, P, N] where P is the number of pixels
in a patch and N is the number of patches. Because channel is the first dimension in this unfolded tensor,
we use point-wise convolution (instead of a linear layer). This avoids a transpose operation (which may be
expensive on resource-constrained devices) that may be required to convert the unfolded tensor from
channel-first to channel-last format in case of a linear layer.
"""
[docs] def __init__(
self,
opts,
embed_dim: int,
attn_dropout: Optional[float] = 0.0,
bias: Optional[bool] = True,
*args,
**kwargs
) -> None:
super().__init__()
self.qkv_proj = ConvLayer2d(
opts=opts,
in_channels=embed_dim,
out_channels=1 + (2 * embed_dim),
bias=bias,
kernel_size=1,
use_norm=False,
use_act=False,
)
self.attn_dropout = Dropout(p=attn_dropout)
self.out_proj = ConvLayer2d(
opts=opts,
in_channels=embed_dim,
out_channels=embed_dim,
bias=bias,
kernel_size=1,
use_norm=False,
use_act=False,
)
self.embed_dim = embed_dim
def __repr__(self):
return "{}(embed_dim={}, attn_dropout={})".format(
self.__class__.__name__, self.embed_dim, self.attn_dropout.p
)
[docs] @staticmethod
def visualize_context_scores(context_scores):
# [B, 1, P, N]
batch_size, channels, num_pixels, num_patches = context_scores.shape
assert batch_size == 1, "For visualization purposes, use batch size of 1"
assert (
channels == 1
), "The inner-product between input and latent node (query) is a scalar"
up_scale_factor = int(num_pixels**0.5)
patch_h = patch_w = int(context_scores.shape[-1] ** 0.5)
# [1, 1, P, N] --> [1, P, h, w]
context_scores = context_scores.reshape(1, num_pixels, patch_h, patch_w)
# Fold context scores [1, P, h, w] using pixel shuffle to obtain [1, 1, H, W]
context_map = F.pixel_shuffle(context_scores, upscale_factor=up_scale_factor)
# [1, 1, H, W] --> [H, W]
context_map = context_map.squeeze()
# For ease of visualization, we do min-max normalization
min_val = torch.min(context_map)
max_val = torch.max(context_map)
context_map = (context_map - min_val) / (max_val - min_val)
try:
import os
from glob import glob
import cv2
# convert from float to byte
context_map = (context_map * 255).byte().cpu().numpy()
context_map = cv2.resize(
context_map, (80, 80), interpolation=cv2.INTER_NEAREST
)
colored_context_map = cv2.applyColorMap(context_map, cv2.COLORMAP_JET)
# Lazy way to dump feature maps in attn_res folder. Make sure that directory is empty and copy
# context maps before running on different image. Otherwise, attention maps will be overridden.
res_dir_name = "attn_res"
if not os.path.isdir(res_dir_name):
os.makedirs(res_dir_name)
f_name = "{}/h_{}_w_{}_index_".format(res_dir_name, patch_h, patch_w)
files_cmap = glob(
"{}/h_{}_w_{}_index_*.png".format(res_dir_name, patch_h, patch_w)
)
idx = len(files_cmap)
f_name += str(idx)
cv2.imwrite("{}.png".format(f_name), colored_context_map)
return colored_context_map
except ModuleNotFoundError as mnfe:
print("Please install OpenCV to visualize context maps")
return context_map
def _forward_self_attn(self, x: Tensor, *args, **kwargs) -> Tensor:
# [B, C, P, N] --> [B, h + 2d, P, N]
qkv = self.qkv_proj(x)
# Project x into query, key and value
# Query --> [B, 1, P, N]
# value, key --> [B, d, P, N]
query, key, value = torch.split(
qkv, split_size_or_sections=[1, self.embed_dim, self.embed_dim], dim=1
)
# apply softmax along N dimension
context_scores = F.softmax(query, dim=-1)
# Uncomment below line to visualize context scores
# self.visualize_context_scores(context_scores=context_scores)
context_scores = self.attn_dropout(context_scores)
# Compute context vector
# [B, d, P, N] x [B, 1, P, N] -> [B, d, P, N]
context_vector = key * context_scores
# [B, d, P, N] --> [B, d, P, 1]
context_vector = torch.sum(context_vector, dim=-1, keepdim=True)
# combine context vector with values
# [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N]
out = F.relu(value) * context_vector.expand_as(value)
out = self.out_proj(out)
return out
def _forward_cross_attn(
self, x: Tensor, x_prev: Optional[Tensor] = None, *args, **kwargs
) -> Tensor:
# x --> [B, C, P, N]
# x_prev = [B, C, P, M]
batch_size, in_dim, kv_patch_area, kv_num_patches = x.shape
q_patch_area, q_num_patches = x.shape[-2:]
assert (
kv_patch_area == q_patch_area
), "The number of pixels in a patch for query and key_value should be the same"
# compute query, key, and value
# [B, C, P, M] --> [B, 1 + d, P, M]
qk = F.conv2d(
x_prev,
weight=self.qkv_proj.block.conv.weight[: self.embed_dim + 1, ...],
bias=self.qkv_proj.block.conv.bias[: self.embed_dim + 1, ...],
)
# [B, 1 + d, P, M] --> [B, 1, P, M], [B, d, P, M]
query, key = torch.split(qk, split_size_or_sections=[1, self.embed_dim], dim=1)
# [B, C, P, N] --> [B, d, P, N]
value = F.conv2d(
x,
weight=self.qkv_proj.block.conv.weight[self.embed_dim + 1 :, ...],
bias=self.qkv_proj.block.conv.bias[self.embed_dim + 1 :, ...],
)
# apply softmax along M dimension
context_scores = F.softmax(query, dim=-1)
context_scores = self.attn_dropout(context_scores)
# compute context vector
# [B, d, P, M] * [B, 1, P, M] -> [B, d, P, M]
context_vector = key * context_scores
# [B, d, P, M] --> [B, d, P, 1]
context_vector = torch.sum(context_vector, dim=-1, keepdim=True)
# combine context vector with values
# [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N]
out = F.relu(value) * context_vector.expand_as(value)
out = self.out_proj(out)
return out
[docs] def forward(
self, x: Tensor, x_prev: Optional[Tensor] = None, *args, **kwargs
) -> Tensor:
if x_prev is None:
return self._forward_self_attn(x, *args, **kwargs)
else:
return self._forward_cross_attn(x, x_prev=x_prev, *args, **kwargs)