Turi Create  4.0
turi::style_transfer::ModelTrainer Class Referenceabstract

#include <toolkits/style_transfer/st_model_trainer.hpp>

Public Member Functions

virtual bool SupportsLossComponents () const =0
 
virtual std::shared_ptr< neural_net::Publisher< TrainingProgress > > AsTrainingBatchPublisher (std::unique_ptr< data_iterator > training_data, const std::string &vgg_mlmodel_path, int offset, std::unique_ptr< float > initial_training_loss, neural_net::compute_context *context)
 
virtual std::shared_ptr< neural_net::Publisher< DataBatch > > AsInferenceBatchPublisher (std::unique_ptr< data_iterator > test_data, std::vector< int > style_indices, neural_net::compute_context *context)
 
virtual std::shared_ptr< neural_net::Publisher< std::unique_ptr< Checkpoint > > > AsCheckpointPublisher ()=0
 

Detailed Description

Abstract base class for style-transfer model trainers.

Responsible for constructing the model-agnostic portions of the overall training pipeline.

Definition at line 203 of file st_model_trainer.hpp.

Member Function Documentation

◆ AsCheckpointPublisher()

virtual std::shared_ptr<neural_net::Publisher<std::unique_ptr<Checkpoint> > > turi::style_transfer::ModelTrainer::AsCheckpointPublisher ( )
pure virtual

Returns a publisher that can be used to request checkpoints.

Implemented in turi::style_transfer::ResNet16ModelTrainer.

◆ AsInferenceBatchPublisher()

virtual std::shared_ptr<neural_net::Publisher<DataBatch> > turi::style_transfer::ModelTrainer::AsInferenceBatchPublisher ( std::unique_ptr< data_iterator >  test_data,
std::vector< int >  style_indices,
neural_net::compute_context context 
)
virtual

Given a data iterator, return a publisher of inference model outputs.

◆ AsTrainingBatchPublisher()

virtual std::shared_ptr<neural_net::Publisher<TrainingProgress> > turi::style_transfer::ModelTrainer::AsTrainingBatchPublisher ( std::unique_ptr< data_iterator >  training_data,
const std::string &  vgg_mlmodel_path,
int  offset,
std::unique_ptr< float >  initial_training_loss,
neural_net::compute_context context 
)
virtual

Given a data iterator, return a publisher of training model outputs.

◆ SupportsLossComponents()

virtual bool turi::style_transfer::ModelTrainer::SupportsLossComponents ( ) const
pure virtual

Returns true iff the output from the training batch publisher sets the style_loss and content_loss values.

Implemented in turi::style_transfer::ResNet16ModelTrainer.


The documentation for this class was generated from the following file: