Turi Create  4.0
turi::factorization::explicit_ranking_sgd_solver< SGDInterface > Class Template Referencefinal

#include <toolkits/factorization/ranking_sgd_solver_explicit.hpp>

Public Member Functions

 explicit_ranking_sgd_solver (const std::shared_ptr< sgd::sgd_interface_base > &main_interface, const v2::ml_data &train_data, const std::map< std::string, flexible_type > &options)
 
std::pair< double, double > run_iteration (size_t iteration, sgd::sgd_interface_base *model_iface, const v2::ml_data &data, double step_size)
 
std::pair< double, double > calculate_objective (sgd::sgd_interface_base *model_iface, const v2::ml_data &data, size_t iteration) const GL_HOT
 
std::map< std::string, variant_typerun ()
 

Protected Member Functions

std::pair< size_t, size_t > fill_x_buffer_with_users_items (std::vector< std::pair< std::vector< v2::ml_data_entry >, double > > &x_buffer, v2::ml_data_block_iterator &it, size_t n_items, dense_bitset &item_observed) const GL_HOT_INLINE_FLATTEN
 
template<typename BufferIndexToItemIndexMapper >
GL_HOT_INLINE_FLATTEN void clear_item_observed_buffer (dense_bitset &item_observed, size_t n_rows, size_t n_items, const BufferIndexToItemIndexMapper &map_index) const
 
virtual void setup (sgd_interface_base *iface)
 

Static Protected Member Functions

static void add_options (option_manager &options)
 

Protected Attributes

std::shared_ptr< sgd_interface_base > model_interface
 
const std::map< std::string, flexible_typeoptions
 

Detailed Description

template<class SGDInterface>
class turi::factorization::explicit_ranking_sgd_solver< SGDInterface >

Ranking When Target Is Present


When the target is present, simultaneously attempt to fit the model to the targets, while penalizing items that are predicted above value_of_unobserved_items.

Definition at line 28 of file ranking_sgd_solver_explicit.hpp.

Constructor & Destructor Documentation

◆ explicit_ranking_sgd_solver()

template<class SGDInterface >
turi::factorization::explicit_ranking_sgd_solver< SGDInterface >::explicit_ranking_sgd_solver ( const std::shared_ptr< sgd::sgd_interface_base > &  main_interface,
const v2::ml_data &  train_data,
const std::map< std::string, flexible_type > &  options 
)
inline

Constructor

Definition at line 33 of file ranking_sgd_solver_explicit.hpp.

Member Function Documentation

◆ add_options()

static void turi::sgd::sgd_solver_base::add_options ( option_manager options)
staticprotectedinherited

Call the following function to insert the option definitions needed for the common sgd optimization class into an option manager. Meant to be called by the subclasses of sgd_solver_base.

◆ calculate_objective()

template<class SGDInterface >
std::pair<double, double> turi::factorization::ranking_sgd_solver_base< SGDInterface >::calculate_objective ( sgd::sgd_interface_base model_iface,
const v2::ml_data &  data,
size_t  iteration 
) const
inlinevirtualinherited

Calculate the objective value of the current state.

Parameters
[in]iface_baseThe interface class that gives the gradient calculation routines on top of the model definition. This must be upcast to SGDInterface*.
[in]row_startThe starting row in the training data to use. In trial mode, we are likely looking at only a subset of the data.
[in]row_endThe ending row in the training data to use. In trial mode, we are likely looking at only a subset of the data.

Reimplemented from turi::sgd::sgd_solver_base.

Definition at line 277 of file ranking_sgd_solver_base.hpp.

◆ clear_item_observed_buffer()

template<class SGDInterface >
template<typename BufferIndexToItemIndexMapper >
GL_HOT_INLINE_FLATTEN void turi::factorization::ranking_sgd_solver_base< SGDInterface >::clear_item_observed_buffer ( dense_bitset item_observed,
size_t  n_rows,
size_t  n_items,
const BufferIndexToItemIndexMapper &  map_index 
) const
inlineprotectedinherited

Clear out the item_observed buffer.

Based on the number of items actually used, deletes stuff. Defined below.

Definition at line 656 of file ranking_sgd_solver_base.hpp.

◆ fill_x_buffer_with_users_items()

template<class SGDInterface >
std::pair<size_t, size_t> turi::factorization::ranking_sgd_solver_base< SGDInterface >::fill_x_buffer_with_users_items ( std::vector< std::pair< std::vector< v2::ml_data_entry >, double > > &  x_buffer,
v2::ml_data_block_iterator it,
size_t  n_items,
dense_bitset item_observed 
) const
inlineprotectedinherited

Fill a buffer with (observation, target value) pairs. Because of the user-block nature of the ml_data_block_iterator, this buffer is gauranteed to hold all the items rated by a particular user. If no target_value is present, then "1" is used.

Parameters
[out]x_bufferThe buffer where the (observation, target_value) pairs are stored.
[in,out]itThe current block iterator.
[in,out]item_observedA mask giving the items observed in the data.
Returns
(n_rows, n_rated_items). The number of rows in the buffer, and the number of unique rated items.

Definition at line 360 of file ranking_sgd_solver_base.hpp.

◆ run()

std::map<std::string, variant_type> turi::sgd::sgd_solver_base::run ( )
inherited

The main function to run the sgd solver given the current options.

◆ run_iteration()

template<class SGDInterface >
std::pair<double, double> turi::factorization::ranking_sgd_solver_base< SGDInterface >::run_iteration ( size_t  iteration,
sgd::sgd_interface_base model_iface,
const v2::ml_data &  data,
double  step_size 
)
inlinevirtualinherited

Run a single SGD pass through the data. Implementation of base sgd_solver's required virtual function.

Parameters
[in]iterationThe iteration index; what gets reported in the progress message.
[in]iface_baseThe interface class that gives the gradient calculation routines on top of the model definition. This must be upcast to SGDInterface*.
[in]row_startThe starting row in the training data to use. In trial mode, we are likely looking at only a subset of the data.
[in]row_endThe ending row in the training data to use. In trial mode, we are likely looking at only a subset of the data.
[in]trial_modeIf true, immediately return failure on any numerical issues and do not report progress messages.

Implements turi::sgd::sgd_solver_base.

Definition at line 158 of file ranking_sgd_solver_base.hpp.

◆ setup()

virtual void turi::sgd::sgd_solver_base::setup ( sgd_interface_base iface)
inlineprotectedvirtualinherited

Called at the start of a run, before any run_iteration is called.

Definition at line 109 of file sgd_solver_base.hpp.

Member Data Documentation

◆ model_interface

std::shared_ptr<sgd_interface_base> turi::sgd::sgd_solver_base::model_interface
protectedinherited

The main interface to the model, implementing sgd-specific routines for that model.

Definition at line 96 of file sgd_solver_base.hpp.

◆ options

const std::map<std::string, flexible_type> turi::sgd::sgd_solver_base::options
protectedinherited

The training options of the solver.

Definition at line 100 of file sgd_solver_base.hpp.


The documentation for this class was generated from the following file: