8 #ifndef TOOLKITS_OBJECT_DETECTION_OD_MODEL_TRAINER_HPP_ 9 #define TOOLKITS_OBJECT_DETECTION_OD_MODEL_TRAINER_HPP_ 20 #include <ml/neural_net/compute_context.hpp> 21 #include <ml/neural_net/model_spec.hpp> 22 #include <toolkits/object_detection/od_data_iterator.hpp> 25 namespace object_detection {
34 std::vector<neural_net::labeled_image> examples;
43 neural_net::shared_float_array
images;
46 std::vector<std::vector<neural_net::image_annotation>>
annotations;
60 neural_net::shared_float_array images;
61 neural_net::shared_float_array labels;
65 std::vector<std::vector<neural_net::image_annotation>> annotations;
68 std::vector<std::pair<size_t, size_t>> image_sizes;
75 neural_net::shared_float_array loss;
81 float smoothed_loss = 0.f;
91 neural_net::float_array_map encoded_data;
93 std::vector<std::vector<neural_net::image_annotation>> annotations;
94 std::vector<std::pair<size_t, size_t>> image_sizes;
101 std::vector<std::vector<neural_net::image_annotation>> predictions;
103 std::vector<std::vector<neural_net::image_annotation>> annotations;
104 std::vector<std::pair<size_t, size_t>> image_sizes;
114 int max_iterations = -1;
124 int output_height = 13;
127 int output_width = 13;
130 int num_classes = -1;
136 size_t num_predictions = 0;
139 std::string model_type =
"";
142 float evaluate_confidence = 0.f;
145 float predict_confidence = 0.f;
148 float nms_threshold = 0.f;
151 bool use_most_confident_class =
false;
163 virtual const Config& config()
const = 0;
164 virtual const neural_net::float_array_map& weights()
const = 0;
167 virtual std::unique_ptr<ModelTrainer> CreateModelTrainer(
178 const std::string& coordinates_name,
179 const std::string& confidence_name,
180 bool use_nms_layer,
float iou_threshold,
181 float confidence_threshold)
const = 0;
201 : impl_(
std::move(impl)),
202 batch_size_(batch_size),
203 last_iteration_id_(offset) {}
205 bool HasNext()
const override {
return impl_->has_next_batch(); }
210 std::unique_ptr<data_iterator> impl_;
211 size_t batch_size_ = 32;
212 int last_iteration_id_ = 0;
218 DataAugmenter(std::unique_ptr<neural_net::image_augmenter> impl)
219 : impl_(std::move(impl)) {}
224 std::unique_ptr<neural_net::image_augmenter> impl_;
238 : smoothed_loss_(std::move(smoothed_loss)) {}
243 std::unique_ptr<float> smoothed_loss_;
257 ModelTrainer(std::unique_ptr<neural_net::image_augmenter> augmenter);
266 virtual std::shared_ptr<neural_net::Publisher<TrainingOutputBatch>>
267 AsTrainingBatchPublisher(std::unique_ptr<data_iterator> training_data,
268 size_t batch_size,
int offset);
275 virtual std::shared_ptr<neural_net::Publisher<EncodedBatch>>
276 AsInferenceBatchPublisher(std::unique_ptr<data_iterator> test_data,
277 size_t batch_size,
float confidence_threshold,
278 float iou_threshold) = 0;
288 float confidence_threshold,
289 float iou_threshold) = 0;
292 virtual std::shared_ptr<neural_net::Publisher<std::unique_ptr<Checkpoint>>>
293 AsCheckpointPublisher() = 0;
300 virtual std::shared_ptr<neural_net::Publisher<TrainingOutputBatch>>
301 AsTrainingBatchPublisher(
305 std::shared_ptr<DataAugmenter> augmenter_;
311 #endif // TOOLKITS_OBJECT_DETECTION_OD_MODEL_TRAINER_HPP_
bool HasNext() const override
DataIterator(std::unique_ptr< data_iterator > impl, size_t batch_size, int offset=0)