Turi Create
4.0
|
#include <toolkits/style_transfer/st_model_trainer.hpp>
Public Member Functions | |
virtual bool | SupportsLossComponents () const =0 |
virtual std::shared_ptr< neural_net::Publisher< TrainingProgress > > | AsTrainingBatchPublisher (std::unique_ptr< data_iterator > training_data, const std::string &vgg_mlmodel_path, int offset, std::unique_ptr< float > initial_training_loss, neural_net::compute_context *context) |
virtual std::shared_ptr< neural_net::Publisher< DataBatch > > | AsInferenceBatchPublisher (std::unique_ptr< data_iterator > test_data, std::vector< int > style_indices, neural_net::compute_context *context) |
virtual std::shared_ptr< neural_net::Publisher< std::unique_ptr< Checkpoint > > > | AsCheckpointPublisher ()=0 |
Abstract base class for style-transfer model trainers.
Responsible for constructing the model-agnostic portions of the overall training pipeline.
Definition at line 203 of file st_model_trainer.hpp.
|
pure virtual |
Returns a publisher that can be used to request checkpoints.
Implemented in turi::style_transfer::ResNet16ModelTrainer.
|
virtual |
Given a data iterator, return a publisher of inference model outputs.
|
virtual |
Given a data iterator, return a publisher of training model outputs.
|
pure virtual |
Returns true iff the output from the training batch publisher sets the style_loss and content_loss values.
Implemented in turi::style_transfer::ResNet16ModelTrainer.