#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import argparse
from typing import Dict, Optional, Tuple
from torch import Tensor, nn
from cvnets.layers import ConvLayer2d, Dropout2d, UpSample
from cvnets.misc.common import parameter_list
from cvnets.misc.init_utils import initialize_weights
from cvnets.models import MODEL_REGISTRY, BaseAnyNNModel
from utils import logger
[docs]@MODEL_REGISTRY.register(name="__base__", type="segmentation_head")
class BaseSegHead(BaseAnyNNModel):
"""
Base class for segmentation heads
"""
[docs] def __init__(
self, opts, enc_conf: dict, use_l5_exp: Optional[bool] = False, *args, **kwargs
):
enc_ch_l5_exp_out = _check_out_channels(enc_conf, "exp_before_cls")
enc_ch_l5_out = _check_out_channels(enc_conf, "layer5")
enc_ch_l4_out = _check_out_channels(enc_conf, "layer4")
enc_ch_l3_out = _check_out_channels(enc_conf, "layer3")
enc_ch_l2_out = _check_out_channels(enc_conf, "layer2")
enc_ch_l1_out = _check_out_channels(enc_conf, "layer1")
n_seg_classes = getattr(opts, "model.segmentation.n_classes")
if n_seg_classes is None:
logger.error(
"Please specify number of segmentation classes using --model.segmentation.n-classes. Got None."
)
super().__init__(opts, *args, **kwargs)
self.use_l5_exp = use_l5_exp
self.enc_l5_exp_channels = enc_ch_l5_exp_out
self.enc_l5_channels = enc_ch_l5_out
self.enc_l4_channels = enc_ch_l4_out
self.enc_l3_channels = enc_ch_l3_out
self.enc_l2_channels = enc_ch_l2_out
self.enc_l1_channels = enc_ch_l1_out
self.n_seg_classes = n_seg_classes
self.lr_multiplier = getattr(opts, "model.segmentation.lr_multiplier", 1.0)
self.classifier_dropout = getattr(
opts, "model.segmentation.classifier_dropout", 0.1
)
self.output_stride = getattr(opts, "model.segmentation.output_stride", 16)
self.aux_head = None
if getattr(opts, "model.segmentation.use_aux_head", False):
drop_aux = getattr(opts, "model.segmentation.aux_dropout", 0.1)
inner_channels = max(int(self.enc_l4_channels // 4), 128)
self.aux_head = nn.Sequential(
ConvLayer2d(
opts=opts,
in_channels=self.enc_l4_channels,
out_channels=inner_channels,
kernel_size=3,
stride=1,
use_norm=True,
use_act=True,
bias=False,
groups=1,
),
Dropout2d(drop_aux),
ConvLayer2d(
opts=opts,
in_channels=inner_channels,
out_channels=self.n_seg_classes,
kernel_size=1,
stride=1,
use_norm=False,
use_act=False,
bias=True,
groups=1,
),
)
self.upsample_seg_out = None
if self.output_stride != 1.0:
self.upsample_seg_out = UpSample(
scale_factor=self.output_stride, mode="bilinear", align_corners=True
)
[docs] def forward_aux_head(self, enc_out: Dict) -> Tensor:
aux_out = self.aux_head(enc_out["out_l4"])
return aux_out
[docs] def forward_seg_head(self, enc_out: Dict) -> Tensor:
raise NotImplementedError
[docs] def forward(self, enc_out: Dict, *args, **kwargs) -> Tensor or Tuple[Tensor]:
out = self.forward_seg_head(enc_out=enc_out)
if self.upsample_seg_out is not None:
# resize the mask based on given size
mask_size = kwargs.get("orig_size", None)
if mask_size is not None:
self.upsample_seg_out.scale_factor = None
self.upsample_seg_out.size = mask_size
out = self.upsample_seg_out(out)
if self.aux_head is not None and self.training:
aux_out = self.forward_aux_head(enc_out=enc_out)
return out, aux_out
return out
[docs] def reset_head_parameters(self, opts) -> None:
# weight initialization
initialize_weights(opts=opts, modules=self.modules())
[docs] @classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Add segmentation head specific arguments"""
group = parser.add_argument_group(
title="Segmentation head arguments",
description="Segmentation head arguments",
)
group.add_argument(
"--model.segmentation.seg-head",
type=str,
default=None,
help="Segmentation head",
)
return parser
[docs] def get_trainable_parameters(
self,
weight_decay: float = 0.0,
no_decay_bn_filter_bias: bool = False,
*args,
**kwargs
):
param_list = parameter_list(
named_parameters=self.named_parameters,
weight_decay=weight_decay,
no_decay_bn_filter_bias=no_decay_bn_filter_bias,
*args,
**kwargs
)
return param_list, [self.lr_multiplier] * len(param_list)
[docs] def update_classifier(self, opts, n_classes: int) -> None:
"""
This function updates the classification layer in a model. Useful for finetuning purposes.
"""
raise NotImplementedError
[docs] @classmethod
def build_model(cls, opts: argparse.Namespace, *args, **kwargs) -> BaseAnyNNModel:
return cls(opts, *args, **kwargs)
def _check_out_channels(config: dict, layer_name: str) -> int:
enc_ch_l: dict = config.get(layer_name, None)
if enc_ch_l is None or not enc_ch_l:
logger.error(
"Encoder does not define input-output mapping for {}: Got: {}".format(
layer_name, config
)
)
enc_ch_l_out = enc_ch_l.get("out", None)
if enc_ch_l_out is None or not enc_ch_l_out:
logger.error(
"Output channels are not defined in {} of the encoder. Got: {}".format(
layer_name, enc_ch_l
)
)
return enc_ch_l_out