# Copyright 2026 Apple Inc.
#
# Use of this source code is governed by a BSD-3-Clause license that can
# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause
from __future__ import annotations
import warnings
from abc import abstractmethod
import torch
import torch.nn as nn
from torchao.quantization.quant_primitives import (
choose_qparams_affine_with_min_max,
)
from coreai_opt._utils.registry_utils import (
ClassRegistryMixin as _ClassRegistryMixin,
)
from coreai_opt._utils.torch_utils import (
E8M0_EXPONENT_BIAS as _E8M0_EXPONENT_BIAS,
F32_MIN_NORMAL as _F32_MIN_NORMAL,
FP_DTYPE_TO_MAX_POW2 as _FP_DTYPE_TO_MAX_POW2,
)
from .granularity import QuantizationGranularity
from .qscheme import QuantizationScheme
from .range_calculator import RangeCalculatorBase
[docs]
class QParamsCalculatorBase(_ClassRegistryMixin, nn.Module):
"""
Base class for implementing logic to calculate quantization parameters
(scale, zero_point, minval) given min/max values.
"""
scale: torch.Tensor
zero_point: torch.Tensor | None
minval: torch.Tensor | None
[docs]
def __init__(
self,
dtype: torch.dtype,
qscheme: QuantizationScheme,
granularity: QuantizationGranularity,
target_dtype: torch.dtype,
quant_min: int,
quant_max: int,
range_calculator: RangeCalculatorBase,
float_range: tuple[float | None, float | None],
scale_dtype: torch.dtype | None = None, # primarily for FP4
**kwargs,
):
super().__init__()
self.scale_dtype = scale_dtype
self.dtype = dtype
self.qscheme = qscheme
self._granularity = granularity
self.target_dtype = target_dtype
self.quant_min = quant_min
self.quant_max = quant_max
self.range_calculator = range_calculator
self.float_range = float_range
self.register_buffer("scale", torch.empty(0))
if dtype.is_floating_point:
self.register_buffer("zero_point", None)
self.register_buffer("minval", None)
else:
self.register_buffer("zero_point", torch.empty(0, dtype=torch.int32))
self.register_buffer("minval", torch.empty(0))
self._initialized = False
self._export_mode = False
# This is added to address MLIR limitation where
# tensor after q-dq op is not casted to incoming tensor dtype
self._compute_dtype_for_export = torch.float32
# Resolved non-negative axis for per-channel export paths.
# Ellipsis sentinel means unresolved; resolved on first forward() call.
self._resolved_axis: int | None = ... # type: ignore[assignment]
@property
def granularity(self) -> QuantizationGranularity:
"""Getter for granularity."""
return self._granularity
@granularity.setter
def granularity(self, granularity: QuantizationGranularity) -> None:
"""Update granularity for this calculator and its range calculator.
Can only be performed before the first forward pass.
"""
if self._initialized:
raise RuntimeError(
"Cannot change granularity after observer has been initialized. "
"Granularity must be set before the first forward pass."
)
self._granularity = granularity
self.range_calculator.granularity = granularity
def _resolve_axis(self, tensor_ndim: int) -> None:
"""Resolve axis to non-negative on first call, delegating to granularity.
Caches the result in ``_resolved_axis``. Each QParamsCalculator instance
is per-node, so tensor rank is consistent across calls.
Args:
tensor_ndim: Rank of the input tensor.
"""
if self._resolved_axis is not ...:
return
self._resolved_axis = QuantizationGranularity._resolve_axis(self.granularity, tensor_ndim)
def _get_tensor_with_granularity_from_scalar(
self, scalar: float, input_tensor: torch.Tensor
) -> torch.Tensor:
"""
Return a tensor with dimensions equal to num blocks in each dimension, comprised
of values equal to scalar.
"""
block_size_list = self.granularity.get_block_size(input_tensor.shape)
num_blocks_list = [
inp_size // block_size
for inp_size, block_size in zip(input_tensor.shape, block_size_list, strict=True)
]
return torch.full(size=num_blocks_list, fill_value=scalar, device=input_tensor.device)
def _get_min_and_max_val(self, tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Return min and max tensors computed from range calculator statistics and/or
taken from float_range setting.
"""
min_val = (
self._get_tensor_with_granularity_from_scalar(self.float_range[0], tensor)
if self.float_range[0] is not None
else None
)
max_val = (
self._get_tensor_with_granularity_from_scalar(self.float_range[1], tensor)
if self.float_range[1] is not None
else None
)
if min_val is None or max_val is None:
computed_min, computed_max = self.range_calculator(tensor)
if min_val is None:
min_val = torch.clamp(computed_min, max=0)
if max_val is None:
max_val = torch.clamp(computed_max, min=0)
return min_val, max_val
def _compute_e8m0_scale(self, max_abs: torch.Tensor) -> torch.Tensor:
"""
Compute power-of-2 scale in e8m0 format using FLOOR mode.
References:
- OCP Microscaling Formats (MX) Specification:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
- torchao implementation:
https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/mx_tensor.py
"""
target_max_pow2 = _FP_DTYPE_TO_MAX_POW2.get(self.dtype)
if target_max_pow2 is None:
raise ValueError(
f"Unsupported dtype for e8m0 scale computation: {self.dtype}. "
f"Supported: {list(_FP_DTYPE_TO_MAX_POW2.keys())}"
)
max_abs_fp32 = max_abs.to(torch.float32)
max_abs_int32 = max_abs_fp32.view(torch.int32)
# Extract biased exponent from float32 (bits 23-30)
extracted_pow2 = ((max_abs_int32 >> 23) & 0xFF) - _E8M0_EXPONENT_BIAS
# Scale exponent = data exponent - target_max_pow2
scale_e8m0_unbiased = extracted_pow2 - target_max_pow2
# Clamp to e8m0 representable range
scale_e8m0_unbiased = torch.clamp(
scale_e8m0_unbiased,
min=-_E8M0_EXPONENT_BIAS,
max=_E8M0_EXPONENT_BIAS + 1,
)
# Convert biased e8m0 back to float32: scale = 2^(unbiased_exponent)
scale_e8m0_biased = scale_e8m0_unbiased + _E8M0_EXPONENT_BIAS
scale_fp32 = (scale_e8m0_biased.to(torch.int32) << 23).view(torch.float32)
# Clamp to minimum normal float32 to avoid denormals
scale_fp32 = torch.clamp(scale_fp32, min=_F32_MIN_NORMAL)
return scale_fp32
def _compute_scale_zero_point_minval(
self, tensor: torch.Tensor, min_val: torch.Tensor, max_val: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
"""Supports two scale computation modes based on ``scale_dtype``:
1. **Default (``scale_dtype=None``)**: Uses torchao's
``choose_qparams_affine_with_min_max`` custom op to compute scale and
zero point. Supports integer dtypes and FP8 dtypes.
2. **e8m0 scales** (``scale_dtype=torch.float8_e8m0fnu``): Scales are
constrained to powers of 2 following the OCP Microscaling (MX)
specification (FLOOR mode):
scale = 2^(floor(log2(max_abs)) - target_max_pow2)
where ``target_max_pow2`` is the largest power-of-2 component of the
target dtype's maximum representable value:
- FP4 E2M1: max = 6.0 = 1.5 * 2^2, target_max_pow2 = 2
- FP8 E4M3: max = 448.0 = 1.75 * 2^8, target_max_pow2 = 8
- FP8 E5M2: max = 57344.0 = 1.75 * 2^15, target_max_pow2 = 15
"""
# e8m0 path: power-of-2 scales
if self.scale_dtype == torch.float8_e8m0fnu:
max_abs = torch.maximum(torch.abs(min_val), torch.abs(max_val))
return self._compute_e8m0_scale(max_abs), None, None
# Default path: torchao handles integer and FP8 dtypes
scale, zero_point = choose_qparams_affine_with_min_max(
min_val=min_val,
max_val=max_val,
mapping_type=QuantizationScheme._to_mapping_type(self.qscheme),
block_size=self.granularity.get_block_size(tensor.shape),
target_dtype=self.target_dtype,
quant_min=self.quant_min,
quant_max=self.quant_max,
eps=torch.finfo(torch.float32).eps,
zero_point_dtype=torch.int32,
)
# ``minval`` is the minimum representable float value for the
# observed range. The fake-quantize / export layer decides whether
# the ZP or MINVAL formulation is in use and selects between
# ``zero_point`` and ``minval`` accordingly; the calculator just
# provides both.
if self.qscheme in [
QuantizationScheme.SYMMETRIC,
QuantizationScheme.SYMMETRIC_WITH_CLIPPING,
]:
minval = -torch.max(torch.abs(min_val), torch.abs(max_val))
else:
# Asymmetric: ``minval`` is ``min(min_val, 0)``. This is a
# defensive no-op — ``min_val`` is always ``<= 0`` here because:
# - the spec validator rejects ``float_range[0] > 0`` at
# construction (see ``QuantizationSpec.validate_float_range``);
# - the computed-range path clamps to ``<= 0`` in
# ``_get_min_and_max_val``.
minval = torch.min(min_val, torch.zeros_like(min_val))
# For FP dtypes, neither zero_point nor minval is used (symmetric
# quantization with no offset). The buffers are registered as None.
if self.dtype.is_floating_point:
zero_point = None
minval = None
return scale, zero_point, minval
def _initialize_state(
self,
tensor: torch.Tensor,
min_val: torch.Tensor,
max_val: torch.Tensor,
) -> None:
"""Hook for subclass-specific initialization on the first forward pass.
Tensor shape and device are not known at construction time, so stateful
subclasses override this to resize their buffers (e.g.
``running_min``/``running_max``) to match the observed input."""
[docs]
def compute_qparams(
self,
tensor: torch.Tensor,
min_val: torch.Tensor,
max_val: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
"""Given the observed min/max range, return ``(scale, zero_point, minval)``.
The default implementation directly computes qparams from the given
range via ``_compute_scale_zero_point_minval``. This is the correct behavior
for stateless calculators (e.g. ``StaticQParamsCalculator``).
Stateful calculators override this via ``RunningRangeMixin`` to update
running-range buffers before computing qparams from the smoothed range.
"""
return self._compute_scale_zero_point_minval(tensor, min_val, max_val)
[docs]
def forward(
self, tensor: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
"""Compute scale, zero point, and minval from the input tensor.
On the first forward pass, initializes internal buffers using the
observed tensor shape and device. Delegates the actual qparams
calculation to ``compute_qparams``.
"""
self._resolve_axis(tensor.ndim)
scale = self.scale.clone()
zero_point = self.zero_point.clone() if self.zero_point is not None else None
minval = self.minval.clone() if self.minval is not None else None
# Optimization to skip range updates if both min and max are frozen.
if (
self._initialized
and self.float_range[0] is not None
and self.float_range[1] is not None
):
return scale, zero_point, minval
if not self._export_mode:
min_val, max_val = self._get_min_and_max_val(tensor)
if not self._initialized:
self._compute_dtype_for_export = tensor.dtype
self._initialize_state(tensor, min_val, max_val)
scale, zero_point, minval = self.compute_qparams(tensor, min_val, max_val)
if not self._initialized:
self.scale = torch.empty(scale.shape, dtype=self.scale.dtype, device=tensor.device)
# Only resize zero_point if it exists (None for FP4/FP8)
if zero_point is not None:
self.zero_point = torch.empty(
zero_point.shape,
dtype=self.zero_point.dtype,
device=tensor.device,
)
if minval is not None:
self.minval = torch.empty(
minval.shape,
dtype=self.minval.dtype,
device=tensor.device,
)
self._initialized = True
self.scale.copy_(scale.detach())
if zero_point is not None:
self.zero_point.copy_(zero_point.detach())
if minval is not None:
self.minval.copy_(minval.detach())
return scale, zero_point, minval
[docs]
def get_qparams(self) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
"""
Return the computed scale, zero point and minval.
For FP4/FP8/floating-point quantization, zero_point and minval are None.
"""
if not self._initialized:
warnings.warn(
"Using default scale and zero point. Call forward pass to compute actual values.",
stacklevel=1,
)
return self.scale, self.zero_point, self.minval
[docs]
def set_export_mode(self, enabled: bool = True) -> None:
self._export_mode = enabled
@QParamsCalculatorBase.register("default")
class _DefaultQParamsCalculator(QParamsCalculatorBase):
"""
Marker class for context-dependent qparam calculator resolution.
This class should not be used directly. When "default" is specified for
qparam_calculator_cls, the QuantizationComponentFactory will resolve it to:
- StaticQParamsCalculator for weight quantization
- MovingAverageQParamsCalculator for activation quantization
Raises:
RuntimeError: If __init__ is called, indicating the factory didn't resolve it
"""
def __init__(self, **kwargs):
raise RuntimeError(
"_DefaultQParamsCalculator is a marker class and must be resolved by "
"QuantizationComponentFactory before use. This error indicates the factory "
"did not properly resolve 'default' to the appropriate qparam calculator "
"based on quantization target (weight or activation)."
)
[docs]
@QParamsCalculatorBase.register("static")
class StaticQParamsCalculator(QParamsCalculatorBase):
"""
Computes scale and zero point using min/max values from the current tensor.
This QParamsCalculator directly uses the min/max range from each forward pass to compute
quantization parameters. So in that sense, it does not maintain any "history" and
only computes the min/max based off of the current (most recent) tensor input.
This QParamsCalculator is typically used for weight quantization. In case of PTQ based
workflows the weights are fixed and during QAT, the min/max range is calculated using the
most recent weight tensor value.
Uses the base-class default ``compute_qparams`` which
directly delegates to ``_compute_scale_zero_point_minval`` without any running state.
"""
# Inherits base-class' default: compute_qparams
# which directly computes qparams from current min/max with no running state.
# ``# type: ignore`` comments are used where the mixin accesses
# attributes and methods provided by ``QParamsCalculatorBase`` /
# ``nn.Module``, which mypy cannot resolve from the mixin class alone.
[docs]
class RunningRangeMixin:
"""Mixin for calculators that maintain running min/max range buffers.
Provides ``running_min`` and ``running_max`` buffers, first-forward
initialization via ``_initialize_state``, and a
``compute_qparams`` implementation that delegates
the range-update rule to the abstract ``update_running_range`` hook.
Subclasses that want to re-use the logic of computing quantization
parameters but with different ways of updating the running statistics
can override the ``update_running_range`` method.
Must appear before ``QParamsCalculatorBase`` in the MRO so that its
``compute_qparams`` and ``_initialize_state``
take precedence over the base-class defaults.
"""
running_min: torch.Tensor
running_max: torch.Tensor
[docs]
def __init__(self, *args: object, **kwargs: object) -> None:
super().__init__(*args, **kwargs)
self.register_buffer("running_min", torch.empty(0)) # type: ignore[attr-defined]
self.register_buffer("running_max", torch.empty(0)) # type: ignore[attr-defined]
def _initialize_state(
self,
tensor: torch.Tensor,
min_val: torch.Tensor,
max_val: torch.Tensor,
) -> None:
"""Resize and move running_min/running_max buffers to match input
tensor shape and device on the first forward pass."""
self.running_min = self.running_min.to(device=tensor.device, dtype=min_val.dtype)
self.running_min.resize_(min_val.shape)
self.running_max = self.running_max.to(device=tensor.device, dtype=max_val.dtype)
self.running_max.resize_(max_val.shape)
[docs]
def compute_qparams(
self,
tensor: torch.Tensor,
min_val: torch.Tensor,
max_val: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
"""Update running range, persist to buffers, then compute qparams."""
if self._initialized: # type: ignore[attr-defined]
running_min, running_max = self.update_running_range(min_val, max_val)
else:
running_min, running_max = min_val, max_val
self.running_min.data.copy_(running_min)
self.running_max.data.copy_(running_max)
return self._compute_scale_zero_point_minval(tensor, running_min, running_max) # type: ignore[attr-defined, no-any-return]
[docs]
@abstractmethod
def update_running_range(
self, min_val: torch.Tensor, max_val: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Return ``(updated_min, updated_max)`` using subclass-specific rule."""
[docs]
@QParamsCalculatorBase.register("moving_average")
class MovingAverageQParamsCalculator(RunningRangeMixin, QParamsCalculatorBase):
"""
Computes the scale and zero point using a moving average of the range.
Maintains ``running_min`` and ``running_max`` buffers that are updated each
forward pass using exponential moving average (EMA):
a_{i} = c * x_{i} + (1 - c) * a_{i-1}
where ``c`` is the ``averaging_constant``.
"""
[docs]
def __init__(
self,
dtype: torch.dtype,
qscheme: QuantizationScheme,
granularity: QuantizationGranularity,
target_dtype: torch.dtype,
quant_min: int,
quant_max: int,
range_calculator: RangeCalculatorBase,
float_range: list[float | None],
averaging_constant: float = 1e-2,
**kwargs,
):
super().__init__(
dtype=dtype,
qscheme=qscheme,
granularity=granularity,
target_dtype=target_dtype,
quant_min=quant_min,
quant_max=quant_max,
range_calculator=range_calculator,
float_range=float_range,
**kwargs,
)
self.averaging_constant = averaging_constant
[docs]
def update_running_range(
self, min_val: torch.Tensor, max_val: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
# Exponential moving average of the min and max values
# a_{i} = c * x_{i} + (1-c) * a_{i-1}
# Reference: https://en.wikipedia.org/wiki/Exponential_smoothing
running_min = (
self.averaging_constant * min_val + (1 - self.averaging_constant) * self.running_min
)
running_max = (
self.averaging_constant * max_val + (1 - self.averaging_constant) * self.running_max
)
return running_min, running_max
[docs]
@QParamsCalculatorBase.register("global_minmax")
class GlobalMinMaxQParamsCalculator(RunningRangeMixin, QParamsCalculatorBase):
"""Computes scale and zero point by tracking the running min/max.
Maintains ``running_min`` and ``running_max`` buffers that are updated each
forward pass via element-wise minimum and maximum:
running_min = min(running_min, x_min)
running_max = max(running_max, x_max)
"""
[docs]
def update_running_range(
self, min_val: torch.Tensor, max_val: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
running_min = torch.minimum(self.running_min.detach(), min_val)
running_max = torch.maximum(self.running_max.detach(), max_val)
return running_min, running_max