Applying quantization to an MNIST model¶
In this tutorial, we will be providing a basic introduction to quantizing a model with CoreAI-Opt.
After the end of this tutorial, you should be familiar with the following:
Table of Contents:
Setup¶
We will be using a basic CNN model and train it on the MNIST dataset and observe its final accuracy.
Once we train this CNN model, we will apply quantization to it using coreai-opt starting with weight only quantization (data free), then moving to apply calibration data based weight + activation quantization, and then finally do Quantization Aware Training (QAT).
[23]:
import random
from pathlib import Path
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchinfo import summary
from torchvision import datasets, transforms
[24]:
SEED = 1976
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
[24]:
<torch._C.Generator at 0x11286cc50>
[25]:
# Used to save intermediate results and datasets
SAVE_DIRECTORY = "."
MNIST Dataset download¶
Helper to download the MNIST dataset with standard normalization applied.
[26]:
def mnist_transforms() -> transforms.Compose:
return transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
def download_mnist_dataset(
download_path: Path, transform: transforms.Compose | None = None
) -> tuple[datasets.MNIST, datasets.MNIST]:
if transform is None:
transform = mnist_transforms()
train = datasets.MNIST(download_path, train=True, download=True, transform=transform)
test = datasets.MNIST(download_path, train=False, download=True, transform=transform)
return train, test
Model definition¶
A simple CNN with a single Conv2d → ReLU → MaxPool block, followed by Flatten and a Linear classifier.
[27]:
class MnistNetwork(nn.Module):
def __init__(self, num_classes: int = 10, state_dict: dict | None = None) -> None:
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(1, 12, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, stride=2, padding=0),
nn.Flatten(),
nn.Linear(2352, num_classes),
)
if state_dict is not None:
self.load_state_dict(state_dict)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
Training and Evaluation¶
Standard PyTorch training loop and evaluation function that computes accuracy.
[28]:
def train_step(model, optimizer, loss_fn, inputs, ground_truth) -> float:
model.train()
device = next(model.parameters()).device
inputs = inputs.to(device)
ground_truth = ground_truth.to(device)
optimizer.zero_grad()
predictions = model(inputs)
loss = loss_fn(predictions, ground_truth)
loss.backward()
optimizer.step()
return loss.item()
def train_epoch(model, train_loader, optimizer, loss_fn) -> float:
total_loss = 0.0
for inputs, ground_truth in train_loader:
loss = train_step(model, optimizer, loss_fn, inputs, ground_truth)
total_loss += loss
return total_loss / len(train_loader)
def create_adam_optimizer(model: nn.Module, lr: float = 1e-3) -> torch.optim.Adam:
return torch.optim.Adam(model.parameters(), lr=lr)
def eval_model(model: nn.Module, test_dataloader: DataLoader) -> float:
model.eval()
device = next(model.parameters()).device
num_correct = 0
total = 0
with torch.no_grad():
for inputs, ground_truth in test_dataloader:
inputs = inputs.to(device)
ground_truth = ground_truth.to(device)
predictions = model(inputs)
_, predicted = torch.max(predictions.data, 1)
total += ground_truth.size(0)
num_correct += (predicted == ground_truth).sum().item()
return num_correct / total
[29]:
# Download and instantiate datasets
DOWNLOAD_PATH = Path(SAVE_DIRECTORY) / ".mnist_dataset"
train_dataset, test_dataset = download_mnist_dataset(
download_path=DOWNLOAD_PATH
)
[30]:
BATCH_SIZE = 128
# Instantiate dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE)
The CNN model used for this tutorial contains a single Conv2d, ReLU, MaxPool2d, Flatten, and Linear layer. Here’s the structure:
[31]:
basic_cnn_model = MnistNetwork(num_classes=10)
# Print summary of model
summary(basic_cnn_model, input_size=(1, 1, 28, 28))
[31]:
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
MnistNetwork [1, 10] --
├─Sequential: 1-1 [1, 10] --
│ └─Conv2d: 2-1 [1, 12, 28, 28] 120
│ └─ReLU: 2-2 [1, 12, 28, 28] --
│ └─MaxPool2d: 2-3 [1, 12, 14, 14] --
│ └─Flatten: 2-4 [1, 2352] --
│ └─Linear: 2-5 [1, 10] 23,530
==========================================================================================
Total params: 23,650
Trainable params: 23,650
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.12
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.08
Params size (MB): 0.09
Estimated Total Size (MB): 0.17
==========================================================================================
Train unquantized model¶
Let’s train this model (unquantized) so we can get a baseline accuracy. We save the trained weights so we can reload them for each quantization experiment.
[32]:
EPOCHS = 10
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = create_adam_optimizer(basic_cnn_model)
basic_cnn_model = basic_cnn_model.to("mps")
epoch_results = []
for epoch in range(EPOCHS):
epoch_avg_loss = train_epoch(
model=basic_cnn_model, train_loader=train_loader, optimizer=optimizer, loss_fn=loss_fn
)
epoch_results.append(f" Epoch {epoch + 1}: loss={epoch_avg_loss:.4f}")
basic_cnn_model = basic_cnn_model.cpu().eval()
print("\n".join(epoch_results))
# Save trained weights for reuse in quantization sections
pretrained_state_dict = basic_cnn_model.state_dict()
Epoch 1: loss=0.3126
Epoch 2: loss=0.1170
Epoch 3: loss=0.0821
Epoch 4: loss=0.0669
Epoch 5: loss=0.0590
Epoch 6: loss=0.0522
Epoch 7: loss=0.0481
Epoch 8: loss=0.0452
Epoch 9: loss=0.0415
Epoch 10: loss=0.0379
[33]:
accuracy = eval_model(basic_cnn_model, test_loader)
print(f"Baseline accuracy: {accuracy:.4f}")
Baseline accuracy: 0.9799
Weight Only Quantization¶
Weight-only quantization compresses model weights to a lower precision (e.g., INT8) while keeping activations in full precision (FP32). This is the simplest form of quantization — it requires no calibration data and can be applied directly to a trained model.
[34]:
from coreai_opt.quantization import (
ModuleQuantizerConfig,
QuantizationSpec,
Quantizer,
QuantizerConfig,
)
example_inputs = (torch.randn(1, 1, 28, 28),)
For this tutorial, we will use the INT8 dtype. Refer to the QuantizationSpec reference for all options.
[35]:
WEIGHT_DTYPE = "int8"
QuantizerConfig describes how to quantize each operation through three spec keys:
op_state_spec: state tensors of the op (weights, biases).op_input_spec: input activations — tensors flowing into the op.op_output_spec: output activations — tensors flowing out of the op.
For weight-only quantization, only the weight state spec is populated; setting op_input_spec and op_output_spec to None tells the quantizer to leave activations in full precision.
[36]:
wo_model = MnistNetwork(num_classes=10, state_dict=pretrained_state_dict)
wo_model.eval()
weight_spec = QuantizationSpec(dtype=WEIGHT_DTYPE)
wo_config = QuantizerConfig(
global_config=ModuleQuantizerConfig(
op_state_spec={"weight": weight_spec},
op_input_spec=None,
op_output_spec=None,
)
)
wo_quantizer = Quantizer(wo_model, wo_config)
wo_prepared = wo_quantizer.prepare(example_inputs)
print(f"Prepared weight-only quantization with dtype={WEIGHT_DTYPE}")
Prepared weight-only quantization with dtype=int8
Note:
coreai-optalso offersQuantizerConfig.presets.w8()as a shortcut for the above config.
After calling prepare(), the model now has FakeQuantize modules inserted for weight tensors. These simulate quantization during the forward pass.
[37]:
wo_accuracy = eval_model(wo_prepared, test_loader)
print(f"Weight-only PTQ accuracy: {wo_accuracy:.4f}")
Weight-only PTQ accuracy: 0.9800
Weight and Activation Quantization¶
Now, let’s try using weight + activation quantization through coreai-opt.
Weight + Activation quantization compresses both the model’s weights and intermediate activation tensors to a lower precision. This provides a greater speedup than weight-only, but requires calibration data — representative inputs are run through the model so it can compute appropriate scale and zero-point values for the quantized activations.
For this tutorial, we will use the INT8 dtype for activations. Refer to the QuantizationSpec reference for all options.
[38]:
ACTIVATION_DTYPE = "int8"
Two changes from the weight-only config:
op_input_specandop_output_specare now populated, so activations are quantized too.The
"*"key applies the same INT8 spec to every operation input/output.
[39]:
wa_model = MnistNetwork(num_classes=10, state_dict=pretrained_state_dict)
wa_model.eval()
activation_spec = QuantizationSpec(dtype=ACTIVATION_DTYPE)
wa_config = QuantizerConfig(
global_config=ModuleQuantizerConfig(
op_state_spec={"weight": weight_spec},
op_input_spec={"*": activation_spec},
op_output_spec={"*": activation_spec},
)
)
wa_quantizer = Quantizer(wa_model, wa_config)
wa_prepared = wa_quantizer.prepare(example_inputs)
print(
f"Prepared weight + activation quantization (weight={WEIGHT_DTYPE}, activation={ACTIVATION_DTYPE})"
)
Prepared weight + activation quantization (weight=int8, activation=int8)
We now call calibration_mode() and feed it representative inputs to populate scale and zero-point value.
The calibration_mode() context manager enables range observers (to collect activation statistics) while disabling fake quantization (so the forward pass is numerically identical to the unquantized model). After exiting the context, observers are frozen and fake quantization is re-enabled.
[40]:
NUM_CALIBRATION_BATCHES = 1
with wa_quantizer.calibration_mode():
for i, (data, _) in enumerate(train_loader):
if i >= NUM_CALIBRATION_BATCHES:
break
wa_prepared(data)
print(
f"Calibrated with {NUM_CALIBRATION_BATCHES} batch(es) ({NUM_CALIBRATION_BATCHES * BATCH_SIZE} samples)"
)
Calibrated with 1 batch(es) (128 samples)
[41]:
wa_accuracy = eval_model(wa_prepared, test_loader)
print(f"Weight + Activation accuracy after prepare + calibration: {wa_accuracy:.4f}")
Weight + Activation accuracy after prepare + calibration: 0.9798
Quantization Aware Training¶
Now, let’s try to fine-tune the weight + activation quantized model with QAT to recover any accuracy lost from quantization.
The QAT training loop wraps each epoch in wa_quantizer.training_mode():
During the context: fake quantization is enabled and observers are enabled (collecting activation statistics during training batches).
On exit: observers are frozen; fake quantization stays enabled, so the model is ready for evaluation with quant-aware weights.
[42]:
QAT_EPOCHS = 5
qat_optimizer = create_adam_optimizer(wa_prepared, lr=1e-4)
wa_prepared.to("mps")
for epoch in range(QAT_EPOCHS):
with wa_quantizer.training_mode():
epoch_loss = train_epoch(
model=wa_prepared,
train_loader=train_loader,
optimizer=qat_optimizer,
loss_fn=torch.nn.CrossEntropyLoss(),
)
qat_acc = eval_model(wa_prepared, test_loader)
print(f" Epoch {epoch + 1}: loss={epoch_loss:.4f}, accuracy={qat_acc:.4f}")
qat_prepared = wa_prepared.cpu()
print(f"\nQAT final accuracy: {qat_acc:.4f}")
Epoch 1: loss=0.0276, accuracy=0.9832
Epoch 2: loss=0.0261, accuracy=0.9828
Epoch 3: loss=0.0253, accuracy=0.9835
Epoch 4: loss=0.0249, accuracy=0.9837
Epoch 5: loss=0.0242, accuracy=0.9830
QAT final accuracy: 0.9830
Export to Core AI¶
Once the quantized model is ready, call finalize() to convert the fake quantization modules into the real quantized representation for deployment.
Pass ExportBackend.CoreAI to finalize(backend=...) to target the .aimodel format produced by coreai-torch.
We’ll export the finetuned weight + activation model from the previous section.
[43]:
from coreai_opt import ExportBackend
coreai_model = wa_quantizer.finalize(backend=ExportBackend.CoreAI)
The export proceeds in three steps:
Trace the model with
torch.export.export()to obtain a graph representation.Apply
cast_to_16_bit_precision()to cast remaining FP32 parameters to FP16 for optimal on-device performance.Convert the exported program to Core AI format using
coreai-torch.TorchConverter.
[44]:
import shutil
from coreai_opt.casting import cast_to_16_bit_precision
from coreai_torch import TorchConverter, get_decomp_table
exported_program = torch.export.export(coreai_model, example_inputs, strict=False)
exported_program = exported_program.run_decompositions(get_decomp_table())
cast_to_16_bit_precision(exported_program)
coreai_program = TorchConverter().add_exported_program(exported_program).to_coreai()
coreai_program.optimize()
output_path = Path(SAVE_DIRECTORY) / "exported_model.aimodel"
if output_path.exists():
shutil.rmtree(output_path)
coreai_program.save_asset(output_path)
print(f"Exported: {output_path}")
Exported: exported_model.aimodel