Turi Create
4.0
|
#include <toolkits/factorization/ranking_sgd_solver_base.hpp>
Classes | |
struct | neg_sample_proc_buffer |
Public Member Functions | |
std::pair< double, double > | run_iteration (size_t iteration, sgd::sgd_interface_base *model_iface, const v2::ml_data &data, double step_size) |
std::pair< double, double > | calculate_objective (sgd::sgd_interface_base *model_iface, const v2::ml_data &data, size_t iteration) const GL_HOT |
std::map< std::string, variant_type > | run () |
Protected Member Functions | |
ranking_sgd_solver_base (const std::shared_ptr< sgd::sgd_interface_base > &main_interface, const v2::ml_data &train_data, const std::map< std::string, flexible_type > &options) | |
virtual std::pair< double, double > | run_sgd_thread (size_t iteration, size_t thread_idx, size_t num_threads, size_t block_idx, size_t num_blocks, const v2::ml_data &data, SGDInterface *iface, double step_size, volatile bool &error_detected)=0 |
virtual std::pair< double, double > | run_loss_calculation_thread (size_t thread_idx, size_t num_threads, const v2::ml_data &data, SGDInterface *iface) const =0 |
std::pair< size_t, size_t > | fill_x_buffer_with_users_items (std::vector< std::pair< std::vector< v2::ml_data_entry >, double > > &x_buffer, v2::ml_data_block_iterator &it, size_t n_items, dense_bitset &item_observed) const GL_HOT_INLINE_FLATTEN |
template<typename BufferIndexToItemIndexMapper > | |
GL_HOT_INLINE_FLATTEN void | clear_item_observed_buffer (dense_bitset &item_observed, size_t n_rows, size_t n_items, const BufferIndexToItemIndexMapper &map_index) const |
virtual void | setup (sgd_interface_base *iface) |
Static Protected Member Functions | |
static void | add_options (option_manager &options) |
Protected Attributes | |
std::shared_ptr< sgd_interface_base > | model_interface |
const std::map< std::string, flexible_type > | options |
The main parts of the ranking sgd solver class. Part of the functionality is implemented in one of two subclasses, explicit_ranking_sgd_solver and implicit_ranking_sgd_solver.
The ranking SGD stuff is broken into two categories; ranking regularization alongside training to predict a target, and "pure" ranking in which there is no target. Both require a substantial amount of bookkeeping to handle the negative item examples. The details of how these are applied, however, is handled with the individual subclasses.
Definition at line 33 of file ranking_sgd_solver_base.hpp.
|
inlineprotected |
Constructor
Definition at line 45 of file ranking_sgd_solver_base.hpp.
|
staticprotectedinherited |
Call the following function to insert the option definitions needed for the common sgd optimization class into an option manager. Meant to be called by the subclasses of sgd_solver_base.
|
inlinevirtual |
Calculate the objective value of the current state.
[in] | iface_base | The interface class that gives the gradient calculation routines on top of the model definition. This must be upcast to SGDInterface*. |
[in] | row_start | The starting row in the training data to use. In trial mode, we are likely looking at only a subset of the data. |
[in] | row_end | The ending row in the training data to use. In trial mode, we are likely looking at only a subset of the data. |
Reimplemented from turi::sgd::sgd_solver_base.
Definition at line 277 of file ranking_sgd_solver_base.hpp.
|
inlineprotected |
Clear out the item_observed buffer.
Based on the number of items actually used, deletes stuff. Defined below.
Definition at line 656 of file ranking_sgd_solver_base.hpp.
|
inlineprotected |
Fill a buffer with (observation, target value) pairs. Because of the user-block nature of the ml_data_block_iterator, this buffer is gauranteed to hold all the items rated by a particular user. If no target_value is present, then "1" is used.
[out] | x_buffer | The buffer where the (observation, target_value) pairs are stored. |
[in,out] | it | The current block iterator. |
[in,out] | item_observed | A mask giving the items observed in the data. |
Definition at line 360 of file ranking_sgd_solver_base.hpp.
|
inherited |
The main function to run the sgd solver given the current options.
|
inlinevirtual |
Run a single SGD pass through the data. Implementation of base sgd_solver's required virtual function.
[in] | iteration | The iteration index; what gets reported in the progress message. |
[in] | iface_base | The interface class that gives the gradient calculation routines on top of the model definition. This must be upcast to SGDInterface*. |
[in] | row_start | The starting row in the training data to use. In trial mode, we are likely looking at only a subset of the data. |
[in] | row_end | The ending row in the training data to use. In trial mode, we are likely looking at only a subset of the data. |
[in] | trial_mode | If true, immediately return failure on any numerical issues and do not report progress messages. |
Implements turi::sgd::sgd_solver_base.
Definition at line 158 of file ranking_sgd_solver_base.hpp.
|
protectedpure virtual |
Calculate the loss value for the block of data assigned to a particular thread.
[in] | thread_idx | The thread index determining this block. |
[in] | num_threads | The number of threads. |
[in] | data | The ml_data instance we're working with. Primarily needed for the metadata. |
[in] | it_init | The iterator inializer for the ml_data_block_iterator used for this thread. |
[in] | iface | The working SGD interface. |
|
protectedpure virtual |
The main method needed to be implemented by the subclass to run the ranking sgd stuff.
[in] | thread_idx | The thread index determining this block. |
[in] | num_threads | The number of threads. |
[in] | data | The v2::ml_data instance we're working with. Primarily needed for the metadata. |
[in] | it_init | The iterator inializer for the ml_data_block_iterator used for this thread. |
[in] | iface | The working SGD interface. |
[in] | step_size | The current SGD step size, set by the higher level algorithm. |
[in,out] | error_detected | If set to true, a numerical error is detected. |
|
inlineprotectedvirtualinherited |
Called at the start of a run, before any run_iteration is called.
Definition at line 109 of file sgd_solver_base.hpp.
|
protectedinherited |
The main interface to the model, implementing sgd-specific routines for that model.
Definition at line 96 of file sgd_solver_base.hpp.
|
protectedinherited |
The training options of the solver.
Definition at line 100 of file sgd_solver_base.hpp.