Turi Create
4.0
|
#include <toolkits/style_transfer/st_resnet16_model_trainer.hpp>
Public Member Functions | |
ResNet16ModelTrainer (Config config, neural_net::float_array_map weights) | |
bool | SupportsLossComponents () const override |
std::shared_ptr< neural_net::Publisher< std::unique_ptr< Checkpoint > > > | AsCheckpointPublisher () override |
virtual std::shared_ptr< neural_net::Publisher< TrainingProgress > > | AsTrainingBatchPublisher (std::unique_ptr< data_iterator > training_data, const std::string &vgg_mlmodel_path, int offset, std::unique_ptr< float > initial_training_loss, neural_net::compute_context *context) |
virtual std::shared_ptr< neural_net::Publisher< DataBatch > > | AsInferenceBatchPublisher (std::unique_ptr< data_iterator > test_data, std::vector< int > style_indices, neural_net::compute_context *context) |
Subclass of ModelTrainer encapsulating the resnet-16 architecture.
Definition at line 35 of file st_resnet16_model_trainer.hpp.
turi::style_transfer::ResNet16ModelTrainer::ResNet16ModelTrainer | ( | Config | config, |
neural_net::float_array_map | weights | ||
) |
Initializes a model from a checkpoint.
|
overridevirtual |
Returns a publisher that can be used to request checkpoints.
Implements turi::style_transfer::ModelTrainer.
|
virtualinherited |
Given a data iterator, return a publisher of inference model outputs.
|
virtualinherited |
Given a data iterator, return a publisher of training model outputs.
|
inlineoverridevirtual |
Returns true iff the output from the training batch publisher sets the style_loss and content_loss values.
Implements turi::style_transfer::ModelTrainer.
Definition at line 42 of file st_resnet16_model_trainer.hpp.