Turi Create  4.0
mps_compute_context.hpp
1 /* Copyright © 2018 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 
7 #ifndef UNITY_TOOLKITS_NEURAL_NET_MPS_COMPUTE_CONTEXT_HPP_
8 #define UNITY_TOOLKITS_NEURAL_NET_MPS_COMPUTE_CONTEXT_HPP_
9 
10 #include <memory>
11 
12 #include <ml/neural_net/compute_context.hpp>
13 #include <ml/neural_net/mps_command_queue.hpp>
14 
15 namespace turi {
16 namespace neural_net {
17 
18 std::unique_ptr<compute_context> create_mps_compute_context();
19 
20 /**
21  * A compute_context implementation backed by Apple frameworks: Metal
22  * Performance Shaders for neural network computation and Core Image for data
23  * augmentation.
24  */
26  public:
27 
28  /**
29  * Constructs a context wrapping the given Metal command queue.
30  */
31  mps_compute_context(std::unique_ptr<mps_command_queue> command_queue);
32 
34 
35  void print_training_device_info() const override;
36  size_t memory_budget() const override;
37 
38  std::unique_ptr<model_backend> create_object_detector(
39  int n, int c_in, int h_in, int w_in, int c_out, int h_out, int w_out,
40  const float_array_map& config, const float_array_map& weights) override;
41 
42  std::unique_ptr<model_backend> create_activity_classifier(
43  const ac_parameters& ac_params) override;
44 
45  static bool has_style_transfer();
46 
47  std::unique_ptr<model_backend> create_style_transfer(
48  const float_array_map& config, const float_array_map& weights) override;
49 
50  std::unique_ptr<model_backend> create_drawing_classifier(
51  /* TODO: const float_array_map& config if needed */
52  const float_array_map& weights, size_t batch_size,
53  size_t num_classes) override;
54 
55  std::unique_ptr<image_augmenter> create_image_augmenter(
56  const image_augmenter::options &opts) override;
57 
58  /**
59  * Alternate implementation of create_image_augmenter supporting injection of
60  * the random number generator, for test purposes.
61  */
62  static std::unique_ptr<image_augmenter> create_image_augmenter_for_testing(
63  const image_augmenter::options& opts,
64  std::function<float(float lower, float upper)> rng);
65 
66  private:
67  std::shared_ptr<mps_command_queue> command_queue_;
68 };
69 
70 } // namespace neural_net
71 } // namespace turi
72 
73 #endif // UNITY_TOOLKITS_NEURAL_NET_MPS_COMPUTE_CONTEXT_HPP_
std::unique_ptr< model_backend > create_drawing_classifier(const float_array_map &weights, size_t batch_size, size_t num_classes) override
std::unique_ptr< model_backend > create_object_detector(int n, int c_in, int h_in, int w_in, int c_out, int h_out, int w_out, const float_array_map &config, const float_array_map &weights) override
std::unique_ptr< model_backend > create_style_transfer(const float_array_map &config, const float_array_map &weights) override
std::unique_ptr< model_backend > create_activity_classifier(const ac_parameters &ac_params) override
void print_training_device_info() const override
static std::unique_ptr< image_augmenter > create_image_augmenter_for_testing(const image_augmenter::options &opts, std::function< float(float lower, float upper)> rng)
std::unique_ptr< image_augmenter > create_image_augmenter(const image_augmenter::options &opts) override
size_t memory_budget() const override
mps_compute_context(std::unique_ptr< mps_command_queue > command_queue)