Turi Create  4.0
linear_regression.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_REGR_LINEAR_REGRESSION_H_
7 #define TURI_REGR_LINEAR_REGRESSION_H_
8 
9 // ML-Data Utils
10 #include <ml/ml_data/metadata.hpp>
11 
12 
13 // Toolkits
14 #include <toolkits/supervised_learning/supervised_learning.hpp>
15 #include <toolkits/coreml_export/mlmodel_wrapper.hpp>
16 
17 // Optimization Interface
18 #include <ml/optimization/optimization_interface.hpp>
19 
20 #include <core/export.hpp>
21 
22 namespace turi {
23 namespace supervised {
24 
25 class linear_regression_opt_interface;
26 
27 /*
28  * Linear Regression Model
29  * ****************************************************************************
30  */
31 
32 
33 /**
34  * Linear regression model class definition.
35  *
36  */
38 
39 
40  protected:
41 
42  std::shared_ptr<linear_regression_opt_interface> lr_interface;
43 
44  public:
45 
46  static constexpr size_t LINEAR_REGRESSION_MODEL_VERSION = 4;
47  Eigen::Matrix<double, Eigen::Dynamic,1> coefs; /**< Coefs */
48  Eigen::Matrix<double, Eigen::Dynamic,1> std_err;
49 
50  /**
51  * Destructor. Make sure bad things don't happen
52  */
54 
55 
56  /**
57  * Initialize things that are specific to your model.
58  *
59  * \param[in] data ML-Data object created by the init function.
60  *
61  */
62  void model_specific_init(const ml_data& data, const ml_data& valid_data) override;
63 
64  /**
65  * Initialize the options.
66  *
67  * \param[in] _options Options to set
68  */
69  void init_options(const std::map<std::string,flexible_type>& _options) override;
70 
71  /**
72  * Gets the model version number
73  */
74  size_t get_version() const override;
75 
76  bool is_classifier() const override { return false; }
77 
78  /**
79  * Train a regression model.
80  */
81  void train() override;
82 
83  /**
84  * Setter for model coefficieints.
85  */
86  void set_coefs(const DenseVector& _coefs) override;
87 
88  /**
89  * Serialize the object.
90  */
91  void save_impl(turi::oarchive& oarc) const override;
92 
93  /**
94  * Load the object
95  */
96  void load_version(turi::iarchive& iarc, size_t version) override;
97 
98  /**
99  * Predict for a single example.
100  *
101  * \param[in] x Single example.
102  * \param[in] output_type Type of prediction.
103  *
104  * \returns Prediction for a single example.
105  *
106  */
107  flexible_type predict_single_example(
108  const DenseVector& x,
109  const prediction_type_enum& output_type=prediction_type_enum::NA) override;
110 
111  /**
112  * Predict for a single example.
113  *
114  * \param[in] x Single example.
115  * \param[in] output_type Type of prediction.
116  *
117  * \returns Prediction for a single example.
118  *
119  */
120  flexible_type predict_single_example(
121  const SparseVector& x,
122  const prediction_type_enum& output_type=prediction_type_enum::NA) override;
123 
124  /**
125  * Get coefficients for a trained model.
126  */
127  void get_coefficients(DenseVector& _coefs) const{
128  _coefs.resize(coefs.size());
129  _coefs = coefs;
130  }
131 
132  std::shared_ptr<coreml::MLModelWrapper> export_to_coreml() override;
133 
134  BEGIN_CLASS_MEMBER_REGISTRATION("regression_linear_regression");
137 };
138 
139 } // supervised
140 } // turicreate
141 
142 #endif
#define BEGIN_CLASS_MEMBER_REGISTRATION(python_facing_classname)
The serialization input archive object which, provided with a reference to an istream, will read from the istream, providing deserialization capabilities.
Definition: iarchive.hpp:60
#define IMPORT_BASE_CLASS_REGISTRATION(base_class)
#define END_CLASS_MEMBER_REGISTRATION
The serialization output archive object which, provided with a reference to an ostream, will write to the ostream, providing serialization capabilities.
Definition: oarchive.hpp:80
void get_coefficients(DenseVector &_coefs) const
Eigen::Matrix< double, Eigen::Dynamic, 1 > coefs