Turi Create  4.0
drawing_classifier.hpp
1 /* Copyright © 2019 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at
5  * https://opensource.org/licenses/BSD-3-Clause
6  */
7 #ifndef TURI_DRAWING_CLASSIFIER_H_
8 #define TURI_DRAWING_CLASSIFIER_H_
9 
10 #include <core/data/sframe/gl_sframe.hpp>
11 #include <core/logging/table_printer/table_printer.hpp>
12 #include <ml/neural_net/compute_context.hpp>
13 #include <ml/neural_net/model_backend.hpp>
14 #include <ml/neural_net/model_spec.hpp>
15 #include <model_server/lib/extensions/ml_model.hpp>
16 #include <model_server/lib/variant.hpp>
17 #include <toolkits/coreml_export/mlmodel_wrapper.hpp>
18 #include <toolkits/coreml_export/neural_net_models_exporter.hpp>
19 #include <toolkits/drawing_classifier/dc_data_iterator.hpp>
20 
21 namespace turi {
22 namespace drawing_classifier {
23 
24 class EXPORT drawing_classifier : public ml_model_base {
25  public:
26  static const size_t DRAWING_CLASSIFIER_VERSION;
27 
28  drawing_classifier() = default;
29 
30  /**
31  * ml_model_base interface
32  */
33 
34  void init_options(const std::map<std::string, flexible_type>& opts) override;
35 
36  size_t get_version() const override;
37 
38  void save_impl(oarchive& oarc) const override;
39 
40  void load_version(iarchive& iarc, size_t version) override;
41 
42  void import_from_custom_model(variant_map_type model_data, size_t version);
43 
44  /**
45  * Interface exposed via Unity server
46  */
47 
48  void train(gl_sframe data, std::string target_column_name,
49  std::string feature_column_name, variant_type validation_data,
50  std::map<std::string, flexible_type> opts);
51 
52  gl_sarray predict(gl_sframe data, std::string output_type = "probability");
53 
54  gl_sframe predict_topk(gl_sframe data,
55  std::string output_type = "probability", size_t k = 5);
56 
57  variant_map_type evaluate(gl_sframe data, std::string metric);
58 
59  std::shared_ptr<coreml::MLModelWrapper> export_to_coreml(
60  std::string filename, std::string short_description,
61  const std::map<std::string, flexible_type>& additional_user_defined,
62  bool use_default_spec = false);
63 
64  // Support for iterative training.
65  // TODO: Expose via forthcoming C-API checkpointing mechanism?
66  virtual void init_training(gl_sframe data, std::string target_column_name,
67  std::string feature_column_name,
68  variant_type validation_data,
69  std::map<std::string, flexible_type> opts);
70 
71  virtual void iterate_training(bool show_loss);
72 
73  BEGIN_CLASS_MEMBER_REGISTRATION("drawing_classifier")
74 
75  IMPORT_BASE_CLASS_REGISTRATION(ml_model_base);
76 
77  REGISTER_CLASS_MEMBER_FUNCTION(drawing_classifier::train, "data",
78  "target_column_name", "feature_column_name",
79  "validation_data", "options");
80 
81  register_defaults("train",
82  {{"validation_data", to_variant(std::string("auto"))},
83  {"options",
84  to_variant(std::map<std::string, flexible_type>())}});
85 
87  drawing_classifier::train,
88  "----------\n"
89  "data : SFrame\n"
90  " Input data, which consists of columns named by the\n"
91  " feature_column_name and target_column_name parameters, used for\n"
92  " training the Drawing Classifier."
93  "target_column_name : string\n"
94  " Name of the column containing the target variable. The values in "
95  " this column must be of string type.\n"
96  "feature_column_name : string\n"
97  " Name of the column containing the input drawings.\n"
98  " The feature column can contain either bitmap-based drawings or\n"
99  " stroke-based drawings. Bitmap-based drawing input can be a\n"
100  " grayscale tc.Image of any size.\n"
101  "\n"
102  " Stroke-based drawing input must be in the following format:\n"
103  " Every drawing must be represented by a list of strokes, where each\n"
104  " stroke must be a list of points in the order in which they were\n"
105  " drawn on the canvas.\n"
106  "\n"
107  " Each point must be a dictionary with two keys,\n"
108  " \"x\" and \"y\", and their\n"
109  " respective values must be numerical, i.e. either integer or float.\n"
110  "validatation_data : SFrame or string\n"
111  " A dataset for monitoring the model's generalization performance to\n"
112  " prevent the model from overfitting to the training data.\n"
113  "\n"
114  " For each row of the progress table, accuracy is measured over the\n"
115  " provided training dataset and the `validation_data`. The format of\n"
116  " this SFrame must be the same as the training set.\n"
117  "\n"
118  " When set to 'auto', a validation set is automatically sampled from "
119  "the\n"
120  " training data (if the training data has > 100 sessions).\n"
121  "options : dict\n"
122  "\n"
123  "Options\n"
124  "-------\n"
125  "max_iterations : int\n"
126  " Maximum number of iterations/epochs made over the data during the\n"
127  " training phase. The default is 500 iterations.\n"
128  "batch_size : int\n"
129  " Number of sequence chunks used per training step. Must be greater "
130  "than\n"
131  " the number of GPUs in use. The default is 32.\n"
132  "random_seed : int\n"
133  " The given seed is used for random weight initialization and\n"
134  " sampling during training\n");
135 
136  REGISTER_CLASS_MEMBER_FUNCTION(drawing_classifier::predict, "data",
137  "output_type");
138 
139  register_defaults("predict", {{"output_type", std::string("class")}});
140 
142  drawing_classifier::predict,
143  "----------\n"
144  "data : SFrame\n"
145  " The drawing(s) on which to perform drawing classification.\n"
146  " If dataset is an SFrame, it must have a column with the same name\n"
147  " as the feature column during training. Additional columns are\n"
148  " ignored.\n"
149  " If the data is a single drawing, it can be either of type\n"
150  " tc.Image, in which case it is a bitmap-based drawing input,\n"
151  " or of type list, in which case it is a stroke-based drawing input.\n"
152  "output_type : {\"class\", \"probability_vector\"}, optional\n"
153  " Form of each prediction which is one of:\n"
154  " - \"probability_vector\": Prediction probability associated with \n"
155  " each class as a vector. The probability of first class (sorted\n"
156  " alphanumerically by name of the class in the training set) is in\n"
157  " position 0 of the vector, the second in position 1 and so on.\n"
158  " - \"class\": Class prediction. This returns the class with maximum\n"
159  " probability.\n");
160 
161  REGISTER_CLASS_MEMBER_FUNCTION(drawing_classifier::predict_topk, "data",
162  "output_type", "k");
163 
164  register_defaults("predict_topk",
165  {{"output_type", std::string("probability")}});
166 
168  drawing_classifier::predict_topk,
169  "----------\n"
170  "data : SFrame\n"
171  " Dataset of new observations.\n"
172  " SFrame must include columns with the same\n"
173  " names as the features used for model training, but does not\n"
174  " require a target column. Additional columns are ignored."
175  "output_type : {\"probability\", \"rank\"}, optional\n"
176  " Form of each prediction which is one of:\n"
177  " - \"probability\": Probability associated with each label in the\n"
178  " prediction\n"
179  " - \"rank\": Rank associated with each label in the prediction.\n"
180  "k : int\n"
181  " Number of classes to return for each input example.\n");
182 
183  REGISTER_CLASS_MEMBER_FUNCTION(drawing_classifier::evaluate, "data",
184  "metric");
185 
186  register_defaults("evaluate", {{"metric", std::string("auto")}});
187 
189  drawing_classifier::evaluate,
190  "----------\n"
191  "data : SFrame\n"
192  " Dataset of new observations. Must include columns with the same\n"
193  " names as the features used for model training, but does not\n"
194  " require a target column. Additional columns are ignored.\n"
195  "metric : str, optional\n"
196  " Name of the evaluation metric. Possible values are:\n"
197  " - 'auto' : Returns all available metrics\n"
198  " - 'accuracy' : Classification accuracy (micro average)\n"
199  " - 'auc' : Area under the ROC curve (macro average)\n"
200  " - 'precision' : Precision score (macro average)\n"
201  " - 'recall' : Recall score (macro average)\n"
202  " - 'f1_score' : F1 score (macro average)\n"
203  " - 'log_loss' : Log loss\n"
204  " - 'confusion_matrix' : An SFrame with counts of possible\n"
205  " prediction/true label combinations.\n"
206  " - 'roc_curve' : An SFrame containing information needed for\n"
207  " an ROC curve\n");
208 
209  REGISTER_CLASS_MEMBER_FUNCTION(drawing_classifier::export_to_coreml,
210  "filename", "short_description", "additional_user_defined");
211  register_defaults("export_to_coreml",
212  {{"short_description", ""},
213  {"additional_user_defined", to_variant(std::map<std::string, flexible_type>())}});
214 
215  REGISTER_CLASS_MEMBER_FUNCTION(drawing_classifier::init_training, "data",
216  "target_column_name", "feature_column_name",
217  "validation_data", "options");
218  register_defaults("init_training",
219  {{"validation_data", to_variant(gl_sframe())},
220  {"options",
221  to_variant(std::map<std::string, flexible_type>())}});
222 
223  REGISTER_CLASS_MEMBER_FUNCTION(drawing_classifier::iterate_training);
224 
225  REGISTER_CLASS_MEMBER_FUNCTION(drawing_classifier::import_from_custom_model,
226  "model_data", "version");
227 
229 
230  protected:
231  // Constructor allowing tests to set the initial state of this class and to
232  // inject dependencies.
233  drawing_classifier(
234  const std::map<std::string, variant_type>& initial_state,
235  std::unique_ptr<neural_net::model_spec> nn_spec,
236  std::unique_ptr<neural_net::compute_context> training_compute_context,
237  std::unique_ptr<data_iterator> training_data_iterator,
238  std::unique_ptr<neural_net::model_backend> training_model)
239  : nn_spec_(std::move(nn_spec)),
240  training_data_iterator_(std::move(training_data_iterator)),
241  training_compute_context_(std::move(training_compute_context)),
242  training_model_(std::move(training_model)) {
243  add_or_update_state(initial_state);
244  }
245 
246  // Factory for data_iterator
247  virtual std::unique_ptr<data_iterator> create_iterator(
248  data_iterator::parameters iterator_params) const;
249 
250  // Factory for compute_context
251  virtual std::unique_ptr<neural_net::compute_context> create_compute_context()
252  const;
253 
254  // Returns the initial neural network to train
255  virtual std::unique_ptr<neural_net::model_spec> init_model(
256  bool use_random_init) const;
257 
258  virtual std::tuple<gl_sframe, gl_sframe> init_data(
259  gl_sframe data, variant_type validation_data) const;
260 
261  virtual std::tuple<float, float> compute_validation_metrics(
262  size_t num_classes, size_t batch_size);
263 
264  virtual void init_table_printer(bool has_validation, bool show_loss);
265 
266  template <typename T>
267  T read_state(const std::string& key) const {
268  try {
269  return variant_get_value<T>(get_state().at(key));
270  } catch (const std::out_of_range& e) {
271  std::stringstream ss;
272  ss << e.what() << std::endl;
273  ss << "from read_state for '" << key << "'" << std::endl;
274  throw std::out_of_range(ss.str().c_str());
275  }
276  }
277 
278  std::unique_ptr<neural_net::model_spec> clone_model_spec_for_test() const {
279  if (nn_spec_) {
280  return std::unique_ptr<neural_net::model_spec>(
281  new neural_net::model_spec(nn_spec_->get_coreml_spec()));
282  }
283  else {
284  return nullptr;
285  }
286  }
287 
288  gl_sframe perform_inference(data_iterator* data) const;
289  gl_sarray get_predictions_class(const gl_sarray& predictions_prob,
290  const flex_list& class_labels);
291 
292  private:
293  /**
294  * by design, this is NOT virtual;
295  * this calls the virtual create_iterator(parameters) in the end.
296  **/
297  std::unique_ptr<data_iterator> create_iterator(gl_sframe data, bool is_train,
298  flex_list class_labels) const;
299 
300  // Primary representation for the trained model.
301  std::unique_ptr<neural_net::model_spec> nn_spec_;
302 
303  // Primary dependencies for training. These should be nonnull while training
304  // is in progress.
305  gl_sframe training_data_; // TODO: Avoid storing gl_sframe AND data_iterator.
306  gl_sframe validation_data_;
307  std::unique_ptr<data_iterator> training_data_iterator_;
308  std::unique_ptr<data_iterator> validation_data_iterator_;
309  std::unique_ptr<neural_net::compute_context> training_compute_context_;
310  std::unique_ptr<neural_net::model_backend> training_model_;
311  // Nonnull while training is in progress, if progress printing is enabled.
312  std::unique_ptr<table_printer> training_table_printer_;
313 };
314 
315 } // namespace drawing_classifier
316 } // namespace turi
317 
318 #endif // TURI_DRAWING_CLASSIFIER_H_
#define BEGIN_CLASS_MEMBER_REGISTRATION(python_facing_classname)
#define REGISTER_CLASS_MEMBER_DOCSTRING(name, docstring)
#define REGISTER_CLASS_MEMBER_FUNCTION(function,...)
STL namespace.
#define IMPORT_BASE_CLASS_REGISTRATION(base_class)
boost::make_recursive_variant< flexible_type, std::shared_ptr< unity_sgraph_base >, dataframe_t, std::shared_ptr< model_base >, std::shared_ptr< unity_sframe_base >, std::shared_ptr< unity_sarray_base >, std::map< std::string, boost::recursive_variant_ >, std::vector< boost::recursive_variant_ >, boost::recursive_wrapper< function_closure_info > >::type variant_type
Definition: variant.hpp:24
#define END_CLASS_MEMBER_REGISTRATION
variant_type to_variant(const T &f)
Definition: variant.hpp:308
std::vector< flexible_type > flex_list