Turi Create  4.0
ml_model.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_UNITY_ML_MODEL_HPP
7 #define TURI_UNITY_ML_MODEL_HPP
8 
9 #include <model_server/lib/variant.hpp>
10 #include <model_server/lib/unity_base_types.hpp>
11 #include <model_server/lib/toolkit_util.hpp>
12 #include <model_server/lib/toolkit_function_specification.hpp>
13 #include <model_server/lib/toolkit_class_macros.hpp>
14 
15 #include <model_server/lib/extensions/model_base.hpp>
16 #include <model_server/lib/extensions/option_manager.hpp>
17 
18 #include <core/export.hpp>
19 
20 namespace turi {
21 
22 /**
23  * ml_model model base class.
24  * ---------------------------------------
25  *
26  * Base class for handling machine learning models. This class is meant to
27  * be a guide to aid model writing and not a hard and fast rule of how the
28  * code must be structured.
29  *
30  * Each machine learning C++ toolkit contains the following:
31  *
32  * *) state: This is the key-value map that stores the "model" attributes.
33  * The value is of type "variant_type" which is fully interfaced
34  * with python. You can add basic types, vectors, SFrames etc.
35  *
36  *
37  * *) options: Option manager which keeps track of default options, current
38  * options, option ranges, type etc. This must be initialized only
39  * once in the set_options() function.
40  *
41  *
42  * Functions that should always be implemented. Here are some notes about
43  * each of these functions that may help guide you in writing your model.
44  *
45  * *) name: Get the name of this model. You might thinks that this is silly but
46  * the name holds the key to everything. The unity_server can construct
47  * model_base objects and they can be cast to a model of this type.
48  * The name determine how the casting happens. The init_models()
49  * function in unity_server.cpp will give you an idea of how
50  * this interface happens.
51  *
52  * *) save: Save the model with the turicreate iarc. Turi is a server-client
53  * module. DO NOT SAVE ANYTHING in the client side. Make sure that
54  * everything is in the server side. For example: You might be tempted
55  * do keep options that the user provides into the server side but
56  * DO NOT do that because save and load will break things for you!
57  *
58  * *) load: Load the model with the turicreate oarc.
59  *
60  * *) version: A get version for this model
61  *
62  *
63  */
64 class EXPORT ml_model_base: public model_base {
65 
66  public:
67 
68  static constexpr size_t ML_MODEL_BASE_VERSION = 0;
69 
70  // virtual destructor
71  inline virtual ~ml_model_base() { }
72 
73  /**
74  * Set one of the options in the algorithm. Use the option manager to set
75  * these options. If the option does not satisfy the conditions that the
76  * option manager has imposed on it. Errors will be thrown.
77  *
78  * \param[in] options Options to set
79  */
80  virtual void init_options(const std::map<std::string,flexible_type>& _options) {};
81 
82 
83  /**
84  * Methods with already meaningful default implementations.
85  * -------------------------------------------------------------------------
86  */
87 
88 
89  /**
90  * Lists all the keys accessible in the "model" map.
91  *
92  * \returns List of keys in the model map.
93  * \ref model_base for details.
94  *
95  * Python side interface
96  * ------------------------
97  *
98  * This is the function that the list_fields should call in python.
99  */
100  std::vector<std::string> list_fields();
101 
102 
103  /**
104  * Returns the value of a particular key from the state.
105  *
106  * \returns Value of a key
107  * \ref model_base for details.
108  *
109  * Python side interface
110  * ------------------------
111  *
112  * From the python side, this is interfaced with the get() function or the
113  * [] operator in python.
114  *
115  */
116  const variant_type& get_value_from_state(std::string key);
117 
118 
119  /**
120  * Get current options.
121  *
122  * \returns Dictionary containing current options.
123  *
124  * Python side interface
125  * ------------------------
126  * Interfaces with the get_current_options function in the Python side.
127  */
128  const std::map<std::string, flexible_type>& get_current_options() const;
129 
130  /**
131  * Get default options.
132  *
133  * \returns Dictionary with default options.
134  *
135  * Python side interface
136  * ------------------------
137  * Interfaces with the get_default_options function in the Python side.
138  */
139  std::map<std::string, flexible_type> get_default_options() const;
140 
141  /**
142  * Returns the value of an option. Throws an error if the option does not
143  * exist.
144  *
145  * \param[in] name Name of the option to get.
146  */
147  const flexible_type& get_option_value(const std::string& name) const;
148 
149  /**
150  * Get model.
151  *
152  * \returns Model map.
153  */
154  const std::map<std::string, variant_type>& get_state() const;
155 
156  /**
157  * Is this model trained.
158  *
159  * \returns True if already trained.
160  */
161  bool is_trained() const;
162 
163  /**
164  * Set one of the options in the algorithm.
165  *
166  * The value are checked with the requirements given by the option
167  * instance.
168  *
169  * \param[in] name Name of the option.
170  * \param[in] value Value for the option.
171  */
172  void set_options(const std::map<std::string, flexible_type>& _options);
173 
174  /**
175  * Append the key value store of the model.
176  *
177  * \param[in] dict Options (Key-Value pairs) to set
178  */
179  void add_or_update_state(const std::map<std::string, variant_type>& dict);
180 
181  /** Returns the option information struct for each of the set
182  * parameters.
183  */
184  const std::vector<option_handling::option_info>& get_option_info() const;
185 
186  // Code to perform the registration for the rest of the tools.
188 
190 
194  "field");
200 
202 
203  protected:
204 
205  option_manager options; /* Option manager */
206  std::map<std::string, variant_type> state; /**< All things python */
207 
208 
209 };
210 
211 
212 namespace ml_model_sdk {
213 
214 /**
215  * Obtains the registration for the toolkit.
216  */
217 std::vector<toolkit_function_specification> get_toolkit_function_registration();
218 
219 /**
220  * Call the default options using a registered model.
221  *
222  * \param[in] name Name of the model registered in the class.
223  */
224 std::map<std::string, variant_type> _toolkits_get_default_options(
225  std::string model_name);
226 
227 } // namespace ml_model_sdk
228 
229 } // namespace turi
230 
231 #endif
#define REGISTER_CLASS_MEMBER_FUNCTION(function,...)
#define BEGIN_BASE_CLASS_MEMBER_REGISTRATION()
std::map< std::string, variant_type > state
Definition: ml_model.hpp:206
virtual void init_options(const std::map< std::string, flexible_type > &_options)
Definition: ml_model.hpp:80
#define IMPORT_BASE_CLASS_REGISTRATION(base_class)
boost::make_recursive_variant< flexible_type, std::shared_ptr< unity_sgraph_base >, dataframe_t, std::shared_ptr< model_base >, std::shared_ptr< unity_sframe_base >, std::shared_ptr< unity_sarray_base >, std::map< std::string, boost::recursive_variant_ >, std::vector< boost::recursive_variant_ >, boost::recursive_wrapper< function_closure_info > >::type variant_type
Definition: variant.hpp:24
#define END_CLASS_MEMBER_REGISTRATION
const variant_type & get_value_from_state(std::string key)
std::vector< std::string > list_fields()
#define REGISTER_NAMED_CLASS_MEMBER_FUNCTION(name, function,...)
const flexible_type & get_option_value(const std::string &name) const
const std::map< std::string, variant_type > & get_state() const
std::map< std::string, flexible_type > get_default_options() const
void set_options(const std::map< std::string, flexible_type > &_options)
bool is_trained() const