.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "_examples/linear_quantization.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr__examples_linear_quantization.py: .. _linear_quantization_tutorial: Linear Quantization =================== .. GENERATED FROM PYTHON SOURCE LINES 11-17 In this tutorial, you learn how to train a simple convolutional neural network on `MNIST `_ using :py:class:`~.quantization.LinearQuantizer`. Learn more about other quantization in the coremltools `Training-Time Quantization Documentation `_. .. GENERATED FROM PYTHON SOURCE LINES 19-23 Network and Dataset Definition ------------------------------ First define your network, which consists of a single convolution layer followed by a dense (linear) layer. .. GENERATED FROM PYTHON SOURCE LINES 23-47 .. code-block:: default from collections import OrderedDict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F def mnist_net(num_classes=10): return nn.Sequential( OrderedDict( [ ("conv", nn.Conv2d(1, 12, 3, padding=1)), ("relu", nn.ReLU()), ("pool", nn.MaxPool2d(2, stride=2, padding=0)), ("flatten", nn.Flatten()), ("dense", nn.Linear(2352, num_classes)), ("softmax", nn.LogSoftmax()), ] ) ) .. GENERATED FROM PYTHON SOURCE LINES 48-51 Use the `MNIST dataset provided by PyTorch `_ for training. Apply a very simple transformation to the input images to normalize them. .. GENERATED FROM PYTHON SOURCE LINES 51-69 .. code-block:: default import os from torchvision import datasets, transforms def mnist_dataset(data_dir="~/.mnist_qat_data"): transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ) data_path = os.path.expanduser(f"{data_dir}/mnist") if not os.path.exists(data_path): os.makedirs(data_path) train = datasets.MNIST(data_path, train=True, download=True, transform=transform) test = datasets.MNIST(data_path, train=False, transform=transform) return train, test .. GENERATED FROM PYTHON SOURCE LINES 70-71 Next, initialize the model and the dataset. .. GENERATED FROM PYTHON SOURCE LINES 71-81 .. code-block:: default model = mnist_net() batch_size = 128 train_dataset, test_dataset = mnist_dataset() train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True ) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size) .. GENERATED FROM PYTHON SOURCE LINES 82-85 Training the Model Without Quantization --------------------------------------- Train the model without any quantization applied. .. GENERATED FROM PYTHON SOURCE LINES 85-143 .. code-block:: default optimizer = torch.optim.Adam(model.parameters(), eps=1e-07) accuracy_unquantized = 0.0 num_epochs = 4 def train_step(model, optimizer, train_loader, data, target, batch_idx, epoch): optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step() if batch_idx % 100 == 0: print( "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( epoch, batch_idx * len(data), len(train_loader.dataset), 100.0 * batch_idx / len(train_loader), loss.item(), ) ) def eval_model(model, test_loader): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: output = model(data) test_loss += F.nll_loss(output, target, reduction="sum").item() pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) accuracy = 100.0 * correct / len(test_loader.dataset) print( "\nTest set: Average loss: {:.4f}, Accuracy: {:.1f}%\n".format( test_loss, accuracy ) ) return accuracy for epoch in range(num_epochs): # train one epoch model.train() for batch_idx, (data, target) in enumerate(train_loader): train_step(model, optimizer, train_loader, data, target, batch_idx, epoch) # evaluate accuracy_unquantized = eval_model(model, test_loader) print("Accuracy of unquantized network: {:.1f}%\n".format(accuracy_unquantized)) .. GENERATED FROM PYTHON SOURCE LINES 144-155 Insert Quantization Layers in the Model --------------------------------------- Install :py:class:`~.quantization.LinearQuantizer` in the trained model. Create an instance of the :py:class:`~.quantization.LinearQuantizerConfig` class to specify quantization parameters. ``milestones=[0, 1, 2, 1]`` refers to the following: * *Index 0*: At 0th epoch, observers will start collecting statistics of values of tensors being quantized * *Index 1*: At 1st epoch, quantization simulation will begin * *Index 2*: At 2nd epoch, observers will stop collecting and quantization parameters will be frozen * *Index 3*: At 1st epoch, batch normalization layers will stop collecting mean and variance, and will start running in inference mode .. GENERATED FROM PYTHON SOURCE LINES 155-168 .. code-block:: default from coremltools.optimize.torch.quantization import ( LinearQuantizer, LinearQuantizerConfig, ModuleLinearQuantizerConfig, ) global_config = ModuleLinearQuantizerConfig(milestones=[0, 1, 2, 1]) config = LinearQuantizerConfig(global_config=global_config) quantizer = LinearQuantizer(model, config) .. GENERATED FROM PYTHON SOURCE LINES 169-171 Next, call :py:meth:`~.quantization.LinearQuantizer.prepare` to insert fake quantization layers in the model. .. GENERATED FROM PYTHON SOURCE LINES 171-174 .. code-block:: default qmodel = quantizer.prepare(example_inputs=torch.randn(1, 1, 28, 28)) .. GENERATED FROM PYTHON SOURCE LINES 175-180 Fine-Tuning the Model --------------------- The next step is to fine tune the model with quantization applied. Call :py:meth:`~.quantization.LinearQuantizer.step` to step through the quantization milestones. .. GENERATED FROM PYTHON SOURCE LINES 180-195 .. code-block:: default optimizer = torch.optim.Adam(qmodel.parameters(), eps=1e-07) accuracy_quantized = 0.0 num_epochs = 4 for epoch in range(num_epochs): # train one epoch model.train() for batch_idx, (data, target) in enumerate(train_loader): quantizer.step() train_step(qmodel, optimizer, train_loader, data, target, batch_idx, epoch) # evaluate accuracy_quantized = eval_model(qmodel, test_loader) .. GENERATED FROM PYTHON SOURCE LINES 196-201 The evaluation shows that you can train a quantized network without a significant loss in model accuracy. In practice, for more complex models, quantization can be lossy and lead to degradation in validation accuracy. In such cases, you can choose to not quantize certain layers which are less amenable to quantization. .. GENERATED FROM PYTHON SOURCE LINES 201-208 .. code-block:: default print("Accuracy of quantized network: {:.1f}%\n".format(accuracy_quantized)) print("Accuracy of unquantized network: {:.1f}%\n".format(accuracy_unquantized)) np.testing.assert_allclose(accuracy_quantized, accuracy_unquantized, atol=2) .. GENERATED FROM PYTHON SOURCE LINES 209-218 Finalizing the Model for Export ------------------------------- The example shows that you can quantize the model with a few code changes to your existing PyTorch training code. Now you can deploy this model on a device. To finalize the model for export, call :py:meth:`~.pruning.LinearQuantizer.finalize` on the quantizer. This folds the quantization parameters like scale and zero point into the weights. .. GENERATED FROM PYTHON SOURCE LINES 218-222 .. code-block:: default qmodel.eval() quantized_model = quantizer.finalize() .. GENERATED FROM PYTHON SOURCE LINES 223-232 Exporting the Model for On-Device Execution ------------------------------------------- In order to deploy the model, convert it to a Core ML model. Follow the same steps in Core ML Tools for exporting a regular PyTorch model (for details, see `Converting from PyTorch `_). The parameter ``ct.target.iOS17`` is necessary here because activation quantization ops are only supported on iOS versions >= 17. .. GENERATED FROM PYTHON SOURCE LINES 232-245 .. code-block:: default import coremltools as ct example_input = torch.rand(1, 1, 28, 28) traced_model = torch.jit.trace(quantized_model, example_input) coreml_model = ct.convert( traced_model, inputs=[ct.TensorType(shape=example_input.shape)], minimum_deployment_target=ct.target.iOS17, ) coreml_model.save("~/.mnist_qat_data/quantized_model.mlpackage") .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.000 seconds) .. _sphx_glr_download__examples_linear_quantization.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: linear_quantization.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: linear_quantization.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_