#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import argparse
import random
from typing import List, Optional
import torch
from torch import Tensor, nn
from cvnets.misc.common import parameter_list
from cvnets.neural_augmentor.utils.neural_aug_utils import (
Clip,
FixedSampler,
UniformSampler,
random_brightness,
random_contrast,
random_noise,
)
from utils import logger
_distribution_tuple = (UniformSampler,)
[docs]class BaseNeuralAugmentor(nn.Module):
"""
Base class for `neural (or range) augmentation <https://arxiv.org/abs/2212.10553>`_
"""
[docs] def __init__(self, opts, *args, **kwargs):
super().__init__()
self.opts = opts
self.lr_multiplier = getattr(
opts, "model.learn_augmentation.lr_multiplier", 1.0
)
# Set variables corresponding to different transforms to None.
# We will override them in child classes with learnable versions
self.brightness = None
self.contrast = None
self.noise = None
self.aug_fns = []
def _is_valid_aug_fn_list(self, aug_fns):
if self.training:
if len(aug_fns) == 0:
logger.error(
"{} needs at least one learnable function.".format(
self.__class__.__name__
)
)
[docs] def get_trainable_parameters(
self,
weight_decay: Optional[float] = 0.0,
no_decay_bn_filter_bias: Optional[bool] = False,
*args,
**kwargs
):
"""Get trainable parameters"""
param_list = parameter_list(
named_parameters=self.named_parameters,
weight_decay=weight_decay,
no_decay_bn_filter_bias=no_decay_bn_filter_bias,
)
return param_list, [self.lr_multiplier] * len(param_list)
def __repr__(self):
aug_str = "{}(".format(self.__class__.__name__)
if self.brightness is not None:
aug_str += "\n\tBrightness={}, ".format(
self.brightness.data.shape
if isinstance(self.brightness, nn.Parameter)
else self.brightness
)
if self.contrast is not None:
aug_str += "\n\tContrast={}, ".format(
self.contrast.data.shape
if isinstance(self.contrast, nn.Parameter)
else self.contrast
)
if self.noise is not None:
aug_str += "\n\tNoise={}, ".format(
self.noise.data.shape
if isinstance(self.noise, nn.Parameter)
else self.noise
)
aug_str += self.extra_repr()
aug_str += ")"
return aug_str
[docs] @classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
"""Add model-specific arguments"""
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--model.learn-augmentation.mode",
type=str,
default=None,
choices=["basic", "distribution"],
help="Neural augmentation mode",
)
group.add_argument(
"--model.learn-augmentation.brightness",
action="store_true",
help="Learn parameters for brightness",
)
group.add_argument(
"--model.learn-augmentation.contrast",
action="store_true",
help="Learn parameters for contrast",
)
group.add_argument(
"--model.learn-augmentation.noise",
action="store_true",
help="Learn parameters for noise",
)
# LR multiplier
group.add_argument(
"--model.learn-augmentation.lr-multiplier",
type=float,
default=1.0,
help="LR multiplier for neural aug parameters",
)
return parser
def _build_aug_fns(self, opts) -> List:
raise NotImplementedError
def _apply_brightness(self, x: Tensor, *args, **kwargs) -> Tensor:
"""
Apply brightness augmentation function with learnable parameters.
"""
# self._check_brightness_bounds()
x_shape = [*x.shape]
x_shape[1:] = [1] * (len(x_shape) - 1)
if isinstance(self.brightness, nn.Parameter):
# learning a fixed number of parameters
magnitude = self.brightness
elif isinstance(self.brightness, _distribution_tuple):
# learning a distribution range from which parameter is sampled.
magnitude = self.brightness(x_shape, device=x.device, data_type=x.dtype)
else:
raise NotImplementedError
return random_brightness(x, magnitude, *args, **kwargs)
def _apply_contrast(self, x: Tensor, *args, **kwargs) -> Tensor:
"""
Apply contrast augmentation function with learnable parameters.
"""
# self._check_contrast_bounds()
x_shape = [*x.shape]
x_shape[1:] = [1] * (len(x_shape) - 1)
if isinstance(self.contrast, nn.Parameter):
# learning a fixed number of parameters
magnitude = self.contrast
elif isinstance(self.contrast, _distribution_tuple):
# learning a distribution range from which parameter is sampled.
magnitude = self.contrast(x_shape, device=x.device, data_type=x.dtype)
else:
raise NotImplementedError
return random_contrast(x, magnitude, *args, *kwargs)
def _apply_noise(self, x: Tensor, *args, **kwargs) -> Tensor:
# self._check_noise_bounds()
x_shape = [*x.shape]
x_shape[1:] = [1] * (len(x_shape) - 1)
if isinstance(self.noise, nn.Parameter):
# learning a fixed number of parameters
variance = self.noise
elif isinstance(self.noise, _distribution_tuple):
# learning a distribution range from which parameter is sampled.
variance = self.noise(x_shape, device=x.device, data_type=x.dtype)
else:
raise NotImplementedError
return random_noise(x, variance, *args, *kwargs)
[docs] def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
batch_size, in_channels, in_height, in_width = x.shape
# Randomly apply augmentation to 50% of the samples
n_aug_samples = max(1, (batch_size // 2))
# shuffle the order of augmentations
random.shuffle(self.aug_fns)
for aug_fn in self.aug_fns:
# select 50% samples for augmentation
sample_ids = torch.randperm(
n=batch_size, dtype=torch.long, device=x.device
)[:n_aug_samples]
x_aug = torch.index_select(x, dim=0, index=sample_ids)
# apply augmentation
x_aug = aug_fn(x=x_aug)
# copy augmented samples to tensor
x = torch.index_copy(x, dim=0, source=x_aug, index=sample_ids)
# clip the values so that they are between 0 and 1
x = torch.clip(x, min=0.0, max=1.0)
return x
[docs]class BasicNeuralAugmentor(BaseNeuralAugmentor):
"""
Basic neural augmentation. This class learns per-channel augmentation parameters
and apply the same parameter to all images in a batch.
See `neural (or range) augmentation <https://arxiv.org/abs/2212.10553>`_ paper for details.
"""
[docs] def __init__(self, opts, *args, **kwargs) -> None:
super().__init__(opts=opts, *args, **kwargs)
aug_fns = self._build_aug_fns(opts=opts)
self._is_valid_aug_fn_list(aug_fns)
self.aug_fns = aug_fns
def _build_aug_fns(self, opts) -> List:
aug_fns = []
if getattr(opts, "model.learn_augmentation.brightness", False):
self.brightness = FixedSampler(
value=1.0, clip_fn=Clip(min_val=0.1, max_val=10.0)
)
aug_fns.append(self._apply_brightness)
if getattr(opts, "model.learn_augmentation.contrast", False):
self.contrast = FixedSampler(
value=1.0, clip_fn=Clip(min_val=0.1, max_val=10.0)
)
aug_fns.append(self._apply_contrast)
if getattr(opts, "model.learn_augmentation.noise", False):
self.noise = FixedSampler(value=0.0, clip_fn=Clip(min_val=0.0, max_val=1.0))
aug_fns.append(self._apply_noise)
return aug_fns
[docs]class DistributionNeuralAugmentor(BaseNeuralAugmentor):
"""
Distribution-based neural (or range) augmentation. This class samples the augmentation parameters
from a specified distribution with learnable range.
See `neural (or range) augmentation <https://arxiv.org/abs/2212.10553>`_ paper for details.
"""
[docs] def __init__(self, opts, *args, **kwargs) -> None:
super().__init__(opts=opts, *args, **kwargs)
aug_fns = self._build_aug_fns_with_uniform_dist(opts=opts)
self._is_valid_aug_fn_list(aug_fns)
self.aug_fns = aug_fns
def _build_aug_fns_with_uniform_dist(self, opts) -> List:
# need to define the learnable parameters in a way that are compatible with bucketing
aug_fns = []
if getattr(opts, "model.learn_augmentation.brightness", False):
self.brightness = UniformSampler(
low=0.5,
high=1.5,
min_fn=Clip(min_val=0.1, max_val=0.9),
max_fn=Clip(min_val=1.1, max_val=10.0),
)
aug_fns.append(self._apply_brightness)
if getattr(opts, "model.learn_augmentation.contrast", False):
self.contrast = UniformSampler(
low=0.5,
high=1.5,
min_fn=Clip(min_val=0.1, max_val=0.9),
max_fn=Clip(min_val=1.1, max_val=10.0),
)
aug_fns.append(self._apply_contrast)
if getattr(opts, "model.learn_augmentation.noise", False):
self.noise = UniformSampler(
low=0.0,
high=0.1,
min_fn=Clip(min_val=0.0, max_val=0.00005),
max_fn=Clip(min_val=0.0001, max_val=1.0),
)
aug_fns.append(self._apply_noise)
return aug_fns
[docs]def build_neural_augmentor(opts, *args, **kwargs):
mode = getattr(opts, "model.learn_augmentation.mode", None)
if mode is None:
mode = "none"
mode = mode.lower()
if mode == "distribution":
return DistributionNeuralAugmentor(opts=opts, *args, **kwargs)
elif mode == "basic":
return BasicNeuralAugmentor(opts=opts, *args, **kwargs)
else:
return None