Turi Create  4.0
mlc_od_backend.hpp
1 /* Copyright © 2020 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at
5  * https://opensource.org/licenses/BSD-3-Clause
6  */
7 #pragma once
8 
9 #include <string>
10 #include <unordered_map>
11 #include <vector>
12 
13 #import <MLCompute/MLCompute.h>
14 
15 #include <ml/neural_net/model_backend.hpp>
16 
17 #include <ml/neural_net/mlc_layer_weights.hpp>
18 
19 // Forward declarations necessary to avoid header-inclusion issues. See the
20 // .cpp file for more details.
21 @class TCModelTrainerBackendGraphs;
22 namespace turi {
23 namespace object_detection {
24 class DarknetYOLOCheckpoint;
25 }
26 } // namespace turi
27 
28 namespace turi {
29 namespace neural_net {
30 
31 class API_AVAILABLE(macos(10.16)) mlc_object_detector_backend : public model_backend {
32  public:
33  // Expose the construction of the underlying MLFoundation graphs, for testing
34  // purposes.
35  static TCModelTrainerBackendGraphs *create_graphs(
37 
38  // TODO: No, really, replace this legacy interface with one that just accepts
39  // a struct with explicit, named parameters for OD specifically.
40  mlc_object_detector_backend(MLCDevice *device, size_t n, size_t c_in, size_t h_in, size_t w_in,
41  size_t c_out, size_t h_out, size_t w_out,
42  const float_array_map &config, const float_array_map &weights);
43 
44  // model_backend interface
45  float_array_map export_weights() const override;
46  void set_learning_rate(float lr) override;
47  float_array_map train(const float_array_map &inputs) override;
48  float_array_map predict(const float_array_map &inputs) const override;
49 
50  private:
51  static TCModelTrainerBackendGraphs *create_graphs(size_t n, size_t c_in, size_t h_in, size_t w_in,
52  size_t c_out, size_t h_out, size_t w_out,
53  const float_array_map &config,
54  const float_array_map &weights,
55  mlc_layer_weights *layer_weights);
56 
57  // TODO: Replace these with an instance of TCModelTrainerBackendGraphs?
58  MLCTrainingGraph *training_graph_ = nil;
59  MLCInferenceGraph *inference_graph_ = nil;
60  MLCTensor *input_ = nil;
61  MLCTensor *labels_ = nil;
62 
63  mlc_layer_weights layer_weights_;
64  std::vector<size_t> output_shape_;
65 };
66 
67 } // namespace neural_net
68 } // namespace turi