Turi Create  4.0
linear_svm_opt_interface.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_CLASS_LINEAR_SVM_OPT_INTERFACE_H_
7 #define TURI_CLASS_LINEAR_SVM_OPT_INTERFACE_H_
8 
9 // ML-Data Utils
10 #include <ml/ml_data/ml_data.hpp>
11 
12 // Toolkits
13 #include <toolkits/supervised_learning/standardization-inl.hpp>
14 #include <toolkits/supervised_learning/supervised_learning.hpp>
15 #include <toolkits/supervised_learning/linear_svm.hpp>
16 
17 // Optimization Interface
18 #include <ml/optimization/optimization_interface.hpp>
19 
20 
21 namespace turi {
22 namespace supervised {
23 
24 
25 /*
26  * SVM Solver
27  * *****************************************************************************
28  *
29  */
30 
31  /**
32  *
33  * Scaled Logistic Loss function
34  * --------------------------------
35  *
36  * SVM is trained using LBFGS on the Modifified logistic function described in
37  * [1]. It is much simpler to optimize and very close to the hinge loss.
38  *
39  * References:
40  *
41  * [1] Modified Logistic Regression: An Approximation to SVM and Its
42  * Applications in Large-Scale Text Categorization - Zhang et al ICML 2003
43  *
44  *
45  */
48 
49  protected:
50 
51  ml_data data;
52  ml_data valid_data;
53  linear_svm smodel;
54 
55 
56  size_t features; /**< Num features */
57  size_t examples; /**< Num examples */
58  size_t primal_variables; /**< Primal variables */
59  size_t classes = 2; /**< Number of classes */
60 
61  std::map<int, float> class_weights = {{0,1.0}, {1, 1.0}};
62 
63  size_t n_threads;
64  std::shared_ptr<l2_rescaling> scaler; /** <Scale features */
65  bool feature_rescaling = false; /** Feature rescaling */
66  double gamma = 30;
67  bool is_dense = false; /** Is the data dense? */
68 
69  public:
70 
71  /**
72  * Default constructor
73  *
74  * \param[in] _data ML Data containing everything
75  *
76  */
78  const ml_data& _valid_data,
79  linear_svm& model);
80 
81 
82  /**
83  * Set the scale for the scaled logistic loss.
84  * \param[in] _gamma Set the Gamma
85  *
86  */
87  void set_gamma(const double _gamma);
88 
89  /**
90  * Default destructor
91  */
93 
94  /**
95  * Set feature scaling
96  */
98 
99  /**
100  * Transform the final solution back to the original scale.
101  *
102  * \param[in,out] coefs Solution vector
103  */
104  void rescale_solution(DenseVector& coefs);
105 
106  /**
107  * Set the number of threads
108  *
109  * \param[in] _n_threads Number of threads
110  */
111  void set_threads(size_t _n_threads);
112 
113  /**
114  * Get the number of examples for the model
115  *
116  * \returns Number of examples
117  */
118  size_t num_examples() const;
119 
120  /**
121  * Get the number of validation-set examples for the model
122  *
123  * \returns Number of examples
124  */
125  size_t num_validation_examples() const;
126 
127  /**
128  * Get the number of variables in the model
129  *
130  * \returns Number of variables
131  */
132  size_t num_variables() const;
133 
134  /**
135  * Get strings needed to print the header for the progress table.
136  *
137  * \param[in] a vector of strings to print at the beginning of the header.
138  */
139  std::vector<std::pair<std::string, size_t>>
140  get_status_header(const std::vector<std::string>& stat_names);
141 
142  /**
143  * Get strings needed to print a row of the progress table.
144  *
145  * \param[in] a vector of model coefficients.
146  * \param[in] a vector of stats to print at the beginning of each row
147  */
148  std::vector<std::string> get_status(const DenseVector& coefs,
149  const std::vector<std::string>& stats);
150 
151  /**
152  * Set the class weights (as a flex_dict which is already validated)
153  *
154  * \param[in] class_weights Validated flex_dict
155  * Key : Index of the class in the target_metadata
156  * Value : Weights on the class
157  */
158  void set_class_weights(const flexible_type& class_weights);
159 
160  /**
161  * Get the number of classes in the model
162  *
163  * \returns Number of classes
164  */
165  size_t num_classes() const;
166 
167  double get_validation_accuracy();
168  double get_training_accuracy();
169 
170  /**
171  * Compute first order statistics at the given point. (Gradient & Function value)
172  *
173  * \param[in] point Point at which we are computing the stats.
174  * \param[out] gradient Dense gradient
175  * \param[out] function_value Function value
176  * \param[in] mbStart Minibatch start index
177  * \param[in] mbSize Minibatch size (-1 implies all)
178  *
179  */
180  void compute_first_order_statistics(const DenseVector &point, DenseVector&
181  gradient, double & function_value, const size_t mbStart = 0, const size_t
182  mbSize = -1);
183 
184 
185 };
186 
187 
188 
189 } // supervised
190 } // turicreate
191 
192 #endif
std::vector< std::string > get_status(const DenseVector &coefs, const std::vector< std::string > &stats)
std::vector< std::pair< std::string, size_t > > get_status_header(const std::vector< std::string > &stat_names)
linear_svm_scaled_logistic_opt_interface(const ml_data &_data, const ml_data &_valid_data, linear_svm &model)
void set_class_weights(const flexible_type &class_weights)
void compute_first_order_statistics(const DenseVector &point, DenseVector &gradient, double &function_value, const size_t mbStart=0, const size_t mbSize=-1)