Source code for data.sampler.chain_sampler

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

from __future__ import annotations

import argparse
import copy
import itertools
import json
from typing import Iterator, List, Mapping, Optional, Tuple

from torch.utils.data.sampler import Sampler

from data.sampler import SAMPLER_REGISTRY, build_sampler
from options.utils import flatten_yaml_as_dict
from utils import logger


[docs]@SAMPLER_REGISTRY.register(name="chain_sampler") class ChainSampler(Sampler): """ This class is a wrapper for iterating over datasets for multiple or similar tasks, typically useful for multi-task training. `task_name` and `sampler_config` are two mandatory keys that allows us to use task-specific data samplers. For specifying batch sizes, we use `train_batch_size0`, and `val_batch_size0` as keys for training and validation sets. Note that the batch sizes are scaled automatically depending on the number of GPUs. Args: opts: Command-line arguments data_samplers: dictionary containing different samplers Example:: # Example yaml config for combining different samplers is given below. # Please note that configuration for each sampler should start with `-` in `chain_sampler`. sampler: name: "chain_sampler" chain_sampler_mode: "sequential" chain_sampler: - task_name: "segmentation" train_batch_size0: 10 sampler_config: name: "variable_batch_sampler" use_shards: false num_repeats: 4 truncated_repeat_aug_sampler: false vbs: crop_size_width: 512 crop_size_height: 512 max_n_scales: 25 min_crop_size_width: 256 max_crop_size_width: 768 min_crop_size_height: 256 max_crop_size_height: 768 check_scale: 16 - task_name: "classification" train_batch_size0: 20 sampler_config: name: "batch_sampler" bs: crop_size_width: 512 crop_size_height: 512 """ _SUPPORTED_SAMPLING_MODES = ["sequential", "interleave"]
[docs] def __init__( self, opts: argparse.Namespace, *args, **kwargs, ) -> None: data_samplers = ChainSampler.build_chain_sampler(opts, *args, **kwargs) sampling_mode = getattr(opts, "sampler.chain_sampler_mode") if sampling_mode is None: logger.error(f"Sampling mode can't be None in {self.__class__.__name__}") if not isinstance(sampling_mode, str): logger.error( f"Sampling mode in {self.__class__.__name__} should be a type of string. Got: {type(sampling_mode)}" ) if sampling_mode not in self._SUPPORTED_SAMPLING_MODES: logger.error( f"Supported sampling mode in {self.__class__.__name__} are {self._SUPPORTED_SAMPLING_MODES}. " f"Got: {sampling_mode}" ) self.samplers_dict = data_samplers self.sampling_mode = sampling_mode
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Add arguments for chain sampler.""" if cls != ChainSampler: # Don't re-register arguments in subclasses that don't override `add_arguments()`. return parser group = parser.add_argument_group(cls.__name__) group.add_argument("--sampler.chain-sampler", type=json.loads, action="append") group.add_argument( "--sampler.chain-sampler-mode", type=str, default="sequential", choices=cls._SUPPORTED_SAMPLING_MODES, help="Chain sampler mode. Defaults to sequential.", ) return parser
[docs] @classmethod def build_chain_sampler( cls, opts: argparse.Namespace, n_data_samples: Mapping[str, int], is_training: bool = False, *args, **kwargs, ) -> Mapping[str, Sampler]: """Build chain sampler from command line arguments and sampler registry Args: opts: command-line arguments n_data_samples: Mapping containing the task name and number of dataset samples in task-specific dataset is_training: Training mode or not Returns: A dictionary, sampler_dict, containing information about sampler name and module. """ chain_sampler_opts = getattr(opts, "sampler.chain_sampler") if chain_sampler_opts is None: logger.error( f"sampler.chain_sampler in {cls.__name__} can't be None. Please specify " f"sampler.chain_sampler using a yaml file." ) if not isinstance(chain_sampler_opts, List): logger.error( f"Chain sampler options are expected as a List. " f"Got type: {type(chain_sampler_opts)} and values: {chain_sampler_opts}" ) num_samplers = len(chain_sampler_opts) if num_samplers < 1: logger.error("We need at least one sampler if using chain sampler") sampler_dict = {} for i, sampler_opts_ in enumerate(chain_sampler_opts): task_name = sampler_opts_.get("task_name", None) if task_name is None: logger.error("task_name is a mandatory key when using chain sampler") # get sampler configuration sampler_opts_as_dict = sampler_opts_.get("sampler_config", None) if sampler_opts_as_dict is None: logger.error( "sampler_config is a mandatory key when using chain sampler" ) train_batch_size = sampler_opts_.get("train_batch_size0", None) val_batch_size = sampler_opts_.get("val_batch_size0", None) # flatten the dictionary sampler_opts_as_dict = flatten_yaml_as_dict(sampler_opts_as_dict) # create a local copy and override the global opts with task-specific sampler opts sampler_opts = copy.deepcopy(opts) # `sampler_opts_as_dict` only contains the values of command-line arguments that are # defined in the yaml file. Therefore, if a user misses few arguments, we won't have access # to default values, leading to an error. To avoid this, we create a local copy of global # command-line arguments and update it with `sampler_opts_as_dict` arguments for k, v in sampler_opts_as_dict.items(): # we need to prefix each argument with sampler because we define individual samplers as # `sampler.vbs.*` and not `vbs.*` setattr(sampler_opts, "sampler." + k, v) # override the batch size of sampler if train_batch_size is not None: setattr(sampler_opts, "dataset.train_batch_size0", train_batch_size) if val_batch_size is not None: setattr(sampler_opts, "dataset.val_batch_size0", val_batch_size) if not isinstance(n_data_samples, Mapping): logger.error( "For chain sampler, we need n_data_samples as a dictionary with key as a task name " f"and value as number of data points. Got: {n_data_samples}" ) if task_name not in n_data_samples: logger.error( f"Sample mapping from dataset has following keys ({n_data_samples.keys()}) " f"and does not contain {task_name}. Please check." ) # build the sampler for the task sampler_dict[task_name] = build_sampler( opts=sampler_opts, n_data_samples=n_data_samples[task_name], is_training=is_training, *args, **kwargs, ) # see if the keys in n_data_samples and sampler_dict are the same or not # i.e., intersection is null. is_intersection = n_data_samples.keys().isdisjoint(sampler_dict) assert is_intersection is False, ( f"The keys in n_data_samples and sampler_dict are not the same. " f"Got: {n_data_samples.keys()} and {sampler_dict.keys()}" ) return sampler_dict
def _sequential_sampling(self) -> List[Tuple]: """Assuming we have samples from N datasets, then this function first iterates over the entire first dataset, then the entire second dataset, and so on Example: Dataset 1: [A, B, C] Dataset 2: [D, E, F, G, H] The result of this sampler would be something like this [A, B, C, D, E, F, G, H] """ for task_name, dataset_sampler in self.samplers_dict.items(): for batch_data in dataset_sampler: # append dataset name to the batch data yield [x + (task_name,) for x in batch_data] def _interleave_sampling(self) -> List[Tuple]: """Assuming we have samples from N datasets, then this function yields a batch from first dataset, then a batch from second dataset, and so on. In other words, batches are sampled from N datasets in a round-robin fashion. Example: Dataset 1: [A, B, C] Dataset 2: [D, E, F, G, H] The result of this sampler would be [A, D, B, E, C, F, G, H] """ items = self.samplers_dict.items() task_names, sampler_names = zip(*items) num_active_samplers = len(sampler_names) next_samplers = itertools.cycle( iter(data_sampler).__next__ for data_sampler in sampler_names ) while num_active_samplers: try: for i, next_sampler in enumerate(next_samplers): yield [ x + (task_names[i % num_active_samplers],) for x in next_sampler() ] except StopIteration: # Remove the sampler that we just exhausted from the cycle. num_active_samplers -= 1 next_samplers = itertools.cycle( itertools.islice(next_samplers, num_active_samplers) ) def __iter__(self) -> Iterator[Tuple]: if self.sampling_mode == "sequential": return self._sequential_sampling() elif self.sampling_mode == "interleave": return self._interleave_sampling() def __len__(self) -> int: return sum( [ len(dataset_sampler) for task_name, dataset_sampler in self.samplers_dict.items() ] )
[docs] def set_epoch(self, epoch: int) -> None: """Helper function to set epoch in each sampler. Args: epoch: Current epoch Returns: Nothing """ for task_name, dataset_sampler in self.samplers_dict.items(): if hasattr(dataset_sampler, "set_epoch"): dataset_sampler.set_epoch(epoch)
[docs] def update_scales( self, epoch: int, is_master_node: Optional[bool] = False, *args, **kwargs ) -> None: """Helper function to update scales in each sampler. This is typically useful for variable-batch samplers Args: epoch: Current epoch is_master_node: Master node or not. Returns: Nothing """ for task_name, dataset_sampler in self.samplers_dict.items(): if hasattr(dataset_sampler, "update_scales"): dataset_sampler.update_scales( epoch, is_master_node=is_master_node, *args, **kwargs )
[docs] def update_indices(self, new_indices: List[int]) -> None: """Update sample indices of the datasets with these new indices. Args: new_indices: Filtered indices of the samples that needs to be used in next epoch. Returns: Nothing ...note: This function is useful for sample-efficient training. This function may be implemented in future (depending on use-case) """ raise NotImplementedError
def __repr__(self) -> str: repr_str = f"{self.__class__.__name__}(\n" for k, v in self.samplers_dict.items(): repr_str += f"{k} --> " + v.__repr__().replace("\n\t", "\n\t\t").replace( "\n)", "\n\t)" ) repr_str += "\n" repr_str += "\n)" return repr_str