Source code for data.transforms.audio

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import argparse
import math
import pathlib
import random
from typing import Callable, Dict, List, Optional, Union

import torch
import torchaudio

from data.transforms import TRANSFORMATIONS_REGISTRY, BaseTransformation
from data.transforms.audio_aux import mfccs


[docs]@TRANSFORMATIONS_REGISTRY.register(name="audio_gain", type="audio") class Gain(BaseTransformation): """ This class implements gain augmentation for audio. """
[docs] def __init__(self, opts, *args, **kwargs) -> None: super().__init__(opts=opts) self.gain_levels = getattr(opts, "audio_augmentation.gain.levels")
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: group = parser.add_argument_group(title=cls.__name__) group.add_argument( "--audio-augmentation.gain.levels", type=float, default=[0], nargs="+", help="Gain levels to use for augmentation (in dB).", ) return parser
def __call__(self, data: Dict) -> Dict: """ This function implements the gain transformation by scaling the input audio with a specific constant determined from the gain_levels Args: data: A dictionary containing {"samples": {"audio": @audio}}, where @audio is a tensor of shape [num_channels, sequence_length]. Returns: The modified dictionary with the augmented audio. """ audio = data["samples"]["audio"] gain_level = random.choice(self.gain_levels) augmented_audio = 10.0 ** (gain_level / 20.0) * audio data["samples"]["audio"] = augmented_audio return data def __repr__(self): return "{}(gain_levels={})".format(self.__class__.__name__, self.gain_levels)
[docs]@TRANSFORMATIONS_REGISTRY.register(name="audio_ambient_noise", type="audio") class Noise(BaseTransformation): """ This class implements ambient noise augmentation for audio. """
[docs] def __init__( self, opts: argparse.Namespace, is_training: bool = True, noise_files_dir: Optional[str] = None, *args, **kwargs, ) -> None: super().__init__(opts=opts) self.gain_levels = getattr(opts, "audio_augmentation.noise.levels") self.cache_size = getattr(opts, "audio_augmentation.noise.cache_size") self.refresh_freq = getattr(opts, "audio_augmentation.noise.refresh_freq") self.refresh_counter = self.refresh_freq self.noise_files_dir = noise_files_dir if self.noise_files_dir is None: self.noise_files_dir = getattr(opts, "audio_augmentation.noise.files_dir") self.noise_files = [] self.pointer = 0 if self.noise_files_dir is not None: self.noise_files = sorted( pathlib.Path(self.noise_files_dir).glob("**/*.wav") ) if is_training: random.shuffle(self.noise_files) self.noise_waves = self.load_noise_files(cache_size=self.cache_size) if len(self.noise_files) == 0: raise ValueError( "--audio-augmentation.noise.files-dir must be provided for this augmentation" )
[docs] def load_noise_files(self, cache_size: int) -> List[torch.TensorType]: """ This method caches a list of noise files for on the fly augmentation. """ noise_waves = [] for i in range(cache_size): noise_wav_file = self.noise_files[self.pointer % len(self.noise_files)] self.pointer += 1 noise, sample_rate = torchaudio.load(noise_wav_file) assert ( noise.dtype == torch.float32 ), f"Expected noise file {noise_wav_file} to decode to float32 audio, but got {noise.dtype}." noise_waves.append((noise, sample_rate)) return noise_waves
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: group = parser.add_argument_group(title=cls.__name__) group.add_argument( "--audio-augmentation.noise.enable", action="store_true", help="Use {}. This flag is useful when you want to study the effect of" " different transforms.".format(cls.__name__), ) group.add_argument( "--audio-augmentation.noise.levels", type=float, default=[-100], nargs="+", help="Gain levels to use for noise augmentation (in dB).", ) group.add_argument( "--audio-augmentation.noise.cache-size", type=int, default=10, help="Number of augmentation noises to cache.", ) group.add_argument( "--audio-augmentation.noise.files-dir", type=str, default=None, help="Directory of noise files.", ) group.add_argument( "--audio-augmentation.noise.refresh-freq", type=int, default=0, help="Frequency to refresh noise files (default 0 means never refresh).", ) return parser
def __call__(self, data: Dict) -> Dict: """ This function adds a random noise sample selected from the noise samples provided in the noise directory scaled by a random gain. The sample should contain floating-point values in [-1, 1]. Args: data: A dictionary containing {"samples": {"audio": @audio}}, where @audio is a tensor of shape [num_channels, sequence_length]. Returns: The modified dictionary with the augmented audio. """ audio = data["samples"]["audio"] assert audio.shape[0] in [1, 2] gain_level = random.choice(self.gain_levels) noise_wave, noise_fps = random.choice(self.noise_waves) # @noise_wave is in [num_channels, sequence_length] format. assert math.isclose(data["metadata"]["audio_fps"], noise_fps, rel_tol=1e-6) if noise_wave.shape[-1] >= audio.shape[-1]: random_start_point = random.randint( 0, noise_wave.shape[-1] - audio.shape[-1] ) noise_wave = noise_wave[ :, random_start_point : random_start_point + audio.shape[-1] ] else: noise_wave = torch.nn.functional.pad( noise_wave.unsqueeze(0), (0, audio.shape[-1] - noise_wave.shape[-1]), mode="circular", ) # @noise_wave is in [1, num_channels, sequence_length] format. noise_wave = noise_wave[0] augmented_audio = audio + 10.0 ** (gain_level / 20.0) * noise_wave data["samples"]["audio"] = augmented_audio self.refresh_counter -= 1 if ( self.refresh_counter <= 0 and self.refresh_freq > 0 and self.noise_files_dir is not None ): # Refresh cache when met refresh criteria. self.noise_waves = self.load_noise_files(self.cache_size) self.refresh_counter = self.refresh_freq return data def __repr__(self): return "{}(gain_levels={}, noise_files_dir={})".format( self.__class__.__name__, self.gain_levels, self.noise_files_dir )
[docs]@TRANSFORMATIONS_REGISTRY.register(name="set_fixed_length", type="audio") class SetFixedLength(BaseTransformation): """Set the audio buffer to a fixed length."""
[docs] def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: super().__init__(opts, *args, **kwargs) self.length = getattr(opts, "audio_augmentation.set_fixed_length.length")
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: group = parser.add_argument_group(cls.__name__) group.add_argument( "--audio-augmentation.set-fixed-length.enable", action="store_true", help="Use {}. This flag is useful when you want to study the effect of" " different transforms.".format(cls.__name__), ) group.add_argument( "--audio-augmentation.set-fixed-length.length", default=16000, type=int, help="Length to which to trim or pad the audio buffer.", ) return parser
def __call__( self, data: Dict[str, Union[Dict[str, torch.Tensor], torch.Tensor, int]], *args, **kwargs, ) -> Dict[str, Union[Dict[str, torch.Tensor], torch.Tensor, int]]: """ Apply the transformation to the input data. Input data must have {"samples": {"audio": torch.Tensor}}. The audio must be [C, N] in shape, where C is the number of channels, and N is the number of samples. Returns: The transformed batch. """ audio = data["samples"]["audio"] if not audio.shape[0] in (1, 2): raise ValueError(f"Expected channels first. Got audio shape {audio.shape}") if audio.shape[1] < self.length: audio = torch.nn.functional.pad(audio, (0, self.length - audio.shape[1])) else: audio = audio[:, 0 : self.length] data["samples"]["audio"] = audio return data def __repr__(self) -> str: return f"{self.__class__.__name__}(length={self.length})"
[docs]@TRANSFORMATIONS_REGISTRY.register(name="roll", type="audio") class Roll(BaseTransformation): """Perform a roll augmentation by shifting the window in a circular manner."""
[docs] def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: self.window = getattr(opts, "audio_augmentation.roll.window") super().__init__(opts, *args, **kwargs)
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: group = parser.add_argument_group(cls.__name__) group.add_argument( "--audio-augmentation.roll.enable", action="store_true", help="Use {}. This flag is useful when you want to study the effect of" " different transforms.".format(cls.__name__), ) group.add_argument( "--audio-augmentation.roll.window", default=0.1, type=float, help="Maximum fraction of the audio buffer to move.", ) return parser
def __call__( self, data: Dict[str, Union[Dict[str, torch.Tensor], torch.Tensor, int]], *args, **kwargs, ) -> Dict[str, Union[Dict[str, torch.Tensor], torch.Tensor, int]]: """ Apply the transformation to the input data. Input data must have {"samples": {"audio": torch.Tensor}}. The audio must be [C, N] in shape, where C is the number of channels, and N is the number of samples. Returns: The transformed batch. """ audio = data["samples"]["audio"] C, N = audio.shape if not C in (1, 2): raise ValueError(f"Unexpected number of channels {C}") audio = torch.roll( audio, torch.randint(-int(N * self.window), int(N * self.window), [1]).item(), 1, ) data["samples"]["audio"] = audio return data def __repr__(self) -> str: return f"{self.__class__.__name__}(window={self.window})"
[docs]@TRANSFORMATIONS_REGISTRY.register(name="mfccs", type="audio") class MFCCs(BaseTransformation):
[docs] def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: super().__init__(opts, *args, **kwargs) self.num_mfccs = getattr(opts, "audio_augmentation.mfccs.num_mfccs") self.window_length = getattr(opts, "audio_augmentation.mfccs.window_length") self.num_frames = getattr(opts, "audio_augmentation.mfccs.num_frames")
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: group = parser.add_argument_group(cls.__name__) group.add_argument( "--audio-augmentation.mfccs.num-mfccs", default=20, type=int, help="Number of MFCC features.", ) group.add_argument( "--audio-augmentation.mfccs.window-length", type=float, default=0.023, help="Window length (unit: seconds) for MFCC calculation.", ) group.add_argument( "--audio-augmentation.mfccs.num-frames", type=int, default=8, help="Number of sub-time-slice temporal components. This argument is used" " for splitting the temporal dimension of the spectrogram into frames.", ) return parser
def __call__(self, data: Dict, *args, **kwargs) -> Dict: """ Converts the audio signal of the samples to MFCC features. See the documentation of @cvnets.misc.get_mfcc_features for further details. Args: { "samples": { "audio": torch.FloatTensor[num_clips x temporal_size x num_channels] "metadata": { "audio_fps": float } } }, Returns: { "samples": { "audio": torch.FloatTensor[num_clips, C, num_mfccs, num_frames, ceil(spectrogram_length/num_frames)] } } """ audio_fps = data["samples"]["metadata"]["audio_fps"] audio_image = mfccs.get_mfcc_features( data["samples"]["audio"], sampling_rate=audio_fps, num_mfccs=self.num_mfccs, window_length=self.window_length, num_frames=self.num_frames, ).detach() data["samples"]["audio"] = audio_image return data
[docs]class LambdaAudio(BaseTransformation): """ Similar to @torchvision.transforms.Lambda, applies a user-defined lambda on the audio samples as a transform. """
[docs] def __init__( self, opts: argparse.Namespace, func: Callable[[torch.Tensor], torch.Tensor], *args, **kwargs, ) -> None: self.func = func super(LambdaAudio, self).__init__(opts, *args, **kwargs)
def __call__(self, data: Dict, *args, **kwargs) -> Dict: data["samples"]["audio"] = self.func(data["samples"]["audio"]) return data
[docs]@TRANSFORMATIONS_REGISTRY.register(name="audio-resample", type="audio") class AudioResample(BaseTransformation): """Resample audio to a specified framerate."""
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> None: group = parser.add_argument_group(cls.__name__) group.add_argument( "--audio-augmentation.audio-resample.audio-fps", default=16000, type=int, help="Frames per second in the incoming audio stream. Default to 16000.", ) return parser
[docs] def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: super().__init__(opts, *args, **kwargs) self.sample_rate = getattr(opts, "audio_augmentation.audio_resample.audio_fps") self.effects = [["rate", str(self.sample_rate)]]
def __call__(self, data: Dict, *args, **kwargs) -> Dict: """Reample audio to the specified audio fps. Args: data: A dict of data input in the following format: { "samples": { "audio": torch.FloatTensor[num_clips x temporal_size x num_channels] "metadata": { "audio_fps": float } } }, Returns: { "samples": { "audio": torch.FloatTensor[num_clips x temporal_size x num_channels] "metadata": { "audio_fps": float } } } """ audio = data["samples"]["audio"] audio_rate = data["samples"]["metadata"]["audio_fps"] resampled_audio = [] for audio_tensor in audio: ( resampled_audio_tensor, sample_rate, ) = torchaudio.sox_effects.apply_effects_tensor( audio_tensor, audio_rate, self.effects, channels_first=False ) resampled_audio.append(resampled_audio_tensor) data["samples"]["audio"] = torch.stack(resampled_audio, dim=0) data["samples"]["metadata"]["audio_fps"] = self.sample_rate return data
[docs]@TRANSFORMATIONS_REGISTRY.register(name="standardize_channels", type="audio") class StandardizeChannels(BaseTransformation):
[docs] def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: self.num_channels = getattr( opts, "audio_augmentation.standardize_channels.num_channels" ) self.enable = getattr(opts, "audio_augmentation.standardize_channels.enable")
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: group = parser.add_argument_group(cls.__name__) group.add_argument( "--audio-augmentation.standardize-channels.num-channels", default=2, type=int, help="Number of output audio channels. Defaults to 2.", ) group.add_argument( "--audio-augmentation.standardize-channels.enable", default=False, action="store_true", help=f"Use {cls.__name__} transformation. Defaults to False.", ) return parser
def __call__(self, data: Dict, *args, **kwargs) -> Dict: """Ensures all audio samples have a specific number of channels. To reduce the number of audio channels from 2 to 1, the average values of the two channels is used. Args: data (Dict): { "samples": { "audio": Tensor[N,T,C] where N is the number of audio clips, T is the audio sequence length, and C is the number of channels. } } Returns: Dict: _description_ """ if not self.enable: return data audio = data["samples"]["audio"] # N, T, C assert audio.ndim == 3, f"Invalid audio dimension {audio.ndim}. Expected 3." num_input_channels = audio.shape[2] if num_input_channels == self.num_channels: return data if (num_input_channels, self.num_channels) == (1, 2): audio = audio.repeat(1, 1, 2) # N, T, 2 elif (num_input_channels, self.num_channels) == (2, 1): audio = audio.mean(dim=2, keepdim=True) # N, T, 1 else: raise NotImplementedError( f"The logic for standardizing audio channels with input shape of" f" {audio.shape} to {self.num_channels} channels is not implemented." ) data["samples"]["audio"] = audio return data