Source code for sad.callback.base

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2020 Apple Inc. All Rights Reserved.
#

import logging
from abc import ABC, abstractmethod
from typing import Dict

from .caller import CallerProtocol


[docs]class CallbackBase(ABC): """A callback base class that every concrete callback subclass will inherit from. Instance of this class will be managed by a caller instance that is compliant with ``CallerProtocol``. Currently instances of ``sad.trainer.TrainerBase`` classes could be such callers. Callback instances will be created during caller's initialization. Configurations for this callback is provided under ``caller:spec:callbacks:``. An example is shown below:: trainer: name: SGDTrainer spec: n_iters: 20 w_l1: 0.1 w_l2: 0.0 u_idxs: [0, 1, 2, 3] callbacks: - name: "CheckpointingCallback" spec: start: 10 every: 1 """ def __init__(self, config: Dict, caller: CallerProtocol): self.config = config self.caller = caller self.logger = logging.getLogger(f"callback.{self.__class__.__name__}") @property def config(self) -> Dict: """Configuration dictionary that is used to initialize the instance.""" return self._config @config.setter def config(self, config: dict): self._config = config @property def spec(self) -> Dict: """A reference to ``"spec"`` field in ``self.config``. When no such field exists or the value is ``None``, an empty dictionary will be set.""" if self.config.get("spec") is None: self.config["spec"] = {} return self.config.get("spec") @spec.setter def spec(self, spec: Dict): self.config["spec"] = spec @property def caller(self) -> CallerProtocol: """Reference to an instance of a caller class that is compliant with ``CallerProtocol``. Could be an instance of ``sad.trainer.TrainerBase``.""" return self._caller @caller.setter def caller(self, caller: CallerProtocol): self._caller = caller self._caller.register_callback(self)
[docs] @abstractmethod def on_loop_begin(self, **kwargs): """Will be called from caller when main loop begins. The main loop could be training loop in ``sad.trainer.TrainerBase``.""" raise NotImplementedError
[docs] @abstractmethod def on_loop_end(self, **kwargs): """Will be called from caller when main loop ends.""" raise NotImplementedError
[docs] @abstractmethod def on_iter_begin(self, iter_idx: int, **kwargs): """Will be called from caller when an iteration begins. An iteration could be an epoch during training loop. Args: iter_idx (:obj:`int`): The index of iteration, 0-based. """ raise NotImplementedError
[docs] @abstractmethod def on_iter_end(self, iter_idx: int, **kwargs): """Will be called from caller when an iteration ends. Args: iter_idx (:obj:`int`): The index of iteration. 0-based. """ raise NotImplementedError
[docs] @abstractmethod def on_step_begin(self, iter_idx: int, step_idx: int, **kwargs): """Will be called from caller when a step begins. A step could be one gradient updates from a minibatch during training. Args: iter_idx (:obj:`int`): The index of iteration. 0-based. step_idx (:obj:`int`): The index of step. 0-based. """ raise NotImplementedError
[docs] @abstractmethod def on_step_end(self, iter_idx: int, step_idx: int, **kwargs): """Will be called from caller when a step finishes. Args: iter_idx (:obj:`int`): The index of iteration. 0-based. step_idx (:obj:`int`): The index of step. 0-based. """ raise NotImplementedError
[docs]class CallbackFactory: """A factory class that is responsible to create callback instances.""" logger = logging.getLogger("callback.CallbackFactory") """:class:`logging.Logger`: Class attribute for logging.""" _registry = dict() """:class:`dict`: Registry dictionary containing a mapping between class name and class object."""
[docs] @classmethod def register(cls, wrapped_class: CallbackBase) -> CallbackBase: """A class level decorator responsible to decorate ``sad.callback.CallbackBase`` classes and register them into ``CallbackFactory._registry``. """ class_name = wrapped_class.__name__ if class_name in cls._registry: cls.logger.warning(f"Callback {class_name} already registered, Ignoring.") return wrapped_class cls._registry[class_name] = wrapped_class return wrapped_class
[docs] @classmethod def produce(cls, config: Dict, caller: CallerProtocol) -> CallbackBase: """A class method to initialize instances of ``sad.callback.CallbackBase``. Args: config (:obj:`config`): Configuration used to initialize an instance object. An example is given below:: name: "EarlyStoppingCallback" spec: allow_incomplete_epoch: False caller (:obj:`sad.callback.CallerProtocol`): An instance of a class that is compliant with ``CallerProtocol``. Currently ``sad.trainer.TrainerBase`` is of this class type. A callback instance will be created with its caller. During caller's loop, callback methods will be invoked. """ callback_name = config.get("name") if callback_name not in cls._registry: cls.logger.error(f"Unable to produce {callback_name} callback.") raise NotImplementedError return cls._registry[callback_name](config, caller)