Source code for cvnets.layers.single_head_attention

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

from typing import Optional

import torch
from torch import Tensor, nn
from torch.nn import functional as F

from cvnets.layers.base_layer import BaseLayer
from cvnets.layers.dropout import Dropout
from cvnets.layers.linear_layer import LinearLayer


[docs]class SingleHeadAttention(BaseLayer): """ This layer applies a single-head attention as described in `DeLighT <https://arxiv.org/abs/2008.00623>`_ paper Args: embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})` attn_dropout (Optional[float]): Attention dropout. Default: 0.0 bias (Optional[bool]): Use bias or not. Default: ``True`` Shape: - Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches, and :math:`C_{in}` is input embedding dim - Output: same shape as the input """
[docs] def __init__( self, embed_dim: int, attn_dropout: Optional[float] = 0.0, bias: Optional[bool] = True, *args, **kwargs ) -> None: super().__init__() self.qkv_proj = LinearLayer( in_features=embed_dim, out_features=3 * embed_dim, bias=bias ) self.attn_dropout = Dropout(p=attn_dropout) self.out_proj = LinearLayer( in_features=embed_dim, out_features=embed_dim, bias=bias ) self.softmax = nn.Softmax(dim=-1) self.embed_dim = embed_dim self.scaling = self.embed_dim**-0.5
def __repr__(self) -> str: return "{}(embed_dim={}, attn_dropout={})".format( self.__class__.__name__, self.embed_dim, self.attn_dropout.p )
[docs] def forward( self, x_q: Tensor, x_kv: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, *args, **kwargs ) -> Tensor: # [N, P, C] --> [N, P, 3C] if x_kv is None: qkv = self.qkv_proj(x_q) # [N, P, 3C] --> [N, P, C] x 3 query, key, value = torch.chunk(qkv, chunks=3, dim=-1) else: query = F.linear( x_q, weight=self.qkv_proj.weight[: self.embed_dim, ...], bias=self.qkv_proj.bias[: self.embed_dim], ) # [N, P, C] --> [N, P, 2C] kv = F.linear( x_kv, weight=self.qkv_proj.weight[self.embed_dim :, ...], bias=self.qkv_proj.bias[self.embed_dim :], ) key, value = torch.chunk(kv, chunks=2, dim=-1) query = query * self.scaling # [N, P, C] --> [N, C, P] key = key.transpose(-2, -1) # QK^T # [N, P, C] x [N, C, P] --> [N, P, P] attn = torch.matmul(query, key) if attn_mask is not None: # attn_mask shape should be the same as attn assert list(attn_mask.shape) == list( attn.shape ), "Shape of attention mask and attn should be the same. Got: {} and {}".format( attn_mask.shape, attn.shape ) attn = attn + attn_mask if key_padding_mask is not None: # Do not attend to padding positions # key padding mask size is [N, P] batch_size, num_src_tokens, num_tgt_tokens = attn.shape assert key_padding_mask.dim() == 2 and list(key_padding_mask.shape) == [ batch_size, num_tgt_tokens, ], "Key_padding_mask should be 2-dimension with shape [{}, {}]. Got: {}".format( batch_size, num_tgt_tokens, key_padding_mask.shape ) attn = attn.masked_fill( key_padding_mask.unsqueeze(1).to(torch.bool), float("-inf"), ) attn = self.softmax(attn) attn = self.attn_dropout(attn) # weighted sum # [N, P, P] x [N, P, C] --> [N, P, C] out = torch.matmul(attn, value) out = self.out_proj(out) return out