Source code for coremltools.optimize.torch.layerwise_compression.layerwise_compressor

#  Copyright (c) 2024, Apple Inc. All rights reserved.
#
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

# Original implementation from https://github.com/IST-DASLab/sparsegpt
# Copyright 2023 IST Austria Distributed Algorithms and Systems Lab. All Rights Reserved.

import logging as _logging
import re as _re
from collections import OrderedDict as _OrderedDict
from contextlib import contextmanager as _contextmanager
from typing import Any as _Any
from typing import Callable as _Callable
from typing import Dict as _Dict
from typing import Iterable as _Iterable
from typing import List as _List
from typing import NewType as _NewType
from typing import Optional as _Optional
from typing import Tuple as _Tuple
from typing import Union as _Union

import cattrs as _cattrs
import torch as _torch
import torch.nn as _nn
from attr import define as _define
from attr import field as _field
from attrs import validators as _validators

from coremltools.optimize.torch._utils.metadata_utils import (
    register_metadata_version as _register_metadata_version,
)
from coremltools.optimize.torch._utils.report_utils import (
    compute_post_training_report as _compute_post_training_report,
)
from coremltools.optimize.torch._utils.torch_utils import get_atomic_layers as _get_atomic_layers
from coremltools.optimize.torch._utils.torch_utils import get_eval_model as _get_eval_model
from coremltools.optimize.torch.base_model_optimizer import (
    BaseDataCalibratedModelOptimizer as _BaseDataCalibratedModelOptimizer,
)
from coremltools.optimize.torch.base_model_optimizer import _Report
from coremltools.optimize.torch.layerwise_compression.algorithms import (
    LayerwiseCompressionAlgorithm as _LayerwiseCompressionAlgorithm,
)
from coremltools.optimize.torch.layerwise_compression.algorithms import (
    LayerwiseCompressionAlgorithmConfig as _LayerwiseCompressionAlgorithmConfig,
)
from coremltools.optimize.torch.layerwise_compression.input_cacher import (
    FirstLayerInputCacher as _FirstLayerInputCacher,
)
from coremltools.optimize.torch.optimization_config import OptimizationConfig as _OptimizationConfig

_logger = _logging.getLogger(__name__)


_ModuleTypeConfigType = _NewType(
    "ModuleTypeConfigType",
    _Dict[_Union[_Callable, str], _Optional[_LayerwiseCompressionAlgorithmConfig]],
)


_SUPPORTED_MODULES = [_torch.nn.Conv2d, _torch.nn.Linear]


[docs] @_define class LayerwiseCompressorConfig(_OptimizationConfig): """ Configuration class for specifying how different submodules of a model are compressed by :py:class:`LayerwiseCompressor`. Note that only sequential models are supported. Args: layers (:obj:`list` of :py:class:`torch.nn.Module` or :obj:`str`): List of layers to be compressed. When items in the list are :obj:`str`, the string can be a regex or the exact name of the module. The layers listed should be immediate child modules of the parent container :py:class:`torch.nn.Sequential` model, and they should be contiguous. That is, the output of layer ``n`` should be the input to layer ``n+1``. global_config (:py:class:`ModuleGPTQConfig` or :py:class:`ModuleSparseGPTConfig`): Config to be applied globally to all supported modules. Missing values are chosen from the default config. module_type_configs (:obj:`dict` of :obj:`str` to :py:class:`ModuleGPTQConfig` or :py:class:`ModuleSparseGPTConfig`): Module type configs applied to a specific module class, such as :py:class:`torch.nn.Linear`. The keys can be either strings or module classes. module_name_configs (:obj:`dict` of :obj:`str` to :py:class:`ModuleGPTQConfig` or :py:class:`ModuleSparseGPTConfig`): Module-level configs applied to specific modules. The name of the module must either be a regex or a fully qualified name that can be used to fetch it from the top level module using the ``module.get_submodule(target)`` method. input_cacher (:obj:`str` or :py:class:`FirstLayerInputCacher`): Cacher object that caches inputs which are then fed to the first layer set up for compression. calibration_nsamples (:obj:`int`): Number of samples to be used for calibration. """ layers: _Optional[_Union[_List[_Union[_nn.Module, str]], _nn.ModuleList]] = _field( default=None, validator=_validators.optional( _validators.deep_iterable( member_validator=_validators.instance_of((_nn.Module, str)), iterable_validator=_validators.instance_of((list, _nn.ModuleList)), ) ), ) global_config: _Optional[_LayerwiseCompressionAlgorithmConfig] = _field( default=None, validator=_validators.optional( _validators.instance_of(_LayerwiseCompressionAlgorithmConfig) ), ) module_type_configs: _ModuleTypeConfigType = _field( factory=_OrderedDict, validator=_validators.deep_mapping( key_validator=_validators.instance_of((str, _Callable)), value_validator=_validators.optional( _validators.instance_of(_LayerwiseCompressionAlgorithmConfig) ), mapping_validator=_validators.instance_of(dict), ), ) module_name_configs: _Dict[str, _Optional[_LayerwiseCompressionAlgorithmConfig]] = _field( factory=_OrderedDict, validator=_validators.deep_mapping( key_validator=_validators.instance_of(str), value_validator=_validators.optional( _validators.instance_of(_LayerwiseCompressionAlgorithmConfig) ), mapping_validator=_validators.instance_of(dict), ), ) input_cacher: str = _field(default="default", converter=_FirstLayerInputCacher.get_class) calibration_nsamples: int = _field(default=128, validator=_validators.instance_of(int))
[docs] @classmethod def from_dict(cls, config_dict: _Dict[str, _Any]) -> "LayerwiseCompressorConfig": super().from_dict(config_dict) converter = _cattrs.Converter(forbid_extra_keys=True) converter.register_structure_hook( _Optional[_Union[_List[_Union[_nn.Module, str]], _nn.ModuleList]], lambda obj, type: obj, ) converter.register_structure_hook( _LayerwiseCompressionAlgorithmConfig, lambda obj, type: _LayerwiseCompressionAlgorithmConfig.get_class( obj["algorithm"] ).from_dict(obj), ) converter.register_structure_hook( _ModuleTypeConfigType, lambda module_type_config, type: { key: _LayerwiseCompressionAlgorithmConfig.get_class(val["algorithm"]).from_dict(val) if val is not None else None for key, val in module_type_config.items() }, ) return converter.structure_attrs_fromdict(config_dict, cls)
def get_layers(self, model: _nn.Module): if self.layers is None: for module_name, module in model.named_children(): yield module_name, module else: yielded = set() for module_name, module in model.named_modules(remove_duplicate=True): for layer in self.layers: if isinstance(layer, str) and _re.fullmatch(layer, module_name): if module_name not in yielded: yielded.add(module_name) yield module_name, module elif module == layer: if module_name not in yielded: yielded.add(module_name) yield module_name, module
@_contextmanager def _set_torch_flags(): # TODO: Copied from original implementation; determine if this is necessary cuda_matmul_tf32 = _torch.backends.cuda.matmul.allow_tf32 cudnn_allow_tf32 = _torch.backends.cudnn.allow_tf32 try: _torch.backends.cuda.matmul.allow_tf32 = False _torch.backends.cudnn.allow_tf32 = False yield finally: _torch.backends.cuda.matmul.allow_tf32 = cuda_matmul_tf32 _torch.backends.cudnn.allow_tf32 = cudnn_allow_tf32
[docs] class LayerwiseCompressor(_BaseDataCalibratedModelOptimizer): """ A post-training compression algorithm which compresses a sequential model layer by layer by minimizing the quantization error while quantizing the weights. The implementation supports two variations of this algorithm: 1) `Generative Pre-Trained Transformer Quantization (GPTQ) <https://arxiv.org/pdf/2210.17323.pdf>`_ 2) `Sparse Generative Pre-Trained Transformer (SparseGPT) <https://arxiv.org/pdf/2301.00774.pdf>`_ At a high level, it compresses weights of a model layer by layer by minimizing the L2 norm of the difference between the original activations and activations obtained from compressing the weights of a layer. The activations are computed using a few samples of training data. Only sequential models are supported, where the output of one layer feeds into the input of the next layer. For HuggingFace models, disable the ``use_cache`` config. This is used to speed up decoding, but to generalize forward pass for :py:class:`LayerwiseCompressor` algorithms across all model types, the behavior must be disabled. Example: .. code-block:: python import torch.nn as nn from coremltools.optimize.torch.layerwise_compression import ( LayerwiseCompressor, LayerwiseCompressorConfig, ) model = nn.Sequential( OrderedDict( { "conv": nn.Conv2d(1, 20, (3, 3)), "relu1": nn.ReLU(), "conv2": nn.Conv2d(20, 20, (3, 3)), "relu2": nn.ReLU(), } ) ) dataloder = load_calibration_data() # initialize the quantizer config = LayerwiseCompressorConfig.from_dict( { "global_config": { "algorithm": "gptq", "weight_dtype": "int4", }, "input_cacher": "default", "calibration_nsamples": 16, } ) compressor = LayerwiseCompressor(model, config) compressed_model = compressor.compress(dataloader) Args: model (:obj:`torch.nn.Module`): Module to be compressed. config (:py:class:`LayerwiseCompressorConfig`): Config that specifies how different submodules in the model will be compressed. """ _supported_modules: _Tuple = tuple(_SUPPORTED_MODULES) def __init__(self, model: _nn.Module, config: LayerwiseCompressorConfig): super().__init__(model, config) self._input_cacher = self._config.input_cacher( self._model, self._config.layers, ) @staticmethod def _forward_layer(layer, inputs, kwarg_inputs, outputs) -> _List: """ Perform forward pass on layer and store outputs. """ for j, inp in enumerate(inputs): if isinstance(inp, _torch.Tensor): inp = (inp,) outputs[j] = layer(*inp, **kwarg_inputs) return outputs def _get_cached_inputs( self, dataloader: _Iterable, device: str ) -> _Tuple[_List[_torch.Tensor], _Dict[str, _torch.Tensor]]: """ Cache the inputs and keyword arguments up till the first layer set up for compression """ inputs, kwarg_inputs = self._input_cacher.cache( dataloader=dataloader, nsamples=self._config.calibration_nsamples, device=device, ) return inputs, kwarg_inputs def _get_layers_to_compress(self) -> _Dict[str, _nn.Module]: """ Returns a list of layers to be compressed """ return self._config.get_layers(self._model) def _init_and_config_layer( self, atomic_layer_name, atomic_layer ) -> _Optional[_LayerwiseCompressionAlgorithm]: """ Initializes and configures the compression algorithm for a given atomic layer. Returns the initialized and configured compression algorithm object """ layer_config = self._config.get_module_config(atomic_layer_name, atomic_layer) if layer_config is not None: algo_class = _LayerwiseCompressionAlgorithm.get_class(layer_config.algorithm) try: return algo_class(atomic_layer, layer_config) except ValueError as error: _logger.info(f"Skipping compression for {atomic_layer_name}. Reason={error}") return None def _register_activation_processing_hook( self, atomic_layer, compressor_obj ) -> _torch.utils.hooks.RemovableHandle: """ Registers a forward hook on the layer for performing computation using the inputs to acquire statistics. Returns the handle for the forward hook """ def activation_processing_hook(_, inp, out): compressor_obj.add_batch(inp[0].data, out.data) return atomic_layer.register_forward_hook(activation_processing_hook) @_torch.no_grad() def _compress_impl(self, dataloader: _Iterable, device: str) -> _nn.Module: """ Compresses a model layerwise using the following steps: 1) Compute inputs to the first layer which is set up for compression using input cacher 2) For each layer, find submodules which are supported for compression and install compression hooks. 3) Run forward pass through each layer, compute activation statistics and use them to compress weights. 4) Compute updated outputs using compressed weights to propagate quantization error to the next layer and set them up as inputs to next layer. """ inputs, kwarg_inputs = self._get_cached_inputs(dataloader, device) outputs = [None for _ in inputs] # compress the layers one by one for layer_idx, (parent_layer_name, layer) in enumerate(self._get_layers_to_compress()): layer.to(device) atomic_layers_dict = _get_atomic_layers( layer, layer_types=self._supported_modules, name_prefix=parent_layer_name, ) # dict mapping layer_name -> compression algorithm object compression_algo_objects_dict = dict() # dict mapping layer_name -> forward hook handle layer_hooks = [] for atomic_layer_name, atomic_layer in atomic_layers_dict.items(): obj = self._init_and_config_layer(atomic_layer_name, atomic_layer) if obj is not None: compression_algo_objects_dict[atomic_layer_name] = obj layer_hooks.append(self._register_activation_processing_hook(atomic_layer, obj)) # Compute statistics on the activations using the activation processing hooks outputs = self._forward_layer( layer, inputs, kwarg_inputs, outputs, ) # Remove the activation processing hooks for h in layer_hooks: h.remove() # compress the layers _logger.info(f"Layer {layer_idx}") for ( atomic_layer_name, compressor_algo, ) in compression_algo_objects_dict.items(): _logger.info(f"Compressing {atomic_layer_name}") compressor_algo.compress() compressor_algo.cleanup() del compression_algo_objects_dict # feed the previous layer's outputs to this layer outputs = self._forward_layer( layer, inputs, kwarg_inputs, outputs, ) # free memory layer.cpu() del layer _torch.cuda.empty_cache() # interchange inputs and outputs inputs, outputs = outputs, inputs _register_metadata_version(self._model) return self._model
[docs] def compress(self, dataloader: _Iterable, device: str, inplace: bool = False) -> _nn.Module: """ Compresses model using samples from ``dataloader``. Args: dataloader (:py:class:`Iterable`): An iterable where each element is an input to the model to be compressed. device (:obj:`str`): Device string for device to run compression on. inplace (:obj:`bool`): If ``True``, model transformations are carried out in-place and the original module is mutated, otherwise a copy of the model is mutated and returned. Defaults to ``False``. """ self._model = super().compress(dataloader=dataloader, inplace=inplace) with _get_eval_model(self._model): with _set_torch_flags(): return self._compress_impl(dataloader, device)
def report(self) -> _Report: return _compute_post_training_report( self._uncompressed_model, self._model, supported_modules=self._supported_modules, )