Turi Create
4.0
|
#include <toolkits/factorization/ranking_sgd_solver_explicit.hpp>
Public Member Functions | |
explicit_ranking_sgd_solver (const std::shared_ptr< sgd::sgd_interface_base > &main_interface, const v2::ml_data &train_data, const std::map< std::string, flexible_type > &options) | |
std::pair< double, double > | run_iteration (size_t iteration, sgd::sgd_interface_base *model_iface, const v2::ml_data &data, double step_size) |
std::pair< double, double > | calculate_objective (sgd::sgd_interface_base *model_iface, const v2::ml_data &data, size_t iteration) const GL_HOT |
std::map< std::string, variant_type > | run () |
Protected Member Functions | |
std::pair< size_t, size_t > | fill_x_buffer_with_users_items (std::vector< std::pair< std::vector< v2::ml_data_entry >, double > > &x_buffer, v2::ml_data_block_iterator &it, size_t n_items, dense_bitset &item_observed) const GL_HOT_INLINE_FLATTEN |
template<typename BufferIndexToItemIndexMapper > | |
GL_HOT_INLINE_FLATTEN void | clear_item_observed_buffer (dense_bitset &item_observed, size_t n_rows, size_t n_items, const BufferIndexToItemIndexMapper &map_index) const |
virtual void | setup (sgd_interface_base *iface) |
Static Protected Member Functions | |
static void | add_options (option_manager &options) |
Protected Attributes | |
std::shared_ptr< sgd_interface_base > | model_interface |
const std::map< std::string, flexible_type > | options |
Ranking When Target Is Present
When the target is present, simultaneously attempt to fit the model to the targets, while penalizing items that are predicted above value_of_unobserved_items.
Definition at line 28 of file ranking_sgd_solver_explicit.hpp.
|
inline |
Constructor
Definition at line 33 of file ranking_sgd_solver_explicit.hpp.
|
staticprotectedinherited |
Call the following function to insert the option definitions needed for the common sgd optimization class into an option manager. Meant to be called by the subclasses of sgd_solver_base.
|
inlinevirtualinherited |
Calculate the objective value of the current state.
[in] | iface_base | The interface class that gives the gradient calculation routines on top of the model definition. This must be upcast to SGDInterface*. |
[in] | row_start | The starting row in the training data to use. In trial mode, we are likely looking at only a subset of the data. |
[in] | row_end | The ending row in the training data to use. In trial mode, we are likely looking at only a subset of the data. |
Reimplemented from turi::sgd::sgd_solver_base.
Definition at line 277 of file ranking_sgd_solver_base.hpp.
|
inlineprotectedinherited |
Clear out the item_observed buffer.
Based on the number of items actually used, deletes stuff. Defined below.
Definition at line 656 of file ranking_sgd_solver_base.hpp.
|
inlineprotectedinherited |
Fill a buffer with (observation, target value) pairs. Because of the user-block nature of the ml_data_block_iterator, this buffer is gauranteed to hold all the items rated by a particular user. If no target_value is present, then "1" is used.
[out] | x_buffer | The buffer where the (observation, target_value) pairs are stored. |
[in,out] | it | The current block iterator. |
[in,out] | item_observed | A mask giving the items observed in the data. |
Definition at line 360 of file ranking_sgd_solver_base.hpp.
|
inherited |
The main function to run the sgd solver given the current options.
|
inlinevirtualinherited |
Run a single SGD pass through the data. Implementation of base sgd_solver's required virtual function.
[in] | iteration | The iteration index; what gets reported in the progress message. |
[in] | iface_base | The interface class that gives the gradient calculation routines on top of the model definition. This must be upcast to SGDInterface*. |
[in] | row_start | The starting row in the training data to use. In trial mode, we are likely looking at only a subset of the data. |
[in] | row_end | The ending row in the training data to use. In trial mode, we are likely looking at only a subset of the data. |
[in] | trial_mode | If true, immediately return failure on any numerical issues and do not report progress messages. |
Implements turi::sgd::sgd_solver_base.
Definition at line 158 of file ranking_sgd_solver_base.hpp.
|
inlineprotectedvirtualinherited |
Called at the start of a run, before any run_iteration is called.
Definition at line 109 of file sgd_solver_base.hpp.
|
protectedinherited |
The main interface to the model, implementing sgd-specific routines for that model.
Definition at line 96 of file sgd_solver_base.hpp.
|
protectedinherited |
The training options of the solver.
Definition at line 100 of file sgd_solver_base.hpp.