Source code for cvnets.text_encoders.base_text_encoder

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import argparse
from typing import Any, Dict, Optional

import torch
from torch import Tensor, nn

from cvnets import parameter_list
from cvnets.layers import norm_layers_tuple
from cvnets.misc.init_utils import initialize_weights
from utils import logger
from utils.ddp_utils import is_master


[docs]class BaseTextEncoder(nn.Module): """Base class for text encoder"""
[docs] def __init__(self, opts, projection_dim: int, *args, **kwargs) -> None: is_master_node = is_master(opts) vocab_size = getattr(opts, "dataset.text_vocab_size") if getattr(opts, "common.debug_mode", False): vocab_size = 100 if vocab_size is None and is_master_node: logger.error( "Vocabulary size can't be None or -1 in {}. Got: {}".format( self.__class__.__name__, vocab_size ) ) super(BaseTextEncoder, self).__init__() self.opts = opts self.projection_dim = projection_dim self.is_master_node = is_master_node self.vocab_size = vocab_size
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Add model specific arguments""" group = parser.add_argument_group(title=cls.__name__) group.add_argument( "--model.text.name", type=str, default=None, help="Name of the text encoder", ) return parser
[docs] def reset_parameters(self): """Initialize model weights""" initialize_weights(opts=self.opts, modules=self.modules())
[docs] def get_trainable_parameters( self, weight_decay: Optional[float] = 0.0, no_decay_bn_filter_bias: Optional[bool] = False, *args, **kwargs ): param_list = parameter_list( named_parameters=self.named_parameters, weight_decay=weight_decay, no_decay_bn_filter_bias=no_decay_bn_filter_bias, *args, **kwargs ) return param_list, [1.0] * len(param_list)
[docs] def freeze_norm_layers(self) -> None: for m in self.modules(): if isinstance(m, norm_layers_tuple): m.eval() m.weight.requires_grad = False m.bias.requires_grad = False m.training = False
[docs] def forward( self, text_tokens: Tensor, key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None, *args, **kwargs ) -> Any: raise NotImplementedError
[docs] def dummy_input_and_label(self, batch_size: int) -> Dict: """Create dummy input and labels for CI/CD purposes. Child classes must override it if functionality is different. """ seq_length = 77 vocab_size = 10 text_tensor = torch.randint( low=0, high=vocab_size, size=(batch_size, seq_length) ).long() return {"text": text_tensor}