Source code for data.datasets.multi_modal_img_text.base_multi_modal_img_text

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import argparse
import os
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union

import torch
from PIL import Image
from torch import Tensor

from data.collate_fns import COLLATE_FN_REGISTRY
from data.datasets.dataset_base import BaseImageDataset
from data.datasets.multi_modal_img_text.zero_shot import (
    BaseZeroShotDataset,
    build_zero_shot_dataset,
)
from data.datasets.utils.text import caption_preprocessing
from data.text_tokenizer import build_tokenizer
from data.transforms import image_pil as T
from data.transforms.common import Compose
from utils import logger
from utils.ddp_utils import is_master, is_start_rank_node


[docs]class BaseMultiModalImgText(BaseImageDataset): """ Base class for Image-Text multi-modal learning Args: opts: command-line arguments """ __SEPARATOR = ":"
[docs] def __init__( self, opts, *args, **kwargs, ) -> None: super().__init__( opts=opts, *args, **kwargs, ) self.is_master_node = is_master(opts) self.is_start_rank_node = is_start_rank_node(opts) self.text_tokenizer = build_tokenizer(opts=opts, *args, **kwargs) self.context_length = getattr( opts, "dataset.multi_modal_img_text.context_length" ) # for sharing padding index across the entire cvnets framework, we will # use a special variable "dataset.padding_index". setattr(opts, "dataset.padding_index", None) self.padding_index = getattr(opts, "dataset.padding_index") vocab_size = self.text_tokenizer.get_vocab_size() if vocab_size is None or vocab_size == -1: logger.error( "Vocab size can't be None or -1 in {}. Got: {}".format( self.__class__.__name__, vocab_size ) ) self.vocab_size = vocab_size setattr(opts, "dataset.text_vocab_size", vocab_size) setattr(opts, "dataset.text_context_length", self.context_length) setattr( opts, "dataset.collate_fn_name_train", "multi_modal_img_text_collate_fn", ) setattr( opts, "dataset.collate_fn_name_val", "multi_modal_img_text_collate_fn", ) setattr( opts, "dataset.collate_fn_name_test", "multi_modal_img_text_collate_fn", ) self.zero_shot_dataset = self.get_zero_shot_dataset(*args, **kwargs) self.cached_zero_shot_captions = None self.cache_loc = os.path.join(self.root, ".img_text_tar_cache") if self.is_training: # Folder where we will download data # TODO: Training data can't fit on a single node, so we save/cache subset of data on each node. # In future, we may enable caching for validation data. try: os.makedirs(self.cache_loc, exist_ok=True) except Exception as e: logger.warning(f"Could not create cache location directory: {e}") self.dataset = self.get_dataset(*args, **kwargs)
[docs] def get_zero_shot_dataset(self, *args, **kwargs) -> Optional[BaseZeroShotDataset]: """If zero-shot evaluation is enabled, zero-shot dataset is returned. Otherwise, None is returned """ zero_shot_eval = ( False if self.is_training else getattr( self.opts, "dataset.multi_modal_img_text.zero_shot_eval", False ) ) if zero_shot_eval: zero_shot_dataset = build_zero_shot_dataset(opts=self.opts, *args, **kwargs) else: zero_shot_dataset = None return zero_shot_dataset
[docs] def get_dataset(self, *args, **kwargs) -> Any: """Helper function to get the dataset. Child classes must override this function""" raise NotImplementedError
[docs] def share_dataset_arguments(self) -> Dict[str, Any]: """Returns the number of classes in detection dataset along with super-class arguments.""" share_dataset_specific_opts: Dict[str, Any] = super().share_dataset_arguments() share_dataset_specific_opts["dataset.text_vocab_size"] = self.vocab_size share_dataset_specific_opts["dataset.text_context_length"] = self.context_length return share_dataset_specific_opts
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Add dataset-specific arguments to the parser.""" if cls != BaseMultiModalImgText: # Don't re-register arguments in subclasses that don't override `add_arguments()`. return parser group = parser.add_argument_group(cls.__name__) group.add_argument( "--dataset.multi-modal-img-text.context-length", type=int, default=77, help="Context length for the text model. Defaults to 77, the same as in CLIP paper.", ) group.add_argument( "--dataset.multi-modal-img-text.trunc-seq-len", action="store_true", default=False, help="Many sequences in a batch do not have lengths equal to specified context length. Enabling this flag " "allows us to truncate the sequences such that the sequence length of a batch is equal to sequence " "with max. non-padded tokens. Defaults to False.", ) return parser
def _transform_text(self, text_tensor: Tensor) -> Tuple[Tensor, int]: """Helper function to transform the text tensor. If the text tensor is smaller than the context length, it pads it and replaces the last token with EOT token. Args: text_tensor: Text tensor with N tokens. Shape is (N,). Returns: A Tuple of text tensor (whole length is equal to context length) and length of the tensor. """ captions_tensor = torch.zeros(size=(self.context_length,), dtype=torch.long) text_len = text_tensor.shape[0] if text_len > self.context_length: text_tensor = text_tensor[: self.context_length] text_tensor[-1] = self.text_tokenizer.get_eot_token() text_len = self.context_length captions_tensor[:text_len] = text_tensor[:text_len] return captions_tensor, text_len def _training_transforms( self, size: Tuple[int, int], *args, **kwargs ) -> T.BaseTransformation: """Data augmentation during training. The default order is RandomResizedCrop, Optional[RandAugment or AutoAugment], ToTensor, Optional[RandomErase] Args: size: Size for resizing the input image. Expected to be a tuple (height, width) Returns: An instance of `data.transforms.image_pil.BaseTransformation.` .. note:: 1. AutoAugment and RandAugment are mutually exclusive. 2. Mixup and CutMix are applied on batches are implemented in trainer. """ aug_list = [ T.RandomResizedCrop(opts=self.opts, size=size), ] auto_augment = getattr( self.opts, "image_augmentation.auto_augment.enable", False ) rand_augment = getattr( self.opts, "image_augmentation.rand_augment.enable", False ) if auto_augment and rand_augment: logger.error( "AutoAugment and RandAugment are mutually exclusive. Use either of them, but not both" ) elif auto_augment: aug_list.append(T.AutoAugment(opts=self.opts)) elif rand_augment: aug_list.append(T.RandAugment(opts=self.opts)) aug_list.append(T.ToTensor(opts=self.opts)) if getattr(self.opts, "image_augmentation.random_erase.enable", False): aug_list.append(T.RandomErasing(opts=self.opts)) return Compose(opts=self.opts, img_transforms=aug_list) def _validation_transforms( self, size: Union[Tuple, int], *args, **kwargs ) -> T.BaseTransformation: """Data transforms during validation or evaluation The order is Resize, CenterCrop, ToTensor Args: size: Size for resizing the input image. Expected to be an integer (width=height) or a tuple (height, width) Returns: An instance of `data.transforms.image_pil.BaseTransformation.` """ aug_list = [ T.Resize(opts=self.opts), T.CenterCrop(opts=self.opts), T.ToTensor(opts=self.opts), ] return Compose(opts=self.opts, img_transforms=aug_list) def _process_img_caption( self, input_img: Image.Image, captions_str: Union[str, List[str], List[List[str]]], img_transform_fn: T.BaseTransformation, zero_shot: bool, ) -> Tuple[Tensor, Tensor, int]: """Apply data augmentation to images and pre-processing to text captions Args: input_img: Input PIL Image captions_str: Text captions img_transform_fn: Image transform functions zero_shot: zero shot evaluation or not Returns: A tuple of image tensor, caption tensor, and max. sequence length of a sequence in caption tensor """ data = {"image": input_img} img_tensor = img_transform_fn(data)["image"] if zero_shot and self.cached_zero_shot_captions is not None: return ( img_tensor, self.cached_zero_shot_captions[0], self.cached_zero_shot_captions[1], ) max_seq_len = 0 # process caption if isinstance(captions_str, str): captions_tensor, max_seq_len = self._transform_text( self.text_tokenizer(caption_preprocessing(captions_str)) ) elif isinstance(captions_str, List): captions_tensor = [] for captions_str_i in captions_str: if isinstance(captions_str_i, List): # captions_str is [ [Num_templates_per_class] * Num_classes] captions_tensor_i = [] for ( captions_str_i_j ) in captions_str_i: # number of templates per class seq, seq_len = self._transform_text( self.text_tokenizer(caption_preprocessing(captions_str_i_j)) ) captions_tensor_i.append(seq) max_seq_len = max(max_seq_len, seq_len) captions_tensor_i = torch.stack(captions_tensor_i, dim=0) captions_tensor.append(captions_tensor_i) elif isinstance(captions_str_i, str): # captions_str is [Num_templates_per_image] seq, seq_len = self._transform_text( self.text_tokenizer(caption_preprocessing(captions_str_i)) ) captions_tensor.append(seq) max_seq_len = max(max_seq_len, seq_len) else: logger.error( "Got captions_str of type {}: {} from {}".format( type(captions_str_i), captions_str_i, captions_str ) ) # the shape of tensor is [Num_classes, captions_per_class, caption_length] # or [Captions_per_image, caption_length] captions_tensor = torch.stack(captions_tensor, dim=0) else: captions_tensor = None logger.error( "Captions should be either string, List[String] or List[List[str]]" ) if zero_shot and self.cached_zero_shot_captions is None: self.cached_zero_shot_captions = (captions_tensor, max_seq_len) return img_tensor, captions_tensor, max_seq_len
[docs] def get_zero_shot_pair( self, img_index: int ) -> Tuple[Image.Image, Union[str, List[str], List[List[str]]], int]: """Get image-text pair for zero-shot dataset along with classification label. Args: img_index: Image index Returns: A tuple of PIL image, captions, and class label """ img_path, captions_str, class_label = self.zero_shot_dataset[img_index] input_img = self.read_image_pil(img_path) return input_img, captions_str, class_label
[docs] def get_dataset_pair(self, img_index: int) -> Any: """Get image-text pair from the dataset. Sub-classes must implement this method.""" raise NotImplementedError
def __getitem__( self, sample_size_and_index: Tuple[int, int, int] ) -> Mapping[str, Union[Tensor, Mapping[str, Tensor]]]: """Returns the sample corresponding to the input sample index. Returned sample is transformed into the size specified by the input. Args: sample_size_and_index: Tuple of the form (crop_size_h, crop_size_w, sample_index) Returns: A dictionary with `samples` and `targets` as keys corresponding to input and label of a sample, respectively. Shapes: The shape of values in output dictionary, output_data, are as follows: output_data["samples"]["image"]: Shape is [Channels, Height, Width] output_data["samples"]["text"]: Shape is * [Context_Length] (single caption, as in CLIP datasets) * [Num_classes, Num_Captions, Context_length] (multiple captions per class, as in 0-shot Imagenet dataset) output_data["samples"]["padding_mask"]: Same as output_data["samples"]["text"] output_data["samples"]["max_seq_len"]: Shape is [1] output_data["targets"]: Shape is [1] """ crop_size_h, crop_size_w, img_index = sample_size_and_index transform_fn = self.get_augmentation_transforms(size=(crop_size_h, crop_size_w)) if self.zero_shot_dataset is not None: # read captions and image path from conceptual captions dataset # read captions and image path from zero-shot dataset input_img, captions_str, class_label = self.get_zero_shot_pair( img_index=img_index ) else: input_img, captions_str, class_label = self.get_dataset_pair( img_index=img_index ) if input_img is None: captions_tensor = torch.zeros(size=(self.context_length,), dtype=torch.long) data = { "samples": { "image": torch.zeros(size=(3, crop_size_h, crop_size_w)), "text": captions_tensor, "padding_mask": (captions_tensor == self.padding_index) if self.padding_index is not None else None, "max_seq_len": self.context_length, }, "targets": -1, } else: (img_tensor, captions_tensor, max_seq_len,) = self._process_img_caption( input_img=input_img, captions_str=captions_str, img_transform_fn=transform_fn, zero_shot=self.zero_shot_dataset is not None, ) data = { "samples": { "image": img_tensor, "text": captions_tensor, "padding_mask": (captions_tensor == self.padding_index) if self.padding_index is not None else None, "max_seq_len": max_seq_len, }, "targets": class_label, } if self.zero_shot_dataset is not None: data["zero_shot"] = 1 return data
[docs] def extra_repr(self) -> str: extra_repr_str = super().extra_repr() extra_repr_str += f"\n\t zero_shot={self.zero_shot_dataset}" return extra_repr_str
[docs]@COLLATE_FN_REGISTRY.register(name="multi_modal_img_text_collate_fn") def multi_modal_img_text_collate_fn( batch: List[Mapping[str, Union[Tensor, Mapping[str, Tensor]]]], opts: argparse.Namespace, ) -> Mapping[str, Union[Tensor, Mapping[str, Tensor]]]: """Combines a list of dictionaries into a single dictionary by concatenating matching fields.""" images = [] text_tokens = [] padding_mask = [] labels = [] truncate_seq_len = getattr(opts, "dataset.multi_modal_img_text.trunc_seq_len") zero_shot = batch[0].pop("zero_shot", 0) max_seq_len_in_batch = 1 # at least one token is required in the sequence for i, batch_i in enumerate(batch): inputs_i = batch_i.pop("samples") img_tensor = inputs_i.pop("image", None) if img_tensor is None: continue images.append(img_tensor) labels.append(batch_i.pop("targets")) text_data = inputs_i.pop("text") pad_mask = inputs_i.pop("padding_mask", None) max_seq_len_in_batch = max(max_seq_len_in_batch, inputs_i.pop("max_seq_len", 0)) if not zero_shot or (zero_shot and i == 0): # For zero-shot, all text captions are the same # so, we only aggregate for one batch element text_tokens.append(text_data) if pad_mask is not None: padding_mask.append(pad_mask) images = torch.stack(images, dim=0) text_tokens = torch.stack(text_tokens, dim=0) # truncate tokens based on the max. seq length if not truncate_seq_len: max_seq_len_in_batch = text_tokens.shape[-1] text_tokens = text_tokens[..., :max_seq_len_in_batch] if len(padding_mask) != 0: padding_mask = torch.stack(padding_mask, dim=0) padding_mask = padding_mask[..., :max_seq_len_in_batch] else: padding_mask = None labels = torch.tensor(labels, dtype=torch.long) channels_last = getattr(opts, "common.channels_last") if channels_last: images = images.to(memory_format=torch.channels_last) return { "samples": { "image": images, "text": text_tokens, "padding_mask": padding_mask, }, "targets": labels, }