Source code for cvnets.models.segmentation.heads.base_seg_head

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import argparse
from typing import Dict, Optional, Tuple

from torch import Tensor, nn

from cvnets.layers import ConvLayer2d, Dropout2d, UpSample
from cvnets.misc.common import parameter_list
from cvnets.misc.init_utils import initialize_weights
from cvnets.models import MODEL_REGISTRY, BaseAnyNNModel
from utils import logger


[docs]@MODEL_REGISTRY.register(name="__base__", type="segmentation_head") class BaseSegHead(BaseAnyNNModel): """ Base class for segmentation heads """
[docs] def __init__( self, opts, enc_conf: dict, use_l5_exp: Optional[bool] = False, *args, **kwargs ): enc_ch_l5_exp_out = _check_out_channels(enc_conf, "exp_before_cls") enc_ch_l5_out = _check_out_channels(enc_conf, "layer5") enc_ch_l4_out = _check_out_channels(enc_conf, "layer4") enc_ch_l3_out = _check_out_channels(enc_conf, "layer3") enc_ch_l2_out = _check_out_channels(enc_conf, "layer2") enc_ch_l1_out = _check_out_channels(enc_conf, "layer1") n_seg_classes = getattr(opts, "model.segmentation.n_classes") if n_seg_classes is None: logger.error( "Please specify number of segmentation classes using --model.segmentation.n-classes. Got None." ) super().__init__(opts, *args, **kwargs) self.use_l5_exp = use_l5_exp self.enc_l5_exp_channels = enc_ch_l5_exp_out self.enc_l5_channels = enc_ch_l5_out self.enc_l4_channels = enc_ch_l4_out self.enc_l3_channels = enc_ch_l3_out self.enc_l2_channels = enc_ch_l2_out self.enc_l1_channels = enc_ch_l1_out self.n_seg_classes = n_seg_classes self.lr_multiplier = getattr(opts, "model.segmentation.lr_multiplier", 1.0) self.classifier_dropout = getattr( opts, "model.segmentation.classifier_dropout", 0.1 ) self.output_stride = getattr(opts, "model.segmentation.output_stride", 16) self.aux_head = None if getattr(opts, "model.segmentation.use_aux_head", False): drop_aux = getattr(opts, "model.segmentation.aux_dropout", 0.1) inner_channels = max(int(self.enc_l4_channels // 4), 128) self.aux_head = nn.Sequential( ConvLayer2d( opts=opts, in_channels=self.enc_l4_channels, out_channels=inner_channels, kernel_size=3, stride=1, use_norm=True, use_act=True, bias=False, groups=1, ), Dropout2d(drop_aux), ConvLayer2d( opts=opts, in_channels=inner_channels, out_channels=self.n_seg_classes, kernel_size=1, stride=1, use_norm=False, use_act=False, bias=True, groups=1, ), ) self.upsample_seg_out = None if self.output_stride != 1.0: self.upsample_seg_out = UpSample( scale_factor=self.output_stride, mode="bilinear", align_corners=True )
[docs] def forward_aux_head(self, enc_out: Dict) -> Tensor: aux_out = self.aux_head(enc_out["out_l4"]) return aux_out
[docs] def forward_seg_head(self, enc_out: Dict) -> Tensor: raise NotImplementedError
[docs] def forward(self, enc_out: Dict, *args, **kwargs) -> Tensor or Tuple[Tensor]: out = self.forward_seg_head(enc_out=enc_out) if self.upsample_seg_out is not None: # resize the mask based on given size mask_size = kwargs.get("orig_size", None) if mask_size is not None: self.upsample_seg_out.scale_factor = None self.upsample_seg_out.size = mask_size out = self.upsample_seg_out(out) if self.aux_head is not None and self.training: aux_out = self.forward_aux_head(enc_out=enc_out) return out, aux_out return out
[docs] def reset_head_parameters(self, opts) -> None: # weight initialization initialize_weights(opts=opts, modules=self.modules())
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Add segmentation head specific arguments""" group = parser.add_argument_group( title="Segmentation head arguments", description="Segmentation head arguments", ) group.add_argument( "--model.segmentation.seg-head", type=str, default=None, help="Segmentation head", ) return parser
[docs] def get_trainable_parameters( self, weight_decay: float = 0.0, no_decay_bn_filter_bias: bool = False, *args, **kwargs ): param_list = parameter_list( named_parameters=self.named_parameters, weight_decay=weight_decay, no_decay_bn_filter_bias=no_decay_bn_filter_bias, *args, **kwargs ) return param_list, [self.lr_multiplier] * len(param_list)
[docs] def update_classifier(self, opts, n_classes: int) -> None: """ This function updates the classification layer in a model. Useful for finetuning purposes. """ raise NotImplementedError
[docs] @classmethod def build_model(cls, opts: argparse.Namespace, *args, **kwargs) -> BaseAnyNNModel: return cls(opts, *args, **kwargs)
def _check_out_channels(config: dict, layer_name: str) -> int: enc_ch_l: dict = config.get(layer_name, None) if enc_ch_l is None or not enc_ch_l: logger.error( "Encoder does not define input-output mapping for {}: Got: {}".format( layer_name, config ) ) enc_ch_l_out = enc_ch_l.get("out", None) if enc_ch_l_out is None or not enc_ch_l_out: logger.error( "Output channels are not defined in {} of the encoder. Got: {}".format( layer_name, enc_ch_l ) ) return enc_ch_l_out