Turi Create  4.0
model_base.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_UNITY_MODEL_BASE_HPP
7 #define TURI_UNITY_MODEL_BASE_HPP
8 
9 #include <functional>
10 #include <map>
11 #include <string>
12 #include <utility>
13 #include <vector>
14 
15 #include <core/export.hpp>
16 
17 #include <core/system/cppipc/ipc_object_base.hpp>
18 #include <core/system/cppipc/magic_macros.hpp>
19 #include <core/storage/serialization/serialization_includes.hpp>
20 #include <model_server/lib/toolkit_class_macros.hpp>
21 #include <model_server/lib/variant.hpp>
22 
23 namespace turi {
24 
25 class model_proxy;
26 
27 /**
28  * The base class from which all new toolkit classes must inherit.
29  *
30  * This class defines a generic object interface, listing properties and
31  * callable methods, so that instances can be naturally wrapped and exposed to
32  * other languages, such as Python.
33  *
34  * Subclasses should use the macros defined in toolkit_class_macros.hpp to
35  * declared the desired properties and methods, and to define their
36  * implementations. See that file for details and examples.
37  *
38  * Subclasses that wish to support saving and loading should also override the
39  * save_impl, load_version, and get_version functions below.
40  */
41 // TODO: Clean up the relationship between the save/load interface defined here
42 // and re-declared in ml_model_base.
43 // TODO: Remove the inheritance from ipc_object_base once the cppipc code has
44 // been disentangled and removed.
45 class EXPORT model_base: public cppipc::ipc_object_base {
46  public:
47  // TODO: Remove this type alias once this class stops inheriting from
48  // ipc_object_base.
50 
51  virtual ~model_base();
52 
53  // These public member functions define the communication between model_base
54  // instances and the unity runtime. Subclasses define the behavior of their
55  // instances using the protected interface below.
56 
57  /**
58  * Returns the name of the toolkit class, as exposed to client code. For
59  * example, the Python proxy for this instance will have a type with this
60  * name.
61  *
62  * Note: this function is typically overridden using the
63  * BEGIN_CLASS_MEMBER_REGISTRATION macro.
64  */
65  virtual const char* name() = 0;
66 
67  /**
68  * Returns a unique identifier for the toolkit class. It can be *any* unique
69  * ID. The UID is only used at runtime (to determine the concrete type of an
70  * arbitrary model_base instance) and is never stored.
71  *
72  * Note: this function is typically overridden using the
73  * BEGIN_CLASS_MEMBER_REGISTRATION macro.
74  */
75  virtual const std::string& uid() = 0;
76 
77  void save(oarchive& oarc) const {
78  oarc << get_version();
79  save_impl(oarc);
80  }
81 
82  /**
83  * Serializes the toolkit class.
84  * Must save the class to the file format
85  * version matching that of get_version().
86  */
87  virtual void save_impl(oarchive& oarc) const {
88  // A subclass needs to override these methods if it has any data that needs to be serialized.
89  // Otherwise this is a valid serialization of an empty model.
90  };
91 
92  /**
93  * Loads a toolkit class previously saved at a particular version number.
94  * Should raise an exception on failure.
95  */
96  virtual void load_version(iarchive& iarc, size_t version) {
97  // A subclass needs to override these methods if it has any data that needs to be serialized.
98  // Otherwise this is a valid serialization of an empty model.
99  } ;
100 
101 
102 
103  void load(iarchive& iarc) {
104  size_t version = 0;
105  iarc >> version;
106  load_version(iarc, version);
107  }
108 
109 
110 
111  /**
112  * Save a toolkit class to disk.
113  *
114  * \param url The destination url to store the class.
115  * \param sidedata Any additional side information
116  */
117  void save_to_url(const std::string& url, const variant_map_type& side_data = {});
118 
119 
120  /**
121  * Save a toolkit class to a data stream.
122  */
123  void save_model_to_data(std::ostream& out);
124 
125 
126  /**
127  * Returns the current version of the toolkit class for this instance, for
128  * serialization purposes.
129  */
130  virtual size_t get_version() const { return 0; }
131 
132  /**
133  * Lists all the registered functions.
134  * Returns a map of function name to array of argument names for the function.
135  */
136  const std::map<std::string, std::vector<std::string> >& list_functions();
137 
138  /**
139  * Lists all the get-table properties of the class.
140  */
141  const std::vector<std::string>& list_get_properties();
142 
143  /**
144  * Lists all the set-table properties of the class.
145  */
146  const std::vector<std::string>& list_set_properties();
147 
148  /**
149  * Calls a user defined function.
150  */
151  variant_type call_function(const std::string& function, variant_map_type argument);
152 
153  /**
154  * Reads a property.
155  */
156  variant_type get_property(const std::string& property);
157 
158  /**
159  * Sets a property. The new value of the property should appear in the
160  * argument map under the key "value".
161  */
162  variant_type set_property(const std::string& property, variant_map_type argument);
163 
164  /**
165  * Returns the toolkit documentation for a function or property.
166  */
167  const std::string& get_docstring(const std::string& symbol);
168 
169  /** Declare the base registration function. This class has to be handled
170  * specially; the macros don't work here due to the override declarations.
171  *
172  */
173  virtual void perform_registration();
174 
175  // TODO: Remove this vestigial macro invocation once the dependency on cppipc
176  // has been removed.
178  REGISTRATION_END
179 
180  protected:
181  using impl_fn = std::function<variant_type(model_base*, variant_map_type)>;
182 
183  // The macros defined in toolkit_class_macros.h use these functions to
184  // conveniently define this instance's collection of client-level methods
185  // and properties.
186 
187 
188  // Used to ensure that perform_registration is called once for each
189  // instance.
190  bool is_registered() const { return m_registered; }
191  void set_registered() { m_registered = true; }
192 
193  /**
194  * Adds a function with the specified name, and argument list.
195  */
196  void register_function(std::string fnname,
197  const std::vector<std::string>& arguments, impl_fn fn);
198 
199  /**
200  * Registers default argument values
201  */
202  void register_defaults(const std::string& fnname,
203  const variant_map_type& arguments);
204 
205  /**
206  * Adds a property setter with the specified name.
207  */
208  void register_setter(const std::string& propname, impl_fn setfn);
209 
210  /**
211  * Adds a property getter with the specified name.
212  */
213  void register_getter(const std::string& propname, impl_fn getfn);
214 
215  /**
216  * Adds a docstring for the specified function or property name.
217  */
218  void register_docstring(
219  const std::pair<std::string, std::string>& fnname_docstring);
220 
221  private:
222  // whether perform registration has been called
223  bool m_registered = false;
224  // a description of all the function arguments. This is returned by
225  // list_functions().
226  std::map<std::string, std::vector<std::string>> m_function_args;
227 
228  // default arguments, if any
229  std::map<std::string, variant_map_type> m_function_default_args;
230  // The implementation of each function
231  std::map<std::string, impl_fn> m_function_list;
232  // The implementation of each setter function
233  std::map<std::string, impl_fn> m_set_property_list;
234  mutable std::vector<std::string> m_set_property_cache;
235 
236  // The implementation of each getter function
237  std::map<std::string, impl_fn> m_get_property_list;
238  mutable std::vector<std::string> m_get_property_cache;
239 
240  // The docstring for each symbol
241  std::map<std::string, std::string> m_docstring;
242 
243  // Internal helper functions
244  inline void _check_registration();
245 
246  template <typename T>
247  GL_COLD_NOINLINE_ERROR void _raise_not_found_error(
248  const std::string& name, const std::map<std::string, T>& m);
249 
250  std::string _make_method_name(const std::string& function);
251 
252 
253 };
254 
255 // TODO: Remove this proxy subclass once the dependency on cppipc has been
256 // removed.
257 #ifndef DISABLE_TURI_CPPIPC_PROXY_GENERATION
258 /**
259  * Explicitly implemented proxy object.
260  *
261  */
262 class model_proxy : public model_base {
263  public:
265 
266  inline model_proxy(cppipc::comm_client& comm,
267  bool auto_create = true,
268  size_t object_id = (size_t)(-1)):
269  proxy(comm, auto_create, object_id){ }
270 
271  inline void save(turi::oarchive& oarc) const {
272  oarc << proxy.get_object_id();
273  }
274 
275  inline size_t __get_object_id() const {
276  return proxy.get_object_id();
277  }
278 
279  inline void load(turi::iarchive& iarc) {
280  size_t objid; iarc >> objid;
281  proxy.set_object_id(objid);
282  }
283 
284  virtual size_t get_version() const {
285  std_log_and_throw(std::runtime_error,"Calling Unreachable Function");
286  }
287 
288  const std::string& uid() {
289  std_log_and_throw(std::runtime_error, "Calling Unreachable Function");
290  }
292  std_log_and_throw(std::runtime_error,"Calling Unreachable Function");
293  }
294 
295  /**
296  * Serializes the model. Must save the model to the file format version
297  * matching that of get_version()
298  */
299  virtual void save_impl(oarchive& oarc) const {
300  std_log_and_throw(std::runtime_error, "Calling Unreachable Function");
301  }
302 
303  /**
304  * Loads a model previously saved at a particular version number.
305  * Should raise an exception on failure.
306  */
307  void load_version(iarchive& iarc, size_t version) {
308  std_log_and_throw(std::runtime_error, "Calling Unreachable Function");
309  }
310 
311  BOOST_PP_SEQ_FOR_EACH(__GENERATE_PROXY_CALLS__, model_base,
312  __ADD_PARENS__(
313  (const char*, name, )
314  ))
315 };
316 #endif
317 } // namespace turi
318 
319 #endif // TURI_UNITY_MODEL_BASE_HPP
virtual size_t get_version() const
Definition: model_base.hpp:130
void load_version(iarchive &iarc, size_t version)
Definition: model_base.hpp:307
The serialization input archive object which, provided with a reference to an istream, will read from the istream, providing deserialization capabilities.
Definition: iarchive.hpp:60
virtual void save_impl(oarchive &oarc) const
Definition: model_base.hpp:299
const std::string & uid()
Definition: model_base.hpp:288
void perform_registration()
Definition: model_base.hpp:291
boost::make_recursive_variant< flexible_type, std::shared_ptr< unity_sgraph_base >, dataframe_t, std::shared_ptr< model_base >, std::shared_ptr< unity_sframe_base >, std::shared_ptr< unity_sarray_base >, std::map< std::string, boost::recursive_variant_ >, std::vector< boost::recursive_variant_ >, boost::recursive_wrapper< function_closure_info > >::type variant_type
Definition: variant.hpp:24
size_t get_object_id() const
#define REGISTRATION_BEGIN(name)
void set_object_id(size_t object_id)
The serialization output archive object which, provided with a reference to an ostream, will write to the ostream, providing serialization capabilities.
Definition: oarchive.hpp:80
virtual void load_version(iarchive &iarc, size_t version)
Definition: model_base.hpp:96
#define GL_COLD_NOINLINE_ERROR
virtual size_t get_version() const
Definition: model_base.hpp:284
virtual void save_impl(oarchive &oarc) const
Definition: model_base.hpp:87