sad.trainer package

Submodules

sad.trainer.base module

class TrainerBase(config: Dict, model: sad.model.base.ModelBase, generator: sad.generator.base.GeneratorBase, task: TrainingTask)[source]

Bases: sad.callback.caller.CallerProtocol

The abstract trainer base class. It is the class that all concrete trainer classes will inherit from.

In the meanwhile, this class is complaint of sad.callback.CallerProtocol.

add_final_metrics_to_model_metrics(**kwargs)[source]

Class specific method to add final metrics to model’s metrics attribute. After addition, model’s metrics will include "final" field with structure shown below:

metrics = {
    "final": {
        "ll": float,
        "t_sparsity": float,
    },
}
property config: Dict

Configuration information that is used to initialize the instance.

property eval_at_every_step

Read directly from self.spec. A number to indicate how many steps log likelihood will be evaluated. A negative number means do not evaluate at step level.

Type

int

property generator: sad.generator.base.GeneratorBase

A reference to a generator instance, which will be used by current trainer to perform a training task.

abstract load(working_dir: str)[source]

Load states of an trainer intance; mostly for continue the training loop of a saved model.

property lr

Read directly from self.spec. Learning rate. Subject to changes during training by callbacks.

Type

float

property model: sad.model.base.ModelBase

A reference to a model instance. This model will be trained during training loop by current trainer.

property n_epochs: int

The number of epochs during training, specific to TrainerBase. Will read directly from "n_epochs" field in self.spec.

property n_iters: int

The number of iterations that will happen in a trainer. Set to be an alias to self.n_epochs.

on_loop_end(**kwargs)[source]

Will be invoked at the end of training loop. Save trainer instance to self.working_dir, and save self.model, self.generator in the meanwhile.

This method overwrites on_loop_end in sad.callback.CallerProtocol.

abstract save(working_dir: str)[source]

Save an intance of trainer for later usage.

property spec: Dict

A reference to "spec" field in self.config. When no such a field available or the value is None, an empty dictionary will be set.

property stop

A flag to indicate whether to stop training. Subject to changes during training by callbacks.

Type

boolean

property task: TrainingTask

A reference to an instance of training task associated with current trainer. It is the task instance in which a trainer is initialized.

abstract train()[source]

The main training loop. Concrete trainer classes are responsible to provide implementations of their training logic.

property working_dir

Read directly from self.task.output_dir.

Type

str

class TrainerFactory[source]

Bases: object

A factory class that is responsible to create trainer instances.

logger = <Logger trainer.TrainerFactory (INFO)>

Class attribute for logging.

Type

logging.Logger

classmethod produce(config: Dict, model: sad.model.base.ModelBase, generator: sad.generator.base.GeneratorBase, task: TrainingTask) sad.trainer.base.TrainerBase[source]

A class level method to initialize instances of sad.trainer.TrainerBase classes.

Parameters
  • config (config) –

    Configuration used to initialize instance object. An example is given below:

    name: "SGDTrainer"
    spec:
        w_l1: 0.01
        w_l2: 0.01:
        ...
    

  • model (sad.model.ModelBase) – An instance of model, a trainable that a trainer will train.

  • generator (sad.generator.GeneratorBase) – An instance of generator, from which training and validation data are generated.

  • task (sad.tasks.training.TrainingTask) – An instance of training task, from which a trainer is created.

classmethod register(wrapped_class: sad.trainer.base.TrainerBase) sad.trainer.base.TrainerBase[source]

A class level decorator responsible to decorate sad.trainer.TrainerBase classes and register them into TrainerFactory.registry.

sad.trainer.cornac module

class CornacTrainer(config: dict, model: sad.model.cornac.CornacModel, generator: sad.generator.implicit_fb.ImplicitFeedbackGenerator, task: TrainingTask)[source]

Bases: sad.trainer.base.TrainerBase

property lambda_reg

Read directly from self.spec. The lambda regularization parameter that will be used during training. Specific to sad.model.CoracModel.

Type

float

load(folder: str)[source]

Load states of an trainer intance; mostly for continue the training loop of a saved model.

save(working_dir: Optional[str] = None)[source]

Save trainer configuration.

train()[source]

The main training loop. Concrete trainer classes are responsible to provide implementations of their training logic.

sad.trainer.fm module

class FMTrainer(config: dict, model: sad.model.fm.FMModel, generator: sad.generator.implicit_fb.ImplicitFeedbackGenerator, task: TrainingTask)[source]

Bases: sad.trainer.base.TrainerBase

load(folder: str)[source]

Load states of an trainer intance; mostly for continue the training loop of a saved model.

property loss_name: str

Read directly from "loss" field in self.spec. Currently can take "bpr"|"warp" two values. Default is "bpr". Specific to sad.model.FMModel.

property n_negative_samples: int

Read directly from "n_negative_samples" field in self.spec. It means the number of samples that will be drawn for "warp" loss.

save(working_dir: Optional[str] = None)[source]

Save trainer configuration.

train()[source]

The main training loop. Concrete trainer classes are responsible to provide implementations of their training logic.

property w_l2: float

Weight of L2 regularization to parameters. Read directly from "w_l2" field in self.spec.

sad.trainer.msft_ncf module

class MSFTRecNCFTrainer(config: dict, model: sad.model.msft_ncf.MSFTRecNCFModel, generator: sad.generator.implicit_fb.ImplicitFeedbackGenerator, task: TrainingTask)[source]

Bases: sad.trainer.base.TrainerBase

property i_idxs: List[int]

Read directly from self.spec. A list of items, represented by item indices. The pairwise comparison over those items from users in self.u_idxs will be used to evaluate the model during training. Configurable to a subset of items for efficiency consideration.

load(folder: str)[source]

Load states of an trainer intance; mostly for continue the training loop of a saved model.

save(working_dir: Optional[str] = None)[source]

Save trainer configuration.

train()[source]

The main training loop. Concrete trainer classes are responsible to provide implementations of their training logic.

property u_idxs: List[int]

Read directly from self.spec. A list of users represented by user indices, on whom log likelihood will be evaluated. Configurable to a subset of users for efficiency consideration.

sad.trainer.msft_rbm module

class MSFTRecRBMTrainer(config: dict, model: sad.model.msft_rbm.MSFTRecRBMModel, generator: sad.generator.implicit_fb.ImplicitFeedbackGenerator, task: TrainingTask)[source]

Bases: sad.trainer.base.TrainerBase

load(folder: str)[source]

Load states of an trainer intance; mostly for continue the training loop of a saved model.

save(working_dir: Optional[str] = None)[source]

Save trainer configuration.

train()[source]

The main training loop. Concrete trainer classes are responsible to provide implementations of their training logic.

sad.trainer.msft_vae module

class MSFTRecVAETrainer(config: dict, model: sad.model.msft_vae.MSFTRecVAEModel, generator: sad.generator.implicit_fb.ImplicitFeedbackGenerator, task: TrainingTask)[source]

Bases: sad.trainer.base.TrainerBase

property beta: float

The beta parameter in beta-VAE model. Will read directly from "beta" field from self.spec.

evaluation(scores: numpy.ndarray)[source]

Actual method to run the evaluation. During evaluation, item relative scores will be calculate for each item pair with i is more preferrable than j. Score mean, std and log likelihood of the model will be calculated.

Parameters

scores (np.ndarray) – Pre-calculated user-item preference.

property evaluation_flag: bool

An attribute that is specific to MSFTRecVAETrainer. When set to True, enable to calculate relative preference scores for each item pair with i-th item being more preferrable than j-th item.

load(folder: str)[source]

Load states of an trainer intance; mostly for continue the training loop of a saved model.

save(working_dir: Optional[str] = None)[source]

Save trainer configuration.

train()[source]

The main training loop. Concrete trainer classes are responsible to provide implementations of their training logic.

sad.trainer.sad module

class SGDTrainer(config: dict, model: sad.model.sad.SADModel, generator: sad.generator.implicit_fb.ImplicitFeedbackGenerator, task: TrainingTask)[source]

Bases: sad.trainer.base.TrainerBase

property i_idxs: List[int]

Read directly from self.spec. A list of items, represented by item indices. The pairwise comparison over those items from users in self.u_idxs will be used to evaluate the model during training. Configurable to a subset of items for efficiency consideration.

load(folder: str)[source]

Load states of an trainer intance; mostly for continue the training loop of a saved model.

save(working_dir: Optional[str] = None)[source]

Save trainer configuration.

train()[source]

The main training loop. Concrete trainer classes are responsible to provide implementations of their training logic.

property u_idxs: List[int]

Read directly from self.spec. A list of users represented by user indices, on whom log likelihood will be evaluated. Configurable to a subset of users for efficiency consideration.

property w_l1

Read directly from self.spec. The weight of L1 penalty on parameter T in a SAD model.

Type

float

property w_l2

Read directly from self.spec. The weight of L2 penalty on parameters of XI and H in a SAD model.

Type

float

sad.trainer.svd module

class SVDTrainer(config: dict, model: sad.model.svd.SVDModel, generator: sad.generator.implicit_fb.ImplicitFeedbackGenerator, task: TrainingTask)[source]

Bases: sad.trainer.base.TrainerBase

load(folder: str)[source]

Load states of an trainer intance; mostly for continue the training loop of a saved model.

property reg: float

Regularization parameter. Read directly from "reg" field in self.spec.

save(working_dir: Optional[str] = None)[source]

Save trainer configuration.

train()[source]

The main training loop. Concrete trainer classes are responsible to provide implementations of their training logic.

Module contents