Source code for coremltools.optimize.torch.palettization.post_training_palettization

#  Copyright (c) 2024, Apple Inc. All rights reserved.
#
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

import logging as _logging
from collections import OrderedDict as _OrderedDict
from typing import Any as _Any
from typing import Callable as _Callable
from typing import Dict as _Dict
from typing import NewType as _NewType
from typing import Optional as _Optional
from typing import Tuple as _Tuple
from typing import Union as _Union

import cattrs as _cattrs
import torch as _torch
from attr import define as _define
from attr import field as _field
from attrs import validators as _validators

from coremltools.optimize.torch._utils.k_means import KMeansConfig as _KMeansConfig
from coremltools.optimize.torch._utils.k_means import (
    KMeansSupportedModulesRegistry as _KMeansSupportedModulesRegistry,
)
from coremltools.optimize.torch._utils.k_means import ParallelKMeans as _ParallelKMeans
from coremltools.optimize.torch._utils.k_means import SequentialKMeans as _SequentialKMeans
from coremltools.optimize.torch._utils.report_utils import (
    compute_post_training_report as _compute_post_training_report,
)
from coremltools.optimize.torch._utils.torch_utils import (
    maybe_convert_str_to_dtype as _maybe_convert_str_to_dtype,
)
from coremltools.optimize.torch._utils.torch_utils import (
    maybe_convert_str_to_mod_type as _maybe_convert_str_to_mod_type,
)
from coremltools.optimize.torch._utils.validation_utils import (
    validate_param_config as _validate_param_config,
)
from coremltools.optimize.torch.base_model_optimizer import (
    BasePostTrainingModelOptimizer as _BasePostTrainingModelOptimizer,
)
from coremltools.optimize.torch.base_model_optimizer import _Report
from coremltools.optimize.torch.optimization_config import (
    ModuleOptimizationConfig as _ModuleOptimizationConfig,
)
from coremltools.optimize.torch.optimization_config import OptimizationConfig as _OptimizationConfig
from coremltools.optimize.torch.optimization_config import (
    PalettizationGranularity,
    _structure_from_dict_hook_factory,
)

_logger = _logging.getLogger(__name__)


[docs] @_define class ModulePostTrainingPalettizerConfig(_ModuleOptimizationConfig): """ Configuration class for specifying global and module-level palettization options for :py:class:`PostTrainingPalettizerConfig` algorithm. Args: n_bits (:obj:`int`): Number of bits to use for palettizing the weights. Defaults to ``4``. lut_dtype (:py:class:`torch.dtype`): The dtype to use for representing each element in lookup tables. When value is ``None``, no quantization is performed. Supported values are :py:class:`torch.int8` and :py:class:`torch.uint8`. Defaults to ``None``. granularity (:py:class:`PalettizationGranularity`) – Granularity for palettization. One of ``per_tensor`` or ``per_grouped_channel``. Defaults to ``per_tensor``. group_size (:obj:`int`): Specify the number of channels in a group. Only effective when granularity is ``per_grouped_channel``. channel_axis (:obj:`int`): Specify the channel axis to form a group of channels. Only effective when granularity is ``per_grouped_channel``. Defaults to output channel axis. cluster_dim (:obj:`int`): The dimension of centroids for each lookup table. The centroid is a scalar by default. When ``cluster_dim > 1``, it indicates 2-D clustering, and each ``cluster_dim`` length of weight vectors along the output channel are palettized using the same 2-D centroid. The length of each entry in the lookup tables is equal to ``cluster_dim``. enable_per_channel_scale (:obj:`bool`): When set to ``True``, weights are normalized along the output channels using per-channel scales before being palettized. This is not supported with ``cluster_dim > 1``. This class supports two different configurations to structure the palettization: 1. **Per-tensor palettization**: This is the default configuration where the whole tensor shares a single lookup table. The ``granularity`` is set to ``per_tensor``, and ``group_size`` is ``None``. 2. **Per-grouped-channel palettization**: In this configuration, the number of channels ``group_size`` along ``channel_axis`` share the same lookup table. For example, for a weight matrix of shape ``(16, 25)``, if we provide ``group_size = 8``, the shape of the lookup table would be ``(2, 2^n_bits)``. .. note:: Grouping is currently only supported along either the input or output channel axis. """ n_bits: _Optional[int] = _field( default=4, validator=_validators.optional(_validators.instance_of(int)) ) lut_dtype: _torch.dtype = _field( default=None, converter=lambda val: _maybe_convert_str_to_dtype(val) if val else val, validator=_validators.optional( [ _validators.instance_of(_torch.dtype), _validators.in_([_torch.int8, _torch.uint8]), ] ), ) granularity: PalettizationGranularity = _field( default="per_tensor", converter=PalettizationGranularity, validator=_validators.in_(PalettizationGranularity), ) group_size: _Optional[int] = _field( default=None, validator=_validators.optional(_validators.instance_of(int)) ) channel_axis: int = _field( default=0, validator=_validators.optional([_validators.instance_of(int), _validators.in_([0, 1])]), ) cluster_dim: _Optional[int] = _field( default=None, validator=_validators.optional(_validators.instance_of(int)) ) enable_per_channel_scale: _Optional[bool] = _field( default=False, validator=_validators.optional(_validators.instance_of(bool)) ) @group_size.validator def per_grouped_channel_granularity(self, attribute, value): if self.granularity == PalettizationGranularity.per_grouped_channel: assert ( value is not None ), "group_size has to be specified along with per_grouped_channel granularity." assert value > 0, "group_size should be greater than zero" else: assert value is None, "group_size can't be specified along with per_tensor granularity." @cluster_dim.validator def no_per_channel_scale(self, attribute, value): if value and value > 1: assert ( self.enable_per_channel_scale == False ), f"Enabling per_channel_scale is not supported for cluster_dim={value} larger than 1"
_ModuleTypeConfigType = _NewType( "ModuleTypeConfigType", _Dict[_Union[_Callable, str], _Optional[ModulePostTrainingPalettizerConfig]], )
[docs] @_define class PostTrainingPalettizerConfig(_OptimizationConfig): """ Configuration class for specifying how different submodules of a model should be post-training palettized by :py:class:`PostTrainingPalettizer`. Args: global_config (:py:class:`ModulePostTrainingPalettizerConfig`): Config to be applied globally to all supported modules. module_type_configs (:obj:`dict` of :obj:`str` to :py:class:`ModulePostTrainingPalettizerConfig`): Module type configs applied to a specific module class, such as :py:class:`torch.nn.Linear`. The keys can be either strings or module classes. module_name_configs (:obj:`dict` of :obj:`str` to :py:class:`ModulePostTrainingPalettizerConfig`): Module name configs applied to specific modules. This can be a dictionary with module names pointing to their corresponding :py:class:`ModulePostTrainingPalettizerConfig`. """ global_config: _Optional[ModulePostTrainingPalettizerConfig] = _field( default=None, validator=_validators.optional(_validators.instance_of(ModulePostTrainingPalettizerConfig)), ) module_type_configs: _ModuleTypeConfigType = _field( factory=_OrderedDict, validator=_validators.deep_mapping( key_validator=_validators.instance_of((str, _Callable)), value_validator=_validators.optional( _validators.instance_of(ModulePostTrainingPalettizerConfig) ), mapping_validator=_validators.instance_of(dict), ), ) module_name_configs: _Dict[str, _Optional[ModulePostTrainingPalettizerConfig]] = _field( factory=_OrderedDict, validator=_validators.deep_mapping( key_validator=_validators.instance_of(str), value_validator=_validators.optional( _validators.instance_of(ModulePostTrainingPalettizerConfig) ), mapping_validator=_validators.instance_of(dict), ), ) def __attrs_post_init__(self): if ( self.global_config is None and len(self.module_type_configs) == 0 and len(self.module_name_configs) == 0 ): self.global_config = ModulePostTrainingPalettizerConfig() self.module_type_configs = { _maybe_convert_str_to_mod_type(key): val for key, val in self.module_type_configs.items() } @classmethod def from_dict(cls, config_dict: _Dict[str, _Any]) -> "PostTrainingPalettizerConfig": super().from_dict(config_dict) converter = _cattrs.Converter(forbid_extra_keys=True) converter.register_structure_hook( _ModuleTypeConfigType, _structure_from_dict_hook_factory(ModulePostTrainingPalettizerConfig), ) return converter.structure_attrs_fromdict(config_dict, cls)
[docs] class PostTrainingPalettizer(_BasePostTrainingModelOptimizer): """ Perform post-training palettization on a torch model. Post palettization, all the weights in supported layers point to elements in a lookup table after performing a k-means operation. Example: .. code-block:: python import torch.nn as nn from coremltools.optimize.torch.palettization import ( PostTrainingPalettizerConfig, PostTrainingPalettizer, ) model = nn.Sequential( OrderedDict( { "conv": nn.Conv2d(1, 20, (3, 3)), "relu1": nn.ReLU(), "conv2": nn.Conv2d(20, 20, (3, 3)), "relu2": nn.ReLU(), } ) ) # initialize the palettizer config = PostTrainingPalettizerConfig.from_dict( { "global_config": { "n_bits": 4, }, } ) ptpalettizer = PostTrainingPalettizer(model, config) palettized_model = ptpalettizer.compress() Args: model (:obj:`torch.nn.Module`): Module to be compressed. config (:py:class:`PostTrainingPalettizerConfig`): Config that specifies how different submodules in the model will be palettized. """ _supported_modules: _Tuple = _KMeansSupportedModulesRegistry.get_supported_modules() def __init__(self, model: _torch.nn.Module, config: PostTrainingPalettizerConfig = None): config = PostTrainingPalettizerConfig() if config is None else config super().__init__(model, config) def compress(self, num_kmeans_workers: int = 1, inplace: bool = False) -> _torch.nn.Module: """ The compress method performs a `k-means` operation on all supported modules. Args: num_kmeans_workers (:obj:`int`): Number of worker processes used for performing post-training palettization. Defaults to ``1``. inplace (:obj:`bool`): If ``True``, model transformations are carried out in-place and the original module is mutated, otherwise a copy of the model is mutated and returned. Defaults to ``False``. """ self._model = super().compress(inplace=inplace) kmeans_config_dict = dict() for name, submodule in self._model.named_modules(): submod_config = self._config.get_module_config(name, submodule) if submod_config is None: continue k_means_module_cls = _KMeansSupportedModulesRegistry.get_kmeans_module(submodule) if k_means_module_cls is None: continue for param_name in k_means_module_cls.parameter_names: # Validate configuration for parameter param = submodule.get_parameter(param_name) updated_config = _validate_param_config( name + "." + param_name, param, submodule, submod_config, ["palettization_group_size", "palettization_cluster_dim"], ) if not updated_config: continue if name not in kmeans_config_dict: kmeans_config_dict[name] = {} kmeans_config_dict[name][param_name] = _KMeansConfig( n_bits=updated_config.n_bits, axis=updated_config.channel_axis, lut_dtype=updated_config.lut_dtype, block_size=updated_config.group_size, cluster_dim=updated_config.cluster_dim, enable_per_channel_scale=updated_config.enable_per_channel_scale, ) if num_kmeans_workers > 1: return _ParallelKMeans.cluster_weights( self._model, kmeans_config_dict, num_workers=num_kmeans_workers ) else: return _SequentialKMeans.cluster_weights(self._model, kmeans_config_dict) def report(self) -> _Report: return _compute_post_training_report( self._uncompressed_model, self._model, supported_modules=self._supported_modules, )