Source code for data.datasets.multi_modal_img_text.zero_shot.base_zero_shot

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import argparse
from typing import List, Tuple

from options.utils import extract_opts_with_prefix_replacement
from utils import logger


[docs]class BaseZeroShotDataset(object): """Base Dataset class for zero shot tasks. Args: opts: Command-line arguments. """
[docs] def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: # we need to set the default value of this one if getattr(opts, "dataset.multi_modal_img_text.zero_shot.trove.enable", False): try: from internal.utils.server_utils import load_from_data_server opts = load_from_data_server( opts=opts, is_training=False, is_evaluation=False, # use root_val / dir_val arg_prefix="dataset.multi_modal_img_text.zero_shot", ) except Exception as e: logger.error("Unable to load from the server. Error: {}".format(str(e))) # Extracting zero-shot options to be able to build them separately in # child classes dataset_opts = extract_opts_with_prefix_replacement( opts, match_prefix="dataset.multi_modal_img_text.zero_shot.", replacement_prefix="dataset.", ) dataset_opts = vars(dataset_opts) dataset_opts.update( { "dataset.num_samples_per_category": -1, "dataset.percentage_of_samples": 100.0, } ) self.dataset_opts = argparse.Namespace(**dataset_opts) root = getattr(opts, "dataset.multi_modal_img_text.zero_shot.root_val") self.root = root self.opts = opts # Initialize text prompts using the static method `class_names`. text_prompts = [] for class_id, class_name in enumerate(self.class_names()): text_prompts.append(self.generate_text_prompts(class_name.lower())) self.text_prompts = text_prompts
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: if cls != BaseZeroShotDataset: # Don't re-register arguments in subclasses that don't override # `add_arguments()`. return parser group = parser.add_argument_group(cls.__name__) group.add_argument( "--dataset.multi-modal-img-text.zero-shot-eval", action="store_true", default=False, help="Use zero shot evaluation. Defaults to False.", ) group.add_argument( "--dataset.multi-modal-img-text.zero-shot.name", type=str, default=None, help="Name of the dataset for zero-shot evaluation. Defaults to None.", ) group.add_argument( "--dataset.multi-modal-img-text.zero-shot.root-val", type=str, default=None, help="Location of the dataset for zero-shot evaluation. Defaults to None.", ) return parser
def __len__(self) -> int: """Return the number of samples in the dataset.""" raise NotImplementedError( "Sub-classes should implement `__len__` that returns the number of samples" " in the dataset." ) def __getitem__(self, img_index: int) -> Tuple[str, List[List[str]], int]: """Return image path and text templates for a given image index. Args: img_index: Index of the image. Returns: Tuple containing image path, list of captions, and image label """ img_path, target = self.samples[img_index] return img_path, self.text_prompts, target
[docs] @classmethod def class_names(cls) -> List[str]: """Return the name of the classes in the dataset. Label is index in the list. The order of class names in the returned list determine the numerical class label. """ raise NotImplementedError( "Sub-classes should define `class_names` that returns the list of class" " names in the order of class labels." )
[docs] @staticmethod def generate_text_prompts(class_name: str) -> List[str]: """Return a list of prompts for the given class name.""" raise NotImplementedError( "Sub-classes should define `generate_text_prompts` that creates a list of" " prompts for a given class name." )
def __repr__(self) -> str: return "{}(root={})".format(self.__class__.__name__, self.root)