6 #ifndef TURI_SGD_BASIC_SGD_SOLVER_CLASS_H_ 7 #define TURI_SGD_BASIC_SGD_SOLVER_CLASS_H_ 10 #include <toolkits/ml_data_2/ml_data.hpp> 11 #include <toolkits/ml_data_2/ml_data_iterators.hpp> 12 #include <toolkits/sgd/sgd_solver_base.hpp> 13 #include <toolkits/sgd/sgd_interface.hpp> 25 template <
class SGDInterface>
26 class basic_sgd_solver :
public sgd_solver_base {
28 const size_t max_num_threads = 0;
31 std::vector<std::vector<std::pair<std::vector<v2::ml_data_entry>,
double> > > x_buffers;
39 basic_sgd_solver(
const std::shared_ptr<sgd_interface_base>& interface,
40 const v2::ml_data& train_data,
41 const std::map<std::string, flexible_type>& options)
44 , max_num_threads(thread::cpu_count())
46 x_buffers.resize(max_num_threads);
52 static void add_options(option_manager& options) {
64 std::pair<double, double> run_iteration(
66 sgd_interface_base* iface_base,
67 const v2::ml_data& data,
68 double step_size) GL_HOT {
73 SGDInterface* iface =
dynamic_cast<SGDInterface*
>(iface_base);
75 volatile bool error_detected =
false;
77 std::vector<double> loss_values(max_num_threads, 0);
79 iface->setup_iteration(iteration, step_size);
87 size_t num_blocks = 16*max_num_threads;
89 std::vector<size_t> blocks_to_use(num_blocks);
90 std::iota(blocks_to_use.begin(), blocks_to_use.end(), 0);
93 atomic<size_t> current_block = 0;
95 size_t block_size = options.at(
"sgd_sampling_block_size");
100 std::vector<std::pair<std::vector<v2::ml_data_entry>,
double> >& x_buffer = x_buffers[thread_idx];
101 x_buffer.resize(block_size);
103 loss_values[thread_idx] = 0;
105 while(!error_detected) {
106 size_t block_lookup_idx = (++current_block) - 1;
109 if(block_lookup_idx >= num_blocks)
112 size_t block = blocks_to_use[block_lookup_idx];
115 auto it = data.get_iterator(block, num_blocks);
117 while(!error_detected && !it.done()) {
122 for(idx = 0; idx < block_size && !it.done(); ++idx, ++it) {
123 it.fill_observation(x_buffer[idx].first);
124 x_buffer[idx].second = it.target_value();
132 for(
size_t i = 0; i < n && !error_detected; ++i) {
133 std::vector<v2::ml_data_entry>& x = x_buffer[i].first;
134 double y = x_buffer[i].second;
136 double current_loss_value = iface->apply_sgd_step(thread_idx, x, y, step_size);
142 loss_values[thread_idx] += current_loss_value;
144 if(!std::isfinite(loss_values[thread_idx]) ) {
145 logstream(
LOG_INFO) <<
"SGD: Non-finite loss value in thread " << thread_idx << std::endl;
146 error_detected =
true;
158 iface->finalize_iteration();
160 double loss_no_regularization
161 = (std::accumulate(loss_values.begin(), loss_values.end(), double(0.0))
162 / std::max(
size_t(1), data.size()));
164 double regularization_penalty = iface->current_regularization_penalty();
166 double objective_value_estimate = loss_no_regularization + regularization_penalty;
169 if(objective_value_estimate <= 1e-16)
172 double reported_training_loss = iface->reported_loss_value(loss_no_regularization);
174 return {objective_value_estimate, reported_training_loss};
void shuffle(std::vector< T > &vec)
static void add_options(option_manager &options)
sgd_solver_base(const std::shared_ptr< sgd_interface_base > &model_interface, const v2::ml_data &_train_data, const std::map< std::string, flexible_type > &_options)
void in_parallel(const std::function< void(size_t thread_id, size_t num_threads)> &fn)