Training-Time Palettization

Palettization is a mechanism for compressing a model by clustering the model’s float weights into a look-up table (LUT) of centroids and indices.

Palettization is implemented as an extension of PyTorch’s QAT APIs. It works by inserting palettization layers in appropriate places inside a model. The model can then be fine-tuned to learn the new palettized layers’ weights in the form of a LUT and indices.


class coremltools.optimize.torch.palettization.ModuleDKMPalettizerConfig(n_bits: Optional[int] = None, weight_threshold: int = 2048, milestone: int = 0, cluster_dim: Optional[int] = None, quant_min: int = -128, quant_max: int = 127, dtype: Union[str, dtype] = torch.qint8, cluster_dtype: str = 'f32', quantize_activations: bool = False, partition_size: int = 2000000000, cluster_permute: Optional[tuple] = None, palett_max_mem: float = 1.0, kmeans_max_iter: int = 3, prune_threshold: float = 0.0, kmeans_init: str = 'cpu.kmeans++', kmeans_opt1d_threshold: int = 1024, enforce_zero: bool = False, palett_mode: str = 'dkm', palett_tau: float = 0.0001, palett_epsilon: float = 0.0001, palett_lambda: float = 0.0, add_extra_centroid: bool = False, palett_cluster_tol: float = 0.05)[source]

Configuration class for specifying global and module-level options for palettization algorithm implemented in DKMPalettizer.

The parameters specified in this config control the DKM algorithm, described in DKM: Differentiable K-Means Clustering Layer for Neural Network Compression.

For most use cases, the only parameters you need to specify are n_bits, weight_threshold, and milestone.


Most of the parameters in this class are meant for advanced use cases and for further fine-tuning the DKM algorithm. The default values usually work for a majority of tasks.


Change the following parameters only when you use activation quantization in conjunction with DKM weight palettization: quant_min, quant_max, dtype, and quantize_activations.

  • n_bits (int) – Number of clusters. The number of clusters used is \(2^{n\_bits}\). Defaults to 4 for linear layers and 2 for all other layers.

  • weight_threshold (int) – A module is only palettized if the number of elements in its weight matrix exceeds weight_threshold. Defaults to 2048.

  • milestone (int) – Step or epoch at which palettization begins. Defaults to 0.

  • cluster_dim (int, optional) – The dimension of each cluster. Defaults to 1.

  • quant_min – (int, optional): The minimum value for each element in the weight clusters if they are quantized. Defaults to -128.

  • quant_max – (int, optional): The maximum value for each element in the weight clusters if they are quantized. Defaults to 127

  • dtype (torch.dtype, optional) – The dtype to use for quantizing the activations. Only applies when quantize_activations is True. Defaults to torch.qint8.

  • cluster_dtype (str, optional) – dtype to use for quantizing the clusters. Allowed options are 'i8', 'u8', 'f16', 'bf16', 'f32'. Defaults to 'f32', i.e., by default, the clusters aren’t quantized.

  • quantize_activations (bool, optional) – When True, the activation are quantized. Defaults to False.

  • partition_size (int, optional) – partition_size helps in per channel palettization. Defaults to 2000000000.

  • cluster_permute (tuple, optional) – Permutation order to apply to weight partitions. Defaults to None.

  • palett_max_mem (float, optional) – Proportion of available GPU memory that should be used for palettization. Defaults to 1.0.

  • kmeans_max_iter (int, optional) – Maximum number of differentiable k-means iterations. Defaults to 3.

  • prune_threshold (float, optional) – Hard-shrinks weights between [-prune_threshold, prune_threshold] to zero. Useful for joint pruning and palettization. Defaults to 0.0.

  • kmeans_init (str, optional) – k-means algorithm to use. Allowed options are efficient_kmeans, cpu.kmeans++ and kmeans_pp. Defaults to cpu.kmeans++.

  • kmeans_opt1d_threshold (int, optional) – Channel threshold to decide if opt1d kmeans should be used. Defaults to 1024.

  • enforce_zero (bool, optional) – If True, enforces the LUT centroid which is closest to the origin to be fixed to zero. Defaults to False.

  • palett_mode (str, optional) – Criteria to calculate attention during k-means. Allowed options are gsm, dkm and hard. Defaults to dkm.

  • palett_tau (float, optional) – Temperature factor for softmax used in DKM algorithm. Defaults to 0.0001.

  • palett_epsilon (float, optional) – Distance threshold for clusters between k-means iterations. Defaults to 0.0001.

  • palett_lambda (float, optional) – Reduces effects of outliers during centroid calculation. Defaults to 0.0.

  • add_extra_centroid (bool, optional) – If True, adds an extra centroid to the LUT. Defaults to False.

  • palett_cluster_tol (float, optional) – Tolerance for non-unique centroids in the LUT. The higher the number, the more tolerance for non-unique centroids. Defaults to 0.05.

as_dict() Dict[str, Any]

Returns the config as a dictionary.

classmethod from_dict(config_dict: Dict[str, Any]) ModuleOptimizationConfig

Create class from a dictionary of string keys and values.


config_dict (dict of str and values) – A nested dictionary of strings and values.

classmethod from_yaml(yml: Union[IO, str]) ModuleOptimizationConfig

Create class from a yaml stream.


yml – An IO stream containing yaml or a str path to the yaml file.

class coremltools.optimize.torch.palettization.DKMPalettizerConfig(global_config: Optional[GlobalConfigType] = None, module_type_configs: ModuleTypeConfigType = _Nothing.NOTHING, module_name_configs: ModuleNameConfigType = _Nothing.NOTHING)[source]

Configuration for specifying how different submodules of a model are palettized by DKMPalettizer.

The module_type_configs parameter can accept a list of ModuleDKMPalettizerConfig as values for a given module type. The list can specify different parameters for different weight_threshold values. This is useful if you want to apply different configs to layers of the same type with weights of different sizes.

For example, to use 4 -bit palettization for weights with more than 1000 elements and 2 -bit palettization for weights with more than 300 but less than 1000 elements, create a config as follows:

custom_config = {
    nn.Conv2d: [
        {"n_bits": 4, "cluster_dim": 4, "weight_threshold": 1000},
        {"n_bits": 2, "cluster_dim": 2, "weight_threshold": 300},
config = DKMPalettizerConfig.from_dict({"module_type_configs": custom_config})
  • global_config (ModuleDKMPalettizerConfig) – Config to be applied globally to all supported modules. Missing values are chosen from the default config.

  • module_type_configs (dict of str to ModuleDKMPalettizerConfig) – Module type level configs applied to a specific module class, such as torch.nn.Linear. The keys can be either strings or module classes. When module_type_config is set to None for a module type, it is not palettized.

  • module_name_configs (dict of str to ModuleDKMPalettizerConfig) – Module level configs applied to specific modules. The name of the module must be a fully qualified name that can be used to fetch it from the top level module using the module.get_submodule(target) method. When module_name_config is set to None for a module, it is not palettized.

as_dict() Dict[str, Any]

Returns the config as a dictionary.

classmethod from_dict(config_dict: Dict[str, Any]) DKMPalettizerConfig[source]

Create class from a dictionary of string keys and values.


config_dict (dict of str and values) – A nested dictionary of strings and values.

classmethod from_yaml(yml: Union[IO, str]) OptimizationConfig

Create class from a yaml stream.


yml – An IO stream containing yaml or a str path to the yaml file.

set_global(global_config: Optional[ModuleOptimizationConfig]) OptimizationConfig

Set the global config.

set_module_name(module_name: str, opt_config: Optional[ModuleOptimizationConfig]) OptimizationConfig

Set the module level optimization config for a given module instance. If the module level optimization config for an existing module was already set, the new config will override the old one.

set_module_type(object_type: Union[Callable, str], opt_config: Optional[ModuleOptimizationConfig]) OptimizationConfig

Set the module level optimization config for a given module type. If the module level optimization config for an existing module type was already set, the new config will override the old one.

class coremltools.optimize.torch.palettization.DKMPalettizer(model: Module, config: Optional[DKMPalettizerConfig] = None)[source]

A palettization algorithm based on “DKM: Differentiable K-Means Clustering Layer for Neural Network Compression”. It clusters the weights using a differentiable version of k-means, allowing the look-up-table (LUT) and indices of palettized weights to be learnt using a gradient-based optimization algorithm such as SGD.


import torch
from coremltools.optimize.torch.palettization import (

# code that defines the pytorch model, loss and optimizer.
model, loss_fn, optimizer = create_model_loss_and_optimizer()

# initialize the palettizer
config = DKMPalettizerConfig(global_config=ModuleDKMPalettizerConfig(n_bits=4))

palettizer = DKMPalettizer(model, config)

# prepare the model to insert FakePalettize layers for palettization
model = palettizer.prepare(inplace=True)

# use palettizer in your PyTorch training loop
for inputs, labels in data:
    output = model(inputs)
    loss = loss_fn(output, labels)

# fold LUT and indices into weights
model = palettizer.finalize(inplace=True)
  • model (torch.nn.Module) – Model on which the palettizer will act.

  • config (DKMPalettizerConfig) – Config which specifies how different submodules in the model will be configured for palettization. Default config is used when passed as None.

finalize(model: Optional[Module] = None, inplace: bool = False) Module[source]

Removes FakePalettize layers from a model and creates new model weights from the LUT and indices buffers.

This function is called to prepare a palettized model for export using coremltools.

  • model (nn.Module) – model to finalize.

  • inplace (bool) – If True, model transformations are carried out in-place and the original module is mutated; otherwise, a copy of the model is mutated and returned.

prepare(inplace: bool = False) Module[source]

Prepares a model for palettization aware training by inserting FakePalettize layers in appropriate places as specified by the config.


inplace (bool) – If True, model transformations are carried out in-place and the original module is mutated, otherwise a copy of the model is mutated and returned.

report() _Report[source]

Returns a dictionary with important statistics related to current state of palettization. Each key in the dictionary corresponds to a module name, and the value is a dictionary containing the statistics, such as number of clusters and cluster dimension, number of parameters, and so on.


Step through the palettizer. When the number of times step is called is equal to milestone, palettization is enabled.

Palettization Layers

class coremltools.optimize.torch.palettization.FakePalettize(observer: ObserverBase, n_bits: int, cluster_dim: int, quant_min: int = -128, quant_max: int = 127, cluster_dtype: str = 'f32', advanced_options: dict = {}, **observer_kwargs)[source]

A class that implements palettization algorithm described in DKM: Differentiable K-Means Clustering Layer for Neural Network Compression. It clusters the weights using a differentiable version of k-means, allowing the look-up-table (LUT) and indices of palettized weights to be learnt using a gradient-based optimization algorithm such as SGD.

Extends torch.quantization.FakeQuantize to add support for palettization.


from collections import OrderedDict
import torch
import torch.nn as nn
import coremltools.optimize.torch.palettization as palett

model = nn.Sequential(
            ("linear1", nn.Linear(4, 5)),
            ("sigmoid1", nn.Sigmoid()),
            ("linear2", nn.Linear(5, 4)),
            ("sigmoid2", nn.Sigmoid),

fq_activation = nn.Identity
fq_weight = palett.FakePalettize.with_args(
        quant_min=-128, quant_max=127, dtype=torch.qint8
model.linear2.qconfig = torch.quantization.QConfig(
    activation=fq_activation, weight=fq_weight

palettized_model = palett.prepare_palettizer(model)


palettized_converted_model = palett.finalize(palettized_model)
  • observer ( – Observer for quantizing the LUT.

  • n_bits (int) – Number of palettization bits. There would be \(2^{n\_bits}\) unique weights in the LUT.

  • cluster_dim (int) – Dimensionality of centroids to use for clustering.

  • quant_min (int) – The minimum allowable quantized value.

  • quant_max (int) – The maximum allowable quantized value.

  • cluster_dtype (str) – String that decides whether to quantize the LUT or not. The following are the str LUT quantization combinations: (u8, uint8), (i8, int8), and (f16, float16).

  • advanced_options (dict) – Advanced options to configure the palettization algorithm.

  • observer_kwargs (optional) – Arguments for the observer module.


Allowed keys for advanced_options are the parameters listed as optional in ModuleDKMPalettizerConfig, besides the ones already covered by other parameters in this class.