# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
import argparse
import math
from typing import List, Mapping, Union
import torch
from torch import Tensor
from torch.nn import functional as F
from loss_fn import LOSS_REGISTRY, BaseCriteria
from utils import logger
from utils.ddp_utils import is_master
# NeuralAugmentation can be used with any task. Therefore, we register both name and type
# as the same.
[docs]@LOSS_REGISTRY.register(name="neural_augmentation", type="neural_augmentation")
class NeuralAugmentation(BaseCriteria):
"""Compute the augmentation loss, as described in the
`RangeAugment <https://arxiv.org/abs/2212.10553>`_ paper.
Args:
opts: command line arguments
"""
__supported_metrics = ["psnr"]
[docs] def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None:
super().__init__(opts, *args, **kwargs)
perceptual_metric = getattr(opts, "loss.neural_augmentation.perceptual_metric")
is_master_node = is_master(opts)
if perceptual_metric is None and is_master_node:
logger.error(
"Perceptual metric can't be none. "
"Please specify perceptual metric using --loss.auxiliary.neural-augmentation.perceptual-metric argument"
)
if not isinstance(perceptual_metric, str) and is_master_node:
logger.error(
"The type of perceptual metric is not string. Got: {}".format(
type(perceptual_metric)
)
)
perceptual_metric = perceptual_metric.lower()
target_value = getattr(opts, "loss.neural_augmentation.target_value")
self.curriculumn_learning = False
self.iteration_based_training = getattr(
opts, "scheduler.is_iteration_based", False
)
self.target_str = f"{target_value}"
alpha = getattr(opts, "loss.neural_augmentation.alpha")
if perceptual_metric == "psnr":
if target_value is None and is_master_node:
logger.error("Target PSNR value can not be None.")
if isinstance(target_value, (int, float)):
if target_value < 0:
if is_master_node:
logger.error(
"PSNR value should be >= 0 in {}. Got: {}".format(
self.__class__.__name__, target_value
)
)
# compute target MSE using below equation
# # PSNR = 20 log10(255) - 10 log10(MSE)
target_mse = 10.0 ** ((20.0 * math.log10(255.0) - target_value) / 10.0)
self.target_value = torch.ones(size=(1,), dtype=torch.float).fill_(
target_mse
)
self.target_str = f"{target_value}"
elif isinstance(target_value, (list, tuple)) and len(target_value) == 2:
start_target_value = target_value[0]
end_target_value = target_value[1]
if start_target_value < 0 or end_target_value < 0:
if is_master_node:
logger.error(
"PSNR value should be >= 0 in {}. Got: {}".format(
self.__class__.__name__, target_value
)
)
# compute target MSE using below equation
# # PSNR = 20 log10(255) - 10 log10(MSE)
start_target_mse = 10.0 ** (
(20.0 * math.log10(255.0) - start_target_value) / 10.0
)
end_target_mse = 10.0 ** (
(20.0 * math.log10(255.0) - end_target_value) / 10.0
)
max_steps = (
getattr(opts, "scheduler.max_iterations")
if self.iteration_based_training
else getattr(opts, "scheduler.max_epochs")
)
if max_steps is None and is_master_node:
logger.error(
"Please specify {}. Got None.".format(
"--scheduler.max-iterations"
if self.iteration_based_training
else "--scheduler.max-epochs"
)
)
curriculum_method = getattr(
opts, "loss.neural_augmentation.curriculum_method"
)
if curriculum_method in CURRICULUM_METHOD.keys():
self.target_value = CURRICULUM_METHOD[curriculum_method](
start=start_target_mse, end=end_target_mse, period=max_steps
)
else:
raise NotImplementedError
self.curriculumn_learning = True
self.target_str = f"[{start_target_value}, {end_target_value}]"
else:
raise NotImplementedError
# the maximum possible MSE error is computed as:
# a = torch.ones((3, H, W)) * 255.0 # Max. input value is 255.0
# b = torch.zeros((3, H, W)) # min. input value is 0.0
# mse = torch.mean( (a -b) ** 2)
# 65025 is the maximum mse
self.alpha = alpha / 65025.0
else:
if is_master_node:
logger.error(
"Supported perceptual metrics are: {}. Got: {}".format(
self.__supported_metrics, perceptual_metric
)
)
self.perceptual_metric = perceptual_metric
self.device = getattr(opts, "dev.device", torch.device("cpu"))
[docs] @classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
group = parser.add_argument_group(cls.__name__)
group.add_argument(
"--loss.neural-augmentation.perceptual-metric",
type=str,
default="psnr",
choices=cls.__supported_metrics,
help=f"Name of the perceptual metric to be used in {cls.__name__}.",
)
group.add_argument(
"--loss.neural-augmentation.target-value",
type=float,
default=[40, 20],
nargs="+",
help=f"Target image similarity value in {cls.__name__}. Defaults to [40, 20]",
)
group.add_argument(
"--loss.neural-augmentation.curriculum-method",
type=str,
default="cosine",
choices=["linear", "cosine"],
help=f"Curriculum for varying the target image similarity value in {cls.__name__}."
f"Supported curriculums are {cls.__supported_metrics}. Defaults to cosine",
)
group.add_argument(
"--loss.neural-augmentation.alpha",
default=100.0,
type=float,
help="Scale loss value by alpha value. Defaults to 100. "
"Note: When perceptual metric is PSNR, alpha value is divided by 65025",
)
return parser
def _forward_psnr(
self, input_tensor: Tensor, augmented_tensor: Tensor, *args, **kwargs
) -> Tensor:
"""Compute the MSE error between input and augmented image, and minimizes
the distance between MSE error and target error.
Args:
input_tensor: Input image of shape [N, C, H, W]
augmented_tensor: Augmented image of shape [N, C, H, W]
Returns:
A scalar loss value
"""
squared_err = ((augmented_tensor - input_tensor) * 255.0) ** 2
# [B, C, H, W] --> [B]
pred_mse = torch.mean(squared_err, dim=[1, 2, 3])
if self.curriculumn_learning:
step = (
kwargs.get("iterations", 0)
if self.iteration_based_training
else kwargs.get("epoch", 0)
)
if step >= len(self.target_value):
step = -1
target_mse = self.target_value[step]
else:
target_mse = self.target_value
# compute L1 loss between target and current MSE
smooth_l1_loss = F.smooth_l1_loss(
input=pred_mse,
target=target_mse.expand_as(pred_mse).to(
device=pred_mse.device, dtype=pred_mse.dtype
),
reduction="mean",
)
loss_na = smooth_l1_loss * self.alpha
return loss_na
def _compute_loss(
self, input_tensor: Tensor, augmented_tensor: Tensor, *args, **kwargs
) -> Tensor:
"""Compute the neural augmentation loss.
Args:
input_tensor: Input image of shape [N, C, H, W]
augmented_tensor: Augmented image of shape [N, C, H, W]
Returns:
A scalar value
"""
if augmented_tensor is None:
logger.error(
f"Augmented tensor can't be None in {self.__class__.__name__} during training mode."
)
forward_loss_fn = getattr(self, f"_forward_{self.perceptual_metric}")
loss_na = forward_loss_fn(
input_tensor=input_tensor,
augmented_tensor=augmented_tensor,
*args,
**kwargs,
)
return loss_na
[docs] def forward(
self,
input_sample: Union[Tensor, Mapping[str, Union[Tensor, List[Tensor]]]],
prediction: Mapping[str, Tensor],
*args,
**kwargs,
) -> Tensor:
"""Compute the loss between input and augmented image, as described in
`RangeAugment <https://arxiv.org/abs/2212.10553>`_ paper.
Args:
input_sample: Input sample can either be a Tensor or a dictionary with mandatory key "image". In
case of a dictionary, the values can be a Tensor or list of Tensors.
prediction: Output of augmentation model. Mapping of (string: Tensor) with `augmented_tensor`
as the required key.
Shapes:
input_sample:
* Tensor: The shape of input tensor is [N, C, H, W]
* Mapping[str, Tensor]: The shape of tensor is [N, C, H, W]
* Mapping[str, List[Tensor]]: The length of List is N, and the shape of each tensor is [1, C, H, W]
prediction: The shape of prediction["augmented_tensor"] is [N, C, H, W]
Returns:
A scalar loss value
...note:
During validation or evaluation, neural augmentation loss is not computed and 0 is returned
"""
if not self.training:
return torch.tensor(0.0, device=self.device, dtype=torch.float)
if not isinstance(prediction, Mapping):
logger.error(
"Prediction needs to be an instance of Mapping and must contain augmented_tensor"
" as keys"
)
if isinstance(input_sample, Mapping):
input_sample = input_sample["image"]
if isinstance(input_sample, List):
# if its a list of images, stack them
input_sample = torch.stack(input_sample, dim=0)
augmented_tensor = prediction["augmented_tensor"]
loss_na = self._compute_loss(
input_tensor=input_sample,
augmented_tensor=augmented_tensor,
*args,
**kwargs,
)
return loss_na
[docs]def linear_curriculum(start: int, end: int, period: int) -> Tensor:
"""This function implements linear curriculum
Args:
start: the starting value for the set of points
end: the ending value for the set of points
period: size of the constructed tensor
Returns:
A float tensor of length period
"""
return torch.linspace(start=start, end=end, steps=period + 1, dtype=torch.float)
[docs]def cosine_curriculum(start: int, end: int, period: int) -> Tensor:
"""This function implements cosine curriculum
Args:
start: the starting value for the set of points
end: the ending value for the set of points
period: size of the constructed tensor
Returns:
A float tensor of length period
"""
curr = [
end + 0.5 * (start - end) * (1 + math.cos(math.pi * i / (period + 1)))
for i in range(period + 1)
]
curr = torch.tensor(curr, dtype=torch.float)
return curr
CURRICULUM_METHOD = {
"linear": linear_curriculum,
"cosine": cosine_curriculum,
}