Source code for coremltools.optimize.torch.optimization_config

#  Copyright (c) 2024, Apple Inc. All rights reserved.
#
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

import re as _re
from collections import OrderedDict as _OrderedDict
from enum import Enum as _Enum
from typing import Any as _Any
from typing import Callable as _Callable
from typing import Dict as _Dict
from typing import List as _List
from typing import Optional as _Optional
from typing import Union as _Union

import torch as _torch
from attr import Factory as _Factory
from attr import define as _define
from attrs import field as _field

from coremltools.optimize.torch._utils.python_utils import DictableDataClass as _DictableDataClass


class QuantizationGranularity(_Enum):
    """
    Enum to denote granularity at which different compression schemes are applied.
    See specific algorithm for more details.
    """
    per_tensor = "per_tensor"
    per_channel = "per_channel"
    per_block = "per_block"

class PalettizationGranularity(_Enum):
    """
    Enum to denote granularity at which different compression schemes are applied.
    See specific algorithm for more details.
    """

    per_tensor = "per_tensor"
    per_grouped_channel = "per_grouped_channel"

class ModuleOptimizationConfig(_DictableDataClass):
    pass


@_define
class OptimizationConfig(_DictableDataClass):
    global_config: _Optional[ModuleOptimizationConfig] = None
    module_type_configs: _Dict[
        _Union[_Callable, str], _Optional[ModuleOptimizationConfig]
    ] = _Factory(_OrderedDict)
    module_name_configs: _Dict[str, _Optional[ModuleOptimizationConfig]] = _Factory(_OrderedDict)

    def set_global(
        self, global_config: _Optional[ModuleOptimizationConfig]
    ) -> "OptimizationConfig":
        """
        Set the global config.
        """
        self.global_config = global_config
        return self

    def set_module_type(
        self, object_type: _Union[_Callable, str], opt_config: _Optional[ModuleOptimizationConfig]
    ) -> "OptimizationConfig":
        """
        Set the module level optimization config for a given module type. If the module level optimization config
        for an existing module type was already set, the new config will override the old one.
        """
        self.module_type_configs[object_type] = opt_config
        return self

    def set_module_name(
        self, module_name: str, opt_config: _Optional[ModuleOptimizationConfig]
    ) -> "OptimizationConfig":
        """
        Set the module level optimization config for a given module instance. If the module level optimization config
        for an existing module was already set, the new config will override the old one.
        """
        self.module_name_configs[module_name] = opt_config
        return self

    def get_module_config(
        self, name: str, module: _torch.nn.Module
    ) -> _Optional[ModuleOptimizationConfig]:
        for mod_name in self.module_name_configs:
            if _re.fullmatch(mod_name, name):
                return self.module_name_configs[mod_name]
        if type(module) in self.module_type_configs:
            return self.module_type_configs[type(module)]
        elif module.__class__.__name__ in self.module_type_configs:
            return self.module_type_configs[module.__class__.__name__]
        else:
            return self.global_config

    @classmethod
    def from_dict(cls, config_dict: _Dict[str, _Any]) -> _Optional["OptimizationConfig"]:
        """
        Create class from a dictionary of string keys and values.

        Args:
            config_dict (:obj:`dict` of :obj:`str` and values): A nested dictionary of strings
                and values.
        """
        # passing forbid_extra_keys=True doesn't prevent silent failure when keys are mis-spelled
        cls._validate_dict(config_dict)
        return

    def _validate_same_params(self, param_names: _List[str]):
        """
        This method validates that all the parameters in param_names
        have the same value across all the module level configs.
        """
        expected_values = None
        if self.global_config is not None:
            expected_values = {
                param_name: getattr(self.global_config, param_name) for param_name in param_names
            }
        for name, config in self.module_type_configs.items():
            if config is not None:
                expected_values = self._validate_expected_value(
                    expected_values, name, config, param_names
                )
        for name, config in self.module_name_configs.items():
            if config is not None:
                expected_values = self._validate_expected_value(
                    expected_values, name, config, param_names
                )

    @staticmethod
    def _validate_expected_value(
        expected_values: _Dict[str, _Any],
        name: str,
        config: ModuleOptimizationConfig,
        param_names: _List[str],
    ):
        if expected_values is None:
            expected_values = {
                param_name: getattr(config, param_name) for param_name in param_names
            }
        for param_name, expected_val in expected_values.items():
            val = getattr(config, param_name)
            if val != expected_val:
                raise ValueError(
                    f"Value of parameter {param_name} cannot "
                    f"be different between different module level configs."
                    f"Expected value: {expected_val}, received: {val} "
                    f"for config {name}."
                )
        return expected_values


def _structure_from_dict_hook_factory(conversion_cls: _Any) -> _Callable:
    def _structure_from_dict_hook(
        module_type_dict: _Dict[_Union[_Callable, str], _Any], type: _Any
    ):
        return_dict = _OrderedDict()
        for key, value in module_type_dict.items():
            if value is None:
                return_dict[key] = None
            else:
                if isinstance(value, dict):
                    return_dict[key] = conversion_cls.from_dict(value)
                else:
                    assert isinstance(value, conversion_cls), (
                        "value in module type dict should be either a dict or "
                        "a module config object."
                    )
                    return_dict[key] = value
        return return_dict
    return _structure_from_dict_hook


def _validate_module_type_keys_factory(supported_modules):
    supported_module_names = [cls.__name__ for cls in supported_modules]

    def validate_module_type_key(instance, attribute, value):
        if isinstance(value, str):
            assert value in supported_module_names, (
                f"keys for module_type_configs must be one of "
                f"{supported_module_names}. Received: {value}."
            )
        else:
            assert value in supported_modules, (
                f"keys for module_type_configs must be one of "
                f"{supported_modules}. Received: {value}."
            )

    return validate_module_type_key


def _deprecated_field(message="This field is deprecated"):
    def validator(inst, attr, val):
        if val is not None:
            raise DeprecationWarning(message)

    return _field(default=None, validator=validator, on_setattr=validator)