Turi Create  4.0
turi::style_transfer::ResNet16ModelTrainer Class Reference

#include <toolkits/style_transfer/st_resnet16_model_trainer.hpp>

Public Member Functions

 ResNet16ModelTrainer (Config config, neural_net::float_array_map weights)
 
bool SupportsLossComponents () const override
 
std::shared_ptr< neural_net::Publisher< std::unique_ptr< Checkpoint > > > AsCheckpointPublisher () override
 
virtual std::shared_ptr< neural_net::Publisher< TrainingProgress > > AsTrainingBatchPublisher (std::unique_ptr< data_iterator > training_data, const std::string &vgg_mlmodel_path, int offset, std::unique_ptr< float > initial_training_loss, neural_net::compute_context *context)
 
virtual std::shared_ptr< neural_net::Publisher< DataBatch > > AsInferenceBatchPublisher (std::unique_ptr< data_iterator > test_data, std::vector< int > style_indices, neural_net::compute_context *context)
 

Detailed Description

Subclass of ModelTrainer encapsulating the resnet-16 architecture.

Definition at line 35 of file st_resnet16_model_trainer.hpp.

Constructor & Destructor Documentation

◆ ResNet16ModelTrainer()

turi::style_transfer::ResNet16ModelTrainer::ResNet16ModelTrainer ( Config  config,
neural_net::float_array_map  weights 
)

Initializes a model from a checkpoint.

Member Function Documentation

◆ AsCheckpointPublisher()

std::shared_ptr<neural_net::Publisher<std::unique_ptr<Checkpoint> > > turi::style_transfer::ResNet16ModelTrainer::AsCheckpointPublisher ( )
overridevirtual

Returns a publisher that can be used to request checkpoints.

Implements turi::style_transfer::ModelTrainer.

◆ AsInferenceBatchPublisher()

virtual std::shared_ptr<neural_net::Publisher<DataBatch> > turi::style_transfer::ModelTrainer::AsInferenceBatchPublisher ( std::unique_ptr< data_iterator >  test_data,
std::vector< int >  style_indices,
neural_net::compute_context context 
)
virtualinherited

Given a data iterator, return a publisher of inference model outputs.

◆ AsTrainingBatchPublisher()

virtual std::shared_ptr<neural_net::Publisher<TrainingProgress> > turi::style_transfer::ModelTrainer::AsTrainingBatchPublisher ( std::unique_ptr< data_iterator >  training_data,
const std::string &  vgg_mlmodel_path,
int  offset,
std::unique_ptr< float >  initial_training_loss,
neural_net::compute_context context 
)
virtualinherited

Given a data iterator, return a publisher of training model outputs.

◆ SupportsLossComponents()

bool turi::style_transfer::ResNet16ModelTrainer::SupportsLossComponents ( ) const
inlineoverridevirtual

Returns true iff the output from the training batch publisher sets the style_loss and content_loss values.

Implements turi::style_transfer::ModelTrainer.

Definition at line 42 of file st_resnet16_model_trainer.hpp.


The documentation for this class was generated from the following file: