Turi Create  4.0
style_transfer.hpp
1 /* Copyright © 2019 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at
5  * https://opensource.org/licenses/BSD-3-Clause
6  */
7 
8 #ifndef __TOOLKITS_STYLE_TRANSFER_H_
9 #define __TOOLKITS_STYLE_TRANSFER_H_
10 
11 #include <memory>
12 
13 #include <core/data/sframe/gl_sarray.hpp>
14 #include <core/data/sframe/gl_sframe.hpp>
15 #include <core/logging/table_printer/table_printer.hpp>
16 #include <ml/neural_net/compute_context.hpp>
17 #include <ml/neural_net/model_backend.hpp>
18 #include <ml/neural_net/model_spec.hpp>
19 #include <model_server/lib/extensions/ml_model.hpp>
20 #include <toolkits/coreml_export/mlmodel_wrapper.hpp>
21 #include <toolkits/coreml_export/neural_net_models_exporter.hpp>
23 #include <toolkits/style_transfer/style_transfer_data_iterator.hpp>
24 
25 namespace turi {
26 namespace style_transfer {
27 
28 // TODO: Move these helper functions to st_model_trainer.cpp
29 std::vector<std::pair<flex_int, flex_image>> process_output(
30  const neural_net::shared_float_array& contents, size_t index);
31 neural_net::float_array_map prepare_batch(const std::vector<st_example>& batch,
32  size_t width, size_t height,
33  bool train = true);
34 neural_net::float_array_map prepare_predict(const st_example& example);
35 
36 class EXPORT style_transfer : public ml_model_base {
37  public:
38  void init_options(const std::map<std::string, flexible_type>& opts) override;
39  size_t get_version() const override;
40  void save_impl(oarchive& oarc) const override;
41  void load_version(iarchive& iarc, size_t version) override;
42 
43  virtual std::shared_ptr<coreml::MLModelWrapper> export_to_coreml(
44  std::string filename, std::string short_description,
45  std::map<std::string, flexible_type> additional_user_defined,
46  std::map<std::string, flexible_type> opts);
47 
48  void train(gl_sarray style, gl_sarray content,
49  std::map<std::string, flexible_type> opts);
50 
51  virtual void init_training(gl_sarray style, gl_sarray content,
52  std::map<std::string, flexible_type> opts);
53  virtual void resume_training(gl_sarray style, gl_sarray content,
54  std::map<std::string, flexible_type> opts);
55 
56  virtual void iterate_training();
57  virtual void synchronize_training();
58  virtual void finalize_training();
59 
60  gl_sframe predict(variant_type data,
61  std::map<std::string, flexible_type> opts);
62 
63  gl_sframe get_styles(variant_type style_index);
64 
65  void import_from_custom_model(variant_map_type model_data, size_t version);
66 
67  BEGIN_CLASS_MEMBER_REGISTRATION("style_transfer")
68  IMPORT_BASE_CLASS_REGISTRATION(ml_model_base);
69 
70  REGISTER_CLASS_MEMBER_FUNCTION(style_transfer::train, "style", "content",
71  "opts");
72 
74  style_transfer::train,
75  "\n"
76  "Options\n"
77  "-------\n"
78  "resnet_mlmodel_path : string\n"
79  " Path to the Resnet CoreML specification with the pre-trained model\n"
80  " parameters.\n"
81  "vgg_mlmodel_path: string\n"
82  " Path to the VGG16 CoreML specification with the pre-trained model\n"
83  " parameters.\n"
84  "num_styles: int\n"
85  " The defined number of styles for the style transfer model\n"
86  "batch_size : int\n"
87  " The number of images per training iteration. If 0, then it will be\n"
88  " automatically determined based on resource availability.\n"
89  "max_iterations : int\n"
90  " The number of training iterations. If 0, then it will be "
91  "automatically\n"
92  " be determined based on the amount of data you provide.\n"
93  "image_width : int\n"
94  " The input image width to the model\n"
95  "image_height : int\n"
96  " The input image height to the model\n");
97 
98  REGISTER_CLASS_MEMBER_FUNCTION(style_transfer::init_training, "style",
99  "content", "opts");
100  REGISTER_CLASS_MEMBER_FUNCTION(style_transfer::resume_training, "style",
101  "content", "opts");
102  register_defaults("resume_training",
103  {{"opts",
104  to_variant(std::map<std::string, flexible_type>())}});
105 
106  REGISTER_CLASS_MEMBER_FUNCTION(style_transfer::iterate_training);
107  REGISTER_CLASS_MEMBER_FUNCTION(style_transfer::synchronize_training);
108  REGISTER_CLASS_MEMBER_FUNCTION(style_transfer::finalize_training);
109 
110  REGISTER_CLASS_MEMBER_FUNCTION(style_transfer::export_to_coreml, "filename",
111  "short_description", "additional_user_defined", "options");
112  register_defaults("export_to_coreml",
113  {{"short_description", ""},
114  {"additional_user_defined", to_variant(std::map<std::string, flexible_type>())},
115  {"options", to_variant(std::map<std::string, flexible_type>())}});
116 
117  REGISTER_CLASS_MEMBER_FUNCTION(style_transfer::predict, "data", "options");
118 
119  REGISTER_CLASS_MEMBER_FUNCTION(style_transfer::import_from_custom_model,
120  "model_data", "version");
121 
122  REGISTER_CLASS_MEMBER_FUNCTION(style_transfer::get_styles, "style_index");
123  register_defaults("get_styles", {{"style_index", FLEX_UNDEFINED}});
124 
126 
127  protected:
128  // Override points allowing subclasses to inject dependencies
129 
130  // Factory for data_iterator
131  virtual std::unique_ptr<data_iterator> create_iterator(
132  data_iterator::parameters iterator_params) const;
133 
134  std::unique_ptr<data_iterator> create_iterator(gl_sarray content,
135  gl_sarray style, bool repeat,
136  bool training,
137  int random_seed) const;
138 
139  // Factory for compute_context
140  virtual std::unique_ptr<neural_net::compute_context> create_compute_context()
141  const;
142 
143  // Factories for Checkpoint
144  virtual std::unique_ptr<Checkpoint> load_checkpoint(
145  neural_net::float_array_map weights) const;
146 
147  virtual std::unique_ptr<Checkpoint> create_checkpoint(
148  Config config, const std::string& resnet_model_path) const;
149 
150  // Establishes training pipelines from the backend.
151  void connect_trainer(gl_sarray style, gl_sarray content,
152  const std::string& vgg_mlmodel_path,
153  std::unique_ptr<neural_net::compute_context> context);
154 
155  void perform_predict(gl_sarray images, gl_sframe_writer& result,
156  const std::vector<int>& style_idx, bool verbose);
157 
158  // Synchronously loads weights from the backend if necessary
159  const Checkpoint& read_checkpoint() const;
160 
161  Config get_config() const;
162 
163  template <typename T>
164  T read_state(const std::string& key) const {
165  return variant_get_value<T>(get_state().at(key));
166  }
167 
168  template <typename T>
169  typename std::map<std::string, T>::iterator _read_iter_opts(
170  std::map<std::string, T>& opts, const std::string& key) const {
171  auto iter = opts.find(key);
172  if (iter == opts.end())
173  log_and_throw("Expected option \"" + key + "\" not found.");
174  return iter;
175  }
176 
177  template <typename T>
178  T read_opts(std::map<std::string, turi::variant_type>& opts,
179  const std::string& key) const {
180  auto iter = _read_iter_opts<turi::variant_type>(opts, key);
181  return variant_get_value<T>(iter->second);
182  }
183 
184  template <typename T>
185  T read_opts(std::map<std::string, turi::flexible_type>& opts,
186  const std::string& key) const {
187  auto iter = _read_iter_opts<turi::flexible_type>(opts, key);
188  return iter->second.get<T>();
189  }
190 
191  private:
192  // Primary representation for the trained model. Can be null if the model has
193  // been updated since the last checkpoint.
194  mutable std::unique_ptr<Checkpoint> checkpoint_;
195 
196  // Primary dependencies for training. These should be nonnull while training
197  // is in progress.
198  std::shared_ptr<ModelTrainer> model_trainer_;
199  std::shared_ptr<neural_net::FuturesStream<TrainingProgress>>
200  training_futures_;
201  std::shared_ptr<neural_net::FuturesStream<std::unique_ptr<Checkpoint>>>
202  checkpoint_futures_;
203 
204  std::unique_ptr<neural_net::model_spec> m_resnet_spec;
205  std::unique_ptr<neural_net::model_spec> m_vgg_spec;
206 
207  std::unique_ptr<data_iterator> m_training_data_iterator;
208  std::unique_ptr<neural_net::compute_context> m_training_compute_context;
209  std::unique_ptr<neural_net::model_backend> m_training_model;
210 
211  std::unique_ptr<table_printer> training_table_printer_;
212 
213  static gl_sarray convert_types_to_sarray(const variant_type& data);
214 
215  /**
216  * convert_style_indices_to_filter
217  *
218  * This function takes a variant type and converts it into a boolean filter.
219  * The elements at the indices we want to keep are set to a value of `1`, the
220  * elements we don't want to keep are set to a value of `0`.
221  */
222  gl_sarray convert_style_indices_to_filter(const variant_type& data);
223  gl_sframe style_sframe_with_index(gl_sarray styles);
224 
225  flex_int get_max_iterations() const;
226  flex_int get_training_iterations() const;
227  flex_int get_num_classes() const;
228 
229  void infer_derived_options(neural_net::compute_context* context);
230 };
231 
232 } // namespace style_transfer
233 } // namespace turi
234 
235 #endif // __TOOLKITS_STYLE_TRANSFER_H_
#define BEGIN_CLASS_MEMBER_REGISTRATION(python_facing_classname)
#define REGISTER_CLASS_MEMBER_DOCSTRING(name, docstring)
#define REGISTER_CLASS_MEMBER_FUNCTION(function,...)
#define IMPORT_BASE_CLASS_REGISTRATION(base_class)
boost::make_recursive_variant< flexible_type, std::shared_ptr< unity_sgraph_base >, dataframe_t, std::shared_ptr< model_base >, std::shared_ptr< unity_sframe_base >, std::shared_ptr< unity_sarray_base >, std::map< std::string, boost::recursive_variant_ >, std::vector< boost::recursive_variant_ >, boost::recursive_wrapper< function_closure_info > >::type variant_type
Definition: variant.hpp:24
#define END_CLASS_MEMBER_REGISTRATION
variant_type to_variant(const T &f)
Definition: variant.hpp:308
static flexible_type FLEX_UNDEFINED