# Palettization Using Differentiable K-Means

In this tutorial, you learn how to palettize a network trained on MNIST using DKMPalettizer.

## Defining the Network and Dataset

from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F

def mnist_net(num_classes=10):
return nn.Sequential(OrderedDict([
('relu1', nn.ReLU()),
('bn1', nn.BatchNorm2d(32, eps=0.001, momentum=0.01)),
('relu2', nn.ReLU()),
('flatten', nn.Flatten()),
('dense1', nn.Linear(3136, 1024)),
('relu3', nn.ReLU()),
('dropout', nn.Dropout(p=0.4)),
('dense2', nn.Linear(1024, num_classes)),
('softmax', nn.LogSoftmax())]))


For training, use the MNIST dataset provided by PyTorch. Apply a very simple transformation to the input images to normalize them.

import os

from torchvision import datasets, transforms

def mnist_dataset(data_dir="~/.mnist_palettization_data"):
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)
data_path = os.path.expanduser(f"{data_dir}/mnist")
if not os.path.exists(data_path):
os.makedirs(data_path)
test = datasets.MNIST(data_path, train=False, transform=transform)
return train, test


Initialize the model and the dataset.

model = mnist_net()

batch_size = 128
train_dataset, test_dataset = mnist_dataset("~/.mnist_data/mnist_palettization")


## Training the Model Without Palettization

Train the model without applying any palettization.

optimizer = torch.optim.SGD(model.parameters(), lr=0.008)
accuracy_unpalettized = 0.0
num_epochs = 2

def train_step(model, optimizer, train_loader, data, target, batch_idx, epoch, palettizer = None):
if palettizer is not None:
palettizer.step()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
100. * batch_idx / len(train_loader), loss.item()))

model.eval()
test_loss = 0
correct = 0
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()

accuracy = 100. * correct / len(test_loader.dataset)

print(
"\nTest set: Average loss: {:.4f}, Accuracy: {:.1f}%\n".format(
test_loss, accuracy
)
)
return accuracy

for epoch in range(num_epochs):
# train one epoch
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
train_step(model, optimizer, train_loader, data, target, batch_idx, epoch)

# evaluate

print("Accuracy of unpalettized network: {:.1f}%\n".format(accuracy_unpalettized))


## Configuring Palettization

Insert palettization layers into the trained model. For this example, apply a 4-bit palettization to the conv2 layer. This would mean that for all the weights that exist in this layer, you try to map each weight element to one of $$2^4$$, that is, 16 clusters.

Note that calling prepare() simply inserts palettization layers into the model. It doesn’t actually palettize the weights. You do that in the next step when you fine-tune the model.

from coremltools.optimize.torch.palettization import DKMPalettizer, DKMPalettizerConfig

config = DKMPalettizerConfig.from_dict(
{"module_name_configs": {"conv2": {"n_bits": 4}}}
)
palettizer = DKMPalettizer(model, config)

prepared_model = palettizer.prepare()


## Fine-Tuning the Palettized Model

Fine-tune the model with palettization applied. This helps the model learn the new palettized layers’ weights in the form of a LUT and indices.

optimizer = torch.optim.SGD(prepared_model.parameters(), lr=0.008)
accuracy_palettized = 0.0
num_epochs = 2

for epoch in range(num_epochs):
prepared_model.train()
for batch_idx, (data, target) in enumerate(train_loader):
train_step(prepared_model, optimizer, train_loader, data, target, batch_idx, epoch, palettizer)

# evaluate


The evaluation shows that you can train a palettized network without losing much accuracy with the final model.

print("Accuracy of unpalettized network: {:.1f}%\n".format(accuracy_unpalettized))
print("Accuracy of palettized network: {:.1f}%\n".format(accuracy_palettized))


## Restoring LUT and Indices as Weights

Use finalize() to restore the LUT and indices of the palettized modules as weights in the model.

finalized_model = palettizer.finalize()


## Exporting the Model for On-Device Execution

To deploy the model on device, convert it to a Core ML model.

To export the model with Core ML Tools, first trace the model with an input, and then use the Core ML Tools converter, as described in Converting from PyTorch. The parameter ct.PassPipeline.DEFAULT_PALETTIZATION signals to the converter a palettized model is being converted, and allows its weights to be represented using a look-up table (LUT) and indices, which have a much smaller footprint on disk as compared to the dense weights.

import coremltools as ct

finalized_model.eval()
example_input = torch.rand(1, 1, 28, 28)
traced_model = torch.jit.trace(finalized_model, example_input)

coreml_model = ct.convert(
traced_model,
inputs=[ct.TensorType(shape=example_input.shape)],
pass_pipeline=ct.PassPipeline.DEFAULT_PALETTIZATION,
minimum_deployment_target=ct.target.iOS16,
)

coreml_model.save("~/.mnist_palettization_data/palettized_model.mlpackage")


Total running time of the script: (0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery