Pruning
Pruning a model is the process of sparsifying the weight matrices of the model’s layers, thereby reducing its storage size. You can also use pruning to reduce a model’s inference latency and power consumption.
Magnitude Pruning
- class coremltools.optimize.torch.pruning.ModuleMagnitudePrunerConfig(scheduler: PolynomialDecayScheduler | ConstantSparsityScheduler = ConstantSparsityScheduler(begin_step=0), initial_sparsity: float = 0.0, target_sparsity: float = 0.5, granularity: str = 'per_scalar', block_size: int = 1, n_m_ratio: Tuple[int, int] | None = None, dim: int = 1, param_name: str = 'weight')[source]
Configuration class for specifying global and module level pruning options for magnitude pruning algorithm implemented in
MagnitudePruner
.This class supports four different modes of sparsity:
1. Unstructured sparsity: This is the default sparsity mode used by
MagnitudePruner
. It is activated whenblock_size = 1
,n_m_ratio = None
andgranularity = per_scalar
. In this mode, then
weights with the lowest absolute values are set to 0, wheren = floor(size_of_weight_tensor * target_sparsity)
. For example, given the following:weight = [0.3, -0.2, -0.01, 0.05]
target_sparsity = 0.75
The pruned weight would be
[0.3, 0, 0, 0]
2. Block structured sparsity: This mode is activated when
block_size > 1
andn_m_ratio = None
. In this mode, the weight matrix is first reshaped to a rank 2 matrix by folding all dimensions>= 1
into a single dimension. Then, blocks of sizeblock_size
along the0-th
dimension, which have the lowestL2
norm, are set to 0. The number of blocks which are zeroed out is determined by thetarget_sparsity
parameter. The blocks are chosen in a non-overlapping fashion.For example:
# Given a 4 x 2 weight with the following value, and block_size = 2. [ [1, 3], [-6, -7], [0, 3], [-9, 2], ] # L2 norm is computed along the 0-th dimension for blocks of size 2: [ [6.08, 7.62], [9.00, 3.61], ] # Then the smallest values are picked to prune. So if target_sparsity = 0.5, # then the blocks that will be pruned will be with ones with L2 norm values # of 6.08 and 3.61. And hence, the elements in the first and third # block are pruned. The final pruned tensor is: [ [0, 3], [0, -7], [0, 0], [-9, 0], ]
3. n:m structured sparsity: This mode is activated when
n_m_ratio != None
. Similar to block structured sparsity, in this mode, the weight matrix is reshaped to a rank 2 matrix. Then, out of non-overlapping blocks of sizem
along the0-th
or1-st
dimension, then
elements with the smallest absolute value are set to 0. The dimension along which the blocks are chosen is controlled by thedim
parameter and it defaults to1
. For linear layers,dim = 1
and ratios wherem
is a factor of 16 (e.g.3:4
,7:8
etc.) are recommended to get latency gains for models executing specifically on the CPU.For example:
# Given a 4 x 4 weight of [ [3, 4, 7, 6], [1, 8, -3, -8], [-2, -3, -4, 0], [5, 4, -3, -2], ] # For n_m_ratio = (1, 2) with dim = 1 (default), the resulting pruned weight is [ [0, 4, 7, 0], [0, 8, 0, -8], [0, -3, -4, 0], [5, 0, -3, 0], ]
4. General structured sparsity: This mode is activated when
granularity
is set to one ofper_channel
orper_kernel
. It only applies to weights ofrank >= 3
. For example, a rank 4 weight matrix of shape[C_o x C_i x H x W]
can be thought of asC_o
matrices of shape[C_i x H X W]
orC_o*C_i
matrices of size[H x W]
.per_channel
granularity sets some of the[C_i x H X W]
matrices to 0 whereasper_kernel
granularity sets some of the[H x W]
matrices to 0.When granularity is
per_channel
, the weight matrix is reshaped to a rank 2 matrix, where all dimensions>= 1
are folded into a single dimension. ThenL2
norm is computed for all rows and the weights corresponding ton
smallestL2
norm rows are set to 0 to achievetarget_sparsity
.For example:
# Given a 2 x 2 x 1 x 2 weight, granularity = per_channel, [ [ [[2, -1]], [[-3, 2]], ], [ [[5, -2]], [[-1, -3]], ], ] # It is first reshaped to shape 2 x 4, i.e.: [ [2, -1, -3, 2], [5, -2, -1, -3], ] # Then L2 norm is computed for each row of the matrix: [4.2426, 6.2450] # Finally, to achieve target sparsity = 0.5, since the first element is # smaller, the corresponding row is set to 0, resulting in the pruned weight: [ [ [[0, 0]], [[0, 0]], ], [ [[5, -2]], [[-1, -3]], ], ]
When granularity is
per_kernel
, the weight matrix is reshaped to a rank 3 matrix, where all dimensions>= 2
are folded into a single dimension. ThenL2
norm is computed for all vectors along the last dimension,dim = 2
and the weights corresponding to then
smallestL2
norm vectors are set to 0 to achievetarget_sparsity
.For the same example as before, setting granularity
per_kernel
will achieve:# The original 2 x 2 x 1 x 2 weight matrix is reshaped into shape 2 x 2 x 2, i.e.: [ [[2, -1], [-3, 2]], [[5, -2], [-1, -3]], ] # Then L2 norm is computed for each of the 4 vectors of size 2, [2, -1], [-3, 2], etc.: [ [2.2361, 3.6056], [5.3852, 3.1623], ] # Finally, to achieve target sparsity = 0.5, since the first and last elements are # smallest, the corresponding row in the weights is set to 0, # resulting in the pruned weight: [ [ [[0, 0]], [[-3, 2]], ], [ [[5, -2]], [[0, 0]], ], ]
- Parameters:
scheduler (
PruningScheduler
) – A pruning scheduler which specifies how the sparsity should be changed over the course of the training. Defaults to constant sparsity scheduler which sets the sparsity totarget_sparsity
at step0
.initial_sparsity (
float
) – Desired fraction of zeroes at the beginning of the training process. Defaults to0.0
.target_sparsity (
float
) – Desired fraction of zeroes at the end of the training process. Defaults to0.5
.granularity (
str
) – Specifies the granularity at which the pruning mask will be computed. Can be one ofper_channel
,per_kernel
orper_scalar
. Defaults toper_scalar
.block_size (
int
) – Block size for inducing block sparsity within the mask. This is applied on the output channel dimension of the parameter (the0
-th dimension). Having larger block size may be beneficial for latency compared to smaller block sizes, for models running on certain compute units such as the neural engine.block_size
must be greater than1
to enable block sparsity, and must be at most half the number of output channels. When the number of output channels is not divisible by the block size, the weight matrix is padded with zeros to compute the pruning mask and then un-padded to the original size. Defaults to1
.n_m_ratio (
tuple
ofint
) – A tuple of two integers which specify hown:m
pruning should be applied. Inn:m
pruning, out of everym
elements,n
with lowest magnitude are set to zero. Whenn_m_ratio
is notNone
,block_size
,granularity
, andinitial_sparsity
should be1
,per_scalar
, and0.0
respectively. The value oftarget_sparsity
is ignored and the actual target sparsity is determined by then:m
ratio. For more information, see Learning N:M Fine-Grained Structured Sparse Neural Networks From Scratch. Defaults toNone
, which meansn:m
sparsity is not used.dim (
int
) – Dimension along which blocks ofm
elements are chosen when applyingn:m
sparsity. This parameter is only used whenn_m_ratio
is notNone
. Defaults to1
.param_name (
str
) – The name of the parameter to be pruned. Defaults toweight
.
- as_dict() Dict[str, Any]
Returns the config as a dictionary.
- classmethod from_dict(data_dict: Dict[str, Any]) DictableDataClass
Create class from a dictionary of string keys and values.
- Parameters:
data_dict (
dict
ofstr
and values) – A nested dictionary of strings and values.
- classmethod from_yaml(yml: IO | str) DictableDataClass
Create class from a yaml stream.
- Parameters:
yml – An
IO
stream containing yaml or astr
path to the yaml file.
- class coremltools.optimize.torch.pruning.MagnitudePrunerConfig(global_config: ModuleMagnitudePrunerConfig | None = None, module_type_configs: ModuleTypeConfigType = NOTHING, module_name_configs: Dict[str, ModuleMagnitudePrunerConfig | None] = NOTHING)[source]
Configuration class for specifying how different submodules in a model are pruned by
MagnitudePruner
.- Parameters:
global_config (
ModuleMagnitudePrunerConfig
) – Config to be applied globally to all supported modules. Missing values are chosen from the default config.module_type_configs (
dict
ofstr
toModuleMagnitudePrunerConfig
) – Module type level configs applied to a specific module class, such astorch.nn.Linear
. The keys can be either strings or module classes. Ifmodule_type_config
is set toNone
for a module type, it wouldn’t get pruned.module_name_configs (
dict
ofstr
toModuleMagnitudePrunerConfig
) – Module level configs applied to specific modules. The name of the module must be a fully qualified name that can be used to fetch it from the top level module using themodule.get_submodule(target)
method. Ifmodule_name_config
is set toNone
for a module, it wouldn’t get pruned.
- as_dict() Dict[str, Any]
Returns the config as a dictionary.
- classmethod from_dict(config_dict: Dict[str, Any]) MagnitudePrunerConfig [source]
Create class from a dictionary of string keys and values.
- Parameters:
config_dict (
dict
ofstr
and values) – A nested dictionary of strings and values.
- classmethod from_yaml(yml: IO | str) DictableDataClass
Create class from a yaml stream.
- Parameters:
yml – An
IO
stream containing yaml or astr
path to the yaml file.
- set_global(global_config: ModuleOptimizationConfig | None) OptimizationConfig
Set the global config.
- set_module_name(module_name: str, opt_config: ModuleOptimizationConfig | None) OptimizationConfig
Set the module level optimization config for a given module instance. If the module level optimization config for an existing module was already set, the new config will override the old one.
- set_module_type(object_type: Callable | str, opt_config: ModuleOptimizationConfig | None) OptimizationConfig
Set the module level optimization config for a given module type. If the module level optimization config for an existing module type was already set, the new config will override the old one.
- class coremltools.optimize.torch.pruning.MagnitudePruner(model: Module, config: MagnitudePrunerConfig | None = None)[source]
A pruning algorithm based on To prune, or not to prune: exploring the efficacy of pruning for model compression. It extends the idea in the paper to different kinds of structured sparsity modes, in addition to unstructured sparsity. In order to achieve the desired sparsity, this algorithm sorts a module’s weight matrix by the magnitude of its elements, and sets all elements less than a threshold to zero.
Four different modes of sparsity are supported, encompassing both structured and unstructured sparsity. For details on how to select these different sparsity modes, please see
ModuleMagnitudePrunerConfig
.Example
import torch from collections import OrderedDict from coremltools.optimize.torch.pruning import MagnitudePruner, MagnitudePrunerConfig # define model and loss function model = torch.nn.Sequential( OrderedDict( [ ("conv1", torch.nn.Conv2d(3, 32, 3, padding="same")), ("conv2", torch.nn.Conv2d(32, 32, 3, padding="same")), ] ) ) loss_fn = define_loss() # define the loss function # initialize pruner and configure it # we only prune the fisrt conv layer config = MagnitudePrunerConfig.from_dict( { "module_name_configs": { "conv1": { "scheduler": {"update_steps": [3, 5, 7]}, "target_sparsity": 0.75, "granularity": "per_channel", }, } } ) pruner = MagnitudePruner(model, config) # insert pruning layers in the model model = pruner.prepare() for inputs, labels in data: output = model(inputs) loss = loss_fn(output, labels) loss.backward() optimizer.step() pruner.step() # commit pruning masks to model parameters pruner.finalize(inplace=True)
- Parameters:
model (
torch.nn.Module
) – Model on which the pruner will act.config (
MagnitudePrunerConfig
) – Config which specifies how different submodules in the model will be configured for pruning. Default config is used when passed asNone
.
- finalize(model: Module | None = None, inplace: bool = False) Module
Prepares the model for export. Removes pruning forward pre-hooks attached to submodules and commits pruning changes to pruned module parameters by multiplying the pruning masks with the parameter matrix.
- Parameters:
model (
nn.Module
) – model to finalizeinplace (
bool
) – IfTrue
, model transformations are carried out in-place and the original module is mutated, otherwise a copy of the model is mutated and returned.
- prepare(inplace: bool = False) Module [source]
Prepares the model for pruning.
- Parameters:
inplace (
bool
) – IfTrue
, model transformations are carried out in-place and the original module is mutated, otherwise a copy of the model is mutated and returned.
- report() _Report
Returns a dictionary with important statistics related to current state of pruning. Each key in the dictionary corresponds to a module name and the value is a dictionary containing the statistics such as
unstructured_weight_sparsity
, number of parameters, etc. Also contains aglobal
key containing the same statistics aggregated over all the modules set up for pruning.
Pruning scheduler
The coremltools.optimize.torch.pruning.pruning_scheduler
submodule contains classes
that implement pruning schedules, which can be used for changing the
sparsity of pruning masks applied by various types of pruning algorithms
to prune neural network parameters.
- class coremltools.optimize.torch.pruning.pruning_scheduler.PruningScheduler[source]
Bases:
ABC
An abstraction for implementing schedules to be used for changing the sparsity of pruning masks applied by various types of pruning algorithms to module parameters over the course of the training.
- class coremltools.optimize.torch.pruning.pruning_scheduler.PolynomialDecayScheduler(update_steps: List[int] | str | Tensor, power: int = 3)[source]
Bases:
PruningScheduler
A pruning scheduler inspired by the paper “To prune or not to prune”.
It sets the sparsity at step \(t\) using the formula:
\[sparsity_t = target\_sparsity + (initial\_sparsity - target\_sparsity) * (1 - \frac{update\_index}{total\_number\_of\_updates}) ^ {power}\]If \(t\) is in \(update\_steps\), else it keeps the sparsity at its previous value.
Here, \(update\_index\) is the index of \(t\) in the \(update\_steps\) array and \(total\_number\_of\_updates\) is the length of \(update\_steps\) array.
- Parameters:
update_steps (
list
ofint
orstr
) – The indices of optimization steps at which pruning should be performed. This can be passed in as a string representing the range, such asrange(start_index, end_index, step_size)
.power (
int
, optional) – Exponent to be used in the sparsity function. Defaults to3
.
- compute_sparsity(step_count: int, prev_sparsity: float, config: ModuleOptimizationConfig) float [source]
Compute the sparsity at the next step given the previous sparsity and the module optimization config.
- Parameters:
step_count (
int
) – Current step count.prev_sparsity (
float
) – Sparsity at previous step.config (
ModuleOptimizationConfig
) – Optimization config for the module which contains information such as target sparsity and initial sparsity.
- class coremltools.optimize.torch.pruning.pruning_scheduler.ConstantSparsityScheduler(begin_step: int)[source]
Bases:
PruningScheduler
A pruning schedule with constant sparsity throughout training.
Sparsity is set to zero initially and to
target_sparsity
at stepbegin_step
.- Parameters:
begin_step (
int
) – step at which to begin pruning.
- compute_sparsity(step_count: int, prev_sparsity: float, config: ModuleOptimizationConfig) float [source]
Compute the sparsity at the next step given the previous sparsity and the module optimization config.
- Parameters:
step_count (
int
) – Current step count.prev_sparsity (
float
) – Sparsity at previous step.config (
ModuleOptimizationConfig
) – Optimization config for the module which contains information such as target sparsity and initial sparsity.
SparseGPT
- class coremltools.optimize.torch.layerwise_compression.LayerwiseCompressorConfig(layers: List[Module | str] | ModuleList | None = None, global_config: LayerwiseCompressionAlgorithmConfig | None = None, module_type_configs: ModuleTypeConfigType = NOTHING, module_name_configs: Dict[str, LayerwiseCompressionAlgorithmConfig | None] = NOTHING, input_cacher: str = 'default', calibration_nsamples: int = 128)[source]
Configuration class for specifying how different submodules of a model are compressed by
LayerwiseCompressor
. Note that only sequential models are supported.- Parameters:
layers (
list
oftorch.nn.Module
orstr
) – List of layers to be compressed. When items in the list arestr
, the string can be a regex or the exact name of the module. The layers listed should be immediate child modules of the parent containertorch.nn.Sequential
model, and they should be contiguous. That is, the output of layern
should be the input to layern+1
.global_config (
ModuleGPTQConfig
orModuleSparseGPTConfig
) – Config to be applied globally to all supported modules. Missing values are chosen from the default config.module_type_configs (
dict
ofstr
toModuleGPTQConfig
orModuleSparseGPTConfig
) – Module type configs applied to a specific module class, such astorch.nn.Linear
. The keys can be either strings or module classes.module_name_configs (
dict
ofstr
toModuleGPTQConfig
orModuleSparseGPTConfig
) – Module-level configs applied to specific modules. The name of the module must either be a regex or a fully qualified name that can be used to fetch it from the top level module using themodule.get_submodule(target)
method.input_cacher (
str
orFirstLayerInputCacher
) – Cacher object that caches inputs which are then fed to the first layer set up for compression.calibration_nsamples (
int
) – Number of samples to be used for calibration.
- as_dict() Dict[str, Any]
Returns the config as a dictionary.
- classmethod from_dict(config_dict: Dict[str, Any]) LayerwiseCompressorConfig [source]
Create class from a dictionary of string keys and values.
- Parameters:
config_dict (
dict
ofstr
and values) – A nested dictionary of strings and values.
- classmethod from_yaml(yml: IO | str) DictableDataClass
Create class from a yaml stream.
- Parameters:
yml – An
IO
stream containing yaml or astr
path to the yaml file.
- class coremltools.optimize.torch.layerwise_compression.LayerwiseCompressor(model: Module, config: LayerwiseCompressorConfig)[source]
A post-training compression algorithm which compresses a sequential model layer by layer by minimizing the quantization error while quantizing the weights. The implementation supports two variations of this algorithm:
At a high level, it compresses weights of a model layer by layer by minimizing the L2 norm of the difference between the original activations and activations obtained from compressing the weights of a layer. The activations are computed using a few samples of training data.
Only sequential models are supported, where the output of one layer feeds into the input of the next layer.
For HuggingFace models, disable the
use_cache
config. This is used to speed up decoding, but to generalize forward pass forLayerwiseCompressor
algorithms across all model types, the behavior must be disabled.Example
import torch.nn as nn from coremltools.optimize.torch.layerwise_compression import ( LayerwiseCompressor, LayerwiseCompressorConfig, ) model = nn.Sequential( OrderedDict( { "conv": nn.Conv2d(1, 20, (3, 3)), "relu1": nn.ReLU(), "conv2": nn.Conv2d(20, 20, (3, 3)), "relu2": nn.ReLU(), } ) ) dataloder = load_calibration_data() # initialize the quantizer config = LayerwiseCompressorConfig.from_dict( { "global_config": { "algorithm": "gptq", "weight_dtype": "int4", }, "input_cacher": "default", "calibration_nsamples": 16, } ) compressor = LayerwiseCompressor(model, config) compressed_model = compressor.compress(dataloader)
- Parameters:
model (
torch.nn.Module
) – Module to be compressed.config (
LayerwiseCompressorConfig
) – Config that specifies how different submodules in the model will be compressed.
- compress(dataloader: Iterable, device: str, inplace: bool = False) Module [source]
Compresses model using samples from
dataloader
.- Parameters:
dataloader (
Iterable
) – An iterable where each element is an input to the model to be compressed.device (
str
) – Device string for device to run compression on.inplace (
bool
) – IfTrue
, model transformations are carried out in-place and the original module is mutated, otherwise a copy of the model is mutated and returned. Defaults toFalse
.
- class coremltools.optimize.torch.layerwise_compression.algorithms.ModuleSparseGPTConfig(target_sparsity: float = 0.5, n_m_ratio: Tuple[int, int] | None = None, weight_dtype: str | dtype = 'uint8', quantization_granularity='per_channel', quantization_scheme='symmetric', enable_normal_float: bool = False, hessian_dampening: float = 0.01, processing_group_size: int = 128, algorithm: str = 'sparse_gpt')[source]
Bases:
LayerwiseCompressionAlgorithmConfig
Configuration class for specifying global and module-level compression options for the Sparse Generative Pre-Trained Transformer (SparseGPT) algorithm.
- Parameters:
target_sparsity (
float
) – Fraction of weight elements to set to0
. Defaults to0.5
.n_m_ratio (
tuple
ofint
) – A tuple of two integers which specify hown:m
pruning should be applied. Inn:m
pruning, out of everym
elements,n
with lowest magnitude are set to zero. Whenn_m_ratio
is notNone
, the value oftarget_sparsity
is ignored and the actual target sparsity is determined by then:m
ratio.weight_dtype (
torch.dtype
) – The dtype to use for quantizing the weights. The number of bits used for quantization is inferred from the dtype. When dtype is set totorch.float32
, the weights corresponding to that layer are not quantized. Defaults totorch.float32
, which corresponds to no quantization.quantization_granularity (
QuantizationGranularity
) – Specifies the granularity at which quantization parameters will be computed. Can be one ofper_channel
,per_tensor
orper_block
. When usingper_block
,block_size
argument must be specified. Defaults toper_channel
.quantization_scheme (
QuantizationScheme
) – Type of quantization configuration to use. When this parameter is set toQuantizationScheme.symmetric
, all weights are quantized with zero point as zero. When it is set toQuantizationScheme.affine
, zero point can be set anywhere in the range of values allowed for the quantized weight. Defaults toQuantizationScheme.symmetric
.enable_normal_float (
bool
) – WhenTrue
, normal float format is used for quantization. It’s only supported forweight_dtype
is equal toint3
andint4
.hessian_dampening (
float
) – Dampening factor added to the diagonal of the Hessian used by GPTQ algorithm. Defaults to0.01
.processing_group_size (
int
) – The weights are updated in blocks of size processing_group_size. Defaults to128
.
- class coremltools.optimize.torch.layerwise_compression.algorithms.SparseGPT(layer: Module, config: ModuleSparseGPTConfig)[source]
Bases:
OBSCompressionAlgorithm
A post-training compression algorithm based on the paper SparseGPT: Massive Language Models Can be Accurately Pruned in One-Shot
- Parameters:
layer (
torch.nn.Module
) – Module to be compressed.config (
ModuleSparseGPTConfig
) – Config specifying hyper-parameters for the SparseGPT algorithm.