Source code for cvnets.layers.positional_embedding

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import math
from typing import Optional

import torch
from torch import Tensor, nn
from torch.nn import functional as F

from cvnets.layers import BaseLayer


[docs]class PositionalEmbedding(BaseLayer):
[docs] def __init__( self, opts, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, is_learnable: Optional[bool] = False, sequence_first: Optional[bool] = False, interpolation_mode: Optional[str] = "bilinear", *args, **kwargs ): super().__init__(*args, **kwargs) module = ( LearnablePositionalEmbedding if is_learnable else SinusoidalPositionalEmbedding ) self.pos_embed = module( opts, num_embeddings=num_embeddings, embedding_dim=embedding_dim, padding_idx=padding_idx, sequence_first=sequence_first, interpolation_mode=interpolation_mode, *args, **kwargs )
[docs] def forward(self, seq_len: int, *args, **kwargs) -> Tensor: return self.pos_embed(seq_len, *args, **kwargs)
def __repr__(self): return self.pos_embed.__repr__()
[docs]class LearnablePositionalEmbedding(nn.Module): """Learnable Positional embedding"""
[docs] def __init__( self, opts, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, sequence_first: Optional[bool] = False, interpolation_mode: Optional[str] = "bilinear", *args, **kwargs ): super().__init__() self.pos_embed = nn.Parameter(torch.empty(1, 1, num_embeddings, embedding_dim)) self.embedding_dim = embedding_dim self.num_embeddings = num_embeddings self.padding_idx = padding_idx self.sequence_first = sequence_first self.interpolation_mode = interpolation_mode self.reset_parameters()
[docs] def reset_parameters(self) -> None: nn.init.trunc_normal_(self.pos_embed, mean=0, std=self.embedding_dim**-0.5) if self.padding_idx is not None: with torch.no_grad(): self.pos_embed[:, :, self.padding_idx, ...] = 0.0
[docs] def forward(self, seq_len: int, *args, **kwargs) -> Tensor: # scale pos embedding pos_embed = self.pos_embed if self.padding_idx is not None: with torch.no_grad(): pos_embed[:, :, self.padding_idx, ...] = 0.0 if seq_len != self.num_embeddings: pos_embed = F.interpolate( pos_embed, size=(seq_len, self.embedding_dim), mode=self.interpolation_mode, ) # add dummy batch dimension if self.sequence_first: # Input is of the form [Seq_len, Batch, Embedding_dim] return pos_embed.reshape(seq_len, 1, self.embedding_dim) else: # Input is of the form [Batch, Seq_len, Embedding_dim] return pos_embed.reshape(1, seq_len, self.embedding_dim)
def __repr__(self): return "{}(num_embeddings={}, embedding_dim={}, padding_idx={}, sequence_first={})".format( self.__class__.__name__, self.num_embeddings, self.embedding_dim, self.padding_idx, self.sequence_first, )
[docs]class SinusoidalPositionalEmbedding(nn.Module):
[docs] def __init__( self, opts, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, sequence_first: Optional[bool] = False, interpolation_mode: Optional[str] = "bilinear", *args, **kwargs ): super().__init__() self.padding_idx = padding_idx self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.sequence_first = sequence_first self.interpolation_mode = interpolation_mode self.register_buffer("pos_embed", self.get_weights())
[docs] def get_weights(self) -> Tensor: """Build sinusoidal embeddings. Adapted from Fairseq.""" half_dim = self.embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) emb = torch.arange(self.num_embeddings, dtype=torch.float).unsqueeze( 1 ) * emb.unsqueeze(0) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).reshape( self.num_embeddings, -1 ) if self.embedding_dim % 2 == 1: # zero pad emb = torch.cat([emb, torch.zeros(self.num_embeddings, 1)], dim=1) # set embeddings corresponding to padding index to 0 if self.padding_idx is not None: emb[self.padding_idx, :] = 0 return emb.unsqueeze(0).unsqueeze(0)
[docs] def forward(self, seq_len: int, *args, **kwargs) -> Tensor: # scale pos embedding pos_embed = self.pos_embed if seq_len != self.num_embeddings: pos_embed = F.interpolate( pos_embed, size=(seq_len, self.embedding_dim), mode=self.interpolation_mode, ) if self.sequence_first: # Input is of the form [Seq_len, Batch, Embedding_dim] return pos_embed.reshape(seq_len, 1, self.embedding_dim) else: # Input is of the form [Batch, Seq_len, Embedding_dim] return pos_embed.reshape(1, seq_len, self.embedding_dim)
def __repr__(self): return "{}(num_embeddings={}, embedding_dim={}, padding_idx={}, sequence_first={})".format( self.__class__.__name__, self.num_embeddings, self.embedding_dim, self.padding_idx, self.sequence_first, )