# Copyright (c) 2020, Apple Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
import operator
import numpy as np
from coremltools.converters.mil.mil import (
InputSpec,
Operation,
TensorInputType,
precondition,
types,
)
from coremltools.converters.mil.mil.operation import VALUE
from coremltools.converters.mil.mil.ops.defs._op_reqs import register_op
from coremltools.converters.mil.mil.ops.defs._utils import (
infer_type_with_broadcast,
promoted_primitive_type,
)
class elementwise_binary(Operation):
"""
Elementwise Binary Op Superclass
"""
input_spec = InputSpec(
x=TensorInputType(type_domain="T"),
y=TensorInputType(type_domain="T"),
)
type_domains = {
"T": (types.fp16, types.fp32, types.int32),
}
def type_inference(self):
typea = self.x.sym_type
typeb = self.y.sym_type
primitive_type = promoted_primitive_type(typea, typeb)
if primitive_type is None:
raise ValueError("Incompatible primitive types in broadcast operation")
primitive_type = self.get_dtype(primitive_type)
return infer_type_with_broadcast(typea, typeb, primitive_type)
@precondition(allow=VALUE)
def value_inference(self):
return self._cast_check_value_inferene(self.x.val, self.y.val)
def get_operator(self):
"""
All subclasses have to implement this.
"""
raise NotImplementedError()
def get_dtype(self, promoted_dtype):
"""
Override if output primitive type is different from input types
(e.g., less, greater)
"""
return promoted_dtype
def _cast_check_value_inferene(self, a, b):
"""
If one of the input is tensor, cast the result to tensor.
"""
to_cast = any([isinstance(x, np.ndarray) for x in [a, b]])
result = self.get_operator()(a, b)
return result if not to_cast else np.array(result)
class elementwise_binary_logical(elementwise_binary):
"""
Elementwise Binary Logical Op Superclass
"""
input_spec = InputSpec(
x=TensorInputType(type_domain="T"),
y=TensorInputType(type_domain="T"),
)
type_domains = {
"T": (types.bool,),
}
"""
Elementwise Binary Op Implementation(s)
"""
[docs]
@register_op
class add(elementwise_binary):
"""
Return ``x + y`` element-wise with
`broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_.
Parameters
----------
x: <\*,T> (Required)
* Shape must be compatible with ``y`` in broadcast.
y: <\*,T> (Required)
* Shape must be compatible with ``x`` in broadcast.
Returns
-------
<\*,T>
Attributes
----------
T: fp16, fp32, i32
"""
def get_operator(self):
return operator.add
[docs]
@register_op
class equal(elementwise_binary):
"""
Return the truth value of ``x == y`` element-wise with
`broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_
(``1`` for true, ``0`` for false in numeric domain).
Parameters
----------
x: <\*,T> (Required)
* Shape must be compatible with ``y`` in broadcast.
y: <\*,T> (Required)
* Shape must be compatible with ``x`` in broadcast.
Returns
-------
<\*, bool>
* A boolean tensor with the same shape as the inputs.
Attributes
----------
T: fp16, fp32, i32
"""
def get_operator(self):
return np.equal
def get_dtype(self, promoted_dtype):
return types.bool
[docs]
@register_op
class floor_div(elementwise_binary):
"""
Return ``x / y`` element-wise with
`broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_,
rounded towards negative infinity.
Parameters
----------
x: tensor<\*, T> (Required)
* Shape must be compatible with ``y`` in broadcast.
y: tensor<\*, T> (Required)
* Shape must be compatible with ``x`` in broadcast.
Returns
-------
tensor<\*, T>
* A tensor of the same type and shape as the inputs.
Attributes
----------
T: fp16, fp32, i32
"""
def get_operator(self):
return operator.floordiv
[docs]
@register_op
class greater(elementwise_binary):
"""
Return the truth value of ``x > y`` element-wise with
`broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_
(``1`` for true, ``0`` for false in numeric domain).
Parameters
----------
x: tensor<\*, T> (Required)
* Shape must be compatible with ``y`` in broadcast.
y: tensor<\*, T> (Required)
* Shape must be compatible with ``x`` in broadcast.
Returns
-------
tensor<\*, bool>
* A boolean tensor with the same shape as the inputs.
Attributes
----------
T: fp16, fp32, i32
"""
def get_operator(self):
return operator.gt
def get_dtype(self, promoted_dtype):
return types.bool
[docs]
@register_op
class greater_equal(elementwise_binary):
"""
Return the truth value of ``x >= y`` element-wise with
`broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_
(``1`` for true, ``0`` for false in numeric domain).
Parameters
----------
x: tensor<\*, T> (Required)
* Shape must be compatible with ``y`` in broadcast.
y: tensor<\*, T> (Required)
* Shape must be compatible with ``x`` in broadcast.
Returns
-------
tensor<\*?, bool>
* A boolean tensor with the same shape as the inputs.
Attributes
----------
T: fp16, fp32, i32
"""
def get_operator(self):
return operator.ge
def get_dtype(self, promoted_dtype):
return types.bool
[docs]
@register_op
class less(elementwise_binary):
"""
Return the truth value of ``x < y`` element-wise with
`broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_
(``1`` for true, ``0`` for false in numeric domain).
Parameters
----------
x: tensor<\*, T> (Required)
* Shape must be compatible with ``y`` in broadcast.
y: tensor<\*, T> (Required)
* Shape must be compatible with ``x`` in broadcast.
Returns
-------
tensor<\*?, bool>
* A boolean tensor with the same shape as the inputs.
Attributes
----------
T: fp16, fp32, i32
"""
def get_operator(self):
return operator.lt
def get_dtype(self, promoted_dtype):
return types.bool
[docs]
@register_op
class less_equal(elementwise_binary):
"""
Return the truth value of ``x <= y`` element-wise with
`broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_
(``1`` for true, ``0`` for false in numeric domain).
Parameters
----------
x: tensor<\*, T> (Required)
* Shape must be compatible with ``y`` in broadcast.
y: tensor<\*, T> (Required)
* Shape must be compatible with ``x`` in broadcast.
Returns
-------
tensor<\*?, bool>
* A boolean tensor with the same shape as the inputs.
Attributes
----------
T: fp16, fp32, i32
"""
def get_operator(self):
return operator.le
def get_dtype(self, promoted_dtype):
return types.bool
[docs]
@register_op
class logical_and(elementwise_binary_logical):
"""
Return the truth value of ``x AND y`` element-wise with
`broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_
Parameters
----------
x: tensor<\*, T> (Required)
* Shape must be compatible with ``y`` in broadcast.
y: tensor<\*, T> (Required)
* Shape must be compatible with ``x`` in broadcast.
Returns
-------
tensor<\*?, bool>
* A boolean tensor with the same shape as the inputs.
Attributes
----------
T: bool
"""
def get_operator(self):
return np.logical_and
def get_dtype(self, promoted_dtype):
return types.bool
[docs]
@register_op
class logical_or(elementwise_binary_logical):
"""
Return the truth value of ``x OR y`` element-wise with
`broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_
Parameters
----------
x: tensor<\*, T> (Required)
* Shape must be compatible with ``y`` in broadcast.
y: tensor<\*, T> (Required)
* Shape must be compatible with ``x`` in broadcast.
Returns
-------
tensor<\*?, bool>
* A boolean tensor with the same shape as the inputs.
Attributes
----------
T: bool
"""
def get_operator(self):
return np.logical_or
def get_dtype(self, promoted_dtype):
return types.bool
[docs]
@register_op
class logical_xor(elementwise_binary_logical):
"""
Return the truth value of ``x XOR y`` element-wise with
`broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_
Parameters
----------
x: tensor<\*, T> (Required)
* Shape must be compatible with ``y`` in broadcast.
y: tensor<\*, T> (Required)
* Shape must be compatible with ``x`` in broadcast.
Returns
-------
tensor<\*?, bool>
* A boolean tensor with the same shape as the inputs.
Attributes
----------
T: bool
"""
def get_operator(self):
return np.logical_xor
def get_dtype(self, promoted_dtype):
return types.bool
[docs]
@register_op
class maximum(elementwise_binary):
"""
Return ``x > y ? x : y`` element-wise with
`broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_.
Parameters
----------
x: tensor<\*, T> (Required)
* Shape must be compatible with ``y`` in broadcast.
y: tensor<\*, T> (Required)
* Shape must be compatible with ``x`` in broadcast.
Returns
-------
tensor<\*?, T>
* A tensor with the broadcasted shape from inputs, and type is derived from inputs.
Attributes
----------
T: fp16, fp32, i32
"""
def get_operator(self):
return np.maximum
[docs]
@register_op
class minimum(elementwise_binary):
"""
Return ``x > y ? y : x`` element-wise with
`broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_.
Parameters
----------
x: tensor<\*, T> (Required)
* Shape must be compatible with ``y`` in broadcast.
y: tensor<\*, T> (Required)
* Shape must be compatible with ``x`` in broadcast.
Returns
-------
tensor<\*?, T>
* A tensor with the broadcasted shape from inputs, and type is derived from inputs.
Attributes
----------
T: fp16, fp32, i32
"""
def get_operator(self):
return np.minimum
[docs]
@register_op
class mod(elementwise_binary):
"""
Return ``x % y`` element-wise with
`broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_.
Parameters
----------
x: tensor<\*, T> (Required)
* Shape must be compatible with ``y`` in broadcast.
y: tensor<\*, T> (Required)
* Shape must be compatible with ``x`` in broadcast.
Returns
-------
tensor<\*?, T>
* A tensor with the broadcasted shape from inputs, and type is derived from inputs.
Attributes
----------
T: fp16, fp32, i32
"""
def get_operator(self):
return operator.mod
[docs]
@register_op
class mul(elementwise_binary):
"""
Return ``x * y`` element-wise with
`broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_.
Parameters
----------
x: tensor<\*, T> (Required)
* Shape must be compatible with ``y`` in broadcast.
y: tensor<\*, T> (Required)
* Shape must be compatible with ``x`` in broadcast.
Returns
-------
tensor<\*?, T>
* A tensor with the broadcasted shape from inputs, and type is derived from inputs.
Attributes
----------
T: fp16, fp32, i32
"""
def get_operator(self):
return operator.mul
[docs]
@register_op
class not_equal(elementwise_binary):
"""
Return the truth value of ``x != y`` element-wise with
`broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_
(``1`` for true, ``0`` for false in numeric domain).
Parameters
----------
x: tensor<\*, T> (Required)
* Shape must be compatible with ``y`` in broadcast.
y: tensor<\*, T> (Required)
* Shape must be compatible with ``x`` in broadcast.
Returns
-------
tensor<\*?, bool>
* A boolean tensor with the broadcasted shape from inputs.
Attributes
----------
T: fp16, fp32, i32
"""
def get_operator(self):
return operator.ne
def get_dtype(self, promoted_dtype):
return types.bool
[docs]
@register_op
class real_div(elementwise_binary):
"""
Return ``x / y`` element-wise with
`broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_.
Parameters
----------
x: tensor<\*, T> (Required)
* Shape must be compatible with ``y`` in broadcast.
y: tensor<\*, T> (Required)
* Shape must be compatible with ``x`` in broadcast.
Returns
-------
tensor<\*?, T>
* A tensor with the broadcasted shape from inputs, and type is derived from inputs.
Attributes
----------
T: fp16, fp32, i32
"""
def get_operator(self):
return operator.truediv
[docs]
@register_op
class pow(elementwise_binary):
"""
Return ``x ^ y`` element-wise with
`broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_.
Parameters
----------
x: tensor<\*, T> (Required)
* Shape must be compatible with ``y`` in broadcast.
y: tensor<\*, T> (Required)
* Shape must be compatible with ``x`` in broadcast.
Returns
-------
tensor<\*?, T>
* A tensor with the broadcasted shape from inputs, and type is derived from inputs.
Attributes
----------
T: fp16, fp32, i32
"""
def get_operator(self):
return operator.pow
[docs]
@register_op
class sub(elementwise_binary):
"""
Return ``x - y`` element-wise with
`broadcasting <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_.
Parameters
----------
x: tensor<\*, T> (Required)
* Shape must be compatible with ``y`` in broadcast.
y: tensor<\*, T> (Required)
* Shape must be compatible with ``x`` in broadcast.
Returns
-------
tensor<\*?, T>
* A tensor with the broadcasted shape from inputs, and type is derived from inputs.
Attributes
----------
T: fp16, fp32, i32
"""
def get_operator(self):
return operator.sub