Turi Create  4.0
image_classification.hpp
1 #ifndef TURI_ANNOTATIONS_IMAGE_CLASSIFICATION_HPP
2 #define TURI_ANNOTATIONS_IMAGE_CLASSIFICATION_HPP
3 
4 #include <core/export.hpp>
5 
6 #include <chrono>
7 #include <future>
8 #include <iostream>
9 #include <thread>
10 
11 #include <visualization/annotation/annotation_base.hpp>
12 #include <model_server/lib/variant.hpp>
13 #include <toolkits/nearest_neighbors/unity_nearest_neighbors.hpp>
14 
15 #include "build/format/cpp/annotate.pb.h"
16 #include "build/format/cpp/data.pb.h"
17 
18 namespace turi {
19 namespace annotate {
20 
21 class ImageClassification : public AnnotationBase {
22 public:
23  ImageClassification() : AnnotationBase(){};
24  ImageClassification(const std::shared_ptr<unity_sframe> &data,
25  const std::vector<std::string> &data_columns,
26  const std::string &annotation_column);
27 
28  ~ImageClassification() {};
29 
30  annotate_spec::MetaData metaData() override;
31 
32  annotate_spec::Data getItems(size_t start, size_t end) override;
33 
34  annotate_spec::Annotations getAnnotations(size_t start, size_t end) override;
35 
36  bool setAnnotations(const annotate_spec::Annotations &annotations) override;
37 
38  void cast_annotations() override;
39 
40  void background_work() override;
41 
42  void addAnnotationColumn() override;
43 
44  void checkDataSet() override;
45 
46  annotate_spec::Similarity get_similar_items(size_t index,
47  size_t k = 7) override;
48 
49  BEGIN_CLASS_MEMBER_REGISTRATION("ImageClassification");
50  IMPORT_BASE_CLASS_REGISTRATION(AnnotationBase);
52 
53 private:
54 #ifdef __APPLE__
55  size_t m_feature_batch_size = 16;
56  image_deep_feature_extractor::image_deep_feature_extractor_toolkit m_extractor;
57 #endif
58  gl_sarray m_feature_sarray;
59  variant_map_type m_nn_model;
60  std::shared_ptr<std::thread> featurizer_thread;
61  gl_sarray m_image_feature_extraction_sarray;
62  std::shared_ptr<gl_sarray_writer> m_writer;
63 
64 
65  void _addAnnotationToSFrame(size_t index, std::string label);
66  void _addAnnotationToSFrame(size_t index, int label);
67  void _create_nearest_neighbors_model();
68  void _createFeaturesExtractor();
69  bool _stepFeaturesExtractor();
70 
71  std::shared_ptr<unity_sarray> _filterDataSFrame(size_t &start, size_t &end);
72  std::shared_ptr<unity_sarray> _filterAnnotationSFrame(size_t &start,
73  size_t &end);
74 };
75 
76 std::shared_ptr<ImageClassification> create_image_classification_annotation(
77  const std::shared_ptr<unity_sframe> &data,
78  const std::vector<std::string> &data_columns,
79  const std::string &annotation_column);
80 
81 } // namespace annotate
82 } // namespace turi
83 #endif
#define BEGIN_CLASS_MEMBER_REGISTRATION(python_facing_classname)
#define IMPORT_BASE_CLASS_REGISTRATION(base_class)
#define END_CLASS_MEMBER_REGISTRATION