10 #include <unordered_map> 13 #import <MLCompute/MLCompute.h> 15 #include <ml/neural_net/model_backend.hpp> 17 #include <ml/neural_net/mlc_layer_weights.hpp> 21 @class TCModelTrainerBackendGraphs;
23 namespace object_detection {
24 class DarknetYOLOCheckpoint;
29 namespace neural_net {
31 class API_AVAILABLE(macos(10.16)) mlc_object_detector_backend :
public model_backend {
35 static TCModelTrainerBackendGraphs *create_graphs(
40 mlc_object_detector_backend(MLCDevice *device,
size_t n,
size_t c_in,
size_t h_in,
size_t w_in,
41 size_t c_out,
size_t h_out,
size_t w_out,
42 const float_array_map &config,
const float_array_map &weights);
45 float_array_map export_weights()
const override;
46 void set_learning_rate(
float lr)
override;
47 float_array_map train(
const float_array_map &inputs)
override;
48 float_array_map predict(
const float_array_map &inputs)
const override;
51 static TCModelTrainerBackendGraphs *create_graphs(
size_t n,
size_t c_in,
size_t h_in,
size_t w_in,
52 size_t c_out,
size_t h_out,
size_t w_out,
53 const float_array_map &config,
54 const float_array_map &weights,
55 mlc_layer_weights *layer_weights);
58 MLCTrainingGraph *training_graph_ = nil;
59 MLCInferenceGraph *inference_graph_ = nil;
60 MLCTensor *input_ = nil;
61 MLCTensor *labels_ = nil;
63 mlc_layer_weights layer_weights_;
64 std::vector<size_t> output_shape_;