Source code for data.transforms.image_bytes

# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.

import argparse
import io
from typing import Dict, Union

import numpy as np
import torch
from PIL import Image

from data.transforms import TRANSFORMATIONS_REGISTRY, BaseTransformation

def _image_to_bytes(x: torch.Tensor, **kwargs) -> io.BytesIO:
    Take an image in [0, 1] and save it as file bytes using PIL.

        x: an image tensor in [C, H, W] order, where C is the number of channels,
            and H, W are the height and width.
        kwargs: any keyword arguments that can be passed to PIL's

        The file bytes.

    assert x.min() >= 0
    assert x.max() <= 1
    x = (x * 255).byte().permute(1, 2, 0).cpu().numpy()  # Bytes in H, W, C order

    img = Image.fromarray(x)
    byte_array = io.BytesIO(), **kwargs)
    return byte_array

def _bytes_to_int32(byte_array: io.BytesIO) -> torch.Tensor:
    Convert a byte array to int32 values.

        byte_array: The input byte array.
        The int32 tensor.
    buf = np.frombuffer(byte_array.getvalue(), dtype=np.uint8)
    # The copy operation is required to avoid a warning about non-writable
    # tensors.
    buf = torch.from_numpy(buf.copy()).to(dtype=torch.int32)
    return buf

[docs]@TRANSFORMATIONS_REGISTRY.register(name="pil_save", type="image_torch") class PILSave(BaseTransformation): """ Encode an image with a supported file encoding. """
[docs] def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: super().__init__(opts=opts, *args, **kwargs) self.file_encoding = getattr(opts, "image_augmentation.pil_save.file_encoding") self.quality = getattr(opts, "image_augmentation.pil_save.quality") self.opts = opts
def __call__( self, data: Dict[str, Union[torch.Tensor, int]] ) -> Dict[str, Union[torch.Tensor, int]]: """ Serialize an image as file bytes. Args: data: A dictionary containing a key called "samples", which contains an image tensor of shape [C, H, W]. Returns: The transformed data. """ x = data["samples"] if self.file_encoding == "fCHW": x = (x * 255).byte().to(dtype=torch.int32).reshape(-1) elif self.file_encoding == "fHWC": x = (x * 255).byte().to(dtype=torch.int32).permute(1, 2, 0).reshape(-1) elif self.file_encoding == "TIFF": x = _bytes_to_int32(_image_to_bytes(x, format="tiff")) elif self.file_encoding == "PNG": x = _bytes_to_int32(_image_to_bytes(x, format="png")) elif self.file_encoding == "JPEG": quality = getattr(self.opts, "image_augmentation.pil_save.quality") x = _bytes_to_int32(_image_to_bytes(x, format="jpeg", quality=quality)) else: raise NotImplementedError( f"Invalid file encoding {self.file_encoding}. Expected one of 'fCHW, fHWC, TIFF, PNG, JPEG'." ) data["samples"] = x return data def __repr__(self) -> str: return f"{self.__class__.__name__}(file_encoding={self.file_encoding}, quality={self.quality})"
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: group = parser.add_argument_group(title=cls.__name__) group.add_argument( "--image-augmentation.pil-save.enable", action="store_true", help="Use {}. This flag is useful when you want to study the effect of different " "transforms.".format(cls.__name__), ) group.add_argument( "--image-augmentation.pil-save.file-encoding", choices=("fCHW", "fHWC", "TIFF", "PNG", "JPEG"), help="The type of file encoding to use. Defaults to TIFF.", default="TIFF", ) group.add_argument( "--image-augmentation.pil-save.quality", help="JPEG quality if using JPEG encoding. Defaults to 100.", type=int, default=100, ) return parser
[docs]@TRANSFORMATIONS_REGISTRY.register(name="shuffle_bytes", type="image_torch") class ShuffleBytes(BaseTransformation): """ Reorder the bytes in a 1-dimensional buffer. """
[docs] def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: super().__init__(opts=opts, *args, **kwargs) self.mode = getattr(opts, "image_augmentation.shuffle_bytes.mode") self.stride = getattr(opts, "image_augmentation.shuffle_bytes.stride") window_size = getattr(opts, "image_augmentation.shuffle_bytes.window_size") self.window_shuffle = torch.randperm(window_size)
def __call__( self, data: Dict[str, Union[torch.Tensor, int]] ) -> Dict[str, Union[torch.Tensor, int]]: """ Reorder the bytes of a 1-dimensional buffer. Args: data: A dictionary containing a key called "samples", which contains a tensor of shape [N], where N is the number of bytes. Returns: The transformed data. """ x = data["samples"] if not x.dim() == 1: raise ValueError(f"Expected 1d input, got {x.shape}") if self.mode == "reverse": x = torch.fliplr(x.view(1, -1))[0] elif self.mode == "random_shuffle": x = x[torch.randperm(x.shape[0])] elif self.mode == "cyclic_half_length": x = torch.roll(x, x.shape[0] // 2) elif self.mode == "stride": l = len(x) values = [] for i in range(self.stride): values.append(x[i :: self.stride]) x =, dim=0) assert len(x) == l elif self.mode == "window_shuffle": l = len(x) window_size = self.window_shuffle.shape[0] num_windows = l // window_size values = [] for i in range(num_windows): chunk = x[i * window_size : (i + 1) * window_size] values.append(chunk[self.window_shuffle]) # Add the last bits that fall outside the shuffling window. values.append(x[num_windows * window_size :]) x =, dim=0) assert len(x) == l else: raise NotImplementedError( f"mode={self.mode} not implemented. Expected one of 'reverse, random_shuffle, cyclic_half_length, stride, window_shuffle'." ) data["samples"] = x return data
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: group = parser.add_argument_group(title=cls.__name__) group.add_argument( "--image-augmentation.shuffle-bytes.enable", action="store_true", help="Use {}. This flag is useful when you want to study the effect of different " "transforms.".format(cls.__name__), ) group.add_argument( "--image-augmentation.shuffle-bytes.mode", default="reverse", help="The mode to use when shuffling bytes. Defaults to 'reverse'.", choices=( "reverse", "random_shuffle", "cyclic_half_length", "stride", "window_shuffle", ), ) group.add_argument( "--image-augmentation.shuffle-bytes.stride", type=int, default=1024, help="The stride of the window used in shuffling operations that are windowed. Defaults to 1024.", ) group.add_argument( "--image-augmentation.shuffle-bytes.window-size", type=int, default=1024, help="The size of the window used in shuffling operations that are windowed. Defaults to 1024.", ) return parser
[docs]@TRANSFORMATIONS_REGISTRY.register(name="mask_positions", type="image_torch") class MaskPositions(BaseTransformation): """ Mask out values in a 1-dimensional buffer using a fixed masking pattern. """
[docs] def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: super().__init__(opts=opts, *args, **kwargs) self.keep_frac = getattr(opts, "image_augmentation.mask_positions.keep_frac") self._cached_masks = None
def _generate_masks(self, N: int) -> torch.Tensor: if self._cached_masks is None: g = torch.Generator() # We want to fix the mask across all inputs, so we fix the seed. # Choose a seed with a good balance of 0 and 1 bits. See: # g.manual_seed(2147483647) random_mask = torch.zeros([N], requires_grad=False, dtype=torch.bool) random_mask[torch.randperm(N, generator=g)[: int(self.keep_frac * N)]] = 1 self._cached_masks = random_mask return self._cached_masks def __call__( self, data: Dict[str, Union[torch.Tensor, int]] ) -> Dict[str, Union[torch.Tensor, int]]: """ Mask values in a 1-dimensional buffer with a fixed masking pattern. Args: data: A dictionary containing a key called "samples", which contains a tensor of shape [N], where N is the number of bytes. Returns: The transformed data. """ x = data["samples"] mask = self._generate_masks(x.shape[0]) x = x[mask] data["samples"] = x return data
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: group = parser.add_argument_group(title=cls.__name__) group.add_argument( "--image-augmentation.mask-positions.enable", action="store_true", help="Use {}. This flag is useful when you want to study the effect of different " "transforms.".format(cls.__name__), ) group.add_argument( "--image-augmentation.mask-positions.keep-frac", type=float, default=0.5, help="The fraction of bytes to keep. Defaults to 0.5.", ) return parser
[docs]@TRANSFORMATIONS_REGISTRY.register(name="byte_permutation", type="image_torch") class BytePermutation(BaseTransformation): """ Remap byte values in [0, 255] to new values in [0, 255] using a permutation. """
[docs] def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: super().__init__(opts=opts, *args, **kwargs) g = torch.Generator() g.manual_seed(2147483647) self.mask = torch.randperm(256, generator=g)
def __call__( self, data: Dict[str, Union[torch.Tensor, int]] ) -> Dict[str, Union[torch.Tensor, int]]: """ Remap byte values in [0, 255] to new values in [0, 255] using a permutation. Args: data: A dictionary containing a key called "samples", which contains a tensor of shape [N], where N is the number of bytes. Returns: The transformed data. """ x = data["samples"] if x.dim() != 1: raise ValueError(f"Expected 1d tensor. Got {x.shape}.") x = torch.index_select(self.mask, dim=0, index=x) data["samples"] = x return data
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: group = parser.add_argument_group(title=cls.__name__) group.add_argument( "--image-augmentation.byte-permutation.enable", action="store_true", help="Use {}. This flag is useful when you want to study the effect of different " "transforms.".format(cls.__name__), ) return parser
[docs]@TRANSFORMATIONS_REGISTRY.register(name="random_uniform", type="image_torch") class RandomUniformNoise(BaseTransformation): """ Add random uniform noise to integer values. """
[docs] def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: super().__init__(opts=opts, *args, **kwargs) self.opts = opts self.width_range = getattr( opts, "image_augmentation.random_uniform.width_range" )
def __call__( self, data: Dict[str, Union[torch.Tensor, int]] ) -> Dict[str, Union[torch.Tensor, int]]: """ Add random uniform noise to byte values. Args: data: A dict containing a tensor in its "samples" key. The tensor contains integers representing byte values. Integers are used because negative padding values may be added later. The shape of the tenor is [N], where N is the number of bytes. Returns: The transformed data. """ x = data["samples"] noise = torch.randint_like(x, self.width_range[0], self.width_range[1] + 1) dtype = x.dtype x = x = x + noise x = x % 256 x = data["samples"] = x return data
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: group = parser.add_argument_group(title=cls.__name__) group.add_argument( "--image-augmentation.random-uniform.enable", action="store_true", help="Use {}. This flag is useful when you want to study the effect of different " "transforms.".format(cls.__name__), ) group.add_argument( "--image-augmentation.random-uniform.width-range", type=int, nargs=2, default=[-5, 5], help="The range of values from which to add noise. It is specified" " as [low, high] (inclusive). Defaults to [-5, 5].", ) return parser