.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "_examples/dkm_palettization.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr__examples_dkm_palettization.py: .. _palettization_tutorial: Palettization Using Differentiable K-Means ========================================== .. GENERATED FROM PYTHON SOURCE LINES 11-17 In this tutorial, you learn how to palettize a network trained on `MNIST `_ using :py:class:`~.palettizer.DKMPalettizer`. Learn more about other palettization in the coremltools `Training-Time Palettization Documentation `_. .. GENERATED FROM PYTHON SOURCE LINES 20-24 Defining the Network and Dataset -------------------------------- First, define your network: .. GENERATED FROM PYTHON SOURCE LINES 24-49 .. code-block:: default from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F def mnist_net(num_classes=10): return nn.Sequential(OrderedDict([ ('conv1', nn.Conv2d(1, 32, 5, padding='same')), ('relu1', nn.ReLU()), ('pool1', nn.MaxPool2d(2, stride=2, padding=0)), ('bn1', nn.BatchNorm2d(32, eps=0.001, momentum=0.01)), ('conv2', nn.Conv2d(32, 64, 5, padding='same')), ('relu2', nn.ReLU()), ('pool2', nn.MaxPool2d(2, stride=2, padding=0)), ('flatten', nn.Flatten()), ('dense1', nn.Linear(3136, 1024)), ('relu3', nn.ReLU()), ('dropout', nn.Dropout(p=0.4)), ('dense2', nn.Linear(1024, num_classes)), ('softmax', nn.LogSoftmax())])) .. GENERATED FROM PYTHON SOURCE LINES 50-53 For training, use the MNIST dataset provided by `PyTorch `_. Apply a very simple transformation to the input images to normalize them. .. GENERATED FROM PYTHON SOURCE LINES 53-71 .. code-block:: default import os from torchvision import datasets, transforms def mnist_dataset(data_dir="~/.mnist_palettization_data"): transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ) data_path = os.path.expanduser(f"{data_dir}/mnist") if not os.path.exists(data_path): os.makedirs(data_path) train = datasets.MNIST(data_path, train=True, download=True, transform=transform) test = datasets.MNIST(data_path, train=False, transform=transform) return train, test .. GENERATED FROM PYTHON SOURCE LINES 72-73 Initialize the model and the dataset. .. GENERATED FROM PYTHON SOURCE LINES 73-81 .. code-block:: default model = mnist_net() batch_size = 128 train_dataset, test_dataset = mnist_dataset("~/.mnist_data/mnist_palettization") train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size) .. GENERATED FROM PYTHON SOURCE LINES 82-86 Training the Model Without Palettization ---------------------------------------- Train the model without applying any palettization. .. GENERATED FROM PYTHON SOURCE LINES 86-139 .. code-block:: default optimizer = torch.optim.SGD(model.parameters(), lr=0.008) accuracy_unpalettized = 0.0 num_epochs = 2 def train_step(model, optimizer, train_loader, data, target, batch_idx, epoch, palettizer = None): optimizer.zero_grad() if palettizer is not None: palettizer.step() output = model(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step() if batch_idx % 100 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) def eval_model(model, test_loader): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: output = model(data) test_loss += F.nll_loss(output, target, reduction='sum').item() pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) accuracy = 100. * correct / len(test_loader.dataset) print( "\nTest set: Average loss: {:.4f}, Accuracy: {:.1f}%\n".format( test_loss, accuracy ) ) return accuracy for epoch in range(num_epochs): # train one epoch model.train() for batch_idx, (data, target) in enumerate(train_loader): train_step(model, optimizer, train_loader, data, target, batch_idx, epoch) # evaluate accuracy_unpalettized = eval_model(model, test_loader) print("Accuracy of unpalettized network: {:.1f}%\n".format(accuracy_unpalettized)) .. GENERATED FROM PYTHON SOURCE LINES 140-152 Configuring Palettization ------------------------- Insert palettization layers into the trained model. For this example, apply a ``4-bit`` palettization to the ``conv2`` layer. This would mean that for all the weights that exist in this layer, you try to map each weight element to one of :math:`2^4`, that is, ``16`` clusters. Note that calling :py:meth:`~.palettization.DKMPalettizer.prepare` simply inserts palettization layers into the model. It doesn't actually palettize the weights. You do that in the next step when you fine-tune the model. .. GENERATED FROM PYTHON SOURCE LINES 152-162 .. code-block:: default from coremltools.optimize.torch.palettization import DKMPalettizer, DKMPalettizerConfig config = DKMPalettizerConfig.from_dict( {"module_name_configs": {"conv2": {"n_bits": 4}}} ) palettizer = DKMPalettizer(model, config) prepared_model = palettizer.prepare() .. GENERATED FROM PYTHON SOURCE LINES 163-168 Fine-Tuning the Palettized Model -------------------------------- Fine-tune the model with palettization applied. This helps the model learn the new palettized layers' weights in the form of a LUT and indices. .. GENERATED FROM PYTHON SOURCE LINES 168-181 .. code-block:: default optimizer = torch.optim.SGD(prepared_model.parameters(), lr=0.008) accuracy_palettized = 0.0 num_epochs = 2 for epoch in range(num_epochs): prepared_model.train() for batch_idx, (data, target) in enumerate(train_loader): train_step(prepared_model, optimizer, train_loader, data, target, batch_idx, epoch, palettizer) # evaluate accuracy_palettized = eval_model(prepared_model, test_loader) .. GENERATED FROM PYTHON SOURCE LINES 182-184 The evaluation shows that you can train a palettized network without losing much accuracy with the final model. .. GENERATED FROM PYTHON SOURCE LINES 184-188 .. code-block:: default print("Accuracy of unpalettized network: {:.1f}%\n".format(accuracy_unpalettized)) print("Accuracy of palettized network: {:.1f}%\n".format(accuracy_palettized)) .. GENERATED FROM PYTHON SOURCE LINES 189-194 Restoring LUT and Indices as Weights ------------------------------------ Use :py:meth:`~.palettization.Palettizer.finalize` to restore the LUT and indices of the palettized modules as weights in the model. .. GENERATED FROM PYTHON SOURCE LINES 194-197 .. code-block:: default finalized_model = palettizer.finalize() .. GENERATED FROM PYTHON SOURCE LINES 198-210 Exporting the Model for On-Device Execution ------------------------------------------- To deploy the model on device, convert it to a Core ML model. To export the model with Core ML Tools, first trace the model with an input, and then use the Core ML Tools converter, as described in `Converting from PyTorch `_. The parameter ``ct.PassPipeline.DEFAULT_PALETTIZATION`` signals to the converter a palettized model is being converted, and allows its weights to be represented using a look-up table (LUT) and indices, which have a much smaller footprint on disk as compared to the dense weights. .. GENERATED FROM PYTHON SOURCE LINES 210-226 .. code-block:: default import coremltools as ct finalized_model.eval() example_input = torch.rand(1, 1, 28, 28) traced_model = torch.jit.trace(finalized_model, example_input) coreml_model = ct.convert( traced_model, inputs=[ct.TensorType(shape=example_input.shape)], pass_pipeline=ct.PassPipeline.DEFAULT_PALETTIZATION, minimum_deployment_target=ct.target.iOS16, ) coreml_model.save("~/.mnist_palettization_data/palettized_model.mlpackage") .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.000 seconds) .. _sphx_glr_download__examples_dkm_palettization.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: dkm_palettization.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: dkm_palettization.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_