#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import argparse
from typing import Optional, Tuple
from data.datasets.classification.base_image_classification_dataset import (
BaseImageClassificationDataset,
)
from data.datasets.classification.base_imagenet_shift_dataset import (
BaseImageNetShiftDataset,
)
from data.datasets.dataset_base import BaseDataset, BaseImageDataset, BaseVideoDataset
from data.datasets.detection.base_detection import BaseDetectionDataset
from data.datasets.multi_modal_img_text import arguments_multi_modal_img_text
from data.datasets.segmentation.base_segmentation import BaseImageSegmentationDataset
from utils import logger
from utils.ddp_utils import is_master
from utils.registry import Registry
DATASET_REGISTRY = Registry(
registry_name="dataset_registry",
base_class=BaseDataset,
lazy_load_dirs=["data/datasets"],
internal_dirs=["internal", "internal/projects/*"],
)
[docs]def build_dataset_from_registry(
opts: argparse.Namespace,
is_training: bool = True,
is_evaluation: bool = False,
*args,
**kwargs,
) -> BaseDataset:
"""Helper function to build a dataset from dataset registry
Args:
opts: Command-line arguments
is_training: Training mode or not. Defaults to True.
is_evaluation: Evaluation mode or not. Defaults to False.
Returns:
An instance of BaseDataset
...note:
`is_training` is used to indicate whether the dataset is used for training or validation
On the other hand, `is_evaluation` mode is used to indicate the dataset is used for testing.
Theoretically, `is_training=False` and `is_evaluation=True` should be the same. However, for some datasets
(especially segmentation), validation dataset transforms are different from
test transforms because each image has different resolution, making it difficult to construct
batches. Therefore, we treat these two modes different. For datasets, where validation and testing
transforms are the same, we set evaluation transforms the same as the validation transforms (e.g., in ImageNet
object classification).
"""
dataset_category = getattr(opts, "dataset.category")
if dataset_category is None:
logger.error("Please specify dataset category using --dataset.category")
dataset_name = getattr(opts, f"dataset.name")
if dataset_name is None:
logger.error("Please specify dataset name using --dataset.name")
dataset = DATASET_REGISTRY[dataset_name, dataset_category](
opts=opts, is_training=is_training, is_evaluation=is_evaluation, *args, **kwargs
)
return dataset
[docs]def get_test_dataset(opts: argparse.Namespace, *args, **kwargs) -> BaseDataset:
"""Helper function to build a dataset for testing.
Args:
opts: Command-line arguments
Returns:
An instance of BaseDataset
"""
test_dataset = build_dataset_from_registry(
opts, is_training=False, is_evaluation=True, *args, **kwargs
)
if is_master(opts):
logger.log("Evaluation dataset details: ")
print("{}".format(test_dataset))
return test_dataset
[docs]def get_train_val_datasets(
opts: argparse.Namespace, *args, **kwargs
) -> Tuple[BaseDataset, Optional[BaseDataset]]:
"""Helper function to build a dataset for training and validation.
Args:
opts: Command-line arguments
Returns:
Training and (optionally) validation datasets.
"""
disable_val = getattr(opts, "dataset.disable_val")
is_master_node = is_master(opts)
train_dataset = build_dataset_from_registry(
opts, is_training=True, is_evaluation=False, *args, **kwargs
)
if is_master_node:
logger.log("Training dataset details are given below")
print(train_dataset)
valid_dataset = None
if not disable_val:
valid_dataset = build_dataset_from_registry(
opts, is_training=False, is_evaluation=False, *args, **kwargs
)
if is_master_node:
logger.log("Validation dataset details are given below")
print(valid_dataset)
return train_dataset, valid_dataset
[docs]def arguments_dataset(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Add dataset-specific arguments from BaseDataset, BaseImageDataset,
BaseImageClassificationDataset, BaseImageNetShiftDataset,
BaseVideoDataset, zero-shot datasets, and DATASET_REGISTRY.
"""
parser = BaseDataset.add_arguments(parser)
parser = BaseImageDataset.add_arguments(parser)
parser = BaseImageSegmentationDataset.add_arguments(parser)
parser = BaseVideoDataset.add_arguments(parser)
parser = BaseImageClassificationDataset.add_arguments(parser)
parser = BaseImageNetShiftDataset.add_arguments(parser)
parser = BaseDetectionDataset.add_arguments(parser)
try:
from internal.utils.server_utils import dataset_server_args
parser = dataset_server_args(parser)
except ImportError:
pass
# add multi-modal and zero-shot arguments
parser = arguments_multi_modal_img_text(parser=parser)
# add dataset specific arguments
parser = DATASET_REGISTRY.all_arguments(parser)
return parser