Source code for coremltools.converters.mil.mil.passes.defs.quantization

#  Copyright (c) 2020, Apple Inc. All rights reserved.
#
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

from enum import Enum as _Enum
from typing import Set, Text

import numpy as np

from coremltools import _logger as logger
from coremltools.converters.mil.backend.mil.load import should_use_weight_file
from coremltools.converters.mil.mil import Builder as mb
from coremltools.converters.mil.mil import types
from coremltools.converters.mil.mil.ops.defs.iOS16 import (
    constexpr_affine_dequantize,
    constexpr_lut_to_dense,
    constexpr_sparse_to_dense,
)
from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass
from coremltools.converters.mil.mil.passes.helper import block_context_manager
from coremltools.converters.mil.mil.passes.pass_registry import register_pass
from coremltools.converters.mil.mil.program import Program
from coremltools.converters.mil.mil.types.type_mapping import (
    is_builtin,
    nptype_from_builtin,
    numpy_type_to_builtin_type,
)
from coremltools.models.neural_network.quantization_utils import _get_kmeans_lookup_table_and_weight


class ComputePrecision(_Enum):
    FLOAT16 = "float16"
    FLOAT32 = "float32"


class AbstractQuantizationPass(AbstractGraphPass):
    """
    Base class for Post-Training Quantization transforms.

    Derived class needs to implement following two methods:
        - is_valid_op(op)
        - transform_op(op)
    """

    type_eps = {}
    type_min = {}
    type_negmin = {}

    def __init__(self, op_selector=None):
        super().__init__()
        if op_selector is not None and not callable(op_selector):
            raise TypeError(
                "Argument `op_selector` needs to be a callable function which accepts "
                "a MIL operation object and returns a boolean value."
            )
        self.op_selector = op_selector

    def apply(self, prog):
        """
        Walks over each operation in the graph and performs following two steps,
        1. Checks whether an operation is valid for that quantized transform using `is_valid_op` method.
        2. If yes, calls `transform_op` method of the derived quantized transform class.

        :param prog: MIL program
        :return: Transformed MIL program
        """
        if not isinstance(prog, Program):
            raise TypeError('Transform "{}" can only be applied on PyMIL programs.'.format(self))

        if getattr(self, "skip_ops_by_type", set()) and self.op_selector is not None:
            raise ValueError(
                "The graph pass option `skip_ops_by_type` cannot be set along with "
                "the `op_selector` in FP16ComputePrecision. Please only use one "
                "method to control which ops to operate on."
            )

        @block_context_manager
        def apply_block(block):
            for op in list(block.operations):
                for b in op.blocks:
                    apply_block(b)

                if self.is_valid_op(op):
                    need_transform: bool
                    if self.op_selector is not None:
                        need_transform = self.op_selector(op)
                    else:
                        need_transform = op.op_type not in getattr(self, "skip_ops_by_type", set())
                    if need_transform:
                        self.transform_op(op)

        for f in prog.functions.values():
            apply_block(f)

    def transform_op(self, op):
        """
        Replaces an op with a transformed op.

        :param op: MIL operation
        :return: None
        """
        raise NotImplementedError(
            'Op transformation for quantization mode "{}" not implemented.'.format(self)
        )

    def is_valid_op(self, op):
        """
        Checks whether an operation is valid for given quantized transform.

        :param op: MIL operation
        :return: true | false
        """
        raise NotImplementedError(
            'Operation Preconditions for quantization mode "{}" not implemented.'.format(self)
        )

    @classmethod
    def _close_to_zero(cls, val, np_type):
        if np_type not in cls.type_eps:
            cls.type_eps[np_type] = np.finfo(np_type).eps
            cls.type_min[np_type] = np.nextafter(0.0, 1.0, dtype=np_type)
            cls.type_negmin[np_type] = np.nextafter(0.0, -1.0, dtype=np_type)

        return np.isclose(val, 0, atol=cls.type_min[np_type], rtol=cls.type_eps[np_type])

    def __repr__(self):
        return str(self)

    def __str__(self):
        return type(self).__name__


class FP16ComputePrecision(AbstractQuantizationPass):
    """
    This transform does the following, for each valid op and if the "op_selector" return True:
    - For each input of dtype float32, inject a "cast" op to change it to float16 dtype
    - For each output of dtype float16, inject a "cast" op to change it back to float32
    """

    def __init__(self, op_selector=None):
        super(FP16ComputePrecision, self).__init__(op_selector=op_selector)
        self.target_dtype = "fp16"

        # Var that feeds into multiple ops will be casted once and cached into this dict
        # For reference: Checkout test_single_input_to_multiple_operations in `TestFP16CastTransform`.
        self.cache_vars = {}

    def fp16_overflow(self, op):
        # Constants with values more than 65504 or less than -65504 overflows in FP16
        for _, inputs in op.inputs.items():
            is_list_input = isinstance(inputs, (list, tuple))
            if not is_list_input:
                inputs = [inputs]
            for var in inputs:
                if (
                    var.op is not None
                    and var.op.op_type == "const"
                    and var.is_tensor_or_scalar_of(dtype="fp32")
                ):
                    if np.max(np.abs(var.op.val.val), initial=0.0) > 65504:
                        return True
        return False

    def is_valid_op(self, op):

        if op.op_type in ["cast", "while_loop", "cond"]:
            return False

        if op.op_type in [
            "make_list",
            "list_gather",
            "list_scatter",
            "list_read",
            "list_write",
            "list_length",
        ]:
            return False  #  rdar://74458192

        if op.op_type in ["gru", "rnn", "lstm"]:
            return False

        if self.fp16_overflow(op):
            return False

        return True

    def is_valid_parameter(self, op, param_name):
        type_domain = getattr(op.input_spec.input_types[param_name], "type_domain", None)
        if type_domain is not None:
            if len(type_domain) == 0:
                return True
            return types.fp16 in type_domain
        return True

    def _check_underflow_to_zero(self, new_var, var):
        # We check whether there are casted values that "becomes" 0 which is not ideal for eps purposes.
        # However we skip arrays with more than 400 in case we compare through a large sparse matrix.
        if (
            new_var.val is not None
            and len(var.val.flatten()) < 400
            and self._close_to_zero(new_var.val, np.float16).any()
        ):
            value_modified = False
            original_val = var.val.flatten()
            new_val = new_var.val.flatten()

            for idx in range(len(original_val)):
                if not self._close_to_zero(original_val[idx], np.float32) and self._close_to_zero(
                    new_val[idx], np.float16
                ):
                    new_val[idx] = (
                        self.type_min[np.float16]
                        if np.sign(original_val[idx]) > 0
                        else self.type_negmin[np.float16]
                    )
                    value_modified = True

            if value_modified:
                if np.isscalar(new_var.val):
                    new_var._sym_val.val = new_val[0]
                else:
                    new_var._sym_val.val = new_val.reshape(new_var.val.shape)

    def transform_op(self, op):
        block = op.enclosing_block
        casted_inputs = {}
        inputs_modified = False

        for param, inputs in op.inputs.items():
            # First loop, iterates over all the input parameters of an operation.
            if not self.is_valid_parameter(op, param):
                continue

            is_list_input = isinstance(inputs, (list, tuple))
            if not is_list_input:
                inputs = [inputs]

            casted_inputs[param] = list(inputs[:])
            for i, var in enumerate(inputs):
                # Second loop, iterates over all the vars of a python list corresponding to an input parameter.
                if not var.is_tensor_or_scalar_of(dtype="fp32"):
                    continue

                inputs_modified = True
                casted_var_name = var.name + "_to_fp16"
                if (
                    len(var._child_ops) > 1
                    and casted_var_name in self.cache_vars
                    and (block.is_var_visible_in_block(self.cache_vars[casted_var_name]))
                ):
                    casted_inputs[param][i] = self.cache_vars[casted_var_name]
                else:
                    x = mb.cast(x=var, dtype="fp16", name=casted_var_name, before_op=op)
                    self._check_underflow_to_zero(x, var)

                    casted_inputs[param][i] = x
                    if len(var._child_ops) > 1:
                        self.cache_vars[casted_var_name] = casted_inputs[param][i]

            if not is_list_input:
                casted_inputs[param] = casted_inputs[param][0]

        if inputs_modified:
            casted_inputs.update({k: v for k, v in op.inputs.items() if k not in casted_inputs})
            casted_inputs["name"] = op.name + "_cast"
            casted_inputs["before_op"] = op
            quant_output = getattr(mb, op.op_type)(**casted_inputs)

            if not isinstance(quant_output, (list, tuple)):
                quant_output = [quant_output]

            for old_output_var, new_output_var in zip(op.outputs, quant_output):
                if old_output_var.is_tensor_or_scalar_of(dtype="fp32") and (
                    not new_output_var.is_tensor_or_scalar_of(dtype="fp32")
                ):
                    x = mb.cast(
                        x=new_output_var,
                        dtype="fp32",
                        name=new_output_var.name + "_to_fp32",
                        before_op=op,
                    )
                    op.enclosing_block.replace_uses_of_var_after_op(
                        anchor_op=op,
                        old_var=old_output_var,
                        new_var=x,
                        force_replace=True,
                    )
                else:
                    op.enclosing_block.replace_uses_of_var_after_op(
                        anchor_op=op,
                        old_var=old_output_var,
                        new_var=new_output_var,
                        force_replace=True,
                    )

            block.remove_ops([op])


[docs]@register_pass(namespace="common") class add_fp16_cast(FP16ComputePrecision): """ For each input of dtype float32, inject a ``cast`` op to change it to float16 dtype. For each output of dtype float16, inject a ``cast`` op to change it back to float32. This pass is the registered interface for FP16ComputePrecision, which makes it consistent with other passes' interfaces. Support options: - ``skip_ops_by_type``: Skip op types specified by comma-separated string; for example, ``"mul,const"``. """ _skip_ops_by_type: Set[Text] = set() @property def skip_ops_by_type(self): return self._skip_ops_by_type @skip_ops_by_type.setter def skip_ops_by_type(self, criteria: Text): self._skip_ops_by_type = set(criteria.split(","))
class SparseParams: def __init__(self, nonzero_data=None, mask=None, shape=None): self.nonzero_data = nonzero_data self.mask = mask self.shape = shape class WeightSparsifier(AbstractQuantizationPass): """ This transform does the following, for each const op and if the "op_selector" return True: - (self.sparsity) fraction of values with the least absolute value are zeroed out. - If fake_compression=False, Zeroed-Out Value is encoded via constexpr_sparse_to_dense op - If fake_compression=True, Zeroed-Out Value is encoded via const op - Old const is replaced by a new operation with zeroed-out value. """ WEIGHT_SPARSIFICATION_MODES = ("THRESHOLD_BASED", "PERCENTILE_BASED") def __init__( self, mode="threshold_based", threshold=1e-3, target_percentile=1.0, fake_compression=False, op_selector=None, ): super().__init__(op_selector=op_selector) self.fake_compression = fake_compression self.mode = mode.upper() self.threshold = threshold self.target_percentile = target_percentile if not self.mode in WeightSparsifier.WEIGHT_SPARSIFICATION_MODES: msg = "Only mode {} supported for weight sparsification. Got mode {}.".format( WeightSparsifier.WEIGHT_SPARSIFICATION_MODES, self.mode ) raise ValueError(msg) if self.mode == "PERCENTILE_BASED" and ( self.target_percentile < 0 or self.target_percentile > 1 ): raise ValueError( "Invalid value of target_percentile: {}. Needs to be in [0, 1]".format( self.target_percentile ) ) if self.mode == "THRESHOLD_BASED" and self.threshold < 0: raise ValueError( "Invalid value of threshold: {}. Needs to be in [0, inf)".format(self.threshold) ) def is_valid_op(self, op): if op.op_type == "const" and should_use_weight_file(op.val.val): return True return False @staticmethod def compress(val, mode, target_percentile=None, threshold=None): mode = mode.upper() def sparsify_with_percentile(val, target_percentile): q = target_percentile * 100 return np.where(np.abs(val) <= np.percentile(np.abs(val), q), 0, val) def sparsify_with_thresohld(val, threshold): return np.where(np.abs(val) <= threshold, 0, val) if not isinstance(val, (np.ndarray, np.generic)): raise ValueError("Only numpy arrays are supported") flattened_val = val.flatten() if mode == "PERCENTILE_BASED": flattened_val = sparsify_with_percentile(flattened_val, target_percentile) elif mode == "THRESHOLD_BASED": flattened_val = sparsify_with_thresohld(flattened_val, threshold) params = SparseParams() params.nonzero_data = flattened_val[np.where(flattened_val != 0)] params.mask = np.packbits(np.where(flattened_val != 0, 1, 0), bitorder="little") params.shape = val.shape return params @staticmethod def decompress(params): if not isinstance(params, SparseParams): raise ValueError("Invalid type of params") return constexpr_sparse_to_dense.decompress(params.nonzero_data, params.mask, params.shape) def transform_op(self, op): block = op.enclosing_block sparse_params = self.compress(op.val.val, self.mode, self.target_percentile, self.threshold) if not self.fake_compression: new_var = mb.constexpr_sparse_to_dense( nonzero_data=sparse_params.nonzero_data, mask=sparse_params.mask, shape=np.uint32(sparse_params.shape), before_op=op, name=op.name + "_sparsified", ) else: decompressed_val = self.decompress(sparse_params) new_var = mb.const( val=decompressed_val, before_op=op, name=op.name + "_fake_sparsified", ) op.enclosing_block.replace_uses_of_var_after_op( anchor_op=op, old_var=op.outputs[0], new_var=new_var, no_check_var_types=True, ) block.remove_ops([op]) class LutParams: def __init__(self, lut=None, indices=None, shape=None): self.lut = lut self.indices = indices self.shape = shape class WeightPalettizer(AbstractQuantizationPass): """ This transform does the following, for each const op and if the "op_selector" return True: - A linear look up table with 2**(nbits) entries is created and value is represented via indexing into this look up table. - If fake_compression=False, compressed value is encoded via constexpr_lut_to_dense op - If fake_compression=True, compressed value is decompressed and then encoded via const op - Old const op is replaced by a newly created operation. """ WEIGHT_PALETTIZATION_MODES = ("KMEANS", "UNIFORM", "UNIQUE", "CUSTOM") def __init__( self, nbits, fake_compression=False, op_selector=None, mode="kmeans", lut_function=None ): super().__init__(op_selector=op_selector) self.fake_compression = fake_compression self.nbits = nbits self.mode = mode.upper() self.lut_function = lut_function if not self.mode in WeightPalettizer.WEIGHT_PALETTIZATION_MODES: msg = "Only mode {} supported for weight palettization. Got mode {}.".format( WeightPalettizer.WEIGHT_PALETTIZATION_MODES, self.mode ) raise ValueError(msg) if nbits is None and self.mode in ("KMEANS", "UNIFORM"): msg = "nbits must be provided for mode {}".format(mode) raise ValueError(msg) if nbits is not None and self.mode in ("UNIQUE", "CUSTOM"): msg = "nbits must NOT be provided for mode {}".format(mode) raise ValueError(msg) if self.nbits is not None and self.nbits not in (1, 2, 4, 6, 8): raise ValueError( "Invalid value of nbits ({}) for palettization. Supported bits are {{1, 2, 4, 6, 8}}".format( nbits ) ) if (self.mode == "CUSTOM") ^ (lut_function is not None): msg = "lut_function must be None if mode is not custom, and that it cannot be None when the mode is custom." raise ValueError(msg) if self.mode == "CUSTOM" and not callable(self.lut_function): msg = "A function object must be provided as lut_function. Got a lut_functions as type {}".format( type(self.lut_function) ) raise ValueError(msg) def is_valid_op(self, op): if op.op_type == "const" and should_use_weight_file(op.val.val): return True return False @staticmethod def compress(val, mode, nbits=None, lut_function=None): mode = mode.upper() def compress_kmeans(val, nbits): lut, indices = _get_kmeans_lookup_table_and_weight(nbits, val) lut = lut.astype(val.dtype) indices = indices.astype(np.uint8) return lut, indices def compress_uniform(val, nbits): val = val.flatten() val_min = np.amin(val) val_max = np.amax(val) scale = (val_max - val_min) / ((1 << nbits) - 1) indices = np.round(((val - val_min) / (val_max - val_min)) * ((1 << nbits) - 1)).astype( np.uint8 ) lut = np.array(range(0, 1 << nbits)) * scale + val_min lut = lut.astype(val.dtype) return lut, indices def get_nbits_for_unique_mode(val): val = val.flatten() unique_vals = np.unique(val).tolist() for nbits in (1, 2, 4, 6, 8): if len(unique_vals) <= 1 << nbits: return nbits msg = "weight value cannot be represented in an 8 bits palettization. Skipped." logger.warning(msg) return None def compress_unique(val, nbits): val = val.flatten() unique_vals = np.unique(val).tolist() if len(unique_vals) > 1 << nbits: msg = "Too many unique values {} in the weight. Couldn't represented in {} bits.".format( len(unique_vals), nbits ) raise ValueError(msg) lut = [0] * (1 << nbits) lut[: len(unique_vals)] = unique_vals indices = np.zeros((len(val),)) for i, k in enumerate(lut[:len(unique_vals)]): indices += (i + 1) * (val == k).astype(np.int32) indices = indices - 1 assert ( len(np.where(indices == -1)[0]) == 0 ), "weight must be corresponding to one existing indice" lut = np.array(lut).astype(val.dtype) indices = indices.astype(np.uint8) return lut, indices def pack_indices_into_bytes_array(indices, nbits): bitarray = np.unpackbits(indices.reshape(-1, 1), bitorder="little", axis=-1)[:, :nbits] return np.packbits(bitarray.flatten(), bitorder="little") def check_lut_parameters_are_valid(val, lut, indices): if not isinstance(lut, np.ndarray) or not isinstance(indices, np.ndarray): raise ValueError("LUT and indices must be type of numpy array.") if indices.size != val.size: msg = "Indices size ({}) mismatched with the original weight({}).".format( indices.size, val.size ) raise ValueError(msg) if len(indices.shape) != 1 or indices.dtype != np.uint8: msg = "Indices must be a numpy vector of type uint8. Found shape {} with type {}".format( indices.shape, indices.dtype ) raise ValueError(msg) if lut.dtype != val.dtype: msg = "Dtype mismatched between LUT ({}) and weight ({})".format( lut.dtype, val.dtype ) raise ValueError(msg) if not isinstance(val, (np.ndarray, np.generic)): raise ValueError("Only numpy arrays are supported") if mode == "KMEANS": lut, indices = compress_kmeans(val, nbits) elif mode == "UNIFORM": lut, indices = compress_uniform(val, nbits) elif mode == "UNIQUE": nbits = get_nbits_for_unique_mode(val) if nbits is None: return None lut, indices = compress_unique(val, nbits) elif mode == "CUSTOM": lut, indices = lut_function(val) check_lut_parameters_are_valid(val, lut, indices) params = LutParams() params.lut = lut params.shape = val.shape params.indices = pack_indices_into_bytes_array(indices, int(np.log2(lut.shape[0]))) return params @staticmethod def decompress(params): if not isinstance(params, LutParams): raise ValueError("Invalid type of params") return constexpr_lut_to_dense.decompress(params.lut, params.indices, params.shape) def transform_op(self, op): block = op.enclosing_block lut_params = self.compress(op.val.val, self.mode, self.nbits, self.lut_function) if lut_params is None: return if not self.fake_compression: new_var = mb.constexpr_lut_to_dense( indices=lut_params.indices, lut=lut_params.lut, shape=np.uint32(lut_params.shape), before_op=op, name=op.name + "_palettized", ) else: decompressed_val = self.decompress(lut_params) new_var = mb.const( val=decompressed_val, before_op=op, name=op.name + "_fake_palettized", ) op.enclosing_block.replace_uses_of_var_after_op( anchor_op=op, old_var=op.outputs[0], new_var=new_var, no_check_var_types=True, ) block.remove_ops([op]) class AffineQuantParams: def __init__(self, quantized_data=None, zero_point=None, scale=None, axis=None): self.quantized_data = quantized_data self.zero_point = zero_point self.scale = scale self.axis = axis class WeightAffineQuantizer(AbstractQuantizationPass): """ This transform does the following, for each const op and if the "op_selector" return True: - Values are linearly quantized into unsigned 8-bits. - If fake_compression=False, compressed value is encoded via constexpr_affine_dequantize op - If fake_compression=True, compressed value is decompressed and then encoded via const op - Old const is replaced by a newly created operation. """ WEIGHT_AFFINE_QUANTIZATION_MODES = ("LINEAR_SYMMETRIC", "LINEAR") WEIGHT_AFFINE_DTYPES = (types.int8, types.uint8) def __init__(self, fake_compression=False, op_selector=None, mode="linear", dtype=np.int8): super().__init__(op_selector=op_selector) self.fake_compression = fake_compression self.mode = mode.upper() # check mode if not self.mode in WeightAffineQuantizer.WEIGHT_AFFINE_QUANTIZATION_MODES: msg = "Only mode {} supported for weight affine quantization. Got mode {}.".format( WeightAffineQuantizer.WEIGHT_AFFINE_QUANTIZATION_MODES, self.mode ) raise ValueError(msg) # check dtype msg = f"dtype={dtype} is unsupported for affine_quantize_weights." if is_builtin(dtype): self.dtype = dtype else: try: self.dtype = numpy_type_to_builtin_type(dtype) except TypeError: raise ValueError(msg) if self.dtype not in WeightAffineQuantizer.WEIGHT_AFFINE_DTYPES: raise ValueError(msg) def is_valid_op(self, op): if op.op_type == "const" and should_use_weight_file(op.val.val): return True return False @staticmethod def _get_axis(op): axis = 0 var = op.outputs[0] if len(var.child_ops) == 1 and var.child_ops[0].op_type == "conv_transpose": axis = 1 return axis @staticmethod def compress(val, axis, mode, dtype): def _ensure_numerical_range_and_cast(val, low, high, np_dtype): ''' For some cases, the computed quantized data might exceed the data range. For instance, after rounding and addition, we might get `128` for the int8 quantization. This utility function ensures the val in the data range before doing the cast. ''' val = np.minimum(val, high) val = np.maximum(val, low) return val.astype(np_dtype) mode = mode.upper() mode_dtype_to_range = { (types.int8, "LINEAR"): (-128, 127), (types.int8, "LINEAR_SYMMETRIC"): (-127, 127), (types.uint8, "LINEAR"): (0, 255), (types.uint8, "LINEAR_SYMMETRIC"): (0, 254), } if not isinstance(val, (np.ndarray, np.generic)): raise ValueError("Only numpy arrays are supported") params = AffineQuantParams() axes = tuple([i for i in range(len(val.shape)) if i != axis]) val_min = np.amin(val, axis=axes, keepdims=True) val_max = np.amax(val, axis=axes, keepdims=True) if mode == "LINEAR_SYMMETRIC": # For the linear_symmetric mode, the range is symmetrical to 0 max_abs = np.maximum(np.abs(val_min), np.abs(val_max)) val_min = -max_abs val_max = max_abs else: assert mode == "LINEAR" # For the linear mode, we need to make sure the data range contains `0` val_min = np.minimum(0.0, val_min) val_max = np.maximum(0.0, val_max) q_val_min, q_val_max = mode_dtype_to_range[(dtype, mode)] # Set the zero point to symmetric mode np_dtype = nptype_from_builtin(dtype) if mode == "LINEAR_SYMMETRIC": if dtype == types.int8: params.zero_point = (0 * np.ones(val_min.shape)).astype(np.int8) else: assert dtype == types.uint8 params.zero_point = (127 * np.ones(val_min.shape)).astype(np.uint8) else: assert mode == "LINEAR" params.zero_point = (q_val_min * val_max - q_val_max * val_min) / (val_max - val_min) params.zero_point = np.round(params.zero_point) params.zero_point = _ensure_numerical_range_and_cast(params.zero_point, q_val_min, q_val_max, np_dtype) # compute the params params.scale = (val_max - val_min) / (q_val_max - q_val_min) params.scale = params.scale.astype(val.dtype).squeeze() params.quantized_data = np.round( val * (q_val_max - q_val_min) / (val_max - val_min) ) params.quantized_data = (params.quantized_data + params.zero_point) params.quantized_data = _ensure_numerical_range_and_cast(params.quantized_data, q_val_min, q_val_max, np_dtype) params.zero_point = params.zero_point.squeeze() params.axis = axis return params @staticmethod def decompress(params): if not isinstance(params, AffineQuantParams): raise ValueError("Invalid type of params") return constexpr_affine_dequantize.decompress( params.quantized_data, params.zero_point, params.scale, params.axis ) def transform_op(self, op): block = op.enclosing_block quant_params = self.compress(op.val.val, self._get_axis(op), self.mode, self.dtype) if not self.fake_compression: new_var = mb.constexpr_affine_dequantize( quantized_data=quant_params.quantized_data, zero_point=quant_params.zero_point, scale=quant_params.scale, axis=quant_params.axis, before_op=op, name=op.name + "_affine_quantized", ) else: decompressed_val = self.decompress(quant_params) new_var = mb.const( val=decompressed_val, before_op=op, name=op.name + "_fake_affine_quantized", ) op.enclosing_block.replace_uses_of_var_after_op( anchor_op=op, old_var=op.outputs[0], new_var=new_var, no_check_var_types=True, ) block.remove_ops([op]) class WeightDecompressor(AbstractQuantizationPass): """ This graph pass transforms the constexpr ops back into mb.const op. constexpr ops includes: (1) constexpr_affine_dequantize (2) constexpr_lut_to_dense (3) constexpr_sparse_to_dense """ def __init__(self, op_selector): super().__init__(op_selector=op_selector) def is_valid_op(self, op): return op.op_type in ( "constexpr_affine_dequantize", "constexpr_lut_to_dense", "constexpr_sparse_to_dense", ) def transform_op(self, op): block = op.enclosing_block decompressed_val = op.value_inference() new_var = mb.const( val=decompressed_val, before_op=op, name=op.name, ) op.enclosing_block.replace_uses_of_var_after_op( anchor_op=op, old_var=op.outputs[0], new_var=new_var, no_check_var_types=True, force_replace=True, ) block.remove_ops([op])