Turi Create  4.0
object_detector.hpp
1 /* Copyright © 2018 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 
7 #ifndef TURI_OBJECT_DETECTION_OBJECT_DETECTOR_H_
8 #define TURI_OBJECT_DETECTION_OBJECT_DETECTOR_H_
9 
10 #include <functional>
11 #include <map>
12 #include <memory>
13 #include <queue>
14 
15 #include <core/data/sframe/gl_sframe.hpp>
16 #include <core/logging/table_printer/table_printer.hpp>
17 #include <ml/neural_net/compute_context.hpp>
18 #include <ml/neural_net/image_augmentation.hpp>
19 #include <ml/neural_net/model_backend.hpp>
20 #include <ml/neural_net/model_spec.hpp>
21 #include <model_server/lib/extensions/ml_model.hpp>
22 #include <toolkits/coreml_export/mlmodel_wrapper.hpp>
23 #include <toolkits/object_detection/od_data_iterator.hpp>
25 
26 namespace turi {
27 namespace object_detection {
28 
29 class EXPORT object_detector: public ml_model_base {
30  public:
31  object_detector() = default;
32 
33  // ml_model_base interface
34 
35  void init_options(const std::map<std::string, flexible_type>& opts) override;
36  size_t get_version() const override;
37  void save_impl(oarchive& oarc) const override;
38  void load_version(iarchive& iarc, size_t version) override;
39 
40  // Interface exposed via Unity server
41 
42  void train(gl_sframe data, std::string annotations_column_name,
43  std::string image_column_name, variant_type validation_data,
44  std::map<std::string, flexible_type> opts);
45  variant_type evaluate(gl_sframe data, std::string metric,
46  std::string output_type,
47  std::map<std::string, flexible_type> opts);
48  variant_type predict(variant_type data,
49  std::map<std::string, flexible_type> opts);
50  virtual std::shared_ptr<coreml::MLModelWrapper> export_to_coreml(
51  std::string filename, std::string short_description,
52  std::map<std::string, flexible_type> additional_user_defined,
53  std::map<std::string, flexible_type> opts);
54  void import_from_custom_model(variant_map_type model_data, size_t version);
55 
56  // Support for iterative training.
57  virtual void init_training(gl_sframe data,
58  std::string annotations_column_name,
59  std::string image_column_name,
60  variant_type validation_data,
61  std::map<std::string, flexible_type> opts);
62  virtual void resume_training(gl_sframe data, variant_type validation_data);
63  virtual void iterate_training();
64  virtual void synchronize_training();
65  virtual void finalize_training(bool compute_final_metrics);
66 
67  // Register with Unity server
68 
69  BEGIN_CLASS_MEMBER_REGISTRATION("object_detector")
70 
71  IMPORT_BASE_CLASS_REGISTRATION(ml_model_base);
72 
73  REGISTER_CLASS_MEMBER_FUNCTION(object_detector::train, "data",
74  "annotations_column_name",
75  "image_column_name", "validation_data",
76  "options");
77  register_defaults("train",
78  {{"validation_data", to_variant(gl_sframe())},
79  {"options",
80  to_variant(std::map<std::string, flexible_type>())}});
82  object_detector::train,
83  "\n"
84  "Options\n"
85  "-------\n"
86  "mlmodel_path : string\n"
87  " Path to the CoreML specification with the pre-trained model parameters.\n"
88  "batch_size: int\n"
89  " The number of images per training iteration. If 0, then it will be\n"
90  " automatically determined based on resource availability.\n"
91  "max_iterations : int\n"
92  " The number of training iterations. If 0, then it will be automatically\n"
93  " be determined based on the amount of data you provide.\n"
94  );
95 
96  REGISTER_CLASS_MEMBER_FUNCTION(object_detector::init_training, "data",
97  "annotations_column_name", "image_column_name",
98  "validation_data", "options");
99  register_defaults("init_training",
100  {{"validation_data", to_variant(gl_sframe())},
101  {"options",
102  to_variant(std::map<std::string, flexible_type>())}});
103 
104  REGISTER_CLASS_MEMBER_FUNCTION(object_detector::resume_training, "data",
105  "validation_data");
106  register_defaults("resume_training",
107  {{"validation_data", to_variant(gl_sframe())}});
108 
109  REGISTER_CLASS_MEMBER_FUNCTION(object_detector::iterate_training);
110  REGISTER_CLASS_MEMBER_FUNCTION(object_detector::synchronize_training);
111  REGISTER_CLASS_MEMBER_FUNCTION(object_detector::finalize_training,
112  "compute_final_metrics");
113  register_defaults("finalize_training", {{"compute_final_metrics", true}});
114 
115  REGISTER_CLASS_MEMBER_FUNCTION(object_detector::evaluate, "data", "metric",
116  "output_type", "options");
117  register_defaults("evaluate",
118  {
119  {"metric", std::string("auto")},
120  {"output_type", std::string("dict")},
121  {"options",
122  to_variant(std::map<std::string, flexible_type>())},
123  });
124 
125  REGISTER_CLASS_MEMBER_FUNCTION(object_detector::predict, "data", "options");
126  register_defaults("predict",{});
127 
128  REGISTER_CLASS_MEMBER_FUNCTION(object_detector::export_to_coreml, "filename",
129  "short_description", "additional_user_defined", "options");
130  register_defaults("export_to_coreml",
131  {{"short_description", ""},
132  {"additional_user_defined", to_variant(std::map<std::string, flexible_type>())},
133  {"options", to_variant(std::map<std::string, flexible_type>())}});
134 
136  object_detector::export_to_coreml,
137  "\n"
138  "Options\n"
139  "-------\n"
140  "include_non_maximum_suppression : bool\n"
141  " A boolean value \"True\" or \"False\" to indicate the use of Non Maximum Suppression.\n"
142  "iou_threshold: double\n"
143  " The allowable IOU overlap between bounding box detections for the same object.\n"
144  " If no value is specified, a default value of 0.45 is used.\n"
145  "confidence_threshold : double\n"
146  " The minimum required object confidence score per bounding box detection.\n"
147  " All bounding box detections with object confidence score lower than\n"
148  " the confidence_threshold are eliminiated. If no value is specified,\n"
149  " a default value of 0.25 is used.\n"
150  );
151 
152  REGISTER_CLASS_MEMBER_FUNCTION(object_detector::import_from_custom_model,
153  "model_data", "version");
154 
155  // TODO: Remainder of interface: predict, etc.
156 
158 
159  protected:
160  // Constructor allowing tests to set the initial state of this class.
161  object_detector(std::map<std::string, variant_type> initial_state,
162  neural_net::float_array_map initial_weights) {
163  load(std::move(initial_state), std::move(initial_weights));
164  }
165 
166  // Resets the internal state. Used by deserialization code and unit tests.
167  void load(std::map<std::string, variant_type> state,
168  neural_net::float_array_map weights);
169 
170  // Assumes state already loaded.
171  virtual std::unique_ptr<Checkpoint> load_checkpoint(
172  neural_net::float_array_map weights) const;
173 
174  // Synchronously loads weights from the backend if necessary.
175  const Checkpoint& read_checkpoint() const;
176 
177  // Override points allowing subclasses to inject dependencies
178 
179  // Factory for data_iterator
180  virtual std::unique_ptr<data_iterator> create_iterator(
181  data_iterator::parameters iterator_params) const;
182 
183  std::unique_ptr<data_iterator> create_iterator(
184  gl_sframe data, std::vector<std::string> class_labels, bool repeat,
185  bool is_training) const;
186 
187  // Factory for compute_context
188  virtual
189  std::unique_ptr<neural_net::compute_context> create_compute_context() const;
190 
191  // Factories for ModelTrainer
192  virtual std::unique_ptr<ModelTrainer> create_trainer(
193  const Config& config, const std::string& pretrained_model_path, int random_seed,
194  std::unique_ptr<neural_net::compute_context> context) const;
195  virtual std::unique_ptr<ModelTrainer> create_inference_trainer(
196  const Checkpoint& checkpoint,
197  std::unique_ptr<neural_net::compute_context> context) const;
198 
199  // Establishes training pipelines from the backend.
200  void connect_trainer(std::unique_ptr<ModelTrainer> trainer,
201  std::unique_ptr<data_iterator> iterator, int batch_size);
202 
203  virtual std::vector<neural_net::image_annotation> convert_yolo_to_annotations(
204  const neural_net::float_array& yolo_map,
205  const std::vector<std::pair<float, float>>& anchor_boxes,
206  float min_confidence);
207 
208  virtual variant_type perform_evaluation(gl_sframe data, std::string metric,
209  std::string output_type,
210  float confidence_threshold,
211  float iou_threshold);
212 
213  void perform_predict(
214  gl_sframe data,
215  std::function<void(const std::vector<neural_net::image_annotation>&,
216  const std::vector<neural_net::image_annotation>&,
217  const std::pair<float, float>&)>
218  consumer,
219  float confidence_threshold, float iou_threshold);
220 
221  // When true, shows all metadata, when false shows less metadata
222  virtual bool should_export_all_metadata() const;
223 
224  // Utility code
225 
226  template <typename T>
227  T read_state(const std::string& key) const {
228  return variant_get_value<T>(get_state().at(key));
229  }
230 
231  private:
232  neural_net::float_array_map strip_fwd(
233  const neural_net::float_array_map& params) const;
234 
235  flex_int get_max_iterations() const;
236  flex_int get_training_iterations() const;
237  flex_int get_num_classes() const;
238 
239  static variant_type convert_map_to_types(const variant_map_type& result_map,
240  const std::string& output_type,
241  const flex_list& class_labels);
242  static gl_sframe convert_types_to_sframe(const variant_type& data,
243  const std::string& column_name);
244 
245  // Sets certain user options heuristically (from the data).
246  void infer_derived_options(neural_net::compute_context* context,
247  data_iterator* iterator);
248 
249  // Waits until the number of pending patches is at most `max_pending`.
250  void wait_for_training_batches(size_t max_pending = 0);
251 
252  // Computes and records training/validation metrics.
253  void update_model_metrics(gl_sframe data, gl_sframe validation_data);
254 
255  // Primary representation for the trained model. Can be null if the model has
256  // been updated since the last checkpoint.
257  mutable std::unique_ptr<Checkpoint> checkpoint_;
258 
259  // Primary dependencies for training. These should be nonnull while training
260  // is in progress.
261  gl_sframe training_data_; // TODO: Avoid storing gl_sframe AND data_iterator.
262  gl_sframe validation_data_;
263  std::shared_ptr<neural_net::FuturesStream<TrainingOutputBatch>>
264  training_futures_;
265  std::shared_ptr<neural_net::FuturesStream<std::unique_ptr<Checkpoint>>>
266  checkpoint_futures_;
267 
268  // Nonnull while training is in progress, if progress printing is enabled.
269  std::unique_ptr<table_printer> training_table_printer_;
270 
271  std::queue<std::future<std::unique_ptr<TrainingOutputBatch>>>
272  pending_training_batches_;
273 
274  struct inference_batch : neural_net::image_augmenter::result {
275  std::vector<std::pair<float, float>> image_dimensions_batch;
276  };
277 };
278 
279 } // object_detection
280 } // turi
281 
282 #endif // TURI_OBJECT_DETECTION_OBJECT_DETECTOR_H_
#define BEGIN_CLASS_MEMBER_REGISTRATION(python_facing_classname)
#define REGISTER_CLASS_MEMBER_DOCSTRING(name, docstring)
#define REGISTER_CLASS_MEMBER_FUNCTION(function,...)
#define IMPORT_BASE_CLASS_REGISTRATION(base_class)
boost::make_recursive_variant< flexible_type, std::shared_ptr< unity_sgraph_base >, dataframe_t, std::shared_ptr< model_base >, std::shared_ptr< unity_sframe_base >, std::shared_ptr< unity_sarray_base >, std::map< std::string, boost::recursive_variant_ >, std::vector< boost::recursive_variant_ >, boost::recursive_wrapper< function_closure_info > >::type variant_type
Definition: variant.hpp:24
#define END_CLASS_MEMBER_REGISTRATION
variant_type to_variant(const T &f)
Definition: variant.hpp:308
std::vector< flexible_type > flex_list