Turi Create  4.0
mps_od_backend.hpp
1 /* Copyright © 2020 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at
5  * https://opensource.org/licenses/BSD-3-Clause
6  */
7 
8 #ifndef MPS_OD_BACKEND_HPP_
9 #define MPS_OD_BACKEND_HPP_
10 
11 #include <ml/neural_net/mps_graph_cnnmodule.h>
12 #include <ml/neural_net/model_backend.hpp>
13 
14 namespace turi {
15 namespace neural_net {
16 
17 /**
18  * Model backend for object detection that uses a separate mps_graph_cnnmodule
19  * for training and for inference, since mps_graph_cnnmodule doesn't currently
20  * support doing both.
21  */
22 class mps_od_backend : public model_backend {
23  public:
24  struct parameters {
25  std::shared_ptr<mps_command_queue> command_queue;
26  int n;
27  int c_in;
28  int h_in;
29  int w_in;
30  int c_out;
31  int h_out;
32  int w_out;
33  float_array_map config;
34  float_array_map weights;
35  };
36 
37  mps_od_backend(parameters params);
38 
39  // Training
40  void set_learning_rate(float lr) override;
41  float_array_map train(const float_array_map& inputs) override;
42 
43  // Inference
44  float_array_map predict(const float_array_map& inputs) const override;
45 
46  float_array_map export_weights() const override;
47 
48  private:
49  void ensure_training_module();
50  void ensure_prediction_module() const;
51 
52  parameters params_;
53 
54  std::unique_ptr<mps_graph_cnn_module> training_module_;
55 
56  // Cleared whenever the training module is updated.
57  mutable std::unique_ptr<mps_graph_cnn_module> prediction_module_;
58 };
59 
60 } // namespace neural_net
61 } // namespace turi
62 
63 #endif // MPS_OD_BACKEND_HPP_
float_array_map export_weights() const override
float_array_map predict(const float_array_map &inputs) const override
float_array_map train(const float_array_map &inputs) override
void set_learning_rate(float lr) override