Turi Create
4.0
|
#include <toolkits/object_detection/od_darknet_yolo_model_trainer.hpp>
Public Member Functions | |
DarknetYOLOCheckpoint (Config config, const std::string &pretrained_model_path, int random_seed) | |
DarknetYOLOCheckpoint (Config config, neural_net::float_array_map weights) | |
std::unique_ptr< ModelTrainer > | CreateModelTrainer (neural_net::compute_context *context) const override |
neural_net::pipeline_spec | ExportToCoreML (const std::string &input_name, const std::string &coordinates_name, const std::string &confidence_name, bool use_nms_layer, float iou_threshold, float confidence_threshold) const override |
neural_net::float_array_map | internal_config () const |
neural_net::float_array_map | internal_weights () const |
Subclass of Checkpoint that generates DarknetYOLOModelTrainer instances.
Definition at line 127 of file od_darknet_yolo_model_trainer.hpp.
turi::object_detection::DarknetYOLOCheckpoint::DarknetYOLOCheckpoint | ( | Config | config, |
const std::string & | pretrained_model_path, | ||
int | random_seed | ||
) |
Initializes a new model, combining the pre-trained warm-start weights with random initialization for the final layers.
turi::object_detection::DarknetYOLOCheckpoint::DarknetYOLOCheckpoint | ( | Config | config, |
neural_net::float_array_map | weights | ||
) |
Loads weights saved from a DarknetYOLOModelTrainer.
|
overridevirtual |
Loads the checkpoint into an active ModelTrainer instance.
Implements turi::object_detection::Checkpoint.
|
overridevirtual |
Returns the CoreML spec corresponding to the current model.
The result must be a pipeline that accepts an image input and yields at least two outputs, all with the given names. The outputs must be suitable for passing directly into a NonMaximumSuppression model.
Implements turi::object_detection::Checkpoint.
neural_net::float_array_map turi::object_detection::DarknetYOLOCheckpoint::internal_config | ( | ) | const |
Returns the config dictionary used to initialize darknet-yolo backends.
neural_net::float_array_map turi::object_detection::DarknetYOLOCheckpoint::internal_weights | ( | ) | const |
Returns the weights with the keys expected by the backends.