Turi Create  4.0
sgd_ranking_interface.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_SGD_SGD_INTERFACE_BASE_H_
7 #define TURI_SGD_SGD_INTERFACE_BASE_H_
8 
9 #include <toolkits/sgd/sgd_interface.hpp>
10 
11 namespace turi { namespace sgd {
12 
13 
14 /** The base class for the ranking SGD interfaces. This interface
15  * governs all the interactions between the sgd solvers and the
16  * model.
17  *
18  * To use the ranking sgd solver, implement the following options.
19  */
21  public:
22 
23  /** Apply two sgd steps to the code to increase the predicted value
24  * of x_positive and decrease the predicted value of x_negative.
25  */
26  virtual double apply_pairwise_sgd_step(
27  size_t thread_idx,
28  const std::vector<ml_data_entry>& x_positive,
29  const std::vector<ml_data_entry>& x_negative,
30  double step_size) = 0;
31 };
32 
33 }}
34 
35 #endif /* TURI_SGD_SGD_INTERFACE_BASE_H_ */
virtual double apply_pairwise_sgd_step(size_t thread_idx, const std::vector< ml_data_entry > &x_positive, const std::vector< ml_data_entry > &x_negative, double step_size)=0