Applying joint palettization and activation quantization to an MNIST model¶
In this tutorial, we will demonstrate how to combine weight palettization with activation quantization (joint P4A8 compression) using CoreAI-Opt.
After the end of this tutorial, you should be familiar with the following:
Table of Contents:
Setup¶
We will be using a basic CNN model and train it on the MNIST dataset and observe its final accuracy.
Once we train this CNN model, we will apply joint compression: first palettizing the weights (with LUT quantization), then quantizing the activations to int8.
[20]:
import random
from pathlib import Path
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchinfo import summary
from torchvision import datasets, transforms
[21]:
SEED = 1976
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
[21]:
<torch._C.Generator at 0x10e364c30>
[22]:
# Used to save intermediate results and datasets
SAVE_DIRECTORY = "."
MNIST Dataset download¶
Helper to download the MNIST dataset with standard normalization applied.
[23]:
def mnist_transforms() -> transforms.Compose:
return transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
def download_mnist_dataset(
download_path: Path, transform: transforms.Compose | None = None
) -> tuple[datasets.MNIST, datasets.MNIST]:
if transform is None:
transform = mnist_transforms()
train = datasets.MNIST(download_path, train=True, download=True, transform=transform)
test = datasets.MNIST(download_path, train=False, download=True, transform=transform)
return train, test
Model definition¶
A simple CNN with a single Conv2d → ReLU → MaxPool block, followed by Flatten and a Linear classifier.
[24]:
class MnistNetwork(nn.Module):
def __init__(self, num_classes: int = 10, state_dict: dict | None = None) -> None:
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(1, 12, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, stride=2, padding=0),
nn.Flatten(),
nn.Linear(2352, num_classes),
)
if state_dict is not None:
self.load_state_dict(state_dict)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
Training and Evaluation¶
Standard PyTorch training loop and evaluation function that computes accuracy.
[25]:
def train_step(model, optimizer, loss_fn, inputs, ground_truth) -> float:
model.train()
device = next(model.parameters()).device
inputs = inputs.to(device)
ground_truth = ground_truth.to(device)
optimizer.zero_grad()
predictions = model(inputs)
loss = loss_fn(predictions, ground_truth)
loss.backward()
optimizer.step()
return loss.item()
def train_epoch(model, train_loader, optimizer, loss_fn) -> float:
total_loss = 0.0
for inputs, ground_truth in train_loader:
loss = train_step(model, optimizer, loss_fn, inputs, ground_truth)
total_loss += loss
return total_loss / len(train_loader)
def create_adam_optimizer(model: nn.Module, lr: float = 1e-3) -> torch.optim.Adam:
return torch.optim.Adam(model.parameters(), lr=lr)
def eval_model(model: nn.Module, test_dataloader: DataLoader) -> float:
model.eval()
device = next(model.parameters()).device
num_correct = 0
total = 0
with torch.no_grad():
for inputs, ground_truth in test_dataloader:
inputs = inputs.to(device)
ground_truth = ground_truth.to(device)
predictions = model(inputs)
_, predicted = torch.max(predictions.data, 1)
total += ground_truth.size(0)
num_correct += (predicted == ground_truth).sum().item()
return num_correct / total
[26]:
# Download and instantiate datasets
DOWNLOAD_PATH = Path(SAVE_DIRECTORY) / ".mnist_dataset"
train_dataset, test_dataset = download_mnist_dataset(
download_path=DOWNLOAD_PATH
)
100%|██████████| 9.91M/9.91M [00:01<00:00, 8.79MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 393kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.03MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 3.59MB/s]
[27]:
BATCH_SIZE = 128
# Instantiate dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE)
The CNN model used for this tutorial contains a single Conv2d, ReLU, MaxPool, Flatten, and Linear layer. Here’s the structure:
[28]:
basic_cnn_model = MnistNetwork(num_classes=10)
# Print summary of model
summary(basic_cnn_model, input_size=(1, 1, 28, 28))
[28]:
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
MnistNetwork [1, 10] --
├─Sequential: 1-1 [1, 10] --
│ └─Conv2d: 2-1 [1, 12, 28, 28] 120
│ └─ReLU: 2-2 [1, 12, 28, 28] --
│ └─MaxPool2d: 2-3 [1, 12, 14, 14] --
│ └─Flatten: 2-4 [1, 2352] --
│ └─Linear: 2-5 [1, 10] 23,530
==========================================================================================
Total params: 23,650
Trainable params: 23,650
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.12
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.08
Params size (MB): 0.09
Estimated Total Size (MB): 0.17
==========================================================================================
Train baseline model¶
Let’s train this model so we can get a baseline accuracy. We save the trained weights so we can reload them for each compression experiment.
[29]:
EPOCHS = 10
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = create_adam_optimizer(basic_cnn_model)
basic_cnn_model = basic_cnn_model.to("mps")
epoch_results = []
for epoch in range(EPOCHS):
epoch_avg_loss = train_epoch(
model=basic_cnn_model, train_loader=train_loader, optimizer=optimizer, loss_fn=loss_fn
)
epoch_results.append(f" Epoch {epoch + 1}: loss={epoch_avg_loss:.4f}")
basic_cnn_model = basic_cnn_model.cpu().eval()
print("\n".join(epoch_results))
# Save trained weights for reuse in compression sections
pretrained_state_dict = basic_cnn_model.state_dict()
Epoch 1: loss=0.3126
Epoch 2: loss=0.1170
Epoch 3: loss=0.0821
Epoch 4: loss=0.0669
Epoch 5: loss=0.0590
Epoch 6: loss=0.0522
Epoch 7: loss=0.0481
Epoch 8: loss=0.0452
Epoch 9: loss=0.0415
Epoch 10: loss=0.0379
[30]:
accuracy = eval_model(basic_cnn_model, test_loader)
print(f"Baseline accuracy: {accuracy:.4f}")
Baseline accuracy: 0.9799
K-Means Palettization + Activation Quantization¶
In this example we will combine weight palettization with activation quantization. Here we will palettize the weights such that the entries of the look-up-table (LUT) are quantized to INT8. And then we apply activation quantization to INT8. With INT8 LUT quantization, and INT8 activations, all operations may be able to run fully in INT8 arithmetic — providing speedup on certain Apple platforms. Reference the Joint Compression documentation for more information.
The workflow applies the two compressors sequentially: palettize weights first, finalize to Core AI, then quantize activations on the palettized model.
[31]:
from coreai_opt import ExportBackend
from coreai_opt.palettization import (
KMeansPalettizer,
KMeansPalettizerConfig,
ModuleKMeansPalettizerConfig,
PalettizationSpec,
)
from coreai_opt.palettization.spec import PerGroupedChannelGranularity
from coreai_opt.quantization import (
ModuleQuantizerConfig,
QuantizationSpec,
Quantizer,
QuantizerConfig,
)
from coreai_opt.quantization.spec import QuantizationScheme
example_inputs = (torch.randn(1, 1, 28, 28),)
We use 4-bit per-grouped-channel palettization with LUT quantization: each group of 2 output channels gets its own look-up table of 2⁴ = 16 centroids, and the LUT entries are quantized to int8 for additional compression.
[32]:
N_BITS = 4
NUM_CALIBRATION_BATCHES = 1
PalettizationSpec controls the palette: n_bits sets the bits per index, PerGroupedChannelGranularity(axis=0, group_size=2) gives each group of 2 output channels its own LUT, and lut_qspec quantizes the LUT entries to int8.
[33]:
model = MnistNetwork(num_classes=10, state_dict=pretrained_state_dict)
model.eval()
lut_qspec = QuantizationSpec(
dtype="int8"
)
weight_spec = PalettizationSpec(
n_bits=N_BITS,
granularity=PerGroupedChannelGranularity(axis=0, group_size=2),
lut_qspec=lut_qspec,
)
config = KMeansPalettizerConfig(
global_config=ModuleKMeansPalettizerConfig(
op_state_spec={"weight": weight_spec},
)
)
palettizer = KMeansPalettizer(model, config)
prepared = palettizer.prepare(example_inputs)
print(f"Prepared k-means palettization with n_bits={N_BITS}, LUT quantized to int8")
Palettizing layers (num_workers=1): 100%|██████████| 2/2 [00:00<00:00, 91.17it/s]
Prepared k-means palettization with n_bits=4, LUT quantized to int8
The palettizer must be finalized before the quantizer can be applied. This converts the fake palettization modules into the backend-specific representation.
Note: The backend specified here must match the backend used for the subsequent activation quantization step.
[34]:
palettized_model = palettizer.finalize(backend=ExportBackend.CoreAI)
Now we apply activation-only quantization. Setting op_state_spec=None is critical — the weights are already palettized, so applying weight quantization on top would be redundant. We quantize inputs and outputs of every operation to int8 symmetric.
[35]:
activation_spec = QuantizationSpec(
dtype=torch.int8,
qscheme=QuantizationScheme.SYMMETRIC,
)
quant_config = QuantizerConfig(
global_config=ModuleQuantizerConfig(
op_state_spec=None,
op_input_spec={"*": activation_spec},
op_output_spec={"*": activation_spec},
)
)
quantizer = Quantizer(palettized_model, quant_config)
quant_prepared = quantizer.prepare(example_inputs)
with quantizer.calibration_mode():
for i, (data, _) in enumerate(train_loader):
if i >= NUM_CALIBRATION_BATCHES:
break
quant_prepared(data)
print(f"Calibrated activations with {NUM_CALIBRATION_BATCHES} batch(es) ({NUM_CALIBRATION_BATCHES * BATCH_SIZE} samples)")
Calibrated activations with 1 batch(es) (128 samples)
With both palettization and activation quantization applied, let’s measure the joint compression accuracy.
[36]:
joint_accuracy = eval_model(quant_prepared, test_loader)
print(f"K-Means P4A8 joint compression accuracy: {joint_accuracy:.4f}")
K-Means P4A8 joint compression accuracy: 0.9782
Export to Core AI¶
Once the jointly compressed model is ready, call finalize() on the quantizer to produce the final model for deployment. We export the k-means joint-compressed model to Core AI.
The export proceeds in three steps: trace the model with torch.export.export(), cast remaining FP32 parameters to FP16 with cast_to_16_bit_precision(), then convert to Core AI format using coreai-torch.
[37]:
coreai_model = quantizer.finalize(backend=ExportBackend.CoreAI)
[38]:
import shutil
from coreai_opt.casting import cast_to_16_bit_precision
from coreai_torch import TorchConverter, get_decomp_table
exported_program = torch.export.export(coreai_model, example_inputs, strict=False)
exported_program = exported_program.run_decompositions(get_decomp_table())
cast_to_16_bit_precision(exported_program)
coreai_program = TorchConverter().add_exported_program(exported_program).to_coreai()
coreai_program.optimize()
output_path = Path(SAVE_DIRECTORY) / "exported_model.aimodel"
if output_path.exists():
shutil.rmtree(output_path)
coreai_program.save_asset(output_path)
print(f"Exported: {output_path}")
Exported: exported_model.aimodel