Compression Configuration

class coremltools.optimize.coreml.OpLinearQuantizerConfig(mode: str = 'linear_symmetric', dtype: type = <class 'numpy.int8'>, weight_threshold: ~typing.Optional[int] = 2048)[source]
Parameters:
mode: str

Mode for linear quantization:

  • "linear_symmetric" (default): Input data are quantized in the range [-R, R], where \(R = max(abs(w_r))\).

  • "linear": Input data are quantized in the range \([min(w_r), max(w_r)]\).

dtype: np.generic or mil.type type

Determines the quantized data type (int8/uint8).

  • The allowed values are:
    • np.int8 (the default)

    • np.uint8

    • coremltools.converters.mil.mil.types.int8

    • coremltools.converters.mil.mil.types.uint8

weight_threshold: int

The size threshold, above which weights are pruned. That is, a weight tensor is pruned only if its total number of elements are greater than weight_threshold.

For example, if weight_threshold = 1024 and a weight tensor is of shape [10, 20, 1, 1], hence 200 elements, it will not be pruned.

  • If not provided, it will be set to 2048, in which weights bigger than 2048 elements are compressed.

class coremltools.optimize.coreml.OpThresholdPrunerConfig(threshold: float = 1e-12, minimum_sparsity_percentile: float = 0.5, weight_threshold: Optional[int] = 2048)[source]

All weights with absolute value smaller than threshold are changed to 0, and the tensor is stored in a sparse format.

For example, given the following:

  • weight = [0.3, -0.2, -0.01, 0.05]

  • threshold = 0.03

The sparsified weight would be [0.3, -0.2, 0, 0.05].

Parameters:
threshold: float

All weight values above this threshold are set to 0.

  • Default value is 1e-12.

minimum_sparsity_percentile: float

The sparsity level must be above this value for the weight representation to be stored in the sparse format rather than the dense format.

For example, if minimum_sparsity_percentile = 0.6 and the sparisty level is 0.54; that is, 54% of the weight values are exactly 0, then the resulting weight tensor will be stored as a dense const op, and not converted to the constsexpr_sparse_to_dense op (which stores the weight values in a sparse format).

  • Must be a value between 0 and 1.

  • Default value is 0.5.

weight_threshold: int

The size threshold, above which weights are pruned. That is, a weight tensor is pruned only if its total number of elements are greater than weight_threshold.

For example, if weight_threshold = 1024 and a weight tensor is of shape [10, 20, 1, 1], hence 200 elements, it will not be pruned.

  • If not provided, it will be set to 2048, in which weights bigger than 2048 elements are compressed.

class coremltools.optimize.coreml.OpMagnitudePrunerConfig(target_sparsity: Optional[float] = None, block_size: Optional[int] = None, n_m_ratio: Optional[Tuple[int, int]] = None, dim: Optional[int] = None, weight_threshold: Optional[int] = 2048)[source]

Prune the weight with a constant sparsity percentile, which can be specified by either target_sparsity or n_m_ratio.

If target_sparsity is set, where n = floor(size_of_weight_tensor * target_sparsity), the n lowest absolute weight values are changed to 0. For example, given the following:

  • weight = [0.3, -0.2, -0.01, 0.05]

  • target_sparsity = 0.75

The sparsified weight would be [0.3, 0, 0, 0].

If block_size is set, then weights are pruned in a block structured manner; that is, chunks of weight values, as big as the block_size, will be set to 0. Block sparsity can only be applied to linear and conv layers. For example:

# Given a 4 x 2 weight with the following value, and block_size = 2, dim = 0.
[
    [1, 3],
    [-6, -7],
    [0, 3],
    [-9, 2],
]

# We first flatten the matrix along axis = 0.
[1, -6, 0, -9, 3, -7, 3, 2]

# For block size 2, the L2 norm will be compute of first 2 elements, then the second and 3rd element and so on.
[6.08, 9.00, 7.62, 3.61]

# Then the smallest values will be picked to prune. So if target_sparsity = 0.5, then the blocks that will be
# pruned will be with ones with L2 norm value of 6.08 and 3.61. And hence, the elements in the first and third
# block are pruned. Resulting in the following flatten pruned tensor:
[0, 0, 0, -9, 3, -7, 0, 0]

# The final pruned tensor is:
[
    [0, 3],
    [0, -7],
    [0, 0],
    [-9, 0],
]

The n_m_ratio triggers n:m pruning along the dim axis. In n:m pruning, out of every m elements, n with lowest magnitude are set to 0. For more information, see Learning N:M Fine-Grained Structured Sparse Neural Networks From Scratch.

n:m pruning can be applied only to linear and conv layers.

Example

# Given a 4 x 4 weight of
[
    [3, 4, 7, 6],
    [1, 8, -3, -8],
    [-2, -3, -4, 0],
    [5, 4, -3, -2],
]

# For n_m_ratio = (1, 2) with axis = 1 (default), the resulting pruned weight is
[
    [0, 4, 7, 0],
    [0, 8, 0, -8],
    [0, -3, -4, 0],
    [5, 0, -3, 0],
]

# For axis = 0, we get
[
    [3, 0, 7, 0],
    [0, 8, 0, -8],
    [0, 0, -4, 0],
    [5, 4, 0, -2],
]
Parameters:
target_sparsity: float

The percentage of sparsity for compression, which needs to be in the range [0, 1]. When 0, no sparsification occurs. For 1, all weights become 0.

block_size: int

Block size for inducing block sparsity. This is applied on the dim dimension of the parameter. Having the zeros aligned in the parameter helps gain latency/memory performance on-device.

  • If set, must be greater than 1 to enable block sparsity.

  • Block sparsity can be applied only to linear and conv layers.

  • The channel will be padded with 0 if it is not divisible by block_size.

n_m_ratio: tuple[int]

A tuple of two integers which specify the ratio for n:m pruning.

  • n must be smaller or equal to m.

  • The channel would be padded with 0 if it is not divisible by m.

dim: int

Dimension where the block sparsity or n:m sparsity is applied.

  • Must be either 0 or 1.

  • The default value for block sparsity is 0 (output channel).

  • The default value for n:m sparsity is 1 (input channel).

weight_threshold: int

The size threshold, above which weights are pruned. That is, a weight tensor is pruned only if its total number of elements is greater than weight_threshold.

For example, if weight_threshold = 1024 and a weight tensor is of shape [10, 20, 1, 1], hence 200 elements, it will not be pruned.

  • If not provided, it will be set to 2048, in which weights bigger than 2048 elements are compressed.

class coremltools.optimize.coreml.OpPalettizerConfig(mode: str = 'kmeans', nbits: Optional[int] = None, lut_function: Optional[Callable] = None, weight_threshold: Optional[int] = 2048)[source]
Parameters:
nbits: int

Number of bits per weight. Required for kmeans or uniform mode, but must not be set for unique or custom mode. A LUT would have 2nbits entries, where nbits can be {1, 2, 4, 6, 8}.

mode: str

Determine how the LUT is constructed by specifying one of the following:

  • "kmeans" (default): The LUT is generated by k-means clustering, a method of vector quantization that groups similar data points together to discover underlying patterns by using a fixed number (k) of clusters in a dataset. A cluster refers to a collection of data points aggregated together because of certain similarities. nbits is required.

  • "uniform": The LUT is generated by a linear histogram.

    • [v_min, v_min + scale, v_min + 2 * scale, ..., v_max]

    • Where the weight is in the range [v_min, v_max], and scale = (v_max - v_min) / (1 << nbits - 1).

    • nbits is required.

    A histogram is a representation of the distribution of a continuous variable, in which the entire range of values is divided into a series of intervals (or bins) and the representation displays how many values fall into each bin. Linear histograms have one bin at even intervals, such as one bin per integer.

  • "unique": The LUT is generated by unique values in the weights. The weights are assumed to be on a discrete lattice but stored in a float data type. This parameter identifies the weights and converts them into the palettized representation.

    Do not provide nbits for this mode. nbits is picked up automatically, with the smallest possible value in {1, 2, 4, 6, 8} such that the number of the unique values is <= (1 << nbits). If the weight has > 256 unique values, the compression is skipped.

    For example:

    • If the weights are {0.1, 0.2, 0.3, 0.4} and nbits=2, the weights are converted to {00b, 01b, 10b, 11b}, and the generated LUT is [0.1, 0.2, 0.3, 0.4].

    • If the weights are {0.1, 0.2, 0.3, 0.4} and nbits=1, nothing happens because the weights are not a 1-bit lattice.

    • If the weights are {0.1, 0.2, 0.3, 0.4, 0.5} and nbits=2, nothing happens because the weights are not a 2-bit lattice.

  • "custom": The LUT and palettization parameters are calculated using a custom function. If this mode is selected then lut_function must be provided.

    Do not provide nbits for this mode. The user should customize nbits in the lut_function implementation.

lut_function: callable

A callable function which computes the weight palettization parameters. This must be provided if the mode is set to "custom".

weight: np.ndarray

A float precision numpy array.

Returns: lut: list[float]

The lookup table.

indices: list[int]

A list of indices for each element.

The following is an example that extract the top_k elements as the LUT. Given that weight = [0.1, 0.5, 0.3, 0.3, 0.5, 0.6, 0.7], the lut_function produces lut = [0, 0.5, 0.6, 0.7], indices = [0, 1, 0, 0, 2, 3].

def lut_function(weight):
    # In this example, we assume elements in the weights >= 0
    weight = weight.flatten()
    nbits = 4

    # Get the LUT, from extracting top k maximum unique elements in the weight to be the LUT
    # Note that k = 1 << nbits - 1, so we have the first element be 0
    unique_elements = np.unique(weight)
    k = (1 << nbits) - 1
    top_k = np.partition(weight, -k)[-k:]
    np.sort(top_k)
    lut = [0.0] + top_k.tolist()

    # Compute the indices
    mapping = {v: idx for idx, v in enumerate(lut)}
    indices = [mapping[v] if v in mapping else 0 for v in weight]

    return lut, indices
weight_threshold: int

The size threshold, above which weights are pruned. That is, a weight tensor is pruned only if its total number of elements are greater than weight_threshold.

For example, if weight_threshold = 1024 and a weight tensor is of shape [10, 20, 1, 1], hence 200 elements, it will not be pruned.

  • If not provided, it will be set to 2048, in which weights bigger than 2048 elements are compressed.

class coremltools.optimize.coreml.OptimizationConfig(global_config: Optional[OpCompressorConfig] = None, op_type_configs: Optional[OpCompressorConfig] = None, op_name_configs: Optional[OpCompressorConfig] = None, is_deprecated: bool = False, op_selector: Optional[Callable] = None)[source]

A configuration wrapper that enables fine-grained control when compressing a model, Providing the following levels: global, op type, and op name.

  1. global_config: The default configuration applied to all ops / consts.

  2. op_type_configs: Configurations applied to specific op type. It overrides global_config.

  3. op_name_configs: Configurations applied to specific constant or op instance. It overrides global_config and op_type_configs.

The following is an example that constructs an optimization config for weight palettization.

from coremltools.optimize.coreml import OpPalettizerConfig, OptimizationConfig

# The default global configuration is 8 bits palettization with kmeans
global_config = OpPalettizerConfig(mode="kmeans", nbits=8)

# We use 2 bits palettization for convolution layers, and skip the compression for linear layers
op_type_configs = {
    "conv": OpPalettizerConfig(mode="kmeans", nbits=2),
    "linear": None,
}

# We want a convolution layer named "conv_1" to have a 4 bits palettization with a different mode
op_name_configs = {
    "conv_1": OpPalettizerConfig(mode="uniform", nbits=4),
}

# Now we can put all configuration across three levels to construct an OptimizationConfig object
config = OptimizationConfig(
    global_config=global_config,
    op_type_configs=op_type_configs,
    op_name_configs=op_name_configs,
)
Parameters:
global_config: OpCompressorConfig

Config to be applied globally to all supported ops.

op_type_configs: dict[str, OpCompressorConfig]

Op type level configs applied to a specific op class.

  • The keys of the dictionary are the string of the op type, and the values are the corresponding OpCompressorConfig.

  • An op type will not be compressed if the value is set to None.

op_name_configs: dict[str, OpCompressorConfig]

Op instance level configs applied to a specific constant or op.

  • The keys of the dictionary are the name of a constant or an op instance, and the values are the corresponding OpCompressorConfig.

  • An op instance will not be compressed if the value is set to None.

  • You can use coremltools.optimize.coreml.get_weights_metadata to get the name of the constants / op instances in the model.

classmethod from_dict(config_dict: Dict[str, Any]) OptimizationConfig[source]

Construct an OptimizationConfig instance from a nested dictionary. The dictionary should have the structure that only contains (if any) the following four str keys:

  • "config_type": Specify the configuration class type.

  • "global_config": Parameters for global_config.

  • "op_type_configs": A nested dictionary for op_type_configs.

  • "op_name_config": A nested dictionary for op_name_configs.

The following is a nested dictionary that creates an optimization config for weight palettization:

config_dict = {
    "config_type": "OpPalettizerConfig",
    "global_config": {
        "mode": "kmeans",
        "nbits": 4,
    },
    "op_type_configs": {
        "conv": {
            "mode": "uniform",
            "nbits": 1,
        }
    },
    "op_name_configs": {
        "conv_1": {
            "mode": "unique",
        }
    },
}

Note that you can override the config_type. For instance, if you want to do threshold-based pruning to the model in addition to the convolution layers in which magnitude pruning is applied, the following is an example of the nested dictionary:

config_dict = {
    "config_type": "OpThresholdPrunerConfig",
    "global_config": {
        "threshold": 0.01,
    },
    "op_type_configs": {
        "conv": {
            "config_type": "OpMagnitudePrunerConfig",
            "n_m_ratio": [3, 4],
        }
    },
}
Parameters:
config_dict: dict[str, Any]

A dictionary that represents the configuration structure.

classmethod from_yaml(yml: Union[IO, str]) OptimizationConfig[source]

Construct an OptimizationConfig instance from a YAML file. The YAML file should have the structure that only contains (if any) the following four str keys:

  • "config_type": Specify the configuration class type.

  • "global_config": Parameters for global_config.

  • "op_type_configs": A nested dictionary for op_type_configs.

  • "op_name_config": A nested dictionary for op_name_configs.

The following is a YAML file that creates an optimization config for weight palettization:

config_type: OpPalettizerConfig
global_config:
    mode: kmeans
    nbits: 4
op_type_configs:
    conv:
        mode: uniform
        nbits: 1
op_name_configs:
    conv_1:
        mode: unique

Note that you can override the config_type. For instance, if you want to do threshold-based pruning to the model in addition to the convolution layers in which magnitude pruning is applied, the following is an example of the YAML file:

config_type: OpThresholdPrunerConfig
global_config:
    threshold: 0.01
op_type_configs:
    conv:
        config_type: OpMagnitudePrunerConfig
        n_m_ratio: [3, 4]
Parameters:
yml: str, IO

A YAML file or the path to the file.

set_global(op_config: OpCompressorConfig)[source]

Sets the global config that would be applied to all constant ops.

from coremltools.optimize.coreml import OpPalettizerConfig, OptimizationConfig

config = OptimizationConfig()
global_config = OpPalettizerConfig(mode="kmeans", nbits=8)
config.set_global(global_config)
Parameters:
op_config: OpCompressorConfig

Config to be applied globally to all supported ops.

set_op_name(op_name: str, op_config: OpCompressorConfig)[source]

Sets the compression config at the level of constant / op instance by name.

from coremltools.optimize.coreml import OpPalettizerConfig, OptimizationConfig

config = OptimizationConfig()
op_config = OpPalettizerConfig(mode="kmeans", nbits=2)
config.set_op_name("conv_1", op_config)

Note that, in order to get the name of a constant or an op instance, please refer to the coremltools.optimize.coreml.get_weights_metadata API.

Parameters:
op_name: str

The name of a constant or an op instance.

op_config: OpCompressorConfig

Op instance level config applied to a specific constant or op with name op_name.

set_op_type(op_type: str, op_config: OpCompressorConfig)[source]

Sets the compression config at the level of op type.

from coremltools.optimize.coreml import OpPalettizerConfig, OptimizationConfig

config = OptimizationConfig()
conv_config = OpPalettizerConfig(mode="kmeans", nbits=2)
config.set_op_type("conv", conv_config)
Parameters:
op_type: str

The type of an op. For instance, "conv", "linear".

op_config: OpCompressorConfig

Op type level config applied to a specific op class op_type.