Source code for sad.callback.w_l1_scheduler

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2020 Apple Inc. All Rights Reserved.
#

import logging

import numpy as np

from .base import CallbackBase, CallbackFactory

logger = logging.getLogger("callback.w_l1_scheduler")


[docs]def exp_rise(w_l1: float, rate: float) -> float: """A scheduling function to calculate new weight of L1 regularization with exponential rise. Args: w_l1 (:obj:`float`): Current weight of L1 regularization. rate (:obj:`float`): The rate of rise. When activated, ``w_l1`` will be changed by multiplying ``exp(rate)``. Returns: :obj:`float`: Updated weight of L1 regularization. """ new_w_l1 = w_l1 * np.exp(rate) logger.info(f"w_l1 updated {w_l1:.02e} -> {new_w_l1:.02e} " "by exponenetial rise.") return new_w_l1
[docs]def step(w_l1: float, new_w_l1: float) -> float: """A scheduling function to update learning rate.. Args: w_l1 (:obj:`float`): Current weight of L1 regularization. new_w_l1 (:obj:`float`): New weight of L1 regularization. Returns: :obj:`float`: Updated weight of L1 regularization. """ if w_l1 == new_w_l1: return w_l1 logger.info(f"w_l1 updated {w_l1:.02e} -> {new_w_l1:.02e} by step scheme.") w_l1 = new_w_l1 return w_l1
[docs]@CallbackFactory.register class WeightL1SchedulerCallback(CallbackBase): """A callback class that is responsible to update weight of L1 regularization during training. Instance of this class will be managed by instances compliant with ``sad.caller.CallerProtocol`` instances, during caller's' initialization. Configurations for this callback is provided under ``trainer:spec:callbacks:``. An example is shown below:: trainer: name: SGDTrainer spec: n_iters: 20 w_l1: 0.1 w_l2: 0.0 u_idxs: [0, 1, 2, 3] callbacks: - name: "WeightL1SchedulerCallback" spec: scheme: "exp_rise" rate: -0.1 start: 0.5 """ @property def scheme(self) -> str: """The scheme of how weight of L1 regularization will be changed. Currently can take ``"exp_rise"|"step"``. Will read directly from ``"scheme"`` field from ``self.spec``. """ return self.spec.get("scheme", "exp_rise") @property def start(self) -> int: """A positive number suggesting when to start to apply changes to weight of L1 regularization. When ``start < 1``, it will be treated as a proportion, suggesting ``w_l1`` will subject to change when ``iter_idx >= int(n_iters * start)``. Otherwise, ``iter_idx >= int(start)`` will be the condition. """ start = self.spec.get("start", 0) if start > 0 and start < 1: # assume it is a proportion start = int(self.caller.n_iters * start) return int(start) @property def every(self) -> int: """Number of iterations every update is performed. ``1`` means weight of L1 regularization is subject to change for every iteration. Will read directly from ``"every"`` field in ``self.spec``. """ every = self.spec.get("every", 1) return every @property def rate(self) -> float: """The rate of rise. Effective when ``self.scheme`` is set to ``"exp_rise"``. When activated, weight of L1 regularization will be changed by multiplying its value by ``exp(rate)``. Will read directly from ``"rate"`` field in ``self.spec``. """ return self.spec.get("rate", 0) @property def new_w_l1(self) -> float: """The new weight of L1 regularization. Effective when ``self.scheme`` is set to ``"step"``. When activated, ``w_l1`` will be changed to ``self.new_w_l1``. Will read directly from ``"new_w_l1"`` field under ``self.spec``. """ return self.spec.get("new_w_l1", self.caller.w_l1)
[docs] def on_loop_begin(self, **kwargs): """Not applicable to this class.""" pass
[docs] def on_loop_end(self, **kwargs): """Not applicable to this class.""" pass
[docs] def on_iter_begin(self, iter_idx: int, **kwargs): """Will be called to determine whether to attempt to update the weight of L1 regulation when an iteration begins. Args: iter_idx (:obj:`int`): The index of iteration, 0-based. """ start = self.start every = self.every caller = self.caller if (iter_idx >= start) and (iter_idx % every == 0): if self.scheme == "exp_rise": new_w_l1 = exp_rise(caller.w_l1, self.rate) elif self.scheme == "step": new_w_l1 = step(caller.w_l1, self.new_w_l1) else: new_w_l1 = caller.w_l1 caller.w_l1 = new_w_l1
[docs] def on_iter_end(self, iter_idx: int, **kwargs): """Not applicable to this class.""" pass
[docs] def on_step_begin(self, iter_idx: int, step_idx: int, **kwargs): """Not applicable to this class.""" pass
[docs] def on_step_end(self, iter_idx: int, step_idx: int, **kwargs): """To be determined.""" pass
[docs] def save(self, folder: str): pass
[docs] def load(self, folder: str): pass