8 #ifndef TOOLKITS_STYLE_TRANSFER_ST_MODEL_TRAINER_HPP_ 9 #define TOOLKITS_STYLE_TRANSFER_ST_MODEL_TRAINER_HPP_ 19 #include <ml/neural_net/compute_context.hpp> 20 #include <ml/neural_net/model_spec.hpp> 21 #include <toolkits/style_transfer/style_transfer_data_iterator.hpp> 24 namespace style_transfer {
38 std::vector<st_example> examples;
51 neural_net::float_array_map encoded_data;
63 float smoothed_loss = 0.f;
68 float style_loss = 0.f;
69 float content_loss = 0.f;
82 int max_iterations = -1;
88 int training_image_height = 256;
91 int training_image_width = 256;
113 : impl_(
std::move(impl)),
114 batch_size_(batch_size),
115 last_iteration_id_(offset) {}
117 bool HasNext()
const override {
return impl_->has_next_batch(); }
122 std::unique_ptr<data_iterator> impl_;
123 size_t batch_size_ = 1;
124 int last_iteration_id_ = 0;
134 std::vector<int> style_indices);
136 bool HasNext()
const override;
140 std::shared_ptr<DataIterator> base_iterator_;
141 std::vector<int> style_indices_;
142 std::vector<int>::const_iterator next_style_;
153 : smoothed_loss_(std::move(smoothed_loss)) {}
158 std::unique_ptr<float> smoothed_loss_;
169 : config_(std::move(config)), weights_(std::move(weights)) {}
173 const Config& config()
const {
return config_; }
174 const neural_net::float_array_map& weights()
const {
return weights_; }
177 virtual std::unique_ptr<ModelTrainer> CreateModelTrainer()
const = 0;
189 static neural_net::float_array_map ExtractWeights(
190 std::unique_ptr<neural_net::model_spec> nn_spec);
194 neural_net::float_array_map weights_;
209 const Config& config()
const {
return config_; }
215 virtual bool SupportsLossComponents()
const = 0;
218 virtual std::shared_ptr<neural_net::Publisher<TrainingProgress>>
219 AsTrainingBatchPublisher(std::unique_ptr<data_iterator> training_data,
220 const std::string& vgg_mlmodel_path,
int offset,
221 std::unique_ptr<float> initial_training_loss,
225 virtual std::shared_ptr<neural_net::Publisher<DataBatch>>
226 AsInferenceBatchPublisher(std::unique_ptr<data_iterator> test_data,
227 std::vector<int> style_indices,
231 virtual std::shared_ptr<neural_net::Publisher<std::unique_ptr<Checkpoint>>>
232 AsCheckpointPublisher() = 0;
237 virtual std::shared_ptr<neural_net::model_backend> CreateTrainingBackend(
238 const std::string& vgg_mlmodel_path,
240 virtual std::shared_ptr<neural_net::model_backend> CreateInferenceBackend(
265 #endif // TOOLKITS_STYLE_TRANSFER_ST_MODEL_TRAINER_HPP_
EncodedInferenceBatch EncodeInferenceBatch(DataBatch batch)
EncodedBatch EncodeTrainingBatch(DataBatch batch, int width, int height)
DataIterator(std::unique_ptr< data_iterator > impl, size_t batch_size, int offset=0)
DataBatch DecodeInferenceBatch(EncodedInferenceBatch batch)
bool HasNext() const override