Palettization

Palettization is a mechanism for compressing a model by clustering the model’s float weights into a lookup table (LUT) of centroids and indices.

Palettization is implemented as an extension of PyTorch’s QAT APIs. It works by inserting palettization layers in appropriate places inside a model. The model can then be fine-tuned to learn the new palettized layers’ weights in the form of a LUT and indices.

class coremltools.optimize.torch.palettization.ModuleDKMPalettizerConfig(n_bits: int | None = None, weight_threshold: int = 2048, granularity='per_tensor', group_size: int | None = None, channel_axis: int = 0, enable_per_channel_scale: bool = False, milestone: int = 0, cluster_dim: int | None = None, quant_min: int = -128, quant_max: int = 127, dtype: str | dtype = torch.qint8, lut_dtype: str = 'f32', quantize_activations: bool = False, cluster_permute: tuple | None = None, palett_max_mem: float = 1.0, kmeans_max_iter: int = 3, prune_threshold: float = 1e-07, kmeans_init: str = 'auto', kmeans_opt1d_threshold: int = 1024, enforce_zero: bool = False, palett_mode: str = 'dkm', palett_tau: float = 0.0001, palett_epsilon: float = 0.0001, palett_lambda: float = 0.0, add_extra_centroid: bool = False, palett_cluster_tol: float = 0.0, palett_min_tsize: int = 65536, palett_unique: bool = False, palett_shard: bool = False, palett_batch_mode: bool = False, palett_dist: bool = False, per_channel_scaling_factor_scheme: str = 'min_max', percentage_palett_enable: float = 1.0, kmeans_batch_threshold: int = 4, kmeans_n_init: int = 10, zero_threshold: float = 1e-07, kmeans_error_bnd: float = 0.0, partition_size: int | None = None, cluster_dtype: str | None = None)[source]

Configuration class for specifying global and module-level options for the palettization algorithm implemented in DKMPalettizer.

The parameters specified in this config control the DKM algorithm, described in DKM: Differentiable K-Means Clustering Layer for Neural Network Compression.

For most use cases, the only parameters you need to specify are n_bits, weight_threshold, and milestone.

Note

Most of the parameters in this class are meant for advanced use cases and for further fine-tuning the DKM algorithm. The default values usually work for a majority of tasks.

Note

Change the following parameters only when you use activation quantization in conjunction with DKM weight palettization: quant_min, quant_max, dtype, and quantize_activations.

Parameters:
  • n_bits (int) – Number of clusters. The number of clusters used is 2n_bits. Defaults to 4 for linear layers and 2 for all other layers.

  • weight_threshold (int) – A module is only palettized if the number of elements in its weight matrix exceeds weight_threshold. If there are multiple weights in a module, such as torch.nn.MultiheadAttention, all of them must have more elements than the weight_threshold for the module to be palettized. Defaults to 2048.

  • granularity (PalettizationGranularity) – Granularity for palettization. One of per_tensor or per_grouped_channel. Defaults to per_tensor.

  • group_size (int) – Specify the number of channels in a group. Only effective when granularity is per_grouped_channel.

  • channel_axis (int) – Specify the channel axis to form a group of channels. Only effective when granularity is per_grouped_channel. Defaults to output channel axis. For now, only output channel axis is supported by DKM.

  • enable_per_channel_scale (bool) – When set to True, per-channel scaling is used along the channel dimension.

  • milestone (int) – Step or epoch at which palettization begins. Defaults to 0.

  • cluster_dim (int, optional) – The dimension of each cluster.

  • quant_min (int, optional) – The minimum value for each element in the weight clusters if they are quantized. Defaults to -128.

  • quant_max (int, optional) – The maximum value for each element in the weight clusters if they are quantized. Defaults to 127

  • dtype (torch.dtype, optional) – The dtype to use for quantizing the activations. Only applies when quantize_activations is True. Defaults to torch.qint8.

  • lut_dtype (str, optional) – dtype to use for quantizing the clusters. Allowed options are 'i8', 'u8', 'f16', 'bf16', 'f32'. Defaults to 'f32', so by default, the clusters aren’t quantized.

  • quantize_activations (bool, optional) – When True, the activations are quantized. Defaults to False.

  • cluster_permute (tuple, optional) – Permutation order to apply to weight partitions. Defaults to None.

  • palett_max_mem (float, optional) – Proportion of available GPU memory that should be used for palettization. Defaults to 1.0.

  • kmeans_max_iter (int, optional) – Maximum number of differentiable k-means iterations. Defaults to 3.

  • prune_threshold (float, optional) – Hardshrinks weights between [-prune_threshold, prune_threshold] to zero. Useful for joint pruning and palettization. Defaults to 1e-7.

  • kmeans_init (str, optional) – k-means algorithm to use. Allowed options are opt1d, cpu.kmeans++ and kmeans++. Defaults to auto.

  • kmeans_opt1d_threshold (int, optional) – Channel threshold to decide if opt1d kmeans should be used. Defaults to 1024.

  • enforce_zero (bool, optional) – If True, enforces the LUT centroid which is closest to the origin to be fixed to zero. Defaults to False.

  • palett_mode (str, optional) – Criteria to calculate attention during k-means. Allowed options are gsm, dkm and hard. Defaults to dkm.

  • palett_tau (float, optional) – Temperature factor for softmax used in DKM algorithm. Defaults to 0.0001.

  • palett_epsilon (float, optional) – Distance threshold for clusters between k-means iterations. Defaults to 0.0001.

  • palett_lambda (float, optional) – Reduces effects of outliers during centroid calculation. Defaults to 0.0.

  • add_extra_centroid (bool, optional) – If True, adds an extra centroid to the LUT. Defaults to False.

  • palett_cluster_tol (float, optional) – Tolerance for non-unique centroids in the LUT. The higher the number, the more tolerance for non-unique centroids. Defaults to 0.0.

  • palett_min_tsize (int, optional) – Weight threshold beyond which to use custom packing and unpacking hook for autograd. Defaults to 64*1024.

  • palett_unique (bool, optional) – If True, reduces the attention map by leveraging the fact that FP16 only has 2^16 unique values. Useful for Large Models like LLMs where attention maps can be huge. Defaults to False. For more details, read eDKM: An Efficient and Accurate Train-time Weight Clustering for Large Language Models .

  • palett_shard (bool, optional) – If True, the index list is sharded across GPUs. Defaults to False. For more details, read eDKM: An Efficient and Accurate Train-time Weight Clustering for Large Language Models .

  • palett_batch_mode (bool, optional) – If True, performs batch DKM across different partitions created for different blocks. Defaults to False. More details can be found eDKM: An Efficient and Accurate Train-time Weight Clustering for Large Language Models .

  • palett_dist (bool, optional) – If True, performs distributed distance calculation in batch_mode if distributed torch is available. Defaults to False.

  • per_channel_scaling_factor_scheme (str, optional) – Criteria to calculate the per_channel_scaling_factor. Allowed options are min_max and abs. Defaults to min_max.

  • percentage_palett_enable (float, optional) – Percentage partitions to enable for DKM. Defaults to 1.0.

  • kmeans_batch_threshold (int, optional) – Threshold to decide what the num_partitions value should be to go through with the sharded centroids list. num_partitions is calculated by dividing the channel size by the group_size provided. If num_partitions` matches kmeans_batch_threshold, the algorithm resorts to performing distributed k-means for lower partition numbers, given that num_partition number of GPUs are available. Defaults to 4.

  • kmeans_n_init (int, optional) – Number of time the k-means algorithm will be run with different centroid seeds. The final results will be the best output of kmeans_n_init consecutive runs in terms of inertia.

  • zero_threshold (int, optional) – Zero threshold to be used to decide the minimum value of clamp for softmax. Defaults to 1e-7.

  • kmeans_error_bnd (float, optional) – Distance threshold to decide at what distance between parameters and clusters to stop the k-means operation. Defaults to 0.0.

This class supports two different configurations to structure the palettization:

1. Per-tensor palettization: This is the default configuration where the whole tensor shares a single lookup table. The granularity is set to per_tensor and group_size is None.

2. Per-grouped-channel palettization: In this configuration, group_size number of channels along channel_axis share the same lookup table. For example, for a weight matrix of shape (16, 25), if we provide group_size = 8, the shape of the lookup table would be (2, 2^n_bits).

Note

Grouping is currently only supported along the output channel axis.

as_dict() Dict[str, Any]

Returns the config as a dictionary.

classmethod from_dict(data_dict: Dict[str, Any]) DictableDataClass

Create class from a dictionary of string keys and values.

Parameters:

data_dict (dict of str and values) – A nested dictionary of strings and values.

classmethod from_yaml(yml: IO | str) DictableDataClass

Create class from a yaml stream.

Parameters:

yml – An IO stream containing yaml or a str path to the yaml file.

class coremltools.optimize.torch.palettization.DKMPalettizerConfig(global_config: GlobalConfigType | None = None, module_type_configs: ModuleTypeConfigType = NOTHING, module_name_configs: ModuleNameConfigType = NOTHING)[source]

Configuration for specifying how different submodules of a model are palettized by DKMPalettizer.

The module_type_configs parameter can accept a list of ModuleDKMPalettizerConfig as values for a given module type. The list can specify different parameters for different weight_threshold values. This is useful if you want to apply different configs to layers of the same type with weights of different sizes.

For example, to use 4 -bit palettization for weights with more than 1000 elements and 2 -bit palettization for weights with more than 300 but less than 1000 elements, create a config as follows:

custom_config = {
    nn.Conv2d: [
        {"n_bits": 4, "cluster_dim": 4, "weight_threshold": 1000},
        {"n_bits": 2, "cluster_dim": 2, "weight_threshold": 300},
    ]
}
config = DKMPalettizerConfig.from_dict({"module_type_configs": custom_config})
Parameters:
  • global_config (ModuleDKMPalettizerConfig) – Config to be applied globally to all supported modules. Missing values are chosen from the default config.

  • module_type_configs (dict of str to ModuleDKMPalettizerConfig) – Module type level configs applied to a specific module class, such as torch.nn.Linear. The keys can be either strings or module classes. When module_type_config is set to None for a module type, it is not palettized.

  • module_name_configs (dict of str to ModuleDKMPalettizerConfig) – Module-level configs applied to specific modules. The name of the module must be a fully qualified name that can be used to fetch it from the top-level module using the module.get_submodule(target) method. When module_name_config is set to None for a module, it is not palettized.

as_dict() Dict[str, Any]

Returns the config as a dictionary.

classmethod from_dict(config_dict: Dict[str, Any]) DKMPalettizerConfig[source]

Create class from a dictionary of string keys and values.

Parameters:

config_dict (dict of str and values) – A nested dictionary of strings and values.

classmethod from_yaml(yml: IO | str) DictableDataClass

Create class from a yaml stream.

Parameters:

yml – An IO stream containing yaml or a str path to the yaml file.

set_global(global_config: ModuleOptimizationConfig | None) OptimizationConfig

Set the global config.

set_module_name(module_name: str, opt_config: ModuleOptimizationConfig | None) OptimizationConfig

Set the module level optimization config for a given module instance. If the module level optimization config for an existing module was already set, the new config will override the old one.

set_module_type(object_type: Callable | str, opt_config: ModuleOptimizationConfig | None) OptimizationConfig

Set the module level optimization config for a given module type. If the module level optimization config for an existing module type was already set, the new config will override the old one.

class coremltools.optimize.torch.palettization.DKMPalettizer(model: Module, config: DKMPalettizerConfig | None = None)[source]

A palettization algorithm based on “DKM: Differentiable K-Means Clustering Layer for Neural Network Compression”. It clusters the weights using a differentiable version of k-means, allowing the lookup table (LUT) and indices of palettized weights to be learnt using a gradient-based optimization algorithm such as SGD.

Example

import torch
from coremltools.optimize.torch.palettization import (
    DKMPalettizer,
    DKMPalettizerConfig,
    ModuleDKMPalettizerConfig,
)

# code that defines the pytorch model, loss and optimizer.
model, loss_fn, optimizer = create_model_loss_and_optimizer()

# initialize the palettizer
config = DKMPalettizerConfig(global_config=ModuleDKMPalettizerConfig(n_bits=4))

palettizer = DKMPalettizer(model, config)

# prepare the model to insert FakePalettize layers for palettization
model = palettizer.prepare(inplace=True)

# use palettizer in your PyTorch training loop
for inputs, labels in data:
    output = model(inputs)
    loss = loss_fn(output, labels)
    loss.backward()
    optimizer.step()
    palettizer.step()

# fold LUT and indices into weights
model = palettizer.finalize(inplace=True)
Parameters:
  • model (torch.nn.Module) – Model on which the palettizer will act.

  • config (DKMPalettizerConfig) – Config which specifies how different submodules in the model will be configured for palettization. Default config is used when passed as None.

finalize(model: Module | None = None, inplace: bool = False) Module[source]

Removes FakePalettize layers from a model and creates new model weights from the LUT and indices buffers.

This function is called to prepare a palettized model for export using coremltools.

Parameters:
  • model (nn.Module) – model to finalize.

  • inplace (bool) – If True, model transformations are carried out in-place and the original module is mutated; otherwise, a copy of the model is mutated and returned.

prepare(inplace: bool = False) Module[source]

Prepares a model for palettization aware training by inserting FakePalettize layers in appropriate places as specified by the config.

Parameters:

inplace (bool) – If True, model transformations are carried out in-place and the original module is mutated, otherwise a copy of the model is mutated and returned.

report() _Report[source]

Returns a dictionary with important statistics related to current state of palettization. Each key in the dictionary corresponds to a module name, and the value is a dictionary containing the statistics, such as number of clusters and cluster dimension, number of parameters, and so on.

step()[source]

Step through the palettizer. When the number of times step is called is equal to milestone, palettization is enabled.

class coremltools.optimize.torch.palettization.ModulePostTrainingPalettizerConfig(n_bits: int | None = 4, lut_dtype=None, granularity='per_tensor', group_size: int | None = None, channel_axis: int = 0, cluster_dim: int | None = None, enable_per_channel_scale: bool | None = False, enable_fast_kmeans_mode: bool | None = True, rounding_precision: int | None = 4)[source]

Configuration class for specifying global and module-level palettization options for PostTrainingPalettizerConfig algorithm.

Parameters:
  • n_bits (int) – Number of bits to use for palettizing the weights. Defaults to 4.

  • lut_dtype (torch.dtype) – The dtype to use for representing each element in lookup tables. When value is None, no quantization is performed. Supported values are torch.int8 and torch.uint8. Defaults to None.

  • granularity (PalettizationGranularity) – One of per_tensor or per_grouped_channel. Defaults to per_tensor.

  • group_size (int) – Specify the number of channels in a group. Only effective when granularity is per_grouped_channel.

  • channel_axis (int) – Specify the channel axis to form a group of channels. Only effective when granularity is per_grouped_channel. Defaults to output channel axis.

  • cluster_dim (int) – The dimension of centroids for each lookup table. The centroid is a scalar by default. When cluster_dim > 1, it indicates 2-D clustering, and each cluster_dim length of weight vectors along the output channel are palettized using the same 2-D centroid. The length of each entry in the lookup tables is equal to cluster_dim.

  • enable_per_channel_scale (bool) – When set to True, weights are normalized along the output channels using per-channel scales before being palettized. This is not supported with cluster_dim > 1.

  • enable_fast_kmeans_mode (bool) – When turned on, will round the weights before clustering if data is in fp16 range. If weight dtype is fp32, weights are cast to fp16 and then rounded. This is not supported with cluster_dim > 1. Defaults to True.

  • rounding_precision (int) – The number of decimal places to set for rounding, when enable_fast_kmeans_mode is enabled. Choose a lower precision for faster processing, at the cost of coarser approximation. Defaults to 4.

This class supports two different configurations to structure the palettization:

1. Per-tensor palettization: This is the default configuration where the whole tensor shares a single lookup table. The granularity is set to per_tensor, and group_size is None.

2. Per-grouped-channel palettization: In this configuration, the number of channels group_size along channel_axis share the same lookup table. For example, for a weight matrix of shape (16, 25), if we provide group_size = 8, the shape of the lookup table would be (2, 2^n_bits).

Note

Grouping is currently only supported along either the input or output channel axis.

class coremltools.optimize.torch.palettization.PostTrainingPalettizer(model: Module, config: PostTrainingPalettizerConfig | None = None)[source]

Perform post-training palettization on a torch model. Post palettization, all the weights in supported layers point to elements in a lookup table after performing a k-means operation.

Example

import torch.nn as nn
from coremltools.optimize.torch.palettization import (
    PostTrainingPalettizerConfig,
    PostTrainingPalettizer,
)

model = nn.Sequential(
    OrderedDict(
        {
            "conv": nn.Conv2d(1, 20, (3, 3)),
            "relu1": nn.ReLU(),
            "conv2": nn.Conv2d(20, 20, (3, 3)),
            "relu2": nn.ReLU(),
        }
    )
)

# initialize the palettizer
config = PostTrainingPalettizerConfig.from_dict(
    {
        "global_config": {
            "n_bits": 4,
        },
    }
)

ptpalettizer = PostTrainingPalettizer(model, config)
palettized_model = ptpalettizer.compress()
Parameters:
  • model (torch.nn.Module) – Module to be compressed.

  • config (PostTrainingPalettizerConfig) – Config that specifies how different submodules in the model will be palettized.

class coremltools.optimize.torch.palettization.PostTrainingPalettizerConfig(global_config: ModulePostTrainingPalettizerConfig | None = None, module_type_configs: ModuleTypeConfigType = NOTHING, module_name_configs: Dict[str, ModulePostTrainingPalettizerConfig | None] = NOTHING)[source]

Configuration class for specifying how different submodules of a model should be post-training palettized by PostTrainingPalettizer.

Parameters:
class coremltools.optimize.torch.palettization.ModuleSKMPalettizerConfig(n_bits: int = 4, lut_dtype=None, granularity='per_tensor', group_size: int | None = None, channel_axis: int = 0, cluster_dim: int | None = None, enable_per_channel_scale: bool = False)[source]

Configuration class for specifying global and module-level palettization options for SKMPalettizer algorithm.

Parameters:
  • n_bits (int) – Number of bits to use for palettizing the weights. Defaults to 4.

  • lut_dtype (torch.dtype) – The dtype to use for representing each element in lookup tables. When value is None, no quantization is performed. Supported values are torch.int8 and torch.uint8. Defaults to None.

  • granularity (PalettizationGranularity) – One of per_tensor or per_grouped_channel. Defaults to per_tensor.

  • group_size (int) – Specify the number of channels in a group. Only effective when granularity is per_grouped_channel.

  • channel_axis (int) – Specify the channel axis to form a group of channels. Only effective when granularity is per_grouped_channel. Defaults to output channel axis.

  • cluster_dim (int) – The dimension of centroids for each lookup table. The centroid is a scalar by default. When cluster_dim > 1, it indicates 2-D clustering, and each cluster_dim length of weight vectors along the output channel are palettized using the same 2-D centroid. The length of each entry in the lookup tables is equal to cluster_dim.

  • enable_per_channel_scale (bool) – When set to True, weights are normalized along the output channels using per-channel scales before being palettized. This is not supported with cluster_dim > 1.

This class supports two different configurations to structure the palettization:

1. Per-tensor palettization: This is the default configuration where the whole tensor shares a single lookup table. The granularity is set to per_tensor, and group_size is None.

2. Per-grouped-channel palettization: In this configuration, the number of channels group_size along channel_axis share the same lookup table. For example, for a weight matrix of shape (16, 25), if we provide group_size = 8, the shape of the lookup table would be (2, 2^n_bits).

Note

Grouping is currently only supported along either the input or output channel axis.

class coremltools.optimize.torch.palettization.SKMPalettizer(model: Module, config: SKMPalettizerConfig | None = None)[source]

Perform post-training palettization of weights by running a weighted k-means on the model weights. The weight values used for weighing different elements of a model’s weight matrix are computed using the Fisher information matrix, which is an approximation of the Hessian. These weight values indicate how sensitive a given weight element is: the more sensitive an element, the larger the impact perturbing or palettizing it has on the model’s loss function. This means that weighted k-means moves the clusters closer to the sensitive weight values, allowing them to be represented more exactly. This leads to a lower degradation in model performance after palettization. The Fisher information matrix is computed using a few samples of calibration data.

This algorithm implements SqueezeLLM: Dense-and-Sparse Quantization.

Example

import torch.nn as nn
from coremltools.optimize.torch.palettization import (
    SKMPalettizer,
    SKMPalettizerConfig,
)

model = nn.Sequential(
    OrderedDict(
        {
            "conv": nn.Conv2d(1, 20, (3, 3)),
            "relu1": nn.ReLU(),
            "conv2": nn.Conv2d(20, 20, (3, 3)),
            "relu2": nn.ReLU(),
        }
    )
)

dataloader = load_calibration_data()

# define callable for loss function
def loss_fn(model, data):
    inp, target = data
    out = model(inp)
    return nn.functional.mse_loss(out, target)

# initialize the palettizer
config = SKMPalettizerConfig.from_dict(
    {
        "global_config": {
            "n_bits": 4,
        },
        "calibration_nsamples": 16,
    }
)

compressor = SKMPalettizer(model, config)
compressed_model = compressor.compress(dataloader=dataloader, loss_fn=loss_fn)
Parameters:
  • model (torch.nn.Module) – Module to be compressed.

  • config (LayerwiseCompressorConfig) – Config that specifies how different submodules in the model will be compressed.

class coremltools.optimize.torch.palettization.SKMPalettizerConfig(global_config: ModuleSKMPalettizerConfig | None = None, module_type_configs: ModuleTypeConfigType = NOTHING, module_name_configs: Dict[str, ModuleSKMPalettizerConfig | None] = NOTHING, calibration_nsamples: int = 128)[source]

Configuration class for specifying how different submodules of a model are palettized by SKMPalettizer.

Parameters:
  • global_config (ModuleSKMPalettizerConfig) – Config to be applied globally to all supported modules. Missing values are chosen from the default config.

  • module_type_configs (dict of str to ModuleSKMPalettizerConfig) – Module type configs applied to a specific module class, such as torch.nn.Linear. The keys can be either strings or module classes.

  • module_name_configs (dict of str to ModuleSKMPalettizerConfig) – Module-level configs applied to specific modules. The name of the module must either be a regex or a fully qualified name that can be used to fetch it from the top level module using the module.get_submodule(target) method.

  • calibration_nsamples (int) – Number of samples to be used for calibration.