10 #import <MLCompute/MLCompute.h> 12 #include <ml/neural_net/mlc_layer_weights.hpp> 13 #include <ml/neural_net/model_backend.hpp> 16 namespace neural_net {
18 class API_AVAILABLE(macos(10.16)) mlc_drawing_classifier_backend :
public model_backend {
20 mlc_drawing_classifier_backend(MLCDevice *device,
const float_array_map &weights,
21 size_t batch_size,
size_t num_classes);
24 float_array_map export_weights()
const override;
25 void set_learning_rate(
float lr)
override;
26 float_array_map train(
const float_array_map &inputs)
override;
27 float_array_map predict(
const turi::neural_net::float_array_map &inputs)
const override;
30 MLCTrainingGraph *training_graph_ = nil;
31 MLCInferenceGraph *inference_graph_ = nil;
32 MLCTensor *input_ = nil;
33 MLCTensor *weights_ = nil;
34 MLCTensor *labels_ = nil;
36 mlc_layer_weights layer_weights_;