#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import argparse
from typing import Dict
from torch import nn
from cvnets.misc.init_utils import initialize_weights
from cvnets.models import MODEL_REGISTRY, BaseAnyNNModel, get_model
from cvnets.models.classification.base_image_encoder import BaseImageEncoder
from utils import logger
[docs]@MODEL_REGISTRY.register(name="__base__", type="detection")
class BaseDetection(BaseAnyNNModel):
"""Base class for the task of object detection
Args:
opts: Command-line arguments
encoder: Image-encoder model (e.g., MobileNet or ResNet)
"""
[docs] def __init__(
self, opts: argparse.Namespace, encoder: BaseImageEncoder, *args, **kwargs
) -> None:
super().__init__(opts, *args, **kwargs)
assert isinstance(encoder, BaseImageEncoder)
self.encoder: BaseImageEncoder = encoder
self.n_detection_classes = getattr(opts, "model.detection.n_classes")
enc_conf = self.encoder.model_conf_dict
enc_ch_l5_out_proj = check_feature_map_output_channels(
enc_conf, "exp_before_cls"
)
enc_ch_l5_out = check_feature_map_output_channels(enc_conf, "layer5")
enc_ch_l4_out = check_feature_map_output_channels(enc_conf, "layer4")
enc_ch_l3_out = check_feature_map_output_channels(enc_conf, "layer3")
enc_ch_l2_out = check_feature_map_output_channels(enc_conf, "layer2")
enc_ch_l1_out = check_feature_map_output_channels(enc_conf, "layer1")
self.enc_l5_channels = enc_ch_l5_out
self.enc_l5_channels_exp = enc_ch_l5_out_proj
self.enc_l4_channels = enc_ch_l4_out
self.enc_l3_channels = enc_ch_l3_out
self.enc_l2_channels = enc_ch_l2_out
self.enc_l1_channels = enc_ch_l1_out
self.opts = opts
[docs] @classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Add model specific arguments"""
if cls != BaseDetection:
# Don't re-register arguments in subclasses that don't override `add_arguments()`.
return parser
group = parser.add_argument_group(cls.__name__)
group.add_argument(
"--model.detection.name",
type=str,
default=None,
help="Detection model name",
)
group.add_argument(
"--model.detection.n-classes",
type=int,
default=80,
help="Number of classes in the dataset. Defaults to 80.",
)
group.add_argument(
"--model.detection.pretrained",
type=str,
default=None,
help="Path of the pretrained detection model. Defaults to None.",
)
group.add_argument(
"--model.detection.output-stride",
type=int,
default=None,
help="Output stride of the classification network. Defaults to None.",
)
group.add_argument(
"--model.detection.replace-stride-with-dilation",
action="store_true",
default=False,
help="Replace stride with dilation",
)
group.add_argument(
"--model.detection.freeze-batch-norm",
action="store_true",
default=False,
help="Freeze batch norm layers in detection model. Defaults to False.",
)
return parser
[docs] @staticmethod
def reset_layer_parameters(layer: nn.Module, opts: argparse.Namespace) -> None:
"""Initialize weights of a given layer"""
initialize_weights(opts=opts, modules=layer.modules())
[docs] @classmethod
def build_model(cls, opts: argparse.Namespace, *args, **kwargs) -> BaseAnyNNModel:
output_stride = getattr(opts, "model.detection.output_stride", None)
image_encoder = get_model(
opts=opts,
category="classification",
output_stride=output_stride,
*args,
**kwargs
)
detection_model = cls(opts=opts, encoder=image_encoder, *args, **kwargs)
if getattr(opts, "model.detection.freeze_batch_norm"):
cls.freeze_norm_layers(opts, model=detection_model)
return detection_model
[docs]def check_feature_map_output_channels(config: Dict, layer_name: str) -> int:
enc_ch_l: Dict = config.get(layer_name, None)
if enc_ch_l is None or not enc_ch_l:
logger.error(
"Encoder does not define input-output mapping for {}: Got: {}".format(
layer_name, config
)
)
enc_ch_l_out = enc_ch_l.get("out", None)
if enc_ch_l_out is None or not enc_ch_l_out:
logger.error(
"Output channels are not defined in {} of the encoder. Got: {}".format(
layer_name, enc_ch_l
)
)
return enc_ch_l_out