Combining Compression Types#
In previous sections on palettization, quantization, and sparsity, we considered how to apply the various compression techniques to the weights and activations of the model independently. In this section, we describe how these techniques can be combined, which may be beneficial to get even more disk savings and latency improvements.
We first start by looking at how to take an uncompressed mlpackage
and get a joint compressed model by using the ct.optimize.coreml.*
APIs.
As discussed in previous sections, this approach may or
may not yield a highly accurate model. In some cases, however,
this is the best way to get the model in the desired format to test
out the expected disk size savings and performance (latency, runtime memory etc).
Once a model has the desired performance characteristics,
a better accuracy model can be generated by applying the various data
based optimization methods available in ct.optimize.torch.*
.
This last topic is discussed via a few API code snippets in the section below.
Combining compression types on an mlpackage#
Joint palettization and quantization#
This means using a lookup table (LUT) whose values are of the dtype INT8/UINT8 instead of Float16 which is the default. This can help speed up inference when combined with INT8 activations. For instance, you could take a A16W16 model, quantize the activations to get A8W16 model, and then quantize the weights to a 4-bit LUT with INT8 dtype to yield an A8W4 model, where “W4” refers to a palettized weights with a LUT that has 2^4 entries, and each entry has a dtype of INT8. When such a model is run on the Neural Engine (on newer SoCs >= A17pro, M4), it will utilize the faster int8-int8 compute path.
from coremltools.optimize.coreml import (
OptimizationConfig,
OpPalettizerConfig,
OpLinearQuantizerConfig,
palettize_weights,
linear_quantize_weights,
)
# mlmodel: an uncompressed mlpackage, loaded into memory
# first palettize the model
# this will produce an LUT with Float values
op_config = OpPalettizerConfig(nbits=4)
config = OptimizationConfig(global_config=op_config)
mlmodel_palettized = palettize_weights(mlmodel, config)
# now apply weight quantization on the model,
# with "joint_compression" set to True.
# this will result in quantizing the LUT to 8 bits.
# (granularity must be set to "per-tensor" for this scenario)
op_config = OpLinearQuantizerConfig(mode="linear_symmetric",
granularity="per_tensor")
linear_weight_quantize_config = OptimizationConfig(global_config=op_config)
mlmodel_palettized_with_8bit_lut = linear_quantize_weights(mlmodel_palettized,
linear_weight_quantize_config,
joint_compression=True)
Joint sparsity and quantization#
This means quantizing the non-zero values in the sparse weight tensor to INT8/UINT8 values. This could improve inference speed and disk savings.
from coremltools.optimize.coreml import (
OptimizationConfig,
OpMagnitudePrunerConfig,
OpLinearQuantizerConfig,
prune_weights,
linear_quantize_weights,
)
# first prune the model
op_config = OpMagnitudePrunerConfig(target_sparsity=0.80)
config = OptimizationConfig(global_config=op_config)
mlmodel_pruned = prune_weights(mlmodel, config=config)
# now apply weight quantization on the model,
# with "joint_compression" set to True.
# this will result in quantizing the non-zero values to 8 bits.
linear_weight_quantize_config = OptimizationConfig(
global_config=OpLinearQuantizerConfig(mode="linear_symmetric")
)
mlmodel_pruned_quantized = linear_quantize_weights(mlmodel_pruned,
linear_weight_quantize_config,
joint_compression=True)
Joint sparsity and palettization#
This means representing the non-zero values in a sparse weight tensor with discrete values pointing to a lookup table (i.e. palettized).
from coremltools.optimize.coreml import (
OptimizationConfig,
OpMagnitudePrunerConfig,
OpPalettizerConfig,
prune_weights,
palettize_weights,
)
# first prune the model
op_config = OpMagnitudePrunerConfig(target_sparsity=0.80)
pruning_config = OptimizationConfig(global_config=op_config)
mlmodel_pruned = prune_weights(mlmodel, config=pruning_config)
# now apply weight palettization on the model,
# with "joint_compression" set to True.
# this will result in palettizing the non-zero values.
palettization_config = OptimizationConfig(global_config=OpPalettizerConfig(nbits=4))
mlmodel_pruned_palettized = palettize_weights(mlmodel_pruned,
palettization_config,
joint_compression=True)
Combining compression types on a Torch model#
Joint palettization and quantization#
This means using a lookup table (LUT) whose values are of the dtype INT8/UINT8 instead of the default Float16.
import torchvision
import torch
import coremltools as ct
from coremltools.optimize.torch.palettization import PostTrainingPalettizerConfig,\
PostTrainingPalettizer
# load a torch model
# e.g. resnet50
model = torchvision.models.resnet50(weights="IMAGENET1K_V2")
model.eval()
# specify "lut_dtype" as torch.int8
# when not specified, it defaults to None and FP16 LUT is constructed
config_dict = {"global_config": {"n_bits": 4, "lut_dtype" : torch.int8}}
palettizer_config = PostTrainingPalettizerConfig.from_dict(config_dict)
compressor = PostTrainingPalettizer(model, palettizer_config)
compressed_model = compressor.compress()
# convert the compressed model
traced_model = torch.jit.trace(compressed_model, torch.rand(1, 3, 256, 256))
mlmodel = ct.convert(traced_model,
inputs=[ct.TensorType(shape=(1, 3, 256, 256))],
minimum_deployment_target=ct.target.macOS15,
)
mlmodel.save("model_4bit_palettized_with_8bit_quantized_lut.mlpackage")
Joint sparsity and quantization#
One way to combine sparsity and quantization is to
first prune a torch model using the MagnitudePruner
class,
export it as an mlpackage
and then apply weight quantization (A16W8) on
the mlpackage
, as shown in the section above.
However, if we want to apply pruning and weight-only quantization (A16W8) on the torch model at training time, it can be done in the way explained below.
Note that if "activation_dtype"
argument in
ModuleLinearQuantizerConfig
is set to its default value of torch.qint8
,
then activations will also be quantized to get an A8W8 model.
import torchvision
import torch
import coremltools as ct
from coremltools.optimize.torch.quantization import ModuleLinearQuantizerConfig, \
LinearQuantizerConfig, \
LinearQuantizer
from coremltools.optimize.torch.pruning import ModuleMagnitudePrunerConfig, \
MagnitudePrunerConfig, \
MagnitudePruner
# Initialize model and optimizer
# e.g. Resnet50
model = torchvision.models.resnet50(weights="IMAGENET1K_V2")
model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# Prepare model for joint quantization and pruning
quant_config = LinearQuantizerConfig(
global_config=ModuleLinearQuantizerConfig(
quantization_scheme="symmetric",
activation_dtype=torch.float,
)
)
prune_config = MagnitudePrunerConfig(
global_config=ModuleMagnitudePrunerConfig(
target_sparsity=0.8,
)
)
# The quantizer config needs to be applied before the pruner config
quantizer = LinearQuantizer(model, quant_config)
quant_model = quantizer.prepare(example_inputs=[1, 3, 256, 256])
pruner = MagnitudePruner(quant_model, prune_config)
# in-place is required to ensure quantizer and pruner are
# operating on the same model
pruned_quant_model = pruner.prepare(inplace=True)
n_classes = 1000
batch_size = 5
# run a couple of training iterations with random data
for i in range(2):
# Dummy data
inputs = torch.randn(batch_size, 3, 256, 256) # Batch of samples
targets = torch.randint(0, n_classes, (batch_size,)) # Target labels
# Forward pass
logits = pruned_quant_model(inputs)
out = torch.nn.LogSoftmax(dim=1)(logits)
loss = torch.nn.functional.nll_loss(out, targets)
print(loss)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
quantizer.step()
pruner.step()
# finalize the model for export
# we first finalize the quantizer followed by the pruner
quant_finalized_model = quantizer.finalize(inplace=True)
finalized_model = pruner.finalize(quant_finalized_model, inplace=True)
finalized_model.eval()
# trace and export to mlpackage
traced_model = torch.jit.trace(finalized_model, torch.rand(1, 3, 256, 256))
mlmodel = ct.convert(traced_model,
inputs=[ct.TensorType(shape=(1, 3, 256, 256))],
minimum_deployment_target=ct.target.macOS15,
)
mlmodel.save("model_torch_pruned_and_quantized.mlpackage")
Joint sparsity and palettization#
Here we apply magnitude pruning to the torch model, followed by data-free palettization.
Note: to apply training time pruning and
palettization (e.g. DKM), follow
the same pattern as the section above, replacing the LinearQuantizer
with DKMPalettizer
.
import torchvision
import torch
import coremltools as ct
from coremltools.optimize.torch.pruning import ModuleMagnitudePrunerConfig, \
MagnitudePrunerConfig, \
MagnitudePruner
from coremltools.optimize.torch.palettization import PostTrainingPalettizer, \
PostTrainingPalettizerConfig, \
ModulePostTrainingPalettizerConfig
# Apply pruning
# Initialize model and optimizer
# e.g. Resnet50
model = torchvision.models.resnet50(weights="IMAGENET1K_V2")
model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# Prepare model for pruning
prune_config = MagnitudePrunerConfig(
global_config=ModuleMagnitudePrunerConfig(
target_sparsity=0.8,
)
)
pruner = MagnitudePruner(model, prune_config)
pruned_model = pruner.prepare()
# run a couple of training iterations with random data
n_classes = 1000
batch_size = 5
for i in range(2):
inputs = torch.randn(batch_size, 3, 256, 256) # Batch of samples
targets = torch.randint(0, n_classes, (batch_size,)) # Target labels
logits = pruned_model(inputs)
out = torch.nn.LogSoftmax(dim=1)(logits)
loss = torch.nn.functional.nll_loss(out, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
pruner.step()
# finalize model
pruned_model = pruner.finalize(pruned_model, inplace=True)
# Apply palettization
palettization_config = PostTrainingPalettizerConfig(
global_config=ModulePostTrainingPalettizerConfig(
n_bits=4,
)
)
palettizer = PostTrainingPalettizer(pruned_model, palettization_config)
joint_compressed_model = palettizer.compress()
# convert the compressed model
joint_compressed_model.eval()
traced_model = torch.jit.trace(joint_compressed_model, torch.rand(1, 3, 256, 256))
mlmodel = ct.convert(traced_model,
inputs=[ct.TensorType(shape=(1, 3, 256, 256))],
minimum_deployment_target=ct.target.macOS15,
)
mlmodel.save("model_torch_pruned_and_palettized.mlpackage")