sad.trainer package
Submodules
sad.trainer.base module
- class TrainerBase(config: Dict, model: sad.model.base.ModelBase, generator: sad.generator.base.GeneratorBase, task: TrainingTask)[source]
Bases:
sad.callback.caller.CallerProtocol
The abstract trainer base class. It is the class that all concrete trainer classes will inherit from.
In the meanwhile, this class is complaint of
sad.callback.CallerProtocol
.- add_final_metrics_to_model_metrics(**kwargs)[source]
Class specific method to add final metrics to model’s metrics attribute. After addition, model’s metrics will include
"final"
field with structure shown below:metrics = { "final": { "ll": float, "t_sparsity": float, }, }
- property config: Dict
Configuration information that is used to initialize the instance.
- property eval_at_every_step
Read directly from
self.spec
. A number to indicate how many steps log likelihood will be evaluated. A negative number means do not evaluate at step level.- Type
int
- property generator: sad.generator.base.GeneratorBase
A reference to a generator instance, which will be used by current trainer to perform a training task.
- abstract load(working_dir: str)[source]
Load states of an trainer intance; mostly for continue the training loop of a saved model.
- property lr
Read directly from
self.spec
. Learning rate. Subject to changes during training by callbacks.- Type
float
- property model: sad.model.base.ModelBase
A reference to a model instance. This model will be trained during training loop by current trainer.
- property n_epochs: int
The number of epochs during training, specific to
TrainerBase
. Will read directly from"n_epochs"
field inself.spec
.
- property n_iters: int
The number of iterations that will happen in a trainer. Set to be an alias to
self.n_epochs
.
- on_loop_end(**kwargs)[source]
Will be invoked at the end of training loop. Save trainer instance to
self.working_dir
, and saveself.model
,self.generator
in the meanwhile.This method overwrites
on_loop_end
insad.callback.CallerProtocol
.
- property spec: Dict
A reference to
"spec"
field inself.config
. When no such a field available or the value isNone
, an empty dictionary will be set.
- property stop
A flag to indicate whether to stop training. Subject to changes during training by callbacks.
- Type
boolean
- property task: TrainingTask
A reference to an instance of training task associated with current trainer. It is the task instance in which a trainer is initialized.
- abstract train()[source]
The main training loop. Concrete trainer classes are responsible to provide implementations of their training logic.
- property working_dir
Read directly from
self.task.output_dir
.- Type
str
- class TrainerFactory[source]
Bases:
object
A factory class that is responsible to create trainer instances.
- logger = <Logger trainer.TrainerFactory (INFO)>
Class attribute for logging.
- Type
logging.Logger
- classmethod produce(config: Dict, model: sad.model.base.ModelBase, generator: sad.generator.base.GeneratorBase, task: TrainingTask) sad.trainer.base.TrainerBase [source]
A class level method to initialize instances of
sad.trainer.TrainerBase
classes.- Parameters
config (
config
) –Configuration used to initialize instance object. An example is given below:
name: "SGDTrainer" spec: w_l1: 0.01 w_l2: 0.01: ...
model (
sad.model.ModelBase
) – An instance of model, a trainable that a trainer will train.generator (
sad.generator.GeneratorBase
) – An instance of generator, from which training and validation data are generated.task (
sad.tasks.training.TrainingTask
) – An instance of training task, from which a trainer is created.
- classmethod register(wrapped_class: sad.trainer.base.TrainerBase) sad.trainer.base.TrainerBase [source]
A class level decorator responsible to decorate
sad.trainer.TrainerBase
classes and register them intoTrainerFactory.registry
.
sad.trainer.cornac module
- class CornacTrainer(config: dict, model: sad.model.cornac.CornacModel, generator: sad.generator.implicit_fb.ImplicitFeedbackGenerator, task: TrainingTask)[source]
Bases:
sad.trainer.base.TrainerBase
- property lambda_reg
Read directly from
self.spec
. Thelambda
regularization parameter that will be used during training. Specific tosad.model.CoracModel
.- Type
float
sad.trainer.fm module
- class FMTrainer(config: dict, model: sad.model.fm.FMModel, generator: sad.generator.implicit_fb.ImplicitFeedbackGenerator, task: TrainingTask)[source]
Bases:
sad.trainer.base.TrainerBase
- load(folder: str)[source]
Load states of an trainer intance; mostly for continue the training loop of a saved model.
- property loss_name: str
Read directly from
"loss"
field inself.spec
. Currently can take"bpr"|"warp"
two values. Default is"bpr"
. Specific tosad.model.FMModel
.
- property n_negative_samples: int
Read directly from
"n_negative_samples"
field inself.spec
. It means the number of samples that will be drawn for"warp"
loss.
- train()[source]
The main training loop. Concrete trainer classes are responsible to provide implementations of their training logic.
- property w_l2: float
Weight of L2 regularization to parameters. Read directly from
"w_l2"
field inself.spec
.
sad.trainer.msft_ncf module
- class MSFTRecNCFTrainer(config: dict, model: sad.model.msft_ncf.MSFTRecNCFModel, generator: sad.generator.implicit_fb.ImplicitFeedbackGenerator, task: TrainingTask)[source]
Bases:
sad.trainer.base.TrainerBase
- property i_idxs: List[int]
Read directly from
self.spec
. A list of items, represented by item indices. The pairwise comparison over those items from users inself.u_idxs
will be used to evaluate the model during training. Configurable to a subset of items for efficiency consideration.
- load(folder: str)[source]
Load states of an trainer intance; mostly for continue the training loop of a saved model.
- train()[source]
The main training loop. Concrete trainer classes are responsible to provide implementations of their training logic.
- property u_idxs: List[int]
Read directly from
self.spec
. A list of users represented by user indices, on whom log likelihood will be evaluated. Configurable to a subset of users for efficiency consideration.
sad.trainer.msft_rbm module
- class MSFTRecRBMTrainer(config: dict, model: sad.model.msft_rbm.MSFTRecRBMModel, generator: sad.generator.implicit_fb.ImplicitFeedbackGenerator, task: TrainingTask)[source]
Bases:
sad.trainer.base.TrainerBase
sad.trainer.msft_vae module
- class MSFTRecVAETrainer(config: dict, model: sad.model.msft_vae.MSFTRecVAEModel, generator: sad.generator.implicit_fb.ImplicitFeedbackGenerator, task: TrainingTask)[source]
Bases:
sad.trainer.base.TrainerBase
- property beta: float
The beta parameter in beta-VAE model. Will read directly from
"beta"
field fromself.spec
.
- evaluation(scores: numpy.ndarray)[source]
Actual method to run the evaluation. During evaluation, item relative scores will be calculate for each item pair with i is more preferrable than j. Score mean, std and log likelihood of the model will be calculated.
- Parameters
scores (
np.ndarray
) – Pre-calculated user-item preference.
- property evaluation_flag: bool
An attribute that is specific to
MSFTRecVAETrainer
. When set toTrue
, enable to calculate relative preference scores for each item pair withi
-th item being more preferrable thanj
-th item.
sad.trainer.sad module
- class SGDTrainer(config: dict, model: sad.model.sad.SADModel, generator: sad.generator.implicit_fb.ImplicitFeedbackGenerator, task: TrainingTask)[source]
Bases:
sad.trainer.base.TrainerBase
- property i_idxs: List[int]
Read directly from
self.spec
. A list of items, represented by item indices. The pairwise comparison over those items from users inself.u_idxs
will be used to evaluate the model during training. Configurable to a subset of items for efficiency consideration.
- load(folder: str)[source]
Load states of an trainer intance; mostly for continue the training loop of a saved model.
- train()[source]
The main training loop. Concrete trainer classes are responsible to provide implementations of their training logic.
- property u_idxs: List[int]
Read directly from
self.spec
. A list of users represented by user indices, on whom log likelihood will be evaluated. Configurable to a subset of users for efficiency consideration.
- property w_l1
Read directly from
self.spec
. The weight ofL1
penalty on parameterT
in a SAD model.- Type
float
- property w_l2
Read directly from
self.spec
. The weight ofL2
penalty on parameters ofXI
andH
in a SAD model.- Type
float
sad.trainer.svd module
- class SVDTrainer(config: dict, model: sad.model.svd.SVDModel, generator: sad.generator.implicit_fb.ImplicitFeedbackGenerator, task: TrainingTask)[source]
Bases:
sad.trainer.base.TrainerBase
- load(folder: str)[source]
Load states of an trainer intance; mostly for continue the training loop of a saved model.
- property reg: float
Regularization parameter. Read directly from
"reg"
field inself.spec
.