Turi Create  4.0
turi::object_detection::ModelTrainer Class Referenceabstract

#include <toolkits/object_detection/od_model_trainer.hpp>

Public Member Functions

virtual std::shared_ptr< neural_net::Publisher< TrainingOutputBatch > > AsTrainingBatchPublisher (std::unique_ptr< data_iterator > training_data, size_t batch_size, int offset)
 
virtual std::shared_ptr< neural_net::Publisher< EncodedBatch > > AsInferenceBatchPublisher (std::unique_ptr< data_iterator > test_data, size_t batch_size, float confidence_threshold, float iou_threshold)=0
 
virtual InferenceOutputBatch DecodeOutputBatch (EncodedBatch batch, float confidence_threshold, float iou_threshold)=0
 
virtual std::shared_ptr< neural_net::Publisher< std::unique_ptr< Checkpoint > > > AsCheckpointPublisher ()=0
 

Detailed Description

Abstract base class for object-detection model trainers.

Responsible for constructing the model-agnostic portions of the overall training pipeline.

Definition at line 252 of file od_model_trainer.hpp.

Member Function Documentation

◆ AsCheckpointPublisher()

virtual std::shared_ptr<neural_net::Publisher<std::unique_ptr<Checkpoint> > > turi::object_detection::ModelTrainer::AsCheckpointPublisher ( )
pure virtual

Returns a publisher that can be used to request checkpoints.

Implemented in turi::object_detection::DarknetYOLOModelTrainer.

◆ AsInferenceBatchPublisher()

virtual std::shared_ptr<neural_net::Publisher<EncodedBatch> > turi::object_detection::ModelTrainer::AsInferenceBatchPublisher ( std::unique_ptr< data_iterator test_data,
size_t  batch_size,
float  confidence_threshold,
float  iou_threshold 
)
pure virtual

Given a data iterator, return a publisher of inference model outputs.

Implemented in turi::object_detection::DarknetYOLOModelTrainer.

◆ AsTrainingBatchPublisher()

virtual std::shared_ptr<neural_net::Publisher<TrainingOutputBatch> > turi::object_detection::ModelTrainer::AsTrainingBatchPublisher ( std::unique_ptr< data_iterator training_data,
size_t  batch_size,
int  offset 
)
virtual

Given a data iterator, return a publisher of model outputs.

Reimplemented in turi::object_detection::DarknetYOLOModelTrainer.

◆ DecodeOutputBatch()

virtual InferenceOutputBatch turi::object_detection::ModelTrainer::DecodeOutputBatch ( EncodedBatch  batch,
float  confidence_threshold,
float  iou_threshold 
)
pure virtual

Convert the raw output of the inference batch publisher into structured predictions.

Implemented in turi::object_detection::DarknetYOLOModelTrainer.


The documentation for this class was generated from the following file: