#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
from typing import Any, Optional
import torch
from torch import Tensor, nn
[docs]class Clip(nn.Module):
[docs] def __init__(
self,
min_val: float,
max_val: float,
hard_clip: Optional[bool] = False,
*args,
**kwargs,
) -> None:
super().__init__()
self.min_val = min_val
self.max_val = max_val
self.hard_clip = hard_clip
[docs] def forward(self, x: Any) -> Any:
if self.hard_clip:
with torch.no_grad():
return x.clamp_(min=self.min_val, max=self.max_val)
else:
return (torch.sigmoid(x) * (self.max_val - self.min_val)) + self.min_val
def __repr__(self):
return "{}(min={}, max={}, clipping={})".format(
self.__class__.__name__,
self.min_val,
self.max_val,
"hard" if self.hard_clip else "soft",
)
[docs]class Identity(nn.Module):
[docs] def __init__(self, *args, **kwargs):
super().__init__()
[docs] def forward(self, x: Any) -> Any:
return x
[docs]class FixedSampler(nn.Module):
[docs] def __init__(
self,
value: float,
clip_fn: Optional[nn.Module] = Identity(),
*args,
**kwargs,
):
super().__init__()
self._value = nn.Parameter(torch.FloatTensor(1, 3, 1, 1).fill_(value))
self.clip_fn = clip_fn
[docs] def forward(
self, sample_shape=(), data_type=torch.float, device=torch.device("cpu")
) -> Tensor:
# sample values from uniform distribution
return self.clip_fn(self._value)
def __repr__(self):
return "{}(clip_fn={})".format(
self.__class__.__name__,
self.clip_fn,
)
[docs]def random_noise(x: Tensor, variance: Tensor, *args, **kwargs) -> Tensor:
"""Apply random noise sampled."""
noise = torch.randn_like(x) * variance
x = x + noise
return x
[docs]def random_contrast(x: Tensor, magnitude: Tensor, *args, **kwargs) -> Tensor:
# compute per-channel mean
per_channel_mean = torch.mean(x, dim=[-1, -2], keepdim=True)
# contrast can be written as
# (1 - contrast_factor) * per_channel_mean + img * contrast_factor
x = ((1.0 - magnitude) * per_channel_mean) + (x * magnitude)
return x
[docs]def random_brightness(x: Tensor, magnitude: Tensor, *args, **kwargs) -> Tensor:
"""
Brightness function.
"""
x = x * magnitude
return x
[docs]def identity(x: Tensor, *args, **kwargs) -> Tensor:
"""Identity function"""
return x