Turi Create  4.0
boosted_trees.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_BOOSTED_TREES_H_
7 #define TURI_BOOSTED_TREES_H_
8 // unity xgboost
9 #include <toolkits/supervised_learning/xgboost.hpp>
10 #include <toolkits/coreml_export/mlmodel_wrapper.hpp>
11 
12 #include <core/export.hpp>
13 
14 namespace turi {
15 namespace supervised {
16 namespace xgboost {
17 
18 /**
19  * Boosted trees regression.
20  *
21  */
23 
24  public:
25 
26  /**
27  * Set one of the options in the algorithm.
28  *
29  * This values is checked against the requirements given by the option
30  * instance. Options that are not present use default options.
31  *
32  * \param[in] opts Options to set
33  */
34  void init_options(const std::map<std::string,flexible_type>& _opts) override;
35 
36  bool is_classifier() const override { return false; }
37 
38  /**
39  * Configure booster from options
40  */
41  void configure(void) override;
42 
43  std::shared_ptr<coreml::MLModelWrapper> export_to_coreml() override;
44 
45 
46  BEGIN_CLASS_MEMBER_REGISTRATION("boosted_trees_regression");
49 
50 };
51 
52 /**It can also be used to predict the class of
53  * Boosted trees classifier.
54  *
55  */
56 class EXPORT boosted_trees_classifier : public xgboost_model {
57 
58  public:
59 
60  /**
61  * Initialize things that are specific to your model.
62  *
63  * \param[in] data ML-Data object created by the init function.
64  *
65  */
66  void model_specific_init(const ml_data& data,
67  const ml_data& valid_data) override;
68 
69  /**
70  * Set one of the options in the algorithm.
71  *
72  * This values is checked against the requirements given by the option
73  * instance. Options that are not present use default options.
74  *
75  * \param[in] opts Options to set
76  */
77  void init_options(const std::map<std::string, flexible_type>& _opts) override;
78 
79  bool is_classifier() const override { return true; }
80 
81  /**
82  * Configure booster from options
83  */
84  void configure(void) override;
85 
86  /**
87  * Set the default evaluation metric during model evaluation..
88  */
90  set_evaluation_metric({
91  "accuracy",
92  "auc",
93  "confusion_matrix",
94  "f1_score",
95  "log_loss",
96  "precision",
97  "recall",
98  "roc_curve",
99  });
100  }
101 
102  /**
103  * Set the default evaluation metric for progress tracking.
104  */
105  void set_default_tracking_metric() override {
106  set_tracking_metric({
107  "accuracy", "log_loss"
108  });
109  }
110 
111  std::shared_ptr<coreml::MLModelWrapper> export_to_coreml() override;
112 
113  BEGIN_CLASS_MEMBER_REGISTRATION("boosted_trees_classifier");
116 
117 };
118 
119 } // namespace xgboost
120 } // namespace supervised
121 } // namespace turi
122 #endif
#define BEGIN_CLASS_MEMBER_REGISTRATION(python_facing_classname)
#define IMPORT_BASE_CLASS_REGISTRATION(base_class)
#define END_CLASS_MEMBER_REGISTRATION