# Copyright 2026 Apple Inc.
#
# Use of this source code is governed by a BSD-3-Clause license that can
# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause
"""Quantization config class definitions."""
from __future__ import annotations
from enum import auto
from typing import TYPE_CHECKING, ClassVar, NamedTuple, Self, TypeAlias, final
from pydantic import BaseModel, ConfigDict, Field, model_validator
from coreai_opt._utils.config_utils import ALL_TENSORS as _ALL_TENSORS
from coreai_opt.common import (
_DeprecatedMemberEnumMeta,
_StrEnum,
)
from coreai_opt.config import (
CompressionConfig,
ModuleCompressionConfig,
OpCompressionConfig,
)
from coreai_opt.quantization.spec import (
QuantizationSpec,
default_activation_quantization_spec,
default_weight_quantization_spec,
)
if TYPE_CHECKING:
from coreai_opt.quantization.config._presets import (
_ModuleQuantizerConfigPresets,
_QuantizerConfigPresets,
)
[docs]
class QATSchedule(BaseModel):
"""Schedule for controlling observer and fake quantization state in QAT.
Defines step thresholds for enabling/disabling observers and fake quantization
during quantization-aware training. Must be used in conjunction with the
``quantizer.step()`` API to advance the schedule.
The step values correspond to the cadence at which ``quantizer.step()`` is
called. For example, if ``step()`` is called once per batch, the thresholds
represent batch steps; if called once per epoch, they represent epochs.
Calling ``step()`` increments the step counter and immediately applies
the corresponding observer/fake-quantization state. Where you place
``step()`` in your training loop determines when the model sees the
new state.
Attributes:
enable_observer: Step count at which observers are enabled. Must be >= 0.
enable_fake_quant: Step count at which fake quantization is enabled.
Must be >= enable_observer.
disable_observer: Step count at which observers are disabled. Must be
> enable_observer and >= enable_fake_quant if provided. None means
observers are never disabled by the schedule.
Example:
>>> schedule = QATSchedule(
... enable_observer=0,
... enable_fake_quant=500,
... disable_observer=1500,
... )
Note:
In graph execution mode, when consecutive modules both quantize the
intermediate edge (one via ``op_output_spec``, the next via
``op_input_spec``), graph mode deduplicates them into a single
fake-quantize node. The schedule of the consuming module is always
applied to the deduplicated node, irrespective of the choice of
deduplication made by the graph preparation.
Note:
When two modules share a weight parameter and have different
schedules, the schedule of the first module encountered in the
module tree is applied. A warning is emitted for the conflict if
there is no fake-quantize node deduplication happening (in Eager
execution mode).
"""
model_config = ConfigDict(frozen=True)
class _ScheduleState(NamedTuple):
obs_on: bool
fq_on: bool
enable_observer: int = Field(default=0, ge=0)
enable_fake_quant: int = Field(default=0, ge=0)
disable_observer: int | None = Field(default=None, gt=0)
@model_validator(mode="after")
def _validate_schedule(self) -> QATSchedule:
if self.enable_fake_quant < self.enable_observer:
raise ValueError(
f"enable_fake_quant ({self.enable_fake_quant}) must be >= "
f"enable_observer ({self.enable_observer})"
)
if self.disable_observer is not None:
if self.disable_observer <= self.enable_observer:
raise ValueError(
f"disable_observer ({self.disable_observer}) must be > "
f"enable_observer ({self.enable_observer})"
)
if self.disable_observer < self.enable_fake_quant:
raise ValueError(
f"disable_observer ({self.disable_observer}) must be >= "
f"enable_fake_quant ({self.enable_fake_quant})"
)
return self
def _compute_state(self, step_count: int) -> _ScheduleState:
"""Return the observer/fake_quant state at the given step."""
obs_end = self.disable_observer or float("inf")
return self._ScheduleState(
obs_on=self.enable_observer <= step_count < obs_end,
fq_on=step_count >= self.enable_fake_quant,
)
_QUANTIZATION_CONFIG = "quantization_config"
_QUANTIZATION_SPEC = "quantization_spec"
_ACTIVATION_SPEC_DICT: TypeAlias = dict[str | int, QuantizationSpec | None]
_STATE_SPEC_DICT: TypeAlias = dict[str, QuantizationSpec | None]
[docs]
class ExecutionMode(_StrEnum, metaclass=_DeprecatedMemberEnumMeta):
"""Enum representing quantization execution modes.
Each member is a string value representing the execution mode used
for quantization.
Attributes:
GRAPH: Graph-based quantization using ``torch.export`` to capture the model as an FX graph,
then applying quantization on top. Built on ``torchao``'s PT2E implementation. Requires
the model to be exportable via ``torch.export.export``. Recommended default.
EAGER: Eager-mode quantization that works directly on ``nn.Module`` without graph capture.
Supports dynamic control flow (if/else, loops) and is the fallback when a model is not
exportable.
"""
GRAPH = auto()
EAGER = auto()
__deprecated_aliases__: ClassVar[dict[str, str]] = {"PT2E": "GRAPH"}
if TYPE_CHECKING:
# Surface the deprecated alias above for static type checkers.
PT2E: ExecutionMode
"""Deprecated. Use ``ExecutionMode.GRAPH`` instead."""
[docs]
class OpQuantizerConfig(OpCompressionConfig[QuantizationSpec]):
"""
Configuration class for quantization at the operation level.
This class specifies quantization settings for inputs, outputs, and state
tensors of individual operations (ops) in a neural network. Each tensor
can have its own quantization specification.
Quantization for operations will require the operation to be registered. Even if
globally all ops are configured to be quantized, ops which are not recognized will
not be quantized.
Attributes:
op_input_spec (dict[str | int, QuantizationSpec | None] | None): Quantization
specifications for operation inputs. Keys can be either all indices or all
string names, but not a mix of both. The special key "*" can be used in both
cases to refer to all inputs.
Example keys:
- int: Input index (e.g., 0 for first input, 1 for second
input)
- str: Named input identifier (e.g., "x", "input_0")
- "\*": Applies to all inputs for the operation. Other tensors
can be explicitly mentioned to override this setting.
Values are QuantizationSpec objects or None defining how to quantize each
input. None value represents disabling quantization.
Default: {"\*": default_activation_quantization_spec()} (int8
quantization for all inputs)
op_output_spec (dict[str | int, QuantizationSpec | None] | None): Quantization
specifications for operation outputs. Keys can be either all indices or all
string names, but not a mix of both. The special key "*" can be used in both
cases to refer to all outputs.
Example keys:
- int: Output index (e.g., 0 for first output, 1 for second
output)
- str: Named output identifier (e.g., "y", "output_0")
- "\*": Applies to all outputs for the operation. Other tensors
can be explicitly mentioned to override this setting.
Values are QuantizationSpec objects or None defining how to quantize each
output. None value represents disabling quantization.
Default: {"\*": default_activation_quantization_spec()} (int8
quantization for all outputs)
op_state_spec (dict[str, QuantizationSpec | None] | None): Quantization
specifications for operation state tensors (parameters, buffers, constants).
Keys can be string names (e.g. "weight", "bias") or "\*" to refer to all
state inputs.
Values are QuantizationSpec objects or None defining how to quantize each
state tensor. None value represents disabling quantization.
Default: {"weight": default_weight_quantization_spec()} (int8
quantization for weight inputs)
Example:
>>> # Quantize first input, disable first output, quantize weight tensor
>>> op_config = OpQuantizerConfig(
... op_input_spec={
... 0: QuantizationSpec(
... dtype=torch.int8,
... qscheme="symmetric",
... granularity={"type": "per_tensor"},
... fake_quantize_cls="default",
... qparam_calculator_cls="moving_average",
... range_calculator_cls="minmax",
... )
... },
... op_output_spec={
... 0: None
... },
... op_state_spec={
... "weight": QuantizationSpec(
... dtype=torch.int4,
... qscheme="symmetric",
... granularity={"type": "per_channel", "axis": 1},
... fake_quantize_cls="default",
... qparam_calculator_cls="default",
... range_calculator_cls="minmax",
... )
... }
... )
"""
[docs]
@classmethod
def get_default_output_spec(cls) -> dict[str | int, QuantizationSpec | None]:
"""Provide default output spec for quantization."""
return {_ALL_TENSORS: default_activation_quantization_spec()}
[docs]
@classmethod
def get_default_state_spec(cls) -> dict[str, QuantizationSpec | None]:
"""Provide default state spec for quantization."""
return {"weight": default_weight_quantization_spec()}
[docs]
@final
class ModuleQuantizerConfig(ModuleCompressionConfig[OpQuantizerConfig, QuantizationSpec]):
"""
Configuration class for quantization at the module level.
This class manages quantization settings for an entire module, including:
- Operation-level configurations (default, by type, by name)
- Module-level input/output quantization
- Module-level state (parameter) quantization
The operation configurations follow a hierarchical precedence:
1. op_name_config (most specific - applies to operations matching a name
pattern)
2. op_type_config (applies to operations of a specific type)
3. op_input/output/state_spec (least specific - applies to all operations
not otherwise configured)
Module-level input, output, and state settings treat the module as an opaque entity,
setting quantization settings for specified tensors and ignoring op specific
quantization capabilities. Module-level settings also don't check whether the
operation receiving the quantized tensor is a registered operation or not.
Module-level settings will override any op specific settings.
Attributes:
op_input_spec (dict[str | int, QuantizationSpec | None] | None): Quantization
specifications for operation inputs applied to all registered
operations/patterns within this module that don't have a more specific
configuration.
Keys can be either all indices or all string names, but not a mix of both.
The special key "\*" can be used in both cases to refer to all inputs.
Example keys:
- int: Input index (e.g., 0 for first input, 1 for second input)
- str: Named input identifier (e.g., "x", "input_0")
- "\*": Applies to all inputs for the operation. Other tensors
can be explicitly mentioned to override this setting.
Values are QuantizationSpec objects or None defining how to quantize each
input. None value represents disabling quantization.
Default: {"\*": default_activation_quantization_spec()} (int8
quantization for all inputs)
op_output_spec (dict[str | int, QuantizationSpec | None] | None): Quantization
specifications for operation inputs applied to all registered
operations/patterns within this module that don't have a more specific
configuration.
Keys can be either all indices or all string names, but not a mix of both.
The special key "\*" can be used in both cases to refer to all outputs.
Example keys:
- int: Output index (e.g., 0 for first output, 1 for second
output)
- str: Named output identifier (e.g., "y", "output_0")
- "\*": Applies to all outputs for the operation. Other tensors
can be explicitly mentioned to override this setting.
Values are QuantizationSpec objects or None defining how to quantize each
output. None value represents disabling quantization.
Default: {"\*": default_activation_quantization_spec()} (int8
quantization for all outputs)
op_state_spec (dict[str, QuantizationSpec | None] | None): Quantization
specifications for operation state tensors (parameters, buffers, constants)
applied to all registered operations/patterns within this module that don't
have a more specific configuration.
Keys can be string names (e.g. "weight", "bias") or "\*" to refer to all
state inputs.
Values are QuantizationSpec objects or None defining how to quantize each
state tensor. None value represents disabling quantization.
Default: {"weight": default_weight_quantization_spec()} (int8
quantization for weight inputs)
op_type_config (dict[str, OpQuantizerConfig | None] | None): Operation
type-specific configurations. Keys are operation types (e.g.,
"linear", "conv2d"). Generally speaking, operation types will match the torch
functional name (https://docs.pytorch.org/docs/stable/nn.functional.html) or
operation name within the torch namespace
(https://docs.pytorch.org/docs/stable/torch.html)
when taking the portion of the name to the right of the last period.
For example, to refer to a Maxpool 2D operation, take the name used for the
torch functional, torch.nn.functional.max_pool2d, and use the portion of the
string after the last period: "max_pool2d".
OpQuantizerConfig objects or None defining how to quantize operations of
that type. None value represents disabling quantization.
Default: {} (empty dict, no type-specific configs)
op_name_config (dict[str, OpQuantizerConfig | None] | None): Operation
name-specific configurations. Keys are operation name patterns
(supports regex matching). Values are OpQuantizerConfig objects or None
defining how to quantize operations matching those names. None value
represents disabling quantization.
Default: {} (empty dict, no name-specific configs)
module_input_spec (dict[str | int, QuantizationSpec | None] | None):
Quantization specifications for module inputs. Module input settings treat
the module as an opaque entity, setting quantization settings for input
tensors to the module without checking whether the op receiving
the input is quantizable. Module input settings override op level
settings for the op receiving the module input.
Keys can be either all indices or all string names, but not a mix of both.
The special key "\*" can be used in both cases to refer to all module inputs.
Example keys:
- int: Input index (e.g., 0 for first input, 1 for second
input)
- str: Named input identifier (e.g., "y", "input_0")
- "\*": Applies to all inputs for the operation. Other tensors
can be explicitly mentioned to override this setting.
Values are QuantizationSpec objects or None. None value represents disabling
quantization.
Default: {} (empty dict, no specific module input settings)
module_output_spec (dict[str | int, QuantizationSpec | None] | None):
Quantization specifications for module outputs. Module output settings treat
the module as an opaque entity, setting quantization settings for
output tensors to the module without checking whether the op
receiving the output is quantizable. Module output settings
override op level settings for the op receiving the module output.
Keys can be:
- int: Output index (e.g., 0 for first output)
- str: Named output identifier
- "\*": Applies to all outputs for the operation. Other tensors
can be explicitly mentioned to override this setting.
Values are QuantizationSpec objects or None. None value represents disabling
quantization.
Default: {} (empty dict, no specific module output settings)
module_state_spec (dict[str, QuantizationSpec | None] | None): Quantization
specifications for module state tensors (parameters, buffers, and
constants). Module state settings will override op state settings for the
same state tensors.
Keys can be string names (e.g. "weight", "bias") or "\*" to refer to all
state inputs.
Values are QuantizationSpec objects or None. None value represents disabling
quantization.
Default: {} (empty dict, no specific module state settings)
qat_schedule (QATSchedule | None): Optional QAT schedule for controlling
observer and fake quantization state transitions during training.
When set, the ``quantizer.step()`` API must be used to advance the
schedule. See :class:`QATSchedule` for details. When None (default),
both observer and fake quantization are enabled from the start of
training.
Example:
>>> # Configure a module with default op config and specific settings for
>>> # linear ops
>>> module_config = ModuleQuantizerConfig(
... # Omitted op_input/output/state_specs sets default quantization for all
... # ops
... op_type_config={
... "linear": OpQuantizerConfig(
... op_input_spec={
... 0: ...
... },
... op_output_spec={
... 0: ...
... },
... op_state_spec={
... "weight": ...
... }
... )
... },
... )
"""
def __init_subclass__(cls, **kwargs):
# Prohibit subclassing due to preset limitation: presets remain bound
# to the base class. Revisit if subclass support is needed.
super().__init_subclass__(**kwargs)
msg = f"{cls.__name__} cannot subclass ModuleQuantizerConfig (marked final)."
raise TypeError(msg)
qat_schedule: QATSchedule | None = None
# Namespace exposing built-in preset constructors for module-level configs.
# Wired at the bottom of this module after the class is fully defined.
presets: ClassVar[_ModuleQuantizerConfigPresets]
[docs]
@final
class QuantizerConfig(CompressionConfig[ModuleQuantizerConfig]):
"""Top-level configuration class for quantization.
This class manages the complete quantization configuration for a neural
network model, organizing module-level configurations in a hierarchical
structure. It inherits from CompressionConfig and specializes it for
quantization using ModuleQuantizerConfig.
The configuration lookup follows a hierarchical precedence (most to least
specific):
1. module_name_configs - Applies to module instances matching a name
pattern (supports regex)
2. module_type_configs - Applies to all modules of a specific type (e.g.,
torch.nn.modules.linear.Linear)
3. global_config - Default configuration applied to all modules not
otherwise configured
Attributes:
global_config (ModuleQuantizerConfig | None): Default module-level
quantization configuration applied to all modules that don't have
a more specific configuration. When QuantizerConfig is initialized
with no arguments, a default global_config is automatically created with
standard int8 quantization.
Setting global_config to None disables quantization by default globally.
Default: Auto-created with int8 quantization specs when no args
provided
module_type_configs (dict[str, ModuleQuantizerConfig | None] | None): Module
type-specific configurations. Keys are fully-qualified module type
names (e.g., "torch.nn.modules.linear.Linear",
"torch.nn.modules.conv.Conv2d"). Values are ModuleQuantizerConfig objects or
None to disable quantization for that module type.
Default: {} (empty dict, no type-specific configs)
module_name_configs (dict[str, ModuleQuantizerConfig | None] | None): Module
name-specific configurations. Keys are module name patterns
(supports regex matching, e.g., "model.layer1.*",
"decoder.layers.0"). Values are ModuleQuantizerConfig objects or
None to disable quantization for matching modules.
Default: {} (empty dict, no name-specific configs)
preserved_attributes (list[str] | None): Names of attributes of the model
which should be preserved on the prepared and finalized models, even if they
are not used in the model's forward pass.
execution_mode (ExecutionMode | str): Specifies which quantization execution
mode to use. Options are:
- ExecutionMode.GRAPH / "graph":
Graph-based quantization using ``torch.export`` and FX graphs, built on
``torchao``'s PT2E implementation. Requires the model to be exportable.
- ExecutionMode.EAGER / "eager":
Works directly on ``nn.Module`` without converting to a graph representation.
Supports dynamic control flow (if/else, loops) and doesn't require ``torch.export``.
Default: ExecutionMode.GRAPH
Example:
>>> # Create default quantizer config (auto-creates int8 global
>>> # config)
>>> config = QuantizerConfig()
>>> # config.global_config is automatically created with default int8 specs
>>>
>>> # Disable quantization globally
>>> config = QuantizerConfig(
... global_config=None
... )
>>>
>>> # Create custom quantizer config with type-specific settings for Linear
>>> # modules.
>>> config = QuantizerConfig(
... # Omitted global_config section defaults to int8/int8 weight/activation
... # quantization for all operations
... module_type_configs={
... "torch.nn.modules.linear.Linear": ModuleQuantizerConfig(
... op_input_spec={
... 0: ...
... },
... op_output_spec={
... 0: ...
... },
... op_state_spec={
... 'weight': ...
... }
... )
... },
... )
>>>
>>> # Load quantizer config from YAML file
>>> config = QuantizerConfig.from_yaml("config.yaml")
Notes:
- When initialized with no arguments, a default configuration is
created with int8 symmetric quantization for activations and
weights
- The from_yaml class method provides an alternative way to create
configurations from YAML files
- Setting a config to None explicitly disables quantization for that
scope
- More specific configurations (name > type > global) always override
less specific ones
"""
def __init_subclass__(cls, **kwargs):
# Prohibit subclassing due to preset limitation: presets remain bound
# to the base class. Revisit if subclass support is needed.
super().__init_subclass__(**kwargs)
msg = f"{cls.__name__} cannot subclass QuantizerConfig (marked final)."
raise TypeError(msg)
# Class attributes for config key pattern used in from_dict/from_yaml
_CONFIG_KEY: ClassVar[str] = _QUANTIZATION_CONFIG
_SPEC_KEY: ClassVar[str] = _QUANTIZATION_SPEC
# Namespace exposing built-in and registered preset constructors.
# Wired at the bottom of this module after the class is fully defined.
presets: ClassVar[_QuantizerConfigPresets]
preserved_attributes: list[str] | None = None
execution_mode: ExecutionMode = ExecutionMode.GRAPH
[docs]
def set_execution_mode(self, mode: ExecutionMode | str) -> Self:
"""Set the quantization execution mode.
Args:
mode (ExecutionMode | str): Execution mode to use.
Accepts an ``ExecutionMode`` member (e.g. ``ExecutionMode.EAGER``)
or its string value (e.g. ``"graph"``, ``"eager"``).
Returns:
Self: This config, for method chaining.
Raises:
ValueError: If ``mode`` is a string that is not a valid
``ExecutionMode`` value.
Example:
>>> config = QuantizerConfig.presets.w4()
>>> config.set_execution_mode(ExecutionMode.EAGER)
"""
self.execution_mode = ExecutionMode(mode)
return self
# Preset wiring — after all classes so _presets can import ExecutionMode.
from coreai_opt.quantization.config._presets import ( # noqa: E402, PLC0415
_ModuleQuantizerConfigPresets,
_QuantizerConfigPresets,
)
ModuleQuantizerConfig.presets = _ModuleQuantizerConfigPresets(owner_cls=ModuleQuantizerConfig)
QuantizerConfig.presets = _QuantizerConfigPresets(owner_cls=QuantizerConfig)