Source code for data.collate_fns

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import argparse

from utils import logger
from utils.registry import Registry

COLLATE_FN_REGISTRY = Registry(
    "collate_fn",
    lazy_load_dirs=["data/collate_fns"],
    internal_dirs=["internal", "internal/projects/*"],
)


[docs]def arguments_collate_fn(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Add arguments related to collate function""" group = parser.add_argument_group("Collate function arguments") group.add_argument( "--dataset.collate-fn-name-train", type=str, default="pytorch_default_collate_fn", help="Name of collate function for training. Defaults to pytorch_default_collate_fn.", ) group.add_argument( "--dataset.collate-fn-name-val", type=str, default="pytorch_default_collate_fn", help="Name of collate function for validation. Defaults to pytorch_default_collate_fn.", ) group.add_argument( "--dataset.collate-fn-name-test", type=str, default="pytorch_default_collate_fn", help="Name of collate function used for evaluation. " "Default is pytorch_default_collate_fn.", ) return parser
[docs]def build_collate_fn(opts, *args, **kwargs): collate_fn_name_train = getattr(opts, "dataset.collate_fn_name_train") if collate_fn_name_train is None: logger.error( "Please specify collate function for training dataset using " "--dataset.collate-fn-name-train" ) collate_fn_name_val = getattr(opts, "dataset.collate_fn_name_val") if collate_fn_name_val is None: logger.error( "Please specify collate function for training dataset using " "--dataset.collate-fn-name-val" ) collate_fn_train = COLLATE_FN_REGISTRY[collate_fn_name_train] collate_fn_val = COLLATE_FN_REGISTRY[collate_fn_name_val] return collate_fn_train, collate_fn_val
[docs]def build_test_collate_fn(opts, *args, **kwargs): collate_fn_name_test = getattr(opts, "dataset.collate_fn_name_test") # for test time if collate_fn_name_test is not None: return COLLATE_FN_REGISTRY[collate_fn_name_test] return None