Source code for cvnets.layers.positional_encoding

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import math
from typing import Optional

import torch
from torch import Tensor, nn

from cvnets.layers.base_layer import BaseLayer
from cvnets.layers.dropout import Dropout


[docs]class SinusoidalPositionalEncoding(BaseLayer): """ This layer adds sinusoidal positional embeddings to a 3D input tensor. The code has been adapted from `Pytorch tutorial <https://pytorch.org/tutorials/beginner/transformer_tutorial.html>`_ Args: d_model (int): dimension of the input tensor dropout (Optional[float]): Dropout rate. Default: 0.0 max_len (Optional[int]): Max. number of patches (or seq. length). Default: 5000 channels_last (Optional[bool]): Channels dimension is the last in the input tensor Shape: - Input: :math:`(N, C, P)` or :math:`(N, P, C)` where :math:`N` is the batch size, :math:`C` is the embedding dimension, :math:`P` is the number of patches - Output: same shape as the input """
[docs] def __init__( self, d_model: int, dropout: Optional[float] = 0.0, max_len: Optional[int] = 5000, channels_last: Optional[bool] = True, *args, **kwargs ) -> None: position_last = not channels_last pos_encoding = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) ) pos_encoding[:, 0::2] = torch.sin(position * div_term) pos_encoding[:, 1::2] = torch.cos(position * div_term) # add dummy batch dimension pos_encoding = pos_encoding.unsqueeze(0) # [1 x C x P_max) patch_dim = -2 # patch dimension is second last (N, P, C) if position_last: pos_encoding = pos_encoding.transpose( 1, 2 ) # patch dimension is last (N, C, P) patch_dim = -1 # patch dimension is last (N, C, P) super().__init__() self.dropout = Dropout(p=dropout) self.patch_dim = patch_dim self.register_buffer("pe", pos_encoding)
[docs] def forward_patch_last( self, x, indices: Optional[Tensor] = None, *args, **kwargs ) -> Tensor: # seq_length should be the last dim if indices is None: x = x + self.pe[..., : x.shape[-1]] else: ndim = x.ndim repeat_size = [x.shape[0]] + [-1] * (ndim - 1) pe = self.pe.expand(repeat_size) selected_pe = torch.gather(pe, index=indices, dim=-1) x = x + selected_pe return self.dropout(x)
[docs] def forward_others( self, x, indices: Optional[Tensor] = None, *args, **kwargs ) -> Tensor: # seq_length should be the second last dim if indices is None: x = x + self.pe[..., : x.shape[-2], :] else: ndim = x.ndim repeat_size = [x.shape[0]] + [-1] * (ndim - 1) pe = self.pe.expand(repeat_size) selected_pe = torch.gather(pe, index=indices, dim=-2) x = x + selected_pe return self.dropout(x)
[docs] def forward(self, x, indices: Optional[Tensor] = None, *args, **kwargs) -> Tensor: if self.patch_dim == -1: return self.forward_patch_last(x, indices=indices) else: return self.forward_others(x, indices=indices)
def __repr__(self): return "{}(dropout={})".format(self.__class__.__name__, self.dropout.p)
[docs]class LearnablePositionEncoding(BaseLayer): """ This layer adds learnable positional embeddings to a 3D input tensor. Args: embed_dim (int): dimension of the input tensor num_embeddings (int): number of input embeddings. This is similar to vocab size in NLP. dropout (Optional[float]): Dropout rate. Default: 0.0 channels_last (Optional[bool]): Channels dimension is the last in the input tensor Shape: - Input: :math:`(N, *, C, P)` or :math:`(N, *, P, C)` where :math:`N` is the batch size, :math:`C` is the embedding dimension, :math:`P` is the number of patches - Output: same shape as the input """
[docs] def __init__( self, embed_dim: int, num_embeddings: int, dropout: Optional[float] = 0.0, channels_last: Optional[bool] = True, *args, **kwargs ) -> None: super().__init__() self.pos_emb = nn.Embedding( num_embeddings=num_embeddings, embedding_dim=embed_dim ) self.channel_last = channels_last self.dropout = Dropout(p=dropout)
[docs] def forward(self, x, *args, **kwargs) -> Tensor: num_embeddings = x.shape[-2] if self.channel_last else x.shape[-1] posistions = torch.arange(num_embeddings, dtype=torch.int64, device=x.device) position_emb = self.pos_emb(posistions) position_emb = position_emb.expand_as(x) x = x + position_emb return self.dropout(x)
def __repr__(self): return "{}(embed_dim={}, vocab_size={}, dropout={})".format( self.__class__.__name__, self.pos_emb.embedding_dim, self.pos_emb.num_embeddings, self.dropout.p, )