9 #import <MLCompute/MLCompute.h> 11 #include <ml/neural_net/compute_context.hpp> 14 namespace neural_net {
20 class API_AVAILABLE(macos(10.16)) mlc_compute_context :
public compute_context {
22 mlc_compute_context(MLCDevice* device);
23 ~mlc_compute_context();
25 void print_training_device_info()
const override;
26 size_t memory_budget()
const override;
28 std::unique_ptr<model_backend> create_object_detector(
int n,
int c_in,
int h_in,
int w_in,
29 int c_out,
int h_out,
int w_out,
30 const float_array_map& config,
31 const float_array_map& weights)
override;
33 std::unique_ptr<model_backend> create_activity_classifier(
34 const ac_parameters& ac_params)
override;
36 std::unique_ptr<model_backend> create_drawing_classifier(
const float_array_map& weights,
38 size_t num_classes)
override;
40 std::unique_ptr<image_augmenter> create_image_augmenter(
41 const image_augmenter::options& opts)
override;
43 std::unique_ptr<model_backend> create_style_transfer(
const float_array_map& config,
44 const float_array_map& weights)
override;
46 std::unique_ptr<model_backend> create_multilayer_perceptron_classifier(
47 int n,
int c_in,
int c_out,
const std::vector<size_t>& layer_sizes,
48 const turi::neural_net::float_array_map& config)
override;
51 MLCDevice* GetDevice()
const;
54 MLCDevice* device_ = nil;