Source code for sad.generator.base

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2020 Apple Inc. All Rights Reserved.
#

import json
import logging
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterator, List, Set, Tuple

import numpy as np

from sad.model import ModelBase


[docs]class GeneratorBase(ABC): """A generator base class that all concrete generator classes should inherit from. A generator class should also be an iterable, by implementing ``__iter__`` method. The way a generator works is that file(s) containing training/validation samples will be first added to the generator by calling ``self.add(file)``. Then by calling ``self.prepare()``, the generator is informed that all files have been added, and it is the time to get ready to iterate through the files and produce samples. At this point, one can use the generator in following manner:: for features, targets in my_generator: # fit my model To only iterate through training samples, one can do:: for features, targets in my_generator.get_trn(): # fit my model Same applies to validation. """ def __init__(self, config: Dict, model: ModelBase, task: "TrainingTask"): self.config = config self.model = model self.task = task self.input_files = [] self.logger = logging.getLogger(f"generator.{self.__class__.__name__}") @property def config(self) -> Dict: """Configuration information that is used to initialize the generator instance.""" return self._config @config.setter def config(self, config: Dict): self._config = config @property def spec(self) -> Dict: """A reference to ``"spec"`` field in ``self.config``. If such field does not exist or its value is ``None``, an empty dictionary will be created.""" if self.config.get("spec") is None: self.config["spec"] = {} return self.config["spec"] @spec.setter def spec(self, spec: Dict): self.config["spec"] = spec @property def model(self) -> ModelBase: """A trainable model instance, which will be trained using samples produced by current generator instance.""" return self._model @model.setter def model(self, model: ModelBase): self._model = model @property def task(self) -> "sad.tasks.training.TrainingTask": """An instance of training task associated with the generator. It is the task in which current generator is initialized. """ return self._task @task.setter def task(self, task: "sad.tasks.training.TrainingTask"): self._task = task @property def mode(self): """:obj:`str`: The mode of how generator works. Currently supports two configurations: ``"random|iteration"``. 1. ``"random"``: When working in this mode, a number of ``self.u_batch`` random users will be selected (with replacement) from entire user set in an iteration. For items, a number of ``self.i_batch`` positive items that each user has interacted with will be randomly (with replacement) generated. Same number of negative items that user hasn't interacted with will be randomly generated as well, producing triplets of samples in the format of (user, item i (interacted), item j (non-interacted)). 2. ``"iteration"``: When working in this mode, all users will be iterated through in a randomized order. Same to items. For each positive user-item interaction, a number of ``self.n_negatives`` non-interacted items will be randomly selected. """ return self.spec.get("mode", "random") @property def u_batch(self) -> int: """The number of random users that will be chosen when working in ``"random"`` mode. Read directly from ``"u_batch"`` field in ``self.spec``. When not configured, it will be set to 20% users.""" u_batch = self.spec.get("u_batch") if not u_batch: n = self.model.n u_batch = int(0.2 * n) return u_batch @property def i_batch(self) -> int: """The number of random items that will be chosen when working in ``"random"`` mode. Read directly from ``"i_batch"`` field in ``self.spec``. When not configured, it will be set to 20% items.""" i_batch = self.spec.get("i_batch") if not i_batch: m = self.model.m i_batch = int(0.2 * m) return i_batch @property def n_negatives(self) -> int: """The number of negative samples will be drawn for each positive user-item interaction. Read directly from ``"n_negatives"`` field in ``self.spec``. Valid when the generator is performing in ``"iteration"`` mode. Default to five.""" return self.spec.get("n_negatives", 5) @property def batch_size(self) -> int: """Batch size when generating samples in minibatch.""" batch_size = self.spec.get("batch_size", 128) return batch_size @property def user_idx_to_id(self) -> Dict[int, str]: """A dictionary with keys being user indices from zero to ``n_users-1``, and values being their ids. Will be set after ``self.prepare()`` is called.""" return self._user_idx_to_id @user_idx_to_id.setter def user_idx_to_id(self, user_idx_to_id: Dict[int, str]): self._user_idx_to_id = user_idx_to_id @property def user_id_to_idx(self) -> Dict[str, int]: """A dictionary with keys being user id and values being the index. It is the inverse mapping of ``self.user_idx_to_id``.""" return self._user_id_to_idx @user_id_to_idx.setter def user_id_to_idx(self, user_id_to_idx: Dict[str, int]): self._user_id_to_idx = user_id_to_idx @property def item_idx_to_id(self) -> Dict[int, str]: """A dictionary with keys being item indices from zero to ``n_items-1``, and values being their ids. Will be set after ``self.prepare()`` is called.""" return self._item_idx_to_id @item_idx_to_id.setter def item_idx_to_id(self, item_idx_to_id: Dict[int, str]): self._item_idx_to_id = item_idx_to_id @property def item_id_to_idx(self) -> Dict[str, int]: """A dictionary with keys being item id and values being the index. It is the inverse mapping of ``self.item_idx_to_id``.""" return self._item_id_to_idx @item_id_to_idx.setter def item_id_to_idx(self, item_id_to_idx: Dict[str, int]): self._item_id_to_idx = item_id_to_idx @property def uidx_to_iidxs_tuple(self) -> Dict[int, Tuple[Set[int], Set[int]]]: """A dictionary mapping from user idx to a tuple in which the first element is a set of item idxs the user has interacted with, and the second one is a set of non-interacted item idxs. Will be set after ``self.prepare()`` is called.""" return self._uidx_to_iidxs_tuple @uidx_to_iidxs_tuple.setter def uidx_to_iidxs_tuple( self, uidx_to_iidxs_tuple: Dict[int, Tuple[Set[int], Set[int]]] ): self._uidx_to_iidxs_tuple = uidx_to_iidxs_tuple @property def data(self) -> Dict[str, List[str]]: """A dictionary with keys being user ids and values being a list of item ids that user has interacted with. Lists of complete users and items will be inferred from it. Will be set after ``self.prepare()`` is called. """ return self._data @data.setter def data(self, data: dict): self._data = data @property def tensor(self) -> np.ndarray: """A three way array with shape of ``n x m x m`` where ``n`` is the number of users, and ``m`` is the number of items. A value of ``1`` at location ``(u, i, j)`` suggests ``u``-th user prefers ``i``-th item over ``j``-th item. ``-1`` suggests the opposite. A value of ``0`` means no information available to determine the preference of the two items. Value will be optionally set after ``self.prepare()`` is called, depending on the value of ``self.tensor_flag``, for the purpose of saving memory. """ return self._tensor @tensor.setter def tensor(self, tensor: np.ndarray): self._tensor = tensor @property def tensor_flag(self) -> bool: """A boolean flag to indicate if three way data tensor ``self.tensor`` will be constructed. ``False`` will stop creating the tensor to save memory consumption. """ return self.spec.get("tensor_flag", True) @property def user_idx_to_preference(self) -> Dict[int, Dict[Tuple[str, str], int]]: """A dictionary contains a mapping between user idx and item pairs that the user prefer one over the other. The item pairs are stored in a dictionary as well, with key being a tuple of two item ids, and value being ``1``.""" return self._user_idx_to_preference @user_idx_to_preference.setter def user_idx_to_preference( self, user_idx_to_preference: Dict[int, Dict[Tuple[str, str], int]] ): self._user_idx_to_preference = user_idx_to_preference @property def input_files(self) -> List[str]: """A list of files from where samples will be read.""" return self._input_files @input_files.setter def input_files(self, input_files: List[str]): self._input_files = input_files @property def output_dir(self) -> str: """Read directly from ``self.task.output_dir``.""" return self.task.output_dir @property def input_dir(self) -> str: """Read directly from ``self.task.input_dir``.""" return self.task.input_dir
[docs] @abstractmethod def prepare(self): """A method to inform generator to setup things in order to be prepared for generating samples. Concrete subclasses are responsible to implement this method. """ raise NotImplementedError
[docs] @abstractmethod def get_trn(self) -> Iterator[Any]: """Interface to generator samples for model training. Returns: :obj:`Iterator[Any]`: An iterable that training samples will be iterated through in mini-batches. """ raise NotImplementedError
[docs] @abstractmethod def get_val_or_not(self) -> Iterator[Any]: """Interface to generator samples for validating model. Returns: :obj:`Iterator[Any]`: An iterable that validation samples will be iterated through in mini-batches. """ raise NotImplementedError
[docs] def add(self, filename: str): """A method to add a local file to generator. The local file contains data from which mini-batches of training/validation samples will be read. Args: filename (:obj:`str`): A file path pointing the file. """ if not os.path.exists(filename): self.logger.warning( f"Unable to add {filename} to generator, file does not exist." ) return self.input_files.append(filename)
[docs] def save(self, working_dir: str): """Save generator's configuration to a folder. Args: working_dir (:obj:`str`): A local path where the configuration of the generator will be saved. """ if not working_dir: working_dir = self.output_dir model_s3_key_path = self.model.s3_key_path filename = "generator_config.json" os.makedirs(os.path.join(working_dir, model_s3_key_path), exist_ok=True) with open(os.path.join(working_dir, model_s3_key_path, filename), "w") as fout: json.dump(self.config, fout)
[docs]class GeneratorFactory: """A factory class that is responsible to create generator instances.""" logger = logging.getLogger("generator.GeneratorFactory") """:obj:`logging.Logger`: Class attribute for logging.""" _registry = dict() """:obj:`dict`: Registry dictionary containing a mapping between class name and class object."""
[docs] @classmethod def register(cls, wrapped_class: GeneratorBase) -> GeneratorBase: class_name = wrapped_class.__name__ if class_name in cls._registry: cls.logger.warning(f"Generator {class_name} already registered, Ignoring.") return wrapped_class cls._registry[class_name] = wrapped_class return wrapped_class
[docs] @classmethod def produce( cls, config: Dict, model: ModelBase, task: "TrainingTask" ) -> GeneratorBase: generator_name = config.get("name") if generator_name not in cls._registry: cls.logger.error(f"Unable to produce {generator_name} generator.") raise NotImplementedError return cls._registry[generator_name](config, model, task)