# Copyright 2026 Apple Inc.
#
# Use of this source code is governed by a BSD-3-Clause license that can
# be found in the LICENSE file or at https://opensource.org/licenses/BSD-3-Clause
"""Fake quantization implementation base class and default implementation."""
from __future__ import annotations
import logging
from abc import abstractmethod
from typing import Any
import torch
from torch.autograd import Function
from torchao.quantization.pt2e import FakeQuantizeBase
from coreai_opt._utils.spec_utils import (
PartialConstructor as _PartialConstructor,
with_args as _with_args,
)
from coreai_opt._utils.torch_utils import (
get_n_bits_from_dtype as _get_n_bits_from_dtype,
is_float4_dtype as _is_float4_dtype,
is_float8_dtype as _is_float8_dtype,
is_float_quant_dtype as _is_float_quant_dtype,
)
from coreai_opt.config.spec import CompressionSimulatorBase, CompressionTargetTensor
from coreai_opt.quantization._utils import get_quantization_shapes as _get_quantization_shapes
from coreai_opt.quantization.spec.errors import _BlockSizeMismatchError
from .granularity import QuantizationGranularity
from .qformulation import QuantizationFormulation
from .qparams_calculator import QParamsCalculatorBase
from .qscheme import QuantizationScheme
__all__ = ["FakeQuantizeImplBase"]
logger = logging.getLogger(__name__)
[docs]
class FakeQuantizeImplBase(CompressionSimulatorBase, FakeQuantizeBase):
"""
Base class for implementing fake quantization
"""
[docs]
def __init__(
self,
dtype: torch.dtype,
qscheme: QuantizationScheme,
qformulation: QuantizationFormulation,
granularity: QuantizationGranularity,
target_dtype: torch.dtype,
quant_min: int | float,
quant_max: int | float,
qparams_calculator: QParamsCalculatorBase,
quantization_target: CompressionTargetTensor,
n_bits: int | None = None,
**kwargs,
):
super().__init__()
self.dtype = dtype
self.qscheme = qscheme
self.qformulation = qformulation
self._granularity = granularity
self.target_dtype = target_dtype
self.quant_min = quant_min
self.quant_max = quant_max
self.qparams_calculator = qparams_calculator
self.quantization_target = quantization_target
self.register_buffer("_disabled", torch.tensor(False))
# Infer n_bits from dtype if not provided
if n_bits is None:
n_bits = _get_n_bits_from_dtype(dtype)
self.n_bits = n_bits
@property
def granularity(self) -> QuantizationGranularity:
"""Getter for granularity."""
return self._granularity
@granularity.setter
def granularity(self, granularity: QuantizationGranularity) -> None:
"""Update granularity for the fake quantize class and its qparams calculator.
Can only be performed before the first forward pass.
"""
self.qparams_calculator.granularity = granularity
self._granularity = granularity
[docs]
def is_disabled(self) -> bool:
"""Return True if fake quantization has been disabled."""
return self._disabled.item()
def _warn_and_disable(self, error: _BlockSizeMismatchError) -> None:
"""Log a warning and permanently disable this module."""
logger.warning(
"Tensor (target: %s) incompatible with block size "
"configuration: %s. Skipping quantization.",
self.quantization_target,
error,
)
self._disabled.fill_(True)
[docs]
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Performs fake quantization of the given tensor using the qparams
(scale, zero point, minval) computed by the QParamsCalculator.
"""
if self._disabled.item():
return tensor
if self.observer_enabled[0] == 1:
# Call the forward function of the qparams_calculator
# to collect observer statistics when the observer is
# enabled
# Use no_grad to prevent gradients flowing through the scale/zp computation path.
# Gradients should be computed through the actual QDQ path only.
with torch.no_grad():
try:
scale, zero_point, minval = self.qparams_calculator(tensor)
except _BlockSizeMismatchError as e:
self._warn_and_disable(e)
return tensor
else:
# When the observer is not enabled, call the get_qparams
# function to retrieved the stored statistics
scale, zero_point, minval = self.qparams_calculator.get_qparams()
if self.fake_quant_enabled[0] == 1:
# Cast incoming tensor to fp32 to perform qdq operations in high precision.
# Cast the tensor to return back to the original dtype.
orig_dtype = tensor.dtype
tensor = tensor.to(torch.float32)
return self._fused_fake_quant_dequant(tensor, scale, zero_point, minval).to(orig_dtype)
return tensor
[docs]
@abstractmethod
def quantize(
self,
tensor: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor | None,
minval: torch.Tensor | None,
cast_to_target_dtype: bool = True,
) -> torch.Tensor:
"""
Given a tensor, scale and zero point, perform quantization of the tensor based
on the configuration in the ``QuantizationSpec``.
Args:
tensor: The tensor to quantize
scale: The scale to use for quantization
zero_point: The zero point computed by the qparams calculator
(None for floating-point dtypes).
minval: The minimum representable float value of the observed
range, computed by the qparams calculator
(None for floating-point dtypes).
cast_to_target_dtype: If True, the quantized tensor is cast to the target_dtype.
Otherwise, the values of the tensor are quantized to appropriate bins but the dtype
used to represent the quantized tensor remains the same as the original tensor.
This allows fake quantization to capture the quantization error while allowing
gradients to backpropagate.
"""
pass
[docs]
@abstractmethod
def dequantize(
self,
tensor: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor | None,
minval: torch.Tensor | None,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Given a quantized tensor, the scale and zero point used to perform quantization,
perform de-quantization of the tensor based on the configuration in the
``QuantizationSpec`` and return it as a tensor with dtype as ``output_dtype``.
Args:
tensor: The tensor to dequantize
scale: The scale to use for dequantization
zero_point: The zero point computed by the qparams calculator
(None for floating-point dtypes).
minval: The minimum representable float value of the observed
range, computed by the qparams calculator
(None for floating-point dtypes).
output_dtype: The dtype to use for the dequantized tensor
"""
pass
@abstractmethod
def _fused_fake_quant_dequant(
self,
tensor: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor | None,
minval: torch.Tensor | None,
) -> torch.Tensor:
"""Fused quantize → dequantize as a single autograd node with STE gradient.
Expects the input tensor to already be in fp32. Returns an fp32 tensor;
the caller is responsible for casting to the desired output dtype.
"""
pass
[docs]
@classmethod
def with_args(cls, **kwargs: dict) -> _PartialConstructor[FakeQuantizeImplBase]:
# This is needed for compatibility with torch prepare_pt2e
fake_quant_constructor = _with_args(cls, **kwargs)
# need to assign the correct module to fake_quantize
# constructors to satisfy public v private requirements
fake_quant_constructor.__module__ = f"{cls.__module__}.{cls.__name__}"
return fake_quant_constructor
[docs]
def calculate_qparams(self) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
"""
Returns the computed (scale, zero_point, minval).
``zero_point`` and ``minval`` are None for floating-point dtypes.
"""
return self.qparams_calculator.get_qparams()
[docs]
def set_export_mode(self, enabled: bool = True) -> None:
"""
Set or unset export mode.
"""
self.qparams_calculator.set_export_mode(enabled=enabled)
[docs]
def convert(self, model: torch.fx.GraphModule, observer_node: torch.fx.Node) -> None:
"""No-op: keep fake quant nodes intact during convert_pt2e.
If this method is not present, torchao's convert method will try to replace
fake quant nodes with its standard quantize/dequantize ops and fails in the process
"""
@FakeQuantizeImplBase.register("default")
class _DefaultFakeQuantizeImpl(FakeQuantizeImplBase):
def _select_int_offsets(
self,
tensor: torch.Tensor,
zero_point: torch.Tensor,
minval: torch.Tensor,
reduced_shape: list[int],
) -> tuple[torch.Tensor, torch.Tensor]:
"""Pick ``(quant_offset, float_offset)`` for the active integer formulation.
``quant_offset`` lives in the integer (quantized) domain.
``float_offset`` lives in the float domain and is returned in
``tensor.dtype``. Works uniformly for signed and unsigned integers.
- ZP: ``(zero_point, 0)``
- MINVAL: ``(quant_min, minval)``
"""
if self.qformulation == QuantizationFormulation.ZP:
return (
zero_point.view(reduced_shape),
tensor.new_zeros(()),
)
if self.qformulation == QuantizationFormulation.MINVAL:
return (
tensor.new_full((), self.quant_min, dtype=zero_point.dtype),
minval.view(reduced_shape).to(tensor.dtype),
)
raise NotImplementedError(f"Unknown qformulation: {self.qformulation}")
def quantize(
self,
tensor: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor | None,
minval: torch.Tensor | None,
cast_to_target_dtype: bool = True,
) -> torch.Tensor:
# Cast incoming tensor to fp32 to perform quantize operations in high precision.
# Track the original dtype of the incoming tensor in case we need to cast the returning
# tensor back (if cast_to_target_dtype is False).
orig_dtype = tensor.dtype
tensor = tensor.to(torch.float32)
if _is_float_quant_dtype(self.dtype):
assert zero_point is None, "zero_point must be None for floating-point quantization"
assert minval is None, "minval must be None for floating-point quantization"
quantized_tensor = self._quantize_float(tensor, scale)
else:
assert zero_point is not None, "zero_point must not be None for integer quantization"
assert minval is not None, "minval must not be None for integer quantization"
quantized_tensor = self._quantize_int(tensor, scale, zero_point, minval)
output_dtype = self.target_dtype if cast_to_target_dtype else orig_dtype
return quantized_tensor.to(output_dtype)
def dequantize(
self,
tensor: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor | None,
minval: torch.Tensor | None,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
# Cast incoming tensor to fp32 to perform dequantize operations in high precision.
tensor = tensor.to(torch.float32)
if _is_float_quant_dtype(self.target_dtype):
assert zero_point is None, "zero_point must be None for floating-point dequantization"
assert minval is None, "minval must be None for floating-point dequantization"
return self._dequantize_float(tensor, scale, output_dtype)
# Integer dequantization
assert zero_point is not None, "zero_point must not be None for integer dequantization"
assert minval is not None, "minval must not be None for integer dequantization"
return self._dequantize_int(tensor, scale, zero_point, minval, output_dtype)
def _quantize_int(
self,
tensor: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
minval: torch.Tensor,
) -> torch.Tensor:
"""
Integer quantization. See :func:`_quantize_int` for the math; offsets are
selected from ``self.qformulation`` via :meth:`_select_int_offsets`.
This function quantizes the values in tensor but keeps the quantized tensor dtype in FP.
"""
block_size = self.granularity.get_block_size(tensor.shape)
original_shape, blockwise_shape, reduced_shape = _get_quantization_shapes(
tensor, block_size
)
tensor = tensor.view(blockwise_shape)
scale = scale.view(reduced_shape)
quant_offset, float_offset = self._select_int_offsets(
tensor, zero_point, minval, reduced_shape
)
quant, _ = _quantize_int(
tensor, scale, quant_offset, float_offset, self.quant_min, self.quant_max
)
return quant.view(original_shape)
def _dequantize_int(
self,
tensor: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
minval: torch.Tensor,
output_dtype: torch.dtype,
) -> torch.Tensor:
"""Integer dequantization. See :func:`_dequantize_int` for the math."""
block_size = self.granularity.get_block_size(tensor.shape)
original_shape, blockwise_shape, reduced_shape = _get_quantization_shapes(
tensor, block_size
)
tensor = tensor.view(blockwise_shape)
scale = scale.view(reduced_shape)
quant_offset, float_offset = self._select_int_offsets(
tensor, zero_point, minval, reduced_shape
)
dequant = _dequantize_int(tensor, scale, quant_offset, float_offset)
return dequant.view(original_shape).to(output_dtype)
def _quantize_float(
self,
tensor: torch.Tensor,
scale: torch.Tensor,
) -> torch.Tensor:
"""
Floating-point quantization: cast_to_low_precision(clamp(input / scale, min, max))
"""
block_size = self.granularity.get_block_size(tensor.shape)
original_shape, blockwise_shape, reduced_shape = _get_quantization_shapes(
tensor, block_size
)
tensor = tensor.view(blockwise_shape)
scale = scale.view(reduced_shape)
quantized_tensor, _ = _quantize_float(
tensor, scale, self.quant_min, self.quant_max, self.dtype
)
return quantized_tensor.view(original_shape)
def _dequantize_float(
self,
tensor: torch.Tensor,
scale: torch.Tensor,
output_dtype: torch.dtype,
) -> torch.Tensor:
"""Floating-point dequantization: input * scale"""
block_size = self.granularity.get_block_size(tensor.shape)
original_shape, blockwise_shape, reduced_shape = _get_quantization_shapes(
tensor, block_size
)
tensor = tensor.view(blockwise_shape)
scale = scale.view(reduced_shape)
dequant = _dequantize_float(tensor, scale)
return dequant.view(original_shape).to(output_dtype)
def _fused_fake_quant_dequant(
self,
tensor: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor | None,
minval: torch.Tensor | None,
) -> torch.Tensor:
"""Fused quantize → dequantize as a single autograd node with STE gradient.
Dispatches to the int or float fused STE class based on self.dtype.
"""
block_size = self.granularity.get_block_size(tensor.shape)
original_shape, blockwise_shape, reduced_shape = _get_quantization_shapes(
tensor, block_size
)
if _is_float_quant_dtype(self.dtype):
return _FusedFakeQuantizeFloatSTE.apply(
tensor,
scale,
self.quant_min,
self.quant_max,
self.dtype,
original_shape,
blockwise_shape,
reduced_shape,
)
quant_offset, float_offset = self._select_int_offsets(
tensor, zero_point, minval, reduced_shape
)
return _FusedFakeQuantizeIntSTE.apply(
tensor,
scale,
quant_offset,
float_offset,
self.quant_min,
self.quant_max,
original_shape,
blockwise_shape,
reduced_shape,
)
def _qdq_int(
tensor: torch.Tensor,
scale: torch.Tensor,
quant_offset: torch.Tensor,
float_offset: torch.Tensor,
quant_min: int,
quant_max: int,
original_shape: torch.Size,
blockwise_shape: list[int],
reduced_shape: list[int],
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Fused quantize → dequantize for integer types.
The quantized-dequantized tensor as well as a mask tensor marking positions in the tensor which
were clamped are both returned.
"""
tensor = tensor.view(blockwise_shape)
scale = scale.view(reduced_shape)
quantized, mask = _quantize_int(tensor, scale, quant_offset, float_offset, quant_min, quant_max)
dequantized = _dequantize_int(quantized, scale, quant_offset, float_offset)
return dequantized.view(original_shape).clone(), mask
class _FusedFakeQuantizeIntSTE(Function):
"""
Fused fake quantize + dequantize for integer types with STE gradient.
Handles blockwise reshaping internally so the entire fake-quantize operation
(reshape → quantize → dequantize → reshape back) is a single autograd node.
Fusing into one node reduces QAT memory: intermediate tensors (scaled, rounded,
clamped) are local to forward and freed immediately instead of being retained by
the autograd graph. Only a boolean mask (1 byte/element) is saved for backward,
replacing multiple float32 intermediates (4 bytes/element each).
"""
@staticmethod
def forward(
ctx: Any,
tensor: torch.Tensor,
scale: torch.Tensor,
quant_offset: torch.Tensor,
float_offset: torch.Tensor,
quant_min: int,
quant_max: int,
original_shape: torch.Size,
blockwise_shape: list[int],
reduced_shape: list[int],
) -> torch.Tensor:
dequantized, mask = _qdq_int(
tensor,
scale,
quant_offset,
float_offset,
quant_min,
quant_max,
original_shape,
blockwise_shape,
reduced_shape,
)
ctx.save_for_backward(mask)
ctx.original_shape = original_shape
return dequantized
@staticmethod
def backward(
ctx: Any, grad_output: torch.Tensor
) -> tuple[torch.Tensor, None, None, None, None, None, None, None, None]:
(mask,) = ctx.saved_tensors
# Reshape grad to blockwise shape to apply mask, then reshape back
grad_blockwise = grad_output.view(mask.shape)
return (
(grad_blockwise * mask).view(ctx.original_shape),
None,
None,
None,
None,
None,
None,
None,
None,
)
def _qdq_float(
tensor: torch.Tensor,
scale: torch.Tensor,
quant_min: float,
quant_max: float,
dtype: torch.dtype,
original_shape: torch.Size,
blockwise_shape: list[int],
reduced_shape: list[int],
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Fused quantize → dequantize for float types.
The quantized-dequantized tensor as well as a mask tensor marking positions in the tensor which
were clamped are both returned.
"""
tensor = tensor.view(blockwise_shape)
scale = scale.view(reduced_shape)
quantized, mask = _quantize_float(tensor, scale, quant_min, quant_max, dtype)
dequantized = _dequantize_float(quantized, scale)
return dequantized.view(original_shape).clone(), mask
class _FusedFakeQuantizeFloatSTE(Function):
"""
Fused fake quantize + dequantize for float types with STE gradient.
Handles blockwise reshaping internally so the entire fake-quantize operation
(reshape → quantize → dequantize → reshape back) is a single autograd node.
Fusing into one node reduces QAT memory: intermediate tensors (scaled, rounded,
clamped) are local to forward and freed immediately instead of being retained by
the autograd graph. Only a boolean mask (1 byte/element) is saved for backward,
replacing multiple float32 intermediates (4 bytes/element each).
"""
@staticmethod
def forward(
ctx: Any,
tensor: torch.Tensor,
scale: torch.Tensor,
quant_min: float,
quant_max: float,
dtype: torch.dtype,
original_shape: torch.Size,
blockwise_shape: list[int],
reduced_shape: list[int],
) -> torch.Tensor:
dequantized, mask = _qdq_float(
tensor,
scale,
quant_min,
quant_max,
dtype,
original_shape,
blockwise_shape,
reduced_shape,
)
ctx.save_for_backward(mask)
ctx.original_shape = original_shape
return dequantized
@staticmethod
def backward(
ctx: Any, grad_output: torch.Tensor
) -> tuple[torch.Tensor, None, None, None, None, None, None, None]:
(mask,) = ctx.saved_tensors
# Reshape grad to blockwise shape to apply mask, then reshape back
grad_blockwise = grad_output.view(mask.shape)
return (
(grad_blockwise * mask).view(ctx.original_shape),
None,
None,
None,
None,
None,
None,
None,
)
def _quantize_int(
tensor: torch.Tensor,
scale: torch.Tensor,
quant_offset: torch.Tensor,
float_offset: torch.Tensor,
quant_min: int,
quant_max: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Integer quantization:
clamp(round((tensor - float_offset) / scale) + quant_offset, quant_min, quant_max)
Generic form parameterized by two offsets so the same kernel handles both
formulations:
- ZP: ``quant_offset = zero_point``, ``float_offset = 0``
- MINVAL: ``quant_offset = quant_min``, ``float_offset = minval``
The quantized tensor remains in FP dtype.
The quantized tensor as well as a mask tensor marking positions in the tensor which were clamped
are both returned.
"""
result = (tensor - float_offset) / scale
result.round_()
result.add_(quant_offset)
mask = result >= quant_min
mask &= result <= quant_max
result.clamp_(quant_min, quant_max)
return result, mask
def _dequantize_int(
tensor: torch.Tensor,
scale: torch.Tensor,
quant_offset: torch.Tensor,
float_offset: torch.Tensor,
) -> torch.Tensor:
"""
Integer dequantization:
(tensor - quant_offset) * scale + float_offset
Inverse of :func:`_quantize_int`. See that function's docstring for the
ZP / MINVAL offset conventions.
"""
return (tensor - quant_offset) * scale + float_offset
def _quantize_float(
tensor: torch.Tensor,
scale: torch.Tensor,
quant_min: float,
quant_max: float,
dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Float quantization: cast_decast(clamp(tensor / scale, min, max))
The quantized tensor as well as a mask tensor marking positions in the tensor which were clamped
are both returned.
"""
result = tensor / scale
mask = result >= quant_min
mask &= result <= quant_max
result.clamp_(min=quant_min, max=quant_max)
if _is_float8_dtype(dtype):
return _fp8_forward(result, dtype), mask
elif _is_float4_dtype(dtype):
return _fp4_forward(result), mask
else:
raise ValueError(f"Expected float4/float8 dtype, got {dtype}")
def _fp8_forward(tensor: torch.Tensor, dtype: torch.dtype):
# Hardcoding return dtype to torch.float32 - all callers of this private method already cast
# the incoming tensor to float32.
return tensor.to(dtype).to(torch.float32)
def _fp4_forward(tensor: torch.Tensor):
"""Perform tensor quantization for fp4 dtype"""
from torchao.prototype.mx_formats.kernels import ( # noqa: PLC0415
f4_unpacked_to_f32,
f32_to_f4_unpacked,
)
fp4_bits = f32_to_f4_unpacked(tensor)
return f4_unpacked_to_f32(fp4_bits)
def _dequantize_float(
tensor: torch.Tensor,
scale: torch.Tensor,
) -> torch.Tensor:
"""Float dequantization: tensor * scale"""
return tensor * scale