Palettization

Palettization is a mechanism for compressing a model by clustering the model’s float weights into a lookup table (LUT) of centroids and indices.

Palettization is implemented as an extension of PyTorch’s QAT APIs. It works by inserting palettization layers in appropriate places inside a model. The model can then be fine-tuned to learn the new palettized layers’ weights in the form of a LUT and indices.

class coremltools.optimize.torch.palettization.ModuleDKMPalettizerConfig(n_bits: int | None = None, weight_threshold: int = 2048, granularity='per_tensor', group_size: int | None = None, channel_axis: int = 0, enable_per_channel_scale: bool = False, milestone: int = 0, cluster_dim: int | None = None, quant_min: int = -128, quant_max: int = 127, dtype: str | dtype = torch.qint8, lut_dtype: str = 'f32', quantize_activations: bool = False, cluster_permute: tuple | None = None, palett_max_mem: float = 1.0, kmeans_max_iter: int = 3, prune_threshold: float = 1e-07, kmeans_init: str = 'auto', kmeans_opt1d_threshold: int = 1024, enforce_zero: bool = False, palett_mode: str = 'dkm', palett_tau: float = 0.0001, palett_epsilon: float = 0.0001, palett_lambda: float = 0.0, add_extra_centroid: bool = False, palett_cluster_tol: float = 0.0, palett_min_tsize: int = 65536, palett_unique: bool = False, palett_shard: bool = False, palett_batch_mode: bool = False, palett_dist: bool = False, per_channel_scaling_factor_scheme: str = 'min_max', percentage_palett_enable: float = 1.0, kmeans_batch_threshold: int = 4, kmeans_n_init: int = 100, zero_threshold: float = 1e-07, kmeans_error_bnd: float = 0.0, partition_size: int | None = None, cluster_dtype: str | None = None)[source]

Configuration class for specifying global and module-level options for the palettization algorithm implemented in DKMPalettizer.

The parameters specified in this config control the DKM algorithm, described in DKM: Differentiable K-Means Clustering Layer for Neural Network Compression.

For most use cases, the only parameters you need to specify are n_bits, weight_threshold, and milestone.

Note

Most of the parameters in this class are meant for advanced use cases and for further fine-tuning the DKM algorithm. The default values usually work for a majority of tasks.

Note

Change the following parameters only when you use activation quantization in conjunction with DKM weight palettization: quant_min, quant_max, dtype, and quantize_activations.

Parameters:
  • n_bits (int) – Number of clusters. The number of clusters used is \(2^{n\_bits}\). Defaults to 4 for linear layers and 2 for all other layers.

  • weight_threshold (int) – A module is only palettized if the number of elements in its weight matrix exceeds weight_threshold. If there are multiple weights in a module, such as torch.nn.MultiheadAttention, all of them must have more elements than the weight_threshold for the module to be palettized. Defaults to 2048.

  • granularity (PalettizationGranularity) – Granularity for palettization. One of per_tensor or per_grouped_channel. Defaults to per_tensor.

  • group_size (int) – Specify the number of channels in a group. Only effective when granularity is per_grouped_channel.

  • channel_axis (int) – Specify the channel axis to form a group of channels. Only effective when granularity is per_grouped_channel. Defaults to output channel axis. For now, only output channel axis is supported by DKM.

  • enable_per_channel_scale (bool) – When set to True, per-channel scaling is used along the channel dimension.

  • milestone (int) – Step or epoch at which palettization begins. Defaults to 0.

  • cluster_dim (int, optional) – The dimension of each cluster.

  • quant_min (int, optional) – The minimum value for each element in the weight clusters if they are quantized. Defaults to -128.

  • quant_max (int, optional) – The maximum value for each element in the weight clusters if they are quantized. Defaults to 127

  • dtype (torch.dtype, optional) – The dtype to use for quantizing the activations. Only applies when quantize_activations is True. Defaults to torch.qint8.

  • lut_dtype (str, optional) – dtype to use for quantizing the clusters. Allowed options are 'i8', 'u8', 'f16', 'bf16', 'f32'. Defaults to 'f32', so by default, the clusters aren’t quantized.

  • quantize_activations (bool, optional) – When True, the activations are quantized. Defaults to False.

  • cluster_permute (tuple, optional) – Permutation order to apply to weight partitions. Defaults to None.

  • palett_max_mem (float, optional) – Proportion of available GPU memory that should be used for palettization. Defaults to 1.0.

  • kmeans_max_iter (int, optional) – Maximum number of differentiable k-means iterations. Defaults to 3.

  • prune_threshold (float, optional) – Hardshrinks weights between [-prune_threshold, prune_threshold] to zero. Useful for joint pruning and palettization. Defaults to 1e-7.

  • kmeans_init (str, optional) – k-means algorithm to use. Allowed options are opt1d, cpu.kmeans++ and kmeans++. Defaults to auto.

  • kmeans_opt1d_threshold (int, optional) – Channel threshold to decide if opt1d kmeans should be used. Defaults to 1024.

  • enforce_zero (bool, optional) – If True, enforces the LUT centroid which is closest to the origin to be fixed to zero. Defaults to False.

  • palett_mode (str, optional) – Criteria to calculate attention during k-means. Allowed options are gsm, dkm and hard. Defaults to dkm.

  • palett_tau (float, optional) – Temperature factor for softmax used in DKM algorithm. Defaults to 0.0001.

  • palett_epsilon (float, optional) – Distance threshold for clusters between k-means iterations. Defaults to 0.0001.

  • palett_lambda (float, optional) – Reduces effects of outliers during centroid calculation. Defaults to 0.0.

  • add_extra_centroid (bool, optional) – If True, adds an extra centroid to the LUT. Defaults to False.

  • palett_cluster_tol (float, optional) – Tolerance for non-unique centroids in the LUT. The higher the number, the more tolerance for non-unique centroids. Defaults to 0.0.

  • palett_min_tsize (int, optional) – Weight threshold beyond which to use custom packing and unpacking hook for autograd. Defaults to 64*1024.

  • palett_unique (bool, optional) – If True, reduces the attention map by leveraging the fact that FP16 only has 2^16 unique values. Useful for Large Models like LLMs where attention maps can be huge. Defaults to False. For more details, read eDKM: An Efficient and Accurate Train-time Weight Clustering for Large Language Models .

  • palett_shard (bool, optional) – If True, the index list is sharded across GPUs. Defaults to False. For more details, read eDKM: An Efficient and Accurate Train-time Weight Clustering for Large Language Models .

  • palett_batch_mode (bool, optional) – If True, performs batch DKM across different partitions created for different blocks. Defaults to False. More details can be found eDKM: An Efficient and Accurate Train-time Weight Clustering for Large Language Models .

  • palett_dist (bool, optional) – If True, performs distributed distance calculation in batch_mode if distributed torch is available. Defaults to False.

  • per_channel_scaling_factor_scheme (str, optional) – Criteria to calculate the per_channel_scaling_factor. Allowed options are min_max and abs. Defaults to min_max.

  • percentage_palett_enable (float, optional) – Percentage partitions to enable for DKM. Defaults to 1.0.

  • kmeans_batch_threshold (int, optional) – Threshold to decide what the num_partitions value should be to go through with the sharded centroids list. num_partitions is calculated by dividing the channel size by the group_size provided. If num_partitions` matches kmeans_batch_threshold, the algorithm resorts to performing distributed k-means for lower partition numbers, given that num_partition number of GPUs are available. Defaults to 4.

  • kmeans_n_init (int, optional) – Number of time the k-means algorithm will be run with different centroid seeds. The final results will be the best output of kmeans_n_init consecutive runs in terms of inertia.

  • zero_threshold (int, optional) – Zero threshold to be used to decide the minimum value of clamp for softmax. Defaults to 1e-7.

  • kmeans_error_bnd (float, optional) – Distance threshold to decide at what distance between parameters and clusters to stop the k-means operation. Defaults to 0.0.

This class supports two different configurations to structure the palettization:

1. Per-tensor palettization: This is the default configuration where the whole tensor shares a single lookup table. The granularity is set to per_tensor and group_size is None.

2. Per-grouped-channel palettization: In this configuration, group_size number of channels along channel_axis share the same lookup table. For example, for a weight matrix of shape (16, 25), if we provide

group_size = 8, the shape of the lookup table would be (2, 2^n_bits).

Note

Grouping is currently only supported along the output channel axis.

as_dict() Dict[str, Any]

Returns the config as a dictionary.

classmethod from_dict(data_dict: Dict[str, Any]) DictableDataClass

Create class from a dictionary of string keys and values.

Parameters:

data_dict (dict of str and values) – A nested dictionary of strings and values.

classmethod from_yaml(yml: IO | str) DictableDataClass

Create class from a yaml stream.

Parameters:

yml – An IO stream containing yaml or a str path to the yaml file.

class coremltools.optimize.torch.palettization.DKMPalettizerConfig(global_config: GlobalConfigType | None = None, module_type_configs: ModuleTypeConfigType = NOTHING, module_name_configs: ModuleNameConfigType = NOTHING)[source]

Configuration for specifying how different submodules of a model are palettized by DKMPalettizer.

The module_type_configs parameter can accept a list of ModuleDKMPalettizerConfig as values for a given module type. The list can specify different parameters for different weight_threshold values. This is useful if you want to apply different configs to layers of the same type with weights of different sizes.

For example, to use 4 -bit palettization for weights with more than 1000 elements and 2 -bit palettization for weights with more than 300 but less than 1000 elements, create a config as follows:

custom_config = {
    nn.Conv2d: [
        {"n_bits": 4, "cluster_dim": 4, "weight_threshold": 1000},
        {"n_bits": 2, "cluster_dim": 2, "weight_threshold": 300},
    ]
}
config = DKMPalettizerConfig.from_dict({"module_type_configs": custom_config})
Parameters:
  • global_config (ModuleDKMPalettizerConfig) – Config to be applied globally to all supported modules. Missing values are chosen from the default config.

  • module_type_configs (dict of str to ModuleDKMPalettizerConfig) – Module type level configs applied to a specific module class, such as torch.nn.Linear. The keys can be either strings or module classes. When module_type_config is set to None for a module type, it is not palettized.

  • module_name_configs (dict of str to ModuleDKMPalettizerConfig) – Module-level configs applied to specific modules. The name of the module must be a fully qualified name that can be used to fetch it from the top-level module using the module.get_submodule(target) method. When module_name_config is set to None for a module, it is not palettized.

as_dict() Dict[str, Any]

Returns the config as a dictionary.

classmethod from_dict(config_dict: Dict[str, Any]) DKMPalettizerConfig[source]

Create class from a dictionary of string keys and values.

Parameters:

config_dict (dict of str and values) – A nested dictionary of strings and values.

classmethod from_yaml(yml: IO | str) DictableDataClass

Create class from a yaml stream.

Parameters:

yml – An IO stream containing yaml or a str path to the yaml file.

set_global(global_config: ModuleOptimizationConfig | None) OptimizationConfig

Set the global config.

set_module_name(module_name: str, opt_config: ModuleOptimizationConfig | None) OptimizationConfig

Set the module level optimization config for a given module instance. If the module level optimization config for an existing module was already set, the new config will override the old one.

set_module_type(object_type: Callable | str, opt_config: ModuleOptimizationConfig | None) OptimizationConfig

Set the module level optimization config for a given module type. If the module level optimization config for an existing module type was already set, the new config will override the old one.

class coremltools.optimize.torch.palettization.DKMPalettizer(model: Module, config: DKMPalettizerConfig | None = None)[source]

A palettization algorithm based on “DKM: Differentiable K-Means Clustering Layer for Neural Network Compression”. It clusters the weights using a differentiable version of k-means, allowing the lookup table (LUT) and indices of palettized weights to be learnt using a gradient-based optimization algorithm such as SGD.

Example

import torch
from coremltools.optimize.torch.palettization import (
    DKMPalettizer,
    DKMPalettizerConfig,
    ModuleDKMPalettizerConfig,
)

# code that defines the pytorch model, loss and optimizer.
model, loss_fn, optimizer = create_model_loss_and_optimizer()

# initialize the palettizer
config = DKMPalettizerConfig(global_config=ModuleDKMPalettizerConfig(n_bits=4))

palettizer = DKMPalettizer(model, config)

# prepare the model to insert FakePalettize layers for palettization
model = palettizer.prepare(inplace=True)

# use palettizer in your PyTorch training loop
for inputs, labels in data:
    output = model(inputs)
    loss = loss_fn(output, labels)
    loss.backward()
    optimizer.step()
    palettizer.step()

# fold LUT and indices into weights
model = palettizer.finalize(inplace=True)
Parameters:
  • model (torch.nn.Module) – Model on which the palettizer will act.

  • config (DKMPalettizerConfig) – Config which specifies how different submodules in the model will be configured for palettization. Default config is used when passed as None.

finalize(model: Module | None = None, inplace: bool = False) Module[source]

Removes FakePalettize layers from a model and creates new model weights from the LUT and indices buffers.

This function is called to prepare a palettized model for export using coremltools.

Parameters:
  • model (nn.Module) – model to finalize.

  • inplace (bool) – If True, model transformations are carried out in-place and the original module is mutated; otherwise, a copy of the model is mutated and returned.

prepare(inplace: bool = False) Module[source]

Prepares a model for palettization aware training by inserting FakePalettize layers in appropriate places as specified by the config.

Parameters:

inplace (bool) – If True, model transformations are carried out in-place and the original module is mutated, otherwise a copy of the model is mutated and returned.

report() _Report[source]

Returns a dictionary with important statistics related to current state of palettization. Each key in the dictionary corresponds to a module name, and the value is a dictionary containing the statistics, such as number of clusters and cluster dimension, number of parameters, and so on.

step()[source]

Step through the palettizer. When the number of times step is called is equal to milestone, palettization is enabled.

class coremltools.optimize.torch.palettization.ModulePostTrainingPalettizerConfig(n_bits: int | None = 4, lut_dtype=None, granularity='per_tensor', group_size: int | None = None, channel_axis: int = 0, cluster_dim: int | None = None, enable_per_channel_scale: bool | None = False)[source]

Configuration class for specifying global and module-level palettization options for PostTrainingPalettizerConfig algorithm.

Parameters:
  • n_bits (int) – Number of bits to use for palettizing the weights. Defaults to 4.

  • lut_dtype (torch.dtype) – The dtype to use for representing each element in lookup tables. When value is None, no quantization is performed. Supported values are torch.int8 and torch.uint8. Defaults to None.

  • granularity (PalettizationGranularity) – One of per_tensor or per_grouped_channel. Defaults to per_tensor.

  • group_size (int) – Specify the number of channels in a group. Only effective when granularity is per_grouped_channel.

  • channel_axis (int) – Specify the channel axis to form a group of channels. Only effective when granularity is per_grouped_channel. Defaults to output channel axis.

  • cluster_dim (int) – The dimension of centroids for each lookup table. The centroid is a scalar by default. When cluster_dim > 1, it indicates 2-D clustering, and each cluster_dim length of weight vectors along the output channel are palettized using the same 2-D centroid. The length of each entry in the lookup tables is equal to cluster_dim.

  • enable_per_channel_scale (bool) – When set to True, weights are normalized along the output channels using per-channel scales before being palettized. This is not supported with cluster_dim > 1.

This class supports two different configurations to structure the palettization:

1. Per-tensor palettization: This is the default configuration where the whole tensor shares a single lookup table. The granularity is set to per_tensor, and group_size is None.

2. Per-grouped-channel palettization: In this configuration, the number of channels group_size along channel_axis share the same lookup table. For example, for a weight matrix of shape (16, 25), if we provide group_size = 8, the shape of the lookup table would be (2, 2^n_bits).

Note

Grouping is currently only supported along either the input or output channel axis.

class coremltools.optimize.torch.palettization.PostTrainingPalettizer(model: Module, config: PostTrainingPalettizerConfig | None = None)[source]

Perform post-training palettization on a torch model. Post palettization, all the weights in supported layers point to elements in a lookup table after performing a k-means operation.

Example

import torch.nn as nn
from coremltools.optimize.torch.palettization import (
    PostTrainingPalettizerConfig,
    PostTrainingPalettizer,
)

model = nn.Sequential(
    OrderedDict(
        {
            "conv": nn.Conv2d(1, 20, (3, 3)),
            "relu1": nn.ReLU(),
            "conv2": nn.Conv2d(20, 20, (3, 3)),
            "relu2": nn.ReLU(),
        }
    )
)

# initialize the palettizer
config = PostTrainingPalettizerConfig.from_dict(
    {
        "global_config": {
            "n_bits": 4,
        },
    }
)

ptpalettizer = PostTrainingPalettizer(model, config)
palettized_model = ptpalettizer.compress()
Parameters:
  • model (torch.nn.Module) – Module to be compressed.

  • config (PostTrainingPalettizerConfig) – Config that specifies how different submodules in the model will be palettized.

class coremltools.optimize.torch.palettization.PostTrainingPalettizerConfig(global_config: ModulePostTrainingPalettizerConfig | None = None, module_type_configs: ModuleTypeConfigType = NOTHING, module_name_configs: Dict[str, ModulePostTrainingPalettizerConfig | None] = NOTHING)[source]

Configuration class for specifying how different submodules of a model should be post-training palettized by PostTrainingPalettizer.

Parameters:
class coremltools.optimize.torch.palettization.ModuleSKMPalettizerConfig(n_bits: int = 4, lut_dtype=None, granularity='per_tensor', group_size: int | None = None, channel_axis: int = 0, cluster_dim: int | None = None, enable_per_channel_scale: bool = False)[source]

Configuration class for specifying global and module-level palettization options for SKMPalettizer algorithm.

Parameters:
  • n_bits (int) – Number of bits to use for palettizing the weights. Defaults to 4.

  • lut_dtype (torch.dtype) – The dtype to use for representing each element in lookup tables. When value is None, no quantization is performed. Supported values are torch.int8 and torch.uint8. Defaults to None.

  • granularity (PalettizationGranularity) – One of per_tensor or per_grouped_channel. Defaults to per_tensor.

  • group_size (int) – Specify the number of channels in a group. Only effective when granularity is per_grouped_channel.

  • channel_axis (int) – Specify the channel axis to form a group of channels. Only effective when granularity is per_grouped_channel. Defaults to output channel axis.

  • cluster_dim (int) – The dimension of centroids for each lookup table. The centroid is a scalar by default. When cluster_dim > 1, it indicates 2-D clustering, and each cluster_dim length of weight vectors along the output channel are palettized using the same 2-D centroid. The length of each entry in the lookup tables is equal to cluster_dim.

  • enable_per_channel_scale (bool) – When set to True, weights are normalized along the output channels using per-channel scales before being palettized. This is not supported with cluster_dim > 1.

This class supports two different configurations to structure the palettization:

1. Per-tensor palettization: This is the default configuration where the whole tensor shares a single lookup table. The granularity is set to per_tensor, and group_size is None.

2. Per-grouped-channel palettization: In this configuration, the number of channels group_size along channel_axis share the same lookup table. For example, for a weight matrix of shape (16, 25), if we provide group_size = 8, the shape of the lookup table would be (2, 2^n_bits).

Note

Grouping is currently only supported along either the input or output channel axis.

class coremltools.optimize.torch.palettization.SKMPalettizer(model: Module, config: SKMPalettizerConfig | None = None)[source]

Perform post-training palettization of weights by running a weighted k-means on the model weights. The weight values used for weighing different elements of a model’s weight matrix are computed using the Fisher information matrix, which is an approximation of the Hessian. These weight values indicate how sensitive a given weight element is: the more sensitive an element, the larger the impact perturbing or palettizing it has on the model’s loss function. This means that weighted k-means moves the clusters closer to the sensitive weight values, allowing them to be represented more exactly. This leads to a lower degradation in model performance after palettization. The Fisher information matrix is computed using a few samples of calibration data.

This algorithm implements SqueezeLLM: Dense-and-Sparse Quantization.

Example

import torch.nn as nn
from coremltools.optimize.torch.palettization import (
    SKMPalettizer,
    SKMPalettizerConfig,
)

model = nn.Sequential(
    OrderedDict(
        {
            "conv": nn.Conv2d(1, 20, (3, 3)),
            "relu1": nn.ReLU(),
            "conv2": nn.Conv2d(20, 20, (3, 3)),
            "relu2": nn.ReLU(),
        }
    )
)

dataloader = load_calibration_data()

# define callable for loss function
def loss_fn(model, data):
    inp, target = data
    out = model(inp)
    return nn.functional.mse_loss(out, target)

# initialize the palettizer
config = SKMPalettizerConfig.from_dict(
    {
        "global_config": {
            "n_bits": 4,
        },
        "calibration_nsamples": 16,
    }
)

compressor = SKMPalettizer(model, config)
compressed_model = compressor.compress(dataloader=dataloader, loss_fn=loss_fn)
Parameters:
  • model (torch.nn.Module) – Module to be compressed.

  • config (LayerwiseCompressorConfig) – Config that specifies how different submodules in the model will be compressed.

class coremltools.optimize.torch.palettization.SKMPalettizerConfig(global_config: ModuleSKMPalettizerConfig | None = None, module_type_configs: ModuleTypeConfigType = NOTHING, module_name_configs: Dict[str, ModuleSKMPalettizerConfig | None] = NOTHING, calibration_nsamples: int = 128)[source]

Configuration class for specifying how different submodules of a model are palettized by SKMPalettizer.

Parameters:
  • global_config (ModuleSKMPalettizerConfig) – Config to be applied globally to all supported modules. Missing values are chosen from the default config.

  • module_type_configs (dict of str to ModuleSKMPalettizerConfig) – Module type configs applied to a specific module class, such as torch.nn.Linear. The keys can be either strings or module classes.

  • module_name_configs (dict of str to ModuleSKMPalettizerConfig) – Module-level configs applied to specific modules. The name of the module must either be a regex or a fully qualified name that can be used to fetch it from the top level module using the module.get_submodule(target) method.

  • calibration_nsamples (int) – Number of samples to be used for calibration.