Source code for sad.trainer.fm

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2020 Apple Inc. All Rights Reserved.
#

import json
import logging
import os

from sad.generator import ImplicitFeedbackGenerator
from sad.model import FMModel

from .base import TrainerBase, TrainerFactory


[docs]@TrainerFactory.register class FMTrainer(TrainerBase): def __init__( self, config: dict, model: FMModel, generator: ImplicitFeedbackGenerator, task: "TrainingTask", ): super().__init__(config, model, generator, task) self.logger = logging.getLogger(f"trainer.{self.__class__.__name__}") @property def loss_name(self) -> str: """Read directly from ``"loss"`` field in ``self.spec``. Currently can take ``"bpr"|"warp"`` two values. Default is ``"bpr"``. Specific to ``sad.model.FMModel``.""" return self.spec.get("loss", "bpr") @property def n_negative_samples(self) -> int: """Read directly from ``"n_negative_samples"`` field in ``self.spec``. It means the number of samples that will be drawn for ``"warp"`` loss.""" return self.spec.get("n_negative_samples", 1) @property def w_l2(self) -> float: """Weight of L2 regularization to parameters. Read directly from ``"w_l2"`` field in ``self.spec``.""" return self.spec.get("w_l2", 0.01)
[docs] def save(self, working_dir: str = None): """Save trainer configuration.""" if not working_dir: working_dir = self.working_dir model_s3_key_path = self.model.s3_key_path os.makedirs(os.path.join(working_dir, model_s3_key_path), exist_ok=True) json.dump( self.config, open( os.path.join(working_dir, model_s3_key_path, "trainer_config.json"), "w" ), )
[docs] def train(self): generator = self.generator self.logger.info("Generator begins to prepare data ...") generator.prepare() self.logger.info("Data preparation done ...") model = self.model data_df = generator.data_df data_df = data_df[data_df["rating"] > 0] data_df = data_df.rename(columns={"userID": "user_id", "itemID": "item_id"}) model.initialize_fm_model(self) self.on_loop_begin() model.fm_model.fit( data_df[["user_id", "item_id"]], epochs=self.n_epochs, verbose=True ) self.on_loop_end()
[docs] def load(self, folder: str): pass