#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
from typing import List, Tuple
import torch
from torch import Tensor, nn
from cvnets.layers import (
ConvLayer2d,
LinearLayer,
TransposeConvLayer2d,
get_normalization_layer,
)
from cvnets.misc.init_utils import initialize_conv_layer, initialize_fc_layer
# Below classes are adapted from Torchvision version=0.12 to make the code compatible with previous torch versions.
[docs]class FastRCNNConvFCHead(nn.Sequential):
[docs] def __init__(
self,
opts,
input_size: Tuple[int, int, int],
conv_layers: List[int],
fc_layers: List[int],
*args,
**kwargs,
):
"""
Args:
input_size (Tuple[int, int, int]): the input size in CHW format.
conv_layers (list): feature dimensions of each Convolution layer
fc_layers (list): feature dimensions of each FCN layer
"""
in_channels, in_height, in_width = input_size
blocks = []
previous_channels = in_channels
for current_channels in conv_layers:
blocks.extend(
[
ConvLayer2d(
opts,
in_channels=previous_channels,
out_channels=current_channels,
kernel_size=3,
stride=1,
use_norm=False,
use_act=False,
),
replace_syncbn_with_syncbnfp32(opts, num_features=current_channels),
nn.ReLU(inplace=False),
]
)
previous_channels = current_channels
blocks.append(nn.Flatten())
previous_channels = previous_channels * in_height * in_width
for current_channels in fc_layers:
blocks.append(LinearLayer(previous_channels, current_channels, bias=True))
blocks.append(nn.ReLU(inplace=True))
previous_channels = current_channels
super().__init__(*blocks)
for layer in self.modules():
if isinstance(layer, nn.Conv2d):
initialize_conv_layer(module=layer, init_method="kaiming_normal")
elif isinstance(layer, LinearLayer):
initialize_fc_layer(module=layer, init_method="kaiming_uniform")
[docs]class RPNHead(nn.Module):
"""
Adds a simple RPN Head with classification and regression heads
Args:
in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted
conv_depth (int, optional): number of convolutions
"""
[docs] def __init__(self, opts, in_channels: int, num_anchors: int, conv_depth=1) -> None:
super().__init__()
convs = []
for _ in range(conv_depth):
convs.extend(
[
ConvLayer2d(
opts,
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=1,
use_norm=False,
use_act=False,
bias=False,
),
replace_syncbn_with_syncbnfp32(opts, num_features=in_channels),
nn.ReLU(inplace=False),
]
)
self.conv = nn.Sequential(*convs)
self.cls_logits = ConvLayer2d(
opts,
in_channels=in_channels,
out_channels=num_anchors,
kernel_size=1,
stride=1,
use_norm=False,
use_act=False,
bias=True,
)
self.bbox_pred = ConvLayer2d(
opts,
in_channels=in_channels,
out_channels=num_anchors * 4,
kernel_size=1,
stride=1,
use_act=False,
use_norm=False,
bias=True,
)
for layer in self.modules():
if isinstance(layer, nn.Conv2d):
initialize_conv_layer(module=layer, init_method="normal", std_val=0.01)
[docs] def forward(self, x: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
logits = []
bbox_reg = []
for feature in x:
t = self.conv(feature)
logits.append(self.cls_logits(t))
bbox_reg.append(self.bbox_pred(t))
return logits, bbox_reg
[docs]class MaskRCNNHeads(nn.Sequential):
[docs] def __init__(self, opts, in_channels: int, layers: List, dilation: int):
"""
Args:
in_channels (int): number of input channels
layers (list): feature dimensions of each FCN layer
dilation (int): dilation rate of kernel
norm_layer (callable, optional): Module specifying the normalization layer to use. Default: None
"""
blocks = []
next_feature = in_channels
for layer_features in layers:
blocks.extend(
[
ConvLayer2d(
opts=opts,
in_channels=next_feature,
out_channels=layer_features,
kernel_size=3,
stride=1,
dilation=dilation,
use_norm=False,
use_act=False,
bias=False,
),
replace_syncbn_with_syncbnfp32(
opts=opts, num_features=layer_features
),
nn.ReLU(inplace=False),
]
)
next_feature = layer_features
super().__init__(*blocks)
for layer in self.modules():
if isinstance(layer, nn.Conv2d):
initialize_conv_layer(module=layer, init_method="kaiming_normal")
[docs]class MaskRCNNPredictor(nn.Sequential):
[docs] def __init__(
self, opts, in_channels: int, dim_reduced: int, num_classes: int
) -> None:
super().__init__(
*[
TransposeConvLayer2d(
opts,
in_channels=in_channels,
out_channels=dim_reduced,
kernel_size=2,
stride=2,
padding=0,
output_padding=0,
use_norm=False,
use_act=False,
bias=False,
groups=1,
),
replace_syncbn_with_syncbnfp32(opts, num_features=dim_reduced),
nn.ReLU(inplace=False),
ConvLayer2d(
opts,
in_channels=dim_reduced,
out_channels=num_classes,
kernel_size=1,
stride=1,
bias=True,
use_norm=False,
use_act=False,
),
]
)
for layer in self.modules():
if isinstance(layer, (nn.Conv2d, nn.ConvTranspose2d)):
initialize_conv_layer(module=layer, init_method="kaiming_normal")
[docs]class FastRCNNPredictor(nn.Module):
"""
Standard classification + bounding box regression layers
for Fast R-CNN.
Args:
in_channels (int): number of input channels
num_classes (int): number of output classes (including background)
"""
[docs] def __init__(self, in_channels: int, num_classes: int) -> None:
super().__init__()
self.cls_score = LinearLayer(in_channels, num_classes, bias=True)
self.bbox_pred = LinearLayer(in_channels, num_classes * 4, bias=True)
for layer in self.modules():
if isinstance(layer, LinearLayer):
initialize_fc_layer(module=layer, init_method="kaiming_uniform")
[docs] def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
if x.dim() == 4:
torch._assert(
list(x.shape[2:]) == [1, 1],
f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}",
)
x = x.flatten(start_dim=1)
scores = self.cls_score(x)
bbox_deltas = self.bbox_pred(x)
return scores, bbox_deltas
[docs]def replace_syncbn_with_syncbnfp32(opts, num_features: int) -> nn.Module:
# Sync-BN with 0 batch size does not work well with AMP. To avoid that,
# we replace all sync_bn in mask rcnn head with FP32 ones.
norm_layer = getattr(opts, "model.normalization.name", None)
if norm_layer.find("sync") > -1:
return get_normalization_layer(
opts, num_features=num_features, norm_type="sync_batch_norm_fp32"
)
else:
return get_normalization_layer(opts=opts, num_features=num_features)