Source code for cvnets.layers.linear_attention

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
from typing import Optional

import torch
from torch import Tensor
from torch.nn import functional as F

from cvnets.layers.base_layer import BaseLayer
from cvnets.layers.conv_layer import ConvLayer2d
from cvnets.layers.dropout import Dropout


[docs]class LinearSelfAttention(BaseLayer): """ This layer applies a self-attention with linear complexity, as described in `MobileViTv2 <https://arxiv.org/abs/2206.02680>`_ paper. This layer can be used for self- as well as cross-attention. Args: opts: command line arguments embed_dim (int): :math:`C` from an expected input of size :math:`(N, C, H, W)` attn_dropout (Optional[float]): Dropout value for context scores. Default: 0.0 bias (Optional[bool]): Use bias in learnable layers. Default: True Shape: - Input: :math:`(N, C, P, N)` where :math:`N` is the batch size, :math:`C` is the input channels, :math:`P` is the number of pixels in the patch, and :math:`N` is the number of patches - Output: same as the input .. note:: For MobileViTv2, we unfold the feature map [B, C, H, W] into [B, C, P, N] where P is the number of pixels in a patch and N is the number of patches. Because channel is the first dimension in this unfolded tensor, we use point-wise convolution (instead of a linear layer). This avoids a transpose operation (which may be expensive on resource-constrained devices) that may be required to convert the unfolded tensor from channel-first to channel-last format in case of a linear layer. """
[docs] def __init__( self, opts, embed_dim: int, attn_dropout: Optional[float] = 0.0, bias: Optional[bool] = True, *args, **kwargs ) -> None: super().__init__() self.qkv_proj = ConvLayer2d( opts=opts, in_channels=embed_dim, out_channels=1 + (2 * embed_dim), bias=bias, kernel_size=1, use_norm=False, use_act=False, ) self.attn_dropout = Dropout(p=attn_dropout) self.out_proj = ConvLayer2d( opts=opts, in_channels=embed_dim, out_channels=embed_dim, bias=bias, kernel_size=1, use_norm=False, use_act=False, ) self.embed_dim = embed_dim
def __repr__(self): return "{}(embed_dim={}, attn_dropout={})".format( self.__class__.__name__, self.embed_dim, self.attn_dropout.p )
[docs] @staticmethod def visualize_context_scores(context_scores): # [B, 1, P, N] batch_size, channels, num_pixels, num_patches = context_scores.shape assert batch_size == 1, "For visualization purposes, use batch size of 1" assert ( channels == 1 ), "The inner-product between input and latent node (query) is a scalar" up_scale_factor = int(num_pixels**0.5) patch_h = patch_w = int(context_scores.shape[-1] ** 0.5) # [1, 1, P, N] --> [1, P, h, w] context_scores = context_scores.reshape(1, num_pixels, patch_h, patch_w) # Fold context scores [1, P, h, w] using pixel shuffle to obtain [1, 1, H, W] context_map = F.pixel_shuffle(context_scores, upscale_factor=up_scale_factor) # [1, 1, H, W] --> [H, W] context_map = context_map.squeeze() # For ease of visualization, we do min-max normalization min_val = torch.min(context_map) max_val = torch.max(context_map) context_map = (context_map - min_val) / (max_val - min_val) try: import os from glob import glob import cv2 # convert from float to byte context_map = (context_map * 255).byte().cpu().numpy() context_map = cv2.resize( context_map, (80, 80), interpolation=cv2.INTER_NEAREST ) colored_context_map = cv2.applyColorMap(context_map, cv2.COLORMAP_JET) # Lazy way to dump feature maps in attn_res folder. Make sure that directory is empty and copy # context maps before running on different image. Otherwise, attention maps will be overridden. res_dir_name = "attn_res" if not os.path.isdir(res_dir_name): os.makedirs(res_dir_name) f_name = "{}/h_{}_w_{}_index_".format(res_dir_name, patch_h, patch_w) files_cmap = glob( "{}/h_{}_w_{}_index_*.png".format(res_dir_name, patch_h, patch_w) ) idx = len(files_cmap) f_name += str(idx) cv2.imwrite("{}.png".format(f_name), colored_context_map) return colored_context_map except ModuleNotFoundError as mnfe: print("Please install OpenCV to visualize context maps") return context_map
def _forward_self_attn(self, x: Tensor, *args, **kwargs) -> Tensor: # [B, C, P, N] --> [B, h + 2d, P, N] qkv = self.qkv_proj(x) # Project x into query, key and value # Query --> [B, 1, P, N] # value, key --> [B, d, P, N] query, key, value = torch.split( qkv, split_size_or_sections=[1, self.embed_dim, self.embed_dim], dim=1 ) # apply softmax along N dimension context_scores = F.softmax(query, dim=-1) # Uncomment below line to visualize context scores # self.visualize_context_scores(context_scores=context_scores) context_scores = self.attn_dropout(context_scores) # Compute context vector # [B, d, P, N] x [B, 1, P, N] -> [B, d, P, N] context_vector = key * context_scores # [B, d, P, N] --> [B, d, P, 1] context_vector = torch.sum(context_vector, dim=-1, keepdim=True) # combine context vector with values # [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N] out = F.relu(value) * context_vector.expand_as(value) out = self.out_proj(out) return out def _forward_cross_attn( self, x: Tensor, x_prev: Optional[Tensor] = None, *args, **kwargs ) -> Tensor: # x --> [B, C, P, N] # x_prev = [B, C, P, M] batch_size, in_dim, kv_patch_area, kv_num_patches = x.shape q_patch_area, q_num_patches = x.shape[-2:] assert ( kv_patch_area == q_patch_area ), "The number of pixels in a patch for query and key_value should be the same" # compute query, key, and value # [B, C, P, M] --> [B, 1 + d, P, M] qk = F.conv2d( x_prev, weight=self.qkv_proj.block.conv.weight[: self.embed_dim + 1, ...], bias=self.qkv_proj.block.conv.bias[: self.embed_dim + 1, ...], ) # [B, 1 + d, P, M] --> [B, 1, P, M], [B, d, P, M] query, key = torch.split(qk, split_size_or_sections=[1, self.embed_dim], dim=1) # [B, C, P, N] --> [B, d, P, N] value = F.conv2d( x, weight=self.qkv_proj.block.conv.weight[self.embed_dim + 1 :, ...], bias=self.qkv_proj.block.conv.bias[self.embed_dim + 1 :, ...], ) # apply softmax along M dimension context_scores = F.softmax(query, dim=-1) context_scores = self.attn_dropout(context_scores) # compute context vector # [B, d, P, M] * [B, 1, P, M] -> [B, d, P, M] context_vector = key * context_scores # [B, d, P, M] --> [B, d, P, 1] context_vector = torch.sum(context_vector, dim=-1, keepdim=True) # combine context vector with values # [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N] out = F.relu(value) * context_vector.expand_as(value) out = self.out_proj(out) return out
[docs] def forward( self, x: Tensor, x_prev: Optional[Tensor] = None, *args, **kwargs ) -> Tensor: if x_prev is None: return self._forward_self_attn(x, *args, **kwargs) else: return self._forward_cross_attn(x, x_prev=x_prev, *args, **kwargs)