Source code for cvnets.models.segmentation.heads.simple_seg_head

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import argparse
from typing import Dict, Optional

from torch import Tensor

from cvnets.layers import ConvLayer2d
from cvnets.models import MODEL_REGISTRY
from cvnets.models.segmentation.heads.base_seg_head import BaseSegHead


[docs]@MODEL_REGISTRY.register(name="simple_seg_head", type="segmentation_head") class SimpleSegHead(BaseSegHead): """ This class defines the simple segmentation head with merely a classification layer. This is useful for performing linear probling on segmentation task. Args: opts: command-line arguments enc_conf (Dict): Encoder input-output configuration at each spatial level use_l5_exp (Optional[bool]): Use features from expansion layer in Level5 in the encoder """
[docs] def __init__( self, opts, enc_conf: Dict, use_l5_exp: Optional[bool] = False, *args, **kwargs ) -> None: super().__init__(opts=opts, enc_conf=enc_conf, use_l5_exp=use_l5_exp) in_channels = ( self.enc_l5_channels if not self.use_l5_exp else self.enc_l5_exp_channels ) self.classifier = ConvLayer2d( opts=opts, in_channels=in_channels, out_channels=self.n_seg_classes, kernel_size=1, stride=1, use_norm=False, use_act=False, bias=True, ) self.reset_head_parameters(opts=opts)
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: return parser
[docs] def forward_seg_head(self, enc_out: Dict) -> Tensor: x = enc_out["out_l5_exp"] if self.use_l5_exp else enc_out["out_l5"] # classify x = self.classifier(x) return x