Turi Create  4.0
one_shot_object_detector.hpp
1 /* Copyright © 2019 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at
5  * https://opensource.org/licenses/BSD-3-Clause
6  */
7 
8 #ifndef TURI_ONE_SHOT_OBJECT_DETECTOR_H_
9 #define TURI_ONE_SHOT_OBJECT_DETECTOR_H_
10 
11 #include <core/data/sframe/gl_sframe.hpp>
12 #include <map>
13 #include <model_server/lib/extensions/ml_model.hpp>
14 #include <toolkits/coreml_export/mlmodel_wrapper.hpp>
15 #include <toolkits/object_detection/object_detector.hpp>
16 
17 namespace turi {
18 namespace one_shot_object_detection {
19 
20 class EXPORT one_shot_object_detector : public ml_model_base {
21  public:
22  // Constructor
23  one_shot_object_detector();
24 
25  // Interface exposed via Unity server
26 
27  // TODO: augment -> train
28  gl_sframe augment(const gl_sframe &data, const std::string &image_column_name,
29  const std::string &target_column_name,
30  const gl_sarray &backgrounds,
31  std::map<std::string, flexible_type> &options);
32 
33  BEGIN_CLASS_MEMBER_REGISTRATION("one_shot_object_detector")
34 
35  IMPORT_BASE_CLASS_REGISTRATION(ml_model_base);
36 
37  REGISTER_CLASS_MEMBER_FUNCTION(one_shot_object_detector::augment, "data",
38  "image_column_name", "target_column_name",
39  "backgrounds", "options");
40 
42 
43  private:
44  // Obsolete until we actually use the object_detector::train.
45  // Leaving it here anyway unless we decide we should remove it.
46  std::unique_ptr<turi::object_detection::object_detector> model_;
47 };
48 
49 } // namespace one_shot_object_detection
50 } // namespace turi
51 
52 #endif
#define BEGIN_CLASS_MEMBER_REGISTRATION(python_facing_classname)
#define REGISTER_CLASS_MEMBER_FUNCTION(function,...)
STL namespace.
#define IMPORT_BASE_CLASS_REGISTRATION(base_class)
#define END_CLASS_MEMBER_REGISTRATION