Turi Create  4.0
accelerated_gradient-inl.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_ACCELERATED_GRADIENT_H_
7 #define TURI_ACCELERATED_GRADIENT_H_
8 
9 // Types
10 #include <core/data/flexible_type/flexible_type.hpp>
11 #include <Eigen/Core>
12 
13 // Optimization
14 #include <ml/optimization/utils.hpp>
15 #include <ml/optimization/optimization_interface.hpp>
16 #include <ml/optimization/regularizer_interface.hpp>
17 #include <ml/optimization/line_search-inl.hpp>
18 #include <core/logging/table_printer/table_printer.hpp>
19 
20 
21 // TODO: List of todo's for this file
22 //------------------------------------------------------------------------------
23 // 1. FISTA's proximal abstraction for regularizers.
24 // 2. Perf improvement for sparse gradients.
25 
26 namespace turi {
27 
28 namespace optimization {
29 
30 /**
31  * \ingroup group_optimization
32  * \addtogroup FISTA FISTA
33  * \{
34  */
35 
36 /**
37  *
38  * Solve a first_order_optimization_iterface model with a dense accelerated
39  * gradient method.
40  *
41  * The algorithm is based on FISTA with backtracking (Beck and Teboulle 2009).
42  * Details are in Page 194 of [1].
43  *
44  * \param[in,out] model Model with first order optimization interface.
45  * \param[in] init_point Starting point for the solver.
46  * \param[in,out] opts Solver options.
47  * \param[in] reg Shared ptr to an interface to a regularizer.
48  * \returns stats Solver return stats.
49  * \tparam Vector Sparse or dense gradient representation.
50  *
51  * \note Fista is an accelerated gradient method. We can try Nesterov's
52  * accelerated gradient method too.
53  *
54  * References:
55  *
56  * [1] Beck, Amir, and Marc Teboulle. "A fast iterative shrinkage-thresholding
57  * algorithm for linear inverse problems." SIAM Journal on Imaging Sciences 2.1
58  * (2009): 183-202. http://mechroom.technion.ac.il/~becka/papers/71654.pdf
59  *
60  *
61 */
62 template <typename Vector = DenseVector>
64  const DenseVector& init_point,
65  std::map<std::string, flexible_type>& opts,
66  const std::shared_ptr<regularizer_interface> reg=NULL){
67 
68  // Benchmarking utils.
69  timer tmr;
70  double start_time = tmr.current_time();
71  logprogress_stream << "Starting Accelerated Gradient (FISTA)" << std::endl;
72  logprogress_stream << "--------------------------------------------------------" << std::endl;
73  std::stringstream ss;
74  ss.str("");
75 
76  // First iteration will take longer. Warn the user.
77  logprogress_stream <<"Tuning step size. First iteration could take longer"
78  <<" than subsequent iterations." << std::endl;
79 
80  // Print progress
81  table_printer printer(
82  model.get_status_header({"Iteration", "Passes", "Step size", "Elapsed Time"}));
83  printer.print_header();
84 
85  // Step 1: Algorithm option init
86  // ------------------------------------------------------------------------
87 
88  size_t iter_limit = opts["max_iterations"];
89  double convergence_threshold = opts["convergence_threshold"];
90  size_t iters = 0;
91  double step_size = opts["step_size"];
92  solver_return stats;
93 
94  // Store previous point and gradient information
95  DenseVector point = init_point; // Initial point
96  DenseVector delta_point = point; // Step taken
97  DenseVector y = point; // Momentum
98  DenseVector xp = point; // Point in the previos iter
99  DenseVector x = point; // Point in the current
100  delta_point.setZero();
101 
102  // First compute the residual. Sometimes, you already have the solution
103  // during the starting point. In these settings, you don't want to waste
104  // time performing a step of the algorithm.
105  Vector gradient(point.size());
106  double fy;
107  model.compute_first_order_statistics(y, gradient, fy);
108  double residual = compute_residual(gradient);
109  stats.num_passes++;
110 
111  std::vector<std::string> stat_info = {std::to_string(iters),
112  std::to_string(stats.num_passes),
113  std::to_string(step_size),
114  std::to_string(tmr.current_time())};
115  std::vector<std::string> row = model.get_status(point, stat_info);
116  printer.print_progress_row_strs(iters, row);
117 
118  // Value of parameters t in itersation k-1 and k
119  double t = 1; // t_k
120  double tp = 1; // t_{k-1}
121  double fply, Qply;
122 
123  // Nan Checking!
124  if (std::isnan(residual) || std::isinf(residual)){
126  }
127 
128  // Step 2: Algorithm starts here
129  // ------------------------------------------------------------------------
130  // While not converged
131  while((residual >= convergence_threshold) && (iters < iter_limit)){
132 
133 
134  // Auto tuning the step_size
135  while (step_size > LS_ZERO){
136 
137  // FISTA with backtracking
138  // Equation 4: Page 194 of (1)
139 
140  // Test point
141  // point = prox(y - \grad_f(y) * s) (where s is the step size)
142  point = y - gradient * step_size;
143  if(reg != NULL)
144  reg->apply_proximal_operator(point, step_size);
145 
146  // Compute
147  // f(point)
148  fply = model.compute_function_value(point);
149  stats.func_evals++;
150 
151  // Compute
152  // f(y) + 0.5 * s * |delta_point|^2 + delta_point^T \grad_f(y)
153  delta_point = (point - y);
154  Qply = fy + delta_point.dot(gradient) + 0.5
155  * delta_point.squaredNorm() / step_size;
156 
157  if (fply < Qply){
158  break;
159  }
160 
161  // Reduce step size until a sufficient decrease is satisfied.
162  step_size /= 1.5;
163 
164  }
165 
166  // FISTA Iteration
167  // Equation 4: Page 193 of (1)
168  x = point;
169  t = (1 + sqrt(1 + 4*tp*tp))/2;
170  y = x + (tp - 1)/t * (x - xp);
171 
172  delta_point = x - xp;
173  xp = x;
174  tp = t;
175 
176  // Numerical error: Insufficient progress.
177  if (delta_point.norm() <= OPTIMIZATION_ZERO){
179  break;
180  }
181  // Numerical error: Numerical overflow. (Step size was too large)
182  if (!delta_point.array().isFinite().all()) {
184  break;
185  }
186 
187  // Compute residual norm (to check for convergence)
188  model.compute_first_order_statistics(y, gradient, fy);
189  stats.num_passes++;
190  // Changed the convergence criterion to stop when no progress is being
191  // made.
192  residual = compute_residual(delta_point);
193  iters++;
194 
195  // Check for nan's in the function value.
196  if(std::isinf(fy) || std::isnan(fy)) {
198  break;
199  }
200 
201  // Print progress
202  stat_info = {std::to_string(iters),
203  std::to_string(stats.num_passes),
204  std::to_string(step_size),
205  std::to_string(tmr.current_time())};
206  row = model.get_status(point, stat_info);
207  printer.print_progress_row_strs(iters, row);
208 
209  // Log info for debugging.
210  logstream(LOG_INFO) << "Iters (" << iters << ") "
211  << "Passes (" << stats.num_passes << ") "
212  << "Residual (" << residual << ") "
213  << "Loss (" << fy << ") "
214  << std::endl;
215  }
216  printer.print_footer();
217 
218  // Step 3: Return optimization model status.
219  // ------------------------------------------------------------------------
220  if (stats.status == OPTIMIZATION_STATUS::OPT_UNSET) {
221  if (iters < iter_limit){
223  } else {
225  }
226  }
227  stats.iters = static_cast<int>(iters);
228  stats.residual = residual;
229  stats.func_value = fy;
230  stats.gradient = gradient;
231  stats.solve_time = tmr.current_time() - start_time;
232  stats.solution = point;
233  stats.progress_table = printer.get_tracked_table();
234 
235  // Display solver stats
237 
238  return stats;
239 }
240 
241 /// \}
242 
243 } // optimizaiton
244 
245 } // turicreate
246 
247 #endif
#define logstream(lvl)
Definition: logger.hpp:276
const double OPTIMIZATION_ZERO
Optimization method zero.
#define LOG_INFO
Definition: logger.hpp:101
virtual std::vector< std::string > get_status(const DenseVector &coefs, const std::vector< std::string > &stats)
virtual std::vector< std::pair< std::string, size_t > > get_status_header(const std::vector< std::string > &stats)
double current_time() const
Returns the elapsed time in seconds since turi::timer::start was last called.
Definition: timer.hpp:83
#define logprogress_stream
Definition: logger.hpp:325
solver_return accelerated_gradient(first_order_opt_interface &model, const DenseVector &init_point, std::map< std::string, flexible_type > &opts, const std::shared_ptr< regularizer_interface > reg=NULL)
Numerical overflow. Step size parameter may be too large.
double compute_residual(const DenseVector &gradient)
virtual double compute_function_value(const DenseVector &point, const size_t mbStart=0, const size_t mbSize=-1)
void print_header() const
const double LS_ZERO
Smallest allowable step length.
A simple class that can be used for benchmarking/timing up to microsecond resolution.
Definition: timer.hpp:59
void log_solver_summary_stats(const solver_return &stats, bool simple_mode=false)
virtual void compute_first_order_statistics(const DenseVector &point, DenseVector &gradient, double &function_value, const size_t mbStart=0, const size_t mbSize=-1)=0
Numerical underflow (not enough progress).