Source code for coreai_opt.palettization.config.palettization_config

# Copyright 2026 Apple Inc.
#
# Use of this source code is governed by a BSD-3-Clause license that can
# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause

from __future__ import annotations

from typing import TYPE_CHECKING, ClassVar, final

from pydantic import PositiveInt, model_validator

from coreai_opt.config import (
    CompressionConfig,
    ModuleCompressionConfig,
    OpCompressionConfig,
    WeightOnlyModuleValidationMixin,
    WeightOnlyOpValidationMixin,
)
from coreai_opt.palettization.spec import (
    PalettizationSpec,
    default_weight_palettization_spec,
)

if TYPE_CHECKING:
    from coreai_opt.palettization.config._presets import (
        _KMeansPalettizerConfigPresets,
        _ModuleKMeansPalettizerConfigPresets,
    )

_KMEANS_PALETTIZATION_CONFIG = "kmeans_palettization_config"
_PALETTIZATION_SPEC = "palettization_spec"


[docs] class OpKMeansPalettizerConfig(WeightOnlyOpValidationMixin, OpCompressionConfig[PalettizationSpec]): """ Configuration class for palettization at the operation level. Palettization is a weight-only compression technique that doesn't apply to activations (inputs/outputs). Only op_state_spec is used to configure which state tensors (e.g., weights, biases) should be palettized. Attributes: op_state_spec (dict[str, PalettizationSpec | None] | None): Palettization specifications for operation state tensors (parameters, buffers, constants). Keys are string names (e.g., "weight", "bias") or "*" to refer to all state inputs. Values are PalettizationSpec objects or None to disable palettization for that state tensor. Default: 4-bit palettization for "weight" and "in_proj_weight" state tensors via ``default_weight_palettization_spec()``. Example: >>> # Palettize weights with 2-bit LUT >>> op_config = OpKMeansPalettizerConfig( ... op_state_spec={ ... "weight": PalettizationSpec(n_bits=2) ... } ... ) >>> >>> # Disable palettization for a specific operation >>> op_config = OpKMeansPalettizerConfig( ... op_state_spec=None # No state tensors will be palettized ... ) """
[docs] @classmethod def get_default_state_spec(cls) -> dict[str, PalettizationSpec | None]: """Provide default state spec for palettization.""" spec = default_weight_palettization_spec() return {"weight": spec, "in_proj_weight": spec}
[docs] @final class ModuleKMeansPalettizerConfig( WeightOnlyModuleValidationMixin, ModuleCompressionConfig[OpKMeansPalettizerConfig, PalettizationSpec], ): """Configuration for palettizing a specific module using K-means clustering. This class manages palettization settings for an entire module, including: - Operation-level configurations (default, by type, by name) - Module-level state (parameter) palettization The operation configurations follow a hierarchical precedence: 1. op_name_config (most specific - applies to operations matching a name pattern) 2. op_type_config (applies to operations of a specific type) 3. op_state_spec (least specific - applies to all operations not otherwise configured) Module-level state settings treat the module as an opaque entity, setting palettization settings for specified tensors and ignoring op specific palettization capabilities. Module-level settings also don't check whether the operation receiving the palettized tensor is a registered operation or not. Module-level settings will override any op specific settings. Attributes: op_state_spec (dict[str, PalettizationSpec | None] | None): Palettization specifications for operation state tensors (parameters, buffers, constants) applied to all registered operations/patterns within this module that don't have a more specific configuration. Keys can be string names (e.g. "weight", "bias") or "*" to refer to all state inputs. Values are PalettizationSpec objects or None defining how to palettize each state tensor. None value represents disabling palettization. Default: 4-bit palettization for "weight" and "in_proj_weight" state tensors via ``default_weight_palettization_spec()``. op_type_config (dict[str, OpKMeansPalettizerConfig | None] | None): Operation type-specific configurations. Keys are operation type names (e.g., "aten.linear.default", "aten.conv2d.default"). Values are OpKMeansPalettizerConfig objects or None, defining how to palettize operations of that type. None value represents disabling palettization. Default: {} (empty dict, no type-specific configs) op_name_config (dict[str, OpKMeansPalettizerConfig | None] | None): Operation name-specific configurations. Keys are operation name patterns (supports regex matching). Values are OpKMeansPalettizerConfig objects or None, defining how to palettize operations matching those names. None value represents disabling palettization. Default: {} (empty dict, no name-specific configs) module_state_spec (dict[str, PalettizationSpec | None] | None): Palettization specifications for module state tensors (parameters, buffers, and constants). Module state settings will override op state settings for the same state tensors. Keys can be string names (e.g. "weight", "bias") or "*" to refer to all state inputs. Values are PalettizationSpec objects or None. None value represents disabling palettization. Default: {} (empty dict, no specific module state settings) enable_fast_kmeans_mode (bool): When True, enables optimizations for faster K-means clustering by rounding the weights before clustering if data is in float16 range. If weight dtype is float32, weights are cast to float16 and then rounded. This is not supported with ``cluster_dim > 1``. Default: True. rounding_precision (int): Number of decimal places to round to during fast K-means clustering. Higher values preserve more precision but may reduce speed benefits. Only used when enable_fast_kmeans_mode is True. Default: 4. Example: >>> config = ModuleKMeansPalettizerConfig() # Uses defaults >>> # Or with custom settings: >>> from coreai_opt.palettization.spec import PalettizationSpec >>> config = ModuleKMeansPalettizerConfig( ... op_state_spec={"weight": PalettizationSpec(n_bits=2)}, ... enable_fast_kmeans_mode=False, # Disable for maximum precision ... rounding_precision=6 # Higher precision when fast mode is enabled ... ) """ def __init_subclass__(cls, **kwargs): # Prohibit subclassing due to preset limitation: presets remain bound # to the base class. Revisit if subclass support is needed. super().__init_subclass__(**kwargs) msg = f"{cls.__name__} cannot subclass ModuleKMeansPalettizerConfig (marked final)." raise TypeError(msg) enable_fast_kmeans_mode: bool = True rounding_precision: PositiveInt = 4 # Namespace exposing built-in preset constructors. presets: ClassVar[_ModuleKMeansPalettizerConfigPresets] @model_validator(mode="after") def validate_fast_kmeans_cluster_dim_constraint( self, ) -> ModuleKMeansPalettizerConfig: """Validate that enable_fast_kmeans_mode is not True when cluster_dim > 1.""" if self.op_state_spec is None: return self # Check all palettization specs within op_state_spec for state, spec in self.op_state_spec.items(): if self.enable_fast_kmeans_mode and spec is not None and spec.cluster_dim > 1: raise ValueError( f"enable_fast_kmeans_mode is not supported when cluster_dim > 1. " f"Got enable_fast_kmeans_mode={self.enable_fast_kmeans_mode}, " f"cluster_dim={spec.cluster_dim} for state '{state}'" ) return self
[docs] @final class KMeansPalettizerConfig(CompressionConfig[ModuleKMeansPalettizerConfig]): """ Top-level configuration class for kmeans palettization. This class manages the complete palettization configuration for a neural network model, organizing module-level configurations in a hierarchical structure. It inherits from CompressionConfig and specializes it for palettization using ModuleKMeansPalettizerConfig. The configuration lookup follows a hierarchical precedence (most to least specific): 1. module_name_configs - Applies to module instances matching a name pattern (supports regex) 2. module_type_configs - Applies to all modules of a specific type (e.g., torch.nn.modules.linear.Linear) 3. global_config - Default configuration applied to all modules not otherwise configured Attributes: global_config (ModuleKMeansPalettizerConfig | None): Default module-level palettization configuration applied to all modules that don't have a more specific configuration. When KMeansPalettizerConfig is initialized with no arguments, a default global_config is automatically created with standard 4-bit palettization. Setting global_config to None disables palettization by default globally. Default: Auto-created with 4-bit palettization spec when no args provided module_type_configs (dict[str, ModuleKMeansPalettizerConfig | None] | None): Module type-specific configurations. Keys are fully-qualified module type names (e.g., "torch.nn.modules.linear.Linear", "torch.nn.modules.conv.Conv2d"). Values are ModuleKMeansPalettizerConfig objects or None to disable palettization for that module type. Default: {} (empty dict, no type-specific configs) module_name_configs (dict[str, ModuleKMeansPalettizerConfig | None] | None): Module name-specific configurations. Keys are module name patterns (supports regex matching, e.g., "model.layer1.*", "decoder.layers.0"). Values are ModuleKMeansPalettizerConfig objects or None to disable palettization for matching modules. Default: {} (empty dict, no name-specific configs) Example: >>> import torch.nn as nn >>> config = KMeansPalettizerConfig() # Uses defaults >>> # Or with custom settings: >>> config = KMeansPalettizerConfig( ... module_type_configs={nn.Linear: ModuleKMeansPalettizerConfig(...)}, ... module_name_configs={"layer1": None} # Skip palettization ... ) """ def __init_subclass__(cls, **kwargs): # Prohibit subclassing due to preset limitation: presets remain bound # to the base class. Revisit if subclass support is needed. super().__init_subclass__(**kwargs) msg = f"{cls.__name__} cannot subclass KMeansPalettizerConfig (marked final)." raise TypeError(msg) # Class attributes for config key pattern used in from_dict/from_yaml _CONFIG_KEY: ClassVar[str] = _KMEANS_PALETTIZATION_CONFIG _SPEC_KEY: ClassVar[str] = _PALETTIZATION_SPEC # Namespace exposing built-in and registered preset constructors. presets: ClassVar[_KMeansPalettizerConfigPresets]
# Preset wiring — after the class so _presets can resolve references at runtime. from coreai_opt.palettization.config._presets import ( # noqa: E402, PLC0415 _KMeansPalettizerConfigPresets, _ModuleKMeansPalettizerConfigPresets, ) KMeansPalettizerConfig.presets = _KMeansPalettizerConfigPresets(owner_cls=KMeansPalettizerConfig) ModuleKMeansPalettizerConfig.presets = _ModuleKMeansPalettizerConfigPresets( owner_cls=ModuleKMeansPalettizerConfig )