Source code for sad.task.base

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2020 Apple Inc. All Rights Reserved.
#

import copy
import json
import logging
import os
import tempfile
from abc import ABC, abstractmethod
from typing import Dict


[docs]class TaskBase(ABC): """A task base class that all task subclasses will inherit from. A task is the main component in our workflow. For example, when training a model, an instance of ``sad.tasks.training.TrainingTask`` will be responsible to launch the training job. """ def __init__(self, config: Dict, input_dir: str = None, output_dir: str = None): self.config = copy.deepcopy(config) self._logger = logging.getLogger(f"task.{self.__class__.__name__}") prefix = "" if self.is_local else "/mnt/" self.input_dir = input_dir or tempfile.mktemp(prefix=prefix) self.output_dir = output_dir or tempfile.mktemp(prefix=prefix)
[docs] def show_config(self): """A function to print the configuration of a running task.""" self.logger.info( f"{self.__class__.__name__} config: \n {json.dumps(self.config, indent=2)}" ) self.logger.info(f"Input directory: {self.input_dir}") self.logger.info(f"Output directory: {self.output_dir}")
@property def input_dir(self) -> str: """An absolute path that points to input directory of a running task.""" return self._input_dir @input_dir.setter def input_dir(self, input_dir: str): self._input_dir = os.path.abspath(os.path.expanduser(input_dir)) @property def output_dir(self) -> str: """An absolute path that points to output directory of a running task.""" return self._output_dir @output_dir.setter def output_dir(self, output_dir: str): self._output_dir = os.path.abspath(os.path.expanduser(output_dir)) @property def artifact_dir(self) -> str: """A path points to an artifact directory. Will be the artifact directory from Bolt when running as a Bolt job. Otherwise, will be the same as ``self.output_dir``.""" artifact_dir = self.config.get("artifact_dir") or self.output_dir return artifact_dir @property def logger(self) -> logging.Logger: """A logger instance to manage logging during the life-cycle of a task.""" return self._logger @property def is_local(self) -> bool: """A boolean flag to indicate whether the task is running in local mode.""" return self.config.get("is_local", True) @property def is_hc(self) -> bool: """A boolean flag to indicate if the task is running in HC.""" return self.config.get("is_hc", False) @property def task_id(self) -> str: """A unique string to identify a running task.""" return self.config.get("task_id") @task_id.setter def task_id(self, task_id: str): self.config["task_id"] = task_id
[docs] @abstractmethod def run(self): raise NotImplementedError