Turi Create
4.0
|
#include <toolkits/sgd/sgd_solver_base.hpp>
Public Member Functions | |
std::map< std::string, variant_type > | run () |
Protected Member Functions | |
sgd_solver_base (const std::shared_ptr< sgd_interface_base > &model_interface, const v2::ml_data &_train_data, const std::map< std::string, flexible_type > &_options) | |
virtual void | setup (sgd_interface_base *iface) |
virtual std::pair< double, double > | run_iteration (size_t iteration, sgd_interface_base *iface, const v2::ml_data &data, double step_size)=0 |
virtual std::pair< double, double > | calculate_objective (sgd_interface_base *iface, const v2::ml_data &data, size_t iteration) const |
Static Protected Member Functions | |
static void | add_options (option_manager &options) |
Protected Attributes | |
std::shared_ptr< sgd_interface_base > | model_interface |
const std::map< std::string, flexible_type > | options |
The base solver class for all the general SGD methods.
This class provides the high-level functionality for the sgd methods. Particular versions of SGD are implemented using the run_iteration method, which is called to do one pass through the data on a particular block of data points.
Definition at line 35 of file sgd_solver_base.hpp.
|
protected |
The constructor.
|
staticprotected |
Call the following function to insert the option definitions needed for the common sgd optimization class into an option manager. Meant to be called by the subclasses of sgd_solver_base.
|
protectedvirtual |
Called to calculate the current objective value for the data. Defaults to calling calculate_loss() + current_regularizer_value() in the current interface. the function to get the current regularization term; however, can be overridden if need be. (For example, for optimizing ranking functions, the loss function doesn't fit into the standard framework laid out by the model's calculate_fx function.
Reimplemented in turi::factorization::ranking_sgd_solver_base< SGDInterface >.
std::map<std::string, variant_type> turi::sgd::sgd_solver_base::run | ( | ) |
The main function to run the sgd solver given the current options.
|
protectedpure virtual |
Called to run one iteration of the SGD algorithm on the training data.
[in] | iteration | The iteration number of the current pass through the data. |
[in] | iface | A pointer to the interface class. This can be upcast to the true SGDInterface class for use in the actual code. |
[in] | step_size | The step size to use for this pass through the data. |
Implemented in turi::factorization::ranking_sgd_solver_base< SGDInterface >.
|
inlineprotectedvirtual |
Called at the start of a run, before any run_iteration is called.
Definition at line 109 of file sgd_solver_base.hpp.
|
protected |
The main interface to the model, implementing sgd-specific routines for that model.
Definition at line 96 of file sgd_solver_base.hpp.
|
protected |
The training options of the solver.
Definition at line 100 of file sgd_solver_base.hpp.