Source code for cvnets.models.classification.config.mobilevit

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

from typing import Dict

from utils import logger


[docs]def get_configuration(opts) -> Dict: mode = getattr(opts, "model.classification.mit.mode", "small") if mode is None: logger.error("Please specify mode") head_dim = getattr(opts, "model.classification.mit.head_dim", None) num_heads = getattr(opts, "model.classification.mit.number_heads", 4) if head_dim is not None: if num_heads is not None: logger.error( "--model.classification.mit.head-dim and --model.classification.mit.number-heads " "are mutually exclusive." ) elif num_heads is not None: if head_dim is not None: logger.error( "--model.classification.mit.head-dim and --model.classification.mit.number-heads " "are mutually exclusive." ) mode = mode.lower() if mode == "xx_small": mv2_exp_mult = 2 config = { "layer1": { "out_channels": 16, "expand_ratio": mv2_exp_mult, "num_blocks": 1, "stride": 1, "block_type": "mv2", }, "layer2": { "out_channels": 24, "expand_ratio": mv2_exp_mult, "num_blocks": 3, "stride": 2, "block_type": "mv2", }, "layer3": { # 28x28 "out_channels": 48, "transformer_channels": 64, "ffn_dim": 128, "transformer_blocks": 2, "patch_h": 2, # 8, "patch_w": 2, # 8, "stride": 2, "mv_expand_ratio": mv2_exp_mult, "head_dim": head_dim, "num_heads": num_heads, "block_type": "mobilevit", }, "layer4": { # 14x14 "out_channels": 64, "transformer_channels": 80, "ffn_dim": 160, "transformer_blocks": 4, "patch_h": 2, # 4, "patch_w": 2, # 4, "stride": 2, "mv_expand_ratio": mv2_exp_mult, "head_dim": head_dim, "num_heads": num_heads, "block_type": "mobilevit", }, "layer5": { # 7x7 "out_channels": 80, "transformer_channels": 96, "ffn_dim": 192, "transformer_blocks": 3, "patch_h": 2, "patch_w": 2, "stride": 2, "mv_expand_ratio": mv2_exp_mult, "head_dim": head_dim, "num_heads": num_heads, "block_type": "mobilevit", }, "last_layer_exp_factor": 4, } elif mode == "x_small": mv2_exp_mult = 4 config = { "layer1": { "out_channels": 32, "expand_ratio": mv2_exp_mult, "num_blocks": 1, "stride": 1, "block_type": "mv2", }, "layer2": { "out_channels": 48, "expand_ratio": mv2_exp_mult, "num_blocks": 3, "stride": 2, "block_type": "mv2", }, "layer3": { # 28x28 "out_channels": 64, "transformer_channels": 96, "ffn_dim": 192, "transformer_blocks": 2, "patch_h": 2, "patch_w": 2, "stride": 2, "mv_expand_ratio": mv2_exp_mult, "head_dim": head_dim, "num_heads": num_heads, "block_type": "mobilevit", }, "layer4": { # 14x14 "out_channels": 80, "transformer_channels": 120, "ffn_dim": 240, "transformer_blocks": 4, "patch_h": 2, "patch_w": 2, "stride": 2, "mv_expand_ratio": mv2_exp_mult, "head_dim": head_dim, "num_heads": num_heads, "block_type": "mobilevit", }, "layer5": { # 7x7 "out_channels": 96, "transformer_channels": 144, "ffn_dim": 288, "transformer_blocks": 3, "patch_h": 2, "patch_w": 2, "stride": 2, "mv_expand_ratio": mv2_exp_mult, "head_dim": head_dim, "num_heads": num_heads, "block_type": "mobilevit", }, "last_layer_exp_factor": 4, } elif mode == "small": mv2_exp_mult = 4 config = { "layer1": { "out_channels": 32, "expand_ratio": mv2_exp_mult, "num_blocks": 1, "stride": 1, "block_type": "mv2", }, "layer2": { "out_channels": 64, "expand_ratio": mv2_exp_mult, "num_blocks": 3, "stride": 2, "block_type": "mv2", }, "layer3": { # 28x28 "out_channels": 96, "transformer_channels": 144, "ffn_dim": 288, "transformer_blocks": 2, "patch_h": 2, "patch_w": 2, "stride": 2, "mv_expand_ratio": mv2_exp_mult, "head_dim": head_dim, "num_heads": num_heads, "block_type": "mobilevit", }, "layer4": { # 14x14 "out_channels": 128, "transformer_channels": 192, "ffn_dim": 384, "transformer_blocks": 4, "patch_h": 2, "patch_w": 2, "stride": 2, "mv_expand_ratio": mv2_exp_mult, "head_dim": head_dim, "num_heads": num_heads, "block_type": "mobilevit", }, "layer5": { # 7x7 "out_channels": 160, "transformer_channels": 240, "ffn_dim": 480, "transformer_blocks": 3, "patch_h": 2, "patch_w": 2, "stride": 2, "mv_expand_ratio": mv2_exp_mult, "head_dim": head_dim, "num_heads": num_heads, "block_type": "mobilevit", }, "last_layer_exp_factor": 4, } else: raise NotImplementedError return config