Source code for coreai_opt.quantization.config.quantization_config

# Copyright 2026 Apple Inc.
#
# Use of this source code is governed by a BSD-3-Clause license that can
# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause

"""Quantization config class definitions."""

from __future__ import annotations

from enum import auto
from typing import TYPE_CHECKING, ClassVar, NamedTuple, Self, TypeAlias, final

from pydantic import BaseModel, ConfigDict, Field, model_validator

from coreai_opt._utils.config_utils import ALL_TENSORS as _ALL_TENSORS
from coreai_opt.common import (
    _DeprecatedMemberEnumMeta,
    _StrEnum,
)
from coreai_opt.config import (
    CompressionConfig,
    ModuleCompressionConfig,
    OpCompressionConfig,
)
from coreai_opt.quantization.spec import (
    QuantizationSpec,
    default_activation_quantization_spec,
    default_weight_quantization_spec,
)

if TYPE_CHECKING:
    from coreai_opt.quantization.config._presets import (
        _ModuleQuantizerConfigPresets,
        _QuantizerConfigPresets,
    )


[docs] class QATSchedule(BaseModel): """Schedule for controlling observer and fake quantization state in QAT. Defines step thresholds for enabling/disabling observers and fake quantization during quantization-aware training. Must be used in conjunction with the ``quantizer.step()`` API to advance the schedule. The step values correspond to the cadence at which ``quantizer.step()`` is called. For example, if ``step()`` is called once per batch, the thresholds represent batch steps; if called once per epoch, they represent epochs. Calling ``step()`` increments the step counter and immediately applies the corresponding observer/fake-quantization state. Where you place ``step()`` in your training loop determines when the model sees the new state. Attributes: enable_observer: Step count at which observers are enabled. Must be >= 0. enable_fake_quant: Step count at which fake quantization is enabled. Must be >= enable_observer. disable_observer: Step count at which observers are disabled. Must be > enable_observer and >= enable_fake_quant if provided. None means observers are never disabled by the schedule. Example: >>> schedule = QATSchedule( ... enable_observer=0, ... enable_fake_quant=500, ... disable_observer=1500, ... ) Note: In graph execution mode, when consecutive modules both quantize the intermediate edge (one via ``op_output_spec``, the next via ``op_input_spec``), graph mode deduplicates them into a single fake-quantize node. The schedule of the consuming module is always applied to the deduplicated node, irrespective of the choice of deduplication made by the graph preparation. Note: When two modules share a weight parameter and have different schedules, the schedule of the first module encountered in the module tree is applied. A warning is emitted for the conflict if there is no fake-quantize node deduplication happening (in Eager execution mode). """ model_config = ConfigDict(frozen=True) class _ScheduleState(NamedTuple): obs_on: bool fq_on: bool enable_observer: int = Field(default=0, ge=0) enable_fake_quant: int = Field(default=0, ge=0) disable_observer: int | None = Field(default=None, gt=0) @model_validator(mode="after") def _validate_schedule(self) -> QATSchedule: if self.enable_fake_quant < self.enable_observer: raise ValueError( f"enable_fake_quant ({self.enable_fake_quant}) must be >= " f"enable_observer ({self.enable_observer})" ) if self.disable_observer is not None: if self.disable_observer <= self.enable_observer: raise ValueError( f"disable_observer ({self.disable_observer}) must be > " f"enable_observer ({self.enable_observer})" ) if self.disable_observer < self.enable_fake_quant: raise ValueError( f"disable_observer ({self.disable_observer}) must be >= " f"enable_fake_quant ({self.enable_fake_quant})" ) return self def _compute_state(self, step_count: int) -> _ScheduleState: """Return the observer/fake_quant state at the given step.""" obs_end = self.disable_observer or float("inf") return self._ScheduleState( obs_on=self.enable_observer <= step_count < obs_end, fq_on=step_count >= self.enable_fake_quant, )
_QUANTIZATION_CONFIG = "quantization_config" _QUANTIZATION_SPEC = "quantization_spec" _ACTIVATION_SPEC_DICT: TypeAlias = dict[str | int, QuantizationSpec | None] _STATE_SPEC_DICT: TypeAlias = dict[str, QuantizationSpec | None]
[docs] class ExecutionMode(_StrEnum, metaclass=_DeprecatedMemberEnumMeta): """Enum representing quantization execution modes. Each member is a string value representing the execution mode used for quantization. Attributes: GRAPH: Graph-based quantization using ``torch.export`` to capture the model as an FX graph, then applying quantization on top. Built on ``torchao``'s PT2E implementation. Requires the model to be exportable via ``torch.export.export``. Recommended default. EAGER: Eager-mode quantization that works directly on ``nn.Module`` without graph capture. Supports dynamic control flow (if/else, loops) and is the fallback when a model is not exportable. """ GRAPH = auto() EAGER = auto() __deprecated_aliases__: ClassVar[dict[str, str]] = {"PT2E": "GRAPH"} if TYPE_CHECKING: # Surface the deprecated alias above for static type checkers. PT2E: ExecutionMode """Deprecated. Use ``ExecutionMode.GRAPH`` instead."""
[docs] class OpQuantizerConfig(OpCompressionConfig[QuantizationSpec]): """ Configuration class for quantization at the operation level. This class specifies quantization settings for inputs, outputs, and state tensors of individual operations (ops) in a neural network. Each tensor can have its own quantization specification. Quantization for operations will require the operation to be registered. Even if globally all ops are configured to be quantized, ops which are not recognized will not be quantized. Attributes: op_input_spec (dict[str | int, QuantizationSpec | None] | None): Quantization specifications for operation inputs. Keys can be either all indices or all string names, but not a mix of both. The special key "*" can be used in both cases to refer to all inputs. Example keys: - int: Input index (e.g., 0 for first input, 1 for second input) - str: Named input identifier (e.g., "x", "input_0") - "\*": Applies to all inputs for the operation. Other tensors can be explicitly mentioned to override this setting. Values are QuantizationSpec objects or None defining how to quantize each input. None value represents disabling quantization. Default: {"\*": default_activation_quantization_spec()} (int8 quantization for all inputs) op_output_spec (dict[str | int, QuantizationSpec | None] | None): Quantization specifications for operation outputs. Keys can be either all indices or all string names, but not a mix of both. The special key "*" can be used in both cases to refer to all outputs. Example keys: - int: Output index (e.g., 0 for first output, 1 for second output) - str: Named output identifier (e.g., "y", "output_0") - "\*": Applies to all outputs for the operation. Other tensors can be explicitly mentioned to override this setting. Values are QuantizationSpec objects or None defining how to quantize each output. None value represents disabling quantization. Default: {"\*": default_activation_quantization_spec()} (int8 quantization for all outputs) op_state_spec (dict[str, QuantizationSpec | None] | None): Quantization specifications for operation state tensors (parameters, buffers, constants). Keys can be string names (e.g. "weight", "bias") or "\*" to refer to all state inputs. Values are QuantizationSpec objects or None defining how to quantize each state tensor. None value represents disabling quantization. Default: {"weight": default_weight_quantization_spec()} (int8 quantization for weight inputs) Example: >>> # Quantize first input, disable first output, quantize weight tensor >>> op_config = OpQuantizerConfig( ... op_input_spec={ ... 0: QuantizationSpec( ... dtype=torch.int8, ... qscheme="symmetric", ... granularity={"type": "per_tensor"}, ... fake_quantize_cls="default", ... qparam_calculator_cls="moving_average", ... range_calculator_cls="minmax", ... ) ... }, ... op_output_spec={ ... 0: None ... }, ... op_state_spec={ ... "weight": QuantizationSpec( ... dtype=torch.int4, ... qscheme="symmetric", ... granularity={"type": "per_channel", "axis": 1}, ... fake_quantize_cls="default", ... qparam_calculator_cls="default", ... range_calculator_cls="minmax", ... ) ... } ... ) """
[docs] @classmethod def get_default_input_spec(cls) -> dict[str | int, QuantizationSpec | None]: """Provide default input spec for quantization.""" return {_ALL_TENSORS: default_activation_quantization_spec()}
[docs] @classmethod def get_default_output_spec(cls) -> dict[str | int, QuantizationSpec | None]: """Provide default output spec for quantization.""" return {_ALL_TENSORS: default_activation_quantization_spec()}
[docs] @classmethod def get_default_state_spec(cls) -> dict[str, QuantizationSpec | None]: """Provide default state spec for quantization.""" return {"weight": default_weight_quantization_spec()}
[docs] @final class ModuleQuantizerConfig(ModuleCompressionConfig[OpQuantizerConfig, QuantizationSpec]): """ Configuration class for quantization at the module level. This class manages quantization settings for an entire module, including: - Operation-level configurations (default, by type, by name) - Module-level input/output quantization - Module-level state (parameter) quantization The operation configurations follow a hierarchical precedence: 1. op_name_config (most specific - applies to operations matching a name pattern) 2. op_type_config (applies to operations of a specific type) 3. op_input/output/state_spec (least specific - applies to all operations not otherwise configured) Module-level input, output, and state settings treat the module as an opaque entity, setting quantization settings for specified tensors and ignoring op specific quantization capabilities. Module-level settings also don't check whether the operation receiving the quantized tensor is a registered operation or not. Module-level settings will override any op specific settings. Attributes: op_input_spec (dict[str | int, QuantizationSpec | None] | None): Quantization specifications for operation inputs applied to all registered operations/patterns within this module that don't have a more specific configuration. Keys can be either all indices or all string names, but not a mix of both. The special key "\*" can be used in both cases to refer to all inputs. Example keys: - int: Input index (e.g., 0 for first input, 1 for second input) - str: Named input identifier (e.g., "x", "input_0") - "\*": Applies to all inputs for the operation. Other tensors can be explicitly mentioned to override this setting. Values are QuantizationSpec objects or None defining how to quantize each input. None value represents disabling quantization. Default: {"\*": default_activation_quantization_spec()} (int8 quantization for all inputs) op_output_spec (dict[str | int, QuantizationSpec | None] | None): Quantization specifications for operation inputs applied to all registered operations/patterns within this module that don't have a more specific configuration. Keys can be either all indices or all string names, but not a mix of both. The special key "\*" can be used in both cases to refer to all outputs. Example keys: - int: Output index (e.g., 0 for first output, 1 for second output) - str: Named output identifier (e.g., "y", "output_0") - "\*": Applies to all outputs for the operation. Other tensors can be explicitly mentioned to override this setting. Values are QuantizationSpec objects or None defining how to quantize each output. None value represents disabling quantization. Default: {"\*": default_activation_quantization_spec()} (int8 quantization for all outputs) op_state_spec (dict[str, QuantizationSpec | None] | None): Quantization specifications for operation state tensors (parameters, buffers, constants) applied to all registered operations/patterns within this module that don't have a more specific configuration. Keys can be string names (e.g. "weight", "bias") or "\*" to refer to all state inputs. Values are QuantizationSpec objects or None defining how to quantize each state tensor. None value represents disabling quantization. Default: {"weight": default_weight_quantization_spec()} (int8 quantization for weight inputs) op_type_config (dict[str, OpQuantizerConfig | None] | None): Operation type-specific configurations. Keys are operation types (e.g., "linear", "conv2d"). Generally speaking, operation types will match the torch functional name (https://docs.pytorch.org/docs/stable/nn.functional.html) or operation name within the torch namespace (https://docs.pytorch.org/docs/stable/torch.html) when taking the portion of the name to the right of the last period. For example, to refer to a Maxpool 2D operation, take the name used for the torch functional, torch.nn.functional.max_pool2d, and use the portion of the string after the last period: "max_pool2d". OpQuantizerConfig objects or None defining how to quantize operations of that type. None value represents disabling quantization. Default: {} (empty dict, no type-specific configs) op_name_config (dict[str, OpQuantizerConfig | None] | None): Operation name-specific configurations. Keys are operation name patterns (supports regex matching). Values are OpQuantizerConfig objects or None defining how to quantize operations matching those names. None value represents disabling quantization. Default: {} (empty dict, no name-specific configs) module_input_spec (dict[str | int, QuantizationSpec | None] | None): Quantization specifications for module inputs. Module input settings treat the module as an opaque entity, setting quantization settings for input tensors to the module without checking whether the op receiving the input is quantizable. Module input settings override op level settings for the op receiving the module input. Keys can be either all indices or all string names, but not a mix of both. The special key "\*" can be used in both cases to refer to all module inputs. Example keys: - int: Input index (e.g., 0 for first input, 1 for second input) - str: Named input identifier (e.g., "y", "input_0") - "\*": Applies to all inputs for the operation. Other tensors can be explicitly mentioned to override this setting. Values are QuantizationSpec objects or None. None value represents disabling quantization. Default: {} (empty dict, no specific module input settings) module_output_spec (dict[str | int, QuantizationSpec | None] | None): Quantization specifications for module outputs. Module output settings treat the module as an opaque entity, setting quantization settings for output tensors to the module without checking whether the op receiving the output is quantizable. Module output settings override op level settings for the op receiving the module output. Keys can be: - int: Output index (e.g., 0 for first output) - str: Named output identifier - "\*": Applies to all outputs for the operation. Other tensors can be explicitly mentioned to override this setting. Values are QuantizationSpec objects or None. None value represents disabling quantization. Default: {} (empty dict, no specific module output settings) module_state_spec (dict[str, QuantizationSpec | None] | None): Quantization specifications for module state tensors (parameters, buffers, and constants). Module state settings will override op state settings for the same state tensors. Keys can be string names (e.g. "weight", "bias") or "\*" to refer to all state inputs. Values are QuantizationSpec objects or None. None value represents disabling quantization. Default: {} (empty dict, no specific module state settings) qat_schedule (QATSchedule | None): Optional QAT schedule for controlling observer and fake quantization state transitions during training. When set, the ``quantizer.step()`` API must be used to advance the schedule. See :class:`QATSchedule` for details. When None (default), both observer and fake quantization are enabled from the start of training. Example: >>> # Configure a module with default op config and specific settings for >>> # linear ops >>> module_config = ModuleQuantizerConfig( ... # Omitted op_input/output/state_specs sets default quantization for all ... # ops ... op_type_config={ ... "linear": OpQuantizerConfig( ... op_input_spec={ ... 0: ... ... }, ... op_output_spec={ ... 0: ... ... }, ... op_state_spec={ ... "weight": ... ... } ... ) ... }, ... ) """ def __init_subclass__(cls, **kwargs): # Prohibit subclassing due to preset limitation: presets remain bound # to the base class. Revisit if subclass support is needed. super().__init_subclass__(**kwargs) msg = f"{cls.__name__} cannot subclass ModuleQuantizerConfig (marked final)." raise TypeError(msg) qat_schedule: QATSchedule | None = None # Namespace exposing built-in preset constructors for module-level configs. # Wired at the bottom of this module after the class is fully defined. presets: ClassVar[_ModuleQuantizerConfigPresets]
[docs] @final class QuantizerConfig(CompressionConfig[ModuleQuantizerConfig]): """Top-level configuration class for quantization. This class manages the complete quantization configuration for a neural network model, organizing module-level configurations in a hierarchical structure. It inherits from CompressionConfig and specializes it for quantization using ModuleQuantizerConfig. The configuration lookup follows a hierarchical precedence (most to least specific): 1. module_name_configs - Applies to module instances matching a name pattern (supports regex) 2. module_type_configs - Applies to all modules of a specific type (e.g., torch.nn.modules.linear.Linear) 3. global_config - Default configuration applied to all modules not otherwise configured Attributes: global_config (ModuleQuantizerConfig | None): Default module-level quantization configuration applied to all modules that don't have a more specific configuration. When QuantizerConfig is initialized with no arguments, a default global_config is automatically created with standard int8 quantization. Setting global_config to None disables quantization by default globally. Default: Auto-created with int8 quantization specs when no args provided module_type_configs (dict[str, ModuleQuantizerConfig | None] | None): Module type-specific configurations. Keys are fully-qualified module type names (e.g., "torch.nn.modules.linear.Linear", "torch.nn.modules.conv.Conv2d"). Values are ModuleQuantizerConfig objects or None to disable quantization for that module type. Default: {} (empty dict, no type-specific configs) module_name_configs (dict[str, ModuleQuantizerConfig | None] | None): Module name-specific configurations. Keys are module name patterns (supports regex matching, e.g., "model.layer1.*", "decoder.layers.0"). Values are ModuleQuantizerConfig objects or None to disable quantization for matching modules. Default: {} (empty dict, no name-specific configs) preserved_attributes (list[str] | None): Names of attributes of the model which should be preserved on the prepared and finalized models, even if they are not used in the model's forward pass. execution_mode (ExecutionMode | str): Specifies which quantization execution mode to use. Options are: - ExecutionMode.GRAPH / "graph": Graph-based quantization using ``torch.export`` and FX graphs, built on ``torchao``'s PT2E implementation. Requires the model to be exportable. - ExecutionMode.EAGER / "eager": Works directly on ``nn.Module`` without converting to a graph representation. Supports dynamic control flow (if/else, loops) and doesn't require ``torch.export``. Default: ExecutionMode.GRAPH Example: >>> # Create default quantizer config (auto-creates int8 global >>> # config) >>> config = QuantizerConfig() >>> # config.global_config is automatically created with default int8 specs >>> >>> # Disable quantization globally >>> config = QuantizerConfig( ... global_config=None ... ) >>> >>> # Create custom quantizer config with type-specific settings for Linear >>> # modules. >>> config = QuantizerConfig( ... # Omitted global_config section defaults to int8/int8 weight/activation ... # quantization for all operations ... module_type_configs={ ... "torch.nn.modules.linear.Linear": ModuleQuantizerConfig( ... op_input_spec={ ... 0: ... ... }, ... op_output_spec={ ... 0: ... ... }, ... op_state_spec={ ... 'weight': ... ... } ... ) ... }, ... ) >>> >>> # Load quantizer config from YAML file >>> config = QuantizerConfig.from_yaml("config.yaml") Notes: - When initialized with no arguments, a default configuration is created with int8 symmetric quantization for activations and weights - The from_yaml class method provides an alternative way to create configurations from YAML files - Setting a config to None explicitly disables quantization for that scope - More specific configurations (name > type > global) always override less specific ones """ def __init_subclass__(cls, **kwargs): # Prohibit subclassing due to preset limitation: presets remain bound # to the base class. Revisit if subclass support is needed. super().__init_subclass__(**kwargs) msg = f"{cls.__name__} cannot subclass QuantizerConfig (marked final)." raise TypeError(msg) # Class attributes for config key pattern used in from_dict/from_yaml _CONFIG_KEY: ClassVar[str] = _QUANTIZATION_CONFIG _SPEC_KEY: ClassVar[str] = _QUANTIZATION_SPEC # Namespace exposing built-in and registered preset constructors. # Wired at the bottom of this module after the class is fully defined. presets: ClassVar[_QuantizerConfigPresets] preserved_attributes: list[str] | None = None execution_mode: ExecutionMode = ExecutionMode.GRAPH
[docs] def set_execution_mode(self, mode: ExecutionMode | str) -> Self: """Set the quantization execution mode. Args: mode (ExecutionMode | str): Execution mode to use. Accepts an ``ExecutionMode`` member (e.g. ``ExecutionMode.EAGER``) or its string value (e.g. ``"graph"``, ``"eager"``). Returns: Self: This config, for method chaining. Raises: ValueError: If ``mode`` is a string that is not a valid ``ExecutionMode`` value. Example: >>> config = QuantizerConfig.presets.w4() >>> config.set_execution_mode(ExecutionMode.EAGER) """ self.execution_mode = ExecutionMode(mode) return self
# Preset wiring — after all classes so _presets can import ExecutionMode. from coreai_opt.quantization.config._presets import ( # noqa: E402, PLC0415 _ModuleQuantizerConfigPresets, _QuantizerConfigPresets, ) ModuleQuantizerConfig.presets = _ModuleQuantizerConfigPresets(owner_cls=ModuleQuantizerConfig) QuantizerConfig.presets = _QuantizerConfigPresets(owner_cls=QuantizerConfig)