#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import argparse
from typing import Dict, Optional
from torch import Tensor
from cvnets.layers import ConvLayer2d
from cvnets.misc.init_utils import initialize_weights
from cvnets.models import MODEL_REGISTRY
from cvnets.models.segmentation.heads.base_seg_head import BaseSegHead
from cvnets.modules import PSP
[docs]@MODEL_REGISTRY.register(name="pspnet", type="segmentation_head")
class PSPNet(BaseSegHead):
"""
This class defines the segmentation head in `PSPNet architecture <https://arxiv.org/abs/1612.01105>`_
Args:
opts: command-line arguments
enc_conf (Dict): Encoder input-output configuration at each spatial level
use_l5_exp (Optional[bool]): Use features from expansion layer in Level5 in the encoder
"""
[docs] def __init__(
self, opts, enc_conf: dict, use_l5_exp: Optional[bool] = False, *args, **kwargs
) -> None:
psp_out_channels = getattr(
opts, "model.segmentation.pspnet.psp_out_channels", 512
)
psp_pool_sizes = getattr(
opts, "model.segmentation.pspnet.psp_pool_sizes", [1, 2, 3, 6]
)
psp_dropout = getattr(opts, "model.segmentation.pspnet.psp_dropout", 0.1)
super().__init__(opts=opts, enc_conf=enc_conf, use_l5_exp=use_l5_exp)
psp_in_channels = (
self.enc_l5_channels if not self.use_l5_exp else self.enc_l5_exp_channels
)
self.psp_layer = PSP(
opts=opts,
in_channels=psp_in_channels,
out_channels=psp_out_channels,
pool_sizes=psp_pool_sizes,
dropout=psp_dropout,
)
self.classifier = ConvLayer2d(
opts=opts,
in_channels=psp_out_channels,
out_channels=self.n_seg_classes,
kernel_size=1,
stride=1,
use_norm=False,
use_act=False,
bias=True,
)
self.reset_head_parameters(opts=opts)
[docs] def update_classifier(self, opts, n_classes: int) -> None:
"""
This function updates the classification layer in a model. Useful for finetuning purposes.
"""
in_channels = self.classifier.in_channels
conv_layer = ConvLayer2d(
opts=opts,
in_channels=in_channels,
out_channels=n_classes,
kernel_size=1,
stride=1,
use_norm=False,
use_act=False,
bias=True,
)
initialize_weights(opts, modules=conv_layer)
self.classifier = conv_layer
[docs] @classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--model.segmentation.pspnet.psp-pool-sizes",
type=int,
nargs="+",
default=[1, 2, 3, 6],
help="Pool sizes in the PSPNet module",
)
group.add_argument(
"--model.segmentation.pspnet.psp-out-channels",
type=int,
default=512,
help="Output channels of PSPNet module",
)
group.add_argument(
"--model.segmentation.pspnet.psp-dropout",
type=float,
default=0.1,
help="Dropout in the PSPNet module",
)
return parser
[docs] def forward_seg_head(self, enc_out: Dict) -> Tensor:
# low resolution features
x = enc_out["out_l5_exp"] if self.use_l5_exp else enc_out["out_l5"]
# Apply PSP layer
x = self.psp_layer(x)
out = self.classifier(x)
return out