Turi Create  4.0
sgd_solver_base.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_SGD_SGD_SOLVER_CLASS_H_
7 #define TURI_SGD_SGD_SOLVER_CLASS_H_
8 
9 #include <map>
10 #include <cmath>
11 #include <vector>
12 #include <toolkits/sgd/sgd_interface.hpp>
13 #include <core/logging/assertions.hpp>
14 #include <core/data/flexible_type/flexible_type.hpp>
15 #include <model_server/lib/variant.hpp>
16 #include <core/storage/sframe_interface/unity_sframe.hpp>
17 
18 namespace turi {
19 
20 class option_manager;
21 
22 namespace v2 {
23 class ml_data;
24 }
25 
26 namespace sgd {
27 
28 /** The base solver class for all the general SGD methods.
29  *
30  * This class provides the high-level functionality for the sgd
31  * methods. Particular versions of SGD are implemented using the
32  * run_iteration method, which is called to do one pass through the
33  * data on a particular block of data points.
34  */
36 
37  public:
38 
39  virtual ~sgd_solver_base() = default;
40 
41  protected:
42 
43  /** The constructor.
44  */
45  sgd_solver_base(const std::shared_ptr<sgd_interface_base>& model_interface,
46  const v2::ml_data& _train_data,
47  const std::map<std::string, flexible_type>& _options);
48 
49  /** Call the following function to insert the option definitions
50  * needed for the common sgd optimization class into an option
51  * manager. Meant to be called by the subclasses of
52  * sgd_solver_base.
53  */
54  static void add_options(option_manager& options);
55 
56  private:
57 
58  /** Can't copy this class, so delete the copy constructor.
59  */
60  sgd_solver_base(const sgd_solver_base&) = delete;
61 
62  /** A reference to the training data. passed as parameter to
63  * subclass.
64  */
65  const v2::ml_data& train_data;
66 
67  public:
68 
69  ////////////////////////////////////////////////////////////////////////////////
70 
71  /** The main function to run the sgd solver given the current
72  * options.
73  */
74  std::map<std::string, variant_type> run();
75 
76  private:
77 
78  ////////////////////////////////////////////////////////////////////////////////
79  // Particular methods of running things
80 
81  /** Run the sgd algorithm with a fixed step size. If divergence is
82  * detected, then retry with a smaller step size and warn the
83  * user.
84  *
85  * If the sgd_step_size parameter is set to a value greater than
86  * zero, the run() method above will run this solver.
87  */
88  std::map<std::string, variant_type> run_fixed_sgd_step_size(double sgd_step_size);
89 
90  protected:
91 
92 
93  /** The main interface to the model, implementing sgd-specific
94  * routines for that model.
95  */
96  std::shared_ptr<sgd_interface_base> model_interface;
97 
98  /** The training options of the solver.
99  */
100  const std::map<std::string, flexible_type> options;
101 
102 
103  ////////////////////////////////////////////////////////////////////////////////
104  // Virtual methods that need to be implemented by the calling class.
105 
106  /** Called at the start of a run, before any run_iteration is
107  * called.
108  */
109  virtual void setup(sgd_interface_base* iface) {};
110 
111 
112  /** Called to run one iteration of the SGD algorithm on the training
113  * data.
114  *
115  * \param[in] iteration The iteration number of the current pass
116  * through the data.
117  *
118  * \param[in] iface A pointer to the interface class. This can be
119  * upcast to the true SGDInterface class for use in the actual
120  * code.
121  *
122  * \param[in] step_size The step size to use for this pass through
123  * the data.
124  *
125  * \return A pair -- (objective_value, loss)
126  */
127  virtual std::pair<double, double> run_iteration(
128  size_t iteration,
129  sgd_interface_base* iface,
130  const v2::ml_data& data,
131  double step_size) = 0;
132 
133 
134  /** Called to calculate the current objective value for the data.
135  * Defaults to calling calculate_loss() +
136  * current_regularizer_value() in the current interface. the
137  * function to get the current regularization term; however, can be
138  * overridden if need be. (For example, for optimizing ranking
139  * functions, the loss function doesn't fit into the standard
140  * framework laid out by the model's calculate_fx function.
141  *
142  * \return (objective value, reportable training loss)
143  */
144  virtual std::pair<double, double> calculate_objective(
145  sgd_interface_base* iface, const v2::ml_data& data, size_t iteration) const;
146 
147  private:
148 
149  ////////////////////////////////////////////////////////////////////////////////
150  //
151  // SGD helper functions.
152  //
153  // The functions below implement specific parts of the
154 
155  /** Runs the sgd algorithm with step size tuning to find the best
156  * value. Returns the best value.
157  *
158  * If the sgd_step_size parameter is set to zero, the run() method
159  * above will run this to determine the best step size.
160  *
161  * The tuning is done by first running the model on a small subset
162  * of the data with several different step sizes. The best step
163  * size is chosen as the step size for running the full algorithm.
164  */
165  double compute_initial_sgd_step_size();
166 
167  /** Gets the initial objective value (objective, reportable_training_loss) for
168  * the problem. Used to tune the sgd step size.
169  *
170  */
171  std::pair<double, double> get_initial_objective_value(const v2::ml_data& data) const;
172 
173  /** Calculates a reasonable stepsize for the current sample. We
174  * return the smaller step size between 2 stepsizes. The first is
175  * gamma / (1 + n * lm * gamma) -- the stepsize dictated in
176  *
177  * Léon Bottou: Stochastic Gradient Tricks, Neural Networks, Tricks
178  * of the Trade, Reloaded, 430–445, Edited by Grégoire Montavon,
179  * Genevieve B. Orr and Klaus-Robert Müller, Lecture Notes in
180  * Computer Science (LNCS 7700), Springer, 2012
181  *
182  * The second step size is gamma / (1 + iteration) **
183  * stepsize_decrease_rate. This is the more standard sgd step size
184  * that works with non-regularized values.
185  */
186  double calculate_iteration_sgd_step_size(
187  size_t iteration,
188  double initial_sgd_step_size, double step_size_decrease_rate,
189  double l2_regularization);
190 
191 
192  /** Tests whether a model has converged or not by looking at changes
193  * in the last few iterations of the data. Values are passed in with
194  * a map of options.
195  *
196  * The technique looks at the max, min, and mean of the loss function
197  * in the last sgd_convergence_interval iterations. If the
198  * difference of the max and the mean, divided by std::max(1, mean)
199  * is less than sgd_convergence_threshhold, then we assume the model
200  * is converged.
201  *
202  * Setting sgd_convergence_interval to 0 or
203  * sgd_convergence_threshhold to 0 disables this test, forcing the
204  * algorithm to run for the full max_iterations.
205  */
206  bool sgd_test_model_convergence(const std::vector<double>& sgd_path);
207 
208  /** Adjusts the step size dynamically based on whether things are
209  * converging or not. Returns the new step size.
210  *
211  * If the loss value has gone steadily down in all of the last
212  * sgd_convergence_interval iterations, then the step size is
213  * increased such that if this happens over a full
214  * sgd_convergence_interval iterations, then the step size is
215  * doubled. If a one-sided t-test on the differences between all the
216  * previous loss values does not show that it is decreasing with
217  * confidence > 95%, then the step size is decreased by the same
218  * amount.
219  */
220  double sgd_adjust_step_size(const std::vector<double>& sgd_path, double sgd_step_size);
221 
222 
223 };
224 
225 }}
226 
227 #endif /* TURI_SGD_SGD_SOLVER_CLASS_H_ */
std::shared_ptr< sgd_interface_base > model_interface
virtual void setup(sgd_interface_base *iface)
const std::map< std::string, flexible_type > options