Turi Create  4.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
mlmodel_image_feature_extractor.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef MLMODEL_IMAGE_FEATURE_EXTRACTOR_HPP
7 #define MLMODEL_IMAGE_FEATURE_EXTRACTOR_HPP
8 
9 #include <toolkits/image_deep_feature_extractor/image_feature_extractor.hpp>
10 
11 #include <memory>
12 #include <string>
13 
14 namespace turi {
15 namespace image_deep_feature_extractor {
16 
17 // Implementation of image_feature_extractor that instantiates an actual MLModel
18 // to process each image.
19 class mlmodel_image_feature_extractor: public image_feature_extractor {
20 public:
21  // Constructs an instance produced by downloading a known MLModel and using
22  // the appropriate layer for the feature values. If necessary, the model will
23  // be downloaded to `download_path`. Supported `model_name` values are:
24  // "resnet-50", "squeezenet_v1.1"
25  mlmodel_image_feature_extractor(const std::string& model_name,
26  const std::string& download_path);
27 
28  // image_feature_extractor interface
29  const CoreML::Specification::Model& coreml_spec() const override;
30  gl_sarray extract_features(gl_sarray images, bool verbose, size_t batch_size) const override;
31 
32 private:
33  // Use PIMPL pattern to hide Objective C from this C++ header.
34  struct impl;
35 
36  std::unique_ptr<impl> m_impl;
37 };
38 
39 } // image_deep_feature_extractor
40 } // turi
41 
42 #endif // MLMODEL_IMAGE_FEATURE_EXTRACTOR_HPP