Source code for data.datasets.segmentation.base_segmentation

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import argparse
import os
from typing import List, Mapping, Optional, Tuple, Union

import numpy as np
import torch
from PIL import Image
from torch import Tensor

from data.datasets.dataset_base import BaseImageDataset
from data.transforms import image_pil as T
from data.transforms.common import Compose
from utils import logger
from utils.color_map import Colormap


[docs]class BaseImageSegmentationDataset(BaseImageDataset): """Base Dataset class for Image Segmentation datasets. Sub-classes must define `ignore_label` and `background_idx` variable. Args: opts: Command-line arguments """
[docs] def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: super().__init__(opts=opts, *args, **kwargs) self.masks = None self.images = None # ignore label and background indices are dataset specific. So, child classes # need to implement these self.ignore_label = None self.background_idx = None
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: if cls != BaseImageSegmentationDataset: # Don't re-register arguments in subclasses that don't override `add_arguments()`. return parser group = parser.add_argument_group(cls.__name__) # segmentation evaluation related arguments group.add_argument( "--evaluation.segmentation.apply-color-map", action="store_true", default=False, help="Apply color map to different classes in segmentation masks. Useful in visualization " "+ some competitions (e.g, PASCAL VOC) accept submissions with colored segmentation masks." "Defaults to False.", ) group.add_argument( "--evaluation.segmentation.save-overlay-rgb-pred", action="store_true", default=False, help="Enable this flag to visualize predicted masks on top of input image. " "Defaults to False.", ) group.add_argument( "--evaluation.segmentation.save-masks", action="store_true", default=False, help="Save predicted masks without colormaps. Useful for submitting to " "competitions like Cityscapes. Defaults to False.", ) group.add_argument( "--evaluation.segmentation.overlay-mask-weight", default=0.5, type=float, help="Contribution of mask when overlaying on top of RGB image. Defaults to 0.5.", ) group.add_argument( "--evaluation.segmentation.mode", type=str, default="validation_set", choices=["single_image", "image_folder", "validation_set"], help="Contribution of mask when overlaying on top of RGB image. Defaults to validation_set.", ) group.add_argument( "--evaluation.segmentation.path", type=str, default=None, help="Path of the image or image folder (only required for single_image and image_folder modes). " "Defaults to None.", ) group.add_argument( "--evaluation.segmentation.resize-input-images", action="store_true", default=False, help="Enable resizing input images while maintaining aspect ratio during segmentation evaluation." "Defaults to False.", ) group.add_argument( "--evaluation.segmentation.resize-input-images-fixed-size", action="store_true", default=False, help="Enable resizing input images to fixed size during segmentation evaluation. " "Defaults to False.", ) return parser
[docs] def check_dataset(self) -> None: # TODO: Remove this check in future assert self.masks is not None, "Please specify masks for segmentation data" assert self.images is not None, "Please specify images for segmentation data" assert ( self.ignore_label is not None ), "Please specify ignore label for segmentation dataset" assert ( self.background_idx is not None ), "Please specify background index for segmentation dataset"
def _training_transforms(self, size: Tuple[int, int]) -> T.BaseTransformation: """Data augmentation during training. Order of transformation is RandomShortSizeResize, RandomHorizontalFlip, RandomCrop, Optional[RandomGaussianBlur], Optional[PhotometricDistort], Optional[RandomRotate]. If random order is enabled, then the order of transforms is shuffled, with an exception to RandomShortSizeResize. These transforms are followed by ToTensor. Args: size: Size for resizing the input image. Expected to be a tuple (height, width) Returns: An instance of `data.transforms.image_pil.BaseTransformation.` """ first_aug = T.RandomShortSizeResize(opts=self.opts) aug_list = [ T.RandomHorizontalFlip(opts=self.opts), T.RandomCrop(opts=self.opts, size=size, ignore_idx=self.ignore_label), ] if getattr(self.opts, "image_augmentation.random_gaussian_noise.enable"): aug_list.append(T.RandomGaussianBlur(opts=self.opts)) if getattr(self.opts, "image_augmentation.photo_metric_distort.enable"): aug_list.append(T.PhotometricDistort(opts=self.opts)) if getattr(self.opts, "image_augmentation.random_rotate.enable"): aug_list.append(T.RandomRotate(opts=self.opts)) if getattr(self.opts, "image_augmentation.random_order.enable"): new_aug_list = [ first_aug, T.RandomOrder(opts=self.opts, img_transforms=aug_list), T.ToTensor(opts=self.opts), ] return Compose(opts=self.opts, img_transforms=new_aug_list) else: aug_list.insert(0, first_aug) aug_list.append(T.ToTensor(opts=self.opts)) return Compose(opts=self.opts, img_transforms=aug_list) def _validation_transforms( self, size: Tuple[int, int], *args, **kwargs ) -> T.BaseTransformation: """Data augmentation during validation. Order of transformation is Resize, ToTensor Args: size: Size for resizing the input image. Expected to be a tuple (height, width) Returns: An instance of `data.transforms.image_pil.BaseTransformation.` """ aug_list = [T.Resize(opts=self.opts), T.ToTensor(opts=self.opts)] return Compose(opts=self.opts, img_transforms=aug_list) def _evaluation_transforms( self, size: Union[int, Tuple[int, int]], *args, **kwargs ) -> T.BaseTransformation: """Data augmentation during testing/evaluation. Order of transformation is Optional[Resize], ToTensor Args: size: Size for resizing the input image. Expected to be an int or a tuple (height, width) Returns: An instance of `data.transforms.image_pil.BaseTransformation.` ...note:: When `evaluation.segmentation.resize_input_images` is enabled, then images are resized while maintaining the aspect ratio. If size is a tuple of integers, then min(size) is used as a size. When `evaluation.segmentation.resize_input_images_fixed_size` is enabled, then images are resized to the given size. """ aug_list = [] resize_maintain_ar = getattr( self.opts, "evaluation.segmentation.resize_input_images" ) resize_fixed_size = getattr( self.opts, "evaluation.segmentation.resize_input_images_fixed_size" ) if resize_maintain_ar: assert resize_fixed_size is False # A standard practice for tasks of segmentation is to resize images while maintaining # aspect ratio. To do so during evaluation, we pass min(img_size) as size as an # argument to resize function. The resize function then resizes image while # maintaining aspect ratio. aug_list.append(T.Resize(opts=self.opts, img_size=min(size))) elif resize_fixed_size: assert resize_maintain_ar is False # we want to resize while maintaining aspect ratio. So, we pass size as an # argument to resize function aug_list.append(T.Resize(opts=self.opts, img_size=size)) # default is no resizing aug_list.append(T.ToTensor(opts=self.opts)) return Compose(opts=self.opts, img_transforms=aug_list)
[docs] @staticmethod def adjust_mask_value() -> int: """Adjust the mask value by this factor""" # Some datasets (e.g., ADE20k) requires us to adjust the mask value. # By default, we set to 0. But child classes can adjust it return 0
def __len__(self) -> int: """Number of samples in segmentation dataset""" return len(self.images)
[docs] @staticmethod def color_palette() -> List[int]: """Class index to RGB color mapping. The list index corresponds to class id. Note that the color list is flattened.""" # Child classes may override this method (optionally) return Colormap().get_color_map_list()
[docs] @staticmethod def class_names() -> List[str]: """Class index to name. The list index should correspond to class id""" # Child classes may implement these methods (optionally) raise NotImplementedError
[docs] @staticmethod def read_mask_pil(path: str) -> Optional[Image.Image]: """Reads mask image and returns as a PIL image""" try: mask = Image.open(path) if mask.mode != "L": logger.error("Mask mode should be L. Got: {}".format(mask.mode)) return mask except: return None
[docs] @staticmethod def convert_mask_to_tensor(mask: Image.Image) -> Tensor: """Convert PIL mask to Tensor""" # convert to tensor mask = np.array(mask) if len(mask.shape) > 2 and mask.shape[-1] > 1: mask = np.ascontiguousarray(mask.transpose(2, 0, 1)) return torch.as_tensor(mask, dtype=torch.long)
def __getitem__( self, sample_size_and_index: Tuple[int, int, int], *args, **kwargs ) -> Mapping[str, Union[Tensor, Mapping[str, Tensor]]]: """Returns the sample corresponding to the input sample index. Returned sample is transformed into the size specified by the input. Args: sample_size_and_index: Tuple of the form (crop_size_h, crop_size_w, sample_index) Returns: A dictionary with `samples` and `targets` as keys corresponding to input and labels of a sample, respectively. Shapes: The shape of values in output dictionary, output_data, are as follows: output_data["samples"]["image"]: Shape is [Channels, Height, Width] output_data["targets"]["mask"]: Shape is [Height, Width] """ crop_size_h, crop_size_w, img_index = sample_size_and_index transform = self.get_augmentation_transforms(size=(crop_size_h, crop_size_w)) mask = self.read_mask_pil(self.masks[img_index]) img = self.read_image_pil(self.images[img_index]) if (img.size[0] != mask.size[0]) or (img.size[1] != mask.size[1]): logger.error( "Input image and mask sizes are different. Input size: {} and Mask size: {}".format( img.size, mask.size ) ) data = {"image": img} if not self.is_evaluation: data["mask"] = mask data = transform(data) if self.is_evaluation: # for evaluation purposes, resize only the input and not mask data["mask"] = self.convert_mask_to_tensor(mask) output_data = { "samples": data["image"], # ignore dataset specific indices in mask "targets": data["mask"] - self.adjust_mask_value(), } if self.is_evaluation: im_width, im_height = img.size img_name = self.images[img_index].split(os.sep)[-1].replace("jpg", "png") mask = output_data.pop("targets") output_data["targets"] = { "mask": mask, "file_name": img_name, "im_width": im_width, "im_height": im_height, } return output_data