Source code for cvnets.models.classification.byteformer

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import argparse
import math
from typing import Dict, Optional, Tuple, Union

import numpy as np
import torch
from torch import Tensor, nn
from torch.nn import init

from cvnets.layers import (
    LinearLayer,
    embedding,
    get_normalization_layer,
    normalization,
    positional_embedding,
    token_merging,
)
from cvnets.models import MODEL_REGISTRY, BaseAnyNNModel
from cvnets.models.classification.config.byteformer import get_configuration
from cvnets.modules import WindowedTransformerEncoder


[docs]def unfold_tokens(t: Tensor, kernel_size: int) -> Tensor: """ Group tokens from tensor @t using torch.Tensor.unfold, using the given kernel size. This amounts to windowing @t using overlapping windows of size @kernel_size, with overlap of @kernel_size // 2. Args: t: A tensor of shape [batch_size, sequence_length, num_channels]. kernel_size: The kernel size. Returns: A tensor of shape [batch_size * (sequence_length - kernel_size) // (kernel_size // 2) + 1, kernel_size, num_channels]. """ t = t.unfold(dimension=1, size=kernel_size, step=kernel_size // 2) B, L, C, _ = t.shape t = t.reshape(B * L, C, kernel_size) t = t.transpose(1, 2) return t
[docs]@MODEL_REGISTRY.register(name="byteformer", type="classification") class ByteFormer(BaseAnyNNModel): """ This class defines the `ByteFormer <https://arxiv.org/pdf/2306.00238.pdf>`_ architecture. """
[docs] def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: super().__init__(opts, *args, **kwargs) byteformer_config = get_configuration(opts) embed_dim = byteformer_config["embed_dim"] ffn_dim = byteformer_config["ffn_dim"] n_transformer_layers = byteformer_config["n_transformer_layers"] num_heads = byteformer_config["n_attn_heads"] attn_dropout = byteformer_config["attn_dropout"] dropout = byteformer_config["dropout"] ffn_dropout = byteformer_config["ffn_dropout"] norm_layer = byteformer_config["norm_layer"] # This is usually 257 in the case of byte inputs (2**8 + 1 mask token). vocab_size = getattr(opts, "model.classification.byteformer.vocab_size") self.embeddings = embedding.Embedding( opts, num_embeddings=vocab_size, embedding_dim=embed_dim, padding_idx=-1 ) # Reinitialize everything except the padding index. init.trunc_normal_(self.embeddings.weight[:-1], std=math.sqrt(1.0 / embed_dim)) self.dummy_input_token_length = getattr( opts, "model.classification.byteformer.dummy_input_token_length" ) # Add token reduction convolution. self.conv_kernel_size = getattr( opts, "model.classification.byteformer.conv_kernel_size" ) if self.conv_kernel_size == 0: # We skip the convolution. self.token_reduction_net = None if self.conv_kernel_size is not None: self.token_reduction_net = nn.Conv1d( embed_dim, get_configuration(opts)["embed_dim"], kernel_size=self.conv_kernel_size, stride=self.conv_kernel_size // 2, bias=False, ) # Add the positional embeddings. self.max_num_tokens = getattr( opts, "model.classification.byteformer.max_num_tokens" ) self.sinusoidal_pos_embed = getattr( opts, "model.classification.byteformer.sinusoidal_pos_emb" ) self.pos_embed = positional_embedding.PositionalEmbedding( opts=opts, num_embeddings=self.max_num_tokens, embedding_dim=embed_dim, sequence_first=False, padding_idx=None, is_learnable=not self.sinusoidal_pos_embed, interpolation_mode="bilinear", ) pos_emb_drop_p = getattr(opts, "model.classification.byteformer.dropout") self.emb_dropout = nn.Dropout(p=pos_emb_drop_p) # Build the transformer backbone. window_sizes = getattr(opts, "model.classification.byteformer.window_sizes") window_shifts = getattr(opts, "model.classification.byteformer.window_shifts") downsample = getattr(opts, "model.classification.byteformer.downsample") if len(window_sizes) == 1: window_sizes = window_sizes * n_transformer_layers for x in [window_sizes, window_shifts, downsample]: if len(x) != n_transformer_layers: raise ValueError( f"Invalid argument length {len(x)} != {n_transformer_layers}" ) stochastic_dropout = getattr( opts, "model.classification.byteformer.stochastic_dropout" ) per_layer_stochastic_drop_rate = [ round(x, 3) for x in np.linspace(0, stochastic_dropout, n_transformer_layers) ] blocks = [] self.downsamplers = nn.ModuleDict() for layer_idx in range(n_transformer_layers): blocks.append( WindowedTransformerEncoder( opts=opts, embed_dim=embed_dim, ffn_latent_dim=ffn_dim, num_heads=num_heads, attn_dropout=attn_dropout, dropout=dropout, ffn_dropout=ffn_dropout, transformer_norm_layer=norm_layer, stochastic_dropout=per_layer_stochastic_drop_rate[layer_idx], window_size=window_sizes[layer_idx], window_shift=window_shifts[layer_idx], ) ) if downsample is not None and downsample[layer_idx]: self.downsamplers[ self.get_downsampler_name(layer_idx) ] = token_merging.TokenMerging(embed_dim) self.transformer = nn.Sequential(*blocks) self.post_transformer_norm = get_normalization_layer( opts=opts, num_features=embed_dim, norm_type=norm_layer ) num_classes = getattr(opts, "model.classification.n_classes") self.classifier = LinearLayer(embed_dim, num_classes)
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: if cls != ByteFormer: # Don't re-register arguments in subclasses that don't override `add_arguments()`. return parser group = parser.add_argument_group(title=cls.__name__) group.add_argument( "--model.classification.byteformer.dropout", type=float, default=0.0, help="Dropout in Byteformer layers. Defaults to 0.0.", ) group.add_argument( "--model.classification.byteformer.stochastic-dropout", type=float, default=0.0, help="Probability of applying stochastic dropout to " "TransformerEncoder submodules. Defaults to 0.0.", ) group.add_argument( "--model.classification.byteformer.norm-layer", type=str, default="layer_norm", help="Normalization layer in Byteformer. Defaults to LayerNorm.", choices=list(normalization.NORM_LAYER_REGISTRY.keys()), ) group.add_argument( "--model.classification.byteformer.sinusoidal-pos-emb", action="store_true", default=False, help="Use sinusoidal instead of learnable positional encoding. Defaults to False.", ) group.add_argument( "--model.classification.byteformer.use-pytorch-mha", action="store_true", default=False, help="Use PyTorch's native multi-head attention. Defaults to False.", ) group.add_argument( "--model.classification.byteformer.mode", type=str, default="tiny", help="Byteformer mode, which determines the model size. Defaults to tiny.", choices=("tiny", "small", "base", "huge"), ) group.add_argument( "--model.classification.byteformer.vocab-size", type=int, help="The vocab size of the token embedding. Defaults to 257," "corresponding to the number of unique bytes (256) plus 1 " "more for the mask token.", default=257, ) group.add_argument( "--model.classification.byteformer.max-num-tokens", type=int, help="The maximum number of tokens that can be input to the network. Defaults to 10000.", default=10000, ) group.add_argument( "--model.classification.byteformer.conv-kernel-size", type=int, default=16, help="The size of the kernel of the initial downsampling conv1d. Defaults to 16.", ) group.add_argument( "--model.classification.byteformer.window-sizes", type=int, nargs="*", default=[128], help="A list of window sizes used in shifted window attention. If the " "list is length 1, the same window size is used for all windows. " "Defaults to 128 for all windows.", ) group.add_argument( "--model.classification.byteformer.window-shifts", type=int, nargs="*", default=[0, 64] * 6, help="A list of shifts used in shifted window attention. Defaults to values that alternate between 0 and 64.", ) default_downsampling = [True, True] + ([False, True] * 4) + [False, False] group.add_argument( "--model.classification.byteformer.downsample", type=bool, nargs="*", default=default_downsampling, help="A list of boolean values, where the i'th element specifies " "whether to downsample after the transformer block with index i. " f"Defaults to {default_downsampling}.", ) group.add_argument( "--model.classification.byteformer.padding-index", default=-1, type=int, help="The index used for padding tokens. Defaults to -1.", ) group.add_argument( "--model.classification.byteformer.dummy-input-token-length", default=48564, type=int, help="The token length to use for dummy inputs. Defaults to 48564, " "corresponding to the average length of 224x224 JPEG images from " "ImageNet.", ) return parser
[docs] def dummy_input_and_label(self, batch_size: int) -> Dict: """ Get a dummy input and label that could be passed to the model. Args: batch_size: The batch size to use for the generated inputs. Returns: A dict with { "samples": tensor of shape [batch_size, sequence_length], "targets": tensor of shape [batch_size], } """ n_labels = 10 max_value = 257 samples = torch.randint( 0, max_value, [batch_size, self.dummy_input_token_length] ) targets = torch.randint(low=0, high=n_labels, size=(batch_size,)).long() return {"samples": samples, "targets": targets}
[docs] def apply_token_reduction_net( self, x: Tensor, x_mask: Tensor ) -> Tuple[Tensor, Tensor]: """ Apply the portion of the network used to reduce sequence lengths before the transformer backbone. Args: x: The input token embeddings of shape [batch_size, sequence_length, embed_dim]. x_mask: The input mask of shape [batch_size, sequence_length]. Returns: New versions of @x and @x_mask, downsampled along the sequence dimension by the token reduction net. """ B, N, C = x.shape if self.token_reduction_net is None: return x, x_mask x = self.token_reduction_net(x.permute(0, 2, 1)).permute(0, 2, 1) if x_mask is not None: x_mask = unfold_tokens( x_mask.reshape(B, N, 1).float(), self.conv_kernel_size ) # The mask is now [B * N, kernel_size, 1]. It contains values in {0, -inf}. x_mask = x_mask.max(dim=1).values.view(x.shape[0], x.shape[1]) assert x.shape[:2] == x_mask.shape return x, x_mask
[docs] def get_backbone_inputs(self, x: Tensor) -> Tuple[Tensor, Tensor]: """ Convert input bytes into embeddings to be passed to the network's transformer backbone. Args: x: The input bytes as an integer tensor of shape [batch_size, sequence_length]. Integer tensors are expected (rather than byte tensors) since -1 is usually used for padding. Returns: The embeddings of shape [batch_size, new_sequence_length] and a mask tensor of shape [batch_size, new_sequence_length]. The mask contains 0 at unmasked positions and float(-inf) at masked positions. """ mask = torch.zeros_like(x, dtype=torch.float) mask[x == -1].fill_(float("-inf")) mask = mask.detach().requires_grad_(False) x[x == -1] = self.embeddings.padding_idx x = self.embeddings(x) x, mask = self.apply_token_reduction_net(x, mask) x = x + self.pos_embed(self.max_num_tokens)[:, : x.shape[1]] x = self.emb_dropout(x) return x, mask
[docs] def backbone_forward( self, x: Tensor, key_padding_mask: Tensor ) -> Tuple[Tensor, Tensor]: """ Execute the forward pass of the network's transformer backbone. Args: x: The input embeddings as a [batch_size, sequence_length, embed_dim] tensor. key_padding_mask: The mask tensor of shape [batch_size, sequence_length]. Returns: The outputs of the backbone as a tuple. The first element is the feature tensor, and the second element is the updated key_padding_mask. """ B, S, _ = x.shape assert key_padding_mask.shape == (B, S) for layer_idx, elem in enumerate(self.transformer): x = elem(x, key_padding_mask=key_padding_mask) if self.get_downsampler(layer_idx) is not None: x, key_padding_mask = self.get_downsampler(layer_idx)( x, key_padding_mask ) x = self.post_transformer_norm(x) return x, key_padding_mask
[docs] def get_downsampler_name(self, idx: int) -> str: """ Get the name of the downsampling layer with index @idx. Args: idx: The index of the downsampling layer. Returns: A string representing the name of the donwsampling layer. """ return f"downsample_{idx}"
[docs] def get_downsampler(self, idx: int) -> Optional[nn.Module]: """ Get the module that performs downsampling after transformer layer @idx. If no downsampling occurs after that layer, return None. Args: idx: The desired index. Returns: The downsampling layer, or None. """ name = self.get_downsampler_name(idx) if name not in self.downsamplers: return None return self.downsamplers[name]
[docs] def forward(self, x: Tensor, *args, **kwargs) -> Tensor: """ Perform a forward pass on input bytes. The tensor is stored as an integer tensor of shape [batch_size, sequence_length]. Integer tensors are used because @x usually contains mask tokens. Args: x: The input tensor of shape [batch_size, sequence_length]. Returns: The output logits. """ x, key_padding_mask = self.get_backbone_inputs(x) x, attn_mask = self.backbone_forward(x, key_padding_mask) attn_mask = attn_mask.view(x.shape[0], x.shape[1], 1) x[(attn_mask == float("-inf")).expand(-1, -1, x.shape[-1])] = 0 norms = (attn_mask == 0).sum(dim=1) x = torch.sum(x, dim=1) / norms x = self.classifier(x) return x
[docs] @classmethod def build_model(cls, opts: argparse.Namespace, *args, **kwargs) -> BaseAnyNNModel: """ Helper function to build a model. Args: opts: Command-line arguments. Returns: An instance of `cvnets.models.BaseAnyNNModel`. """ model = cls(opts, *args, **kwargs) if getattr(opts, "model.classification.freeze_batch_norm"): cls.freeze_norm_layers(opts=opts, model=model) return model