Turi Create  4.0
basic_sgd_solver.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_SGD_BASIC_SGD_SOLVER_CLASS_H_
7 #define TURI_SGD_BASIC_SGD_SOLVER_CLASS_H_
8 
9 #include <map>
10 #include <toolkits/ml_data_2/ml_data.hpp>
11 #include <toolkits/ml_data_2/ml_data_iterators.hpp>
12 #include <toolkits/sgd/sgd_solver_base.hpp>
13 #include <toolkits/sgd/sgd_interface.hpp>
14 
15 #ifdef interface
16 #undef interface
17 #endif
18 
19 namespace turi {
20 
21 class option_manager;
22 
23 namespace sgd {
24 
25 template <class SGDInterface>
26 class basic_sgd_solver : public sgd_solver_base {
27  private:
28  const size_t max_num_threads = 0;
29 
30  // Memory to hold things across threads.
31  std::vector<std::vector<std::pair<std::vector<v2::ml_data_entry>, double> > > x_buffers;
32 
33  public:
34 
35  /** Constructor.
36  *
37  * Construct the sgd solver.
38  */
39  basic_sgd_solver(const std::shared_ptr<sgd_interface_base>& interface,
40  const v2::ml_data& train_data,
41  const std::map<std::string, flexible_type>& options)
42 
43  : sgd_solver_base(interface, train_data, options)
44  , max_num_threads(thread::cpu_count())
45  {
46  x_buffers.resize(max_num_threads);
47  }
48 
49  /** Call the following function to insert the option definitions
50  * needed for the basic sgd solver into an option manager.
51  */
52  static void add_options(option_manager& options) {
53 
54  // No additional options beyond those in the base sgd solver
55  // class.
57  }
58 
59 
60 
61  /** Run a single SGD pass through the data.
62  *
63  */
64  std::pair<double, double> run_iteration(
65  size_t iteration,
66  sgd_interface_base* iface_base,
67  const v2::ml_data& data,
68  double step_size) GL_HOT {
69 
70  ////////////////////////////////////////////////////////////////////////////////
71  // Set up a few preliminary variables
72 
73  SGDInterface* iface = dynamic_cast<SGDInterface*>(iface_base);
74 
75  volatile bool error_detected = false;
76 
77  std::vector<double> loss_values(max_num_threads, 0);
78 
79  iface->setup_iteration(iteration, step_size);
80 
81  ////////////////////////////////////////////////////////////////////////////////
82  // Do one parallel pass through the data. Randomize as much as possible
83 
84  // Slice up the initial input data, so we take it from different
85  // sections each time. Since ml_data has a block cache manager
86  // and other
87  size_t num_blocks = 16*max_num_threads;
88 
89  std::vector<size_t> blocks_to_use(num_blocks);
90  std::iota(blocks_to_use.begin(), blocks_to_use.end(), 0);
91  random::shuffle(blocks_to_use);
92 
93  atomic<size_t> current_block = 0;
94 
95  size_t block_size = options.at("sgd_sampling_block_size");
96 
97  in_parallel([&](size_t thread_idx, size_t num_threads) GL_GCC_ONLY(GL_HOT_FLATTEN) {
98 
99  // Get the buffers
100  std::vector<std::pair<std::vector<v2::ml_data_entry>, double> >& x_buffer = x_buffers[thread_idx];
101  x_buffer.resize(block_size);
102 
103  loss_values[thread_idx] = 0;
104 
105  while(!error_detected) {
106  size_t block_lookup_idx = (++current_block) - 1;
107 
108  // we're done in this case.
109  if(block_lookup_idx >= num_blocks)
110  break;
111 
112  size_t block = blocks_to_use[block_lookup_idx];
113 
114  // Get the iterator for this.
115  auto it = data.get_iterator(block, num_blocks);
116 
117  while(!error_detected && !it.done()) {
118  // Fill up the buffer
119 
120  size_t idx = 0;
121 
122  for(idx = 0; idx < block_size && !it.done(); ++idx, ++it) {
123  it.fill_observation(x_buffer[idx].first);
124  x_buffer[idx].second = it.target_value();
125  }
126 
127  size_t n = idx;
128 
129  // Shuffle things!
130  random::shuffle(x_buffer.begin(), x_buffer.begin() + n);
131 
132  for(size_t i = 0; i < n && !error_detected; ++i) {
133  std::vector<v2::ml_data_entry>& x = x_buffer[i].first;
134  double y = x_buffer[i].second;
135 
136  double current_loss_value = iface->apply_sgd_step(thread_idx, x, y, step_size);
137 
138  ////////////////////////////////////////////////////////////
139  // Do a gradient step. The loss value is the one at the
140  // current point, before the sgd step is performed.
141 
142  loss_values[thread_idx] += current_loss_value;
143 
144  if(!std::isfinite(loss_values[thread_idx]) ) {
145  logstream(LOG_INFO) << "SGD: Non-finite loss value in thread " << thread_idx << std::endl;
146  error_detected = true;
147  }
148  }
149  }
150  }
151  }
152  );
153 
154  if(error_detected)
155  return {NAN, NAN};
156 
157  // Finalize the iteration.
158  iface->finalize_iteration();
159 
160  double loss_no_regularization
161  = (std::accumulate(loss_values.begin(), loss_values.end(), double(0.0))
162  / std::max(size_t(1), data.size()));
163 
164  double regularization_penalty = iface->current_regularization_penalty();
165 
166  double objective_value_estimate = loss_no_regularization + regularization_penalty;
167 
168  // Is it a trivial model? If so, we can break early.
169  if(objective_value_estimate <= 1e-16)
170  return {0,0};
171 
172  double reported_training_loss = iface->reported_loss_value(loss_no_regularization);
173 
174  return {objective_value_estimate, reported_training_loss};
175  }
176 
177 
178 };
179 
180 }}
181 
182 #endif
#define logstream(lvl)
Definition: logger.hpp:276
#define GL_HOT_FLATTEN
#define LOG_INFO
Definition: logger.hpp:101
void shuffle(std::vector< T > &vec)
Definition: random.hpp:536
static void add_options(option_manager &options)
sgd_solver_base(const std::shared_ptr< sgd_interface_base > &model_interface, const v2::ml_data &_train_data, const std::map< std::string, flexible_type > &_options)
void in_parallel(const std::function< void(size_t thread_id, size_t num_threads)> &fn)
Definition: lambda_omp.hpp:35