Config API¶
Palettization Configs follow the same philosophy as the Quantization Config.
They are simpler as palettization applies only to the weights in the model.
(Hence there are no op_input_spec and op_output_spec fields in the ModuleKMeansPalettizerConfig and OpKMeansPalettizerConfig.)
PalettizationSpec¶
PalettizationSpec defines the following key properties, among others (for full list see API reference):
n_bits: Number of bits per LUT index. LUT size = 2^n_bits centroids (default: 4)granularity: controls whether there is one LUT for the entire weight tensor or one LUT per group of channels. Allowed:PerTensorGranularity() orPerGroupedChannelGranularity(), defaults to the formercluster_dim: Dimension of the cluster centers, i.e. the entries in the LUT. Defaults to 1; if > 1, it results in vector palettization.lut_qspec: OptionalQuantizationSpecdescribing how to quantize the LUT centroid values themselves. Defaults toNone(LUT centroids are stored at the same dtype as the uncompressed weights). Supported dtypes aretorch.int8,torch.uint8,torch.float8_e4m3fn, andtorch.float8_e5m2.
from coreai_opt.palettization import PalettizationSpec
from coreai_opt.palettization.spec import (
PerGroupedChannelGranularity,
default_weight_palettization_spec,
)
from coreai_opt.quantization import QuantizationSpec
# 4-bit per-tensor (default — 16 centroids)
spec = default_weight_palettization_spec()
# 2-bit per group of 8 channels — finer granularity, better accuracy
spec = PalettizationSpec(
n_bits=2,
granularity=PerGroupedChannelGranularity(axis=0, group_size=8),
)
# 4-bit with quantized LUT entries — LUT centroids stored as INT8
spec = PalettizationSpec(
n_bits=4,
lut_qspec=QuantizationSpec(dtype=torch.int8, qscheme="symmetric"),
)
# Cluster pairs of weight values instead of individual scalars
spec = PalettizationSpec(n_bits=4, cluster_dim=2)
Note
Reproducibility with cluster_dim > 1:
When cluster_dim > 1, palettization uses vector k-means whose centroid initialization relies on numpy.random and torch.randint, so it is non-deterministic.
# model gets different weights when we run it multiple times
model_1 = KMeansPalettizer(model, config).prepare(example_inputs)
model_2 = KMeansPalettizer(model, config).prepare(example_inputs)
To obtain reproducible results, seed numpy and torch before each call to prepare():
seed = 42
# models now have identical palettized weights
np.random.seed(seed)
torch.manual_seed(seed)
model_1 = KMeansPalettizer(model, config).prepare(example_inputs)
np.random.seed(seed)
torch.manual_seed(seed)
model_2 = KMeansPalettizer(model, config).prepare(example_inputs)
When prepare() is called with num_workers > 1, k-means runs in spawned worker processes that do not inherit the parent’s RNG state, so the seeding advice above is only effective with num_workers=1. Use the sequential path if you need reproducible vector palettization.
Scalar palettization (cluster_dim == 1, the default) is deterministic and does not require seeding.
Config classes and their defaults¶
The palettization config system mirrors quantization’s three-class hierarchy, but scoped to weights only (no activation quantization):
KMeansPalettizerConfig— the top-level config for the entire model. It holds aglobal_config, plus optionalmodule_type_configsandmodule_name_configsfor overrides. Same precedence as quantization:module_name_configs>module_type_configs>global_config.ModuleKMeansPalettizerConfig— controls palettization for all ops within a module’s scope (or all modules if used as aglobal_config). LikeModuleQuantizerConfig, it specifies a defaultop_state_specfor ops in the module and allows overrides viaop_type_config,op_name_config, andmodule_state_spec. Since palettization is weight-only, the activation/IO fields (op_input_spec,op_output_spec,module_input_spec,module_output_spec) are absent. For a given op’s weight, the spec is resolved in this priority order (highest first):module_state_spec, the matching entry inop_name_config, the matching entry inop_type_config, then the module’sop_state_spec.OpKMeansPalettizerConfig— controls palettization for a specific op type or op name. Onlyop_state_specis used.
Default behavior when no arguments are provided¶
Creating any of these config classes with no arguments gives you a ready-to-use 4-bit weight palettization configuration:
# All three of these produce equivalent default palettization settings:
config = KMeansPalettizerConfig()
# is equivalent to:
config = KMeansPalettizerConfig(global_config=ModuleKMeansPalettizerConfig())
# which is equivalent to:
config = KMeansPalettizerConfig(
global_config=ModuleKMeansPalettizerConfig(
op_state_spec={
"weight": default_weight_palettization_spec(),
"in_proj_weight": default_weight_palettization_spec(),
},
)
)
op_config = OpKMeansPalettizerConfig()
# is equivalent to:
op_config = OpKMeansPalettizerConfig(
op_state_spec={
"weight": default_weight_palettization_spec(),
"in_proj_weight": default_weight_palettization_spec(),
},
)
The default applies
default_weight_palettization_spec()— 4-bit, per-tensor granularity, scalar clustering — to parameters named"weight"and"in_proj_weight". Other state tensors (e.g.,"bias") are left uncompressed.If you need different behavior — such as palettizing custom parameter names, excluding certain modules, or applying different bit widths to different layers, see the Examples section.
Examples¶
Several examples below configure specific module types or module names. To determine these for your model, see How to get names + types. Since palettization only supports eager execution mode, only the eager mode guidance in that section is relevant.
Apply 4-bit palettization globally, 8-bit to linear layers¶
Apply 4-bit palettization to all supported layers, and override linear layers to 8-bit.
# programmatic — using presets
import torch.nn as nn
from coreai_opt.palettization import (
KMeansPalettizerConfig,
ModuleKMeansPalettizerConfig,
)
# define a config that applies 4-bit per-grouped-channel palettization to all supported layers, using one of the "pre-defined" presets.
config = KMeansPalettizerConfig.presets.w4()
# then update this config, to change the palettization for just the linear layers: to 8-bit per-tensor
config.set_module_type(nn.Linear, ModuleKMeansPalettizerConfig.presets.w8())
The snippet above applies 4-bit palettization globally (covering Conv2d and all other supported modules), then overrides Linear layers to 8-bit.
Config chaining¶
The setters also return the config itself, so multiple modifications can be chained into a single expression. The snippet above is equivalent to:
config = KMeansPalettizerConfig.presets.w4().set_module_type(
nn.Linear, ModuleKMeansPalettizerConfig.presets.w8()
)
# yaml
kmeans_palettization_config:
global_config:
op_state_spec:
weight:
n_bits: 4
granularity: { type: per_grouped_channel, axis: 0, group_size: 16 }
module_type_configs:
torch.nn.modules.linear.Linear:
op_state_spec:
weight:
n_bits: 8
granularity: { type: per_tensor }
Apply 4-bit palettization to conv layers, 8-bit to linear layers¶
When you want to palettize only specific module types and leave everything else uncompressed, construct the config explicitly without a global_config. Each module type gets its own ModuleKMeansPalettizerConfig, and modules not listed in module_type_configs are skipped.
# programmatic — explicit (scoped to specific module types)
from coreai_opt.palettization import (
KMeansPalettizerConfig,
ModuleKMeansPalettizerConfig,
PalettizationSpec,
)
config = KMeansPalettizerConfig(
module_type_configs={
"torch.nn.modules.linear.Linear": ModuleKMeansPalettizerConfig(
op_state_spec={"weight": PalettizationSpec(n_bits=8)},
),
"torch.nn.modules.conv.Conv2d": ModuleKMeansPalettizerConfig(
op_state_spec={"weight": PalettizationSpec(n_bits=4)},
),
},
)
# yaml
kmeans_palettization_config:
module_type_configs:
torch.nn.modules.linear.Linear:
op_state_spec:
weight:
n_bits: 8
granularity: { type: per_tensor }
torch.nn.modules.conv.Conv2d:
op_state_spec:
weight:
n_bits: 4
granularity: { type: per_tensor }