Source code for cvnets.modules.ssd_heads

# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.

from typing import Optional, Tuple

import torch
from torch import Tensor, nn
from torchvision.ops.roi_align import RoIAlign

from cvnets.layers import ConvLayer2d, SeparableConv2d, TransposeConvLayer2d
from cvnets.misc.init_utils import initialize_conv_layer
from cvnets.modules import BaseModule

[docs]class SSDHead(BaseModule): """ This class defines the `SSD object detection Head <>`_ Args: opts: command-line arguments in_channels (int): :math:`C` from an expected input of size :math:`(N, C, H, W)` n_anchors (int): Number of anchors n_classes (int): Number of classes in the dataset n_coordinates (Optional[int]): Number of coordinates. Default: 4 (x, y, w, h) proj_channels (Optional[int]): Number of projected channels. If `-1`, then projection layer is not used kernel_size (Optional[int]): Kernel size in convolutional layer. If kernel_size=1, then standard point-wise convolution is used. Otherwise, separable convolution is used stride (Optional[int]): stride for feature map. If stride > 1, then feature map is sampled at this rate and predictions are made on fewer pixels as compared to the input tensor. Default: 1 """
[docs] def __init__( self, opts, in_channels: int, n_anchors: int, n_classes: int, n_coordinates: Optional[int] = 4, proj_channels: Optional[int] = -1, kernel_size: Optional[int] = 3, stride: Optional[int] = 1, *args, **kwargs ) -> None: super().__init__() proj_layer = None self.proj_channels = None if proj_channels != -1 and proj_channels != in_channels and kernel_size > 1: proj_layer = ConvLayer2d( opts=opts, in_channels=in_channels, out_channels=proj_channels, kernel_size=1, stride=1, groups=1, bias=False, use_norm=True, use_act=True, ) in_channels = proj_channels self.proj_channels = proj_channels self.proj_layer = proj_layer conv_fn = ConvLayer2d if kernel_size == 1 else SeparableConv2d if kernel_size > 1 and stride > 1: kernel_size = max(kernel_size, stride if stride % 2 != 0 else stride + 1) self.loc_cls_layer = conv_fn( opts=opts, in_channels=in_channels, out_channels=n_anchors * (n_coordinates + n_classes), kernel_size=kernel_size, stride=1, groups=1, bias=True, use_norm=False, use_act=False, ) self.n_coordinates = n_coordinates self.n_classes = n_classes self.n_anchors = n_anchors self.k_size = kernel_size self.stride = stride self.in_channel = in_channels self.reset_parameters()
def __repr__(self) -> str: repr_str = "{}(in_channels={}, n_anchors={}, n_classes={}, n_coordinates={}, kernel_size={}, stride={}".format( self.__class__.__name__, self.in_channel, self.n_anchors, self.n_classes, self.n_coordinates, self.k_size, self.stride, ) if self.proj_layer is not None: repr_str += ", proj=True, proj_channels={}".format(self.proj_channels) repr_str += ")" return repr_str
[docs] def reset_parameters(self) -> None: for layer in self.modules(): if isinstance(layer, nn.Conv2d): initialize_conv_layer(module=layer, init_method="xavier_uniform")
def _sample_fm(self, x: Tensor) -> Tensor: height, width = x.shape[-2:] device = x.device start_step = max(0, self.stride // 2) indices_h = torch.arange( start=start_step, end=height, step=self.stride, dtype=torch.int64, device=device, ) indices_w = torch.arange( start=start_step, end=width, step=self.stride, dtype=torch.int64, device=device, ) x_sampled = torch.index_select(x, dim=-1, index=indices_w) x_sampled = torch.index_select(x_sampled, dim=-2, index=indices_h) return x_sampled
[docs] def forward(self, x: Tensor, *args, **kwargs) -> Tuple[Tensor, Tensor]: batch_size = x.shape[0] if self.proj_layer is not None: x = self.proj_layer(x) # [B x C x H x W] --> [B x Anchors * (coordinates + classes) x H x W] x = self.loc_cls_layer(x) if self.stride > 1: x = self._sample_fm(x) # [B x Anchors * (coordinates + classes) x H x W] --> [B x H x W x Anchors * (coordinates + classes)] x = x.permute(0, 2, 3, 1) # [B x H x W x Anchors * (coordinates + classes)] --> [B x H*W*Anchors X (coordinates + classes)] x = x.contiguous().view(batch_size, -1, self.n_coordinates + self.n_classes) # [B x H*W*Anchors X (coordinates + classes)] --> [B x H*W*Anchors X coordinates], [B x H*W*Anchors X classes] box_locations, box_classes = torch.split( x, [self.n_coordinates, self.n_classes], dim=-1 ) return box_locations, box_classes
[docs]class SSDInstanceHead(BaseModule): """ Instance segmentation head for SSD model. """
[docs] def __init__( self, opts, in_channels: int, n_classes: Optional[int] = 1, inner_dim: Optional[int] = 256, output_stride: Optional[int] = 1, output_size: Optional[int] = 8, *args, **kwargs ) -> None: """ Args: opts: command-line arguments in_channels (int): :math:`C` from an expected input of size :math:`(N, C, H, W)` n_classes (Optional[int]): Number of classes. Default: 1 inner_dim: (Optional[int]): Inner dimension of the instance head. Default: 256 output_stride (Optional[int]): Output stride of the feature map. Output stride is the ratio of input to the feature map size. Default: 1 output_size (Optional[int]): Output size of the instances extracted from RoIAlign layer. Default: 8 """ super().__init__() self.roi_align = RoIAlign( output_size=output_size, spatial_scale=1.0 / output_stride, sampling_ratio=2, aligned=True, ) self.seg_head = nn.Sequential( TransposeConvLayer2d( opts=opts, in_channels=in_channels, out_channels=inner_dim, kernel_size=2, stride=2, bias=True, use_norm=False, use_act=True, auto_padding=False, padding=0, output_padding=0, ), ConvLayer2d( opts=opts, in_channels=inner_dim, out_channels=n_classes, kernel_size=1, stride=1, use_norm=False, use_act=False, bias=True, ), ) self.inner_channels = inner_dim self.in_channels = in_channels self.mask_classes = n_classes self.reset_parameters()
def __repr__(self) -> str: return "{}(in_channels={}, up_out_channels={}, n_classes={})".format( self.__class__.__name__, self.in_channels, self.inner_channels, self.mask_classes, )
[docs] def reset_parameters(self) -> None: for layer in self.modules(): if isinstance(layer, (nn.Conv2d, nn.ConvTranspose2d)): initialize_conv_layer(module=layer, init_method="kaiming_normal")
[docs] def forward(self, x: Tensor, boxes: Tensor, *args, **kwargs) -> Tensor: rois = self.roi_align(x, boxes) rois = self.seg_head(rois) return rois