Turi Create  4.0
st_model_trainer.hpp
Go to the documentation of this file.
1 /* Copyright © 2020 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at
5  * https://opensource.org/licenses/BSD-3-Clause
6  */
7 
8 #ifndef TOOLKITS_STYLE_TRANSFER_ST_MODEL_TRAINER_HPP_
9 #define TOOLKITS_STYLE_TRANSFER_ST_MODEL_TRAINER_HPP_
10 
11 /**
12  * \file st_model_trainer.hpp
13  *
14  * Defines the value types representing each stage of a style-transfer training
15  * pipeline, and the virtual interface for arbitrary style-transfer models.
16  */
17 
19 #include <ml/neural_net/compute_context.hpp>
20 #include <ml/neural_net/model_spec.hpp>
21 #include <toolkits/style_transfer/style_transfer_data_iterator.hpp>
22 
23 namespace turi {
24 namespace style_transfer {
25 
26 class ModelTrainer;
27 
28 /**
29  * Represents one batch of content/style image pairs.
30  *
31  * Also used for inference, in which case the "style" image is the stylized
32  * output.
33  */
34 struct DataBatch {
35  /** The serial number for this batch, starting with 1. */
36  int iteration_id = 0;
37 
38  std::vector<st_example> examples;
39 };
40 
41 /**
42  * Represents the immediate (model-specific) input or output of a model backend,
43  * using the generic float_array_map representation.
44  *
45  * \todo Define types for input and output batches that don't rely on the
46  * arbitrary keys in float_array_map.
47  */
48 struct EncodedBatch {
49  int iteration_id = 0;
50 
51  neural_net::float_array_map encoded_data;
52 };
53 
54 /** EncodedBatch that also records the style index used for inference. */
56  int style_index = -1;
57 };
58 
59 /** Represents the output conveyed to the user. */
61  int iteration_id = 0;
62 
63  float smoothed_loss = 0.f;
64 
65  // These are only set if the ModelTrainer returns true for
66  // SupportsLossComponents().
67  // TODO: Should these also be smoothed?
68  float style_loss = 0.f;
69  float content_loss = 0.f;
70 };
71 
72 /** Model-agnostic parameters for style transfer. */
73 struct Config {
74  /** Determines the number of style images used during training. */
75  int num_styles = 1;
76 
77  /**
78  * The target number of training iterations to perform.
79  *
80  * If -1, then this target should be computed heuristically.
81  */
82  int max_iterations = -1;
83 
84  /** The number of images to process per training batch. */
85  int batch_size = 1;
86 
87  /** The height of images passed into the training backend. */
88  int training_image_height = 256;
89 
90  /** The width of images passed into the training backend. */
91  int training_image_width = 256;
92 
93  /** Random seed used to initialize the model. */
94  int random_seed = 0;
95 };
96 
97 /**
98  * Wrapper adapting style_transfer::data_iterator to the Iterator interface.
99  */
100 class DataIterator : public neural_net::Iterator<DataBatch> {
101  public:
102  /**
103  * \param impl The style_transfer::data_iterator to wrap
104  * \param batch_size The number of images to request from impl for each batch.
105  * \param offset The number of batches to skip. The first batch produced will
106  * have an iteration_id one more than the offset.
107  *
108  * \todo style_transfer::data_iterator needs to support specifying the
109  * offset (and doing the right thing with random seeding)
110  */
111  DataIterator(std::unique_ptr<data_iterator> impl, size_t batch_size,
112  int offset = 0)
113  : impl_(std::move(impl)),
114  batch_size_(batch_size),
115  last_iteration_id_(offset) {}
116 
117  bool HasNext() const override { return impl_->has_next_batch(); }
118 
119  DataBatch Next() override;
120 
121  private:
122  std::unique_ptr<data_iterator> impl_;
123  size_t batch_size_ = 1;
124  int last_iteration_id_ = 0; // Next ID starts at 1, not 0, by default.
125 };
126 
127 /**
128  * Wrapper around DataIterator that duplicates each batch, with each duplicate
129  * writing a different style index into every example for each duplicate.
130  */
131 class InferenceDataIterator : public neural_net::Iterator<DataBatch> {
132  public:
133  InferenceDataIterator(std::shared_ptr<DataIterator> base_iterator,
134  std::vector<int> style_indices);
135 
136  bool HasNext() const override;
137  DataBatch Next() override;
138 
139  private:
140  std::shared_ptr<DataIterator> base_iterator_;
141  std::vector<int> style_indices_;
142  std::vector<int>::const_iterator next_style_;
143  DataBatch current_batch_;
144 };
145 
146 /**
147  * Converts raw training output to user-visible progress updates.
148  */
150  : public neural_net::Transform<EncodedBatch, TrainingProgress> {
151  public:
152  ProgressUpdater(std::unique_ptr<float> smoothed_loss)
153  : smoothed_loss_(std::move(smoothed_loss)) {}
154 
155  TrainingProgress Invoke(EncodedBatch output_batch) override;
156 
157  private:
158  std::unique_ptr<float> smoothed_loss_;
159 };
160 
161 /**
162  * A representation of all the parameters needed to reconstruct a model.
163  *
164  * \todo Include optimizer state to allow training to resume seamlessly.
165  */
166 class Checkpoint {
167  public:
168  Checkpoint(Config config, neural_net::float_array_map weights)
169  : config_(std::move(config)), weights_(std::move(weights)) {}
170 
171  virtual ~Checkpoint() = default;
172 
173  const Config& config() const { return config_; }
174  const neural_net::float_array_map& weights() const { return weights_; }
175 
176  /** Loads the checkpoint into an active ModelTrainer instance. */
177  virtual std::unique_ptr<ModelTrainer> CreateModelTrainer() const = 0;
178 
179  /**
180  * Returns the CoreML spec corresponding to the current model.
181  *
182  * The first layer of the model should have a single input: the image to
183  * stylize. The last layer of the model should have a single output: the
184  * stylized image.
185  */
186  virtual neural_net::model_spec ExportToCoreML() const = 0;
187 
188  protected:
189  static neural_net::float_array_map ExtractWeights(
190  std::unique_ptr<neural_net::model_spec> nn_spec);
191 
192  private:
193  Config config_;
194  neural_net::float_array_map weights_;
195 };
196 
197 /**
198  * Abstract base class for style-transfer model trainers.
199  *
200  * Responsible for constructing the model-agnostic portions of the overall
201  * training pipeline.
202  */
204  public:
205  ModelTrainer(Config config) : config_(std::move(config)) {}
206 
207  virtual ~ModelTrainer() = default;
208 
209  const Config& config() const { return config_; }
210 
211  /**
212  * Returns true iff the output from the training batch publisher sets the
213  * style_loss and content_loss values.
214  */
215  virtual bool SupportsLossComponents() const = 0;
216 
217  /** Given a data iterator, return a publisher of training model outputs. */
218  virtual std::shared_ptr<neural_net::Publisher<TrainingProgress>>
219  AsTrainingBatchPublisher(std::unique_ptr<data_iterator> training_data,
220  const std::string& vgg_mlmodel_path, int offset,
221  std::unique_ptr<float> initial_training_loss,
222  neural_net::compute_context* context);
223 
224  /** Given a data iterator, return a publisher of inference model outputs. */
225  virtual std::shared_ptr<neural_net::Publisher<DataBatch>>
226  AsInferenceBatchPublisher(std::unique_ptr<data_iterator> test_data,
227  std::vector<int> style_indices,
228  neural_net::compute_context* context);
229 
230  /** Returns a publisher that can be used to request checkpoints. */
231  virtual std::shared_ptr<neural_net::Publisher<std::unique_ptr<Checkpoint>>>
232  AsCheckpointPublisher() = 0;
233 
234  protected:
235  // TODO: Style transfer backends should support both training and inference.
236  // Then we would only need one.
237  virtual std::shared_ptr<neural_net::model_backend> CreateTrainingBackend(
238  const std::string& vgg_mlmodel_path,
239  neural_net::compute_context* context) = 0;
240  virtual std::shared_ptr<neural_net::model_backend> CreateInferenceBackend(
241  neural_net::compute_context* context) = 0;
242 
243  private:
244  Config config_;
245 };
246 
247 /**
248  * Converts native images into tensors that can be fed into the model backend.
249  */
250 EncodedBatch EncodeTrainingBatch(DataBatch batch, int width, int height);
251 
252 /**
253  * Converts native images into tensors that can be fed into the model backend.
254  */
256 
257 /**
258  * Converts the raw output from an inference backend into images.
259  */
261 
262 } // namespace style_transfer
263 } // namespace turi
264 
265 #endif // TOOLKITS_STYLE_TRANSFER_ST_MODEL_TRAINER_HPP_
EncodedInferenceBatch EncodeInferenceBatch(DataBatch batch)
EncodedBatch EncodeTrainingBatch(DataBatch batch, int width, int height)
STL namespace.
DataIterator(std::unique_ptr< data_iterator > impl, size_t batch_size, int offset=0)
DataBatch DecodeInferenceBatch(EncodedInferenceBatch batch)