Source code for data.datasets.detection.coco_mask_rcnn

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import argparse
from typing import List, Mapping, Tuple, Union

import torch
from torch import Tensor

from data.collate_fns import COLLATE_FN_REGISTRY
from data.datasets import DATASET_REGISTRY
from data.datasets.detection.coco_base import COCODetection
from data.transforms import image_pil as T
from data.transforms.common import Compose


[docs]@DATASET_REGISTRY.register(name="coco_mask_rcnn", type="detection") class COCODetectionMaskRCNN(COCODetection): """Dataset class for the MS COCO Object Detection using Mask RCNN . Args: opts: Command-line arguments """
[docs] def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: super().__init__(opts=opts, *args, **kwargs) # set the collate functions for the dataset setattr(opts, "dataset.collate_fn_name_train", "coco_mask_rcnn_collate_fn") setattr(opts, "dataset.collate_fn_name_val", "coco_mask_rcnn_collate_fn") setattr(opts, "dataset.collate_fn_name_test", "coco_mask_rcnn_collate_fn")
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: if cls != COCODetectionMaskRCNN: # Don't re-register arguments in subclasses that don't override `add_arguments()`. return parser group = parser.add_argument_group(title=cls.__name__) group.add_argument( "--dataset.detection.coco-mask-rcnn.use-lsj-aug", action="store_true", help="Use large scale jitter augmentation for training Mask RCNN model", ) return parser
def _training_transforms( self, size: Tuple[int, int], *args, **kwargs ) -> T.BaseTransformation: """Data augmentation during training. Default order of transformation is Resize, RandomHorizontalFlip, ToTensor. When large-scale jittering is enabled, Resize is replaced with ScaleJitter and FixedSizeCrop Args: size: Size for resizing the input image. Expected to be a tuple (height, width) Returns: An instance of `data.transforms.image_pil.BaseTransformation.` """ if getattr(self.opts, "dataset.detection.coco_mask_rcnn.use_lsj_aug"): # Apply large scale jittering, following https://arxiv.org/abs/2012.07177 aug_list = [ T.ScaleJitter(opts=self.opts), T.FixedSizeCrop(opts=self.opts), T.RandomHorizontalFlip(opts=self.opts), T.ToTensor(opts=self.opts), ] else: # standard augmentation for Mask-RCNN aug_list = [ T.Resize(opts=self.opts, img_size=size), T.RandomHorizontalFlip(opts=self.opts), T.ToTensor(opts=self.opts), ] return Compose(opts=self.opts, img_transforms=aug_list) def _validation_transforms( self, size: Tuple[int, int], *args, **kwargs ) -> T.BaseTransformation: """Data augmentation during validation or evaluation. Default order of transformation is Resize, ToTensor. Args: size: Size for resizing the input image. Expected to be a tuple (height, width) Returns: An instance of `data.transforms.image_pil.BaseTransformation.` """ aug_list = [ T.Resize(opts=self.opts), T.ToTensor(opts=self.opts), ] return Compose(opts=self.opts, img_transforms=aug_list) def __getitem__( self, sample_size_and_index: Tuple[int, int, int], *args, **kwargs ) -> Mapping[str, Union[Tensor, Mapping[str, Tensor]]]: """Returns the sample corresponding to the input sample index. Returned sample is transformed into the size specified by the input. Args: sample_size_and_index: Tuple of the form (crop_size_h, crop_size_w, sample_index) Returns: A dictionary with `samples` and `targets` as keys corresponding to input and labels of a sample, respectively. Shapes: The shape of values in output dictionary, output_data, are as follows: output_data["samples"]["image"]: Shape is [Channels, Height, Width] output_data["samples"]["label]["labels"]: Shape is [Num of boxes] output_data["samples"]["label"]["boxes"]: Shape is [Num of boxes, 4] output_data["samples"]["label"]["masks"]: Shape is [Num of boxes, Height, Width] output_data["targets"]["image_id"]: Shape is [1] output_data["targets"]["image_width"]: Shape is [1] output_data["targets"]["image_height"]: Shape is [1] """ crop_size_h, crop_size_w, img_index = sample_size_and_index transform_fn = self.get_augmentation_transforms(size=(crop_size_h, crop_size_w)) image_id = self.ids[img_index] image, img_name = self.get_image(image_id=image_id) im_width, im_height = image.size boxes, labels, mask = self.get_boxes_and_labels( image_id=image_id, image_width=im_width, image_height=im_height, include_masks=True, ) data = { "image": image, "box_labels": labels, "box_coordinates": boxes, "mask": mask, } if transform_fn is not None: data = transform_fn(data) output_data = { "samples": { "image": data["image"], # PyTorch Mask RCNN implementation expect labels as an input. Because we do not want to change # the training infrastructure of CVNets library, we pass labels as part of image key and # handle it in the model. "label": { "labels": data["box_labels"], "boxes": data["box_coordinates"], "masks": data["mask"], }, }, "targets": { "image_id": torch.tensor(image_id), "image_width": torch.tensor(im_width), "image_height": torch.tensor(im_height), }, } return output_data
[docs]@COLLATE_FN_REGISTRY.register(name="coco_mask_rcnn_collate_fn") def coco_mask_rcnn_collate_fn( batch: List[Mapping[str, Union[Tensor, Mapping[str, Tensor]]]], opts: argparse.Namespace, *args, **kwargs ) -> Mapping[str, Union[List[Tensor], Mapping[str, List[Tensor]]]]: """Combines a list of dictionaries into a single dictionary by concatenating matching fields. For expected keys, see the keys in the output of `__getitem__` function of COCODetectionMaskRCNN class. Args: batch: A list of dictionaries opts: Command-line arguments Returns: A dictionary with `samples` and `targets` as keys. """ new_batch = {"samples": {"image": [], "label": []}, "targets": []} for b_id, batch_ in enumerate(batch): new_batch["samples"]["image"].append(batch_["samples"]["image"]) new_batch["samples"]["label"].append(batch_["samples"]["label"]) new_batch["targets"].append(batch_["targets"]) return new_batch