8 #ifndef TOOLKITS_STYLE_TRANSFER_ST_RESNET16_MODEL_TRAINER_HPP_ 9 #define TOOLKITS_STYLE_TRANSFER_ST_RESNET16_MODEL_TRAINER_HPP_ 14 namespace style_transfer {
45 std::shared_ptr<neural_net::Publisher<std::unique_ptr<Checkpoint>>>
46 AsCheckpointPublisher()
override;
49 std::shared_ptr<neural_net::model_backend> CreateTrainingBackend(
50 const std::string& vgg_mlmodel_path,
53 std::shared_ptr<neural_net::model_backend> CreateInferenceBackend(
59 std::shared_ptr<neural_net::model_backend> training_backend;
62 neural_net::float_array_map weights;
65 static neural_net::float_array_map GetWeights(
const ModelState& state);
68 std::shared_ptr<ModelState> state_;
74 #endif // TOOLKITS_STYLE_TRANSFER_ST_RESNET16_MODEL_TRAINER_HPP_
neural_net::model_spec ExportToCoreML() const override
std::unique_ptr< ModelTrainer > CreateModelTrainer() const override
bool SupportsLossComponents() const override
ResNet16Checkpoint(Config config, const std::string &resnet_mlmodel_path)