#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import argparse
from functools import partial
from typing import Mapping, Optional, Tuple, Union
from torch.utils.data.sampler import Sampler
from data.collate_fns import build_collate_fn, build_test_collate_fn
from data.datasets import BaseDataset, get_test_dataset, get_train_val_datasets
from data.loader.dataloader import CVNetsDataLoader
from data.sampler import build_sampler
from utils import logger
from utils.ddp_utils import is_master
from utils.tensor_utils import image_size_from_opts
[docs]def create_test_loader(opts: argparse.Namespace) -> CVNetsDataLoader:
"""Helper function to create and return a dataset loader for test dataset from command-line arguments"""
test_dataset = get_test_dataset(opts)
n_test_samples = get_num_data_samples_as_int_or_mapping(test_dataset)
is_master_node = is_master(opts)
# overwrite the validation argument
setattr(
opts,
"dataset.val_batch_size0",
getattr(opts, "dataset.eval_batch_size0"),
)
# we don't need variable batch sampler for evaluation
sampler_name = getattr(opts, "sampler.name", "batch_sampler")
crop_size_h, crop_size_w = image_size_from_opts(opts)
if sampler_name.find("video") > -1 and sampler_name != "video_batch_sampler":
clips_per_video = getattr(opts, "sampler.vbs.clips_per_video", 1)
frames_per_clip = getattr(opts, "sampler.vbs.num_frames_per_clip", 8)
setattr(opts, "sampler.name", "video_batch_sampler")
setattr(opts, "sampler.bs.crop_size_width", crop_size_w)
setattr(opts, "sampler.bs.crop_size_height", crop_size_h)
setattr(opts, "sampler.bs.clips_per_video", clips_per_video)
setattr(opts, "sampler.bs.num_frames_per_clip", frames_per_clip)
elif sampler_name.find("var") > -1:
setattr(opts, "sampler.name", "batch_sampler")
setattr(opts, "sampler.bs.crop_size_width", crop_size_w)
setattr(opts, "sampler.bs.crop_size_height", crop_size_h)
test_sampler = build_sampler(
opts=opts,
n_data_samples=n_test_samples,
is_training=False,
get_item_metadata=test_dataset.get_item_metadata,
)
collate_fn_test = build_test_collate_fn(opts=opts)
data_workers = getattr(opts, "dataset.workers", 1)
persistent_workers = False
pin_memory = False
test_loader = CVNetsDataLoader(
dataset=test_dataset,
batch_size=1,
batch_sampler=test_sampler,
num_workers=data_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
collate_fn=partial(collate_fn_test, opts=opts)
if collate_fn_test is not None
else None,
)
if is_master_node:
logger.log("Evaluation sampler details: ")
print("{}".format(test_sampler))
return test_loader
[docs]def create_train_val_loader(
opts: argparse.Namespace,
) -> Tuple[CVNetsDataLoader, Optional[CVNetsDataLoader], Sampler]:
"""Helper function to create training and validation data loaders.
Args:
opts: Command-line arguments
Returns:
A tuple containing training data loader, (optional) validation data loader, and training data sampler.
"""
train_dataset, valid_dataset = get_train_val_datasets(opts)
n_train_samples = get_num_data_samples_as_int_or_mapping(train_dataset)
is_master_node = is_master(opts)
train_sampler = build_sampler(
opts=opts,
n_data_samples=n_train_samples,
is_training=True,
get_item_metadata=train_dataset.get_item_metadata,
)
if valid_dataset is not None:
n_valid_samples = get_num_data_samples_as_int_or_mapping(valid_dataset)
valid_sampler = build_sampler(
opts=opts,
n_data_samples=n_valid_samples,
is_training=False,
get_item_metadata=valid_dataset.get_item_metadata,
)
else:
valid_sampler = None
data_workers = getattr(opts, "dataset.workers", 1)
persistent_workers = getattr(opts, "dataset.persistent_workers", False) and (
data_workers > 0
)
pin_memory = getattr(opts, "dataset.pin_memory", False)
prefetch_factor = getattr(opts, "dataset.prefetch_factor", 2)
collate_fn_train, collate_fn_val = build_collate_fn(opts=opts)
train_loader = CVNetsDataLoader(
dataset=train_dataset,
batch_size=1, # Handled inside data sampler
num_workers=data_workers,
pin_memory=pin_memory,
batch_sampler=train_sampler,
persistent_workers=persistent_workers,
collate_fn=partial(collate_fn_train, opts=opts)
if collate_fn_train is not None
else None,
prefetch_factor=prefetch_factor,
)
if valid_dataset is not None:
val_loader = CVNetsDataLoader(
dataset=valid_dataset,
batch_size=1,
batch_sampler=valid_sampler,
num_workers=data_workers,
pin_memory=pin_memory,
persistent_workers=persistent_workers,
collate_fn=partial(collate_fn_val, opts=opts)
if collate_fn_val is not None
else None,
)
else:
val_loader = None
if is_master_node:
logger.log("Training sampler details: ")
print("{}".format(train_sampler))
if valid_dataset is not None:
logger.log("Validation sampler details: ")
print("{}".format(valid_sampler))
logger.log("Number of data workers: {}".format(data_workers))
return train_loader, val_loader, train_sampler
[docs]def get_num_data_samples_as_int_or_mapping(
dataset: BaseDataset,
) -> Union[int, Mapping[str, int]]:
"""Return the number of samples in the dataset.
The dataset can be a single or composition of multiple datasets (as in multi-task learning). For a single
dataset, the number of samples is integer while for multiple datasets, a dictionary is returned with task name and
number of samples per task.
Args:
dataset: An instance of `data.datasets.BaseDataset` class
Returns:
An integer for single dataset and mapping for composite datasets.
"""
if hasattr(dataset, "get_dataset_length_as_mapping"):
return dataset.get_dataset_length_as_mapping()
else:
return len(dataset)