Source code for sad.utils.misc

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2020 Apple Inc. All Rights Reserved.
#

import logging
from typing import Any, Dict

from scipy.special import logit

_logger = logging.getLogger("utils.misc")


[docs]def my_logit(value: float, EPS: float = 1e-10) -> float: """Take logit of a given value. Input value will be restricted to ``[EPS, 1-EPS]`` interval. Args: value (:obj:`float`): A value is between ``(0, 1)``. Due to numerical consideration, the value will be truncated to ``[EPS, 1-EPS]`` where ``EPS`` is a small number. EPS (:obj:`float`): A small positive number that will be used to maintain numerical stability. Default to ``1e-10``. Returns: :obj:`float`: The logit of input ``value``. """ if value < EPS: value = EPS elif value > 1 - EPS: value = 1 - EPS return logit(value)
[docs]def update_dict_recursively(dict_a: Dict, dict_b: Dict): """A helper function to absorb contents in ``dict_b`` into ``dict_a``, recursively. ``dict_a`` will be modified in place. Args: dict_a (:obj:`dict`): First dictionary that absorbs. dict_b (:obj:`dict`): Second dictionary in which all fields will be absorbed into ``dict_a``. Return: Modified input dictionary ``dict_a``. """ for kb, vb in dict_b.items(): if kb in dict_a: if isinstance(dict_a[kb], dict) and isinstance(vb, dict): dict_a[kb] = update_dict_recursively(dict_a[kb], vb) else: dict_a[kb] = vb # overwrite existing values in dict_a otherwise else: dict_a[kb] = vb return dict_a