Training-Time Pruning

Pruning a model is the process of sparsifying the weight matrices of the model’s layers, thereby reducing its storage size. You can also use pruning to reduce a model’s inference latency and power consumption.

Magnitude Pruning

class coremltools.optimize.torch.pruning.ModuleMagnitudePrunerConfig(scheduler: Union[PolynomialDecayScheduler, ConstantSparsityScheduler] = ConstantSparsityScheduler(begin_step=0), initial_sparsity: float = 0.0, target_sparsity: float = 0.5, granularity: str = 'per_scalar', block_size: int = 1, n_m_ratio: Optional[Tuple[int, int]] = None, dim: int = 1, param_name: str = 'weight')[source]

Configuration class for specifying global and module level pruning options for magnitude pruning algorithm implemented in MagnitudePruner.

This class supports four different modes of sparsity:

1. Unstructured sparsity: This is the default sparsity mode used by MagnitudePruner. It is activated when block_size = 1, n_m_ratio = None and granularity = per_scalar. In this mode, the n weights with the lowest absolute values are set to 0, where n = floor(size_of_weight_tensor * target_sparsity). For example, given the following:

  • weight = [0.3, -0.2, -0.01, 0.05]

  • target_sparsity = 0.75

The pruned weight would be [0.3, 0, 0, 0]

2. Block structured sparsity: This mode is activated when block_size > 1 and n_m_ratio = None. In this mode, the weight matrix is first reshaped to a rank 2 matrix by folding all dimensions >= 1 into a single dimension. Then, blocks of size block_size along the 0-th dimension, which have the lowest L2 norm, are set to 0. The number of blocks which are zeroed out is determined by the target_sparsity parameter. The blocks are chosen in a non-overlapping fashion.

For example:

# Given a 4 x 2 weight with the following value, and block_size = 2.
[
    [1, 3],
    [-6, -7],
    [0, 3],
    [-9, 2],
]

# L2 norm  is computed along the 0-th dimension for blocks of size 2:
[
    [6.08, 7.62],
    [9.00, 3.61],
]

# Then the smallest values are picked to prune. So if target_sparsity = 0.5,
# then the blocks that will be pruned will be with ones with L2 norm values
# of 6.08 and 3.61. And hence, the elements in the first and third
# block are pruned. The final pruned tensor is:
[
    [0, 3],
    [0, -7],
    [0, 0],
    [-9, 0],
]

3. n:m structured sparsity: This mode is activated when n_m_ratio != None. Similar to block structured sparsity, in this mode, the weight matrix is reshaped to a rank 2 matrix. Then, out of non-overlapping blocks of size m along the 0-th or 1-st dimension, the n elements with the smallest absolute value are set to 0. The dimension along which the blocks are chosen is controlled by the dim parameter and it defaults to 1. For linear layers, dim = 1 and ratios where m is a factor of 16 (e.g. 3:4, 7:8 etc.) are recommended to get latency gains for models executing specifically on the CPU.

For example:

# Given a 4 x 4 weight of
[
    [3, 4, 7, 6],
    [1, 8, -3, -8],
    [-2, -3, -4, 0],
    [5, 4, -3, -2],
]

# For n_m_ratio = (1, 2) with dim = 1 (default), the resulting pruned weight is
[
    [0, 4, 7, 0],
    [0, 8, 0, -8],
    [0, -3, -4, 0],
    [5, 0, -3, 0],
]

4. General structured sparsity: This mode is activated when granularity is set to one of per_channel or per_kernel. It only applies to weights of rank >= 3. For example, a rank 4 weight matrix of shape [C_o x C_i x H x W] can be thought of as C_o matrices of shape [C_i x H X W] or C_o*C_i matrices of size [H x W]. per_channel granularity sets some of the [C_i x H X W] matrices to 0 whereas per_kernel granularity sets some of the [H x W] matrices to 0.

When granularity is per_channel, the weight matrix is reshaped to a rank 2 matrix, where all dimensions >= 1 are folded into a single dimension. Then L2 norm is computed for all rows and the weights corresponding to n smallest L2 norm rows are set to 0 to achieve target_sparsity.

For example:

# Given a 2 x 2 x 1 x 2 weight, granularity = per_channel,
[
    [
        [[2, -1]],
        [[-3, 2]],
    ],
    [
        [[5, -2]],
        [[-1, -3]],
    ],
]

# It is first reshaped to shape 2 x 4, i.e.:
[
    [2, -1, -3, 2],
    [5, -2, -1, -3],
]

# Then L2 norm is computed for each row of the matrix:
[4.2426, 6.2450]

# Finally, to achieve target sparsity = 0.5, since the first element is
# smaller, the corresponding row is set to 0, resulting in the pruned weight:
[
    [
        [[0, 0]],
        [[0, 0]],
    ],
    [
        [[5, -2]],
        [[-1, -3]],
    ],
]

When granularity is per_kernel, the weight matrix is reshaped to a rank 3 matrix, where all dimensions >= 2 are folded into a single dimension. Then L2 norm is computed for all vectors along the last dimension, dim = 2 and the weights corresponding to the n smallest L2 norm vectors are set to 0 to achieve target_sparsity.

For the same example as before, setting granularity per_kernel will achieve:

# The original 2 x 2 x 1 x 2 weight matrix is reshaped into shape 2 x 2 x 2, i.e.:
[
    [[2, -1], [-3, 2]],
    [[5, -2], [-1, -3]],
]

# Then L2 norm is computed for each of the 4 vectors of size 2, [2, -1], [-3, 2], etc.:
[
    [2.2361, 3.6056],
    [5.3852, 3.1623],
]

# Finally, to achieve target sparsity = 0.5, since the first and last elements are
# smallest, the corresponding row in the weights is set to 0,
# resulting in the pruned weight:
[
    [
        [[0, 0]],
        [[-3, 2]],
    ],
    [
        [[5, -2]],
        [[0, 0]],
    ],
]
Parameters:
  • scheduler (PruningScheduler) – A pruning scheduler which specifies how the sparsity should be changed over the course of the training. Defaults to constant sparsity scheduler which sets the sparsity to target_sparsity at step 0.

  • initial_sparsity (float) – Desired fraction of zeroes at the beginning of the training process. Defaults to 0.0.

  • target_sparsity (float) – Desired fraction of zeroes at the end of the training process. Defaults to 0.5.

  • granularity (str) – Specifies the granularity at which the pruning mask will be computed. Can be one of per_channel, per_kernel or per_scalar. Defaults to per_scalar.

  • block_size (int) – Block size for inducing block sparsity within the mask. This is applied on the output channel dimension of the parameter (the 0 -th dimension). Having larger block size may be beneficial for latency compared to smaller block sizes, for models running on certain compute units such as the neural engine. block_size must be greater than 1 to enable block sparsity, and must be at most half the number of output channels. When the number of output channels is not divisible by the block size, the weight matrix is padded with zeros to compute the pruning mask and then un-padded to the original size. Defaults to 1.

  • n_m_ratio (tuple of int) – A tuple of two integers which specify how n:m pruning should be applied. In n:m pruning, out of every m elements, n with lowest magnitude are set to zero. When n_m_ratio is not None, block_size, granularity, and initial_sparsity should be 1, per_scalar, and 0.0 respectively. The value of target_sparsity is ignored and the actual target sparsity is determined by the n:m ratio. For more information, see Learning N:M Fine-Grained Structured Sparse Neural Networks From Scratch. Defaults to None, which means n:m sparsity is not used.

  • dim (int) – Dimension along which blocks of m elements are chosen when applying n:m sparsity. This parameter is only used when n_m_ratio is not None. Defaults to 1.

  • param_name (str) – The name of the parameter to be pruned. Defaults to weight.

as_dict() Dict[str, Any]

Returns the config as a dictionary.

classmethod from_dict(config_dict: Dict[str, Any]) ModuleOptimizationConfig

Create class from a dictionary of string keys and values.

Parameters:

config_dict (dict of str and values) – A nested dictionary of strings and values.

classmethod from_yaml(yml: Union[IO, str]) ModuleOptimizationConfig

Create class from a yaml stream.

Parameters:

yml – An IO stream containing yaml or a str path to the yaml file.

class coremltools.optimize.torch.pruning.MagnitudePrunerConfig(global_config: Optional[ModuleMagnitudePrunerConfig] = None, module_type_configs: ModuleTypeConfigType = _Nothing.NOTHING, module_name_configs: Dict[str, Optional[ModuleMagnitudePrunerConfig]] = _Nothing.NOTHING)[source]

Configuration class for specifying how different submodules in a model are pruned by MagnitudePruner.

Parameters:
  • global_config (ModuleMagnitudePrunerConfig) – Config to be applied globally to all supported modules. Missing values are chosen from the default config.

  • module_type_configs (dict of str to ModuleMagnitudePrunerConfig) – Module type level configs applied to a specific module class, such as torch.nn.Linear. The keys can be either strings or module classes. If module_type_config is set to None for a module type, it wouldn’t get pruned.

  • module_name_configs (dict of str to ModuleMagnitudePrunerConfig) – Module level configs applied to specific modules. The name of the module must be a fully qualified name that can be used to fetch it from the top level module using the module.get_submodule(target) method. If module_name_config is set to None for a module, it wouldn’t get pruned.

as_dict() Dict[str, Any]

Returns the config as a dictionary.

classmethod from_dict(config_dict: Dict[str, Any]) MagnitudePrunerConfig[source]

Create class from a dictionary of string keys and values.

Parameters:

config_dict (dict of str and values) – A nested dictionary of strings and values.

classmethod from_yaml(yml: Union[IO, str]) OptimizationConfig

Create class from a yaml stream.

Parameters:

yml – An IO stream containing yaml or a str path to the yaml file.

set_global(global_config: Optional[ModuleOptimizationConfig]) OptimizationConfig

Set the global config.

set_module_name(module_name: str, opt_config: Optional[ModuleOptimizationConfig]) OptimizationConfig

Set the module level optimization config for a given module instance. If the module level optimization config for an existing module was already set, the new config will override the old one.

set_module_type(object_type: Union[Callable, str], opt_config: Optional[ModuleOptimizationConfig]) OptimizationConfig

Set the module level optimization config for a given module type. If the module level optimization config for an existing module type was already set, the new config will override the old one.

class coremltools.optimize.torch.pruning.MagnitudePruner(model: Module, config: Optional[MagnitudePrunerConfig] = None)[source]

A pruning algorithm based on To prune, or not to prune: exploring the efficacy of pruning for model compression. It extends the idea in the paper to different kinds of structured sparsity modes, in addition to unstructured sparsity. In order to achieve the desired sparsity, this algorithm sorts a module’s weight matrix by the magnitude of its elements, and sets all elements less than a threshold to zero.

Four different modes of sparsity are supported, encompassing both structured and unstructured sparsity. For details on how to select these different sparsity modes, please see ModuleMagnitudePrunerConfig.

Example

import torch
from collections import OrderedDict
from coremltools.optimize.torch.pruning import MagnitudePruner, MagnitudePrunerConfig

# define model and loss function
model = torch.nn.Sequential(
    OrderedDict(
        [
            ("conv1", torch.nn.Conv2d(3, 32, 3, padding="same")),
            ("conv2", torch.nn.Conv2d(32, 32, 3, padding="same")),
        ]
    )
)
loss_fn = define_loss()  # define the loss function

# initialize pruner and configure it
# we only prune the first conv layer
config = MagnitudePrunerConfig.from_dict(
    {
        "module_name_configs": {
            "conv1": {
                "scheduler": {"update_steps": [3, 5, 7]},
                "target_sparsity": 0.75,
                "granularity": "per_channel",
            },
        }
    }
)

pruner = MagnitudePruner(model, config)

# insert pruning layers in the model
model = pruner.prepare()

for inputs, labels in data:
    output = model(inputs)
    loss = loss_fn(output, labels)
    loss.backward()
    optimizer.step()
    pruner.step()

# commit pruning masks to model parameters
pruner.finalize(inplace=True)
Parameters:
  • model (torch.nn.Module) – Model on which the pruner will act.

  • config (MagnitudePrunerConfig) – Config which specifies how different submodules in the model will be configured for pruning. Default config is used when passed as None.

finalize(model: Optional[Module] = None, inplace: bool = False) Module

Prepares the model for export. Removes pruning forward pre-hooks attached to submodules and commits pruning changes to pruned module parameters by multiplying the pruning masks with the parameter matrix.

Parameters:
  • model (nn.Module) – model to finalize

  • inplace (bool) – If True, model transformations are carried out in-place and the original module is mutated, otherwise a copy of the model is mutated and returned.

prepare(inplace: bool = False) Module[source]

Prepares the model for pruning.

Parameters:

inplace (bool) – If True, model transformations are carried out in-place and the original module is mutated, otherwise a copy of the model is mutated and returned.

report() _Report

Returns a dictionary with important statistics related to current state of pruning. Each key in the dictionary corresponds to a module name and the value is a dictionary containing the statistics such as unstructured_weight_sparsity, number of parameters, etc. Also contains a global key containing the same statistics aggregated over all the modules set up for pruning.

step()[source]

Steps through the pruning schedule once. At every call to step(), an internal step counter is incremented by one.

Pruning scheduler

The coremltools.optimize.torch.pruning.pruning_scheduler submodule contains classes that implement pruning schedules, which can be used for changing the sparsity of pruning masks applied by various types of pruning algorithms to prune neural network parameters.

class coremltools.optimize.torch.pruning.pruning_scheduler.PruningScheduler[source]

Bases: ABC

An abstraction for implementing schedules to be used for changing the sparsity of pruning masks applied by various types of pruning algorithms to module parameters over the course of the training.

class coremltools.optimize.torch.pruning.pruning_scheduler.PolynomialDecayScheduler(update_steps: Union[List[int], str, Tensor], power: int = 3)[source]

Bases: PruningScheduler

A pruning scheduler inspired by the paper “To prune or not to prune”.

It sets the sparsity at step \(t\) using the formula:

\[sparsity_t = target\_sparsity + (initial\_sparsity - target\_sparsity) * (1 - \frac{update\_index}{total\_number\_of\_updates}) ^ {power}\]

If \(t\) is in \(update\_steps\), else it keeps the sparsity at its previous value.

Here, \(update\_index\) is the index of \(t\) in the \(update\_steps\) array and \(total\_number\_of\_updates\) is the length of \(update\_steps\) array.

Parameters:
  • update_steps (list of int or str) – The indices of optimization steps at which pruning should be performed. This can be passed in as a string representing the range, such as range(start_index, end_index, step_size).

  • power (int, optional) – Exponent to be used in the sparsity function. Defaults to 3.

compute_sparsity(step_count: int, prev_sparsity: float, config: ModuleOptimizationConfig) float[source]

Compute the sparsity at the next step given the previous sparsity and the module optimization config.

Parameters:
  • step_count (int) – Current step count.

  • prev_sparsity (float) – Sparsity at previous step.

  • config (ModuleOptimizationConfig) – Optimization config for the module which contains information such as target sparsity and initial sparsity.

class coremltools.optimize.torch.pruning.pruning_scheduler.ConstantSparsityScheduler(begin_step: int)[source]

Bases: PruningScheduler

A pruning schedule with constant sparsity throughout training.

Sparsity is set to zero initially and to target_sparsity at step begin_step.

Parameters:

begin_step (int) – step at which to begin pruning.

compute_sparsity(step_count: int, prev_sparsity: float, config: ModuleOptimizationConfig) float[source]

Compute the sparsity at the next step given the previous sparsity and the module optimization config.

Parameters:
  • step_count (int) – Current step count.

  • prev_sparsity (float) – Sparsity at previous step.

  • config (ModuleOptimizationConfig) – Optimization config for the module which contains information such as target sparsity and initial sparsity.