Source code for coremltools.converters.mil.mil.ops.defs.iOS15.elementwise_binary

#  Copyright (c) 2020, Apple Inc. All rights reserved.
#
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
import operator

import numpy as np

from coremltools.converters.mil.mil import (
    InputSpec,
    Operation,
    TensorInputType,
    precondition,
    types,
)
from coremltools.converters.mil.mil.operation import VALUE
from coremltools.converters.mil.mil.ops.defs._op_reqs import register_op
from coremltools.converters.mil.mil.ops.defs._utils import (
    infer_type_with_broadcast,
    promoted_primitive_type,
)


class elementwise_binary(Operation):
    """
    Elementwise Binary Op Superclass
    """
    input_spec = InputSpec(
        x=TensorInputType(type_domain="T"),
        y=TensorInputType(type_domain="T"),
    )

    type_domains = {
        "T": (types.fp16, types.fp32, types.int32),
    }

    def type_inference(self):
        typea = self.x.sym_type
        typeb = self.y.sym_type
        primitive_type = promoted_primitive_type(typea, typeb)
        if primitive_type is None:
            raise ValueError("Incompatible primitive types in broadcast operation")
        primitive_type = self.get_dtype(primitive_type)

        return infer_type_with_broadcast(typea, typeb, primitive_type)

    @precondition(allow=VALUE)
    def value_inference(self):
        return self._cast_check_value_inferene(self.x.val, self.y.val)

    def get_operator(self):
        """
        All subclasses have to implement this.
        """
        raise NotImplementedError()

    def get_dtype(self, promoted_dtype):
        """
        Override if output primitive type is different from input types
        (e.g., less, greater)
        """
        return promoted_dtype

    def _cast_check_value_inferene(self, a, b):
        """
        If one of the input is tensor, cast the result to tensor.
        """
        to_cast = any([isinstance(x, np.ndarray) for x in [a, b]])
        result = self.get_operator()(a, b)
        return result if not to_cast else np.array(result)


class elementwise_binary_logical(elementwise_binary):
    """
    Elementwise Binary Logical Op Superclass
    """
    input_spec = InputSpec(
        x=TensorInputType(type_domain="T"),
        y=TensorInputType(type_domain="T"),
    )

    type_domains = {
        "T": (types.bool,),
    }


"""
Elementwise Binary Op Implementation(s)
"""


[docs]@register_op class add(elementwise_binary): """ Return ``x + y`` element-wise with `broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_. Parameters ---------- x: <\*,T> (Required) * Shape must be compatible with ``y`` in broadcast. y: <\*,T> (Required) * Shape must be compatible with ``x`` in broadcast. Returns ------- <\*,T> Attributes ---------- T: fp16, fp32, i32 """ def get_operator(self): return operator.add
[docs]@register_op class equal(elementwise_binary): """ Return the truth value of ``x == y`` element-wise with `broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_ (``1`` for true, ``0`` for false in numeric domain). Parameters ---------- x: <\*,T> (Required) * Shape must be compatible with ``y`` in broadcast. y: <\*,T> (Required) * Shape must be compatible with ``x`` in broadcast. Returns ------- <\*, bool> * A boolean tensor with the same shape as the inputs. Attributes ---------- T: fp16, fp32, i32 """ def get_operator(self): return np.equal def get_dtype(self, promoted_dtype): return types.bool
[docs]@register_op class floor_div(elementwise_binary): """ Return ``x / y`` element-wise with `broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_, rounded towards negative infinity. Parameters ---------- x: tensor<\*, T> (Required) * Shape must be compatible with ``y`` in broadcast. y: tensor<\*, T> (Required) * Shape must be compatible with ``x`` in broadcast. Returns ------- tensor<\*, T> * A tensor of the same type and shape as the inputs. Attributes ---------- T: fp16, fp32, i32 """ def get_operator(self): return operator.floordiv
[docs]@register_op class greater(elementwise_binary): """ Return the truth value of ``x > y`` element-wise with `broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_ (``1`` for true, ``0`` for false in numeric domain). Parameters ---------- x: tensor<\*, T> (Required) * Shape must be compatible with ``y`` in broadcast. y: tensor<\*, T> (Required) * Shape must be compatible with ``x`` in broadcast. Returns ------- tensor<\*, bool> * A boolean tensor with the same shape as the inputs. Attributes ---------- T: fp16, fp32, i32 """ def get_operator(self): return operator.gt def get_dtype(self, promoted_dtype): return types.bool
[docs]@register_op class greater_equal(elementwise_binary): """ Return the truth value of ``x >= y`` element-wise with `broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_ (``1`` for true, ``0`` for false in numeric domain). Parameters ---------- x: tensor<\*, T> (Required) * Shape must be compatible with ``y`` in broadcast. y: tensor<\*, T> (Required) * Shape must be compatible with ``x`` in broadcast. Returns ------- tensor<\*?, bool> * A boolean tensor with the same shape as the inputs. Attributes ---------- T: fp16, fp32, i32 """ def get_operator(self): return operator.ge def get_dtype(self, promoted_dtype): return types.bool
[docs]@register_op class less(elementwise_binary): """ Return the truth value of ``x < y`` element-wise with `broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_ (``1`` for true, ``0`` for false in numeric domain). Parameters ---------- x: tensor<\*, T> (Required) * Shape must be compatible with ``y`` in broadcast. y: tensor<\*, T> (Required) * Shape must be compatible with ``x`` in broadcast. Returns ------- tensor<\*?, bool> * A boolean tensor with the same shape as the inputs. Attributes ---------- T: fp16, fp32, i32 """ def get_operator(self): return operator.lt def get_dtype(self, promoted_dtype): return types.bool
[docs]@register_op class less_equal(elementwise_binary): """ Return the truth value of ``x <= y`` element-wise with `broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_ (``1`` for true, ``0`` for false in numeric domain). Parameters ---------- x: tensor<\*, T> (Required) * Shape must be compatible with ``y`` in broadcast. y: tensor<\*, T> (Required) * Shape must be compatible with ``x`` in broadcast. Returns ------- tensor<\*?, bool> * A boolean tensor with the same shape as the inputs. Attributes ---------- T: fp16, fp32, i32 """ def get_operator(self): return operator.le def get_dtype(self, promoted_dtype): return types.bool
[docs]@register_op class logical_and(elementwise_binary_logical): """ Return the truth value of ``x AND y`` element-wise with `broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_ Parameters ---------- x: tensor<\*, T> (Required) * Shape must be compatible with ``y`` in broadcast. y: tensor<\*, T> (Required) * Shape must be compatible with ``x`` in broadcast. Returns ------- tensor<\*?, bool> * A boolean tensor with the same shape as the inputs. Attributes ---------- T: bool """ def get_operator(self): return np.logical_and def get_dtype(self, promoted_dtype): return types.bool
[docs]@register_op class logical_or(elementwise_binary_logical): """ Return the truth value of ``x OR y`` element-wise with `broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_ Parameters ---------- x: tensor<\*, T> (Required) * Shape must be compatible with ``y`` in broadcast. y: tensor<\*, T> (Required) * Shape must be compatible with ``x`` in broadcast. Returns ------- tensor<\*?, bool> * A boolean tensor with the same shape as the inputs. Attributes ---------- T: bool """ def get_operator(self): return np.logical_or def get_dtype(self, promoted_dtype): return types.bool
[docs]@register_op class logical_xor(elementwise_binary_logical): """ Return the truth value of ``x XOR y`` element-wise with `broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_ Parameters ---------- x: tensor<\*, T> (Required) * Shape must be compatible with ``y`` in broadcast. y: tensor<\*, T> (Required) * Shape must be compatible with ``x`` in broadcast. Returns ------- tensor<\*?, bool> * A boolean tensor with the same shape as the inputs. Attributes ---------- T: bool """ def get_operator(self): return np.logical_xor def get_dtype(self, promoted_dtype): return types.bool
[docs]@register_op class maximum(elementwise_binary): """ Return ``x > y ? x : y`` element-wise with `broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_. Parameters ---------- x: tensor<\*, T> (Required) * Shape must be compatible with ``y`` in broadcast. y: tensor<\*, T> (Required) * Shape must be compatible with ``x`` in broadcast. Returns ------- tensor<\*?, T> * A tensor with the broadcasted shape from inputs, and type is derived from inputs. Attributes ---------- T: fp16, fp32, i32 """ def get_operator(self): return np.maximum
[docs]@register_op class minimum(elementwise_binary): """ Return ``x > y ? y : x`` element-wise with `broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_. Parameters ---------- x: tensor<\*, T> (Required) * Shape must be compatible with ``y`` in broadcast. y: tensor<\*, T> (Required) * Shape must be compatible with ``x`` in broadcast. Returns ------- tensor<\*?, T> * A tensor with the broadcasted shape from inputs, and type is derived from inputs. Attributes ---------- T: fp16, fp32, i32 """ def get_operator(self): return np.minimum
[docs]@register_op class mod(elementwise_binary): """ Return ``x % y`` element-wise with `broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_. Parameters ---------- x: tensor<\*, T> (Required) * Shape must be compatible with ``y`` in broadcast. y: tensor<\*, T> (Required) * Shape must be compatible with ``x`` in broadcast. Returns ------- tensor<\*?, T> * A tensor with the broadcasted shape from inputs, and type is derived from inputs. Attributes ---------- T: fp16, fp32, i32 """ def get_operator(self): return operator.mod
[docs]@register_op class mul(elementwise_binary): """ Return ``x * y`` element-wise with `broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_. Parameters ---------- x: tensor<\*, T> (Required) * Shape must be compatible with ``y`` in broadcast. y: tensor<\*, T> (Required) * Shape must be compatible with ``x`` in broadcast. Returns ------- tensor<\*?, T> * A tensor with the broadcasted shape from inputs, and type is derived from inputs. Attributes ---------- T: fp16, fp32, i32 """ def get_operator(self): return operator.mul
[docs]@register_op class not_equal(elementwise_binary): """ Return the truth value of ``x != y`` element-wise with `broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_ (``1`` for true, ``0`` for false in numeric domain). Parameters ---------- x: tensor<\*, T> (Required) * Shape must be compatible with ``y`` in broadcast. y: tensor<\*, T> (Required) * Shape must be compatible with ``x`` in broadcast. Returns ------- tensor<\*?, bool> * A boolean tensor with the broadcasted shape from inputs. Attributes ---------- T: fp16, fp32, i32 """ def get_operator(self): return operator.ne def get_dtype(self, promoted_dtype): return types.bool
[docs]@register_op class real_div(elementwise_binary): """ Return ``x / y`` element-wise with `broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_. Parameters ---------- x: tensor<\*, T> (Required) * Shape must be compatible with ``y`` in broadcast. y: tensor<\*, T> (Required) * Shape must be compatible with ``x`` in broadcast. Returns ------- tensor<\*?, T> * A tensor with the broadcasted shape from inputs, and type is derived from inputs. Attributes ---------- T: fp16, fp32, i32 """ def get_operator(self): return operator.truediv
[docs]@register_op class pow(elementwise_binary): """ Return ``x ^ y`` element-wise with `broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_. Parameters ---------- x: tensor<\*, T> (Required) * Shape must be compatible with ``y`` in broadcast. y: tensor<\*, T> (Required) * Shape must be compatible with ``x`` in broadcast. Returns ------- tensor<\*?, T> * A tensor with the broadcasted shape from inputs, and type is derived from inputs. Attributes ---------- T: fp16, fp32, i32 """ def get_operator(self): return operator.pow
[docs]@register_op class sub(elementwise_binary): """ Return ``x - y`` element-wise with `broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_. Parameters ---------- x: tensor<\*, T> (Required) * Shape must be compatible with ``y`` in broadcast. y: tensor<\*, T> (Required) * Shape must be compatible with ``x`` in broadcast. Returns ------- tensor<\*?, T> * A tensor with the broadcasted shape from inputs, and type is derived from inputs. Attributes ---------- T: fp16, fp32, i32 """ def get_operator(self): return operator.sub