Source code for coreai_opt.quantization.spec.factory

# Copyright 2026 Apple Inc.
#
# Use of this source code is governed by a BSD-3-Clause license that can
# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause

from __future__ import annotations

from coreai_opt._utils.spec_utils import PartialConstructor as _PartialConstructor
from coreai_opt.config.spec import (
    CompressionComponentFactoryBase,
    CompressionTargetTensor,
)

from .fake_quantize import FakeQuantizeImplBase
from .qparams_calculator import (
    MovingAverageQParamsCalculator,
    QParamsCalculatorBase,
    StaticQParamsCalculator,
    _DefaultQParamsCalculator,
)
from .range_calculator import RangeCalculatorBase
from .spec import QuantizationSpec


[docs] class QuantizationComponentFactory(CompressionComponentFactoryBase): """ Factory class for creating quantization components from QuantizationSpec. This factory eliminates circular dependencies between QuantizationSpec and component classes (FakeQuantizeImplBase, QParamsCalculatorBase, RangeCalculatorBase) by centralizing the creation logic. """
[docs] @classmethod def create_range_calculator(cls, spec: QuantizationSpec) -> RangeCalculatorBase: """ Create a RangeCalculatorBase instance from a QuantizationSpec. Args: spec: QuantizationSpec instance containing configuration Returns: RangeCalculatorBase instance configured from the spec """ # Standard arguments for range calculator common_args = { "granularity": spec.granularity, } # Automatically detect and include any extra arguments extra_args = spec.get_extra_args() # Create instance with all arguments return spec.range_calculator_cls(**common_args, **extra_args)
[docs] @classmethod def create_qparams_calculator( cls, spec: QuantizationSpec, quantization_target: CompressionTargetTensor ) -> QParamsCalculatorBase: """ Create a QParamsCalculatorBase instance from a QuantizationSpec. Args: spec: QuantizationSpec instance containing configuration quantization_target: The target tensor for quantization (weight/activation) Returns: QParamsCalculatorBase instance configured from the spec """ # Resolve "default" marker class based on quantization target qparam_calculator_cls = spec.qparam_calculator_cls if qparam_calculator_cls is _DefaultQParamsCalculator: if quantization_target in ( CompressionTargetTensor.WEIGHT, CompressionTargetTensor.LUT, ): qparam_calculator_cls = StaticQParamsCalculator elif quantization_target == CompressionTargetTensor.ACTIVATION: qparam_calculator_cls = MovingAverageQParamsCalculator else: raise ValueError( f"Unsupported quantization target: {quantization_target}. " f"Expected WEIGHT, ACTIVATION, or LUT." ) # Create range calculator first range_calculator = cls.create_range_calculator(spec) # Standard arguments for qparams calculator common_args = { "dtype": spec.dtype, "qscheme": spec.qscheme, "granularity": spec.granularity, "target_dtype": spec.target_dtype, "quant_min": spec.quant_min, "quant_max": spec.quant_max, "range_calculator": range_calculator, "float_range": spec.float_range, "scale_dtype": spec.scale_dtype, } # Automatically detect and include any extra arguments extra_args = spec.get_extra_args() # Create instance with all arguments return qparam_calculator_cls(**common_args, **extra_args)
[docs] @classmethod def construct( cls, spec: QuantizationSpec | None, target: CompressionTargetTensor ) -> FakeQuantizeImplBase | None: """ Create a fake quantizer instance from a QuantizationSpec. This method implements the base class interface and delegates to create_fake_quantizer. Args: spec: QuantizationSpec instance containing configuration target: The target tensor for compression (weight or activation) Returns: FakeQuantizeImplBase instance configured from the spec, or None if spec is None """ if spec is None: return None return cls.create_fake_quantizer(spec, target)
[docs] @classmethod def construct_partial( cls, spec: QuantizationSpec | None, target: CompressionTargetTensor ) -> _PartialConstructor[FakeQuantizeImplBase] | None: """ Create a fake quantizer partial object for deferred construction. This method implements the base class interface and delegates to create_fake_quantizer_partial. Args: spec: QuantizationSpec instance containing configuration target: The target tensor for compression (weight or activation) Returns: PartialConstructor: A partial object for deferred construction, or None if spec is None """ if spec is None: return None return cls.create_fake_quantizer_partial(spec, target)
[docs] @classmethod def create_fake_quantizer( cls, spec: QuantizationSpec, quantization_target: CompressionTargetTensor ) -> FakeQuantizeImplBase: """ Create a FakeQuantizeImplBase instance from a QuantizationSpec. This method automatically detects any extra arguments in the spec beyond the base QuantizationSpec fields and passes them to the fake quantizer constructor. Args: spec: QuantizationSpec instance containing configuration quantization_target: The target tensor for quantization Returns: FakeQuantizeImplBase instance configured from the spec Example: >>> spec = QuantizationSpec(...) >>> fake_quantize = QuantizationComponentFactory.create_fake_quantizer( ... spec, quantization_target=CompressionTargetTensor.WEIGHT ... ) >>> extended_spec = ExtraArgQuantizationSpec(eps=0.1, ...) >>> fake_quantize = QuantizationComponentFactory.create_fake_quantizer( ... extended_spec, ... quantization_target=CompressionTargetTensor.ACTIVATION ... ) """ # For direct instantiation, create qparams calculator immediately qparams_calculator = cls.create_qparams_calculator(spec, quantization_target) # Standard arguments that all fake quantizers need common_args = { "dtype": spec.dtype, "qscheme": spec.qscheme, "qformulation": spec.qformulation, "granularity": spec.granularity, "target_dtype": spec.target_dtype, "quant_min": spec.quant_min, "quant_max": spec.quant_max, "qparams_calculator": qparams_calculator, "quantization_target": quantization_target, "n_bits": spec.n_bits, } # Automatically detect and include any extra arguments extra_args = spec.get_extra_args() # Create instance with all arguments return spec.fake_quantize_cls(**common_args, **extra_args)
[docs] @classmethod def create_fake_quantizer_partial( cls, spec: QuantizationSpec, quantization_target: CompressionTargetTensor ) -> _PartialConstructor[FakeQuantizeImplBase]: """ Create a fake quantizer partial object for deferred construction by the graph-mode prepare API (torchao PT2E). Args: spec: QuantizationSpec instance containing configuration quantization_target: The target tensor for quantization Returns: PartialConstructor: A partial object that can be used by the graph-mode prepare API to construct fake quantizer instances. Each call to the partial will create a new instance with its own qparams_calculator. """ # For partial construction, we need to defer qparams_calculator creation # to ensure each instance gets its own calculator # Standard arguments that all fake quantizers need # (excluding qparams_calculator) common_args = { "dtype": spec.dtype, "qscheme": spec.qscheme, "qformulation": spec.qformulation, "granularity": spec.granularity, "target_dtype": spec.target_dtype, "quant_min": spec.quant_min, "quant_max": spec.quant_max, "quantization_target": quantization_target, "n_bits": spec.n_bits, } # Automatically detect and include any extra arguments extra_args = spec.get_extra_args() # Create a factory function that creates qparams_calculator on each call def qparams_calculator_factory(): return cls.create_qparams_calculator(spec, quantization_target) # Create partially constructed class obj with callable args # for qparams_calculator return spec.fake_quantize_cls.with_args(**common_args, **extra_args).with_callable_args( qparams_calculator=qparams_calculator_factory )