7 #ifndef TURI_DRAWING_CLASSIFIER_H_ 8 #define TURI_DRAWING_CLASSIFIER_H_ 10 #include <core/data/sframe/gl_sframe.hpp> 11 #include <core/logging/table_printer/table_printer.hpp> 12 #include <ml/neural_net/compute_context.hpp> 13 #include <ml/neural_net/model_backend.hpp> 14 #include <ml/neural_net/model_spec.hpp> 15 #include <model_server/lib/extensions/ml_model.hpp> 16 #include <model_server/lib/variant.hpp> 17 #include <toolkits/coreml_export/mlmodel_wrapper.hpp> 18 #include <toolkits/coreml_export/neural_net_models_exporter.hpp> 19 #include <toolkits/drawing_classifier/dc_data_iterator.hpp> 22 namespace drawing_classifier {
24 class EXPORT drawing_classifier :
public ml_model_base {
26 static const size_t DRAWING_CLASSIFIER_VERSION;
28 drawing_classifier() =
default;
34 void init_options(
const std::map<std::string, flexible_type>& opts)
override;
36 size_t get_version()
const override;
38 void save_impl(oarchive& oarc)
const override;
40 void load_version(iarchive& iarc,
size_t version)
override;
42 void import_from_custom_model(variant_map_type model_data,
size_t version);
48 void train(gl_sframe data, std::string target_column_name,
49 std::string feature_column_name,
variant_type validation_data,
50 std::map<std::string, flexible_type> opts);
52 gl_sarray predict(gl_sframe data, std::string output_type =
"probability");
54 gl_sframe predict_topk(gl_sframe data,
55 std::string output_type =
"probability",
size_t k = 5);
57 variant_map_type evaluate(gl_sframe data, std::string metric);
59 std::shared_ptr<coreml::MLModelWrapper> export_to_coreml(
60 std::string filename, std::string short_description,
61 const std::map<std::string, flexible_type>& additional_user_defined,
62 bool use_default_spec =
false);
66 virtual void init_training(gl_sframe data, std::string target_column_name,
67 std::string feature_column_name,
69 std::map<std::string, flexible_type> opts);
71 virtual void iterate_training(
bool show_loss);
78 "target_column_name", "feature_column_name",
79 "validation_data", "options");
81 register_defaults("train",
82 {{
"validation_data",
to_variant(std::string(
"auto"))},
84 to_variant(std::map<std::string, flexible_type>())}});
87 drawing_classifier::train,
90 " Input data, which consists of columns named by the\n" 91 " feature_column_name and target_column_name parameters, used for\n" 92 " training the Drawing Classifier." 93 "target_column_name : string\n" 94 " Name of the column containing the target variable. The values in " 95 " this column must be of string type.\n" 96 "feature_column_name : string\n" 97 " Name of the column containing the input drawings.\n" 98 " The feature column can contain either bitmap-based drawings or\n" 99 " stroke-based drawings. Bitmap-based drawing input can be a\n" 100 " grayscale tc.Image of any size.\n" 102 " Stroke-based drawing input must be in the following format:\n" 103 " Every drawing must be represented by a list of strokes, where each\n" 104 " stroke must be a list of points in the order in which they were\n" 105 " drawn on the canvas.\n" 107 " Each point must be a dictionary with two keys,\n" 108 " \"x\" and \"y\", and their\n" 109 " respective values must be numerical, i.e. either integer or float.\n" 110 "validatation_data : SFrame or string\n" 111 " A dataset for monitoring the model's generalization performance to\n" 112 " prevent the model from overfitting to the training data.\n" 114 " For each row of the progress table, accuracy is measured over the\n" 115 " provided training dataset and the `validation_data`. The format of\n" 116 " this SFrame must be the same as the training set.\n" 118 " When set to 'auto', a validation set is automatically sampled from " 120 " training data (if the training data has > 100 sessions).\n" 125 "max_iterations : int\n" 126 " Maximum number of iterations/epochs made over the data during the\n" 127 " training phase. The default is 500 iterations.\n" 129 " Number of sequence chunks used per training step. Must be greater " 131 " the number of GPUs in use. The default is 32.\n" 132 "random_seed : int\n" 133 " The given seed is used for random weight initialization and\n" 134 " sampling during training\n");
139 register_defaults(
"predict", {{
"output_type", std::string(
"class")}});
142 drawing_classifier::predict,
145 " The drawing(s) on which to perform drawing classification.\n" 146 " If dataset is an SFrame, it must have a column with the same name\n" 147 " as the feature column during training. Additional columns are\n" 149 " If the data is a single drawing, it can be either of type\n" 150 " tc.Image, in which case it is a bitmap-based drawing input,\n" 151 " or of type list, in which case it is a stroke-based drawing input.\n" 152 "output_type : {\"class\", \"probability_vector\"}, optional\n" 153 " Form of each prediction which is one of:\n" 154 " - \"probability_vector\": Prediction probability associated with \n" 155 " each class as a vector. The probability of first class (sorted\n" 156 " alphanumerically by name of the class in the training set) is in\n" 157 " position 0 of the vector, the second in position 1 and so on.\n" 158 " - \"class\": Class prediction. This returns the class with maximum\n" 164 register_defaults(
"predict_topk",
165 {{
"output_type", std::string(
"probability")}});
168 drawing_classifier::predict_topk,
171 " Dataset of new observations.\n" 172 " SFrame must include columns with the same\n" 173 " names as the features used for model training, but does not\n" 174 " require a target column. Additional columns are ignored." 175 "output_type : {\"probability\", \"rank\"}, optional\n" 176 " Form of each prediction which is one of:\n" 177 " - \"probability\": Probability associated with each label in the\n" 179 " - \"rank\": Rank associated with each label in the prediction.\n" 181 " Number of classes to return for each input example.\n");
186 register_defaults(
"evaluate", {{
"metric", std::string(
"auto")}});
189 drawing_classifier::evaluate,
192 " Dataset of new observations. Must include columns with the same\n" 193 " names as the features used for model training, but does not\n" 194 " require a target column. Additional columns are ignored.\n" 195 "metric : str, optional\n" 196 " Name of the evaluation metric. Possible values are:\n" 197 " - 'auto' : Returns all available metrics\n" 198 " - 'accuracy' : Classification accuracy (micro average)\n" 199 " - 'auc' : Area under the ROC curve (macro average)\n" 200 " - 'precision' : Precision score (macro average)\n" 201 " - 'recall' : Recall score (macro average)\n" 202 " - 'f1_score' : F1 score (macro average)\n" 203 " - 'log_loss' : Log loss\n" 204 " - 'confusion_matrix' : An SFrame with counts of possible\n" 205 " prediction/true label combinations.\n" 206 " - 'roc_curve' : An SFrame containing information needed for\n" 210 "filename",
"short_description",
"additional_user_defined");
211 register_defaults(
"export_to_coreml",
212 {{
"short_description",
""},
213 {
"additional_user_defined",
to_variant(std::map<std::string, flexible_type>())}});
216 "target_column_name",
"feature_column_name",
217 "validation_data",
"options");
218 register_defaults(
"init_training",
219 {{
"validation_data",
to_variant(gl_sframe())},
221 to_variant(std::map<std::string, flexible_type>())}});
226 "model_data",
"version");
234 const std::map<std::string, variant_type>& initial_state,
235 std::unique_ptr<neural_net::model_spec> nn_spec,
236 std::unique_ptr<neural_net::compute_context> training_compute_context,
237 std::unique_ptr<data_iterator> training_data_iterator,
238 std::unique_ptr<neural_net::model_backend> training_model)
239 : nn_spec_(
std::move(nn_spec)),
240 training_data_iterator_(
std::move(training_data_iterator)),
241 training_compute_context_(
std::move(training_compute_context)),
242 training_model_(
std::move(training_model)) {
243 add_or_update_state(initial_state);
247 virtual std::unique_ptr<data_iterator> create_iterator(
248 data_iterator::parameters iterator_params)
const;
251 virtual std::unique_ptr<neural_net::compute_context> create_compute_context()
255 virtual std::unique_ptr<neural_net::model_spec> init_model(
256 bool use_random_init)
const;
258 virtual std::tuple<gl_sframe, gl_sframe> init_data(
261 virtual std::tuple<float, float> compute_validation_metrics(
262 size_t num_classes,
size_t batch_size);
264 virtual void init_table_printer(
bool has_validation,
bool show_loss);
266 template <
typename T>
267 T read_state(
const std::string& key)
const {
269 return variant_get_value<T>(get_state().at(key));
270 }
catch (
const std::out_of_range& e) {
271 std::stringstream ss;
272 ss << e.what() << std::endl;
273 ss <<
"from read_state for '" << key <<
"'" << std::endl;
274 throw std::out_of_range(ss.str().c_str());
278 std::unique_ptr<neural_net::model_spec> clone_model_spec_for_test()
const {
280 return std::unique_ptr<neural_net::model_spec>(
281 new neural_net::model_spec(nn_spec_->get_coreml_spec()));
288 gl_sframe perform_inference(data_iterator* data)
const;
289 gl_sarray get_predictions_class(
const gl_sarray& predictions_prob,
297 std::unique_ptr<data_iterator> create_iterator(gl_sframe data,
bool is_train,
301 std::unique_ptr<neural_net::model_spec> nn_spec_;
305 gl_sframe training_data_;
306 gl_sframe validation_data_;
307 std::unique_ptr<data_iterator> training_data_iterator_;
308 std::unique_ptr<data_iterator> validation_data_iterator_;
309 std::unique_ptr<neural_net::compute_context> training_compute_context_;
310 std::unique_ptr<neural_net::model_backend> training_model_;
312 std::unique_ptr<table_printer> training_table_printer_;
318 #endif // TURI_DRAWING_CLASSIFIER_H_ #define BEGIN_CLASS_MEMBER_REGISTRATION(python_facing_classname)
#define REGISTER_CLASS_MEMBER_DOCSTRING(name, docstring)
#define REGISTER_CLASS_MEMBER_FUNCTION(function,...)
#define IMPORT_BASE_CLASS_REGISTRATION(base_class)
boost::make_recursive_variant< flexible_type, std::shared_ptr< unity_sgraph_base >, dataframe_t, std::shared_ptr< model_base >, std::shared_ptr< unity_sframe_base >, std::shared_ptr< unity_sarray_base >, std::map< std::string, boost::recursive_variant_ >, std::vector< boost::recursive_variant_ >, boost::recursive_wrapper< function_closure_info > >::type variant_type
#define END_CLASS_MEMBER_REGISTRATION
variant_type to_variant(const T &f)
std::vector< flexible_type > flex_list