Source code for coremltools.optimize.torch.palettization.palettizer

#  Copyright (c) 2024, Apple Inc. All rights reserved.
#
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

import logging as _logging
from typing import Dict as _Dict
from typing import Optional as _Optional

import torch as _torch
import torch.nn as _nn
from torch.ao.quantization import FakeQuantize as _FakeQuantize

from coremltools.optimize.torch._typing import ParamsDict as _ParamsDict
from coremltools.optimize.torch._utils.math_utils import rmse_error as _rmse_error
from coremltools.optimize.torch._utils.metadata_utils import (
    register_metadata_version as _register_metadata_version,
)
from coremltools.optimize.torch._utils.torch_utils import get_eval_model as _get_eval_model
from coremltools.optimize.torch._utils.validation_utils import (
    validate_param_config as _validate_param_config,
)
from coremltools.optimize.torch.base_model_optimizer import (
    BaseTrainingTimeModelOptimizer as _BaseTrainingTimeModelOptimizer,
)
from coremltools.optimize.torch.base_model_optimizer import _Report
from coremltools.optimize.torch.palettization._custom_conversion import (
    PALETTIZATION_CONVERT_DICT as _PALETTIZATION_CONVERT_DICT,
)
from coremltools.optimize.torch.palettization._supported_modules import (
    _get_palettization_qat_mappings,
)
from coremltools.optimize.torch.palettization._supported_modules import (
    get_palettizable_parameters as _get_palettizable_parameters,
)
from coremltools.optimize.torch.palettization.fake_palettize import FakePalettize as _FakePalettize
from coremltools.optimize.torch.palettization.palettization_config import (
    DEFAULT_PALETTIZATION_ADVANCED_OPTIONS as _DEFAULT_PALETTIZATION_ADVANCED_OPTIONS,
)
from coremltools.optimize.torch.palettization.palettization_config import (
    DEFAULT_PALETTIZATION_SCHEME as _DEFAULT_PALETTIZATION_SCHEME,
)
from coremltools.optimize.torch.palettization.palettization_config import (
    DKMPalettizerConfig as _DKMPalettizerConfig,
)
from coremltools.optimize.torch.palettization.palettization_config import (
    ModuleDKMPalettizerConfig as _ModuleDKMPalettizerConfig,
)

_logger = _logging.getLogger(__name__)


class Palettizer(_BaseTrainingTimeModelOptimizer):
    pass


[docs] class DKMPalettizer(Palettizer): """ A palettization algorithm based on `"DKM: Differentiable K-Means Clustering Layer for Neural Network Compression" <https://arxiv.org/pdf/2108.12659.pdf>`_. It clusters the weights using a differentiable version of ``k-means``, allowing the lookup table (LUT) and indices of palettized weights to be learnt using a gradient-based optimization algorithm such as SGD. Example: .. code-block:: python import torch from coremltools.optimize.torch.palettization import ( DKMPalettizer, DKMPalettizerConfig, ModuleDKMPalettizerConfig, ) # code that defines the pytorch model, loss and optimizer. model, loss_fn, optimizer = create_model_loss_and_optimizer() # initialize the palettizer config = DKMPalettizerConfig(global_config=ModuleDKMPalettizerConfig(n_bits=4)) palettizer = DKMPalettizer(model, config) # prepare the model to insert FakePalettize layers for palettization model = palettizer.prepare(inplace=True) # use palettizer in your PyTorch training loop for inputs, labels in data: output = model(inputs) loss = loss_fn(output, labels) loss.backward() optimizer.step() palettizer.step() # fold LUT and indices into weights model = palettizer.finalize(inplace=True) Args: model (:py:class:`torch.nn.Module`): Model on which the palettizer will act. config (:py:class:`DKMPalettizerConfig`): Config which specifies how different submodules in the model will be configured for palettization. Default config is used when passed as ``None``. """ def __init__(self, model: _nn.Module, config: _Optional[_DKMPalettizerConfig] = None): config = _DKMPalettizerConfig() if config is None else config super().__init__(model, config) self._milestones = {} self._supported_modules = _get_palettization_qat_mappings() def _palettize_supported_modules(self): """ Method to palettize supported modules. """ for name, submodule in self._model.named_modules(remove_duplicate=True): config = self._config.get_module_config(name, submodule) if type(submodule) in self._supported_modules: if config is not None: submod_configs = config if isinstance(config, list) else [config] for submod_config in submod_configs: if all( param.numel() > submod_config.weight_threshold for param, _ in _get_palettizable_parameters(submodule) ): module_level_advanced_options = self._get_module_level_advanced_options( submodule, submod_config ) n_bits = ( submod_config.n_bits if submod_config.n_bits is not None else _DEFAULT_PALETTIZATION_SCHEME[type(submodule)]["n_bits"] ) cluster_dim = ( submod_config.cluster_dim if submod_config.cluster_dim is not None else _DEFAULT_PALETTIZATION_SCHEME[type(submodule)]["cluster_dim"] ) enable_per_channel_scale = ( submod_config.enable_per_channel_scale if submod_config.enable_per_channel_scale is not None else _DEFAULT_PALETTIZATION_SCHEME[type(submodule)][ "enable_per_channel_scale" ] ) updated_config = None for param, param_name in _get_palettizable_parameters(submodule): updated_config = _validate_param_config( name + "." + param_name, param, submodule, submod_config, [ "palettization_cluster_dim", "palettization_group_size", ], module_level_advanced_options, ) if not updated_config: break if not updated_config: continue self._palettize_module( submodule, n_bits, cluster_dim, enable_per_channel_scale, updated_config.group_size, updated_config.quant_min, updated_config.quant_max, updated_config.lut_dtype, updated_config.dtype, updated_config.quantize_activations, module_level_advanced_options, ) self._milestones[name] = updated_config.milestone @staticmethod def _palettize_module( module: _nn.Module, n_bits: int, cluster_dim: int, enable_per_channel_scale: bool, group_size: _Optional[int], quant_min: int, quant_max: int, lut_dtype: str, dtype: _torch.dtype, quantize_activations: bool, advanced_options: _Dict, ): """ Method to palettize a module. """ fq_activation = _nn.Identity fq_weight = _FakePalettize.with_args( observer=_torch.quantization.MovingAveragePerChannelMinMaxObserver.with_args( quant_min=quant_min, quant_max=quant_max, dtype=dtype ), n_bits=n_bits, cluster_dim=cluster_dim, enable_per_channel_scale=enable_per_channel_scale, group_size=group_size, quant_min=quant_min, quant_max=quant_max, lut_dtype=lut_dtype, advanced_options=advanced_options, ) if quantize_activations: fq_activation = _FakeQuantize.with_args( observer=_torch.quantization.MovingAveragePerChannelMinMaxObserver.with_args( quant_min=quant_min, quant_max=quant_max, dtype=dtype ), quant_min=quant_min, quant_max=quant_max, ) module.qconfig = _torch.quantization.QConfig(activation=fq_activation, weight=fq_weight) @staticmethod def _get_module_level_advanced_options( module: _nn.Module, module_level_config: _ModuleDKMPalettizerConfig ) -> _ParamsDict: """ Returns advanced_options for a module. First checks whether the user specified something for those options in the palettization_config. If not, uses the options from the DEFAULT_PALETTIZATION_SCHEME of that module type. Returns false otherwise. """ module_level_advanced_options = {} for key in _DEFAULT_PALETTIZATION_ADVANCED_OPTIONS.keys(): if key == "cluster_permute" and module_level_config.lut_dtype == "oc_last": cluster_permute = list(range(module.weight.dim())) cluster_permute = cluster_permute[1:] + cluster_permute[:1] module_level_advanced_options[key] = cluster_permute else: module_level_advanced_options[key] = getattr(module_level_config, key) return module_level_advanced_options
[docs] def prepare(self, inplace: bool = False) -> _nn.Module: """ Prepares a model for palettization aware training by inserting :py:class:`FakePalettize` layers in appropriate places as specified by the config. Args: inplace (:obj:`bool`): If ``True``, model transformations are carried out in-place and the original module is mutated, otherwise a copy of the model is mutated and returned. """ self._model = self._get_model_for_compression(inplace) self._model.train() self._palettize_supported_modules() qat_mappings = _get_palettization_qat_mappings() self._model = _torch.quantization.prepare_qat( self._model, mapping=qat_mappings, inplace=True ) return self._model
[docs] def finalize(self, model: _Optional[_nn.Module] = None, inplace: bool = False) -> _nn.Module: """ Removes :py:class:`FakePalettize` layers from a model and creates new model weights from the ``LUT`` and ``indices`` buffers. This function is called to prepare a palettized model for export using `coremltools <https://coremltools.readme.io/docs>`_. Args: model (:obj:`nn.Module`): model to finalize. inplace (:obj:`bool`): If ``True``, model transformations are carried out in-place and the original module is mutated; otherwise, a copy of the model is mutated and returned. """ if model is None: model = self._model model.eval() finalized_model = _torch.quantization.convert( model, convert_custom_config_dict=_PALETTIZATION_CONVERT_DICT, inplace=inplace ) if model is None: self._model = finalized_model _register_metadata_version(finalized_model) return finalized_model
[docs] def step(self): """ Step through the palettizer. When the number of times ``step`` is called is equal to ``milestone``, palettization is enabled. """ for name, module in self._model.named_modules(): if name in self._milestones: if self._step_count == self._milestones[name]: self._enable_fake_palett_impl(module, True) self._init_prune_threshold_and_module_wise_target_sparsity(module) if self._step_count > self._milestones[name]: self._update_prune_threshold(module) self._step_count += 1
@staticmethod def _init_prune_threshold_and_module_wise_target_sparsity(module: _torch.nn.Module): if hasattr(module, "weight_fake_quant") and hasattr(module, "weight_mask"): non_zero_weights = module.weight_mask.count_nonzero().item() total_weights = _torch.numel(module.weight_mask) target_module_level_sparsity = 1 - non_zero_weights / total_weights inverse_mask = (module.weight_mask + 1) % 2 n_bits = module.weight_fake_quant.n_bits cluster_dim = module.weight_fake_quant.cluster_dim add_extra_centroid = module.weight_fake_quant.add_extra_centroid n_clusters = 2 ** int(n_bits) + int(add_extra_centroid) prune_threshold_init = _torch.abs(inverse_mask * module.weight_orig).max() / ( total_weights / cluster_dim / n_clusters ) module.weight_fake_quant.prune_threshold = prune_threshold_init module.weight_fake_quant._target_module_level_sparsity = target_module_level_sparsity @staticmethod def _update_prune_threshold(module: _torch.nn.Module): if hasattr(module, "weight_fake_quant") and hasattr(module, "weight_mask"): weight_detached = module.weight.detach() qweight = module.weight_fake_quant.palettize(weight_detached) sparsity = 1 - qweight.count_nonzero() / qweight.numel() prune_ratio = float(module.weight_fake_quant._target_module_level_sparsity) / ( sparsity + 1e-7 ) if prune_ratio > 0 and abs(prune_ratio - 1) > 0.01: prune_multiplier = max(min(prune_ratio, 1.25), 0.9) module.weight_fake_quant.prune_threshold *= prune_multiplier def enable_fake_palett(self, flag: bool): _logging.info( f"[{type(self).__name__}] " + ("enable" if flag else "disable") + " fake_palett" ) for name, module in self._model.named_modules(): self._enable_fake_palett_impl(module, flag) @staticmethod def _enable_fake_palett_impl(module: _torch.nn.Module, flag: bool): def enable_fn(mod): if hasattr(mod, "enable_fake_palett"): mod.enable_fake_palett(flag) module.apply(enable_fn)
[docs] def report(self) -> _Report: """ Returns a dictionary with important statistics related to current state of palettization. Each key in the dictionary corresponds to a module name, and the value is a dictionary containing the statistics, such as number of clusters and cluster dimension, number of parameters, and so on. """ report = _Report() with _get_eval_model(self._model) as model: with _torch.no_grad(): for name, module in model.named_modules(): module_summary = dict() if hasattr(module, "weight_fake_quant"): module_summary["device"] = module.weight.device qweight = module.weight_fake_quant.forward(module.weight.detach()) lut_dtype = module.weight_fake_quant.lut_dtype cluster_permute = module.weight_fake_quant.cluster_permute module_summary["error"] = _rmse_error( module.weight.detach(), qweight ).item() n_clusters = module.weight_fake_quant.n_clusters module_summary["#params"] = int(_torch.numel(qweight)) cluster_dim = module.weight_fake_quant.cluster_dim module_summary["#dtype"] = ( f":num_clusters: {n_clusters} <{lut_dtype, cluster_permute}> " f"dim={cluster_dim}" ) report[name] = module_summary return report