Applying palettization to an MNIST model

In this tutorial, we will be providing a basic introduction to palettizing a model with CoreAI-Opt.

After the end of this tutorial, you should be familiar with the following:

  1. How to apply CoreAI-Opt’s k-means palettization

  2. How to apply CoreAI-Opt’s Sensitive K-Means (SKM) palettization

  3. How to export CoreAI-Opt palettized models to Core AI

Table of Contents:

Setup

We will be using a basic CNN model and train it on the MNIST dataset and observe its final accuracy.

Once we train this CNN model, we will apply palettization to it using coreai-opt starting with k-means palettization (data free), and then moving to apply calibration data based Sensitive K-Means (SKM) palettization.

[21]:
import random
from pathlib import Path

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchinfo import summary
from torchvision import datasets, transforms
[22]:
SEED = 1976

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
[22]:
<torch._C.Generator at 0x111d68c50>
[23]:
# Used to save intermediate results and datasets
SAVE_DIRECTORY = "."

MNIST Dataset download

Helper to download the MNIST dataset with standard normalization applied.

[24]:
def mnist_transforms() -> transforms.Compose:
    return transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])


def download_mnist_dataset(
    download_path: Path, transform: transforms.Compose | None = None
) -> tuple[datasets.MNIST, datasets.MNIST]:
    if transform is None:
        transform = mnist_transforms()
    train = datasets.MNIST(download_path, train=True, download=True, transform=transform)
    test = datasets.MNIST(download_path, train=False, download=True, transform=transform)
    return train, test

Model definition

A simple CNN with a single Conv2d → ReLU → MaxPool block, followed by Flatten and a Linear classifier.

[25]:
class MnistNetwork(nn.Module):
    def __init__(self, num_classes: int = 10, state_dict: dict | None = None) -> None:
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 12, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, stride=2, padding=0),
            nn.Flatten(),
            nn.Linear(2352, num_classes),
        )
        if state_dict is not None:
            self.load_state_dict(state_dict)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

Training and Evaluation

Standard PyTorch training loop and evaluation function that computes accuracy.

[26]:
def train_step(model, optimizer, loss_fn, inputs, ground_truth) -> float:
    model.train()
    device = next(model.parameters()).device
    inputs = inputs.to(device)
    ground_truth = ground_truth.to(device)
    optimizer.zero_grad()
    predictions = model(inputs)
    loss = loss_fn(predictions, ground_truth)
    loss.backward()
    optimizer.step()
    return loss.item()


def train_epoch(model, train_loader, optimizer, loss_fn) -> float:
    total_loss = 0.0
    for inputs, ground_truth in train_loader:
        loss = train_step(model, optimizer, loss_fn, inputs, ground_truth)
        total_loss += loss
    return total_loss / len(train_loader)


def create_adam_optimizer(model: nn.Module, lr: float = 1e-3) -> torch.optim.Adam:
    return torch.optim.Adam(model.parameters(), lr=lr)


def eval_model(model: nn.Module, test_dataloader: DataLoader) -> float:
    model.eval()
    device = next(model.parameters()).device
    num_correct = 0
    total = 0
    with torch.no_grad():
        for inputs, ground_truth in test_dataloader:
            inputs = inputs.to(device)
            ground_truth = ground_truth.to(device)
            predictions = model(inputs)
            _, predicted = torch.max(predictions.data, 1)
            total += ground_truth.size(0)
            num_correct += (predicted == ground_truth).sum().item()
    return num_correct / total
[27]:
# Download and instantiate datasets
DOWNLOAD_PATH = Path(SAVE_DIRECTORY) / ".mnist_dataset"

train_dataset, test_dataset = download_mnist_dataset(
    download_path=DOWNLOAD_PATH
)
[28]:
BATCH_SIZE = 128

# Instantiate dataloaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE)

The CNN model used for this tutorial contains a single Conv2d, ReLU, MaxPool, Flatten, and Linear layer. Here’s the structure:

[29]:
basic_cnn_model = MnistNetwork(num_classes=10)

# Print summary of model
summary(basic_cnn_model, input_size=(1, 1, 28, 28))
[29]:
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
MnistNetwork                             [1, 10]                   --
├─Sequential: 1-1                        [1, 10]                   --
│    └─Conv2d: 2-1                       [1, 12, 28, 28]           120
│    └─ReLU: 2-2                         [1, 12, 28, 28]           --
│    └─MaxPool2d: 2-3                    [1, 12, 14, 14]           --
│    └─Flatten: 2-4                      [1, 2352]                 --
│    └─Linear: 2-5                       [1, 10]                   23,530
==========================================================================================
Total params: 23,650
Trainable params: 23,650
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 0.12
==========================================================================================
Input size (MB): 0.00
Forward/backward pass size (MB): 0.08
Params size (MB): 0.09
Estimated Total Size (MB): 0.17
==========================================================================================

Train baseline model

Let’s train this model so we can get a baseline accuracy. We save the trained weights so we can reload them for each palettization experiment.

[30]:
EPOCHS = 10

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = create_adam_optimizer(basic_cnn_model)

basic_cnn_model = basic_cnn_model.to("mps")

epoch_results = []
for epoch in range(EPOCHS):
    epoch_avg_loss = train_epoch(
        model=basic_cnn_model, train_loader=train_loader, optimizer=optimizer, loss_fn=loss_fn
    )
    epoch_results.append(f"  Epoch {epoch + 1}: loss={epoch_avg_loss:.4f}")

basic_cnn_model = basic_cnn_model.cpu().eval()
print("\n".join(epoch_results))

# Save trained weights for reuse in palettization sections
pretrained_state_dict = basic_cnn_model.state_dict()
  Epoch 1: loss=0.3126
  Epoch 2: loss=0.1170
  Epoch 3: loss=0.0821
  Epoch 4: loss=0.0669
  Epoch 5: loss=0.0590
  Epoch 6: loss=0.0522
  Epoch 7: loss=0.0481
  Epoch 8: loss=0.0452
  Epoch 9: loss=0.0415
  Epoch 10: loss=0.0379
[31]:
accuracy = eval_model(basic_cnn_model, test_loader)
print(f"Baseline accuracy: {accuracy:.4f}")
Baseline accuracy: 0.9799

K-Means Palettization

Palettization compresses a model’s weights by clustering them into a small look-up table (LUT) of centroids. Each weight is then replaced by an index into this table, so the weights can be stored using only a few bits per value (the palette).

K-means palettization uses standard k-means clustering to compute the LUT. It is data-free — no calibration data is required — and can be applied directly to a trained model, much like weight-only quantization.

[32]:
from coreai_opt.palettization import (
    KMeansPalettizer,
    KMeansPalettizerConfig,
    ModuleKMeansPalettizerConfig,
    PalettizationSpec,
)
from coreai_opt.palettization.spec import PerGroupedChannelGranularity

example_inputs = (torch.randn(1, 1, 28, 28),)

For this tutorial, we will use 4-bit per-grouped-channel palettization: each group of 2 output channels gets its own look-up table of 2⁴ = 16 centroids. Refer to the PalettizationSpec reference for all options.

[33]:
N_BITS = 4

KMeansPalettizerConfig describes how to palettize each operation through op_state_spec, which maps a state tensor name (such as "weight") to a PalettizationSpec.

The PalettizationSpec controls the palette: n_bits sets the bits per index, and PerGroupedChannelGranularity(axis=0, group_size=2) gives each group of 2 output channels its own LUT — a finer-grained alternative to the default single per-tensor LUT.

[34]:
kmeans_model = MnistNetwork(num_classes=10, state_dict=pretrained_state_dict)
kmeans_model.eval()

weight_spec = PalettizationSpec(
    n_bits=N_BITS, granularity=PerGroupedChannelGranularity(axis=0, group_size=2)
)

kmeans_config = KMeansPalettizerConfig(
    global_config=ModuleKMeansPalettizerConfig(
        op_state_spec={"weight": weight_spec},
    )
)

kmeans_palettizer = KMeansPalettizer(kmeans_model, kmeans_config)
kmeans_prepared = kmeans_palettizer.prepare(example_inputs)
print(f"Prepared k-means palettization with n_bits={N_BITS}")
Palettizing layers (num_workers=1): 100%|██████████| 2/2 [00:00<00:00, 83.27it/s]
Prepared k-means palettization with n_bits=4

The built-in KMeansPalettizerConfig.presets.w4() preset builds a similar 4-bit per-grouped-channel config. Its default group_size is 16, which suits larger models; here we use group_size=2 because this toy model’s output-channel counts (12 and 10) aren’t divisible by 16. KMeansPalettizerConfig.presets.w4(group_size=2) would be an equivalent shortcut for the config above.

After calling prepare(), the model now has fake palettization modules inserted for its weight tensors. These simulate the effect of the LUT during the forward pass, so we can measure accuracy before exporting.

[35]:
kmeans_accuracy = eval_model(kmeans_prepared, test_loader)
print(f"K-means palettization accuracy: {kmeans_accuracy:.4f}")
K-means palettization accuracy: 0.9787

Sensitive K-Means Palettization

Now, let’s try Sensitive K-Means (SKM) palettization through coreai-opt.

SKM uses calibration data to compute a per-weight importance score (sensitivity). Following the SqueezeLLM method, it runs a backward pass and collects the squared gradients of each weight as its sensitivity. The k-means clustering is then weighted by these sensitivities, moving centroids closer to the weights that matter most for the model’s loss.

[36]:
skm_model = MnistNetwork(num_classes=10, state_dict=pretrained_state_dict)
skm_model.eval()

skm_config = KMeansPalettizerConfig(
    global_config=ModuleKMeansPalettizerConfig(
        op_state_spec={"weight": weight_spec},
    )
)

skm_palettizer = KMeansPalettizer(skm_model, skm_config)
skm_prepared = skm_palettizer.prepare(example_inputs)
print(f"Prepared Sensitive K-Means palettization with n_bits={N_BITS}")
Palettizing layers (num_workers=1): 100%|██████████| 2/2 [00:00<00:00, 82.94it/s]
Prepared Sensitive K-Means palettization with n_bits=4

We now enter calibration_mode() with a loss function and feed it representative inputs.

For each batch, skm.step(output, target) computes the loss and runs a backward pass so the palettizer can collect squared gradients as sensitivities. When the context manager exits, the LUTs are recomputed using weighted k-means based on those sensitivities.

[37]:
import torch.nn.functional as F

NUM_CALIBRATION_BATCHES = 1

with skm_palettizer.calibration_mode(loss_fn=F.cross_entropy) as skm:
    for i, (data, target) in enumerate(train_loader):
        if i >= NUM_CALIBRATION_BATCHES:
            break
        output = skm_prepared(data)
        skm.step(output, target)

print(
    f"Calibrated with {NUM_CALIBRATION_BATCHES} batch(es) ({NUM_CALIBRATION_BATCHES * BATCH_SIZE} samples)"
)
Palettizing layers (num_workers=1): 100%|██████████| 2/2 [00:00<00:00, 98.55it/s]
Calibrated with 1 batch(es) (128 samples)

[38]:
skm_accuracy = eval_model(skm_prepared, test_loader)
print(f"Sensitive K-Means palettization accuracy: {skm_accuracy:.4f}")
Sensitive K-Means palettization accuracy: 0.9783

Export to Core AI

Once the palettized model is ready, call finalize() to convert the fake palettization modules into the real LUT-based representation for deployment.

Pass ExportBackend.CoreAI to finalize(backend=...) to target the .aimodel format produced by coreai-torch.

We’ll export the Sensitive K-Means model from the previous section.

[39]:
from coreai_opt import ExportBackend

coreai_model = skm_palettizer.finalize(backend=ExportBackend.CoreAI)

The export proceeds in three steps:

  • Trace the model with torch.export.export() to obtain a graph representation.

  • Apply cast_to_16_bit_precision() to cast remaining FP32 parameters to FP16 for optimal on-device performance.

  • Convert the exported program to Core AI format using coreai-torch.TorchConverter.

[40]:
import shutil

from coreai_opt.casting import cast_to_16_bit_precision
from coreai_torch import TorchConverter, get_decomp_table

exported_program = torch.export.export(coreai_model, example_inputs, strict=False)
exported_program = exported_program.run_decompositions(get_decomp_table())
cast_to_16_bit_precision(exported_program)

coreai_program = TorchConverter().add_exported_program(exported_program).to_coreai()
coreai_program.optimize()

output_path = Path(SAVE_DIRECTORY) / "exported_model.aimodel"
if output_path.exists():
    shutil.rmtree(output_path)
coreai_program.save_asset(output_path)
print(f"Exported: {output_path}")
Exported: exported_model.aimodel