Palettization
- coremltools.optimize.coreml.palettize_weights(*args, **kwargs)[source]
Utility function to convert a float precision MLModel of type
mlprogram
to a compressed MLModel by reducing the overall number of weights using one or more lookup tables (LUT). A LUT contains a list of float values. Ann-bit
LUT has \(2^{n-bits}\) entries.For example, a float weight vector such as
{0.3, 0.3, 0.5, 0.5}
can be compressed using a 1-bit LUT:{0.3, 0.5}
. In this case the float vector can be replaced with a 1-bit vector{0, 0, 1, 1}
.This function iterates over all the weights in the
mlprogram
, discretizes its values, and constructs the LUT according to the algorithm specified inmode
. The float values are then converted to then-bit
values, and the LUT is saved alongside each weight. Theconst
ops storing weight values are replaced byconstexpr_lut_to_dense
ops.At runtime, the LUT and the
n-bit
values are used to reconstruct the float weight values, which are then used to perform the float operation the weight is feeding into.Consider the following example of
"uniform"
mode (a linear histogram):nbits = 4
mode = "uniform"
weight = [0.11, 0.19, 0.3, 0.08, 0.0, 0.02]
The weight can be converted to a palette with indices
[0, 1, 2, 3]
(2 bits). The indices are a byte array.The data range
[0.0, 0.3]
is divided into four partitions linearly, which is[0.0, 0.1, 0.2, 0.3]
.The LUT would be
[0.0, 0.1, 0.2, 0.3]
.The weight is rounded to
[0.1, 0.2, 0.3, 0.1, 0.0, 0.0]
and represented in the palette as indices[01b, 10b, 11b, 01b, 00b, 00b]
.
- Parameters:
- mlmodel: MLModel
Model to be converted by a LUT. This MLModel should be of type
mlprogram
.- config: OptimizationConfig
An
OptimizationConfig
object that specifies the parameters for weight palettization.- joint_compression: bool
Specification of whether or not to further compress the already-compressed input MLModel to a jointly compressed MLModel. See the channelwise_palettize_weights graph pass for information about which compression schemas could be further jointly palettized.
Take “prune + palettize” as an example of joint compression, where the input MLModel is already pruned, and the non-zero entries will be further palettized. In such an example, the weight values are represented by
constexpr_lut_to_sparse
+constexpr_sparse_to_dense
ops:lut(sparse)
->constexpr_lut_to_sparse
->weight(sparse)
->constexpr_sparse_to_dense
->weight(dense)
- Returns:
- model: MLModel
The palettized MLModel instance.
Examples
import coremltools as ct import coremltools.optimize as cto model = ct.models.MLModel("my_model.mlpackage") config = cto.coreml.OptimizationConfig( global_config=cto.coreml.OpPalettizerConfig(mode="kmeans", nbits=4) ) compressed_model = cto.coreml.palettize_weights(model, config)
- class coremltools.optimize.coreml.OpPalettizerConfig(mode: str = 'kmeans', nbits: int | None = None, lut_function: Callable | None = None, granularity: str | CompressionGranularity = CompressionGranularity.PER_TENSOR, group_size: int = 32, channel_axis: int | None = None, cluster_dim: int = 1, enable_per_channel_scale: bool = False, num_kmeans_workers: int = 1, weight_threshold: int | None = 2048)[source]
- Parameters:
- nbits: int
Number of bits per weight. Required for
kmeans
oruniform
mode, but must not be set forunique
orcustom
mode. A LUT would have 2nbits entries, where nbits can be{1, 2, 3, 4, 6, 8}
.- mode: str
Determine how the LUT is constructed by specifying one of the following:
"kmeans"
(default): The LUT is generated by k-means clustering, a method of vector quantization that groups similar data points together to discover underlying patterns by using a fixed number (k) of clusters in a dataset. A cluster refers to a collection of data points aggregated together because of certain similarities. nbits is required."uniform"
: The LUT is generated by a linear histogram.[v_min, v_min + scale, v_min + 2 * scale, ..., v_max]
Where the weight is in the range
[v_min, v_max]
, andscale = (v_max - v_min) / (1 << nbits - 1)
.nbits
is required.
A histogram is a representation of the distribution of a continuous variable, in which the entire range of values is divided into a series of intervals (or bins) and the representation displays how many values fall into each bin. Linear histograms have one bin at even intervals, such as one bin per integer.
"unique"
: The LUT is generated by unique values in the weights. The weights are assumed to be on a discrete lattice but stored in a float data type. This parameter identifies the weights and converts them into the palettized representation.Do not provide
nbits
for this mode.nbits
is picked up automatically, with the smallest possible value in{1, 2, 4, 6, 8}
such that the number of the unique values is<= (1 << nbits)
. If the weight has> 256
unique values, the compression is skipped.For example:
If the weights are
{0.1, 0.2, 0.3, 0.4}
andnbits=2
, the weights are converted to{00b, 01b, 10b, 11b}
, and the generated LUT is[0.1, 0.2, 0.3, 0.4]
.If the weights are
{0.1, 0.2, 0.3, 0.4}
andnbits=1
, nothing happens because the weights are not a 1-bit lattice.If the weights are
{0.1, 0.2, 0.3, 0.4, 0.5}
andnbits=2
, nothing happens because the weights are not a 2-bit lattice.
"custom"
: The LUT and palettization parameters are calculated using a custom function. If this mode is selected thenlut_function
must be provided.Do not provide
nbits
for this mode. The user should customizenbits
in thelut_function
implementation.
- lut_function: callable
A callable function which computes the weight palettization parameters. This must be provided if the mode is set to
"custom"
.- weight: np.ndarray
A float precision numpy array.
- Returns: lut: list[float]
The lookup table.
- indices: list[int]
A list of indices for each element.
The following is an example that extract the
top_k
elements as the LUT. Given thatweight = [0.1, 0.5, 0.3, 0.3, 0.5, 0.6, 0.7]
, thelut_function
produceslut = [0, 0.5, 0.6, 0.7], indices = [0, 1, 0, 0, 2, 3]
.def lut_function(weight): # In this example, we assume elements in the weights >= 0 weight = weight.flatten() nbits = 4 # Get the LUT, from extracting top k maximum unique elements in the weight to be the LUT # Note that k = 1 << nbits - 1, so we have the first element be 0 unique_elements = np.unique(weight) k = (1 << nbits) - 1 top_k = np.partition(weight, -k)[-k:] np.sort(top_k) lut = [0.0] + top_k.tolist() # Compute the indices mapping = {v: idx for idx, v in enumerate(lut)} indices = [mapping[v] if v in mapping else 0 for v in weight] return lut, indices
- granularity: str
Granularity for quantization. *
"per_tensor"
(default) *"per_grouped_channel"
- group_size: int
Specify the number of channels in a group. Only effective when granularity is per_grouped_channel.
Default to 32.
- channel_axis: Optional[int] = None
Specify the channel axis to form a group of channels. Only effective when granularity is per_grouped_channel.
Default to None, where the axis is automatically picked based on op type.
- cluster_dim: int
The dimension of centroids for each look up table. When cluster_dim == 1, it’s scalar palettization, where each entry in the lookup table is a scalar element. When cluster_dim > 1, it’s vector palettization, where each entry in the lookup table is a vector of length cluster_dim.
More specifically, when
cluster_dim > 1
, eachcluster_dim
length of weight vectors along the channel axis are palettized using the same 2-D centroid. .Default to 1.
- enable_per_channel_scale: bool
When set to True, weights are normalized along the output channels using per channel scales before being palettized.
- num_kmeans_workers: int
Number of worker processes to use for performing k-means. It is recommended to use more than one worker process to parallelize the clustering, especially when multiple CPUs are available.
Default to 1.
- weight_threshold: int
The size threshold, above which weights are pruned. That is, a weight tensor is pruned only if its total number of elements are greater than
weight_threshold
.For example, if
weight_threshold = 1024
and a weight tensor is of shape[10, 20, 1, 1]
, hence200
elements, it will not be pruned.If not provided, it will be set to
2048
, in which weights bigger than2048
elements are compressed.