Applying palettization to an MNIST model¶
In this tutorial, we will be providing a basic introduction to palettizing a model with CoreAI-Opt.
After the end of this tutorial, you should be familiar with the following:
Table of Contents:
Setup¶
We will be using a basic CNN model and train it on the MNIST dataset and observe its final accuracy.
Once we train this CNN model, we will apply palettization to it using coreai-opt starting with k-means palettization (data free), and then moving to apply calibration data based Sensitive K-Means (SKM) palettization.
[21]:
import random
from pathlib import Path
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchinfo import summary
from torchvision import datasets, transforms
[22]:
SEED = 1976
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
[22]:
<torch._C.Generator at 0x111d68c50>
[23]:
# Used to save intermediate results and datasets
SAVE_DIRECTORY = "."
MNIST Dataset download¶
Helper to download the MNIST dataset with standard normalization applied.
[24]:
def mnist_transforms() -> transforms.Compose:
return transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
def download_mnist_dataset(
download_path: Path, transform: transforms.Compose | None = None
) -> tuple[datasets.MNIST, datasets.MNIST]:
if transform is None:
transform = mnist_transforms()
train = datasets.MNIST(download_path, train=True, download=True, transform=transform)
test = datasets.MNIST(download_path, train=False, download=True, transform=transform)
return train, test
Model definition¶
A simple CNN with a single Conv2d → ReLU → MaxPool block, followed by Flatten and a Linear classifier.
[25]:
class MnistNetwork(nn.Module):
def __init__(self, num_classes: int = 10, state_dict: dict | None = None) -> None:
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(1, 12, 3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, stride=2, padding=0),
nn.Flatten(),
nn.Linear(2352, num_classes),
)
if state_dict is not None:
self.load_state_dict(state_dict)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
Training and Evaluation¶
Standard PyTorch training loop and evaluation function that computes accuracy.
[26]:
def train_step(model, optimizer, loss_fn, inputs, ground_truth) -> float:
model.train()
device = next(model.parameters()).device
inputs = inputs.to(device)
ground_truth = ground_truth.to(device)
optimizer.zero_grad()
predictions = model(inputs)
loss = loss_fn(predictions, ground_truth)
loss.backward()
optimizer.step()
return loss.item()
def train_epoch(model, train_loader, optimizer, loss_fn) -> float:
total_loss = 0.0
for inputs, ground_truth in train_loader:
loss = train_step(model, optimizer, loss_fn, inputs, ground_truth)
total_loss += loss
return total_loss / len(train_loader)
def create_adam_optimizer(model: nn.Module, lr: float = 1e-3) -> torch.optim.Adam:
return torch.optim.Adam(model.parameters(), lr=lr)
def eval_model(model: nn.Module, test_dataloader: DataLoader) -> float:
model.eval()
device = next(model.parameters()).device
num_correct = 0
total = 0
with torch.no_grad():
for inputs, ground_truth in test_dataloader:
inputs = inputs.to(device)
ground_truth = ground_truth.to(device)
predictions = model(inputs)
_, predicted = torch.max(predictions.data, 1)
total += ground_truth.size(0)
num_correct += (predicted == ground_truth).sum().item()
return num_correct / total
[27]:
# Download and instantiate datasets
DOWNLOAD_PATH = Path(SAVE_DIRECTORY) / ".mnist_dataset"
train_dataset, test_dataset = download_mnist_dataset(
download_path=DOWNLOAD_PATH
)
[28]:
BATCH_SIZE = 128
# Instantiate dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE)
The CNN model used for this tutorial contains a single Conv2d, ReLU, MaxPool, Flatten, and Linear layer. Here’s the structure:
[29]:
basic_cnn_model = MnistNetwork(num_classes=10)
# Print summary of model
summary(basic_cnn_model, input_size=(1, 1, 28, 28))
[29]:
==========================================================================================
Layer (type:depth-idx) Output Shape Param #
==========================================================================================
MnistNetwork [1, 10] --
├─Sequential: 1-1 [1, 10] --
│ └─Conv2d: 2-1 [1, 12, 28, 28] 120
│ └─ReLU: 2-2 [1, 12, 28, 28] --
│ └─MaxPool2d: 2-3 [1, 12, 14, 14] --
│ └─Flatten: 2-4 [1, 2352] --
│ └─Linear: 2-5 [1, 10] 23,530
==========================================================================================
Total params: 23,650
Trainable params: 23,650
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.12
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.08
Params size (MB): 0.09
Estimated Total Size (MB): 0.17
==========================================================================================
Train baseline model¶
Let’s train this model so we can get a baseline accuracy. We save the trained weights so we can reload them for each palettization experiment.
[30]:
EPOCHS = 10
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = create_adam_optimizer(basic_cnn_model)
basic_cnn_model = basic_cnn_model.to("mps")
epoch_results = []
for epoch in range(EPOCHS):
epoch_avg_loss = train_epoch(
model=basic_cnn_model, train_loader=train_loader, optimizer=optimizer, loss_fn=loss_fn
)
epoch_results.append(f" Epoch {epoch + 1}: loss={epoch_avg_loss:.4f}")
basic_cnn_model = basic_cnn_model.cpu().eval()
print("\n".join(epoch_results))
# Save trained weights for reuse in palettization sections
pretrained_state_dict = basic_cnn_model.state_dict()
Epoch 1: loss=0.3126
Epoch 2: loss=0.1170
Epoch 3: loss=0.0821
Epoch 4: loss=0.0669
Epoch 5: loss=0.0590
Epoch 6: loss=0.0522
Epoch 7: loss=0.0481
Epoch 8: loss=0.0452
Epoch 9: loss=0.0415
Epoch 10: loss=0.0379
[31]:
accuracy = eval_model(basic_cnn_model, test_loader)
print(f"Baseline accuracy: {accuracy:.4f}")
Baseline accuracy: 0.9799
K-Means Palettization¶
Palettization compresses a model’s weights by clustering them into a small look-up table (LUT) of centroids. Each weight is then replaced by an index into this table, so the weights can be stored using only a few bits per value (the palette).
K-means palettization uses standard k-means clustering to compute the LUT. It is data-free — no calibration data is required — and can be applied directly to a trained model, much like weight-only quantization.
[32]:
from coreai_opt.palettization import (
KMeansPalettizer,
KMeansPalettizerConfig,
ModuleKMeansPalettizerConfig,
PalettizationSpec,
)
from coreai_opt.palettization.spec import PerGroupedChannelGranularity
example_inputs = (torch.randn(1, 1, 28, 28),)
For this tutorial, we will use 4-bit per-grouped-channel palettization: each group of 2 output channels gets its own look-up table of 2⁴ = 16 centroids. Refer to the PalettizationSpec reference for all options.
[33]:
N_BITS = 4
KMeansPalettizerConfig describes how to palettize each operation through op_state_spec, which maps a state tensor name (such as "weight") to a PalettizationSpec.
The PalettizationSpec controls the palette: n_bits sets the bits per index, and PerGroupedChannelGranularity(axis=0, group_size=2) gives each group of 2 output channels its own LUT — a finer-grained alternative to the default single per-tensor LUT.
[34]:
kmeans_model = MnistNetwork(num_classes=10, state_dict=pretrained_state_dict)
kmeans_model.eval()
weight_spec = PalettizationSpec(
n_bits=N_BITS, granularity=PerGroupedChannelGranularity(axis=0, group_size=2)
)
kmeans_config = KMeansPalettizerConfig(
global_config=ModuleKMeansPalettizerConfig(
op_state_spec={"weight": weight_spec},
)
)
kmeans_palettizer = KMeansPalettizer(kmeans_model, kmeans_config)
kmeans_prepared = kmeans_palettizer.prepare(example_inputs)
print(f"Prepared k-means palettization with n_bits={N_BITS}")
Palettizing layers (num_workers=1): 100%|██████████| 2/2 [00:00<00:00, 83.27it/s]
Prepared k-means palettization with n_bits=4
The built-in
KMeansPalettizerConfig.presets.w4()preset builds a similar 4-bit per-grouped-channel config. Its defaultgroup_sizeis 16, which suits larger models; here we usegroup_size=2because this toy model’s output-channel counts (12 and 10) aren’t divisible by 16.KMeansPalettizerConfig.presets.w4(group_size=2)would be an equivalent shortcut for the config above.
After calling prepare(), the model now has fake palettization modules inserted for its weight tensors. These simulate the effect of the LUT during the forward pass, so we can measure accuracy before exporting.
[35]:
kmeans_accuracy = eval_model(kmeans_prepared, test_loader)
print(f"K-means palettization accuracy: {kmeans_accuracy:.4f}")
K-means palettization accuracy: 0.9787
Sensitive K-Means Palettization¶
Now, let’s try Sensitive K-Means (SKM) palettization through coreai-opt.
SKM uses calibration data to compute a per-weight importance score (sensitivity). Following the SqueezeLLM method, it runs a backward pass and collects the squared gradients of each weight as its sensitivity. The k-means clustering is then weighted by these sensitivities, moving centroids closer to the weights that matter most for the model’s loss.
[36]:
skm_model = MnistNetwork(num_classes=10, state_dict=pretrained_state_dict)
skm_model.eval()
skm_config = KMeansPalettizerConfig(
global_config=ModuleKMeansPalettizerConfig(
op_state_spec={"weight": weight_spec},
)
)
skm_palettizer = KMeansPalettizer(skm_model, skm_config)
skm_prepared = skm_palettizer.prepare(example_inputs)
print(f"Prepared Sensitive K-Means palettization with n_bits={N_BITS}")
Palettizing layers (num_workers=1): 100%|██████████| 2/2 [00:00<00:00, 82.94it/s]
Prepared Sensitive K-Means palettization with n_bits=4
We now enter calibration_mode() with a loss function and feed it representative inputs.
For each batch, skm.step(output, target) computes the loss and runs a backward pass so the palettizer can collect squared gradients as sensitivities. When the context manager exits, the LUTs are recomputed using weighted k-means based on those sensitivities.
[37]:
import torch.nn.functional as F
NUM_CALIBRATION_BATCHES = 1
with skm_palettizer.calibration_mode(loss_fn=F.cross_entropy) as skm:
for i, (data, target) in enumerate(train_loader):
if i >= NUM_CALIBRATION_BATCHES:
break
output = skm_prepared(data)
skm.step(output, target)
print(
f"Calibrated with {NUM_CALIBRATION_BATCHES} batch(es) ({NUM_CALIBRATION_BATCHES * BATCH_SIZE} samples)"
)
Palettizing layers (num_workers=1): 100%|██████████| 2/2 [00:00<00:00, 98.55it/s]
Calibrated with 1 batch(es) (128 samples)
[38]:
skm_accuracy = eval_model(skm_prepared, test_loader)
print(f"Sensitive K-Means palettization accuracy: {skm_accuracy:.4f}")
Sensitive K-Means palettization accuracy: 0.9783
Export to Core AI¶
Once the palettized model is ready, call finalize() to convert the fake palettization modules into the real LUT-based representation for deployment.
Pass ExportBackend.CoreAI to finalize(backend=...) to target the .aimodel format produced by coreai-torch.
We’ll export the Sensitive K-Means model from the previous section.
[39]:
from coreai_opt import ExportBackend
coreai_model = skm_palettizer.finalize(backend=ExportBackend.CoreAI)
The export proceeds in three steps:
Trace the model with
torch.export.export()to obtain a graph representation.Apply
cast_to_16_bit_precision()to cast remaining FP32 parameters to FP16 for optimal on-device performance.Convert the exported program to Core AI format using
coreai-torch.TorchConverter.
[40]:
import shutil
from coreai_opt.casting import cast_to_16_bit_precision
from coreai_torch import TorchConverter, get_decomp_table
exported_program = torch.export.export(coreai_model, example_inputs, strict=False)
exported_program = exported_program.run_decompositions(get_decomp_table())
cast_to_16_bit_precision(exported_program)
coreai_program = TorchConverter().add_exported_program(exported_program).to_coreai()
coreai_program.optimize()
output_path = Path(SAVE_DIRECTORY) / "exported_model.aimodel"
if output_path.exists():
shutil.rmtree(output_path)
coreai_program.save_asset(output_path)
print(f"Exported: {output_path}")
Exported: exported_model.aimodel