Turi Create  4.0
mlc_dc_backend.hpp
1 /* Copyright © 2020 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at
5  * https://opensource.org/licenses/BSD-3-Clause
6  */
7 
8 #pragma once
9 
10 #import <MLCompute/MLCompute.h>
11 
12 #include <ml/neural_net/mlc_layer_weights.hpp>
13 #include <ml/neural_net/model_backend.hpp>
14 
15 namespace turi {
16 namespace neural_net {
17 
18 class API_AVAILABLE(macos(10.16)) mlc_drawing_classifier_backend : public model_backend {
19  public:
20  mlc_drawing_classifier_backend(MLCDevice *device, const float_array_map &weights,
21  size_t batch_size, size_t num_classes);
22 
23  // model_backend interface
24  float_array_map export_weights() const override;
25  void set_learning_rate(float lr) override;
26  float_array_map train(const float_array_map &inputs) override;
27  float_array_map predict(const turi::neural_net::float_array_map &inputs) const override;
28 
29  private:
30  MLCTrainingGraph *training_graph_ = nil;
31  MLCInferenceGraph *inference_graph_ = nil;
32  MLCTensor *input_ = nil;
33  MLCTensor *weights_ = nil;
34  MLCTensor *labels_ = nil;
35 
36  mlc_layer_weights layer_weights_;
37  size_t num_classes_;
38 };
39 
40 } // namespace neural_net
41 } // namespace turi