Turi Create  4.0
mlmodel_wrapper.hpp
1 #ifndef __TC_ML_MODEL_WRAPPER_HPP_
2 #define __TC_ML_MODEL_WRAPPER_HPP_
3 
4 #include <memory>
5 
6 #include <model_server/lib/extensions/model_base.hpp>
7 #include <model_server/lib/toolkit_class_macros.hpp>
8 
9 // Forward declare CoreML::Model in lieu of including problematic protocol
10 // buffer headers.
11 namespace CoreML {
12 class Model;
13 }
14 
15 namespace turi {
16 namespace coreml {
17 
18 class MLModelWrapper : public model_base {
19  public:
20  MLModelWrapper() {}
21  MLModelWrapper(std::shared_ptr<CoreML::Model> model) : m_model(std::move(model)) {}
22 
23  void save(const std::string& path_to_save_file);
24 
25  void add_metadata(const std::map<std::string, flexible_type>& context_metadata);
26 
27  std::shared_ptr<CoreML::Model> coreml_model() const { return m_model; }
28 
29  private:
30  std::shared_ptr<CoreML::Model> m_model;
31 
32  BEGIN_CLASS_MEMBER_REGISTRATION("_MLModelWrapper")
33  REGISTER_CLASS_MEMBER_FUNCTION(MLModelWrapper::save, "path")
34  REGISTER_CLASS_MEMBER_FUNCTION(MLModelWrapper::add_metadata,
35  "context_metadata")
37 };
38 } // namespace coreml
39 } // namespace turi
40 
41 #endif
#define BEGIN_CLASS_MEMBER_REGISTRATION(python_facing_classname)
#define REGISTER_CLASS_MEMBER_FUNCTION(function,...)
void add_metadata(CoreML::Specification::Model &model_spec, const std::map< std::string, flexible_type > &context)
STL namespace.
#define END_CLASS_MEMBER_REGISTRATION