Source code for cvnets.text_encoders.transformer

# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import argparse
import math
from typing import Optional, Sequence

import torch
from torch import Tensor, nn
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint as gradient_checkpoint_fn

from cvnets.layers import (
    Dropout,
    Embedding,
    PositionalEmbedding,
    get_normalization_layer,
)
from cvnets.modules import TransformerEncoder
from cvnets.text_encoders import TEXT_ENCODER_REGISTRY, BaseTextEncoder
from utils import logger


[docs]@TEXT_ENCODER_REGISTRY.register(name="transformer") class TextTransformer(BaseTextEncoder):
[docs] def __init__(self, opts, projection_dim: int, *args, **kwargs) -> None: model_dim = getattr(opts, "model.text.transformer.model_dim", 512) no_scale_embedding = getattr( opts, "model.text.transformer.no_scale_embedding", False ) no_pos_embedding = getattr( opts, "model.text.transformer.no_pos_embedding", False ) embed_dropout = getattr(opts, "model.text.transformer.embed_dropout", 0.0) dropout = getattr(opts, "model.text.transformer.dropout", 0.0) attn_dropout = getattr(opts, "model.text.transformer.attn_dropout", 0.0) ffn_dropout = getattr(opts, "model.text.transformer.ffn_dropout", 0.0) norm_layer = getattr(opts, "model.text.transformer.norm_layer", None) gradient_ckpt = getattr( opts, "model.text.transformer.gradient_checkpoint", False ) if norm_layer is None: logger.error( "Normalization layer can not be None in {}".format( self.__class__.__name__ ) ) super().__init__(opts=opts, projection_dim=projection_dim, *args, **kwargs) # token embedding layer padding_index = getattr(opts, "dataset.padding_index", None) self.embedding_layer = Embedding( opts=opts, embedding_dim=model_dim, padding_idx=padding_index, num_embeddings=self.vocab_size, ) self.embed_scale = 1.0 if no_scale_embedding else model_dim**-0.5 context_length = getattr(opts, "dataset.text_context_length") if getattr(opts, "common.debug_mode", False): context_length = 77 assert context_length is not None, ( "Context length can't be None. Please set dataset.text_context_length " "argument in your dataset class" ) self.positional_embedding = ( None if no_pos_embedding else PositionalEmbedding( opts=opts, num_embeddings=context_length, embedding_dim=model_dim, padding_idx=getattr(opts, "dataset.padding_index", None), is_learnable=not getattr( opts, "model.text.transformer.sinusoidal_pos_emb", False ), ) ) self.embedding_dropout = Dropout(p=embed_dropout) # Transformer layer n_transformer_layers = getattr( opts, "model.text.transformer.n_transformer_layers", 6 ) # FFN multipliers for transformer layer ffn_multipliers = getattr( opts, "model.text.transformer.ffn_multiplier_per_layer", 4.0 ) if isinstance(ffn_multipliers, (float, int)): ffn_multipliers = [ffn_multipliers] * n_transformer_layers if not isinstance(ffn_multipliers, Sequence): logger.error( "{} expects FFN multipliers as a list, whose length is the same as number of " "transformer layers. Got: {}".format( self.__class__.__name__, type(ffn_multipliers) ) ) elif ( isinstance(ffn_multipliers, Sequence) and len(ffn_multipliers) != n_transformer_layers ): logger.error( "We need FFN multiplier for each transformer layer. Got {} ffn multipliers while number of " "transformer layers = {}".format( len(ffn_multipliers), n_transformer_layers ) ) ffn_dims = [ int(math.ceil(model_dim * ffn_mult / 16.0) * 16.0) for ffn_mult in ffn_multipliers ] # Heads for transformer layers mha_heads = getattr(opts, "model.text.transformer.n_heads_per_layer", 8) if isinstance(mha_heads, int): mha_heads = [mha_heads] * n_transformer_layers if not isinstance(mha_heads, Sequence): logger.error( "{} expects MHA heads as a list, whose length is the same as number of " "transformer layers. Got: {}".format( self.__class__.__name__, type(mha_heads) ) ) elif isinstance(mha_heads, Sequence) and len(mha_heads) != n_transformer_layers: logger.error( "{} needs MHA heads for each transformer layer. Got {} mha heads while number of " "transformer layers = {}".format( self.__class__.__name__, len(mha_heads), n_transformer_layers ) ) self.transformer = nn.ModuleList( [ TransformerEncoder( opts=opts, embed_dim=model_dim, num_heads=mha_heads[layer_idx], ffn_latent_dim=ffn_dims[layer_idx], attn_dropout=attn_dropout, ffn_dropout=ffn_dropout, dropout=dropout, transformer_norm_layer=norm_layer, ) for layer_idx in range(n_transformer_layers) ] ) self.final_layer_norm = get_normalization_layer( opts, num_features=model_dim, norm_type=norm_layer ) self.projection_layer = nn.Parameter( torch.empty(model_dim, self.projection_dim) ) self.model_dim = model_dim self.reset_parameters_clip_style() self.gradient_ckpt = gradient_ckpt self.use_pytorch_mha = False self.causal_masking = getattr( opts, "model.text.transformer.causal_masking", False ) self.classes_per_split_zero_shot = max( 1, int(getattr(opts, "model.text.transformer.classes_per_split_zero_shot", 1)), )
[docs] def reset_parameters_clip_style(self): """This function resets the weights of Transformer model as done in the CLIP paper""" # reset the weights of the embedding and positional embedding layers nn.init.normal_(self.embedding_layer.weight, mean=0.0, std=0.02) # if self.positional_embedding is not None and not getattr( # self.opts, "model.text.transformer.sinusoidal_pos_emb", False # ): # nn.init.normal_( # self.positional_embedding.pos_embed.weight, mean=0.0, std=0.01 # ) # compute standard deviation for different linear layers in transformer model attn_std = self.model_dim**-0.5 proj_std = attn_std * ((2 * len(self.transformer)) ** -0.5) fc_std = (2 * self.model_dim) ** -0.5 for block in self.transformer: # multi-head attention QKV projection layer nn.init.normal_( block.pre_norm_mha[1].qkv_proj.weight, mean=0.0, std=attn_std ) # multi-head attention output projection layer nn.init.normal_( block.pre_norm_mha[1].out_proj.weight, mean=0.0, std=proj_std ) # FFN expansion layer nn.init.normal_(block.pre_norm_ffn[1].weight, mean=0.0, std=fc_std) # FFN reduction layer nn.init.normal_(block.pre_norm_ffn[4].weight, mean=0.0, std=proj_std) nn.init.normal_(self.projection_layer, mean=0.0, std=attn_std)
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: if cls != TextTransformer: return parser group = parser.add_argument_group(title=cls.__name__) group.add_argument( "--model.text.transformer.model-dim", type=int, default=512, help="Model dimension of the transformer model", ) group.add_argument( "--model.text.transformer.no-scale-embedding", action="store_true", help="Do not scale the output of embedding layer in {}".format( cls.__name__ ), ) group.add_argument( "--model.text.transformer.no-pos-embedding", action="store_true", help="Do not add positional embeddings to the output of embedding layer in {}".format( cls.__name__ ), ) group.add_argument( "--model.text.transformer.embed-dropout", type=float, default=0.0, help="Dropout in embedding layer", ) # transformer layer parameters default_layers = 6 group.add_argument( "--model.text.transformer.n-transformer-layers", type=int, default=default_layers, help="Number of transformer layers in {}".format(cls.__name__), ) group.add_argument( "--model.text.transformer.n-heads-per-layer", type=int, default=[8] * default_layers, nargs="+", help="Number of transformer heads per transformer layer", ) group.add_argument( "--model.text.transformer.ffn-multiplier-per-layer", type=float, default=[4.0] * default_layers, nargs="+", help="FFN multiplier for each transformer layer", ) group.add_argument( "--model.text.transformer.attn-dropout", type=float, default=0.0, help="Dropout in multi-head attention", ) group.add_argument( "--model.text.transformer.ffn-dropout", type=float, default=0.0, help="Dropout between linear layers in FFN", ) group.add_argument( "--model.text.transformer.dropout", type=float, default=0.0, help="Dropout in transformer", ) group.add_argument( "--model.text.transformer.norm-layer", type=str, default="layer_norm", help="Normalization layer", ) group.add_argument( "--model.text.transformer.sinusoidal-pos-emb", action="store_true", help="Use sinusoidal positional embedding", ) group.add_argument( "--model.text.transformer.gradient-checkpoint", action="store_true", help="Use gradient checkpointing", ) group.add_argument( "--model.text.transformer.num-checkpoint-segments", type=int, default=1, help="Number of gradient checkpoint segments", ) group.add_argument( "--model.text.transformer.causal-masking", action="store_true", help="Use causal masking", ) group.add_argument( "--model.text.transformer.classes-per-split-zero-shot", type=int, default=20, help="Divide zero-shot classes into these many chunks, for faster processing", ) return parser
[docs] def forward_embedding( self, text_tokens: Tensor, ): # [Batch, Seq_len] --> [Batch, Seq_len, hidden_dim] token_emb = self.embedding_layer(text_tokens) # token_emb = self.embed_scale * token_emb seq_len = token_emb.shape[1] if self.positional_embedding is not None: token_emb = token_emb + self.positional_embedding(seq_len).to( token_emb.dtype ) token_emb = self.embedding_dropout(token_emb) return token_emb
[docs] def build_attention_mask(self, context_length: int, batch_size: int): # lazily create causal attention mask, with full attention between the vision tokens # pytorch uses additive attention mask; fill with -inf mask = torch.empty(context_length, context_length) mask.fill_(float("-inf")) mask.triu_(1) # zero out the lower diagonal if not self.use_pytorch_mha: mask = mask.unsqueeze(0) # add dummy batch dimension mask = mask.expand(batch_size, -1, -1) return mask
[docs] def encode_text( self, text_tokens: Tensor, key_padding_mask: Optional[Tensor] = None, return_all_tokens: bool = False, *args, **kwargs ) -> Tensor: """ Returns token embeddings. :param text_tokens: a tensor of token indices. ([Batch, Seq_len]) :param key_padding_mask: a tensor of boolean values as the padding mask. :param return_all_tokens: a boolean flag to return all tokens, defaults to False to return only EOT token embedding. :return: a tensor of [Batch, Seq_len, hidden_dim] if return_all_tokens is True, otherwise a tensor of [Batch, hidden_dim]. """ # discrete tokens to continuous embeddings # [Batch, Seq_len] --> [Batch, Seq_len, hidden_dim] token_emb = self.forward_embedding(text_tokens) # [1, Seq_len, Seq_len] attn_mask = None if self.causal_masking: attn_mask = self.build_attention_mask( context_length=text_tokens.shape[1], batch_size=text_tokens.shape[0] ) attn_mask = attn_mask.to(device=token_emb.device, dtype=token_emb.dtype) key_padding_mask = None if self.use_pytorch_mha: # [Batch, Seq_len, hidden_dim] --> [Seq_len, Batch, hidden_dim] # we will use PyTorch's multi-head attention, which uses sequence_first format token_emb = token_emb.transpose(0, 1) for layer in self.transformer: if self.gradient_ckpt: token_emb = gradient_checkpoint_fn( layer, token_emb, None, key_padding_mask, attn_mask ) else: token_emb = layer( token_emb, key_padding_mask=key_padding_mask, attn_mask=attn_mask, use_pytorch_mha=self.use_pytorch_mha, ) # Apply layer norm token_emb = self.final_layer_norm(token_emb) if return_all_tokens: if self.use_pytorch_mha: # [Seq_len, Batch, hidden_dim] --> [Batch, Seq_len, hidden_dim] token_emb = token_emb.transpose(0, 1) return token_emb # take features from the eot embedding (eot_token is the highest number in each sequence) if self.use_pytorch_mha: token_emb = token_emb[ text_tokens.argmax(dim=-1), torch.arange(text_tokens.shape[0]) ] else: token_emb = token_emb[ torch.arange(text_tokens.shape[0]), text_tokens.argmax(dim=-1) ] token_emb = token_emb @ self.projection_layer # normalize text features token_emb = F.normalize(token_emb, dim=-1) return token_emb
[docs] def forward_zero_shot( self, text_tokens: Tensor, key_padding_mask: Optional[Tensor] = None, *args, **kwargs ) -> Tensor: # In case of zero-shot evaluation, text tokens is of shape [Batch, num_classes, num_captions, context_length] # For example, in the ImageNet dataset, we have 1000 classes, and for each class we generate certain number of # captions (each caption with context_length tokens) if self.training: raise NotImplementedError( "Zero-shot evaluation is only supported with eval mode" ) if text_tokens.ndim != 4: logger.error( "For zero-shot evaluation, expected size of text is [Batch, Num_classes, num_captions, context_len]" ) batch_size, num_classes, num_captions, context_len = text_tokens.shape # for zero-shot evaluation, text templates are the same across all images in the batch # Therefore, batch size should be 1. if batch_size > 1: text_tokens = text_tokens[0:1] batch_size = 1 logger.warning( "For zero-shot evaluation, text templates are the same across all images in the batch." "Got: {}. Please consider adjusting collate function.".format( batch_size ) ) text_features = [] for start_idx in range(0, num_classes, self.classes_per_split_zero_shot): end_idx = min(start_idx + self.classes_per_split_zero_shot, num_classes) text_tokens_split = text_tokens[0, start_idx:end_idx, ...] num_classes_split = text_tokens_split.shape[0] text_tokens_split = text_tokens_split.reshape( num_classes_split * num_captions, context_len ) key_padding_mask_split = None if key_padding_mask is not None: key_padding_mask_split = key_padding_mask[0, start_idx:end_idx, ...] key_padding_mask_split = key_padding_mask_split.reshape( num_classes_split * num_captions, context_len ) # [num_classes_per_split * num_cations, context_len] --> [num_classes_per_split * num_cations, latent_dim] class_embedding_split = self.encode_text( text_tokens=text_tokens_split, key_padding_mask=key_padding_mask_split ) # [num_classes_per_split * num_cations, latent_dim] --> [num_classes_per_split, num_cations, latent_dim] class_embedding_split = class_embedding_split.reshape( num_classes_split, num_captions, class_embedding_split.shape[-1] ) # Compute mean of all classes # [num_classes_per_split, num_cations, latent_dim] --> [num_classes_per_split, latent_dim] mean_class_embedding_split = class_embedding_split.mean(dim=1) # Normalize the embeddings mean_class_embedding_split = F.normalize(mean_class_embedding_split, dim=-1) text_features.append(mean_class_embedding_split) # [num_classes_per_split, latent_dim] * num_splits --> [num_classes, Latent_dim] text_features = torch.cat(text_features, dim=0) # [num_classes, Latent_dim] --> [Latent_dim, num_classes] text_features = text_features.transpose(0, 1) return text_features.contiguous()
[docs] def forward( self, text_tokens: Tensor, key_padding_mask: Optional[Tensor] = None, *args, **kwargs ) -> Tensor: if text_tokens.dim() == 4: # It's for zero-shot evaluation. # Each class in the dataset has multiple captions # Encoding happens separately for each classes/captions due to OOM issue return self.forward_zero_shot( text_tokens=text_tokens, key_padding_mask=key_padding_mask, *args, **kwargs ) elif text_tokens.dim() == 2: # Image-text pair data with single caption # [B, CL] --> [B, d] text_tokens = self.encode_text( text_tokens=text_tokens, key_padding_mask=key_padding_mask, *args, **kwargs ) return text_tokens elif text_tokens.dim() == 3: # Image-text pair with multiple captions per image (e.g. Flickr-30k) # Treat them as separate captions by reshaping into batch dim # [B, N, C] --> [B*N, C] -encode-> [B*N, d] --> [B, N, d] b, n, _ = text_tokens.shape text_tokens = text_tokens.reshape(b * n, -1) if key_padding_mask: key_padding_mask = key_padding_mask.reshape(b * n, -1) text_tokens = self.encode_text( text_tokens=text_tokens, key_padding_mask=key_padding_mask, *args, **kwargs ) text_tokens = text_tokens.reshape(b, n, -1) return text_tokens else: raise NotImplementedError