Source code for data.sampler

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import argparse
from typing import Callable, Dict, Mapping, Optional, Union

from torch.utils.data.sampler import Sampler

from data.sampler.base_sampler import BaseSampler, BaseSamplerDDP
from utils.registry import Registry

SAMPLER_REGISTRY = Registry(
    registry_name="data_samplers",
    base_class=Sampler,
    # lazily import the samplers
    lazy_load_dirs=["data/sampler"],
    internal_dirs=["internal", "internal/projects/*"],
)


[docs]def build_sampler( opts: argparse.Namespace, n_data_samples: Union[int, Mapping[str, int]], is_training: bool = False, get_item_metadata: Optional[Callable[[int], Dict]] = None, *args, **kwargs ) -> Sampler: """Helper function to build data sampler from command-line arguments Args: opts: Command-line arguments n_data_samples: Number of data samples. It can be an integer specifying number of data samples for a given task or a mapping of task name and data samples per task in case of a chain sampler. get_item_metadata: A callable that provides sample metadata, given sample index. is_training: Training mode or not. Defaults to False. Returns: Data sampler over which we can iterate. """ sampler_name = getattr(opts, "sampler.name") is_distributed = getattr(opts, "ddp.use_distributed") if ( is_distributed and sampler_name.split("_")[-1] != "ddp" and sampler_name != "chain_sampler" ): # In case of a DDP environment, add `_ddp` to sampler name if not present # with an exception to chain_sampler (which is nothing but a loop over existing samplers) sampler_name = sampler_name + "_ddp" sampler = SAMPLER_REGISTRY[sampler_name]( opts, n_data_samples=n_data_samples, is_training=is_training, get_item_metadata=get_item_metadata, ) return sampler
[docs]def add_sampler_arguments( parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: """Add sampler arguments to parser from SAMPLER_REGISTRY, BaseSampler, and BaseSamplerDDP""" parser = SAMPLER_REGISTRY.all_arguments(parser) parser = BaseSampler.add_arguments(parser) parser = BaseSamplerDDP.add_arguments(parser) return parser