Source code for coremltools.optimize.torch._utils.python_utils

#  Copyright (c) 2024, Apple Inc. All rights reserved.
#
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

import logging as _logging
from collections import OrderedDict as _OrderedDict
from typing import IO as _IO
from typing import Any as _Any
from typing import Dict as _Dict
from typing import Type as _Type
from typing import Union as _Union

import cattrs as _cattrs
import torch as _torch
import yaml as _yaml
from attr import asdict as _asdict

_logger = _logging.getLogger(__name__)


def get_str(val: _Any):
    if isinstance(val, float):
        return f"{val:.5f}"
    return str(val)


class RegistryMixin:
    REGISTRY = None

    @classmethod
    def register(cls, name: str):
        if cls.REGISTRY is None:
            cls.REGISTRY = _OrderedDict()

        def inner_wrapper(wrapped_obj):
            if name in cls.REGISTRY:
                _logger.warning(
                    f"Name: {name} is already registered with object: {cls.REGISTRY[name].__name__} "
                    f"in registry: {cls.__name__}"
                    f"Over-writing the name with new class: {wrapped_obj.__name__}"
                )
            cls.REGISTRY[name] = wrapped_obj
            return wrapped_obj

        return inner_wrapper

    @classmethod
    def _get_object(cls, name: str):
        if name in cls.REGISTRY:
            return cls.REGISTRY[name]
        raise NotImplementedError(
            f"No object is registered with name: {name} in registry {cls.__name__}."
        )


class ClassRegistryMixin(RegistryMixin):
    @classmethod
    def get_class(cls, name: str):
        return cls._get_object(name)


class FunctionRegistryMixin(RegistryMixin):
    @classmethod
    def get_function(cls, name: str):
        return cls._get_object(name)


class DictableDataClass:
    """
    Utility class that provides convertors to and from Python dict
    """

    @classmethod
    def from_dict(cls, data_dict: _Dict[str, _Any]) -> "DictableDataClass":
        """
        Create class from a dictionary of string keys and values.

        Args:
            data_dict (:obj:`dict` of :obj:`str` and values): A nested dictionary of strings
                and values.
        """
        # Explicitly raise exception for unrecognized keys
        cls._validate_dict(data_dict)
        converter = _cattrs.Converter(forbid_extra_keys=True)
        converter.register_structure_hook(_torch.Tensor, lambda obj, type: obj)
        return converter.structure_attrs_fromdict(data_dict, cls)

    @classmethod
    def from_yaml(cls, yml: _Union[_IO, str]) -> "DictableDataClass":
        """
        Create class from a yaml stream.

        Args:
            yml: An :py:class:`IO` stream containing yaml or a :obj:`str`
                path to the yaml file.
        """
        if isinstance(yml, str):
            with open(yml, "r") as file:
                dict_from_yml = _yaml.safe_load(file)
        else:
            dict_from_yml = _yaml.safe_load(yml)
        if dict_from_yml is None:
            dict_from_yml = {}
        assert isinstance(dict_from_yml, dict), (
            "Invalid yaml received. yaml stream should return a dict "
            f"on parsing. Received type: {type(dict_from_yml)}."
        )
        return cls.from_dict(dict_from_yml)

    def as_dict(self) -> _Dict[str, _Any]:
        """
        Returns the config as a dictionary.
        """
        return _asdict(self)

    @classmethod
    def _validate_dict(cls: _Type, config_dict: _Dict[str, _Any]):
        for key, _ in config_dict.items():
            if not hasattr(cls, key):
                raise ValueError(f"Found unrecognized key {key} in config_dict: {config_dict}.")