Source code for coreai_opt.quantization.quantizer

# Copyright 2026 Apple Inc.
#
# Use of this source code is governed by a BSD-3-Clause license that can
# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause

from __future__ import annotations

import warnings
from contextlib import contextmanager
from os import PathLike
from typing import Any

import torch
import torch.nn as nn
from torch import fx
from torchao.quantization.pt2e import (
    disable_fake_quant as _torchao_disable_fake_quant,
    disable_observer as _torchao_disable_observer,
    enable_fake_quant as _torchao_enable_fake_quant,
    enable_observer as _torchao_enable_observer,
)

from coreai_opt._utils.config_utils import ConfigLevel as _ConfigLevel
from coreai_opt._utils.export_utils import (
    validate_mmap_backend_and_device as _validate_mmap_backend_and_device,
)
from coreai_opt._utils.torch_utils import get_module_name
from coreai_opt.common import ExportBackend
from coreai_opt.quantization._eager import EagerQuantizer as _EagerQuantizer
from coreai_opt.quantization._graph import GraphQuantizer as _GraphQuantizer
from coreai_opt.quantization.base_quantizer import _BaseQuantizer
from coreai_opt.quantization.config.quantization_config import (
    ExecutionMode,
    QATSchedule,
    QuantizerConfig,
)
from coreai_opt.quantization.spec.fake_quantize import FakeQuantizeImplBase


[docs] class Quantizer(_BaseQuantizer): """ Unified quantizer API that provides a single entry point for various quantization workflows, including: - **Data Types**: Integer (e.g. int8, int4) and floating-point (e.g. float8_e4m3fn, float8_e5m2) quantization - **Quantization Workflows**: Post-training quantization (PTQ) and quantization-aware training (QAT) - **Execution Modes**: Graph mode (built on torchao's PT2E) or eager mode The quantizer automatically selects the appropriate underlying implementation based on the `execution_mode` specified in the configuration. Defaults to graph mode. Some of the key differences between the execution modes are summarized below: +-----------------------+---------------------------------+----------------------------+ | Feature | Graph Mode (Default) | Eager Mode | +=======================+=================================+============================+ | Input/Output Types | nn.Module | nn.Module -> nn.Module | | | -> fx.GraphModule. | | +-----------------------+---------------------------------+----------------------------+ | Module Fusion | Automatic pattern-based fusion | Manual fusion required | | | (e.g., conv+bn+relu) | | +-----------------------+---------------------------------+----------------------------+ | Control Flow | Static graph only; | Supports dynamic | | | Requires torch.export | control flow | | | compatible model | (if/else, loops) | +-----------------------+---------------------------------+----------------------------+ | Shared Observer Ops | Handled correctly; ops like | Not supported; Ops like | | | MaxPool that share the same | MaxPool have independent | | | observer across inputs and | observers for input vs | | | outputs are detected and | output, which can cause | | | deduplicated on the graph. | incorrect quantization. | +-----------------------+---------------------------------+----------------------------+ | FQ Node Deduplication | Back-to-back fake-quantize | No deduplication; if both | | | nodes on the same tensor are | the output of one op and | | | collapsed into a single node, | the input of the next are | | | avoiding redundant quantization | quantized, two consecutive | | | on intermediate edges. | FQ nodes are inserted on | | | | that intermediate edge. | +-----------------------+---------------------------------+----------------------------+ As a result of above mentioned differences, the total number of fake-quantize nodes inserted by graph and eager mode can differ for the same ``QuantizerConfig``. This means the two modes are **not guaranteed to produce equivalent quantized models**, and final model performance (accuracy and latency) may differ between modes even when using identical configurations. Args: model: The PyTorch model to quantize. config: Quantization configuration. If None, a default configuration with int8 weight and activation quantization is created. Example: >>> from coreai_opt.quantization import Quantizer, QuantizerConfig, ExecutionMode >>> >>> # PTQ with calibration (default int8, graph mode) >>> config = QuantizerConfig() >>> quantizer = Quantizer(model, config) >>> prepared_model = quantizer.prepare((example_input,)) >>> with quantizer.calibration_mode(): ... for data in calibration_loader: ... prepared_model(data) >>> quantized_model = quantizer.finalize() >>> >>> # QAT workflow (default schedule — observers and fake_quant enabled throughout) >>> prepared_model = quantizer.prepare((example_input,)) >>> with quantizer.training_mode(): ... for epoch in range(num_epochs): ... for data, target in train_loader: ... optimizer.zero_grad() ... output = prepared_model(data) ... loss = criterion(output, target) ... loss.backward() ... optimizer.step() >>> quantized_model = quantizer.finalize() >>> >>> # QAT workflow with schedule >>> from coreai_opt.quantization import ModuleQuantizerConfig >>> from coreai_opt.quantization.config import QATSchedule >>> # Enable observers from the start, enable fake quant at the 100th step, >>> # and disable observers at the 500th step. >>> schedule = QATSchedule( ... enable_observer=0, enable_fake_quant=100, disable_observer=500 ... ) >>> config = QuantizerConfig( ... global_config=ModuleQuantizerConfig(qat_schedule=schedule) ... ) >>> quantizer = Quantizer(model, config) >>> prepared_model = quantizer.prepare((example_input,)) >>> with quantizer.training_mode(): ... for data, target in train_loader: ... optimizer.zero_grad() ... loss = criterion(prepared_model(data), target) ... loss.backward() ... optimizer.step() ... quantizer.step() >>> quantized_model = quantizer.finalize() """
[docs] def __init__( self, model: nn.Module, config: QuantizerConfig | None = None, ): if config is None: config = QuantizerConfig() execution_mode = config.execution_mode self._execution_mode = execution_mode # Create the underlying quantizer based on execution mode if execution_mode == ExecutionMode.GRAPH: self._quantizer = _GraphQuantizer(model, config) elif execution_mode == ExecutionMode.EAGER: self._quantizer = _EagerQuantizer(model, config) else: raise ValueError(f"Unsupported execution mode: {execution_mode}") super().__init__(model, config) # QAT schedule state self._step_count: int = 0 self._in_training_mode: bool = False # Mapping of FakeQuantize module to its corresponding QATSchedule self._fq_to_schedule: dict[FakeQuantizeImplBase, QATSchedule] = {} # Cached module-name → FQ-modules map, populated after prepare() self._module_to_fqs: dict[str, list[FakeQuantizeImplBase]] = {}
@property def _model(self): """Delegate to the underlying quantizer's model.""" return self._quantizer._model @_model.setter def _model(self, value): """Delegate model setting to the underlying quantizer.""" self._quantizer._model = value def _get_fake_quantize_modules(self) -> dict[str, list]: """Delegate to the underlying execution-mode quantizer.""" return self._quantizer._get_fake_quantize_modules()
[docs] @classmethod def get_compressible_op_names( cls, model: nn.Module | torch.fx.GraphModule, execution_mode: ExecutionMode, ) -> set[str]: """Return op names in *model* that this quantizer can target. Dispatches to the appropriate underlying quantizer based on *execution_mode*. Args: model (nn.Module): The model to get compressible op names for. execution_mode (ExecutionMode): The execution mode. Returns: set[str]: Op names that can be compressed via quantization. """ if execution_mode == ExecutionMode.GRAPH: return _GraphQuantizer.get_compressible_op_names(model) if execution_mode == ExecutionMode.EAGER: return _EagerQuantizer.get_compressible_op_names(model) msg = f"Unknown execution_mode {execution_mode}. Expected 'graph' or 'eager'." raise ValueError(msg)
def _resolve_schedule(self, module_name: str) -> QATSchedule | None: """Look up the QAT schedule for a module via the config hierarchy.""" for level in _ConfigLevel.priority_order(): config = self._module_config_dict[level].get(module_name) if config is not None: return config.qat_schedule return None def _build_fq_to_schedule(self) -> None: """Build ``_fq_to_schedule`` from cached config dict + FQ modules. Must be called after ``prepare()`` so that FQ modules exist. Requires ``_module_config_dict`` to have been populated before ``prepare()`` (since prepare may modify the module types in Eager mode). """ if self._fq_to_schedule: return for module_name, fq_list in self._module_to_fqs.items(): schedule = self._resolve_schedule(module_name) if schedule is not None: for fq_mod in fq_list: if fq_mod in self._fq_to_schedule: warnings.warn( f"FakeQuantize module under '{module_name}' is shared " f"with another module that already has a qat_schedule " f"assigned. The existing schedule will be kept.", UserWarning, stacklevel=2, ) else: self._fq_to_schedule[fq_mod] = schedule def _maybe_apply_qat_schedule(self) -> None: """Apply observer/fake-quant state for the current step count.""" for fq_module, schedule in self._fq_to_schedule.items(): state = schedule._compute_state(self._step_count) fq_module.enable_observer(state.obs_on) fq_module.enable_fake_quant(state.fq_on) def _validate_no_schedule_configured(self) -> None: """Raise RuntimeError if any FQ modules have a qat_schedule.""" if self._fq_to_schedule: raise RuntimeError( "Enable/disable APIs for observers or fake quantization cannot be " "used with a qat_schedule configured. To use these APIs, make sure " "there are no global or module-level qat_schedule configured. For " "using the QAT schedule, refer to the step() API." )
[docs] def step(self) -> None: """ Advance the QAT schedule by one step and apply observer/fake_quant transitions after the step has been incremented. Must be called inside a training_mode() context. Increments _step_count (monotonically; never reset between training loops), then applies the absolute observer/fake_quant state corresponding to the new step count. Raises: RuntimeError: If called outside a training_mode() context. Warns: UserWarning: If no qat_schedule is configured on any module. """ if not self._in_training_mode: raise RuntimeError("step() must be called inside a training_mode() context.") self._step_count += 1 if not self._fq_to_schedule: warnings.warn( "step() was called but no qat_schedule is configured on any module. " "step() has no effect. Configure a QATSchedule on at least one " "ModuleQuantizerConfig to use QAT scheduling.", UserWarning, stacklevel=2, ) return self._maybe_apply_qat_schedule()
def _maybe_apply_fn_to_fqs(self, fn: callable, module: nn.Module | None = None) -> None: """Apply fn to FQ modules if no QAT schedule is configured. Validates that no schedule is set (raises RuntimeError otherwise). If module is None, applies to the entire model. If module is given, finds its FQs via the cached ``_module_to_fqs`` by resolving the module's name and walking its children. Args: fn: A torchao function (e.g. ``enable_observer``) to apply. module: If None, applies to the entire model. Otherwise, applies only to FQs associated with the given module and its children. Raises: RuntimeError: If any ModuleQuantizerConfig has qat_schedule configured. ValueError: If the given module is not found in the model. """ self._validate_no_schedule_configured() if module is None: self._quantizer._model.apply(fn) return prefix = get_module_name(self._quantizer._model, module) if prefix is None: raise ValueError(f"Module {module} is not a submodule of the prepared model.") for child_name, _ in module.named_modules(): full_name = f"{prefix}.{child_name}" if child_name else prefix for fq in self._module_to_fqs.get(full_name, []): fq.apply(fn)
[docs] def enable_observer(self, module: nn.Module | None = None) -> None: """Enable observers on the model or a specific module.""" self._maybe_apply_fn_to_fqs(_torchao_enable_observer, module)
[docs] def disable_observer(self, module: nn.Module | None = None) -> None: """Disable observers on the model or a specific module.""" self._maybe_apply_fn_to_fqs(_torchao_disable_observer, module)
[docs] def enable_fake_quant(self, module: nn.Module | None = None) -> None: """Enable fake quantization on the model or a specific module.""" self._maybe_apply_fn_to_fqs(_torchao_enable_fake_quant, module)
[docs] def disable_fake_quant(self, module: nn.Module | None = None) -> None: """Disable fake quantization on the model or a specific module.""" self._maybe_apply_fn_to_fqs(_torchao_disable_fake_quant, module)
[docs] def prepare( self, example_inputs: tuple[Any, ...], dynamic_shapes: dict[str, Any] | tuple[Any] | list[Any] | None = None, export_with_no_grad: bool = True, ) -> nn.Module | fx.GraphModule: """ Prepare the model for quantization by inserting fake quantization modules. **Graph Mode:** Exports the model using torch.export, applies quantization annotations, and sets up fake quantization modules. Returns an fx.GraphModule. **Eager Mode:** Uses `__torch_function__` to trace model execution and insert fake quantizers during the forward pass. Returns an nn.Module. **Important Notes:** - For weight-only PTQ: The prepared model can be directly finalized (prepare() → finalize() workflow). - For activation quantization: The prepared model should be calibrated using calibration_mode() before finalization to collect statistics and achieve good accuracy. Args: example_inputs: Tuple of example inputs for model tracing. When activation quantization is in use, these should be representative of the data the model would typically see. dynamic_shapes: Dynamic shapes specification (graph mode only). Ignored in EAGER mode. export_with_no_grad: Whether to export with no_grad (graph mode only). Ignored in EAGER mode. Returns: The prepared model with fake quantization modules inserted, ready for calibration or training. This is a data-free PTQ compressed model. Note: In graph mode, the returned ``fx.GraphModule`` supports calling ``.train()`` and ``.eval()``, but with limited effect: only dropout and batchnorm ops are affected via FX graph rewriting. User code branching on the ``training`` flag and other ops with mode-dependent behavior are not affected. """ # Cache config dict before prepare() so that module_type_configs can # match original types. After prepare, modules can be modified such # that the types no longer match what is given in the config. self._module_config_dict = self._config.build_module_config_dict(self._quantizer._model) if self._execution_mode == ExecutionMode.EAGER: if dynamic_shapes is not None: warnings.warn( "dynamic_shapes is only supported in graph mode and will be ignored.", UserWarning, stacklevel=2, ) if not export_with_no_grad: warnings.warn( "export_with_no_grad is only supported in graph mode and will be ignored.", UserWarning, stacklevel=2, ) prepared_model = self._quantizer.prepare(example_inputs) else: prepared_model = self._quantizer.prepare( example_inputs, dynamic_shapes=dynamic_shapes, export_with_no_grad=export_with_no_grad, ) self._module_to_fqs = self._get_fake_quantize_modules() self._build_fq_to_schedule() return prepared_model
def _validate_mmap_dir_constraints( self, model: nn.Module | fx.GraphModule | None, backend: ExportBackend, mmap_dir: str | PathLike[str] | None, ) -> None: """Validate that ``mmap_dir`` is compatible with the current execution mode, target backend, and model device. No-op when ``mmap_dir is None``. """ if mmap_dir is None: return if self._execution_mode != ExecutionMode.EAGER: raise ValueError( "mmap_dir is only supported in eager execution mode, " f"got execution_mode={self._execution_mode}." ) model_to_check = model if model is not None else self._model _validate_mmap_backend_and_device(model_to_check, backend, mmap_dir)
[docs] def finalize( self, model: nn.Module | fx.GraphModule | None = None, backend: ExportBackend = ExportBackend.CoreAI, *, mmap_dir: str | PathLike[str] | None = None, ) -> nn.Module | fx.GraphModule: """Convert quantized model to backend-specific representations. Converts fake quantization modules into backend-specific quantization ops. Only call ``finalize`` when exporting to a target backend. For torch-based evaluation, use the model returned by ``prepare()`` directly rather than calling ``finalize``. Backend-specific processing: - CoreAI: Prepares for CoreAI export by replacing fake quantization modules with Core AI specific PyTorch custom ops. - CoreML: Prepares for CoreML export by registering compression metadata as buffers and removes fake quantization modules. Args: model: Optional model to finalize. If None, uses the internal prepared model. backend: Target export backend for the quantized model. Supports CoreAI (default), CoreML, and _TORCH backends. mmap_dir (str | None): If provided, serialize finalized quantized weights to safetensors files under this directory and re-load them via mmap. Only supported in eager execution mode with the CoreAI backend; raises ``ValueError`` otherwise. The files in ``mmap_dir`` must remain in place for the lifetime of the returned model; removing them invalidates the mmap-backed weights. Returns: The finalized quantized model ready for deployment on the target backend. Note: In graph mode, the returned ``fx.GraphModule`` supports calling ``.train()`` and ``.eval()``, but with limited effect: only dropout and batchnorm ops are affected via FX graph rewriting. User code branching on the ``training`` flag and other ops with mode-dependent behavior are not affected. Note: When ``backend=ExportBackend.CoreAI`` in execution_mode=ExecutionMode.EAGER, finalize frees the original dense weights. """ self._validate_mmap_dir_constraints(model, backend, mmap_dir) return self._quantizer.finalize(model, backend, mmap_dir=mmap_dir)
[docs] @contextmanager def calibration_mode(self, model: nn.Module | fx.GraphModule | None = None): """ Context manager for calibration-based post-training quantization. When entering this context, observers are enabled to collect statistics from calibration data, and fake quantization is disabled to get accurate statistics. When exiting, observers are disabled and fake quantization is re-enabled for evaluation. **When to use:** - Required for activation quantization to achieve good accuracy. The model post prepare() may have poor accuracy for activation quantization until calibrated with representative data - Not needed for weight-only PTQ (prepare() → finalize() is sufficient) Args: model: Optional model to setup for calibration. If None, uses the internal prepared model. Example: >>> quantizer = Quantizer(model, config) >>> prepared_model = quantizer.prepare(example_inputs) >>> # For activation quantization, calibrate to improve accuracy: >>> with quantizer.calibration_mode(): ... for batch in calibration_dataloader: ... prepared_model(batch) >>> finalized_model = quantizer.finalize() Raises: RuntimeError: If the model has not been prepared. """ with self._quantizer.calibration_mode(model): yield
[docs] @contextmanager def training_mode(self, model: nn.Module | fx.GraphModule | None = None): """ Context manager for quantization-aware training (QAT) workflow. When entering this context, the model is configured for training with both observers and fake quantization enabled (default behavior), or with the state determined by the current step count if a QATSchedule is configured. This allows the model to: 1. Set the model in training mode (model.training is set to True) 2. Enable the observers and activate the fake quantization 3. Using the observers, simulate quantization during forward/backward passes When exiting the context, observers are disabled and fake quantization is enabled (regardless of schedule). The step count is not reset when re-entering training_mode() — it resumes from the last value, so schedule state is restored from the accumulated count. Nested calls to training_mode() are not allowed and will raise a RuntimeError. **When to use:** - For quantization-aware training (QAT) to fine-tune a prepared model - The prepared model from prepare() may have poor accuracy for weight-only quantization. Fine-tuning the model with the quantization enabled will help the weights adapt to the effects of quantization. - Upon calibrating an activation-quantized model, there wasn't enough improvement in model accuracy. Fine-tuning the weights to adapt to the effect of activation (and weight) quantization can help recover the lost accuracy. Args: model: Optional model to setup for training. If None, uses the internal prepared model. Example: >>> quantizer = Quantizer(model, config) >>> prepared_model = quantizer.prepare(example_inputs) >>> # Fine-tune with quantization-aware training: >>> with quantizer.training_mode(): ... # Model is put in training mode ... for epoch in range(num_epochs): ... for batch in train_dataloader: ... # Perform training step ... optimizer.zero_grad() ... loss = loss_fn(prepared_model(batch), targets) ... loss.backward() ... optimizer.step() ... quantizer.step() ... >>> finalized_model = quantizer.finalize() Raises: RuntimeError: If the model has not been prepared. RuntimeError: If called while already inside a training_mode() context. TypeError: If the provided model is not a torch.fx.GraphModule (graph mode). """ if self._in_training_mode: raise RuntimeError( "Cannot enter training_mode() while already inside a " "training_mode() context. Nested training_mode() calls are not " "supported." ) self._in_training_mode = True try: with self._quantizer.training_mode(model): # Inner training_mode enables obs+fq by default. # Override with schedule state if configured. self._maybe_apply_qat_schedule() yield finally: self._in_training_mode = False