Palettization
Palettization is a mechanism for compressing a model by clustering the model’s float weights into a lookup table (LUT) of centroids and indices.
Palettization is implemented as an extension of PyTorch’s QAT APIs. It works by inserting palettization layers in appropriate places inside a model. The model can then be fine-tuned to learn the new palettized layers’ weights in the form of a LUT and indices.
- class coremltools.optimize.torch.palettization.ModuleDKMPalettizerConfig(n_bits: int | None = None, weight_threshold: int = 2048, granularity='per_tensor', group_size: int | None = None, channel_axis: int = 0, enable_per_channel_scale: bool = False, milestone: int = 0, cluster_dim: int | None = None, quant_min: int = -128, quant_max: int = 127, dtype: str | dtype = torch.qint8, lut_dtype: str = 'f32', quantize_activations: bool = False, cluster_permute: tuple | None = None, palett_max_mem: float = 1.0, kmeans_max_iter: int = 3, prune_threshold: float = 1e-07, kmeans_init: str = 'auto', kmeans_opt1d_threshold: int = 1024, enforce_zero: bool = False, palett_mode: str = 'dkm', palett_tau: float = 0.0001, palett_epsilon: float = 0.0001, palett_lambda: float = 0.0, add_extra_centroid: bool = False, palett_cluster_tol: float = 0.0, palett_min_tsize: int = 65536, palett_unique: bool = False, palett_shard: bool = False, palett_batch_mode: bool = False, palett_dist: bool = False, per_channel_scaling_factor_scheme: str = 'min_max', percentage_palett_enable: float = 1.0, kmeans_batch_threshold: int = 4, kmeans_n_init: int = 10, zero_threshold: float = 1e-07, kmeans_error_bnd: float = 0.0, partition_size: int | None = None, cluster_dtype: str | None = None)[source]
Configuration class for specifying global and module-level options for the palettization algorithm implemented in
DKMPalettizer
.The parameters specified in this config control the DKM algorithm, described in DKM: Differentiable K-Means Clustering Layer for Neural Network Compression.
For most use cases, the only parameters you need to specify are
n_bits
,weight_threshold
, andmilestone
.Note
Most of the parameters in this class are meant for advanced use cases and for further fine-tuning the DKM algorithm. The default values usually work for a majority of tasks.
Note
Change the following parameters only when you use activation quantization in conjunction with DKM weight palettization:
quant_min
,quant_max
,dtype
, andquantize_activations
.- Parameters:
n_bits (
int
) – Number of clusters. The number of clusters used is . Defaults to4
for linear layers and2
for all other layers.weight_threshold (
int
) – A module is only palettized if the number of elements in its weight matrix exceedsweight_threshold
. If there are multiple weights in a module, such astorch.nn.MultiheadAttention
, all of them must have more elements than theweight_threshold
for the module to be palettized. Defaults to2048
.granularity (
PalettizationGranularity
) – Granularity for palettization. One ofper_tensor
orper_grouped_channel
. Defaults toper_tensor
.group_size (
int
) – Specify the number of channels in a group. Only effective when granularity isper_grouped_channel
.channel_axis (
int
) – Specify the channel axis to form a group of channels. Only effective when granularity isper_grouped_channel
. Defaults to output channel axis. For now, only output channel axis is supported by DKM.enable_per_channel_scale (
bool
) – When set toTrue
, per-channel scaling is used along the channel dimension.milestone (
int
) – Step or epoch at which palettization begins. Defaults to0
.cluster_dim (
int
,optional
) – The dimension of each cluster.quant_min (
int
,optional
) – The minimum value for each element in the weight clusters if they are quantized. Defaults to-128
.quant_max (
int
,optional
) – The maximum value for each element in the weight clusters if they are quantized. Defaults to127
dtype (
torch.dtype
,optional
) – Thedtype
to use for quantizing the activations. Only applies whenquantize_activations
isTrue
. Defaults totorch.qint8
.lut_dtype (
str
,optional
) –dtype
to use for quantizing the clusters. Allowed options are'i8'
,'u8'
,'f16'
,'bf16'
,'f32'
. Defaults to'f32'
, so by default, the clusters aren’t quantized.quantize_activations (
bool
,optional
) – WhenTrue
, the activations are quantized. Defaults toFalse
.cluster_permute (
tuple
,optional
) – Permutation order to apply to weight partitions. Defaults toNone
.palett_max_mem (
float
,optional
) – Proportion of available GPU memory that should be used for palettization. Defaults to1.0
.kmeans_max_iter (
int
,optional
) – Maximum number of differentiablek-means
iterations. Defaults to3
.prune_threshold (
float
,optional
) – Hardshrinks weights between [-prune_threshold
,prune_threshold
] to zero. Useful for joint pruning and palettization. Defaults to1e-7
.kmeans_init (
str
,optional
) –k-means
algorithm to use. Allowed options areopt1d
,cpu.kmeans++
andkmeans++
. Defaults toauto
.kmeans_opt1d_threshold (
int
,optional
) – Channel threshold to decide ifopt1d kmeans
should be used. Defaults to1024
.enforce_zero (
bool
,optional
) – IfTrue
, enforces the LUT centroid which is closest to the origin to be fixed to zero. Defaults toFalse
.palett_mode (
str
,optional
) – Criteria to calculate attention duringk-means
. Allowed options aregsm
,dkm
andhard
. Defaults todkm
.palett_tau (
float
,optional
) – Temperature factor for softmax used in DKM algorithm. Defaults to0.0001
.palett_epsilon (
float
,optional
) – Distance threshold for clusters betweenk-means
iterations. Defaults to0.0001
.palett_lambda (
float
,optional
) – Reduces effects of outliers during centroid calculation. Defaults to0.0
.add_extra_centroid (
bool
,optional
) – IfTrue
, adds an extra centroid to the LUT. Defaults toFalse
.palett_cluster_tol (
float
,optional
) – Tolerance for non-unique centroids in the LUT. The higher the number, the more tolerance for non-unique centroids. Defaults to0.0
.palett_min_tsize (
int
,optional
) – Weight threshold beyond which to use custom packing and unpacking hook for autograd. Defaults to64*1024
.palett_unique (
bool
,optional
) – IfTrue
, reduces the attention map by leveraging the fact that FP16 only has2^16
unique values. Useful for Large Models like LLMs where attention maps can be huge. Defaults toFalse
. For more details, read eDKM: An Efficient and Accurate Train-time Weight Clustering for Large Language Models .palett_shard (
bool
,optional
) – IfTrue
, the index list is sharded across GPUs. Defaults toFalse
. For more details, read eDKM: An Efficient and Accurate Train-time Weight Clustering for Large Language Models .palett_batch_mode (
bool
,optional
) – IfTrue
, performs batch DKM across different partitions created for different blocks. Defaults toFalse
. More details can be found eDKM: An Efficient and Accurate Train-time Weight Clustering for Large Language Models .palett_dist (
bool
,optional
) – IfTrue
, performs distributed distance calculation in batch_mode if distributed torch is available. Defaults toFalse
.per_channel_scaling_factor_scheme (
str
,optional
) – Criteria to calculate theper_channel_scaling_factor
. Allowed options aremin_max
andabs
. Defaults tomin_max
.percentage_palett_enable (
float
,optional
) – Percentage partitions to enable for DKM. Defaults to1.0
.kmeans_batch_threshold (
int
,optional
) – Threshold to decide what thenum_partitions
value should be to go through with the sharded centroids list.num_partitions
is calculated by dividing the channel size by thegroup_size
provided. Ifnum_partitions`
matcheskmeans_batch_threshold
, the algorithm resorts to performing distributed k-means for lower partition numbers, given thatnum_partition
number of GPUs are available. Defaults to4
.kmeans_n_init (
int
,optional
) – Number of time the k-means algorithm will be run with different centroid seeds. The final results will be the best output ofkmeans_n_init
consecutive runs in terms of inertia.zero_threshold (
int
,optional
) – Zero threshold to be used to decide the minimum value of clamp for softmax. Defaults to1e-7
.kmeans_error_bnd (
float
,optional
) – Distance threshold to decide at what distance between parameters and clusters to stop thek-means
operation. Defaults to0.0
.
This class supports two different configurations to structure the palettization:
1. Per-tensor palettization: This is the default configuration where the whole tensor shares a single lookup table. The
granularity
is set toper_tensor
andgroup_size
isNone
.2. Per-grouped-channel palettization: In this configuration,
group_size
number of channels alongchannel_axis
share the same lookup table. For example, for a weight matrix of shape(16, 25)
, if we providegroup_size = 8
, the shape of the lookup table would be(2, 2^n_bits)
.Note
Grouping is currently only supported along the output channel axis.
- as_dict() Dict[str, Any]
Returns the config as a dictionary.
- classmethod from_dict(data_dict: Dict[str, Any]) DictableDataClass
Create class from a dictionary of string keys and values.
- Parameters:
data_dict (
dict
ofstr
and values) – A nested dictionary of strings and values.
- classmethod from_yaml(yml: IO | str) DictableDataClass
Create class from a yaml stream.
- Parameters:
yml – An
IO
stream containing yaml or astr
path to the yaml file.
- class coremltools.optimize.torch.palettization.DKMPalettizerConfig(global_config: GlobalConfigType | None = None, module_type_configs: ModuleTypeConfigType = NOTHING, module_name_configs: ModuleNameConfigType = NOTHING)[source]
Configuration for specifying how different submodules of a model are palettized by
DKMPalettizer
.The
module_type_configs
parameter can accept a list ofModuleDKMPalettizerConfig
as values for a given module type. The list can specify different parameters for differentweight_threshold
values. This is useful if you want to apply different configs to layers of the same type with weights of different sizes.For example, to use
4
-bit palettization for weights with more than1000
elements and2
-bit palettization for weights with more than300
but less than1000
elements, create a config as follows:custom_config = { nn.Conv2d: [ {"n_bits": 4, "cluster_dim": 4, "weight_threshold": 1000}, {"n_bits": 2, "cluster_dim": 2, "weight_threshold": 300}, ] } config = DKMPalettizerConfig.from_dict({"module_type_configs": custom_config})
- Parameters:
global_config (
ModuleDKMPalettizerConfig
) – Config to be applied globally to all supported modules. Missing values are chosen from the default config.module_type_configs (
dict
ofstr
toModuleDKMPalettizerConfig
) – Module type level configs applied to a specific module class, such astorch.nn.Linear
. The keys can be either strings or module classes. Whenmodule_type_config
is set toNone
for a module type, it is not palettized.module_name_configs (
dict
ofstr
toModuleDKMPalettizerConfig
) – Module-level configs applied to specific modules. The name of the module must be a fully qualified name that can be used to fetch it from the top-level module using themodule.get_submodule(target)
method. Whenmodule_name_config
is set toNone
for a module, it is not palettized.
- as_dict() Dict[str, Any]
Returns the config as a dictionary.
- classmethod from_dict(config_dict: Dict[str, Any]) DKMPalettizerConfig [source]
Create class from a dictionary of string keys and values.
- Parameters:
config_dict (
dict
ofstr
and values) – A nested dictionary of strings and values.
- classmethod from_yaml(yml: IO | str) DictableDataClass
Create class from a yaml stream.
- Parameters:
yml – An
IO
stream containing yaml or astr
path to the yaml file.
- set_global(global_config: ModuleOptimizationConfig | None) OptimizationConfig
Set the global config.
- set_module_name(module_name: str, opt_config: ModuleOptimizationConfig | None) OptimizationConfig
Set the module level optimization config for a given module instance. If the module level optimization config for an existing module was already set, the new config will override the old one.
- set_module_type(object_type: Callable | str, opt_config: ModuleOptimizationConfig | None) OptimizationConfig
Set the module level optimization config for a given module type. If the module level optimization config for an existing module type was already set, the new config will override the old one.
- class coremltools.optimize.torch.palettization.DKMPalettizer(model: Module, config: DKMPalettizerConfig | None = None)[source]
A palettization algorithm based on “DKM: Differentiable K-Means Clustering Layer for Neural Network Compression”. It clusters the weights using a differentiable version of
k-means
, allowing the lookup table (LUT) and indices of palettized weights to be learnt using a gradient-based optimization algorithm such as SGD.Example
import torch from coremltools.optimize.torch.palettization import ( DKMPalettizer, DKMPalettizerConfig, ModuleDKMPalettizerConfig, ) # code that defines the pytorch model, loss and optimizer. model, loss_fn, optimizer = create_model_loss_and_optimizer() # initialize the palettizer config = DKMPalettizerConfig(global_config=ModuleDKMPalettizerConfig(n_bits=4)) palettizer = DKMPalettizer(model, config) # prepare the model to insert FakePalettize layers for palettization model = palettizer.prepare(inplace=True) # use palettizer in your PyTorch training loop for inputs, labels in data: output = model(inputs) loss = loss_fn(output, labels) loss.backward() optimizer.step() palettizer.step() # fold LUT and indices into weights model = palettizer.finalize(inplace=True)
- Parameters:
model (
torch.nn.Module
) – Model on which the palettizer will act.config (
DKMPalettizerConfig
) – Config which specifies how different submodules in the model will be configured for palettization. Default config is used when passed asNone
.
- finalize(model: Module | None = None, inplace: bool = False) Module [source]
Removes
FakePalettize
layers from a model and creates new model weights from theLUT
andindices
buffers.This function is called to prepare a palettized model for export using coremltools.
- Parameters:
model (
nn.Module
) – model to finalize.inplace (
bool
) – IfTrue
, model transformations are carried out in-place and the original module is mutated; otherwise, a copy of the model is mutated and returned.
- prepare(inplace: bool = False) Module [source]
Prepares a model for palettization aware training by inserting
FakePalettize
layers in appropriate places as specified by the config.- Parameters:
inplace (
bool
) – IfTrue
, model transformations are carried out in-place and the original module is mutated, otherwise a copy of the model is mutated and returned.
- report() _Report [source]
Returns a dictionary with important statistics related to current state of palettization. Each key in the dictionary corresponds to a module name, and the value is a dictionary containing the statistics, such as number of clusters and cluster dimension, number of parameters, and so on.
- class coremltools.optimize.torch.palettization.ModulePostTrainingPalettizerConfig(n_bits: int | None = 4, lut_dtype=None, granularity='per_tensor', group_size: int | None = None, channel_axis: int = 0, cluster_dim: int | None = None, enable_per_channel_scale: bool | None = False, enable_fast_kmeans_mode: bool | None = True, rounding_precision: int | None = 4)[source]
Configuration class for specifying global and module-level palettization options for
PostTrainingPalettizerConfig
algorithm.- Parameters:
n_bits (
int
) – Number of bits to use for palettizing the weights. Defaults to4
.lut_dtype (
torch.dtype
) – The dtype to use for representing each element in lookup tables. When value isNone
, no quantization is performed. Supported values aretorch.int8
andtorch.uint8
. Defaults toNone
.granularity (
PalettizationGranularity
) – One ofper_tensor
orper_grouped_channel
. Defaults toper_tensor
.group_size (
int
) – Specify the number of channels in a group. Only effective when granularity isper_grouped_channel
.channel_axis (
int
) – Specify the channel axis to form a group of channels. Only effective when granularity isper_grouped_channel
. Defaults to output channel axis.cluster_dim (
int
) – The dimension of centroids for each lookup table. The centroid is a scalar by default. Whencluster_dim > 1
, it indicates 2-D clustering, and eachcluster_dim
length of weight vectors along the output channel are palettized using the same 2-D centroid. The length of each entry in the lookup tables is equal tocluster_dim
.enable_per_channel_scale (
bool
) – When set toTrue
, weights are normalized along the output channels using per-channel scales before being palettized. This is not supported withcluster_dim > 1
.enable_fast_kmeans_mode (
bool
) – When turned on, will round the weights before clustering if data is in fp16 range. If weight dtype is fp32, weights are cast to fp16 and then rounded. This is not supported withcluster_dim > 1
. Defaults to True.rounding_precision (
int
) – The number of decimal places to set for rounding, when enable_fast_kmeans_mode is enabled. Choose a lower precision for faster processing, at the cost of coarser approximation. Defaults to 4.
This class supports two different configurations to structure the palettization:
1. Per-tensor palettization: This is the default configuration where the whole tensor shares a single lookup table. The
granularity
is set toper_tensor
, andgroup_size
isNone
.2. Per-grouped-channel palettization: In this configuration, the number of channels
group_size
alongchannel_axis
share the same lookup table. For example, for a weight matrix of shape(16, 25)
, if we providegroup_size = 8
, the shape of the lookup table would be(2, 2^n_bits)
.Note
Grouping is currently only supported along either the input or output channel axis.
- class coremltools.optimize.torch.palettization.PostTrainingPalettizer(model: Module, config: PostTrainingPalettizerConfig | None = None)[source]
Perform post-training palettization on a torch model. Post palettization, all the weights in supported layers point to elements in a lookup table after performing a k-means operation.
Example
import torch.nn as nn from coremltools.optimize.torch.palettization import ( PostTrainingPalettizerConfig, PostTrainingPalettizer, ) model = nn.Sequential( OrderedDict( { "conv": nn.Conv2d(1, 20, (3, 3)), "relu1": nn.ReLU(), "conv2": nn.Conv2d(20, 20, (3, 3)), "relu2": nn.ReLU(), } ) ) # initialize the palettizer config = PostTrainingPalettizerConfig.from_dict( { "global_config": { "n_bits": 4, }, } ) ptpalettizer = PostTrainingPalettizer(model, config) palettized_model = ptpalettizer.compress()
- Parameters:
model (
torch.nn.Module
) – Module to be compressed.config (
PostTrainingPalettizerConfig
) – Config that specifies how different submodules in the model will be palettized.
- class coremltools.optimize.torch.palettization.PostTrainingPalettizerConfig(global_config: ModulePostTrainingPalettizerConfig | None = None, module_type_configs: ModuleTypeConfigType = NOTHING, module_name_configs: Dict[str, ModulePostTrainingPalettizerConfig | None] = NOTHING)[source]
Configuration class for specifying how different submodules of a model should be post-training palettized by
PostTrainingPalettizer
.- Parameters:
global_config (
ModulePostTrainingPalettizerConfig
) – Config to be applied globally to all supported modules.module_type_configs (
dict
ofstr
toModulePostTrainingPalettizerConfig
) – Module type configs applied to a specific module class, such astorch.nn.Linear
. The keys can be either strings or module classes.module_name_configs (
dict
ofstr
toModulePostTrainingPalettizerConfig
) – Module name configs applied to specific modules. This can be a dictionary with module names pointing to their correspondingModulePostTrainingPalettizerConfig
.
- class coremltools.optimize.torch.palettization.ModuleSKMPalettizerConfig(n_bits: int = 4, lut_dtype=None, granularity='per_tensor', group_size: int | None = None, channel_axis: int = 0, cluster_dim: int | None = None, enable_per_channel_scale: bool = False)[source]
Configuration class for specifying global and module-level palettization options for
SKMPalettizer
algorithm.- Parameters:
n_bits (
int
) – Number of bits to use for palettizing the weights. Defaults to4
.lut_dtype (
torch.dtype
) – The dtype to use for representing each element in lookup tables. When value isNone
, no quantization is performed. Supported values aretorch.int8
andtorch.uint8
. Defaults toNone
.granularity (
PalettizationGranularity
) – One ofper_tensor
orper_grouped_channel
. Defaults toper_tensor
.group_size (
int
) – Specify the number of channels in a group. Only effective when granularity isper_grouped_channel
.channel_axis (
int
) – Specify the channel axis to form a group of channels. Only effective when granularity isper_grouped_channel
. Defaults to output channel axis.cluster_dim (
int
) – The dimension of centroids for each lookup table. The centroid is a scalar by default. Whencluster_dim > 1
, it indicates 2-D clustering, and eachcluster_dim
length of weight vectors along the output channel are palettized using the same 2-D centroid. The length of each entry in the lookup tables is equal tocluster_dim
.enable_per_channel_scale (
bool
) – When set toTrue
, weights are normalized along the output channels using per-channel scales before being palettized. This is not supported withcluster_dim > 1
.
This class supports two different configurations to structure the palettization:
1. Per-tensor palettization: This is the default configuration where the whole tensor shares a single lookup table. The
granularity
is set toper_tensor
, andgroup_size
isNone
.2. Per-grouped-channel palettization: In this configuration, the number of channels
group_size
alongchannel_axis
share the same lookup table. For example, for a weight matrix of shape(16, 25)
, if we providegroup_size = 8
, the shape of the lookup table would be(2, 2^n_bits)
.Note
Grouping is currently only supported along either the input or output channel axis.
- class coremltools.optimize.torch.palettization.SKMPalettizer(model: Module, config: SKMPalettizerConfig | None = None)[source]
Perform post-training palettization of weights by running a weighted k-means on the model weights. The weight values used for weighing different elements of a model’s weight matrix are computed using the Fisher information matrix, which is an approximation of the Hessian. These weight values indicate how sensitive a given weight element is: the more sensitive an element, the larger the impact perturbing or palettizing it has on the model’s loss function. This means that weighted k-means moves the clusters closer to the sensitive weight values, allowing them to be represented more exactly. This leads to a lower degradation in model performance after palettization. The Fisher information matrix is computed using a few samples of calibration data.
This algorithm implements SqueezeLLM: Dense-and-Sparse Quantization.
Example
import torch.nn as nn from coremltools.optimize.torch.palettization import ( SKMPalettizer, SKMPalettizerConfig, ) model = nn.Sequential( OrderedDict( { "conv": nn.Conv2d(1, 20, (3, 3)), "relu1": nn.ReLU(), "conv2": nn.Conv2d(20, 20, (3, 3)), "relu2": nn.ReLU(), } ) ) dataloader = load_calibration_data() # define callable for loss function def loss_fn(model, data): inp, target = data out = model(inp) return nn.functional.mse_loss(out, target) # initialize the palettizer config = SKMPalettizerConfig.from_dict( { "global_config": { "n_bits": 4, }, "calibration_nsamples": 16, } ) compressor = SKMPalettizer(model, config) compressed_model = compressor.compress(dataloader=dataloader, loss_fn=loss_fn)
- Parameters:
model (
torch.nn.Module
) – Module to be compressed.config (
LayerwiseCompressorConfig
) – Config that specifies how different submodules in the model will be compressed.
- class coremltools.optimize.torch.palettization.SKMPalettizerConfig(global_config: ModuleSKMPalettizerConfig | None = None, module_type_configs: ModuleTypeConfigType = NOTHING, module_name_configs: Dict[str, ModuleSKMPalettizerConfig | None] = NOTHING, calibration_nsamples: int = 128)[source]
Configuration class for specifying how different submodules of a model are palettized by
SKMPalettizer
.- Parameters:
global_config (
ModuleSKMPalettizerConfig
) – Config to be applied globally to all supported modules. Missing values are chosen from the default config.module_type_configs (
dict
ofstr
toModuleSKMPalettizerConfig
) – Module type configs applied to a specific module class, such astorch.nn.Linear
. The keys can be either strings or module classes.module_name_configs (
dict
ofstr
toModuleSKMPalettizerConfig
) – Module-level configs applied to specific modules. The name of the module must either be a regex or a fully qualified name that can be used to fetch it from the top level module using themodule.get_submodule(target)
method.calibration_nsamples (
int
) – Number of samples to be used for calibration.