Training-Time Quantization

Quantization refers to techniques for performing neural network computations in lower precision than floating point. Quantization can reduce a model’s size and also improve a model’s inference latency and memory bandwidth requirement, because many hardware platforms offer high-performance implementations of quantized operations.

class coremltools.optimize.torch.quantization.ModuleLinearQuantizerConfig(weight_dtype: Union[str, dtype] = torch.qint8, weight_observer=ObserverType.moving_average_min_max, weight_per_channel: bool = True, activation_dtype: Union[str, dtype] = torch.quint8, activation_observer=ObserverType.moving_average_min_max, quantization_scheme=QuantizationScheme.symmetric, milestones: Optional[List[int]] = None)[source]

Configuration class for specifying global and module level quantization options for linear quantization algorithm implemented in LinearQuantizer.

Linear quantization algorithm simulates the effects of quantization during training, by quantizing and dequantizing the weights and/or activations during the model’s forward pass. The forward and backward pass computations are conducted in float dtype, however, these float values follow the constraints imposed by int8 and quint8 dtypes. For more details, please refer to Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference.

For most applications, the only parameters that need to be set are quantization_scheme and milestones.

By default, quantization_scheme is set to QuantizationScheme.symmetric, which means all weights are quantized with zero point as zero, and activations are quantized with zero point as zero for non-negative activations and 128 for all other activations. The weights are quantized using torch.qint8 and activations are quantized using torch.quint8.

Linear quantization algorithm inserts observers for each weight/activation tensor. These observers collect statistics of these tensors’ values, for example, the minimum and maximum values they can take. These statistics are then used to compute the scale and zero point, which are in turn used for quantizing the weights/activations. By default, moving_average_min_max observer is used. For more details, please check MinMaxObserver.

The milestones parameter controls the flow of the quantization algorithm. The example below illustrates its usage in more detail:

model = define_model()

config = LinearQuantizerConfig(
    global_config=ModuleLinearQuantizerConfig(
        quantization_scheme="symmetric",
        milestones=[0, 100, 300, 200],
    )
)

quantizer = LinearQuantizer(model, config)

# prepare the model to insert FakeQuantize layers for QAT
model = quantizer.prepare()

# use quantizer in your PyTorch training loop
for inputs, labels in data:
    output = model(inputs)
    loss = loss_fn(output, labels)
    loss.backward()
    optimizer.step()
    quantizer.step()

# In this example, from step 0 onwards, observers will collect statistics
# of the values of weights/activations. However, between steps 0 and 100,
# effects of quantization will not be simulated. At step 100, quantization
# simulation will begin and at step 300, observer statistics collection will
# stop. A batch norm layer computes mean and variance of input batch for normalizing
# it during training, and collects running estimates of its computed mean and variance,
# which are then used for normalization during evaluation. At step 200, batch norm
# statistics collection is frozen, and the batch norm layers switch to evaluation
# mode, thus more closely simulating the inference numerics during training time.
Parameters:
  • weight_dtype (torch.dtype) – The dtype to use for quantizing the weights. When dtype is set to torch.float32, the weights corresponding to that layer are not quantized. Defaults to torch.qint8.

  • weight_observer (ObserverType) – Type of observer to use for quantizing weights. Defaults to moving_average_min_max.

  • weight_per_channel (bool) – When True, weights are quantized per channel; otherwise, per tensor.

  • activation_dtype (torch.dtype) – The dtype to use for quantizing the activations. When dtype is set to torch.float32, the activations corresponding to that layer are not quantized. Defaults to torch.quint8.

  • activation_observer (ObserverType) – Type of observer to use for quantizing activations. Allowed values are min_max and moving_average_min_max. Defaults to moving_average_min_max.

  • quantization_scheme – (QuantizationScheme): Type of quantization configuration to use. When this parameter is set to QuantizationScheme.symmetric, all weights are quantized with zero point as zero, and activations are quantized with zero point as zero for non-negative activations and 128 for all other activations. When it is set to QuantizationScheme.affine, zero point can be set anywhere in the range of values allowed for the quantized weight/activation. Defaults to QuantizationScheme.symmetric.

  • milestones (list of int) – A list of four integers indicating milestones to use during quantization. The first milestone corresponds to enabling observers, the second to enabling fake quantization simulation, the third to disabling observers, and the last to freezing batch norm statistics. Defaults to None, which means the step method of LinearQuantizer will be a no-op and all observers and quantization simulation will be turned on from the first step, batch norm layers always operate in training mode, and mean and variance statistics collection is not frozen.

as_dict() Dict[str, Any]

Returns the config as a dictionary.

classmethod from_dict(config_dict: Dict[str, Any]) ModuleOptimizationConfig

Create class from a dictionary of string keys and values.

Parameters:

config_dict (dict of str and values) – A nested dictionary of strings and values.

classmethod from_yaml(yml: Union[IO, str]) ModuleOptimizationConfig

Create class from a yaml stream.

Parameters:

yml – An IO stream containing yaml or a str path to the yaml file.

class coremltools.optimize.torch.quantization.LinearQuantizerConfig(global_config: Optional[ModuleLinearQuantizerConfig] = None, module_type_configs: ModuleTypeConfigType = _Nothing.NOTHING, module_name_configs: Dict[str, Optional[ModuleLinearQuantizerConfig]] = _Nothing.NOTHING, non_traceable_module_names: List[str] = [])[source]

Configuration class for specifying how different submodules of a model are quantized by LinearQuantizer.

In order to disable quantizing a layer or an operation, module_type_config or module_name_config corresponding to that operation can be set to None.

For example:

# The following config will enable weight only quantization for all layers:
config = LinearQuantizerConfig.from_dict(
    {
        "global_config": {
            "activation_dtype": "float32",
        }
    }
)

# The following config will disable quantization for all linear layers and
# set quantization mode to weight only quantization for convolution layers:
config = LinearQuantizerConfig.from_dict(
    {
        "module_type_configs": {
            "Linear": None,
            "Conv2d": {
                "activation_dtype": "float32",
            },
        }
    }
)

# The following config will disable quantization for layers named conv1 and conv2:
config = LinearQuantizerConfig.from_dict(
    {
        "module_name_configs": {
            "conv1": None,
            "conv2": None,
        }
    }
)
Parameters:
  • global_config (ModuleLinearQuantizerConfig) – Config to be applied globally to all supported modules. Missing values are chosen from the default config.

  • module_type_configs (dict of str to ModuleLinearQuantizerConfig) – Module type level configs applied to a specific module class, such as torch.nn.Linear. The keys can be either strings or module classes.

  • module_name_configs (dict of str to ModuleLinearQuantizerConfig) – Module level configs applied to specific modules. The name of the module must be a fully qualified name that can be used to fetch it from the top level module using the module.get_submodule(target) method.

  • non_traceable_module_names (list of str) – Names of modules which cannot be traced using torch.fx.

Note

The quantization_scheme parameter must be the same across all configs.

as_dict() Dict[str, Any]

Returns the config as a dictionary.

classmethod from_dict(config_dict: Dict[str, Any]) LinearQuantizerConfig[source]

Create class from a dictionary of string keys and values.

Parameters:

config_dict (dict of str and values) – A nested dictionary of strings and values.

classmethod from_yaml(yml: Union[IO, str]) OptimizationConfig

Create class from a yaml stream.

Parameters:

yml – An IO stream containing yaml or a str path to the yaml file.

set_global(global_config: Optional[ModuleOptimizationConfig]) OptimizationConfig

Set the global config.

set_module_name(module_name: str, opt_config: Optional[ModuleOptimizationConfig]) OptimizationConfig

Set the module level optimization config for a given module instance. If the module level optimization config for an existing module was already set, the new config will override the old one.

set_module_type(object_type: Union[Callable, str], opt_config: Optional[ModuleOptimizationConfig]) OptimizationConfig

Set the module level optimization config for a given module type. If the module level optimization config for an existing module type was already set, the new config will override the old one.

class coremltools.optimize.torch.quantization.LinearQuantizer(model: Module, config: Optional[LinearQuantizerConfig] = None)[source]

Perform quantization aware training (QAT) of models. This algorithm simulates the effects of quantization during training, by quantizing and dequantizing the weights and/or activations during the model’s forward pass. The forward and backward pass computations are conducted in float dtype, however, these float values follow the constraints imposed by int8 and quint8 dtypes. Thus, this algorithm adjusts the model’s weights while closely simulating the numerics which get executed during quantized inference, allowing model’s weights to adjust to quantization constraints.

For more details, please refer to Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference.

Example

import torch.nn as nn
from coremltools.optimize.torch.quantization import (
    LinearQuantizer,
    LinearQuantizerConfig,
)

model = nn.Sequential(
    OrderedDict(
        {
            "conv": nn.Conv2d(1, 20, (3, 3)),
            "relu1": nn.ReLU(),
            "conv2": nn.Conv2d(20, 20, (3, 3)),
            "relu2": nn.ReLU(),
        }
    )
)

loss_fn = define_loss()

# initialize the quantizer
config = LinearQuantizerConfig.from_dict(
    {
        "global_config": {
            "quantization_scheme": "symmetric",
            "milestones": [0, 100, 400, 400],
        }
    }
)

quantizer = LinearQuantizer(model, config)

# prepare the model to insert FakeQuantize layers for QAT
model = quantizer.prepare()

# use quantizer in your PyTorch training loop
for inputs, labels in data:
    output = model(inputs)
    loss = loss_fn(output, labels)
    loss.backward()
    optimizer.step()
    quantizer.step()

# convert operations to their quanitzed counterparts using parameters learnt via QAT
model = quantizer.finalize(inplace=True)
Parameters:
  • model (torch.nn.Module) – Module to be trained.

  • config (_LinearQuantizerConfig) – Config that specifies how different submodules in the model will be quantized. Default config is used when passed as None.

finalize(model: Optional[Module] = None, inplace: bool = False) Module[source]

Prepares the model for export.

Parameters:
  • model (_torch.nn.Module) – Model to be finalized.

  • inplace (bool) – If True, model transformations are carried out in-place and the original module is mutated; otherwise, a copy of the model is mutated and returned.

Note

Once the model is finalized with in_place = True, it may not be runnable on the GPU.

prepare(example_inputs: Any, inplace: bool = False) Module[source]

Prepares the model for quantization aware training by inserting torch.ao.quantization.FakeQuantize layers in the model in appropriate places.

Parameters:

inplace (bool) – If True, model transformations are carried out in-place and the original module is mutated, otherwise a copy of the model is mutated and returned.

Note

This method uses prepare_qat_fx method to insert quantization layers and the returned model is a torch.fx.GraphModule. Some models, like those with dynamic control flow, may not be trace-able into a torch.fx.GraphModule. Please follow directions in Limitations of Symbolic Tracing to update your model first before using LinearQuantizer algorithm.

report() _Report[source]

Returns a dictionary with important statistics related to current state of quantization. Each key in the dictionary corresponds to a module name, and the value is a dictionary containing the statistics such as scale, zero point, number of parameters, and so on.

step()[source]

Steps through the milestones defined for this quantizer.

The first milestone corresponds to enabling observers, the second to enabling fake quantization simulation, the third to disabling observers, and the last to freezing batch norm statistics.

Note

If milestones argument is set as None, this method is a no-op.

Note

In order to not use a particular milestone, its value can be set as float('inf').

class coremltools.optimize.torch.quantization.ObserverType(value)[source]

An enum indicating the type of observer. Allowed options are moving_average_min_max and mix_max.

class coremltools.optimize.torch.quantization.QuantizationScheme(value)[source]

An enum indicating the type of quantization to be performed. Allowed options are symmetric and affine.