data.collate_fns package

Submodules

data.collate_fns.byteformer_collate_functions module

This file contains collate functions used by ByteFormer.

Since the model operates on a variety of input types, these collate functions are not associated with a particular dataset.

These transforms are applied before the model (rather than inside the model) to take advantage of parallelism, and to avoid the need to move tensors from the GPU, back to the CPU, then back to GPU (since these transforms cannot be done on GPU).

data.collate_fns.byteformer_collate_functions.byteformer_image_collate_fn(batch: List[Mapping[str, Tensor]], opts: Namespace) Mapping[str, Tensor][source]

Apply augmentations specific to ByteFormer image training, then perform padded collation.

See ByteFormer for more information on this modeling approach.

Parameters:
  • batch – The batch of data.

  • opts – The global arguments.

Returns:

The modified batch.

data.collate_fns.byteformer_collate_functions.apply_padding(batch: List[Mapping[str, Dict[str, Tensor] | Tensor]], opts: Namespace, key: str | None = None) List[Mapping[str, Tensor]][source]

Apply padding to make samples the same length.

The input is a list of dictionaries of the form:

[{“samples”: @entry, …}, …].

If @key is specified, @entry has the form {@key: @value}, where @value corresponds to the entry that should be padded. Otherwise, @entry is assumed to be a tensor.

The tensor mentioned in the above paragraph will have shape [batch_size,

sequence_length, …].

Parameters:
  • batch – The batch of data.

  • opts – The global arguments.

  • key – The key of the sample element to pad. If @key is None, the entry is assumed to be a tensor.

Returns:

The modified batch of size [batch_size, padded_sequence_length, …].

data.collate_fns.byteformer_collate_functions.apply_pil_save(batch: List[Mapping[str, Tensor]], opts: Namespace) List[Mapping[str, Tensor]][source]

Apply the PILSave transform to each batch element.

Parameters:
  • batch – The batch of data.

  • opts – The global arguments.

Returns:

The modified batch.

data.collate_fns.byteformer_collate_functions.apply_shuffle_bytes(batch: List[Mapping[str, Tensor]], opts: Namespace) List[Mapping[str, Tensor]][source]

Apply the ShuffleBytes transform to each batch element.

Parameters:
  • batch – The batch of data.

  • opts – The global arguments.

Returns:

The modified batch.

data.collate_fns.byteformer_collate_functions.apply_mask_positions(batch: List[Mapping[str, Tensor]], opts: Namespace) List[Mapping[str, Tensor]][source]

Apply the MaskPositions transform to each batch element.

Parameters:
  • batch – The batch of data.

  • opts – The global arguments.

Returns:

The modified batch.

data.collate_fns.byteformer_collate_functions.apply_random_uniform_noise(batch: List[Mapping[str, Tensor]], opts: Namespace) List[Mapping[str, Tensor]][source]

Apply the RandomUniformNoise transform to each batch element.

Parameters:
  • batch – The batch of data.

  • opts – The global arguments.

Returns:

The modified batch.

data.collate_fns.byteformer_collate_functions.apply_byte_permutation(batch: List[Mapping[str, Tensor]], opts: Namespace) List[Mapping[str, Tensor]][source]

Apply the BytePermutation transform to each batch element.

Parameters:
  • batch – The batch of data.

  • opts – The global arguments.

Returns:

The modified batch.

data.collate_fns.byteformer_collate_functions.apply_torchaudio_save(batch: List[Mapping[str, Tensor]], opts: Namespace) List[Mapping[str, Tensor]][source]

Apply the TorchaudioSave transform to each batch element.

Parameters:
  • batch – The batch of data.

  • opts – The global arguments.

Returns:

The modified batch.

data.collate_fns.byteformer_collate_functions.byteformer_audio_collate_fn(batch: List[Mapping[str, Tensor]], opts: Namespace) Mapping[str, Tensor][source]

Apply augmentations specific to ByteFormer audio training, then perform padded collation.

See ByteFormer for more information on this modeling approach.

Parameters:
  • batch – The batch of data.

  • opts – The global arguments.

Returns:

The modified batch.

data.collate_fns.collate_functions module

data.collate_fns.collate_functions.pytorch_default_collate_fn(batch: Any, *args, **kwargs) Any[source]

A wrapper around PyTorch’s default collate function.

data.collate_fns.collate_functions.unlabeled_image_data_collate_fn(batch: List[Mapping[str, Any]], opts: Namespace) Mapping[str, Any][source]

Combines a list of dictionaries into a single dictionary by concatenating matching fields.

Each input dictionary is expected to have items with samples and sample_id as keys. The value for samples is expected to be a tensor and the value for sample_id is expected to be an integer.

This function adds targets field to the output dictionary with dummy values to meet the expectations of training engine.

Parameters:
  • batch – A list of dictionaries

  • opts – An argparse.Namespace instance.

Returns:

A dictionary with samples, sample_id and targets as keys.

data.collate_fns.collate_functions.image_classification_data_collate_fn(batch: List[Mapping[str, Any]], opts: Namespace) Mapping[str, Any][source]

Combines a list of dictionaries into a single dictionary by concatenating matching fields.

Each input dictionary is expected to have items with samples,`sample_id` and targets as keys. The value for samples is expected to be a tensor and the values for sample_id and targets are expected to be integers.

Parameters:
  • batch – A list of dictionaries

  • opts – An argparse.Namespace instance.

Returns:

A dictionary with samples, sample_id and targets as keys.

data.collate_fns.collate_functions.default_collate_fn(batch: List[Mapping[str, Tensor]], opts: Namespace) Mapping[str, Tensor][source]

Combines a list of dictionaries into a single dictionary by concatenating matching fields.

Parameters:
  • batch – A list of dictionaries

  • opts – An argparse.Namespace instance.

Returns:

A dictionary with the same keys as batch[0].

Module contents

data.collate_fns.arguments_collate_fn(parser: ArgumentParser) ArgumentParser[source]

Add arguments related to collate function

data.collate_fns.build_collate_fn(opts, *args, **kwargs)[source]
data.collate_fns.build_test_collate_fn(opts, *args, **kwargs)[source]