8 #ifndef __TOOLKITS_STYLE_TRANSFER_H_ 9 #define __TOOLKITS_STYLE_TRANSFER_H_ 13 #include <core/data/sframe/gl_sarray.hpp> 14 #include <core/data/sframe/gl_sframe.hpp> 15 #include <core/logging/table_printer/table_printer.hpp> 16 #include <ml/neural_net/compute_context.hpp> 17 #include <ml/neural_net/model_backend.hpp> 18 #include <ml/neural_net/model_spec.hpp> 19 #include <model_server/lib/extensions/ml_model.hpp> 20 #include <toolkits/coreml_export/mlmodel_wrapper.hpp> 21 #include <toolkits/coreml_export/neural_net_models_exporter.hpp> 23 #include <toolkits/style_transfer/style_transfer_data_iterator.hpp> 26 namespace style_transfer {
29 std::vector<std::pair<flex_int, flex_image>> process_output(
30 const neural_net::shared_float_array& contents,
size_t index);
31 neural_net::float_array_map prepare_batch(
const std::vector<st_example>& batch,
32 size_t width,
size_t height,
34 neural_net::float_array_map prepare_predict(
const st_example& example);
36 class EXPORT style_transfer :
public ml_model_base {
38 void init_options(
const std::map<std::string, flexible_type>& opts)
override;
39 size_t get_version()
const override;
40 void save_impl(oarchive& oarc)
const override;
41 void load_version(iarchive& iarc,
size_t version)
override;
43 virtual std::shared_ptr<coreml::MLModelWrapper> export_to_coreml(
44 std::string filename, std::string short_description,
45 std::map<std::string, flexible_type> additional_user_defined,
46 std::map<std::string, flexible_type> opts);
48 void train(gl_sarray style, gl_sarray content,
49 std::map<std::string, flexible_type> opts);
51 virtual void init_training(gl_sarray style, gl_sarray content,
52 std::map<std::string, flexible_type> opts);
53 virtual void resume_training(gl_sarray style, gl_sarray content,
54 std::map<std::string, flexible_type> opts);
56 virtual void iterate_training();
57 virtual void synchronize_training();
58 virtual void finalize_training();
61 std::map<std::string, flexible_type> opts);
65 void import_from_custom_model(variant_map_type model_data,
size_t version);
74 style_transfer::train,
78 "resnet_mlmodel_path :
string\n"
79 " Path to the Resnet
CoreML specification with the pre-trained model\n"
81 "vgg_mlmodel_path:
string\n"
82 " Path to the VGG16
CoreML specification with the pre-trained model\n"
85 " The defined number of styles for the style transfer model\n"
87 " The number of images per training iteration. If 0, then it will be\n"
88 " automatically determined based on resource availability.\n"
89 "max_iterations :
int\n"
90 " The number of training iterations. If 0, then it will be "
92 " be determined based on the amount of data you provide.\n"
94 " The input image width to the model\n"
95 "image_height :
int\n"
96 " The input image height to the model\n");
102 register_defaults("resume_training",
104 to_variant(std::map<std::string, flexible_type>())}});
111 "short_description",
"additional_user_defined",
"options");
112 register_defaults(
"export_to_coreml",
113 {{
"short_description",
""},
114 {
"additional_user_defined",
to_variant(std::map<std::string, flexible_type>())},
115 {
"options",
to_variant(std::map<std::string, flexible_type>())}});
120 "model_data",
"version");
123 register_defaults(
"get_styles", {{
"style_index",
FLEX_UNDEFINED}});
131 virtual std::unique_ptr<data_iterator> create_iterator(
132 data_iterator::parameters iterator_params)
const;
134 std::unique_ptr<data_iterator> create_iterator(gl_sarray content,
135 gl_sarray style,
bool repeat,
137 int random_seed)
const;
140 virtual std::unique_ptr<neural_net::compute_context> create_compute_context()
144 virtual std::unique_ptr<Checkpoint> load_checkpoint(
145 neural_net::float_array_map weights)
const;
147 virtual std::unique_ptr<Checkpoint> create_checkpoint(
148 Config config,
const std::string& resnet_model_path)
const;
151 void connect_trainer(gl_sarray style, gl_sarray content,
152 const std::string& vgg_mlmodel_path,
153 std::unique_ptr<neural_net::compute_context> context);
155 void perform_predict(gl_sarray images, gl_sframe_writer& result,
156 const std::vector<int>& style_idx,
bool verbose);
159 const Checkpoint& read_checkpoint()
const;
161 Config get_config()
const;
163 template <
typename T>
164 T read_state(
const std::string& key)
const {
165 return variant_get_value<T>(get_state().at(key));
168 template <
typename T>
169 typename std::map<std::string, T>::iterator _read_iter_opts(
170 std::map<std::string, T>& opts,
const std::string& key)
const {
171 auto iter = opts.find(key);
172 if (iter == opts.end())
173 log_and_throw(
"Expected option \"" + key +
"\" not found.");
177 template <
typename T>
178 T read_opts(std::map<std::string, turi::variant_type>& opts,
179 const std::string& key)
const {
180 auto iter = _read_iter_opts<turi::variant_type>(opts, key);
181 return variant_get_value<T>(iter->second);
184 template <
typename T>
185 T read_opts(std::map<std::string, turi::flexible_type>& opts,
186 const std::string& key)
const {
187 auto iter = _read_iter_opts<turi::flexible_type>(opts, key);
188 return iter->second.get<T>();
194 mutable std::unique_ptr<Checkpoint> checkpoint_;
198 std::shared_ptr<ModelTrainer> model_trainer_;
199 std::shared_ptr<neural_net::FuturesStream<TrainingProgress>>
201 std::shared_ptr<neural_net::FuturesStream<std::unique_ptr<Checkpoint>>>
204 std::unique_ptr<neural_net::model_spec> m_resnet_spec;
205 std::unique_ptr<neural_net::model_spec> m_vgg_spec;
207 std::unique_ptr<data_iterator> m_training_data_iterator;
208 std::unique_ptr<neural_net::compute_context> m_training_compute_context;
209 std::unique_ptr<neural_net::model_backend> m_training_model;
211 std::unique_ptr<table_printer> training_table_printer_;
213 static gl_sarray convert_types_to_sarray(
const variant_type& data);
222 gl_sarray convert_style_indices_to_filter(
const variant_type& data);
223 gl_sframe style_sframe_with_index(gl_sarray styles);
225 flex_int get_max_iterations()
const;
226 flex_int get_training_iterations()
const;
229 void infer_derived_options(neural_net::compute_context* context);
235 #endif // __TOOLKITS_STYLE_TRANSFER_H_
#define BEGIN_CLASS_MEMBER_REGISTRATION(python_facing_classname)
#define REGISTER_CLASS_MEMBER_DOCSTRING(name, docstring)
#define REGISTER_CLASS_MEMBER_FUNCTION(function,...)
#define IMPORT_BASE_CLASS_REGISTRATION(base_class)
boost::make_recursive_variant< flexible_type, std::shared_ptr< unity_sgraph_base >, dataframe_t, std::shared_ptr< model_base >, std::shared_ptr< unity_sframe_base >, std::shared_ptr< unity_sarray_base >, std::map< std::string, boost::recursive_variant_ >, std::vector< boost::recursive_variant_ >, boost::recursive_wrapper< function_closure_info > >::type variant_type
#define END_CLASS_MEMBER_REGISTRATION
variant_type to_variant(const T &f)
static flexible_type FLEX_UNDEFINED