Source code for

#  Copyright (c) 2020, Apple Inc. All rights reserved.
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at

from enum import Enum as _Enum
from typing import Set, Text

import numpy as np

from import AvailableTarget
from import Builder as mb
from import Operation, types
from import AbstractGraphPass
from import block_context_manager
from import register_pass
from import Program

class ComputePrecision(_Enum):
    FLOAT16 = "float16"
    FLOAT32 = "float32"

class AbstractQuantizationPass(AbstractGraphPass):
    Base class for Post-Training Quantization transforms.

    Derived class needs to implement following two methods:
        - is_valid_op(op)
        - transform_op(op)

    type_eps = {}
    type_min = {}
    type_negmin = {}

    def __init__(self, op_selector=None):
        if op_selector is not None and not callable(op_selector):
            raise TypeError(
                "Argument `op_selector` needs to be a callable function which accepts "
                "a MIL operation object and returns a boolean value."
        self.op_selector = op_selector

    def apply(self, prog):
        Walks over each operation in the graph and performs following two steps,
        1. Checks whether an operation is valid for that quantized transform using `is_valid_op` method.
        2. If yes, calls `transform_op` method of the derived quantized transform class.

        :param prog: MIL program
        :return: Transformed MIL program
        if not isinstance(prog, Program):
            raise TypeError('Transform "{}" can only be applied on PyMIL programs.'.format(self))

        if getattr(self, "skip_ops_by_type", set()) and self.op_selector is not None:
            raise ValueError(
                "The graph pass option `skip_ops_by_type` cannot be set along with "
                "the `op_selector` in FP16ComputePrecision. Please only use one "
                "method to control which ops to operate on."

        def apply_block(block):
            for op in list(block.operations):
                for b in op.blocks:

                if self.is_valid_op(op):
                    need_transform: bool
                    if self.op_selector is not None:
                        need_transform = self.op_selector(op)
                        need_transform = op.op_type not in getattr(self, "skip_ops_by_type", set())
                    if need_transform:

        for f in prog.functions.values():

    def transform_op(self, op):
        Replaces an op with a transformed op.

        :param op: MIL operation
        :return: None
        raise NotImplementedError(
            'Op transformation for quantization mode "{}" not implemented.'.format(self)

    def is_valid_op(self, op):
        Checks whether an operation is valid for given quantized transform.

        :param op: MIL operation
        :return: true | false
        raise NotImplementedError(
            'Operation Preconditions for quantization mode "{}" not implemented.'.format(self)

    def _close_to_zero(cls, val, np_type):
        if np_type not in cls.type_eps:
            cls.type_eps[np_type] = np.finfo(np_type).eps
            cls.type_min[np_type] = np.nextafter(0.0, 1.0, dtype=np_type)
            cls.type_negmin[np_type] = np.nextafter(0.0, -1.0, dtype=np_type)

        return np.isclose(val, 0, atol=cls.type_min[np_type], rtol=cls.type_eps[np_type])

    def __repr__(self):
        return str(self)

    def __str__(self):
        return type(self).__name__

class FP16ComputePrecision(AbstractQuantizationPass):
    This transform does the following, for each valid op and if the "op_selector" return True:
    - For each input of dtype float32, inject a "cast" op to change it to float16 dtype
    - For each output of dtype float16, inject a "cast" op to change it back to float32

    # Activation related ops with alpha/beta parameters.
    _ACTIVATION_ALPHA_OPS: Set[str] = {"elu", "leaky_relu", "prelu", "thresholded_relu"}
    _ELEMENTWISE_UNARY_EPSILON_OPS: Set[str] = {"inverse", "log", "rsqrt"}

    def __init__(self, op_selector=None):
        super(FP16ComputePrecision, self).__init__(op_selector=op_selector)
        self.target_dtype = "fp16"

        # Var that feeds into multiple ops will be casted once and cached into this dict
        # For reference: Checkout test_single_input_to_multiple_operations in `TestFP16CastTransform`.
        self.cache_vars = {}

    def fp16_overflow(self, op: Operation) -> bool:
        # This overflow check consists of two parts:
        # 1. For valid fp32 numbers (abs < 1e38), we want their exact values,
        #    so we make sure they are within fp16 range [-65504, 65504]
        # 2. For inifinities (abs >= 1e38), their exact values does not matter,
        #    so we can always downcast them to fp16 inf. For example, in attention mask
        #    we just want -inf to make the masked entries have 0 probability after softmax
        for _, inputs in op.inputs.items():
            is_list_input = isinstance(inputs, (list, tuple))
            if not is_list_input:
                inputs = [inputs]
            for var in inputs:
                if (
                    var.op is not None
                    and var.op.op_type == "const"
                    and var.is_tensor_or_scalar_of(dtype="fp32")
                    value = np.expand_dims(var.op.val.val, 0)
                    abs_value = np.abs(value)
                    if np.max(abs_value[np.where(abs_value < 1e38)], initial=0.0) > 65504:
                        return True
        return False

    def is_valid_op(self, op: Operation) -> bool:
        """Determines if op is valid for fp16 casting."""
        if op.op_type in ["cast", "while_loop", "cond"]:
            return False

        # TODO: Remove after supporting FP16 dynamic quantize transformation for list ops (rdar://74458192)
        if op.op_type in [
            return False

        if self.fp16_overflow(op):
            return False

        return True

    def should_cast_parameter(self, op: Operation, param_name: str) -> bool:
        """Determines if a param of an op should be casted to fp16."""
        # Make sure the param is valid for fp16 when type domain is specified.
        type_domain = getattr(op.input_spec.input_types[param_name], "type_domain", None)
        if type_domain and types.fp16 not in type_domain:
            return False

        if op.opset_version >= AvailableTarget.iOS17:
            # In IOS17+ activation ops with alpha/beta support mixed precision, and we don't want to
            # cast alpha/beta to fp16 for better numerical accuracy.
            if op.op_type in self._ACTIVATION_ALPHA_OPS and param_name == "alpha":
                return False
            if op.op_type in self._ACTIVATION_ALPHA_BETA_OPS and param_name in {"alpha", "beta"}:
                return False

            # Element-wise unary ops with epsilon also support mixed precision.
            if op.op_type in self._ELEMENTWISE_UNARY_EPSILON_OPS and param_name == "epsilon":
                return False

        return True

    def _check_underflow_to_zero(self, new_var, var):
        # We check whether there are casted values that "becomes" 0 which is not ideal for eps purposes.
        # However we skip arrays with more than 400 in case we compare through a large sparse matrix.
        if (
            new_var.val is not None
            and len(var.val.flatten()) < 400
            and self._close_to_zero(new_var.val, np.float16).any()
            value_modified = False
            original_val = var.val.flatten()
            new_val = new_var.val.flatten()

            for idx in range(len(original_val)):
                if not self._close_to_zero(original_val[idx], np.float32) and self._close_to_zero(
                    new_val[idx], np.float16
                    new_val[idx] = (
                        if np.sign(original_val[idx]) > 0
                        else self.type_negmin[np.float16]
                    value_modified = True

            if value_modified:
                if np.isscalar(new_var.val):
                    new_var._sym_val.val = new_val[0]
                    new_var._sym_val.val = new_val.reshape(new_var.val.shape)

    def transform_op(self, op):
        block = op.enclosing_block
        casted_inputs = {}
        inputs_modified = False

        for param, inputs in op.inputs.items():
            # First loop, iterates over all the input parameters of an operation.
            if not self.should_cast_parameter(op, param):

            is_list_input = isinstance(inputs, (list, tuple))
            if not is_list_input:
                inputs = [inputs]

            casted_inputs[param] = list(inputs[:])
            for i, var in enumerate(inputs):
                # Second loop, iterates over all the vars of a python list corresponding to an input parameter.
                if not var.is_tensor_or_scalar_of(dtype="fp32"):

                inputs_modified = True
                casted_var_name = + "_to_fp16"
                if (
                    len(var._child_ops) > 1
                    and casted_var_name in self.cache_vars
                    and (block.is_var_visible_in_block(self.cache_vars[casted_var_name]))
                    casted_inputs[param][i] = self.cache_vars[casted_var_name]
                    x = mb.cast(x=var, dtype="fp16", name=casted_var_name, before_op=op)
                    self._check_underflow_to_zero(x, var)

                    casted_inputs[param][i] = x
                    if len(var._child_ops) > 1:
                        self.cache_vars[casted_var_name] = casted_inputs[param][i]

            if not is_list_input:
                casted_inputs[param] = casted_inputs[param][0]

        if inputs_modified:
            casted_inputs.update({k: v for k, v in op.inputs.items() if k not in casted_inputs})
            casted_inputs["name"] = + "_cast"
            casted_inputs["before_op"] = op
            quant_output = getattr(mb, op.op_type)(**casted_inputs)

            if not isinstance(quant_output, (list, tuple)):
                quant_output = [quant_output]

            for old_output_var, new_output_var in zip(op.outputs, quant_output):
                if old_output_var.is_tensor_or_scalar_of(dtype="fp32") and (
                    not new_output_var.is_tensor_or_scalar_of(dtype="fp32")
                    x = mb.cast(
               + "_to_fp32",


[docs]@register_pass(namespace="common") class add_fp16_cast(FP16ComputePrecision): """ For each input of dtype float32, inject a ``cast`` op to change it to float16 dtype. For each output of dtype float16, inject a ``cast`` op to change it back to float32. This pass is the registered interface for FP16ComputePrecision, which makes it consistent with other passes' interfaces. Support options: - ``skip_ops_by_type``: Skip op types specified by comma-separated string; for example, ``"mul,const"``. """ _skip_ops_by_type: Set[Text] = set() @property def skip_ops_by_type(self): return self._skip_ops_by_type @skip_ops_by_type.setter def skip_ops_by_type(self, criteria: Text): self._skip_ops_by_type = set(criteria.split(","))