Turi Create  4.0
linear_svm.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_CLASS_LINEAR_SVM_H_
7 #define TURI_CLASS_LINEAR_SVM_H_
8 
9 // ML-Data Utils
10 #include <ml/ml_data/ml_data.hpp>
11 #include <core/data/sframe/gl_sframe.hpp>
12 
13 // Toolkits
14 #include <toolkits/supervised_learning/supervised_learning.hpp>
15 #include <toolkits/coreml_export/mlmodel_wrapper.hpp>
16 
17 // Optimization Interface
18 #include <ml/optimization/optimization_interface.hpp>
19 
20 #include <core/export.hpp>
21 
22 namespace turi {
23 namespace supervised {
24 
25 class linear_svm_scaled_logistic_opt_interface;
26 
27 /*
28  * SVM Model
29  * ****************************************************************************
30  */
31 
32 /**
33  * SVM svm model class definition.
34  *
35  */
37 
38  protected:
39 
40  Eigen::Matrix<double, Eigen::Dynamic,1> coefs; /**< Primal sol */
41  std::shared_ptr<linear_svm_scaled_logistic_opt_interface>
42  scaled_logistic_svm_interface;
43 
44  public:
45 
46  static constexpr size_t SVM_MODEL_VERSION = 5;
47  /**
48  * Destructor. Make sure bad things don't happen
49  */
50  virtual ~linear_svm();
51 
52  /**
53  * Set the default evaluation metric during model evaluation..
54  */
56  set_evaluation_metric({
57  "accuracy",
58  "confusion_matrix",
59  "f1_score",
60  "precision",
61  "recall",
62  });
63  }
64 
65  /**
66  * Set the default evaluation metric for progress tracking.
67  */
68  void set_default_tracking_metric() override {
69  set_tracking_metric({
70  "accuracy",
71  });
72  }
73 
74  /**
75  * Internal init after the ml_data is built.
76  *
77  * \param[in] data Training data
78  * \param[in] valid_data Validation data
79  *
80  */
81  void model_specific_init(const ml_data& data,
82  const ml_data& valid_data) override;
83 
84  bool is_classifier() const override { return true; }
85 
86  /**
87  * Train a svm model.
88  */
89  void train() override;
90 
91  /**
92  * Init the options.
93  *
94  * \param[in] opts Options to set
95  */
96  void init_options(const std::map<std::string,flexible_type>& _opts) override;
97 
98 
99  /**
100  * Gets the model version number
101  */
102  size_t get_version() const override;
103 
104  /**
105  * Setter for model coefficieints.
106  */
107  void set_coefs(const DenseVector& _coefs) override;
108 
109  /**
110  * Serialize the object.
111  */
112  void save_impl(turi::oarchive& oarc) const override;
113 
114  /**
115  * Load the object
116  */
117  void load_version(turi::iarchive& iarc, size_t version) override;
118 
119 
120  /**
121  * Predict for a single example.
122  *
123  * \param[in] x Single example.
124  * \param[in] output_type Type of prediction.
125  *
126  * \returns Prediction for a single example.
127  *
128  */
129  flexible_type predict_single_example(
130  const DenseVector& x,
131  const prediction_type_enum& output_type=prediction_type_enum::NA) override;
132 
133  /**
134  * Predict for a single example.
135  *
136  * \param[in] x Single example.
137  * \param[in] output_type Type of prediction.
138  *
139  * \returns Prediction for a single example.
140  *
141  */
142  flexible_type predict_single_example(
143  const SparseVector& x,
144  const prediction_type_enum& output_type=prediction_type_enum::NA) override;
145 
146  /**
147  * Make classification using a trained supervised_learning model.
148  *
149  * \param[in] X Test data (only independent variables)
150  * \param[in] output_type Type of classifcation (future proof).
151  * \returns ret SFrame with "class" and probability (if applicable)
152  *
153  * \note Already assumes that data is of the right shape.
154  */
155  sframe classify(const ml_data& test_data,
156  const std::string& output_type="") override;
157 
158  /**
159  * Fast path predictions given a row of flexible_types
160  *
161  * \param[in] rows List of rows (each row is a flex_dict)
162  * \param[in] output_type Output type.
163  */
164  gl_sframe fast_classify(
165  const std::vector<flexible_type>& rows,
166  const std::string& missing_value_action ="error") override;
167 
168  /**
169  * Get coefficients for a trained model.
170  */
171  void get_coefficients(DenseVector& _coefs) const{
172  _coefs.resize(coefs.size());
173  _coefs = coefs;
174  }
175 
176  std::shared_ptr<coreml::MLModelWrapper> export_to_coreml() override;
177 
178  BEGIN_CLASS_MEMBER_REGISTRATION("classifier_svm");
181 
182 };
183 
184 } // supervised
185 } // turicreate
186 
187 #endif
#define BEGIN_CLASS_MEMBER_REGISTRATION(python_facing_classname)
The serialization input archive object which, provided with a reference to an istream, will read from the istream, providing deserialization capabilities.
Definition: iarchive.hpp:60
#define IMPORT_BASE_CLASS_REGISTRATION(base_class)
void set_default_evaluation_metric() override
Definition: linear_svm.hpp:55
#define END_CLASS_MEMBER_REGISTRATION
Eigen::Matrix< double, Eigen::Dynamic, 1 > coefs
Definition: linear_svm.hpp:40
void get_coefficients(DenseVector &_coefs) const
Definition: linear_svm.hpp:171
void set_default_tracking_metric() override
Definition: linear_svm.hpp:68
The serialization output archive object which, provided with a reference to an ostream, will write to the ostream, providing serialization capabilities.
Definition: oarchive.hpp:80
bool is_classifier() const override
Definition: linear_svm.hpp:84