Applying quantization to an MNIST model

In this tutorial, we will be providing a basic introduction to quantizing a model with CoreAI-Opt.

After the end of this tutorial, you should be familiar with the following:

  1. How to apply CoreAI-Opt’s Weight-Only quantization

  2. How to apply CoreAI-Opt’s Weight + Activation quantization

  3. How to apply CoreAI-Opt’s Quantization-Aware Training

  4. How to export CoreAI-Opt quantized models to Core AI

Table of Contents:

Setup

We will be using a basic CNN model and train it on the MNIST dataset and observe its final accuracy.

Once we train this CNN model, we will apply quantization to it using coreai-opt starting with weight only quantization (data free), then moving to apply calibration data based weight + activation quantization, and then finally do Quantization Aware Training (QAT).

[23]:
import random
from pathlib import Path

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchinfo import summary
from torchvision import datasets, transforms
[24]:
SEED = 1976

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
[24]:
<torch._C.Generator at 0x11286cc50>
[25]:
# Used to save intermediate results and datasets
SAVE_DIRECTORY = "."

MNIST Dataset download

Helper to download the MNIST dataset with standard normalization applied.

[26]:
def mnist_transforms() -> transforms.Compose:
    return transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])


def download_mnist_dataset(
    download_path: Path, transform: transforms.Compose | None = None
) -> tuple[datasets.MNIST, datasets.MNIST]:
    if transform is None:
        transform = mnist_transforms()
    train = datasets.MNIST(download_path, train=True, download=True, transform=transform)
    test = datasets.MNIST(download_path, train=False, download=True, transform=transform)
    return train, test

Model definition

A simple CNN with a single Conv2d → ReLU → MaxPool block, followed by Flatten and a Linear classifier.

[27]:
class MnistNetwork(nn.Module):
    def __init__(self, num_classes: int = 10, state_dict: dict | None = None) -> None:
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 12, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2, padding=0),
            nn.Flatten(),
            nn.Linear(2352, num_classes),
        )
        if state_dict is not None:
            self.load_state_dict(state_dict)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

Training and Evaluation

Standard PyTorch training loop and evaluation function that computes accuracy.

[28]:
def train_step(model, optimizer, loss_fn, inputs, ground_truth) -> float:
    model.train()
    device = next(model.parameters()).device
    inputs = inputs.to(device)
    ground_truth = ground_truth.to(device)
    optimizer.zero_grad()
    predictions = model(inputs)
    loss = loss_fn(predictions, ground_truth)
    loss.backward()
    optimizer.step()
    return loss.item()


def train_epoch(model, train_loader, optimizer, loss_fn) -> float:
    total_loss = 0.0
    for inputs, ground_truth in train_loader:
        loss = train_step(model, optimizer, loss_fn, inputs, ground_truth)
        total_loss += loss
    return total_loss / len(train_loader)


def create_adam_optimizer(model: nn.Module, lr: float = 1e-3) -> torch.optim.Adam:
    return torch.optim.Adam(model.parameters(), lr=lr)


def eval_model(model: nn.Module, test_dataloader: DataLoader) -> float:
    model.eval()
    device = next(model.parameters()).device
    num_correct = 0
    total = 0
    with torch.no_grad():
        for inputs, ground_truth in test_dataloader:
            inputs = inputs.to(device)
            ground_truth = ground_truth.to(device)
            predictions = model(inputs)
            _, predicted = torch.max(predictions.data, 1)
            total += ground_truth.size(0)
            num_correct += (predicted == ground_truth).sum().item()
    return num_correct / total
[29]:
# Download and instantiate datasets
DOWNLOAD_PATH = Path(SAVE_DIRECTORY) / ".mnist_dataset"

train_dataset, test_dataset = download_mnist_dataset(
    download_path=DOWNLOAD_PATH
)
[30]:
BATCH_SIZE = 128

# Instantiate dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE)

The CNN model used for this tutorial contains a single Conv2d, ReLU, MaxPool2d, Flatten, and Linear layer. Here’s the structure:

[31]:
basic_cnn_model = MnistNetwork(num_classes=10)

# Print summary of model
summary(basic_cnn_model, input_size=(1, 1, 28, 28))
[31]:
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
MnistNetwork                             [1, 10]                   --
├─Sequential: 1-1                        [1, 10]                   --
│    └─Conv2d: 2-1                       [1, 12, 28, 28]           120
│    └─ReLU: 2-2                         [1, 12, 28, 28]           --
│    └─MaxPool2d: 2-3                    [1, 12, 14, 14]           --
│    └─Flatten: 2-4                      [1, 2352]                 --
│    └─Linear: 2-5                       [1, 10]                   23,530
==========================================================================================
Total params: 23,650
Trainable params: 23,650
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.12
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.08
Params size (MB): 0.09
Estimated Total Size (MB): 0.17
==========================================================================================

Train unquantized model

Let’s train this model (unquantized) so we can get a baseline accuracy. We save the trained weights so we can reload them for each quantization experiment.

[32]:
EPOCHS = 10

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = create_adam_optimizer(basic_cnn_model)

basic_cnn_model = basic_cnn_model.to("mps")

epoch_results = []
for epoch in range(EPOCHS):
    epoch_avg_loss = train_epoch(
        model=basic_cnn_model, train_loader=train_loader, optimizer=optimizer, loss_fn=loss_fn
    )
    epoch_results.append(f"  Epoch {epoch + 1}: loss={epoch_avg_loss:.4f}")

basic_cnn_model = basic_cnn_model.cpu().eval()
print("\n".join(epoch_results))

# Save trained weights for reuse in quantization sections
pretrained_state_dict = basic_cnn_model.state_dict()
  Epoch 1: loss=0.3126
  Epoch 2: loss=0.1170
  Epoch 3: loss=0.0821
  Epoch 4: loss=0.0669
  Epoch 5: loss=0.0590
  Epoch 6: loss=0.0522
  Epoch 7: loss=0.0481
  Epoch 8: loss=0.0452
  Epoch 9: loss=0.0415
  Epoch 10: loss=0.0379
[33]:
accuracy = eval_model(basic_cnn_model, test_loader)
print(f"Baseline accuracy: {accuracy:.4f}")
Baseline accuracy: 0.9799

Weight Only Quantization

Weight-only quantization compresses model weights to a lower precision (e.g., INT8) while keeping activations in full precision (FP32). This is the simplest form of quantization — it requires no calibration data and can be applied directly to a trained model.

[34]:
from coreai_opt.quantization import (
    ModuleQuantizerConfig,
    QuantizationSpec,
    Quantizer,
    QuantizerConfig,
)

example_inputs = (torch.randn(1, 1, 28, 28),)

For this tutorial, we will use the INT8 dtype. Refer to the QuantizationSpec reference for all options.

[35]:
WEIGHT_DTYPE = "int8"

QuantizerConfig describes how to quantize each operation through three spec keys:

  • op_state_spec: state tensors of the op (weights, biases).

  • op_input_spec: input activations — tensors flowing into the op.

  • op_output_spec: output activations — tensors flowing out of the op.

For weight-only quantization, only the weight state spec is populated; setting op_input_spec and op_output_spec to None tells the quantizer to leave activations in full precision.

[36]:
wo_model = MnistNetwork(num_classes=10, state_dict=pretrained_state_dict)
wo_model.eval()

weight_spec = QuantizationSpec(dtype=WEIGHT_DTYPE)

wo_config = QuantizerConfig(
    global_config=ModuleQuantizerConfig(
        op_state_spec={"weight": weight_spec},
        op_input_spec=None,
        op_output_spec=None,
    )
)

wo_quantizer = Quantizer(wo_model, wo_config)
wo_prepared = wo_quantizer.prepare(example_inputs)
print(f"Prepared weight-only quantization with dtype={WEIGHT_DTYPE}")
Prepared weight-only quantization with dtype=int8

Note: coreai-opt also offers QuantizerConfig.presets.w8() as a shortcut for the above config.

After calling prepare(), the model now has FakeQuantize modules inserted for weight tensors. These simulate quantization during the forward pass.

[37]:
wo_accuracy = eval_model(wo_prepared, test_loader)
print(f"Weight-only PTQ accuracy: {wo_accuracy:.4f}")
Weight-only PTQ accuracy: 0.9800

Weight and Activation Quantization

Now, let’s try using weight + activation quantization through coreai-opt.

Weight + Activation quantization compresses both the model’s weights and intermediate activation tensors to a lower precision. This provides a greater speedup than weight-only, but requires calibration data — representative inputs are run through the model so it can compute appropriate scale and zero-point values for the quantized activations.

For this tutorial, we will use the INT8 dtype for activations. Refer to the QuantizationSpec reference for all options.

[38]:
ACTIVATION_DTYPE = "int8"

Two changes from the weight-only config:

  • op_input_spec and op_output_spec are now populated, so activations are quantized too.

  • The "*" key applies the same INT8 spec to every operation input/output.

[39]:
wa_model = MnistNetwork(num_classes=10, state_dict=pretrained_state_dict)
wa_model.eval()

activation_spec = QuantizationSpec(dtype=ACTIVATION_DTYPE)

wa_config = QuantizerConfig(
    global_config=ModuleQuantizerConfig(
        op_state_spec={"weight": weight_spec},
        op_input_spec={"*": activation_spec},
        op_output_spec={"*": activation_spec},
    )
)

wa_quantizer = Quantizer(wa_model, wa_config)
wa_prepared = wa_quantizer.prepare(example_inputs)
print(
    f"Prepared weight + activation quantization (weight={WEIGHT_DTYPE}, activation={ACTIVATION_DTYPE})"
)
Prepared weight + activation quantization (weight=int8, activation=int8)

We now call calibration_mode() and feed it representative inputs to populate scale and zero-point value.

The calibration_mode() context manager enables range observers (to collect activation statistics) while disabling fake quantization (so the forward pass is numerically identical to the unquantized model). After exiting the context, observers are frozen and fake quantization is re-enabled.

[40]:
NUM_CALIBRATION_BATCHES = 1

with wa_quantizer.calibration_mode():
    for i, (data, _) in enumerate(train_loader):
        if i >= NUM_CALIBRATION_BATCHES:
            break
        wa_prepared(data)

print(
    f"Calibrated with {NUM_CALIBRATION_BATCHES} batch(es) ({NUM_CALIBRATION_BATCHES * BATCH_SIZE} samples)"
)
Calibrated with 1 batch(es) (128 samples)
[41]:
wa_accuracy = eval_model(wa_prepared, test_loader)
print(f"Weight + Activation accuracy after prepare + calibration: {wa_accuracy:.4f}")
Weight + Activation accuracy after prepare + calibration: 0.9798

Quantization Aware Training

Now, let’s try to fine-tune the weight + activation quantized model with QAT to recover any accuracy lost from quantization.

The QAT training loop wraps each epoch in wa_quantizer.training_mode():

  • During the context: fake quantization is enabled and observers are enabled (collecting activation statistics during training batches).

  • On exit: observers are frozen; fake quantization stays enabled, so the model is ready for evaluation with quant-aware weights.

[42]:
QAT_EPOCHS = 5
qat_optimizer = create_adam_optimizer(wa_prepared, lr=1e-4)

wa_prepared.to("mps")

for epoch in range(QAT_EPOCHS):
    with wa_quantizer.training_mode():
        epoch_loss = train_epoch(
            model=wa_prepared,
            train_loader=train_loader,
            optimizer=qat_optimizer,
            loss_fn=torch.nn.CrossEntropyLoss(),
        )

    qat_acc = eval_model(wa_prepared, test_loader)
    print(f"  Epoch {epoch + 1}: loss={epoch_loss:.4f}, accuracy={qat_acc:.4f}")

qat_prepared = wa_prepared.cpu()
print(f"\nQAT final accuracy: {qat_acc:.4f}")
  Epoch 1: loss=0.0276, accuracy=0.9832
  Epoch 2: loss=0.0261, accuracy=0.9828
  Epoch 3: loss=0.0253, accuracy=0.9835
  Epoch 4: loss=0.0249, accuracy=0.9837
  Epoch 5: loss=0.0242, accuracy=0.9830

QAT final accuracy: 0.9830

Export to Core AI

Once the quantized model is ready, call finalize() to convert the fake quantization modules into the real quantized representation for deployment.

Pass ExportBackend.CoreAI to finalize(backend=...) to target the .aimodel format produced by coreai-torch.

We’ll export the finetuned weight + activation model from the previous section.

[43]:
from coreai_opt import ExportBackend

coreai_model = wa_quantizer.finalize(backend=ExportBackend.CoreAI)

The export proceeds in three steps:

  • Trace the model with torch.export.export() to obtain a graph representation.

  • Apply cast_to_16_bit_precision() to cast remaining FP32 parameters to FP16 for optimal on-device performance.

  • Convert the exported program to Core AI format using coreai-torch.TorchConverter.

[44]:
import shutil

from coreai_opt.casting import cast_to_16_bit_precision
from coreai_torch import TorchConverter, get_decomp_table

exported_program = torch.export.export(coreai_model, example_inputs, strict=False)
exported_program = exported_program.run_decompositions(get_decomp_table())
cast_to_16_bit_precision(exported_program)

coreai_program = TorchConverter().add_exported_program(exported_program).to_coreai()
coreai_program.optimize()

output_path = Path(SAVE_DIRECTORY) / "exported_model.aimodel"
if output_path.exists():
    shutil.rmtree(output_path)
coreai_program.save_asset(output_path)
print(f"Exported: {output_path}")
Exported: exported_model.aimodel