.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "_examples/magnitude_pruning.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr__examples_magnitude_pruning.py: .. _magnitude_pruning_tutorial: Magnitude Pruning ================= .. GENERATED FROM PYTHON SOURCE LINES 11-17 In this tutorial, you learn how to train a simple convolutional neural network on `MNIST `_ using :py:class:`~.pruning.MagnitudePruner`. Learn more about other pruners and schedulers in the coremltools `Training-Time Pruning Documentation `_. .. GENERATED FROM PYTHON SOURCE LINES 19-23 Network and Dataset Definition ------------------------------ First define your network, which consists of a single convolution layer followed by a dense (linear) layer. .. GENERATED FROM PYTHON SOURCE LINES 23-45 .. code-block:: default from collections import OrderedDict import numpy as np import torch import torch.nn as nn import torch.nn.functional as F def mnist_net(num_classes=10): return nn.Sequential( OrderedDict( [('conv', nn.Conv2d(1, 12, 3, padding='same')), ('relu', nn.ReLU()), ('pool', nn.MaxPool2d(2, stride=2, padding=0)), ('flatten', nn.Flatten()), ('dense', nn.Linear(2352, num_classes)), ('softmax', nn.LogSoftmax())] ) ) .. GENERATED FROM PYTHON SOURCE LINES 46-49 Use the `MNIST dataset provided by PyTorch `_ for training. Apply a very simple transformation to the input images to normalize them. .. GENERATED FROM PYTHON SOURCE LINES 49-66 .. code-block:: default import os from torchvision import datasets, transforms def mnist_dataset(data_dir="~/.mnist_pruning_data"): transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] ) data_path = os.path.expanduser(f"{data_dir}/mnist") if not os.path.exists(data_path): os.makedirs(data_path) train = datasets.MNIST(data_path, train=True, download=True, transform=transform) test = datasets.MNIST(data_path, train=False, transform=transform) return train, test .. GENERATED FROM PYTHON SOURCE LINES 67-68 Next, initialize the model and the dataset. .. GENERATED FROM PYTHON SOURCE LINES 68-77 .. code-block:: default model = mnist_net() batch_size = 128 train_dataset, test_dataset = mnist_dataset() train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size) .. GENERATED FROM PYTHON SOURCE LINES 78-81 Training the Model Without Pruning ---------------------------------- Train the model without any pruning applied. .. GENERATED FROM PYTHON SOURCE LINES 81-133 .. code-block:: default optimizer = torch.optim.Adam(model.parameters(), eps=1e-07) accuracy_unpruned = 0.0 num_epochs = 4 def train_step(model, optimizer, train_loader, data, target, batch_idx, epoch): optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step() if batch_idx % 100 == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) def eval_model(model, test_loader): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: output = model(data) test_loss += F.nll_loss(output, target, reduction='sum').item() pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) accuracy = 100. * correct / len(test_loader.dataset) print( "\nTest set: Average loss: {:.4f}, Accuracy: {:.1f}%\n".format( test_loss, accuracy ) ) return accuracy for epoch in range(num_epochs): # train one epoch model.train() for batch_idx, (data, target) in enumerate(train_loader): train_step(model, optimizer, train_loader, data, target, batch_idx, epoch) # evaluate accuracy_unpruned = eval_model(model, test_loader) print("Accuracy of unpruned network: {:.1f}%\n".format(accuracy_unpruned)) .. GENERATED FROM PYTHON SOURCE LINES 134-146 Installing the Pruner in the Model ---------------------------------- Install :py:class:`~.pruning.MagnitudePruner` in the trained model. First, construct a :py:class:`~.pruning.pruning_scheduler.PruningScheduler` class, which specifies how the sparsity of your pruned layers should evolve over the course of the training. For this tutorial, use a :py:class:`~.pruning.PolynomialDecayScheduler`, which is introduced in the paper `"To prune or not to prune" `_. Begin pruning from step ``0`` and prune every ``100`` steps for two epochs. As you step through this pruning scheduler, the sparsity of pruned modules will increase gradually from the initial value to the target value. .. GENERATED FROM PYTHON SOURCE LINES 146-151 .. code-block:: default from coremltools.optimize.torch.pruning import PolynomialDecayScheduler scheduler = PolynomialDecayScheduler(update_steps=list(range(0, 900, 100))) .. GENERATED FROM PYTHON SOURCE LINES 152-159 Next, create an instance of the :py:class:`~.pruning.MagnitudePrunerConfig` class to specify how you want different submodules to be pruned. Set the target sparsity of the convolution layer to ``70 %`` and the dense layer to ``80 %``. The point of this is to demonstrate that different layers can be targeted at different sparsity levels. In practice, the sparsity level of a layer is a hyperparameter, which needs to be tuned for your requirements and the amenability of the layer to sparsification. .. GENERATED FROM PYTHON SOURCE LINES 159-175 .. code-block:: default from coremltools.optimize.torch.pruning import ( MagnitudePruner, MagnitudePrunerConfig, ModuleMagnitudePrunerConfig, ) conv_config = ModuleMagnitudePrunerConfig(target_sparsity=0.7) linear_config = ModuleMagnitudePrunerConfig(target_sparsity=0.8) config = MagnitudePrunerConfig().set_module_type(torch.nn.Conv2d, conv_config) config = config.set_module_type(torch.nn.Linear, linear_config) pruner = MagnitudePruner(model, config) .. GENERATED FROM PYTHON SOURCE LINES 176-182 Next, call :py:meth:`~.pruning.MagnitudePruner.prepare` to insert pruning ``forward pre hooks`` on the modules configured previously. These forward pre hooks are called before a call to the forward method of the module. They multiply the parameter with a pruning mask, which is a tensor of the same shape as the parameter, in which each element has a value of either ``1`` or ``0``. .. GENERATED FROM PYTHON SOURCE LINES 182-185 .. code-block:: default pruner.prepare(inplace=True) .. GENERATED FROM PYTHON SOURCE LINES 186-191 Fine-Tuning the Pruned Model ---------------------------- The next step is to fine tune the model with pruning applied. In order to prune the model, call the :py:meth:`~.pruning.MagnitudePruner.step` method on the pruner after every call to ``optimizer.step()`` to step through the pruning schedule. .. GENERATED FROM PYTHON SOURCE LINES 191-206 .. code-block:: default optimizer = torch.optim.Adam(model.parameters(), eps=1e-07) accuracy_pruned = 0.0 num_epochs = 2 for epoch in range(num_epochs): # train one epoch model.train() for batch_idx, (data, target) in enumerate(train_loader): train_step(model, optimizer, train_loader, data, target, batch_idx, epoch) pruner.step() # evaluate accuracy_pruned = eval_model(model, test_loader) .. GENERATED FROM PYTHON SOURCE LINES 207-212 The evaluation shows that you can train a pruned network without losing accuracy with the final model. In practice, for more complex models, you have a trade-off between the sparsity and the validation accuracy that can be achieved for the model. Finding the right sweet spot on this trade-off curve depends on the model and task. .. GENERATED FROM PYTHON SOURCE LINES 212-218 .. code-block:: default print("Accuracy of pruned network: {:.1f}%\n".format(accuracy_pruned)) print("Accuracy of unpruned network: {:.1f}%\n".format(accuracy_unpruned)) np.testing.assert_allclose(accuracy_pruned, accuracy_unpruned, atol=2) .. GENERATED FROM PYTHON SOURCE LINES 219-229 Finalizing the Model for Export ------------------------------- The example shows that you can prune the model with a few code changes to your existing PyTorch training code. Now you can deploy this model on a device. To finalize the model for export, call :py:meth:`~.pruning.MagnitudePruner.finalize` on the pruner. This removes all the forward pre-hooks you had attached on the submodules. It also freezes the state of the pruner and multiplies the pruning mask with the corresponding weight matrix. .. GENERATED FROM PYTHON SOURCE LINES 229-233 .. code-block:: default model.eval() pruner.finalize(inplace=True) .. GENERATED FROM PYTHON SOURCE LINES 234-244 Exporting the Model for On-Device Execution ------------------------------------------- In order to deploy the model, convert it to a Core ML model. Follow the same steps in Core ML Tools for exporting a regular PyTorch model (for details, see `Converting from PyTorch `_). The parameter ``ct.PassPipeline.DEFAULT_PRUNING`` signals to the converter that the model being converted is a pruned model, and allows the model weights to be represented as sparse matrices, which have a smaller memory footprint than dense matrices. .. GENERATED FROM PYTHON SOURCE LINES 244-258 .. code-block:: default import coremltools as ct example_input = torch.rand(1, 1, 28, 28) traced_model = torch.jit.trace(model, example_input) coreml_model = ct.convert( traced_model, inputs=[ct.TensorType(shape=example_input.shape)], pass_pipeline=ct.PassPipeline.DEFAULT_PRUNING, minimum_deployment_target=ct.target.iOS16, ) coreml_model.save("~/.mnist_pruning_data/pruned_model.mlpackage") .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.000 seconds) .. _sphx_glr_download__examples_magnitude_pruning.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: magnitude_pruning.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: magnitude_pruning.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_