Turi Create  4.0
sgd_interface.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_SGD_INTERFACE_BASE_H_
7 #define TURI_SGD_INTERFACE_BASE_H_
8 
9 #include <map>
10 #include <string>
11 #include <limits>
12 
13 namespace turi {
14 
15 class flexible_type;
16 
17 namespace v2 {
18 class ml_data;
19 struct ml_data_entry;
20 }
21 
22 namespace sgd {
23 
24 
25 /** The base class for all the SGD interfaces. This interface governs
26  * all the interactions between the sgd solvers and the model.
27  *
28  * To implement an sgd solver, subclass sgd_interface_base and
29  * implement the appropriate methods.
30  *
31  * Then on top of this, choose the solver, and template it with your
32  * interface. The basic solver is the basic_sgd_solver, defined in
33  * basic_sgd_solver.hpp.
34  *
35  * Example:
36  *
37  * class simple_sgd_interface {
38  * ...
39  * };
40  *
41  * std::shared_ptr<simple_sgd_interface> iface(new simple_sgd_interface);
42  *
43  * basic_sgd_solver<simple_sgd_interface> solver(iface, train_data, options);
44  *
45  * auto training_status = solver.run();
46  *
47  */
49  public:
50 
51  virtual ~sgd_interface_base() = default;
52 
53  /** Called at the start of optimization, before any other functions
54  * are called.
55  *
56  * Perform any setup in light of the data used for training the
57  * model. Since ml_data has some statistics (e.g. maximum row
58  * size), these can be saved for processing stuff later.
59  */
60  virtual void setup(const v2::ml_data& train_data,
61  const std::map<std::string, flexible_type>& options) {}
62 
63  /** Called before each pass through the data.
64  */
65  virtual void setup_iteration(size_t iteration, double step_size) {}
66 
67  /** Called at the end of each pass through the data.
68  */
69  virtual void finalize_iteration() {}
70 
71  /** For automatically tuning the SGD step size and calculating the
72  * decrease rate of the step size. This value is also used to
73  * determine an upper bound on the allowed sgd step size, above
74  * which the algorithm stops being numerically stable. It also
75  * helps govern the decrease rate of the step size over iterations.
76  */
77  virtual double l2_regularization_factor() const { return 0; }
78 
79  /** Gives a hard limit on the sgd step size. Certain algorithms
80  * will blow up with a step size too large, and this gives a method
81  * of setting a hard limit on step sizes considered.
82  */
83  virtual double max_step_size() const { return std::numeric_limits<double>::max(); }
84 
85  /** If there are any issues with the model, this function can return
86  * false to force a reset. It is called once at the end of each
87  * iteration.
88  *
89  * Returns true if the state is numerically stable, and false if
90  * there are any numerical instabilities detected now or in the
91  * previous pass over the data. If this is true, then reset_state
92  * is called.
93  */
94  virtual bool state_is_numerically_stable() const { return true; }
95 
96  /** Sets up the optimization run. Called at the beginning of an
97  * optimization run or in the presence of numerical instabilities
98  * to reset the solver. Optimization is attempted again with a
99  * smaller step size.
100  */
101  virtual void setup_optimization(size_t random_seed = size_t(-1), bool _in_trial_mode = false) = 0;
102 
103  /** Calculate the value of the objective function as determined by
104  * the loss function, for a full data set, minus the regularization
105  * penalty.
106  *
107  * In reporting this loss, reported_loss_value(...) is called on
108  * this value to get a loss value to print.
109  *
110  * \param data The data to use in calculating the objective function.
111  *
112  * \return (loss for objective, loss for reporting)
113  */
114  virtual double calculate_loss(const v2::ml_data& data) const = 0;
115 
116  /** The value of the reported loss. The apply_sgd_step accumulates
117  * estimated loss values between samples. This function is called
118  * with this accumulated value to get a value
119  *
120  * For example, if squared error loss is used,
121  * reported_loss_name() could give RMSE, and then
122  * reported_loss_value(v) would be std::sqrt(v).
123  */
124  virtual double reported_loss_value(double accumulative_loss) const = 0;
125 
126 
127  /** The name of the loss to report on each iteration.
128  *
129  * For example, if squared error loss is used,
130  * reported_loss_name() could give RMSE, and then
131  * reported_loss_value(v) would be std::sqrt(v).
132  */
133  virtual std::string reported_loss_name() const = 0;
134 
135 
136  /** Calculate the current regularization penalty. This is used to
137  * compute the objective value, which is interpreted as loss + reg
138  * penalty.
139  */
140  virtual double current_regularization_penalty() const = 0;
141 
142  /** Apply the sgd step. Called on each data point.
143  */
144  virtual double apply_sgd_step(size_t thread_idx,
145  const std::vector<v2::ml_data_entry>& x,
146  double y,
147  double step_size) = 0;
148 
149 };
150 
151 
152 }}
153 
154 #endif /* TURI_SGD_INTERFACE_BASE_H_ */
virtual double l2_regularization_factor() const
virtual bool state_is_numerically_stable() const
virtual double max_step_size() const
virtual void setup(const v2::ml_data &train_data, const std::map< std::string, flexible_type > &options)
virtual void setup_iteration(size_t iteration, double step_size)