6 #ifndef TURI_XGBOOST_H_ 7 #define TURI_XGBOOST_H_ 9 #include <core/storage/sframe_data/sarray.hpp> 10 #include <core/storage/sframe_data/sframe.hpp> 11 #include <core/data/sframe/gl_sarray.hpp> 12 #include <core/data/sframe/gl_sframe.hpp> 15 #include <ml/ml_data/ml_data.hpp> 18 #include <timer/timer.hpp> 19 #include <core/logging/table_printer/table_printer.hpp> 20 #include <core/export.hpp> 23 #include <toolkits/supervised_learning/supervised_learning.hpp> 24 #include <toolkits/coreml_export/mlmodel_wrapper.hpp> 35 namespace supervised {
41 enum class storage_mode_enum : int { IN_MEMORY = 0, EXT_MEMORY = 1, AUTO = 2 };
49 static constexpr
size_t XGBOOST_MODEL_VERSION = 9;
59 virtual void configure(
void) = 0;
71 void model_specific_init(
const ml_data& data,
72 const ml_data& valid_data)
override;
82 virtual void init_options(
const std::map<std::string,flexible_type>& _opts)
override;
93 void train(
void)
override;
105 std::shared_ptr<sarray<flexible_type>> predict(
107 const std::string& output_type=
"")
override;
112 using supervised_learning_model_base::predict;
122 const std::vector<flexible_type>& test_data,
123 const std::string& missing_value_action =
"error",
124 const std::string& output_type=
"")
override;
126 std::shared_ptr<sarray<flexible_type>> predict_impl(
127 const ::xgboost::learner::DMatrix& dmat,
128 const std::string& output_type=
"");
130 void xgboost_predict(const ::xgboost::learner::DMatrix& dmat,
132 std::vector<float>& out_preds,
133 double rf_running_rescale_constant=0.0);
144 const std::vector<flexible_type>& rows,
145 const std::string& missing_value_action =
"error",
146 const std::string& output_type=
"",
147 const size_t topk = 5)
override;
150 const ::xgboost::learner::DMatrix& dmat,
151 const std::string& output_type=
"",
152 const size_t topk = 5);
165 const std::string& output_type=
"",
166 const size_t topk = 2)
override;
177 std::map<std::string, variant_type> evaluate(
179 const std::string& evaluation_type=
"",
180 bool with_prediction=
false)
override;
182 std::map<std::string, variant_type> evaluate_impl(
183 const DMatrixMLData& dmat,
184 const std::string& evaluation_type=
"");
195 std::shared_ptr<sarray<flexible_type>> extract_features(
220 std::vector<std::string> dump(
bool with_stats);
221 std::vector<std::string> dump_json(
bool with_stats);
238 return XGBOOST_MODEL_VERSION;
254 bool is_random_forest();
259 size_t num_classes();
266 void _set_storage_mode(storage_mode_enum mode);
272 void _set_num_batches(
size_t num_batches);
277 std::pair<std::shared_ptr<DMatrixMLData>, std::shared_ptr<DMatrixMLData>> _init_data();
279 void _init_learner(std::shared_ptr<DMatrixMLData> ptrain, std::shared_ptr<DMatrixMLData> pvalid,
280 bool restore_from_checkpoint, std::string checkpoint_restore_path);
282 table_printer _init_progress_printer(
bool has_validation_data);
284 size_t _get_early_stopping_rounds(
bool has_validation_data);
286 void _save_training_state(
size_t iteration,
287 const std::vector<float>& training_metrics,
288 const std::vector<float>& validation_metrics,
289 std::shared_ptr<unity_sframe> progress_table,
290 double training_time);
292 void _checkpoint(
const std::string& path);
294 void _restore_from_checkpoint(
const std::string& path);
296 void _save(
oarchive& oarc,
bool save_booster_prediction_buffer)
const;
304 std::shared_ptr<::xgboost::learner::BoostLearner>
booster_;
306 storage_mode_enum storage_mode_ = storage_mode_enum::AUTO;
308 size_t num_batches_ = 0;
310 std::shared_ptr<coreml::MLModelWrapper> _export_xgboost_model(
bool is_classifier,
311 bool is_random_forest,
312 const std::map<std::string, flexible_type>& context);
The serialization input archive object which, provided with a reference to an istream, will read from the istream, providing deserialization capabilities.
bool support_missing_value() const override
ml_data ml_data_
internal ml data structure used for training
std::shared_ptr<::xgboost::learner::BoostLearner > booster_
this is the xgboost object supporting things
The serialization output archive object which, provided with a reference to an ostream, will write to the ostream, providing serialization capabilities.
virtual size_t get_version() const override