Source code for data.collate_fns.collate_functions

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import argparse
from typing import Any, List, Mapping

import torch
from torch import Tensor
from torch.utils.data import default_collate

from data.collate_fns import COLLATE_FN_REGISTRY
from utils import logger


[docs]@COLLATE_FN_REGISTRY.register(name="pytorch_default_collate_fn") def pytorch_default_collate_fn(batch: Any, *args, **kwargs) -> Any: """A wrapper around PyTorch's default collate function.""" batch = default_collate(batch) return batch
[docs]@COLLATE_FN_REGISTRY.register(name="unlabeled_image_data_collate_fn") def unlabeled_image_data_collate_fn( batch: List[Mapping[str, Any]], opts: argparse.Namespace ) -> Mapping[str, Any]: """ Combines a list of dictionaries into a single dictionary by concatenating matching fields. Each input dictionary is expected to have items with `samples` and `sample_id` as keys. The value for `samples` is expected to be a tensor and the value for `sample_id` is expected to be an integer. This function adds `targets` field to the output dictionary with dummy values to meet the expectations of training engine. Args: batch: A list of dictionaries opts: An argparse.Namespace instance. Returns: A dictionary with `samples`, `sample_id` and `targets` as keys. """ batch_size = len(batch) sample_size = [batch_size, *batch[0]["samples"].shape] img_dtype = batch[0]["samples"].dtype samples = torch.zeros(size=sample_size, dtype=img_dtype) sample_ids = torch.zeros(size=[batch_size], dtype=torch.long) for i, batch_i in enumerate(batch): samples[i] = batch_i["samples"] sample_ids[i] = batch_i["sample_id"] channels_last = getattr(opts, "common.channels_last") if channels_last: samples = samples.to(memory_format=torch.channels_last) # Add dummy labels to meet the expectations of training engine. dummy_labels = torch.full(size=[batch_size], fill_value=0, dtype=torch.long) return {"samples": samples, "sample_id": sample_ids, "targets": dummy_labels}
[docs]@COLLATE_FN_REGISTRY.register(name="image_classification_data_collate_fn") def image_classification_data_collate_fn( batch: List[Mapping[str, Any]], opts: argparse.Namespace ) -> Mapping[str, Any]: """Combines a list of dictionaries into a single dictionary by concatenating matching fields. Each input dictionary is expected to have items with `samples`,`sample_id` and `targets` as keys. The value for `samples` is expected to be a tensor and the values for `sample_id` and `targets` are expected to be integers. Args: batch: A list of dictionaries opts: An argparse.Namespace instance. Returns: A dictionary with `samples`, `sample_id` and `targets` as keys. """ batch_size = len(batch) img_size = [batch_size, *batch[0]["samples"].shape] img_dtype = batch[0]["samples"].dtype images = torch.zeros(size=img_size, dtype=img_dtype) sample_ids = torch.zeros(size=[batch_size], dtype=torch.long) labels = torch.full(size=[batch_size], fill_value=-1, dtype=torch.long) valid_indexes = [] for i, batch_i in enumerate(batch): images[i] = batch_i["samples"] sample_ids[i] = batch_i["sample_id"] label_i = batch_i["targets"] labels[i] = label_i if label_i != -1: valid_indexes.append(i) valid_indexes = torch.tensor(valid_indexes, dtype=torch.long) images = torch.index_select(images, dim=0, index=valid_indexes) sample_ids = torch.index_select(sample_ids, dim=0, index=valid_indexes) labels = torch.index_select(labels, dim=0, index=valid_indexes) channels_last = getattr(opts, "common.channels_last") if channels_last: images = images.to(memory_format=torch.channels_last) return {"samples": images, "targets": labels, "sample_id": sample_ids}
[docs]@COLLATE_FN_REGISTRY.register(name="default_collate_fn") def default_collate_fn( batch: List[Mapping[str, Tensor]], opts: argparse.Namespace ) -> Mapping[str, Tensor]: """Combines a list of dictionaries into a single dictionary by concatenating matching fields. Args: batch: A list of dictionaries opts: An argparse.Namespace instance. Returns: A dictionary with the same keys as batch[0]. """ batch_size = len(batch) # get the keys for first element in the list, assuming all elements have the same keys keys = list(batch[0].keys()) new_batch = {k: [] for k in keys} for b in range(batch_size): for k in keys: new_batch[k].append(batch[b][k]) # stack the keys for k in keys: batch_elements = new_batch.pop(k) if isinstance(batch_elements[0], (int, float)): # list of ints or floats batch_elements = torch.as_tensor(batch_elements) else: # stack tensors (including 0-dimensional) try: batch_elements = torch.stack(batch_elements, dim=0).contiguous() except Exception as e: logger.error("Unable to stack the tensors. Error: {}".format(e)) new_batch[k] = batch_elements return new_batch