Source code for sad.model.base

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2020 Apple Inc. All Rights Reserved.
#

import copy
import datetime
import logging
import os
from abc import ABC, abstractmethod
from typing import Any, Dict

from sad.task.dummy import DummyTask
from sad.utils.job import id_generator


[docs]class ModelBase(ABC): """The abstract model base class. It is the class that all concrete model classes will inherit from. """ def __init__(self, config: Dict, task: "TrainingTask" = None): self.config = config self.task = DummyTask({}) if task is None else task self.metrics = dict() if not self.s3_key_path: now = datetime.datetime.now() s3_key_path = os.path.join( "model", f"{now.year}", f"{now.month:02d}", f"{now.day:02d}", f"model_{id_generator()}", ) self.s3_key_path = s3_key_path self.logger = logging.getLogger(f"model.{self.__class__.__name__}") @property def config(self) -> Dict: """Configuration information that is used to initialize the instance.""" return self._config @config.setter def config(self, config: Dict): self._config = config @property def spec(self) -> Dict: """A reference to ``"spec"`` field in ``self.config``.""" if self.config.get("spec") is None: self.config["spec"] = {} return self.config["spec"] @spec.setter def spec(self, spec: Dict): self.config["spec"] = spec @property def task(self) -> "sad.task.training.TrainingTask": """An instance of training task associated with current model. It is the task instance in which a model is initialized. """ return self._task @task.setter def task(self, task: "sad.task.training.TrainingTask"): self._task = task @property def working_dir(self) -> str: """Alias to ``self.task.output_dir``.""" return self.task.output_dir @property def s3_key_path(self) -> str: """A S3 key uniquely assigned to a model instance. Will be setup during model's instantiation, and populated to ``self.spec``. It is the S3 key of the model's remote store if the model will be pushed to a S3 bucket.""" return self.spec.get("s3_key_path") @s3_key_path.setter def s3_key_path(self, s3_key_path: str): self.spec["s3_key_path"] = s3_key_path @property def metrics(self) -> Dict: """A dictionary stores metrics of the model. Subject to change during model training by callbacks.""" if "metrics" not in self.spec: self.spec["metrics"] = {} return self.spec["metrics"] @metrics.setter def metrics(self, metrics: Dict): self.spec["metrics"] = copy.deepcopy(metrics)
[docs] @abstractmethod def save_checkpoint(self, working_dir: str, checkpoint_id: int): raise NotImplementedError
[docs] @abstractmethod def save(self, working_dir: str, filename: str): raise NotImplementedError
[docs] @abstractmethod def load(self, working_dir: str, filename: str): raise NotImplementedError
[docs] @abstractmethod def load_checkpoint(self, working_dir: str, checkpoint_id: int): raise NotImplementedError
[docs] @abstractmethod def load_best(self, working_dir: str, criterion: str): raise NotImplementedError
[docs] @abstractmethod def predict(self, inputs: Any) -> Any: raise NotImplementedError
[docs] @abstractmethod def reset_parameters(self): raise NotImplementedError
[docs] @abstractmethod def parameters_for_monitor(self) -> Dict[str, float]: raise NotImplementedError
[docs]class ModelFactory: """A factory class that is responsible to create model instances.""" logger = logging.getLogger("model.ModelFactory") """:class:`logging.Logger`: Class attribute for logging.""" _registry = dict() """:class:`dict`: Registry dictionary containing a mapping between class name and class object. """
[docs] @classmethod def register(cls, wrapped_class: ModelBase) -> ModelBase: """A class decorator responsible to decorate ``sad.model.ModelBase`` classes and register them into ``ModelFactory.registry``. """ class_name = wrapped_class.__name__ if class_name in cls._registry: cls.logger.warning(f"Model {class_name} already registered, ignoring.") return wrapped_class cls._registry[class_name] = wrapped_class return wrapped_class
[docs] @classmethod def produce(cls, config: Dict, task: "TrainingTask") -> ModelBase: """A class method to create instances of ``sad.model.ModelBase``. Args: config (:obj:`config`): Configuration used to initialize instance object. An example is given below:: name: SADModel spec: n: 200 m: 500 k: 100 """ model_name = config.get("name") if model_name not in cls._registry: cls.logger.error(f"Unable to produce {model_name} generator.") raise NotImplementedError return cls._registry[model_name](config, task)