# Copyright (c) 2023, Apple Inc. All rights reserved.
#
# Use of this source code is governed by a BSD-3-clause license that can be
# found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
from collections import OrderedDict
from typing import Dict, List, Optional
import numpy as np
from attrs import define, field, validators
from tqdm import tqdm
from coremltools.converters.mil.frontend.milproto import load as _milproto_to_pymil
from coremltools.converters.mil.mil.passes.graph_pass import PassOption
from coremltools.converters.mil.mil.passes.pass_registry import PASS_REGISTRY
from coremltools.models import model as _model
from coremltools.models import utils as _model_utils
from coremltools.optimize.coreml import OptimizationConfig as _OptimizationConfig
from coremltools.optimize.coreml._config import _MetaDataDict
from ._quantization_passes import WeightDecompressor as _WeightDecompressor
def _is_valid_const(val, weight_threshold):
return isinstance(val, np.ndarray) and val.size >= weight_threshold
def _multifunction_unsupported(func):
"""
The decorator marks the PTQ API that doesn't support the multifunction model.
We should use this decorator until the radar is fixed:
rdar://126084385 ([Infra] Figure out the story of PTQ or other passes operate on loaded Mutli-function model)
Note that the API must take `mlmodel` with type of `MLModel` as an input.
"""
def decorator(*args, **kwargs):
num_args = func.__code__.co_argcount
arg_names = list(func.__code__.co_varnames)[:num_args]
param_dict = {k: v for k, v in zip(arg_names, args)}
model = param_dict.get("mlmodel", None)
if model is None:
raise ValueError(
f'Function {func} decorated with _multifunction_unsupported must takes "mlmodel" as an input.'
)
if model._is_multifunction():
raise ValueError(f"{func} is not supported for a multifunction model.")
return func(*args, **kwargs)
decorator.__doc__ = func.__doc__
return decorator
[docs]
@_multifunction_unsupported
def linear_quantize_weights(
mlmodel: _model.MLModel, config: _OptimizationConfig, joint_compression: bool = False
):
"""
Utility function to convert a float precision MLModel of type ``mlprogram``, which uses
float-precision weights, into a compressed MLModel that uses n-bit weights (currently only
support n=4 and n=8). This is achieved by converting the float weight values that are stored in
the ``const`` op into the ``constexpr_affine_dequantize`` or ``constexpr_blockwise_shift_scale``
op (based on model's minimum deployment target).
This function uses linear quantization on the float weights, providing up to 4x (for 4-bit)
savings in storage compared to float 16, or up to 4x savings compared to float 32.
All computation at runtime uses float precision; the precision of the intermediate
tensors and the compute precision of the ops are not altered.
For each weight, this utility function converts the weight into the int4/8 or uint4/8 type using
either `linear interpolation` (``"linear"`` mode) or `linear symmetric interpolation`
(``"linear_symmetric"`` mode, the default).
**Linear interpolation**
The following description uses 8-bit quantization to illustrate, and 4-bit is similar to it.
Linear interpolation (``"linear"`` mode) maps the min/max of the float
range to the 8-bit integer range ``[low, high]`` using a zero point (also called quantization bias, or
offset) and a scale factor. For the int8 quantization, ``[low, high] = [-128, 127]``, while uint8
quantization uses range ``[0, 255]``.
``"linear"`` mode uses the quantization formula:
.. math::
w_r = s * (w_q - z)
Where:
* :math:`w_r` and :math:`s` are of type float.
* :math:`w_r`` represents the float precision weight.
* :math:`s` represents the scale.
* :math:`w_q` and :math:`z` are of type 8-bit integer.
* :math:`w_q` represents quantized weight.
* :math:`z` represents the zero point.
Quantized weights are computed as follows:
.. math::
w_q = cast\_to\_8\_bit\_integer(w_r / s + cast\_to\_float(z))
Note: :math:`cast\_to\_8\_bit\_integer` is the process of clipping the input to range ``[low, high]`` followed by rounding and casting to 8-bit integer.
In ``"linear"`` mode, ``s, z`` are computed by mapping the original float range
``[A, B]`` into the 8-bit integer range ``[-128, 127]`` or ``[0, 255]``. That is, you are solving the
following linear equations:
* ``B = s * (high - z)``
* ``A = s * (low - z)``
The equations result in the following:
* ``s = (B - A) / (high - low)``
* ``z = cast_to_8_bit_integer((low * B - high * A) / (B - A))``
When the rank of weight ``w`` is 1, then ``s`` and ``z`` are both scalars. When the
rank of the weight is greater than 1, then ``s`` and ``z`` are both vectors. In that
case, scales are computed per `channel`, in which `channel` is the output dimension,
which corresponds to the first dimension for ops such as ``conv`` and ``linear``, and
the second dimension for the ``conv_transpose`` op.
For ``"linear"`` mode, :math:`A = min(w_r)`, :math:`B = max(w_r)`.
**Linear symmetric interpolation**
With linear symmetric interpolation (``"linear_symmetric"`` mode, the default), rather than
mapping the exact min/max of the float range to the quantized range, the function
chooses the maximum absolute value between the min/max, which results in a
floating-point range that is symmetric with respect to zero. This also makes the resulting zero
point ``0`` for int8 weight and ``127`` for uint8 weight.
For ``"linear_symmetric"`` mode:
* :math:`A = -R` and :math:`B = R`, where :math:`R = max(abs(w_r))`.
* This function maps to the range of ``[-127, 127]`` for int8 weight and ``[0, 254]`` for uint8 weight.
* The result is ``s=(B-A)/254`` -> ``s=2R/254`` -> ``s=R/127``.
* Solving for ``z``:
* int8: ``z = (-127 * R + 127 * R)/2R`` -> ``z=0``.
* uint8: ``z = (0 * R + 254 * R)/2R`` -> ``z=127``.
Parameters
----------
mlmodel: MLModel
Model to be quantized. This MLModel should be of type ``mlprogram``.
config: OptimizationConfig
An :py:class:`OptimizationConfig` object that specifies the parameters for weight quantization.
joint_compression: bool
Specification of whether or not to further compress the already-compressed input MLModel to a
jointly compressed MLModel. See the `blockwise_palettize_weights` graph pass for information
about which compression schemas could be further jointly palettized.
Take "palettize + quantize" as an example of joint compression, where the input MLModel is already
palettized, and the palettization's lookup table will be further quantized. In such an example,
the weight values are represented by ``constexpr_blockwise_shift_scale`` + ``constexpr_lut_to_dense`` ops:
lut(int8) -> constexpr_blockwise_shift_scale -> lut(fp16) -> constexpr_lut_to_dense -> dense(fp16)
Returns
-------
model: MLModel
The quantized MLModel instance.
Examples
--------
.. sourcecode:: python
import coremltools as ct
import coremltools.optimize as cto
model = ct.coreml.models.MLModel("my_model.mlpackage")
config = cto.coreml.OptimizationConfig(
global_config=cto.coreml.OpLinearQuantizerConfig(mode="linear_symmetric")
)
compressed_model = cto.coreml.linear_quantize_weights(model, config)
"""
blockwise_weight_quantizer = PASS_REGISTRY["compression::linear_quantize_weights"]
blockwise_weight_quantizer.set_options(
[PassOption("config", config), PassOption("joint_compression", joint_compression)]
)
return _model_utils._apply_graph_pass(mlmodel, blockwise_weight_quantizer)
[docs]
@_multifunction_unsupported
def palettize_weights(
mlmodel: _model.MLModel, config: _OptimizationConfig, joint_compression: bool = False
):
"""
Utility function to convert a float precision MLModel of type ``mlprogram`` to a
compressed MLModel by reducing the overall number of weights using one or more lookup tables
(LUT). A LUT contains a list of float values. An ``n-bit`` LUT has :math:`2^{n-bits}` entries.
For example, a float weight vector such as ``{0.3, 0.3, 0.5, 0.5}`` can be compressed
using a 1-bit LUT: ``{0.3, 0.5}``. In this case the float vector can be replaced
with a 1-bit vector ``{0, 0, 1, 1}``.
This function iterates over all the weights in the ``mlprogram``, discretizes its values,
and constructs the LUT according to the algorithm specified in ``mode``. The float
values are then converted to the ``n-bit`` values, and the LUT is saved alongside each
weight. The ``const`` ops storing weight values are replaced by
``constexpr_lut_to_dense`` ops.
At runtime, the LUT and the ``n-bit`` values are used to reconstruct the float weight
values, which are then used to perform the float operation the weight is feeding into.
Consider the following example of ``"uniform"`` mode (a linear histogram):
* ``nbits = 4``
* ``mode = "uniform"``
* ``weight = [0.11, 0.19, 0.3, 0.08, 0.0, 0.02]``
The weight can be converted to a palette with indices ``[0, 1, 2, 3]`` (2 bits). The
indices are a byte array.
The data range ``[0.0, 0.3]`` is divided into four partitions linearly, which is
``[0.0, 0.1, 0.2, 0.3]``.
* The LUT would be ``[0.0, 0.1, 0.2, 0.3]``.
* The weight is rounded to ``[0.1, 0.2, 0.3, 0.1, 0.0, 0.0]`` and represented in
the palette as indices ``[01b, 10b, 11b, 01b, 00b, 00b]``.
Parameters
----------
mlmodel: MLModel
Model to be converted by a LUT. This MLModel should be of type ``mlprogram``.
config: OptimizationConfig
An :py:class:`OptimizationConfig` object that specifies the parameters for weight palettization.
joint_compression: bool
Specification of whether or not to further compress the already-compressed input MLModel to a
jointly compressed MLModel. See the `channelwise_palettize_weights` graph pass for information
about which compression schemas could be further jointly palettized.
Take "prune + palettize" as an example of joint compression, where the input MLModel is already
pruned, and the non-zero entries will be further palettized. In such an example, the weight values are
represented by ``constexpr_lut_to_sparse`` + ``constexpr_sparse_to_dense`` ops:
``lut(sparse)`` -> ``constexpr_lut_to_sparse`` -> ``weight(sparse)`` -> ``constexpr_sparse_to_dense`` -> ``weight(dense)``
Returns
-------
model: MLModel
The palettized MLModel instance.
Examples
--------
.. sourcecode:: python
import coremltools as ct
import coremltools.optimize as cto
model = ct.models.MLModel("my_model.mlpackage")
config = cto.coreml.OptimizationConfig(
global_config=cto.coreml.OpPalettizerConfig(mode="kmeans", nbits=4)
)
compressed_model = cto.coreml.palettize_weights(model, config)
"""
weight_palettizer = PASS_REGISTRY["compression::palettize_weights"]
weight_palettizer.set_options(
[PassOption("config", config), PassOption("joint_compression", joint_compression)]
)
return _model_utils._apply_graph_pass(mlmodel, weight_palettizer)
[docs]
@_multifunction_unsupported
def prune_weights(
mlmodel: _model.MLModel, config: _OptimizationConfig, joint_compression: bool = False
):
"""
Utility function to convert a float precision MLModel of type ``mlprogram`` to a
compressed MLModel using sparse representation. The ``const`` ops storing weight
values are replaced by ``constexpr_sparse_to_dense`` ops.
This function is useful if the model is trained with pruning techniques so that
a lot of weights have zero values. If a large percentage of weight values are zero,
a sparse representation is more efficient than a dense one (the default).
The sparsified weights are stored in a bit mask. If the weight values are
``{0, 0, 0, 0, 0, 0, 0, 56.3}``, its sparse representation contains a bit mask with
ones on locations where the value is non-zero: ``00000001b``. This is accompanied by
non-zero data, which is a size-1 vector of value ``{56.3}``.
For example, given the following:
* ``weight = [0.3, 0, 0, 0.5, 0, 0]``
* ``non_zero_data, bit_mask = sparsify(weight)``
The indices of the non-zero elements are:
* ``non_zero_data = [0.3, 0.5]``
* ``bit_mask = "100100"``
Parameters
----------
mlmodel: MLModel
Model to be sparsified. This MLModel should be of type ``mlprogram``.
config: OptimizationConfig
An :py:class:`OptimizationConfig` object that specifies the parameters for weight pruning.
joint_compression: bool
Specification of whether or not to further prune the already-compressed input MLModel to a
jointly compressed MLModel. See the `prune_weights` graph pass for information
about which compression schemas could be further pruned.
Take "quantize + prune" as an example of joint compression, where the input MLModel is already
quantized, and it will be further pruned. In such an example, the weight values are
represented by ``constexpr_sparse_blockwise_shift_scale`` + ``constexpr_sparse_to_dense`` ops:
quantized(sparse) -> constexpr_sparse_blockwise_shift_scale -> weight(sparse) -> constexpr_sparse_to_dense -> weight(dense)
Returns
-------
model: MLModel
The sparse MLModel instance.
Examples
--------
.. sourcecode:: python
import coremltools as ct
import coremltools.optimize as cto
model = ct.models.MLModel("my_model.mlpackage")
config = cto.coreml.OptimizationConfig(
global_config=cto.coreml.OpThresholdPrunerConfig(threshold=1e-12)
)
compressed_model = cto.coreml.prune_weights(model, config)
"""
weight_pruner = PASS_REGISTRY["compression::prune_weights"]
weight_pruner.set_options(
[PassOption("config", config), PassOption("joint_compression", joint_compression)]
)
return _model_utils._apply_graph_pass(mlmodel, weight_pruner)
[docs]
@_multifunction_unsupported
def decompress_weights(mlmodel: _model.MLModel):
"""
Utility function to convert weights that are sparse or palettized or affine quantized, back to the float format.
That is, convert any of the following three ops to ``mb.const``:
(1) ``constexpr_affine_dequantize``
(2) ``constexpr_lut_to_dense``
(3) ``constexpr_sparse_to_dense``
Parameters
----------
mlmodel: MLModel
Model which will be decompressed.
Returns
-------
model: MLModel
The MLModel with no ``constexpr`` ops included.
Examples
--------
.. sourcecode:: python
import coremltools as ct
model = ct.models.MLModel("my_compressed_model.mlpackage")
decompressed_model = ct.optimize.coreml.decompress_weights(model)
"""
weight_decompressor = _WeightDecompressor(op_selector=lambda op: True)
return _model_utils._apply_graph_pass(mlmodel, weight_decompressor)