Source code for data.datasets.classification.base_image_classification_dataset

# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
import argparse
from typing import Any, Dict, Tuple, Union

import torch
from torchvision.datasets import ImageFolder

from data.datasets.dataset_base import BaseImageDataset
from data.datasets.utils.common import select_samples_by_category
from data.transforms import image_pil
from data.transforms.common import Compose
from utils import logger
from utils.ddp_utils import is_master

[docs]class BaseImageClassificationDataset(BaseImageDataset, ImageFolder): """Image Classification Dataset. This base class can be used to represent any image classification dataset which is stored in a way that meets the expectations of `torchvision.datasets.ImageFolder`. New image classification datasets can be derived from this similar to ImageNetDataset ( or Places365Dataset ( and overwrite the data transformations as needed. This dataset also supports sampling a random subset of the training set to be used for training. The subset size is determined by the arguments `dataset.num_samples_per_category` and `dataset.percentage_of_samples` in the input `opts`. Only one of these two should be specified. When `dataset.percentage_of_samples` is specified, data is sampled from all classes according to this percentage such that the distribution of classes does not change. The randomness in sampling is controlled by the `dataset.sample_selection_random_seed` in the input `opts`. Args: opts: An argparse.Namespace instance. """
[docs] def __init__( self, opts: argparse.Namespace, *args, **kwargs, ) -> None: BaseImageDataset.__init__( self, opts=opts, *args, **kwargs, ) root = self.root ImageFolder.__init__( self, root=root, transform=None, target_transform=None, is_valid_file=None, ) self.n_classes = len(list(self.class_to_idx.keys())) setattr(opts, "model.classification.n_classes", self.n_classes) for collate_fn in [ "dataset.collate_fn_name_train", "dataset.collate_fn_name_val", "dataset.collate_fn_name_test", ]: if not hasattr(opts, collate_fn): setattr(opts, collate_fn, "image_classification_data_collate_fn") master = is_master(self.opts) if master: logger.log("Number of categories: {}".format(self.n_classes)) logger.log("Total number of samples: {}".format(len(self.samples))) num_samples_per_category = getattr( self.opts, "dataset.num_samples_per_category" ) percentage_of_samples = getattr(self.opts, "dataset.percentage_of_samples") if self.is_training and ( num_samples_per_category > 0 or (0 < percentage_of_samples < 100) ): if num_samples_per_category > 0 and (0 < percentage_of_samples < 100): raise ValueError( "Both `dataset.num_samples_per_category` and `dataset.percentage_of_samples` are specified. " "Please specify only one." ) random_seed = getattr(self.opts, "dataset.sample_selection_random_seed") if num_samples_per_category > 0: selected_sample_indices = select_samples_by_category( sample_category_labels=self.targets, random_seed=random_seed, num_samples_per_category=num_samples_per_category, ) if master: logger.log( "Using {} samples per category.".format( num_samples_per_category ) ) else: selected_sample_indices = select_samples_by_category( sample_category_labels=self.targets, random_seed=random_seed, percentage_of_samples_per_category=percentage_of_samples, ) if master: logger.log( "Using {} percentage of samples per category.".format( percentage_of_samples ) ) self.samples = [self.samples[ind] for ind in selected_sample_indices] self.imgs = [self.imgs[ind] for ind in selected_sample_indices] self.targets = [self.targets[ind] for ind in selected_sample_indices] elif master: logger.log("Using all samples in the dataset.")
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """ Adds dataset related arguments to the parser. Args: parser: An argparse.Namespace instance Returns: Input argparse.Namespace instance with additional arguments. """ if cls != BaseImageClassificationDataset: # Don't re-register arguments in subclasses that don't override `add_arguments()`. return parser group = parser.add_argument_group(title=cls.__name__) group.add_argument( "--dataset.num-samples-per-category", type=int, default=-1, help="Number of samples to use per category. If set to -1, all samples will be used.", ) # sample efficient training group.add_argument( "--dataset.sample-efficient-training.enable", action="store_true", default=False, help="Enabling this flag allows us to use sample efficient training. " "Currently only tested on classification tasks.", ) group.add_argument( "--dataset.sample-efficient-training.sample-confidence", type=float, default=0.5, help="Confidence for identifying samples as easy or hard. If the " "predicted classification probability of a sample is greater than " "this value, then sample is classified as easy. Otherwise, it is " "classified as hard. Defaults to 0.5.", ) group.add_argument( "--dataset.sample-efficient-training.find-easy-samples-every-k-epochs", type=int, default=5, help="Find easy samples after every K epochs. Defaults to 5.", ) group.add_argument( "--dataset.sample-efficient-training.min-sample-frequency", type=int, default=5, help="Frequency that sample has been classified as easy for N number of times." "Defaults to 5.", ) return parser
def _training_transforms( self, size: Union[Tuple[int, int], int], *args, **kwargs ) -> image_pil.BaseTransformation: """ Returns transformations applied to the input in training mode. Order of transformations: RandomResizedCrop, RandomHorizontalFlip, One of AutoAugment or RandAugment or TrivialAugmentWide, RandomErasing Batch-based augmentations such as Mixup and CutMix are implemented in trainer. Args: size: Size for resizing the input image. Expected to be an integer (width=height) or a tuple (height, width) Returns: An instance of `data.transforms.image_pil.BaseTransformation.` """ if not getattr(self.opts, "image_augmentation.random_resized_crop.enable"): raise ValueError( "`image_augmentation.random_resized_crop.enable` must be set to True in input options." ) aug_list = [image_pil.RandomResizedCrop(opts=self.opts, size=size)] if getattr(self.opts, "image_augmentation.random_horizontal_flip.enable"): aug_list.append(image_pil.RandomHorizontalFlip(opts=self.opts)) auto_augment = getattr(self.opts, "image_augmentation.auto_augment.enable") rand_augment = getattr(self.opts, "image_augmentation.rand_augment.enable") trivial_augment_wide = getattr( self.opts, "image_augmentation.trivial_augment_wide.enable" ) if bool(auto_augment) + bool(rand_augment) + bool(trivial_augment_wide) > 1: logger.error( "Only one of AutoAugment, RandAugment and TrivialAugmentWide should be used." ) elif auto_augment: aug_list.append(image_pil.AutoAugment(opts=self.opts)) elif rand_augment: if getattr(self.opts, "image_augmentation.rand_augment.use_timm_library"): aug_list.append(image_pil.RandAugmentTimm(opts=self.opts)) else: aug_list.append(image_pil.RandAugment(opts=self.opts)) elif trivial_augment_wide: aug_list.append(image_pil.TrivialAugmentWide(opts=self.opts)) aug_list.append(image_pil.ToTensor(opts=self.opts)) if getattr(self.opts, "image_augmentation.random_erase.enable"): aug_list.append(image_pil.RandomErasing(opts=self.opts)) return Compose(opts=self.opts, img_transforms=aug_list) def _validation_transforms(self, *args, **kwargs) -> image_pil.BaseTransformation: """ Returns transformations applied to the input in validation mode. Oder of augmentations: Resize followed by CenterCrop """ if not getattr(self.opts, "image_augmentation.resize.enable"): raise ValueError( "`image_augmentation.resize.enable` must be set to True in input options." ) aug_list = [image_pil.Resize(opts=self.opts)] if not getattr(self.opts, "image_augmentation.center_crop.enable"): raise ValueError( "`image_augmentation.center_crop.enable` must be set to True in input options." ) aug_list.append(image_pil.CenterCrop(opts=self.opts)) aug_list.append(image_pil.ToTensor(opts=self.opts)) return Compose(opts=self.opts, img_transforms=aug_list) def __getitem__( self, sample_size_and_index: Tuple[int, int, int] ) -> Dict[str, Any]: """Returns the sample corresponding to the input sample index. Returned sample is transformed into the size specified by the input. Args: sample_size_and_index: Tuple of the form (crop_size_h, crop_size_w, sample_index) Returns: A dictionary with `samples`, `sample_id` and `targets` as keys corresponding to input, index and label of a sample, respectively. Shapes: The output data dictionary contains three keys (samples, sample_id, and target). The values of these keys has the following shapes: data["samples"]: Shape is [Channels, Height, Width] data["sample_id"]: Shape is 1 data["targets"]: Shape is 1 """ crop_size_h, crop_size_w, sample_index = sample_size_and_index transform_fn = self.get_augmentation_transforms(size=(crop_size_h, crop_size_w)) img_path, target = self.samples[sample_index] input_img = self.read_image_pil(img_path) if input_img is None: # Sometimes images are corrupt # Skip such images logger.log("Img index {} is possibly corrupt.".format(sample_index)) input_tensor = torch.zeros( size=(3, crop_size_h, crop_size_w), dtype=torch.float ) target = -1 data = {"image": input_tensor} else: data = {"image": input_img} data = transform_fn(data) data["samples"] = data.pop("image") data["targets"] = target data["sample_id"] = sample_index return data def __len__(self) -> int: return len(self.samples)
[docs] def share_dataset_arguments(self) -> Dict[str, Any]: """Returns the number of classes in classification dataset along with super-class arguments.""" share_dataset_specific_opts: Dict[str, Any] = super().share_dataset_arguments() share_dataset_specific_opts["model.classification.n_classes"] = self.n_classes return share_dataset_specific_opts
[docs] def extra_repr(self) -> str: extra_repr_str = super().extra_repr() return extra_repr_str + f"\n\t num_classes={self.n_classes}"