#
# Copyright 2021 Apple Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import dataclasses
import dnikit.typing._types as t
import numpy as np
import torch
from torch.utils.data import IterableDataset, DataLoader
from dnikit.base import Producer, Batch
# see ProducerTorchDataset class doc and mapping below
PRODUCER_TORCH_MAPPING = t.Union[str, Batch.DictMetaKey, Batch.MetaKey,
t.Callable[[Batch.ElementType], t.Any]]
# see TorchProducer class doc and mapping below
TORCH_PRODUCER_MAPPING = t.Union[str, Batch.DictMetaKey, Batch.MetaKey,
t.Callable[[t.Any, Batch.Builder], None]]
[docs]@dataclasses.dataclass(frozen=True)
class ProducerTorchDataset(IterableDataset):
"""
Adaptor that transforms any :class:`Producer <dnikit.base.Batch>` into a
:class:`PyTorch IterableDataset <torch.utils.data.IterableDataset>`.
The Producer can be something simple like a :class:`ImageProducer <dnikit.base.ImageProducer>`
or a more complex :func:`pipeline <dnikit.base.pipeline>` of stages.
Instances are given a :attr:`mapping` that describes how to transform the structured
data in a :class:`Batch.ElementType <dnikit.base.Batch.ElementType>` (type of single
:class:`batch.elements <dnikit.base.Batch.elements>`) into an unstructured tuple
that PyTorch expects from a Dataset. This same mapping can be used to map the positional values
from PyTorch back into a dnikit Producer via :class:`TorchProducer`.
This class also supports an optional :attr:`transforms` that works similar
to the ``transforms`` attr on
`PyTorch image datasets <https://pytorch.org/vision/stable/datasets.html>`_.
See Also
- :class:`TorchProducer` -- converts a PyTorch Dataset/DataLoader into a
:class:`Producer <dnikit.base.Producer>`
Args:
producer: see :attr:`producer`
mapping: see :attr:`mapping`
batch_size: **[optional]** see :attr:`batch_size`
transforms: **[optional]** see :attr:`transforms`
"""
producer: Producer
"""
The Producer to represent as a PyTorch Dataset.
"""
mapping: t.Sequence[PRODUCER_TORCH_MAPPING]
"""
Describes how to map a :class:`Batch.ElementType <dnikit.base.Batch.ElementType>` to the Dataset
result. Typically the first value returned from a Dataset is an array-like piece of data, e.g.
an ``image`` :attr:`field <dnikit.base.Batch.fields>` in a typical
:class:`Batch <dnikit.base.Batch>`.
The mapping supports several different types of values:
- string -- names a `batch.fields` to copy into the output
- :class:`DictMetaKey <dnikit.base.Batch.DictMetaKey>` /
:class:`MetaKey <dnikit.base.Batch.MetaKey>` -- names a
:attr:`batch.metadata <dnikit.base.Batch.metadata>` to copy into the output
- callable -- custom code to produce a custom result
For example:
.. code-block:: python
# consider a Batch.ElementType with data like this:
im = np.random.randint(255, size=(64, 64), dtype=np.uint8)
fields = {
"image": im,
"image2": im,
}
key1 = Batch.DictMetaKey[dict]("KEY1")
metadata = {
key1: {"k1": "v1", "k2": "v2"}
}
# it's possible to define the mapping like this:
def transform(element: Batch.ElementType) -> np.ndarray:
# note: pycharm requires a writable copy of the ndarray
return element.fields["image"].reshape((128, 32)).copy()
ds = ProducerTorchDataset(producer, ["image", "image2", key1, transform])
In this example the Dataset will produce two ndarrays, a dictionary and a reshaped ndarray.
"""
batch_size: int = 100
"""
The size of batch to read from the producer. This is independent of the downstream
batch size in PyTorch.
"""
transforms: t.Optional[t.Mapping[str, t.Callable[[torch.Tensor], torch.Tensor]]] = None
"""
Optional transforms (https://pytorch.org/vision/stable/transforms.html). This
is a mapping from field name to a Tensor transform, e.g. image and audio transforms.
Typical PyTorch Datasets provide a ``transform`` and ``target_transform`` to transform
the first and second values. This class requires passing in specific field names for the
transforms to apply to.
For example:
.. code-block:: python
dataset = ProducerTorchDataset(
producer, ["image", "mask", "heights"],
transforms={
"image": transforms.RandomCrop(32, 32),
"mask": transforms.Compose([
transforms.CenterCrop(10),
transforms.ColorJitter(),
]),
})
"""
def __post_init__(self) -> None:
# since str is-a Sequence, double check to make sure the caller
# didn't pass in a bare str by accident
assert not isinstance(self.mapping, str), (
"mapping should be a list of strings, callables and DictMetaKeys, not a single str.")
def __iter__(self) -> t.Iterator:
transforms = self.transforms or {}
for batch in self.producer(self.batch_size):
for element in batch.elements:
result: t.List[t.Any] = [] # this is a t.Any to allow for custom mappings
for i, mapping in enumerate(self.mapping):
if isinstance(mapping, str):
# take a copy of the field data -- torch.Tensor requires that the
# data be writable.
data = element.fields[mapping].copy()
if mapping in transforms:
tensor = data if isinstance(data, torch.Tensor) else torch.Tensor(data)
transformed = transforms[mapping](tensor)
data = transformed.detach().cpu().numpy()
result.append(data)
elif isinstance(mapping, Batch.MetaKey):
# simple sequence
data = element.metadata[mapping]
result.append(data)
elif isinstance(mapping, Batch.DictMetaKey):
# metadata is a dictionary of arrays
meta_data = element.metadata[mapping]
# dictionary with single field -> list
if len(meta_data) == 1:
single_field_data = meta_data[next(iter(meta_data))]
result.append(single_field_data)
else:
result.append(meta_data)
elif callable(mapping):
# custom mapping
result.append(mapping(element))
else:
raise ValueError(f'mapping "{mapping}" is an '
f'unhandled type: {type(mapping)}')
yield tuple(result)
[docs]@dataclasses.dataclass(frozen=True)
class TorchProducer(Producer):
"""
Adaptor that transforms a PyTorch
`DataLoader <https://pytorch.org/tutorials/beginner/basics/data_tutorial.html>`_ into a DNIKit
:class:`Producer <dnikit.base.Producer>`. This enables reuse of PyTorch Datasets with
DNIKit :func:`pipelines <dnikit.base.pipeline>`.
Instances are given a :attr:`mapping` that describes how to transform the unstructured
tuple that a PyTorch Dataset produces into a structured DNIKit
:class:`Batch <dnikit.base.Batch>`.
This same mapping can be used to match a :class:`Batch <dnikit.base.Batch>` into a
Dataset in :class:`ProducerTorchDataset`.
See Also
- :class:`ProducerTorchDataset` -- :class:`Producer <dnikit.base.Producer>` into a
PyTorch Dataset
Args:
data_loader: see :attr:`data_loader`
mapping: see :attr:`mapping`
anonymous_field_name: see :attr:`anonymous_field_name`
"""
data_loader: DataLoader
"""
The PyTorch DataLoader to adapt to a :class:`Producer <dnikit.base.Producer>`.
"""
mapping: t.Sequence[TORCH_PRODUCER_MAPPING]
"""
This mapping defines how the positional values in a PyTorch Dataset map
back to a structured :class:`Batch <dnikit.base.Batch>`. This is essentially the same mapping
used in :class:`ProducerTorchDataset` -- the same mapping could be used
to round-trip the data between PyTorch and dnikit.
The values in the mapping correspond to the positions in the Dataset result
and convert values as follows:
- string -- map a Tensor into a :attr:`batch.fields <dnikit.base.Batch.fields>`
:class:`numpy.ndarray`
- :class:`DictMetaKey <dnikit.base.Batch.DictMetaKey>` /
:class:`MetaKey <dnikit.base.Batch.MetaKey>` -- map a value into
:attr:`batch.metadata <dnikit.base.Batch.metadata>`
- callable -- perform custom conversion and update the
:class:`Batch.Builder <dnikit.base.Batch.Builder>`
- None -- discard a value
For example, given a Dataset that produced data like this:
.. code-block:: python
yield ndarray, ndarray, 50, {"k1": "v1", "k2": "v2"}
it can be mapped into dnikit :attr:`metadata <dnikit.base.Batch.metadata>` like this:
.. code-block:: python
key1 = Batch.DictMetaKey[int]("KEY1")
key2 = Batch.DictMetaKey[t.Mapping[str, str]]("KEY2")
producer = TorchProducer(loader, ["image", None, key1, key2])
That will map the first field into ``batch.fields["image"]`` as an
:class:`numpy.ndarray`. The second field will be discarded. The third and fourth fields
will come across as :attr:`metadata <dnikit.base.Batch.metadata>` like this:
.. code-block:: python
element.metadata[key1] == { "_": 50 }
element.metadata[key2] == {"k1": "v1", "k2": "v2"}
If the Dataset only produces image data, a single mapping will be sufficient: ``["image"]``
"""
anonymous_field_name: str = "_"
"""
The field name to use when mapping non-dictionary metadata to
:class:`DictMetaKey <dnikit.base.Batch.DictMetaKey>`.
For example, if a PyTorch Dataset produces:
.. code-block:: python
yield ndarray, [10, 20, 30]
it can be mapped into a :class:`DictMetaKey <dnikit.base.Batch.DictMetaKey>` like this:
.. code-block:: python
key1 = Batch.DictMetaKey[t.List[int]]("KEY1")
producer = TorchProducer(loader, ["image", key1])
element = next(iter(producer(1))).elements[0]
# this is how the metadata is surfaced
element.metadata[key1] == { "_": [10, 20, 30] }
Ideally a :class:`MetaKey <dnikit.base.Batch.MetaKey>` is used in these cases.
"""
def __post_init__(self) -> None:
# since str is-a Sequence, double check to make sure the caller
# didn't pass in a bare str by accident
assert not isinstance(self.mapping, str), ("mapping should be a list of strings, "
"callables and metadata keys, not a single str.")
@property
def batch_size(self) -> int:
return self.data_loader.batch_size or 100
def _transform(self, value: t.Any, mapping: TORCH_PRODUCER_MAPPING,
batch: Batch.Builder) -> None:
if isinstance(mapping, str):
# field data
if isinstance(value, torch.Tensor):
value = value.detach().cpu().numpy()
elif isinstance(value, list):
value = np.array(value)
elif isinstance(value, np.ndarray):
pass
else:
raise ValueError(f'field mapping "{mapping}" is an unhandled type: {type(value)}')
batch.fields[mapping] = value
elif isinstance(mapping, Batch.MetaKey):
# simple metadata, use the value directly (unwrapping tensors as needed)
if isinstance(value, torch.Tensor):
batch.metadata[mapping] = value.tolist()
elif isinstance(value, t.Sequence):
batch.metadata[mapping] = value
else:
raise ValueError(f'metadata mapping "{mapping}" '
f'is an unhandled type: {type(value)}')
elif isinstance(mapping, Batch.DictMetaKey):
# dictionary metadata
#
# if a Dataset returns:
#
# <number> -> Tensor([...])
# ndarray -> Tensor([ndarray, ndarray, ...])
#
# <str> -> [...]
# [a, b, ...] -> [ Tensor([a, ..]), Tensor([b, ...]), ...] or [["a", ...], ["b", ...]]
# {k1: v1, k2: v2} -> {k1: Tensor[v1, ...], k2: Tensor[v2, ...]}
if isinstance(value, torch.Tensor):
# map to an array of values
batch.metadata[mapping] = {self.anonymous_field_name: value.tolist()}
elif isinstance(value, t.Sequence):
value_is_list_tensor_or_tuple = (
isinstance(value[0], list) or
isinstance(value[0], torch.Tensor) or
isinstance(value[0], tuple)
)
if value_is_list_tensor_or_tuple:
# transpose the lists
# -- lists of [a, b, c, ...] are expected, not [a, a, a, ...]
batch.metadata[mapping] = {
self.anonymous_field_name: list(map(list, zip(*value)))}
else:
batch.metadata[mapping] = {self.anonymous_field_name: list(value)}
elif isinstance(value, dict):
# turn the keys into fields in the metadata -- the dicts at the element
# level will match what the Dataset returned
batch.metadata[mapping] = {
k: v.tolist() if isinstance(v, torch.Tensor) else v
for k, v in value.items()
}
else:
raise ValueError(f'metadata mapping "{mapping}" '
f'is an unhandled type: {type(value)}')
elif mapping is None:
# skip the field in the tuple
pass
elif callable(mapping):
mapping(value, batch)
else:
raise ValueError(f'cannot handle mapping of type: {type(mapping)}')
def __call__(self, batch_size: int) -> t.Iterable[Batch]:
if batch_size != self.batch_size:
raise ValueError('The Torch DataLoader used in this instance produces batches '
f'of size {self.batch_size}, '
f'requested batch size: {batch_size}')
for data in self.data_loader:
batch = Batch.Builder()
if isinstance(data, list):
# if the Dataset returns multiple values, data will be a list
for value, mapping in zip(data, self.mapping):
self._transform(value, mapping, batch)
else:
# a single value -- this is the result of a Dataset that returns a single value
self._transform(data, self.mapping[0], batch)
yield batch.make_batch()