Turi Create  4.0
turi::sgd::sgd_interface_base Class Referenceabstract

#include <toolkits/sgd/sgd_interface.hpp>

Public Member Functions

virtual void setup (const v2::ml_data &train_data, const std::map< std::string, flexible_type > &options)
 
virtual void setup_iteration (size_t iteration, double step_size)
 
virtual void finalize_iteration ()
 
virtual double l2_regularization_factor () const
 
virtual double max_step_size () const
 
virtual bool state_is_numerically_stable () const
 
virtual void setup_optimization (size_t random_seed=size_t(-1), bool _in_trial_mode=false)=0
 
virtual double calculate_loss (const v2::ml_data &data) const =0
 
virtual double reported_loss_value (double accumulative_loss) const =0
 
virtual std::string reported_loss_name () const =0
 
virtual double current_regularization_penalty () const =0
 
virtual double apply_sgd_step (size_t thread_idx, const std::vector< v2::ml_data_entry > &x, double y, double step_size)=0
 

Detailed Description

The base class for all the SGD interfaces. This interface governs all the interactions between the sgd solvers and the model.

To implement an sgd solver, subclass sgd_interface_base and implement the appropriate methods.

Then on top of this, choose the solver, and template it with your interface. The basic solver is the basic_sgd_solver, defined in basic_sgd_solver.hpp.

Example:

class simple_sgd_interface { ... };

std::shared_ptr<simple_sgd_interface> iface(new simple_sgd_interface);

basic_sgd_solver<simple_sgd_interface> solver(iface, train_data, options);

auto training_status = solver.run();

Definition at line 48 of file sgd_interface.hpp.

Member Function Documentation

◆ apply_sgd_step()

virtual double turi::sgd::sgd_interface_base::apply_sgd_step ( size_t  thread_idx,
const std::vector< v2::ml_data_entry > &  x,
double  y,
double  step_size 
)
pure virtual

◆ calculate_loss()

virtual double turi::sgd::sgd_interface_base::calculate_loss ( const v2::ml_data &  data) const
pure virtual

Calculate the value of the objective function as determined by the loss function, for a full data set, minus the regularization penalty.

In reporting this loss, reported_loss_value(...) is called on this value to get a loss value to print.

Parameters
dataThe data to use in calculating the objective function.
Returns
(loss for objective, loss for reporting)

Implemented in turi::factorization::factorization_sgd_interface< GLMModel, _LossModelProfile, _regularization_type >.

◆ current_regularization_penalty()

virtual double turi::sgd::sgd_interface_base::current_regularization_penalty ( ) const
pure virtual

Calculate the current regularization penalty. This is used to compute the objective value, which is interpreted as loss + reg penalty.

Implemented in turi::factorization::factorization_sgd_interface< GLMModel, _LossModelProfile, _regularization_type >.

◆ finalize_iteration()

virtual void turi::sgd::sgd_interface_base::finalize_iteration ( )
inlinevirtual

Called at the end of each pass through the data.

Reimplemented in turi::factorization::factorization_sgd_interface< GLMModel, _LossModelProfile, _regularization_type >.

Definition at line 69 of file sgd_interface.hpp.

◆ l2_regularization_factor()

virtual double turi::sgd::sgd_interface_base::l2_regularization_factor ( ) const
inlinevirtual

For automatically tuning the SGD step size and calculating the decrease rate of the step size. This value is also used to determine an upper bound on the allowed sgd step size, above which the algorithm stops being numerically stable. It also helps govern the decrease rate of the step size over iterations.

Reimplemented in turi::factorization::factorization_sgd_interface< GLMModel, _LossModelProfile, _regularization_type >.

Definition at line 77 of file sgd_interface.hpp.

◆ max_step_size()

virtual double turi::sgd::sgd_interface_base::max_step_size ( ) const
inlinevirtual

Gives a hard limit on the sgd step size. Certain algorithms will blow up with a step size too large, and this gives a method of setting a hard limit on step sizes considered.

Reimplemented in turi::factorization::factorization_sgd_interface< GLMModel, _LossModelProfile, _regularization_type >.

Definition at line 83 of file sgd_interface.hpp.

◆ reported_loss_name()

virtual std::string turi::sgd::sgd_interface_base::reported_loss_name ( ) const
pure virtual

The name of the loss to report on each iteration.

For example, if squared error loss is used, reported_loss_name() could give RMSE, and then reported_loss_value(v) would be std::sqrt(v).

Implemented in turi::factorization::factorization_sgd_interface< GLMModel, _LossModelProfile, _regularization_type >.

◆ reported_loss_value()

virtual double turi::sgd::sgd_interface_base::reported_loss_value ( double  accumulative_loss) const
pure virtual

The value of the reported loss. The apply_sgd_step accumulates estimated loss values between samples. This function is called with this accumulated value to get a value

For example, if squared error loss is used, reported_loss_name() could give RMSE, and then reported_loss_value(v) would be std::sqrt(v).

Implemented in turi::factorization::factorization_sgd_interface< GLMModel, _LossModelProfile, _regularization_type >.

◆ setup()

virtual void turi::sgd::sgd_interface_base::setup ( const v2::ml_data &  train_data,
const std::map< std::string, flexible_type > &  options 
)
inlinevirtual

Called at the start of optimization, before any other functions are called.

Perform any setup in light of the data used for training the model. Since ml_data has some statistics (e.g. maximum row size), these can be saved for processing stuff later.

Reimplemented in turi::factorization::factorization_sgd_interface< GLMModel, _LossModelProfile, _regularization_type >.

Definition at line 60 of file sgd_interface.hpp.

◆ setup_iteration()

virtual void turi::sgd::sgd_interface_base::setup_iteration ( size_t  iteration,
double  step_size 
)
inlinevirtual

Called before each pass through the data.

Reimplemented in turi::factorization::factorization_sgd_interface< GLMModel, _LossModelProfile, _regularization_type >.

Definition at line 65 of file sgd_interface.hpp.

◆ setup_optimization()

virtual void turi::sgd::sgd_interface_base::setup_optimization ( size_t  random_seed = size_t(-1),
bool  _in_trial_mode = false 
)
pure virtual

Sets up the optimization run. Called at the beginning of an optimization run or in the presence of numerical instabilities to reset the solver. Optimization is attempted again with a smaller step size.

Implemented in turi::factorization::factorization_sgd_interface< GLMModel, _LossModelProfile, _regularization_type >.

◆ state_is_numerically_stable()

virtual bool turi::sgd::sgd_interface_base::state_is_numerically_stable ( ) const
inlinevirtual

If there are any issues with the model, this function can return false to force a reset. It is called once at the end of each iteration.

Returns true if the state is numerically stable, and false if there are any numerical instabilities detected now or in the previous pass over the data. If this is true, then reset_state is called.

Reimplemented in turi::factorization::factorization_sgd_interface< GLMModel, _LossModelProfile, _regularization_type >.

Definition at line 94 of file sgd_interface.hpp.


The documentation for this class was generated from the following file: