Source code for cvnets.models.segmentation.enc_dec

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import argparse
from typing import Dict, Optional, Tuple, Union

from torch import Tensor

from cvnets.models import MODEL_REGISTRY, BaseAnyNNModel, get_model
from cvnets.models.classification.base_image_encoder import BaseImageEncoder
from cvnets.models.segmentation.base_seg import (
    BaseSegmentation,
    set_model_specific_opts_before_model_building,
    unset_model_specific_opts_after_model_building,
)


[docs]@MODEL_REGISTRY.register(name="encoder_decoder", type="segmentation") class SegEncoderDecoder(BaseSegmentation): """ This class defines a encoder-decoder architecture for the task of semantic segmentation. Different segmentation heads (e.g., PSPNet and DeepLabv3) can be used Args: opts: command-line arguments encoder (BaseImageEncoder): Backbone network (e.g., MobileViT or ResNet) """
[docs] def __init__( self, opts, encoder: BaseImageEncoder, seg_head, *args, **kwargs ) -> None: super().__init__(opts=opts, encoder=encoder) # delete layers that are not required in segmentation network self.encoder.classifier = None use_l5_exp = getattr(opts, "model.segmentation.use_level5_exp") if not use_l5_exp: self.encoder.conv_1x1_exp = None self.maybe_seg_norm_layer() self.seg_head = seg_head self.use_l5_exp = use_l5_exp self.set_default_norm_layer()
[docs] def get_trainable_parameters( self, weight_decay: Optional[float] = 0.0, no_decay_bn_filter_bias: Optional[bool] = False, *args, **kwargs ): """This function separates the parameters for backbone and segmentation head, so that different learning rates can be used for backbone and segmentation head """ if getattr(self.encoder, "enable_layer_wise_lr_decay"): encoder_params, enc_lr_mult = self.encoder.get_trainable_parameters( weight_decay=weight_decay, no_decay_bn_filter_bias=no_decay_bn_filter_bias, module_name="encoder.", *args, **kwargs, ) else: encoder_params, enc_lr_mult = self.encoder.get_trainable_parameters( weight_decay=weight_decay, no_decay_bn_filter_bias=no_decay_bn_filter_bias, module_name="encoder.", *args, **kwargs, ) decoder_params, dec_lr_mult = self.seg_head.get_trainable_parameters( weight_decay=weight_decay, no_decay_bn_filter_bias=no_decay_bn_filter_bias, module_name="seg_head.", *args, **kwargs, ) total_params = sum([p.numel() for p in self.parameters()]) encoder_params_count = sum([p.numel() for p in self.encoder.parameters()]) decoder_params_count = sum([p.numel() for p in self.seg_head.parameters()]) assert total_params == encoder_params_count + decoder_params_count, ( "Total network parameters are not equal to " "the sum of encoder and decoder. " "{} != {} + {}".format( total_params, encoder_params_count, decoder_params_count ) ) return encoder_params + decoder_params, enc_lr_mult + dec_lr_mult
[docs] def forward( self, x: Tensor, *args, **kwargs ) -> Union[Tuple[Tensor, Tensor], Tensor, Dict]: enc_end_points: Dict = self.encoder.extract_end_points_all( x, use_l5=True, use_l5_exp=self.use_l5_exp ) if "augmented_tensor" in enc_end_points: output_dict = { "augmented_tensor": enc_end_points.pop("augmented_tensor"), "segmentation_output": self.seg_head( enc_out=enc_end_points, *args, **kwargs ), } return output_dict else: return self.seg_head(enc_out=enc_end_points, *args, **kwargs)
[docs] def update_classifier(self, opts, n_classes: int) -> None: """ This function updates the classification layer in a model. Useful for finetuning purposes. """ if hasattr(self.seg_head, "update_classifier"): self.seg_head.update_classifier(opts, n_classes)
[docs] @classmethod def build_model(cls, opts: argparse.Namespace, *args, **kwargs) -> BaseAnyNNModel: output_stride = getattr(opts, "model.segmentation.output_stride", None) image_encoder = get_model( opts, category="classification", output_stride=output_stride, *args, **kwargs, ) default_opt_info = set_model_specific_opts_before_model_building(opts) use_l5_exp = getattr(opts, "model.segmentation.use_level5_exp") seg_head = get_model( opts=opts, category="segmentation_head", model_name=getattr(opts, "model.segmentation.seg_head"), enc_conf=image_encoder.model_conf_dict, use_l5_exp=use_l5_exp, *args, **kwargs, ) seg_model = cls(opts, encoder=image_encoder, seg_head=seg_head, *args, **kwargs) unset_model_specific_opts_after_model_building( opts, default_opts_info=default_opt_info ) if getattr(opts, "model.segmentation.freeze_batch_norm"): cls.freeze_norm_layers(opts, model=seg_model) return seg_model