8 #ifndef TOOLKITS_OBJECT_DETECTION_OD_DARKNET_YOLO_MODEL_TRAINER_HPP_ 9 #define TOOLKITS_OBJECT_DETECTION_OD_DARKNET_YOLO_MODEL_TRAINER_HPP_ 18 #include <ml/neural_net/compute_context.hpp> 19 #include <ml/neural_net/model_backend.hpp> 20 #include <ml/neural_net/model_spec.hpp> 24 namespace object_detection {
31 int batch_size,
int output_height,
int output_width);
38 int batch_size,
int output_height,
int output_width);
45 size_t output_height,
size_t output_width,
46 size_t num_anchors,
size_t num_classes);
52 float confidence_threshold,
68 std::shared_ptr<neural_net::model_backend> impl,
float base_learning_rate,
70 : impl_(std::move(impl)),
71 base_learning_rate_(base_learning_rate),
72 max_iterations_(max_iterations) {}
77 void ApplyLearningRateSchedule(
int iteration_id);
79 std::shared_ptr<neural_net::model_backend> impl_;
80 float base_learning_rate_ = 0.f;
81 int max_iterations_ = 0;
95 std::shared_ptr<neural_net::model_backend> impl)
96 : impl_(std::move(impl)) {}
101 std::shared_ptr<neural_net::model_backend> impl_;
111 std::shared_ptr<neural_net::model_backend> impl)
112 : config_(config), impl_(std::move(impl)) {}
114 bool HasNext()
const override {
return impl_ !=
nullptr; }
116 std::unique_ptr<Checkpoint> Next()
override;
120 std::shared_ptr<neural_net::model_backend> impl_;
138 const Config& config()
const override;
139 const neural_net::float_array_map& weights()
const override;
141 std::unique_ptr<ModelTrainer> CreateModelTrainer(
145 const std::string& coordinates_name,
146 const std::string& confidence_name,
bool use_nms_layer,
148 float confidence_threshold)
const override;
153 neural_net::float_array_map internal_config()
const;
156 neural_net::float_array_map internal_weights()
const;
161 std::unique_ptr<neural_net::model_spec> model_spec_;
162 neural_net::float_array_map weights_;
174 std::shared_ptr<neural_net::Publisher<TrainingOutputBatch>>
175 AsTrainingBatchPublisher(std::unique_ptr<data_iterator> training_data,
176 size_t batch_size,
int offset)
override;
178 std::shared_ptr<neural_net::Publisher<EncodedBatch>>
179 AsInferenceBatchPublisher(std::unique_ptr<data_iterator> test_data,
180 size_t batch_size,
float confidence_threshold,
181 float iou_threshold)
override;
184 float confidence_threshold,
185 float iou_threshold)
override;
187 std::shared_ptr<neural_net::Publisher<std::unique_ptr<Checkpoint>>>
188 AsCheckpointPublisher()
override;
191 std::shared_ptr<neural_net::Publisher<TrainingOutputBatch>>
193 augmented_data)
override;
197 std::shared_ptr<neural_net::model_backend> backend_;
198 std::shared_ptr<DataAugmenter> training_augmenter_;
199 std::shared_ptr<DataAugmenter> inference_augmenter_;
205 #endif // TOOLKITS_OBJECT_DETECTION_OD_DARKNET_YOLO_MODEL_TRAINER_HPP_
InferenceOutputBatch DecodeDarknetYOLOInference(EncodedBatch batch, float confidence_threshold, float iou_threshold)
bool HasNext() const override
neural_net::image_augmenter::options DarknetYOLOInferenceAugmentationOptions(int batch_size, int output_height, int output_width)
neural_net::image_augmenter::options DarknetYOLOTrainingAugmentationOptions(int batch_size, int output_height, int output_width)
EncodedInputBatch EncodeDarknetYOLO(InputBatch input_batch, size_t output_height, size_t output_width, size_t num_anchors, size_t num_classes)