#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
from typing import Dict
from utils import logger
[docs]def get_configuration(opts) -> Dict:
mode = getattr(opts, "model.classification.mit.mode", "small")
if mode is None:
logger.error("Please specify mode")
head_dim = getattr(opts, "model.classification.mit.head_dim", None)
num_heads = getattr(opts, "model.classification.mit.number_heads", 4)
if head_dim is not None:
if num_heads is not None:
logger.error(
"--model.classification.mit.head-dim and --model.classification.mit.number-heads "
"are mutually exclusive."
)
elif num_heads is not None:
if head_dim is not None:
logger.error(
"--model.classification.mit.head-dim and --model.classification.mit.number-heads "
"are mutually exclusive."
)
mode = mode.lower()
if mode == "xx_small":
mv2_exp_mult = 2
config = {
"layer1": {
"out_channels": 16,
"expand_ratio": mv2_exp_mult,
"num_blocks": 1,
"stride": 1,
"block_type": "mv2",
},
"layer2": {
"out_channels": 24,
"expand_ratio": mv2_exp_mult,
"num_blocks": 3,
"stride": 2,
"block_type": "mv2",
},
"layer3": { # 28x28
"out_channels": 48,
"transformer_channels": 64,
"ffn_dim": 128,
"transformer_blocks": 2,
"patch_h": 2, # 8,
"patch_w": 2, # 8,
"stride": 2,
"mv_expand_ratio": mv2_exp_mult,
"head_dim": head_dim,
"num_heads": num_heads,
"block_type": "mobilevit",
},
"layer4": { # 14x14
"out_channels": 64,
"transformer_channels": 80,
"ffn_dim": 160,
"transformer_blocks": 4,
"patch_h": 2, # 4,
"patch_w": 2, # 4,
"stride": 2,
"mv_expand_ratio": mv2_exp_mult,
"head_dim": head_dim,
"num_heads": num_heads,
"block_type": "mobilevit",
},
"layer5": { # 7x7
"out_channels": 80,
"transformer_channels": 96,
"ffn_dim": 192,
"transformer_blocks": 3,
"patch_h": 2,
"patch_w": 2,
"stride": 2,
"mv_expand_ratio": mv2_exp_mult,
"head_dim": head_dim,
"num_heads": num_heads,
"block_type": "mobilevit",
},
"last_layer_exp_factor": 4,
}
elif mode == "x_small":
mv2_exp_mult = 4
config = {
"layer1": {
"out_channels": 32,
"expand_ratio": mv2_exp_mult,
"num_blocks": 1,
"stride": 1,
"block_type": "mv2",
},
"layer2": {
"out_channels": 48,
"expand_ratio": mv2_exp_mult,
"num_blocks": 3,
"stride": 2,
"block_type": "mv2",
},
"layer3": { # 28x28
"out_channels": 64,
"transformer_channels": 96,
"ffn_dim": 192,
"transformer_blocks": 2,
"patch_h": 2,
"patch_w": 2,
"stride": 2,
"mv_expand_ratio": mv2_exp_mult,
"head_dim": head_dim,
"num_heads": num_heads,
"block_type": "mobilevit",
},
"layer4": { # 14x14
"out_channels": 80,
"transformer_channels": 120,
"ffn_dim": 240,
"transformer_blocks": 4,
"patch_h": 2,
"patch_w": 2,
"stride": 2,
"mv_expand_ratio": mv2_exp_mult,
"head_dim": head_dim,
"num_heads": num_heads,
"block_type": "mobilevit",
},
"layer5": { # 7x7
"out_channels": 96,
"transformer_channels": 144,
"ffn_dim": 288,
"transformer_blocks": 3,
"patch_h": 2,
"patch_w": 2,
"stride": 2,
"mv_expand_ratio": mv2_exp_mult,
"head_dim": head_dim,
"num_heads": num_heads,
"block_type": "mobilevit",
},
"last_layer_exp_factor": 4,
}
elif mode == "small":
mv2_exp_mult = 4
config = {
"layer1": {
"out_channels": 32,
"expand_ratio": mv2_exp_mult,
"num_blocks": 1,
"stride": 1,
"block_type": "mv2",
},
"layer2": {
"out_channels": 64,
"expand_ratio": mv2_exp_mult,
"num_blocks": 3,
"stride": 2,
"block_type": "mv2",
},
"layer3": { # 28x28
"out_channels": 96,
"transformer_channels": 144,
"ffn_dim": 288,
"transformer_blocks": 2,
"patch_h": 2,
"patch_w": 2,
"stride": 2,
"mv_expand_ratio": mv2_exp_mult,
"head_dim": head_dim,
"num_heads": num_heads,
"block_type": "mobilevit",
},
"layer4": { # 14x14
"out_channels": 128,
"transformer_channels": 192,
"ffn_dim": 384,
"transformer_blocks": 4,
"patch_h": 2,
"patch_w": 2,
"stride": 2,
"mv_expand_ratio": mv2_exp_mult,
"head_dim": head_dim,
"num_heads": num_heads,
"block_type": "mobilevit",
},
"layer5": { # 7x7
"out_channels": 160,
"transformer_channels": 240,
"ffn_dim": 480,
"transformer_blocks": 3,
"patch_h": 2,
"patch_w": 2,
"stride": 2,
"mv_expand_ratio": mv2_exp_mult,
"head_dim": head_dim,
"num_heads": num_heads,
"block_type": "mobilevit",
},
"last_layer_exp_factor": 4,
}
else:
raise NotImplementedError
return config