Turi Create  4.0
mlc_compute_context.hpp
1 /* Copyright © 2020 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at
5  * https://opensource.org/licenses/BSD-3-Clause
6  */
7 #pragma once
8 
9 #import <MLCompute/MLCompute.h>
10 
11 #include <ml/neural_net/compute_context.hpp>
12 
13 namespace turi {
14 namespace neural_net {
15 
16 /**
17  * A compute_context implementation backed by MLFoundation (and Core Image for
18  * data augmentation).
19  */
20 class API_AVAILABLE(macos(10.16)) mlc_compute_context : public compute_context {
21  public:
22  mlc_compute_context(MLCDevice* device);
23  ~mlc_compute_context();
24 
25  void print_training_device_info() const override;
26  size_t memory_budget() const override;
27 
28  std::unique_ptr<model_backend> create_object_detector(int n, int c_in, int h_in, int w_in,
29  int c_out, int h_out, int w_out,
30  const float_array_map& config,
31  const float_array_map& weights) override;
32 
33  std::unique_ptr<model_backend> create_activity_classifier(
34  const ac_parameters& ac_params) override;
35 
36  std::unique_ptr<model_backend> create_drawing_classifier(const float_array_map& weights,
37  size_t batch_size,
38  size_t num_classes) override;
39 
40  std::unique_ptr<image_augmenter> create_image_augmenter(
41  const image_augmenter::options& opts) override;
42 
43  std::unique_ptr<model_backend> create_style_transfer(const float_array_map& config,
44  const float_array_map& weights) override;
45 
46  std::unique_ptr<model_backend> create_multilayer_perceptron_classifier(
47  int n, int c_in, int c_out, const std::vector<size_t>& layer_sizes,
48  const turi::neural_net::float_array_map& config) override;
49 
50  protected:
51  MLCDevice* GetDevice() const;
52 
53  private:
54  MLCDevice* device_ = nil;
55 };
56 
57 } // namespace neural_net
58 } // namespace turi