6 #ifndef TURI_REGR_LOGISTIC_REGRESSION_H_ 7 #define TURI_REGR_LOGISTIC_REGRESSION_H_ 10 #include <ml/ml_data/ml_data.hpp> 13 #include <toolkits/supervised_learning/supervised_learning.hpp> 14 #include <toolkits/coreml_export/mlmodel_wrapper.hpp> 17 #include <ml/optimization/optimization_interface.hpp> 19 #include <core/export.hpp> 22 namespace supervised {
24 class logistic_regression_opt_interface;
41 std::shared_ptr<logistic_regression_opt_interface> lr_interface;
42 Eigen::Matrix<double, Eigen::Dynamic,1>
coefs;
43 Eigen::Matrix<double, Eigen::Dynamic,1> std_err;
45 size_t num_classes = 0;
46 size_t num_coefficients= 0;
49 static constexpr
size_t LOGISTIC_REGRESSION_MODEL_VERSION = 6;
63 set_evaluation_metric({
90 void model_specific_init(
const ml_data& data,
const ml_data& valid_data)
override;
99 void init_options(
const std::map<std::string,flexible_type>& _options)
override;
104 size_t get_version()
const override;
110 void train()
override;
115 void set_coefs(
const DenseVector& _coefs)
override;
137 const DenseVector& x,
138 const prediction_type_enum& output_type=prediction_type_enum::NA)
override;
147 const std::vector<flexible_type>& rows,
148 const std::string& missing_value_action =
"error",
149 const std::string& output_type=
"",
150 const size_t topk = 5)
override;
162 const SparseVector& x,
163 const prediction_type_enum& output_type=prediction_type_enum::NA)
override;
169 _coefs.resize(coefs.size());
173 std::shared_ptr<coreml::MLModelWrapper> export_to_coreml()
override;
#define BEGIN_CLASS_MEMBER_REGISTRATION(python_facing_classname)
The serialization input archive object which, provided with a reference to an istream, will read from the istream, providing deserialization capabilities.
void set_default_tracking_metric() override
bool is_classifier() const override
#define IMPORT_BASE_CLASS_REGISTRATION(base_class)
#define END_CLASS_MEMBER_REGISTRATION
void set_default_evaluation_metric() override
Eigen::Matrix< double, Eigen::Dynamic, 1 > coefs
void get_coefficients(DenseVector &_coefs) const
The serialization output archive object which, provided with a reference to an ostream, will write to the ostream, providing serialization capabilities.