Source code for data.datasets.multi_modal_img_text.flickr

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import argparse
import json
import os
from typing import Dict, List, Optional, Tuple

import torch

from data.datasets import DATASET_REGISTRY
from data.datasets.multi_modal_img_text.base_multi_modal_img_text import (
    BaseMultiModalImgText,
)
from utils import logger
from utils.ddp_utils import is_master


[docs]@DATASET_REGISTRY.register(name="flickr", type="multi_modal_image_text") class FlickrDataset(BaseMultiModalImgText): """ Dataset loader for Flickr-30k and Flickr-8k datasets. For more info see: http://hockenmaier.cs.illinois.edu/8k-pictures.html https://shannon.cs.illinois.edu/DenotationGraph/ Splits: train, val, and test Also known in literature as Karpathy splits https://cs.stanford.edu/people/karpathy/deepimagesent/ Tracking license info: Captions have CC BY 3.0 license (see links above). Splits are under BSD License (see Github of NeuralTalk by Karpathy et. al.). Images are from Flickr. We do not own them and are only used for research purposes. Args: opts: command-line arguments is_training (Optional[bool]): A flag used to indicate training or validation mode. Default: True is_evaluation (Optional[bool]): A flag used to indicate evaluation (or inference) mode. Default: False """
[docs] def get_dataset(self, *args, **kwargs) -> None: """ The data under `self.root` is expected to consist of: dataset.json # Karpathy splits + captions images/ # Raw images The metdatadata cap be downloaded from: https://cs.stanford.edu/people/karpathy/deepimagesent/flickr30k.zip Images can be obtained from: Flickr-8k: http://hockenmaier.cs.illinois.edu/8k-pictures.html Flickr-30k: https://shannon.cs.illinois.edu/DenotationGraph/ """ metadata_path = os.path.join(self.root, "dataset.json") with open(metadata_path, "r") as metadata_file: metadata = json.load(metadata_file)["images"] split = self.mode self.samples = [ { "image_name": sample["filename"], "captions": [x["raw"] for x in sample["sentences"]], } for sample in metadata if sample["split"] == split ] if self.is_training: # For training, flatten the captions by copying each image multiple times # This way at each epoch, each caption will be seen 1 time but each image # will be seen #num_captions (= 5) times. self.samples = [ {"image_name": sample["image_name"], "captions": caption} for sample in self.samples for caption in sample["captions"] ]
# No need for shuffling, since dataloader takes care of it def __len__(self) -> int: return len(self.samples) def __getitem__(self, batch_indexes_tup: Tuple[int, int, int]) -> Dict[str, Dict]: """Return the sample given an index, a pre-specified height and width. Args: batch_indexes_tup: Tuple of (crop size height, crop size width, image index) Returns: A dict like: { "samples": { "image": FloatTensor of shape [C, H, W] -> [B, C, H, W] after collate "text": LongTensor of shape if is_training: [num_tokens] -> [B, T] after collate o.w.: [num_captions, num_tokens] -> [B, N, T] after collate "padding_mask": Optional; same size as text "max_seq_len": int indicating the maximum caption length (T) Gets removed after collate }, "targets": -1 } """ crop_size_h, crop_size_w, img_index = batch_indexes_tup crop_size = (crop_size_h, crop_size_w) transform_fn = self.get_augmentation_transforms(size=crop_size) sample = self.samples[img_index] img_path = os.path.join(self.root, "images", sample["image_name"]) captions = sample["captions"] input_img = self.read_image_pil(img_path) img_tensor, captions_tensor, max_seq_len = self._process_img_caption( input_img=input_img, captions_str=captions, img_transform_fn=transform_fn, zero_shot=self.zero_shot_dataset is not None, ) data = { "samples": { "image": img_tensor, "text": captions_tensor, "padding_mask": (captions_tensor == self.padding_index) if self.padding_index is not None else None, "max_seq_len": max_seq_len, }, "targets": -1, } return data def __repr__(self) -> str: return "{}(\n\troot={}\n\tis_training={}\n\tsamples={}\n)".format( self.__class__.__name__, self.root, self.is_training, len(self), )