Turi Create  4.0
ranking_sgd_solver_base.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_SGD_RANKING_SGD_SOLVER_BASE_CLASS_H_
7 #define TURI_SGD_RANKING_SGD_SOLVER_BASE_CLASS_H_
8 
9 #include <map>
10 #include <vector>
11 #include <type_traits>
12 #include <core/util/code_optimization.hpp>
13 #include <toolkits/ml_data_2/ml_data.hpp>
14 #include <toolkits/ml_data_2/ml_data_iterators.hpp>
15 #include <toolkits/sgd/sgd_solver_base.hpp>
16 #include <toolkits/factorization/loss_model_profiles.hpp>
17 
18 namespace turi { namespace factorization {
19 
20 
21 /** The main parts of the ranking sgd solver class. Part of the
22  * functionality is implemented in one of two subclasses,
23  * explicit_ranking_sgd_solver and implicit_ranking_sgd_solver.
24  *
25  * The ranking SGD stuff is broken into two categories; ranking
26  * regularization alongside training to predict a target, and "pure"
27  * ranking in which there is no target. Both require a substantial
28  * amount of bookkeeping to handle the negative item examples. The
29  * details of how these are applied, however, is handled with the
30  * individual subclasses.
31  */
32 template <class SGDInterface>
34  protected:
35 
36  const size_t max_n_threads;
37  double num_sampled_negative_examples;
38  size_t random_seed = 0;
39 
40  protected:
41 
42  /**
43  * Constructor
44  */
46  const std::shared_ptr<sgd::sgd_interface_base>& main_interface,
47  const v2::ml_data& train_data,
48  const std::map<std::string, flexible_type>& options)
49 
50  : sgd::sgd_solver_base(main_interface, train_data, options)
51  , max_n_threads(thread::cpu_count())
52  , num_sampled_negative_examples(options.at("num_sampled_negative_examples"))
53  , random_seed(hash64(options.at("random_seed")))
54  {
55  DASSERT_GE(num_sampled_negative_examples, 1);
56  }
57 
58  ////////////////////////////////////////////////////////////////////////////////
59  //
60  // Typedefs for the x buffer.
61  //
62  // A buffer contains all the items for a given user. This
63  // information is needed for constructing the buffer of negative
64  // examples, as items not in this buffer must be picked for this
65  // process.
66 
67  ////////////////////////////////////////////////////////////////////////////////
68  //
69  // Virtual functions needed to be subclassed.
70 
71  /** The main method needed to be implemented by the subclass to
72  * run the ranking sgd stuff.
73  *
74  * \param[in] thread_idx The thread index determining this block.
75  *
76  * \param[in] num_threads The number of threads.
77  *
78  * \param[in] data The v2::ml_data instance we're working with.
79  * Primarily needed for the metadata.
80  *
81  * \param[in] it_init The iterator inializer for the
82  * ml_data_block_iterator used for this thread.
83  *
84  * \param[in] iface The working SGD interface.
85  *
86  * \param[in] step_size The current SGD step size, set by the
87  * higher level algorithm.
88  *
89  * \param[in,out] error_detected If set to true, a numerical error
90  * is detected.
91  *
92  * \return (loss, rank_loss) -- loss is the cumulative estimated
93  * loss value for this thread on predicting the training data, and
94  * rank_loss is the weighted loss on the negative examples.
95  */
96  virtual std::pair<double, double> run_sgd_thread(
97  size_t iteration,
98  size_t thread_idx, size_t num_threads,
99  size_t block_idx, size_t num_blocks,
100  const v2::ml_data& data,
101  SGDInterface* iface,
102  double step_size,
103  volatile bool& error_detected) = 0;
104 
105 
106 
107  /** Calculate the loss value for the block of data assigned to a
108  * particular thread.
109  *
110  * \param[in] thread_idx The thread index determining this block.
111  *
112  * \param[in] num_threads The number of threads.
113  *
114  * \param[in] data The ml_data instance we're working with.
115  * Primarily needed for the metadata.
116  *
117  * \param[in] it_init The iterator inializer for the
118  * ml_data_block_iterator used for this thread.
119  *
120  * \param[in] iface The working SGD interface.
121  *
122  * \return (loss, rank_loss) -- loss is the cumulative estimated
123  * loss value for this thread on predicting the training data, and
124  * rank_loss is the weighted loss on the negative examples.
125  */
126  virtual std::pair<double, double> run_loss_calculation_thread(
127  size_t thread_idx, size_t num_threads,
128  const v2::ml_data& data,
129  SGDInterface* iface) const = 0;
130 
131 
132  public:
133 
134  ////////////////////////////////////////////////////////////////////////////////
135  //
136  // The primary functions that wrap the above virtual functions.
137 
138  /**
139  * Run a single SGD pass through the data. Implementation of base
140  * sgd_solver's required virtual function.
141  *
142  * \param[in] iteration The iteration index; what gets reported in
143  * the progress message.
144  *
145  * \param[in] iface_base The interface class that gives the
146  * gradient calculation routines on top of the model definition.
147  * This must be upcast to SGDInterface*.
148  *
149  * \param[in] row_start The starting row in the training data to use. In
150  * trial mode, we are likely looking at only a subset of the data.
151  *
152  * \param[in] row_end The ending row in the training data to use. In trial
153  * mode, we are likely looking at only a subset of the data.
154  *
155  * \param[in] trial_mode If true, immediately return failure on any
156  * numerical issues and do not report progress messages.
157  */
158  std::pair<double, double> run_iteration(
159  size_t iteration,
160  sgd::sgd_interface_base* model_iface,
161  const v2::ml_data& data,
162  double step_size) {
163 
164  ////////////////////////////////////////////////////////////////////////////////
165  // Step 1. Set up a few preliminary variables
166 
167  SGDInterface* iface = dynamic_cast<SGDInterface*>(model_iface);
168 
169  const size_t data_size = data.size();
170 
171  ////////////////////////////////////////////////////////////////////////////////
172  // Step 2. Do one parallel pass through the data
173 
174  // We accumulate the loss values from each thread here, calculated
175  // from the state before the update is applied. The total loss
176  // value is used to report back an estimate of the current state.
177  //
178  // The objective value for the threads without
179  std::vector<double> loss_values(max_n_threads, 0);
180 
181  std::vector<double> rank_loss_values(max_n_threads, 0);
182 
183  // The status is governed by a common variable. If any of the
184  // threads detects a numerical error, this flag is set and all the
185  // threads soon exit.
186  volatile bool error_detected = false;
187 
188  iface->setup_iteration(iteration, step_size);
189 
190  // Slice up the initial input data, so we take it from different
191  // sections each time. Since ml_data has a block cache manager
192  // and other
193  size_t num_blocks = 16*thread::cpu_count();
194 
195  std::vector<size_t> blocks_to_use(num_blocks);
196  std::iota(blocks_to_use.begin(), blocks_to_use.end(), 0);
197  random::shuffle(blocks_to_use);
198 
199  atomic<size_t> current_block = 0;
200 
201  in_parallel([&](size_t thread_idx, size_t num_threads) {
202  random::seed(hash64(thread_idx, iteration, random_seed));
203 
204  while(!error_detected) {
205  size_t block_lookup_idx = (++current_block) - 1;
206 
207  // we're done in this case.
208  if(block_lookup_idx >= num_blocks)
209  break;
210 
211  double lv, rlv;
212 
213  std::tie(lv, rlv)
214  = run_sgd_thread(
215  iteration,
216  thread_idx, num_threads,
217  blocks_to_use[block_lookup_idx], num_blocks,
218  data, iface, step_size, error_detected);
219 
220  loss_values[thread_idx] += lv;
221  rank_loss_values[thread_idx] += rlv;
222  }
223  });
224 
225  ////////////////////////////////////////////////////////////////////////////////
226  // Step 3. Check for errors.
227 
228  if(error_detected)
229  return {std::numeric_limits<double>::max(), std::numeric_limits<double>::max()};
230 
231  ////////////////////////////////////////////////////////////////////////////////
232  // Step 4. Calculate how well we've done and what the current
233  // estimated value is.
234 
235  iface->finalize_iteration();
236 
237  double loss_no_regularization
238  = (std::accumulate(loss_values.begin(), loss_values.end(), double(0.0))
239  / std::max(size_t(1), data_size));
240 
241  double rank_loss
242  = (std::accumulate(rank_loss_values.begin(), rank_loss_values.end(), double(0.0))
243  / std::max(size_t(1), data_size));
244 
245  double regularization_penalty = iface->current_regularization_penalty();
246  double objective_value_estimate = loss_no_regularization + rank_loss + regularization_penalty;
247 
248  // Is it a trivial model? If so, we can break early.
249  if(objective_value_estimate <= 1e-16) {
250  return {0, 0};
251  }
252 
253  double reported_training_loss = iface->loss_model.reported_loss_value(loss_no_regularization);
254 
255  ////////////////////////////////////////////////////////////////////////////////
256  // Step 5. Unless we're in trial mode, report progress.
257 
258  return {objective_value_estimate, reported_training_loss};
259  }
260 
261  ////////////////////////////////////////////////////////////////////////////////
262 
263  /** Calculate the objective value of the current state.
264  *
265  * \param[in] iface_base The interface class that gives the
266  * gradient calculation routines on top of the model definition.
267  * This must be upcast to SGDInterface*.
268  *
269  * \param[in] row_start The starting row in the training data to
270  * use. In trial mode, we are likely looking at only a subset of
271  * the data.
272  *
273  * \param[in] row_end The ending row in the training data to use.
274  * In trial mode, we are likely looking at only a subset of the
275  * data.
276  */
277  std::pair<double, double> calculate_objective(
278  sgd::sgd_interface_base* model_iface,
279  const v2::ml_data& data,
280  size_t iteration) const GL_HOT {
281 
282  ////////////////////////////////////////////////////////////////////////////////
283  // Step 1. Set up a few preliminary variables
284 
285  SGDInterface* iface = dynamic_cast<SGDInterface*>(model_iface);
286 
287  const size_t data_size = data.size();
288 
289  ////////////////////////////////////////////////////////////////////////////////
290  // Step 2. Do one parallel pass through the data, calculating the
291  // loss value for each data point.
292 
293  std::vector<double> loss_values(max_n_threads, 0);
294  std::vector<double> rank_loss_values(max_n_threads, 0);
295 
296  volatile bool error_detected = false;
297 
298  in_parallel([&](size_t thread_idx, size_t num_threads) {
299  random::seed(hash64(thread_idx, iteration, random_seed));
300 
301  double loss = 0, rank_loss = 0;
302 
303  // If the training data has a target
304  std::tie(loss, rank_loss) = run_loss_calculation_thread(
305  thread_idx, num_threads, data, iface);
306 
307  if(!std::isfinite(loss) || loss == std::numeric_limits<double>::max()) {
308  error_detected = true;
309  } else {
310  loss_values[thread_idx] = loss;
311  rank_loss_values[thread_idx] = rank_loss;
312  }
313  });
314 
315  if(error_detected)
316  return {std::numeric_limits<double>::max(), std::numeric_limits<double>::max()};
317 
318  ////////////////////////////////////////////////////////////////////////////////
319  // Step 3. Calculate the regularization penalty and the rest of
320  // the objective stuff.
321 
322  double loss_no_regularization
323  = (std::accumulate(loss_values.begin(), loss_values.end(), double(0.0))
324  / std::max(size_t(1), data_size));
325 
326  double rank_loss
327  = (std::accumulate(rank_loss_values.begin(), rank_loss_values.end(), double(0.0))
328  / std::max(size_t(1), data_size));
329 
330  double regularization_penalty = iface->current_regularization_penalty();
331  double objective_value_estimate = loss_no_regularization + rank_loss + regularization_penalty;
332 
333  double reported_training_loss = iface->loss_model.reported_loss_value(loss_no_regularization);
334 
335  return {objective_value_estimate, reported_training_loss};
336  }
337 
338  protected:
339 
340 
341  ////////////////////////////////////////////////////////////////////////////////
342 
343  /** Fill a buffer with (observation, target value) pairs. Because
344  * of the user-block nature of the ml_data_block_iterator, this
345  * buffer is gauranteed to hold all the items rated by a particular
346  * user. If no target_value is present, then "1" is used.
347  *
348  * \param[out] x_buffer The buffer where the (observation,
349  * target_value) pairs are stored.
350  *
351  * \param[in,out] it The current block iterator.
352  *
353  * \param[in,out] item_observed A mask giving the items observed in
354  * the data.
355  *
356  * \return (n_rows, n_rated_items). The number of rows in the
357  * buffer, and the number of unique rated items.
358  */
359  inline std::pair<size_t, size_t>
361  std::vector<std::pair<std::vector<v2::ml_data_entry>, double> >& x_buffer,
363  size_t n_items,
364  dense_bitset& item_observed) const GL_HOT_INLINE_FLATTEN {
365 
366  size_t n_rows = 0;
367  size_t n_rated_items = 0;
368 
369  do {
370  size_t index = n_rows;
371 
372  if(x_buffer.size() <= index)
373  x_buffer.resize(2*index);
374 
375  auto& x = x_buffer[index].first;
376  it.fill_observation(x);
377 
378  x_buffer[index].second = it.target_value();
379 
380  size_t item = x[1].index;
381  DASSERT_LT(item, n_items);
382 
383  bool old_bit = item_observed.set_bit_unsync(item);
384  if(!old_bit)
385  ++n_rated_items;
386 
387  ++n_rows, ++it;
388  } while(!it.done() && !it.is_start_of_new_block());
389 
390  return std::make_pair(n_rows, n_rated_items);
391  }
392 
393 
394  /********************************************************************************
395 
396  ============================================================
397  Internal bookkeeping routines.
398  ============================================================
399 
400  These routines take care of the bookkeeping surrounding the
401  tracking of items that the users have not rated.
402 
403  ********************************************************************************/
404 
405  /** A processing buffer for the choose_negative_example function so
406  * that we can avoid memory allocations.
407  */
409  std::vector<v2::ml_data_entry> candidate_x;
410  std::vector<size_t> chosen_negative_items;
411  std::vector<size_t> candidate_negative_items;
412 
413  // If a user has rated most of the items, then rejection sampling
414  // doesn't work efficiently. In this case, build a list of the
415  // available items, then simply sample the available items from
416  // that.
417  //
418  // This operation only has to be done once per user; thus we keep
419  // the result in this thread-local buffer. On future attempts, if
420  // the user_of_available_item_list matches the current user, then
421  // we skip the rejection sampling altogether and just draw from
422  // this item list.
423  size_t user_of_available_item_list;
424  std::vector<size_t> available_item_list;
425  std::vector<size_t> available_item_list_chosen_indices;
426  };
427 
428  ////////////////////////////////////////////////////////////////////////////////
429 
430  /* Chooses a negative example to complement the current example,
431  * given a mask of all the observed items. Returns true if
432  * successful and false if no viable negative candidate is found.
433  */
434  inline double choose_negative_example(
435  size_t thread_idx,
436  const v2::ml_data& data,
437  SGDInterface* iface,
438  std::vector<v2::ml_data_entry>& negative_example_x,
439  const std::vector<v2::ml_data_entry>& current_positive_example,
440  const dense_bitset& item_observed,
441  size_t n_rows, size_t n_items,
442  size_t n_rated_items,
444 
445  const size_t ITEM_COLUMN_INDEX = 1;
446 
447  ////////////////////////////////////////////////////////////////////////////////
448  // Goal: pick at most n_points with items not in the candidate
449  // training set.
450 
451  ////////////////////////////////////////////////////////////////////////////////
452  // Step 1: set up the buffers
453 
454  std::vector<size_t>& chosen_negative_items = proc_buf.chosen_negative_items;
455  chosen_negative_items.resize(num_sampled_negative_examples);
456 
457  std::vector<size_t>& candidate_negative_items = proc_buf.candidate_negative_items;
458  candidate_negative_items.resize(num_sampled_negative_examples);
459 
460  size_t n_points_picked = 0;
461  bool remove_from_available_item_list = false;
462 
463  ////////////////////////////////////////////////////////////////////////////////
464  // Step 1: Pick num_sampled_negative_examples from samples that
465  // are not ones the user rated. For efficiency, we need to handle
466  // two cases -- when the user has rated few items and when the
467  // user has rated many of the items.
468 
469  ////////////////////////////////////////
470  // Case 1: Fewer rated items
471  //
472  // If at least 1 / 8 of the items are free, then don't bother to
473  // build the list of available items and just do rejection
474  // sampling. Otherwise, we should build the list of available
475  // items and sample from that.
476 
477  if(8 * (n_items - n_rated_items) > n_items) {
478  while(n_points_picked < num_sampled_negative_examples) {
479  // Get num_sampled_negative_examples candidate points.
480 
481  for(size_t i = 0; i < num_sampled_negative_examples; ++i) {
482  size_t candidate_item = random::fast_uniform<size_t>(0, n_items - 1);
483  item_observed.prefetch(candidate_item);
484  candidate_negative_items[i] = candidate_item;
485  }
486 
487  // Move unobserved items over to the chosen points.
488  for(size_t i = 0;
489  i < num_sampled_negative_examples && n_points_picked < num_sampled_negative_examples;
490  ++i) {
491 
492  size_t candidate_item = candidate_negative_items[i];
493  if(!item_observed.get(candidate_item))
494  chosen_negative_items[n_points_picked++] = candidate_item;
495  }
496  }
497 
498  } else {
499 
500  ////////////////////////////////////////
501  // Case 2: Many rated items
502  //
503  // If the user has rated at least 7 / 8 of the items, then build a
504  // list of the free items and randomly sample from those. This
505  // list is saved in the buffer for all subsequent rounds of
506  // choosing negative example items for this user.
507 
508  size_t user = current_positive_example[0].index;
509 
510  ////////////////////////////////////////
511  // Step 2.1: Build the list of items, if necessary.
512 
513  if(proc_buf.user_of_available_item_list != user
514  || proc_buf.available_item_list.empty() ) {
515 
516  // The list has not been built yet; need to construct it.
517  DASSERT_LT(n_rated_items, n_items);
518 
519  proc_buf.available_item_list.resize(n_items - n_rated_items);
520 
521  size_t current_position = 0;
522 
523  // Do one round of first_zero_bit to get the location of the
524  // first zero bit. Then use next_zero_bit to walk the rest.
525 #ifndef NDEBUG
526  bool has_free_bit =
527 #endif
528  item_observed.first_zero_bit(current_position);
529 
530  DASSERT_TRUE(has_free_bit);
531  proc_buf.available_item_list[0] = current_position;
532 
533  size_t index_count = 1;
534 
535  while(true) {
536  bool found_index = item_observed.next_zero_bit(current_position);
537 
538  if(found_index) {
539  proc_buf.available_item_list[index_count++] = current_position;
540  } else {
541  break;
542  }
543  }
544 
545  DASSERT_EQ(index_count, n_items - n_rated_items);
546 
547  // This flags it as available for use on the next sampling
548  proc_buf.user_of_available_item_list = user;
549  proc_buf.available_item_list.resize(index_count);
550  }
551 
552  ////////////////////////////////////////
553  // Step 2.2: Sample randomly from the free items.
554 
555  proc_buf.available_item_list_chosen_indices.resize(num_sampled_negative_examples);
556  for(size_t i = 0; i < num_sampled_negative_examples; ++i) {
557  size_t idx = random::fast_uniform<size_t>(0, proc_buf.available_item_list.size()-1);
558  chosen_negative_items[i] = proc_buf.available_item_list[idx];
559  proc_buf.available_item_list_chosen_indices[i] = idx;
560  DASSERT_FALSE(item_observed.get(chosen_negative_items[i]));
561  }
562 
563  remove_from_available_item_list = true;
564  n_points_picked = num_sampled_negative_examples;
565  }
566 
567  ////////////////////////////////////////////////////////////////////////////////
568  // Step 3. Check that all the examples chosen are negative ones.
569 
570 #ifndef NDEBUG
571  {
572  DASSERT_EQ(n_points_picked, num_sampled_negative_examples);
573 
574  for(size_t i = 0; i < n_points_picked; ++i) {
575  DASSERT_FALSE(item_observed.get(chosen_negative_items[i]));
576  }
577  }
578 #endif
579 
580  ////////////////////////////////////////////////////////////////////////////////
581  // Step 4: Score all the items; keep track of the highest scoring
582  // one.
583 
584  std::vector<v2::ml_data_entry>& candidate_x = proc_buf.candidate_x;
585  candidate_x = current_positive_example;
586  size_t trim_size = candidate_x.size();
587 
588  if(data.has_side_features()) {
589  // Strip out the side features associated with the item row;
590  // they will get added in later.
591 
592  size_t lb, ub;
593  std::tie(lb, ub) =
594  data.get_side_features()->column_indices_of_side_information_block(ITEM_COLUMN_INDEX);
595 
596  auto new_end = std::remove_if(candidate_x.begin(), candidate_x.end(),
597  [&](const v2::ml_data_entry& v) {
598  return (lb <= v.column_index) && (v.column_index < ub);
599  });
600 
601  trim_size = new_end - candidate_x.begin();
602  }
603 
604  double highest_fx = std::numeric_limits<double>::lowest();
605  size_t chosen_idx = 0;
606 
607  // Build the item segments.
608  for(size_t i = 0; i < n_points_picked; ++i) {
609  candidate_x[1].index = chosen_negative_items[i];
610 
611  // Add in the side information as needed
612  if(data.has_side_features()) {
613  candidate_x.resize(trim_size);
614  data.get_side_features()->add_partial_side_features_to_row(candidate_x, ITEM_COLUMN_INDEX);
615  }
616 
617  double fx_hat = iface->calculate_fx(thread_idx, candidate_x);
618 
619  // We hit a numerical error. Baaaaad.
620  if(!std::isfinite(fx_hat))
621  return NAN;
622 
623  if(fx_hat > highest_fx) {
624  highest_fx = fx_hat;
625  negative_example_x = candidate_x;
626  chosen_idx = i;
627  }
628  }
629 
630  // We hit a numerical error. Baaaaad.
631  if(highest_fx == std::numeric_limits<double>::lowest())
632  return NAN;
633 
634  if(remove_from_available_item_list) {
635  DASSERT_LT(chosen_idx, proc_buf.available_item_list_chosen_indices.size());
636  size_t remove_idx = proc_buf.available_item_list_chosen_indices[chosen_idx];
637  DASSERT_LT(remove_idx, proc_buf.available_item_list.size());
638  std::swap(proc_buf.available_item_list[remove_idx], proc_buf.available_item_list.back());
639  proc_buf.available_item_list.pop_back();
640  }
641 
642  ////////////////////////////////////////////////////////////////////////////////
643  // Step 5: Return the value. Means we are done!
644 
645  return highest_fx;
646  }
647 
648  /** Clear out the item_observed buffer.
649  *
650  * Based on the number of items actually used, deletes stuff.
651  * Defined below.
652  *
653  */
654  template <typename BufferIndexToItemIndexMapper>
657  dense_bitset& item_observed, size_t n_rows, size_t n_items,
658  const BufferIndexToItemIndexMapper& map_index) const {
659 
660  // If the number of on items means that less than 1/8 of the
661  // bit-words are going to be touched, then just clear those
662  // selectively. Thus we optimize it for sparse arrays. Otherwise,
663  // it is faster to just memset the entire array.
664 
665  if(n_rows < n_items / ( 8*sizeof(size_t) * 8) ) {
666  for(size_t i = 0; i < n_rows; ++i) {
667  size_t index = map_index(i);
668  item_observed.clear_word_unsync(index);
669  }
670  } else {
671  item_observed.clear();
672  }
673 
674  DASSERT_TRUE(item_observed.empty());
675  }
676 };
677 
678 
679 }}
680 
681 #endif
ranking_sgd_solver_base(const std::shared_ptr< sgd::sgd_interface_base > &main_interface, const v2::ml_data &train_data, const std::map< std::string, flexible_type > &options)
bool next_zero_bit(size_t &b) const
GL_HOT_INLINE_FLATTEN void clear_item_observed_buffer(dense_bitset &item_observed, size_t n_rows, size_t n_items, const BufferIndexToItemIndexMapper &map_index) const
std::pair< double, double > calculate_objective(sgd::sgd_interface_base *model_iface, const v2::ml_data &data, size_t iteration) const GL_HOT
std::pair< double, double > run_iteration(size_t iteration, sgd::sgd_interface_base *model_iface, const v2::ml_data &data, double step_size)
void prefetch(size_t b) const
Prefetches the word containing the bit b.
void clear_word_unsync(size_t b)
bool get(size_t b) const
Returns the value of the bit b.
void shuffle(std::vector< T > &vec)
Definition: random.hpp:536
static size_t cpu_count()
bool first_zero_bit(size_t &b) const
static uint64_t hash64(const char *s, size_t len)
#define DASSERT_FALSE(cond)
Definition: assertions.hpp:365
void clear()
Sets all bits to 0.
#define GL_HOT_INLINE_FLATTEN
std::pair< size_t, size_t > fill_x_buffer_with_users_items(std::vector< std::pair< std::vector< v2::ml_data_entry >, double > > &x_buffer, v2::ml_data_block_iterator &it, size_t n_items, dense_bitset &item_observed) const GL_HOT_INLINE_FLATTEN
sgd_solver_base(const std::shared_ptr< sgd_interface_base > &model_interface, const v2::ml_data &_train_data, const std::map< std::string, flexible_type > &_options)
void in_parallel(const std::function< void(size_t thread_id, size_t num_threads)> &fn)
Definition: lambda_omp.hpp:35
const std::map< std::string, flexible_type > options
virtual std::pair< double, double > run_loss_calculation_thread(size_t thread_idx, size_t num_threads, const v2::ml_data &data, SGDInterface *iface) const =0
#define DASSERT_TRUE(cond)
Definition: assertions.hpp:364
virtual std::pair< double, double > run_sgd_thread(size_t iteration, size_t thread_idx, size_t num_threads, size_t block_idx, size_t num_blocks, const v2::ml_data &data, SGDInterface *iface, double step_size, volatile bool &error_detected)=0