Turi Create  4.0
turi::sgd::sgd_solver_base Class Referenceabstract

#include <toolkits/sgd/sgd_solver_base.hpp>

Public Member Functions

std::map< std::string, variant_typerun ()
 

Protected Member Functions

 sgd_solver_base (const std::shared_ptr< sgd_interface_base > &model_interface, const v2::ml_data &_train_data, const std::map< std::string, flexible_type > &_options)
 
virtual void setup (sgd_interface_base *iface)
 
virtual std::pair< double, double > run_iteration (size_t iteration, sgd_interface_base *iface, const v2::ml_data &data, double step_size)=0
 
virtual std::pair< double, double > calculate_objective (sgd_interface_base *iface, const v2::ml_data &data, size_t iteration) const
 

Static Protected Member Functions

static void add_options (option_manager &options)
 

Protected Attributes

std::shared_ptr< sgd_interface_basemodel_interface
 
const std::map< std::string, flexible_typeoptions
 

Detailed Description

The base solver class for all the general SGD methods.

This class provides the high-level functionality for the sgd methods. Particular versions of SGD are implemented using the run_iteration method, which is called to do one pass through the data on a particular block of data points.

Definition at line 35 of file sgd_solver_base.hpp.

Constructor & Destructor Documentation

◆ sgd_solver_base()

turi::sgd::sgd_solver_base::sgd_solver_base ( const std::shared_ptr< sgd_interface_base > &  model_interface,
const v2::ml_data &  _train_data,
const std::map< std::string, flexible_type > &  _options 
)
protected

The constructor.

Member Function Documentation

◆ add_options()

static void turi::sgd::sgd_solver_base::add_options ( option_manager options)
staticprotected

Call the following function to insert the option definitions needed for the common sgd optimization class into an option manager. Meant to be called by the subclasses of sgd_solver_base.

◆ calculate_objective()

virtual std::pair<double, double> turi::sgd::sgd_solver_base::calculate_objective ( sgd_interface_base iface,
const v2::ml_data &  data,
size_t  iteration 
) const
protectedvirtual

Called to calculate the current objective value for the data. Defaults to calling calculate_loss() + current_regularizer_value() in the current interface. the function to get the current regularization term; however, can be overridden if need be. (For example, for optimizing ranking functions, the loss function doesn't fit into the standard framework laid out by the model's calculate_fx function.

Returns
(objective value, reportable training loss)

Reimplemented in turi::factorization::ranking_sgd_solver_base< SGDInterface >.

◆ run()

std::map<std::string, variant_type> turi::sgd::sgd_solver_base::run ( )

The main function to run the sgd solver given the current options.

◆ run_iteration()

virtual std::pair<double, double> turi::sgd::sgd_solver_base::run_iteration ( size_t  iteration,
sgd_interface_base iface,
const v2::ml_data &  data,
double  step_size 
)
protectedpure virtual

Called to run one iteration of the SGD algorithm on the training data.

Parameters
[in]iterationThe iteration number of the current pass through the data.
[in]ifaceA pointer to the interface class. This can be upcast to the true SGDInterface class for use in the actual code.
[in]step_sizeThe step size to use for this pass through the data.
Returns
A pair – (objective_value, loss)

Implemented in turi::factorization::ranking_sgd_solver_base< SGDInterface >.

◆ setup()

virtual void turi::sgd::sgd_solver_base::setup ( sgd_interface_base iface)
inlineprotectedvirtual

Called at the start of a run, before any run_iteration is called.

Definition at line 109 of file sgd_solver_base.hpp.

Member Data Documentation

◆ model_interface

std::shared_ptr<sgd_interface_base> turi::sgd::sgd_solver_base::model_interface
protected

The main interface to the model, implementing sgd-specific routines for that model.

Definition at line 96 of file sgd_solver_base.hpp.

◆ options

const std::map<std::string, flexible_type> turi::sgd::sgd_solver_base::options
protected

The training options of the solver.

Definition at line 100 of file sgd_solver_base.hpp.


The documentation for this class was generated from the following file: