Source code for cvnets.models.classification.swin_transformer
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import argparse
from typing import Dict, List, Optional
import torch
from torch import Tensor, nn
from cvnets.layers import (
ConvLayer2d,
Dropout,
GlobalPool,
Identity,
LinearLayer,
get_normalization_layer,
)
from cvnets.models import MODEL_REGISTRY
from cvnets.models.classification.base_image_encoder import BaseImageEncoder
from cvnets.models.classification.config.swin_transformer import get_configuration
from cvnets.modules import PatchMerging, Permute, SwinTransformerBlock
from utils import logger
[docs]@MODEL_REGISTRY.register(name="swin", type="classification")
class SwinTransformer(BaseImageEncoder):
"""
Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using
Shifted Windows" <https://arxiv.org/pdf/2103.14030>`_ paper.
The code is adapted from `"Torchvision repository" <https://github.com/pytorch/vision/blob/main/torchvision/models/swin_transformer.py>`_
"""
[docs] def __init__(self, opts, *args, **kwargs) -> None:
image_channels = 3
num_classes = getattr(opts, "model.classification.n_classes", 1000)
classifier_dropout = getattr(
opts, "model.classification.classifier_dropout", 0.0
)
pool_type = getattr(opts, "model.layer.global_pool", "mean")
super().__init__(opts, *args, **kwargs)
cfg = get_configuration(opts=opts)
patch_size = cfg["patch_size"]
embed_dim = cfg["embed_dim"]
depths = cfg["depths"]
window_size = cfg["window_size"]
mlp_ratio = cfg["mlp_ratio"]
num_heads = cfg["num_heads"]
dropout = cfg["dropout"]
attn_dropout = cfg["attn_dropout"]
ffn_dropout = cfg["ffn_dropout"]
stochastic_depth_prob = cfg["stochastic_depth_prob"]
norm_layer = cfg["norm_layer"]
# store model configuration in a dictionary
self.model_conf_dict = dict()
self.conv_1 = nn.Sequential(
*[
ConvLayer2d(
opts=opts,
in_channels=image_channels,
out_channels=embed_dim,
kernel_size=patch_size,
stride=patch_size,
use_norm=False,
use_act=False,
),
Permute([0, 2, 3, 1]),
get_normalization_layer(
opts=opts, norm_type=norm_layer, num_features=embed_dim
),
]
)
self.model_conf_dict["conv1"] = {"in": image_channels, "out": embed_dim}
in_channels = embed_dim
self.model_conf_dict["layer1"] = {"in": embed_dim, "out": embed_dim}
# build SwinTransformer blocks
layers: List[nn.Module] = []
total_stage_blocks = sum(depths)
stage_block_id = 0
for i_stage in range(len(depths)):
stage: List[nn.Module] = []
dim = embed_dim * 2**i_stage
for i_layer in range(depths[i_stage]):
# adjust stochastic depth probability based on the depth of the stage block
sd_prob = (
stochastic_depth_prob
* float(stage_block_id)
/ (total_stage_blocks - 1)
)
stage.append(
SwinTransformerBlock(
opts,
dim,
num_heads[i_stage],
window_size=window_size,
shift_size=[
0 if i_layer % 2 == 0 else w // 2 for w in window_size
],
mlp_ratio=mlp_ratio,
dropout=dropout,
attn_dropout=attn_dropout,
ffn_dropout=ffn_dropout,
stochastic_depth_prob=sd_prob,
norm_layer=norm_layer,
)
)
stage_block_id += 1
# add patch merging layer
if i_stage < (len(depths) - 1):
stage += [PatchMerging(opts, dim, norm_layer)]
layers.append(nn.Sequential(*stage))
self.model_conf_dict["layer{}".format(i_stage + 2)] = {
"in": in_channels,
"out": dim,
}
in_channels = dim
self.layer_1, self.layer_2, self.layer_3, self.layer_4 = layers
# For segmentation architectures, we need to disable striding at an output stride of
# 8 or 16. Depending on the output stride value, we disable the striding in SwinTransformer
if self.dilate_l5:
for m in self.layer_3.modules():
if isinstance(m, PatchMerging):
m.strided = False
if self.dilate_l4:
for m in self.layer_2.modules():
if isinstance(m, PatchMerging):
m.strided = False
self.layer_5 = nn.Sequential(
*[
get_normalization_layer(
opts=opts, norm_type=norm_layer, num_features=in_channels
),
Permute([0, 3, 1, 2]),
]
)
self.conv_1x1_exp = Identity()
self.model_conf_dict["exp_before_cls"] = {
"in": in_channels,
"out": in_channels,
}
self.classifier = nn.Sequential()
self.classifier.add_module(
name="global_pool", module=GlobalPool(pool_type=pool_type, keep_dim=False)
)
if 0.0 < classifier_dropout < 1.0:
self.classifier.add_module(
name="classifier_dropout", module=Dropout(p=classifier_dropout)
)
self.classifier.add_module(
name="classifier_fc",
module=LinearLayer(
in_features=in_channels, out_features=num_classes, bias=True
),
)
self.model_conf_dict["cls"] = {"in": in_channels, "out": num_classes}
extract_enc_point_format = getattr(
opts, "model.classification.swin.extract_end_point_format", "nchw"
)
if extract_enc_point_format not in ["nchw", "nhwc"]:
logger.error(
"End point extraction format should be either nchw or nhwc. Got: {}".format(
extract_enc_point_format
)
)
self.extract_end_point_nchw_format = extract_enc_point_format == "nchw"
# check model
self.check_model()
# weight initialization
self.reset_parameters(opts=opts)
[docs] def extract_end_points_all(
self,
x: Tensor,
use_l5: Optional[bool] = True,
use_l5_exp: Optional[bool] = False,
*args,
**kwargs
) -> Dict[str, Tensor]:
# First conv layer in SwinTransformer down samples by a factor of 4, so we modify the end-point extraction
# function, so that the model is compatible with down-stream heads (e.g., Mask-RCNN)
out_dict = {} # Use dictionary over NamedTuple so that JIT is happy
if self.training and self.neural_augmentor is not None:
x = self.neural_augmentor(x)
out_dict["augmented_tensor"] = x
# [N, C, H, W] --> [N, H/4, W/4, C]
x = self.conv_1(x)
# first layer down-samples by 4, so L1 and l2 should be identity
if self.extract_end_point_nchw_format:
x_nchw = torch.permute(x, dims=(0, 3, 1, 2))
out_dict["out_l1"] = x_nchw
out_dict["out_l2"] = x_nchw
else:
out_dict["out_l1"] = x
out_dict["out_l2"] = x
# [N, H/4, W/4, C] --> [N, H/8, W/8, C]
x = self.layer_1(x)
out_dict["out_l3"] = (
torch.permute(x, dims=(0, 3, 1, 2))
if self.extract_end_point_nchw_format
else x
)
# [N, H/8, W/8, C] --> [N, H/16, W/16, C]
x = self.layer_2(x)
out_dict["out_l4"] = (
torch.permute(x, dims=(0, 3, 1, 2))
if self.extract_end_point_nchw_format
else x
)
if use_l5:
# [N, H/16, W/16, C] --> [N, H/32, W/32, C]
x = self.layer_3(x)
x = self.layer_4(x)
# [N, H/32, W/32, C] --> [N, C, H/32, W/32]
x = self.layer_5(x)
out_dict["out_l5"] = (
x
if self.extract_end_point_nchw_format
else torch.permute(x, dims=(0, 2, 3, 1))
)
if use_l5_exp:
x = self.conv_1x1_exp(x)
out_dict["out_l5_exp"] = x
return out_dict
[docs] @classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--model.classification.swin.mode",
type=str,
default="tiny",
help="SwinTransformer mode. Default is swin_t",
)
group.add_argument(
"--model.classification.swin.stochastic-depth-prob",
type=float,
default=None,
)
group.add_argument(
"--model.classification.swin.extract-end-point-format",
type=str,
default="nchw",
choices=["nchw", "nhwc"],
help="End point extraction format in Swin Transformer. This is useful for down-stream tasks where "
"task-specific heads are either in nhwc format or nchw format. Defaults to nchw.",
)
return parser