#
# Copyright 2020 Apple Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import dataclasses
import enum
from functools import partial
import numpy as np
from ._base_processor import Processor
from dnikit.base import Batch, PipelineStage
from dnikit._dict_utils import rename_keys, delete_keys
import dnikit.typing as dt
import dnikit.typing._types as t
[docs]@t.final
class MeanStdNormalizer(Processor):
"""
A :class:`Processor` that standardizes a :attr:`field <dnikit.base.Batch.fields>` of
a :class:`Batch <dnikit.base.Batch>` by subtracting the mean and adjusting the standard
deviation to 1.
More precisely, if ``x`` is the data to be processed, the following processing
is applied: ``(x - mean) / std``.
Args:
mean: **[keyword arg]** The mean to be applied
std: **[keyword arg]** The standard deviation to be applied
fields: **[keyword arg, optional]** a single :attr:`field <dnikit.base.Batch.fields>` name,
or an iterable of :attr:`field <dnikit.base.Batch.fields>` names, to be processed.
If ``fields`` param is ``None``, then all :attr:`fields <dnikit.base.Batch.fields>`
will be processed.
"""
def __init__(self, *, mean: float, std: float, fields: dt.OneManyOrNone[str] = None):
def func(data: np.ndarray) -> np.ndarray:
return np.divide(data - mean, std)
super().__init__(func, fields=fields)
[docs]@t.final
class Transposer(Processor):
"""
A :class:`Processor` that transposes dimensions in a data
:attr:`field <dnikit.base.Batch.fields>` from a :class:`Batch <dnikit.base.Batch>`.
This processor will reorder the dimensions of the data as specified in the :attr:`dim` param.
Example:
To reorder ``NCHW`` to ``NHWC`` (or vice versa), specify
``Transposer(dim=[0,3,1,2])``
Args:
dim: **[keyword arg]** the new order of the dimensions. It is illegal to reorder the
0th dimension.
fields: **[keyword arg, optional]** a single :attr:`field <dnikit.base.Batch.fields>` name,
or an iterable of :attr:`field <dnikit.base.Batch.fields>` names, to be transposed.
If ``fields`` param is ``None``, then all :attr:`fields <dnikit.base.Batch.fields>`
will be transposed.
See also:
:func:`numpy.transpose`
Raises:
ValueError: if input specifies reordering the 0th dimension
"""
def __init__(self, *,
dim: t.Sequence[int],
fields: dt.OneManyOrNone[str] = None):
if dim[0] != 0:
raise ValueError("Unable to move the 0th (batch) dimension.")
def func(data: np.ndarray) -> np.ndarray:
return np.transpose(data, axes=dim)
super().__init__(func, fields=fields)
[docs]@t.final
class FieldRemover(PipelineStage):
"""
A :class:`PipelineStage <dnikit.base.PipelineStage>` that removes some
:attr:`fields <dnikit.base.Batch.fields>` from a :class:`Batch <dnikit.base.Batch>`.
Args:
fields: **[keyword arg]** a single :attr:`field <dnikit.base.Batch.fields>` name, or an
iterable of :attr:`field <dnikit.base.Batch.fields>` names, to be removed.
keep: **[keyword arg, optional]** if True, the ``fields`` input will be kept and all other
will be removed
"""
def __init__(self, *, fields: dt.OneOrMany[str], keep: bool = False):
super().__init__()
self._fields = dt.resolve_one_or_many(fields, str)
self._keep = keep
def _get_batch_processor(self) -> t.Callable[[Batch], Batch]:
def batch_processor(batch: Batch) -> Batch:
builder = Batch.Builder(base=batch)
delete_keys(builder.fields, keys=self._fields, keep=self._keep)
self.logger.debug(
f"Result of FieldRemover has fields {list(builder.fields.keys())}"
)
return builder.make_batch()
return batch_processor
[docs]@t.final
class FieldRenamer(PipelineStage):
"""
A :class:`PipelineStage <dnikit.base.PipelineStage>` that renames some
:attr:`fields <dnikit.base.Batch.fields>` from a :class:`Batch <dnikit.base.Batch>`.
Args:
mapping: a dictionary (or similar) whose keys are the old
:attr:`field <dnikit.base.Batch.fields>` names
and values are the new :attr:`field <dnikit.base.Batch.fields>` names.
"""
def __init__(self, mapping: t.Mapping[str, str]):
self._mapping = mapping
def _get_batch_processor(self) -> t.Callable[[Batch], Batch]:
def batch_processor(batch: Batch) -> Batch:
builder = Batch.Builder(base=batch)
builder.fields = rename_keys(builder.fields, self._mapping)
self.logger.debug(
f"Result of FieldRenamer has fields {frozenset(builder.fields.keys())}"
)
return builder.make_batch()
return batch_processor
[docs]@t.final
class Flattener(Processor):
"""
A :class:`Processor` that collapses array of shape ``BxN1xN2x..`` into ``BxN``
Args:
order: **[optional]** {``C``, ``F``, ``A``, ``K``}:
``C`` (default) means to flatten in row-major (C-style) order.
``F`` means to flatten in column-major (Fortran-style) order.
``A`` means to flatten in column-major order if it is Fortran contiguous in
memory, row-major order otherwise.
``K`` means to flatten in the order the elements occur in memory.
fields: **[optional]** a single :attr:`field <dnikit.base.Batch.fields>` name, or an
iterable of :attr:`field <dnikit.base.Batch.fields>` names, to be resized. If the
``fields`` param is ``None``, then all the :attr:`fields <dnikit.base.Batch.fields>`
in the :class:`batch <dnikit.base.Batch>` will be resized.
Raises:
ValueError: if ``order`` param is not one of {``C``, ``F``, ``A``, ``K``}
"""
def __init__(self, order: str = 'C', fields: dt.OneManyOrNone[str] = None) -> None:
if order not in ['C', 'F', 'A', 'K']:
raise ValueError(
"``order`` param for ``Flattener`` must be `C`, `F`, `A`, or `K`}")
def func(data: np.ndarray) -> np.ndarray:
if order == 'C':
return data.reshape((data.shape[0], -1))
else:
return np.array([d.flatten(order) for d in data])
super().__init__(func, fields=fields)
[docs]@t.final
@dataclasses.dataclass(frozen=True)
class SnapshotSaver(PipelineStage):
"""
A :class:`PipelineStage <dnikit.base.PipelineStage>` that attaches the current
:class:`Batch <dnikit.base.Batch>` as the :attr:`snapshot <dnikit.base.Batch.snapshots>`.
Args:
save: **[optional]** see :attr:`save`
fields: **[optional]** see :attr:`fields`
keep: **[optional]** see :attr:`keep`
"""
save: str = "snapshot"
"""save the current state of the :class:`batches <dnikit.base.Batch>` under the given key."""
fields: dt.OneManyOrNone[str] = None
"""
Optional list of :attr:`fields <dnikit.base.Batch.fields>` to include/remove in the
saved :attr:`snapshot <dnikit.base.Batch.snapshots>` or ``None`` for all."""
keep: bool = True
"""If ``True``, the ``fields`` list are the fields to keep, if ``False``, the ones to omit."""
def _get_batch_processor(self) -> t.Callable[[Batch], Batch]:
fields = dt.resolve_one_many_or_none(self.fields, str)
def batch_processor(batch: Batch) -> Batch:
builder = Batch.Builder(base=batch)
snapshot = Batch.Builder(base=batch)
if fields:
delete_keys(snapshot.fields, keys=fields, keep=self.keep)
builder.snapshots[self.save] = snapshot.make_batch()
return builder.make_batch()
return batch_processor
[docs]@t.final
@dataclasses.dataclass(frozen=True)
class SnapshotRemover(PipelineStage):
"""
A :class:`PipelineStage <dnikit.base.PipelineStage>` that removes snapshots from a
:class:`Batch <dnikit.base.Batch>`. If used with no arguments, this
will remove *all* :attr:`snapshots <dnikit.base.Batch.snapshots>`.
Args:
snapshots: **[optional]** see :attr:`snapshots` attribute
keep: **[optional]** see :attr:`keep` attribute
"""
snapshots: dt.OneManyOrNone[str] = None
"""List of :attr:`snapshots <dnikit.base.Batch.snapshots>` to keep/remove."""
keep: bool = False
"""If ``True``, the listed ``snapshots`` are kept, else the ``snapshots`` will be removed."""
def _get_batch_processor(self) -> t.Callable[[Batch], Batch]:
snapshots = dt.resolve_one_many_or_none(self.snapshots, str)
def batch_processor(batch: Batch) -> Batch:
builder = Batch.Builder(base=batch)
if snapshots:
delete_keys(builder.snapshots, keys=snapshots, keep=self.keep)
else:
builder.snapshots = {}
return builder.make_batch()
return batch_processor
[docs]@t.final
@dataclasses.dataclass(frozen=True)
class PipelineDebugger(PipelineStage):
"""
A :class:`PipelineStage <dnikit.base.PipelineStage>` that can be used to inspect
:class:`batches <dnikit.base.Batch>` in a :class:`pipeline <dnikit.base.pipeline>`.
Args:
label: **[optional]** see :attr:`label`
first_only: **[optional]** see :attr:`first_only`
dump_fields: **[optional]** see :attr:`dump_fields`
fields: **[optional]** see :attr:`fields`
"""
label: str = ""
"""Optional label to display."""
first_only: bool = True
"""Show the first batch only."""
dump_fields: bool = False
"""If ``True``, print the contents of the fields."""
fields: dt.OneManyOrNone[str] = None
"""List of fields of interest. Default is None which means all. See ``dump_fields``"""
[docs] @staticmethod
def dump(batch: t.Union[Batch, Batch.Builder],
label: str = "",
dump_fields: bool = False,
fields: dt.OneManyOrNone[str] = None) -> str:
"""
Utility method to produce a dump of a :class:`Batch <dnikit.base.Batch>` or a
:class:`Batch.Builder <dnikit.base.Batch.Builder>`.
Args:
batch: :class:`Batch <dnikit.base.Batch>` or
:class:`Batch.Builder <dnikit.base.Batch.Builder>` to dump
label: **[optional]** see :attr:`label`
dump_fields: **[optional]** see :attr:`dump_fields`
fields: **[optional]** see :attr:`fields`
"""
prefix = f"{label} " if label else ""
batch_size = f"batch_size={len(next(iter(batch.fields.values())))}" if batch.fields else ""
result = f"{prefix}Batch({batch_size}) {{\n"
fields = dt.resolve_one_many_or_none(fields, str)
for name in sorted(batch.fields.keys()):
data = batch.fields[name]
if dump_fields and (fields is None or name in fields):
result += f"{name}: {data.shape}\n{data}\n"
else:
result += f"{name}: {data.shape}\n"
snapshots = batch.snapshots
if snapshots:
result += "\nSnapshots:\n"
for key in sorted(snapshots.keys()):
result += f"{key}: {sorted(snapshots[key].fields.keys())}\n"
metadata = batch.metadata
if metadata:
result += "\nMetadata:\n"
for meta_key in sorted(metadata.keys(), key=lambda x: x.name):
result += f"{meta_key}"
if isinstance(meta_key, Batch.DictMetaKey):
result += f"{meta_key}: {sorted(metadata[meta_key].keys())}"
result += "\n"
result += "}\n"
return result
def _get_batch_processor(self) -> t.Callable[[Batch], Batch]:
fields = dt.resolve_one_many_or_none(self.fields, str)
# track whether the first value was processed. note that it's necessary to capture
# reference to something mutable -- an array in this case.
first = [True]
def batch_processor(batch: Batch) -> Batch:
# show everything unless asked to show the first only
show = not self.first_only or first[0]
if show:
print(
PipelineDebugger.dump(
batch, label=self.label, dump_fields=self.dump_fields, fields=fields
),
end="\n\n"
)
# mark that the first item was handled
first[0] = False
return batch
return batch_processor
[docs]@t.final
class Pooler(Processor):
"""
A :class:`Processor` that pools the axes of a data field from a
:class:`Batch <dnikit.base.Batch>` with a specific method.
Args:
dim: **[keyword arg]** The dimension (one or many) to be pooled.
E.g., Spatial pooling is generally ``(1, 2)``.
method: **[keyword arg]** Pooling method. See :class:`Pooler.Method` for full list
of options.
fields: **[keyword arg, optional]** a single :attr:`field <dnikit.base.Batch.fields>`
name, or an iterable of :attr:`field <dnikit.base.Batch.fields>` names, to be pooled.
If the ``fields`` param is ``None``, then all the
:attr:`fields <dnikit.base.Batch.fields>` in the
:class:`batch <dnikit.base.Batch>` will be pooled.
"""
[docs] class Method(enum.Enum):
MAX = enum.auto()
SUM = enum.auto()
AVERAGE = enum.auto()
def __init__(self, *,
dim: dt.OneOrMany[int],
method: Method,
fields: dt.OneManyOrNone[str] = None):
dims = tuple(dt.resolve_one_or_many(dim, int))
assert 0 not in dims, "Unable to pool the 0th (batch) dimension."
def func(data: np.ndarray) -> np.ndarray:
assert data.shape, 'Data with no dimensions.'
assert len(dims) <= len(data.shape), (
f'data of dimension {data.shape}, too many dim {dims} selected.')
assert max(dims) < len(data.shape), (
f'dim {dims} out of data shape {data.shape}.')
if method is self.Method.MAX:
return np.max(data, axis=dims)
elif method is self.Method.SUM:
return np.sum(data, axis=dims)
elif method is self.Method.AVERAGE:
return np.mean(data, axis=dims)
else:
raise NotImplementedError(f'Pooling method {method.name} not implemented.')
super().__init__(func, fields=fields)
[docs]@t.final
@dataclasses.dataclass(frozen=True)
class Concatenator(PipelineStage):
"""
This :class:`PipelineStage <dnikit.base.PipelineStage>` will concatenate 2 or more
:attr:`fields <dnikit.base.Batch.fields>` in the :attr:`Batch <dnikit.base.Batch>` and produce
a new field with the given ``output_field``.
Example:
If there were fields ``M`` and ``N`` with dimensions ``BxM1xZ`` and ``BxN1xZ`` and
they were concatenated along dimension 1,
the result will have a new field of size ``Bx(M1+N1)xZ``.
Args:
dim: see :attr:`dim`
output_field: see :attr:`output_field`
fields: see :attr:`fields`
"""
dim: int
"""the dimension to concatenate along"""
output_field: str
"""name of the new :attr:`field <dnikit.base.Batch.fields>` (layer name) to hold the result"""
fields: t.Sequence[str]
"""a sequence of :attr:`fields <dnikit.base.Batch.fields>` to concatenate, in order"""
def __post_init__(self) -> None:
assert self.dim != 0, "Unable to concatenate along dimension 0 (batch dimension)"
assert len(self.fields) > 0, "Must specify fields to concatenate"
def _get_batch_processor(self) -> t.Callable[[Batch], Batch]:
# note that this is a PipelineStage rather than a Processor -- it needs to read
# multiple layers by name at once and add a new field
def batch_processor(batch: Batch) -> Batch:
builder = Batch.Builder(base=batch)
# collect a list of the source fields in order and concatenate them
builder.fields[self.output_field] = np.concatenate([
batch.fields[field]
for field in self.fields
], axis=self.dim)
return builder.make_batch()
return batch_processor
[docs]@t.final
class Composer(PipelineStage):
"""
Apply a filter function to all :class:`batches <dnikit.base.Batch>`, e.g. composing filter(b).
Args:
filter: The filter function to apply to every :class:`batch <dnikit.base.Batch>` in the
:func:`pipeline <dnikit.base.pipeline>`. The ``filter`` should take a single
:class:`Batch <dnikit.base.Batch>` as input and return a transformed
:class:`batch <dnikit.base.Batch>` (e.g. a subset) or ``None``
(to produce an empty :class:`batch <dnikit.base.Batch>`).
"""
def __init__(self, filter: t.Callable[[Batch], t.Optional[Batch]]) -> None:
super().__init__()
self._filter = filter
def _get_batch_processor(self) -> t.Callable[[Batch], Batch]:
def batch_processor(batch: Batch) -> Batch:
result = self._filter(batch)
if result is None:
result = batch.elements[[]]
return result
return batch_processor
[docs] @classmethod
def from_element_filter(cls, elem_filter: t.Callable[[Batch.ElementType], bool]) -> 'Composer':
"""
Initialize a :class:`Composer` that filters batch data based on element-wise filter criteria
Args:
elem_filter: :attr:`Batch.element <dnikit.base.Batch.elements>`-wise validation fnc.
Returns ``True`` if valid else ``False``
Return:
:class:`Composer` that filters batches to only elements that meet filter criteria
"""
def batch_filter(batch: Batch,
element_filter: t.Callable[[Batch.ElementType], bool]
) -> t.Optional[Batch]:
return batch.elements[[
i for i, element in enumerate(batch.elements)
if element_filter(element)
]]
return cls(filter=partial(batch_filter, element_filter=elem_filter))