Turi Create  4.0
decision_tree.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_DECISION_TREE_H_
7 #define TURI_DECISION_TREE_H_
8 // unity xgboost
9 #include <toolkits/supervised_learning/xgboost.hpp>
10 #include <toolkits/coreml_export/mlmodel_wrapper.hpp>
11 #include <core/export.hpp>
12 
13 namespace turi {
14 namespace supervised {
15 namespace xgboost {
16 
17 class EXPORT decision_tree_regression: public xgboost_model {
18 
19  public:
20 
21  /**
22  * Set one of the options in the algorithm.
23  *
24  * This values is checked against the requirements given by the option
25  * instance. Options that are not present use default options.
26  *
27  * \param[in] opts Options to set
28  */
29  void init_options(const std::map<std::string,flexible_type>& _opts) override;
30 
31  bool is_classifier() const override { return false; }
32 
33  /**
34  * Configure booster from options
35  */
36  void configure(void) override;
37 
38  std::shared_ptr<coreml::MLModelWrapper> export_to_coreml() override;
39 
40  BEGIN_CLASS_MEMBER_REGISTRATION("decision_tree_regression");
41  IMPORT_BASE_CLASS_REGISTRATION(supervised_learning_model_base);
43 };
44 
45 class EXPORT decision_tree_classifier: public xgboost_model {
46 
47  public:
48 
49  /**
50  * Initialize things that are specific to your model.
51  *
52  * \param[in] data ML-Data object created by the init function.
53  *
54  */
55  void model_specific_init(const ml_data& data,
56  const ml_data& valid_data) override;
57 
58  /**
59  * Set one of the options in the algorithm.
60  *
61  * This values is checked against the requirements given by the option
62  * instance. Options that are not present use default options.
63  *
64  * \param[in] opts Options to set
65  */
66  void init_options(const std::map<std::string, flexible_type>& _opts) override;
67 
68  bool is_classifier() const override { return true; }
69 
70  /**
71  * Configure booster from options
72  */
73  void configure(void) override;
74 
75  /**
76  * Set the default evaluation metric during model evaluation..
77  */
78  void set_default_evaluation_metric() override {
79  set_evaluation_metric({
80  "accuracy",
81  "auc",
82  "confusion_matrix",
83  "f1_score",
84  "log_loss",
85  "precision",
86  "recall",
87  "roc_curve",
88  });
89  }
90 
91  /**
92  * Set the default evaluation metric for progress tracking.
93  */
94  void set_default_tracking_metric() override {
95  set_tracking_metric({
96  "accuracy", "log_loss"
97  });
98  }
99 
100  std::shared_ptr<coreml::MLModelWrapper> export_to_coreml() override;
101 
102  BEGIN_CLASS_MEMBER_REGISTRATION("decision_tree_classifier");
103  IMPORT_BASE_CLASS_REGISTRATION(supervised_learning_model_base);
105 
106 
107 };
108 
109 } // namespace xgboost
110 } // namespace supervised
111 } // namespace turi
112 #endif
#define BEGIN_CLASS_MEMBER_REGISTRATION(python_facing_classname)
#define IMPORT_BASE_CLASS_REGISTRATION(base_class)
#define END_CLASS_MEMBER_REGISTRATION