Turi Create  4.0
od_darknet_yolo_model_trainer.hpp
Go to the documentation of this file.
1 /* Copyright © 2020 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at
5  * https://opensource.org/licenses/BSD-3-Clause
6  */
7 
8 #ifndef TOOLKITS_OBJECT_DETECTION_OD_DARKNET_YOLO_MODEL_TRAINER_HPP_
9 #define TOOLKITS_OBJECT_DETECTION_OD_DARKNET_YOLO_MODEL_TRAINER_HPP_
10 
11 /**
12  * \file od_darknet_yolo_model_trainer.hpp
13  *
14  * Defines helper functions and the Model subclass for the darknet-yolo
15  * architecture.
16  */
17 
18 #include <ml/neural_net/compute_context.hpp>
19 #include <ml/neural_net/model_backend.hpp>
20 #include <ml/neural_net/model_spec.hpp>
22 
23 namespace turi {
24 namespace object_detection {
25 
26 /**
27  * Configures an image_augmenter for inference given darknet-yolo network
28  * parameters.
29  */
30 neural_net::image_augmenter::options DarknetYOLOInferenceAugmentationOptions(
31  int batch_size, int output_height, int output_width);
32 
33 /**
34  * Configures an image_augmenter for training given darknet-yolo network
35  * parameters.
36  */
37 neural_net::image_augmenter::options DarknetYOLOTrainingAugmentationOptions(
38  int batch_size, int output_height, int output_width);
39 
40 /**
41  * Encodes the annotations of an input batch into the format expected by the
42  * darknet-yolo network.
43  */
44 EncodedInputBatch EncodeDarknetYOLO(InputBatch input_batch,
45  size_t output_height, size_t output_width,
46  size_t num_anchors, size_t num_classes);
47 
48 /**
49  * Decodes the raw inference output into structured predictions.
50  */
51 InferenceOutputBatch DecodeDarknetYOLOInference(EncodedBatch batch,
52  float confidence_threshold,
53  float iou_threshold);
54 
55 /**
56  * Wrapper that integrates a darknet-yolo model_backend into a training
57  * pipeline.
58  *
59  * \todo Once model_backend exposes support for explicit asynchronous
60  * invocations, this class won't be able to simply use the Transform base class.
61  */
63  : public neural_net::Transform<EncodedInputBatch, TrainingOutputBatch> {
64  public:
65  // Uses base_learning_rate and max_iterations to determine the learning-rate
66  // schedule.
68  std::shared_ptr<neural_net::model_backend> impl, float base_learning_rate,
69  int max_iterations)
70  : impl_(std::move(impl)),
71  base_learning_rate_(base_learning_rate),
72  max_iterations_(max_iterations) {}
73 
74  TrainingOutputBatch Invoke(EncodedInputBatch input_batch) override;
75 
76  private:
77  void ApplyLearningRateSchedule(int iteration_id);
78 
79  std::shared_ptr<neural_net::model_backend> impl_;
80  float base_learning_rate_ = 0.f;
81  int max_iterations_ = 0;
82 };
83 
84 /**
85  * Wrapper that integrates a darknet-yolo model_backend into an inference
86  * pipeline.
87  *
88  * \todo Once model_backend exposes support for explicit asynchronous
89  * invocations, this class won't be able to simply use the Transform base class.
90  */
92  : public neural_net::Transform<EncodedInputBatch, EncodedBatch> {
93  public:
95  std::shared_ptr<neural_net::model_backend> impl)
96  : impl_(std::move(impl)) {}
97 
98  EncodedBatch Invoke(EncodedInputBatch input_batch) override;
99 
100  private:
101  std::shared_ptr<neural_net::model_backend> impl_;
102 };
103 
104 /**
105  * Wrapper for a darknet-yolo model_backend that publishes checkpoints.
106  */
108  : public neural_net::Iterator<std::unique_ptr<Checkpoint>> {
109  public:
110  DarknetYOLOCheckpointer(const Config& config,
111  std::shared_ptr<neural_net::model_backend> impl)
112  : config_(config), impl_(std::move(impl)) {}
113 
114  bool HasNext() const override { return impl_ != nullptr; }
115 
116  std::unique_ptr<Checkpoint> Next() override;
117 
118  private:
119  Config config_;
120  std::shared_ptr<neural_net::model_backend> impl_;
121 };
122 
123 /**
124  * Subclass of Checkpoint that generates DarknetYOLOModelTrainer
125  * instances.
126  */
128  public:
129  /**
130  * Initializes a new model, combining the pre-trained warm-start weights with
131  * random initialization for the final layers.
132  */
133  DarknetYOLOCheckpoint(Config config, const std::string& pretrained_model_path, int random_seed);
134 
135  /** Loads weights saved from a DarknetYOLOModelTrainer. */
136  DarknetYOLOCheckpoint(Config config, neural_net::float_array_map weights);
137 
138  const Config& config() const override;
139  const neural_net::float_array_map& weights() const override;
140 
141  std::unique_ptr<ModelTrainer> CreateModelTrainer(
142  neural_net::compute_context* context) const override;
143 
144  neural_net::pipeline_spec ExportToCoreML(const std::string& input_name,
145  const std::string& coordinates_name,
146  const std::string& confidence_name, bool use_nms_layer,
147  float iou_threshold,
148  float confidence_threshold) const override;
149 
150  CheckpointMetadata GetCheckpointMetadata() const override;
151 
152  /** Returns the config dictionary used to initialize darknet-yolo backends. */
153  neural_net::float_array_map internal_config() const;
154 
155  /** Returns the weights with the keys expected by the backends. */
156  neural_net::float_array_map internal_weights() const;
157 
158  private:
159  Config config_;
160 
161  std::unique_ptr<neural_net::model_spec> model_spec_;
162  neural_net::float_array_map weights_;
163 };
164 
165 /** Subclass of ModelTrainer encapsulating the darknet-yolo architecture. */
167  public:
168  /**
169  * Initializes a model from a checkpoint.
170  */
172  neural_net::compute_context* context);
173 
174  std::shared_ptr<neural_net::Publisher<TrainingOutputBatch>>
175  AsTrainingBatchPublisher(std::unique_ptr<data_iterator> training_data,
176  size_t batch_size, int offset) override;
177 
178  std::shared_ptr<neural_net::Publisher<EncodedBatch>>
179  AsInferenceBatchPublisher(std::unique_ptr<data_iterator> test_data,
180  size_t batch_size, float confidence_threshold,
181  float iou_threshold) override;
182 
183  InferenceOutputBatch DecodeOutputBatch(EncodedBatch batch,
184  float confidence_threshold,
185  float iou_threshold) override;
186 
187  std::shared_ptr<neural_net::Publisher<std::unique_ptr<Checkpoint>>>
188  AsCheckpointPublisher() override;
189 
190  protected:
191  std::shared_ptr<neural_net::Publisher<TrainingOutputBatch>>
192  AsTrainingBatchPublisher(std::shared_ptr<neural_net::Publisher<InputBatch>>
193  augmented_data) override;
194 
195  private:
196  Config config_;
197  std::shared_ptr<neural_net::model_backend> backend_;
198  std::shared_ptr<DataAugmenter> training_augmenter_;
199  std::shared_ptr<DataAugmenter> inference_augmenter_;
200 };
201 
202 } // namespace object_detection
203 } // namespace turi
204 
205 #endif // TOOLKITS_OBJECT_DETECTION_OD_DARKNET_YOLO_MODEL_TRAINER_HPP_
InferenceOutputBatch DecodeDarknetYOLOInference(EncodedBatch batch, float confidence_threshold, float iou_threshold)
neural_net::image_augmenter::options DarknetYOLOInferenceAugmentationOptions(int batch_size, int output_height, int output_width)
neural_net::image_augmenter::options DarknetYOLOTrainingAugmentationOptions(int batch_size, int output_height, int output_width)
EncodedInputBatch EncodeDarknetYOLO(InputBatch input_batch, size_t output_height, size_t output_width, size_t num_anchors, size_t num_classes)