Source code for cvnets.models.multi_modal_img_text.base_multi_modal_img_text

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import argparse

from cvnets.models import MODEL_REGISTRY, BaseAnyNNModel


[docs]@MODEL_REGISTRY.register(name="__base__", type="multi_modal_image_text") class BaseMultiModalImageText(BaseAnyNNModel): """Base class for multi-modal image-text data Args: opts: Command-line arguments """
[docs] def __init__(self, opts, *args, **kwargs) -> None: super().__init__(opts, *args, **kwargs) self.lr_multiplier_img_encoder = getattr( opts, "model.multi_modal_image_text.lr_multiplier_img_encoder" ) self.lr_multiplier_text_encoder = getattr( opts, "model.multi_modal_image_text.lr_multiplier_text_encoder" )
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Add model specific arguments""" if cls != BaseMultiModalImageText: # Don't re-register arguments in subclasses that don't override `add_arguments()`. return parser group = parser.add_argument_group(title=cls.__name__) group.add_argument( "--model.multi-modal-image-text.name", type=str, default=None, help="Name of the multi-modal image-text model", ) group.add_argument( "--model.multi-modal-image-text.lr-multiplier-img-encoder", type=float, default=1.0, help="LR multiplier for the image encoder in {}".format(cls.__name__), ) group.add_argument( "--model.multi-modal-image-text.lr-multiplier-text-encoder", type=float, default=1.0, help="LR multiplier for the text encoder in {}".format(cls.__name__), ) group.add_argument( "--model.multi-modal-image-text.pretrained", type=str, default=None, help="Path of the pretrained backbone", ) group.add_argument( "--model.multi-modal-image-text.freeze-batch-norm", action="store_true", help="Freeze batch norm layers", ) return parser