7 #ifndef UNITY_TOOLKITS_NEURAL_NET_COMPUTE_CONTEXT_HPP_ 8 #define UNITY_TOOLKITS_NEURAL_NET_COMPUTE_CONTEXT_HPP_ 12 #include <core/export.hpp> 13 #include <core/system/exceptions/TuriException.hpp> 14 #include <ml/neural_net/image_augmentation.hpp> 15 #include <ml/neural_net/model_backend.hpp> 18 namespace neural_net {
58 using factory = std::function<std::unique_ptr<compute_context>()>;
73 int priority()
const {
return priority_; }
74 std::unique_ptr<compute_context> create_context()
const {
78 std::unique_ptr<compute_context> create_tensorflow_context()
const {
79 return tf_factory_fn_ ? tf_factory_fn_() :
nullptr;
82 std::unique_ptr<compute_context> create_mlc_context()
const 84 return mlc_factory_fn_ ? mlc_factory_fn_() :
nullptr;
100 static std::unique_ptr<compute_context> create();
102 static std::unique_ptr<compute_context> create_tf();
104 static std::unique_ptr<compute_context> create_mlc();
111 virtual void print_training_device_info()
const = 0;
119 virtual size_t memory_budget()
const = 0;
130 int c_out,
int h_out,
int w_out,
131 const float_array_map& config,
132 const float_array_map&
weights)
134 throw TuriException(TuriErrorCode::NotImplemented);
147 throw TuriException(TuriErrorCode::NotImplemented);
159 const float_array_map&
weights)
161 throw TuriException(TuriErrorCode::NotImplemented);
176 throw TuriException(TuriErrorCode::NotImplemented);
185 throw TuriException(TuriErrorCode::NotImplemented);
192 int n,
int c_in,
int c_out,
const std::vector<size_t>& layer_sizes,
193 const turi::neural_net::float_array_map& config)
195 throw TuriException(TuriErrorCode::NotImplemented);
202 #endif // UNITY_TOOLKITS_NEURAL_NET_COMPUTE_CONTEXT_HPP_
virtual std::unique_ptr< model_backend > create_activity_classifier(const ac_parameters &ac_params)
int num_predictions_per_chunk
virtual std::unique_ptr< image_augmenter > create_image_augmenter(const image_augmenter::options &opts)
std::function< std::unique_ptr< compute_context >()> factory
virtual std::unique_ptr< model_backend > create_object_detector(int n, int c_in, int h_in, int w_in, int c_out, int h_out, int w_out, const float_array_map &config, const float_array_map &weights)
virtual std::unique_ptr< model_backend > create_style_transfer(const float_array_map &config, const float_array_map &weights)
virtual std::unique_ptr< turi::neural_net::model_backend > create_multilayer_perceptron_classifier(int n, int c_in, int c_out, const std::vector< size_t > &layer_sizes, const turi::neural_net::float_array_map &config)
virtual std::unique_ptr< model_backend > create_drawing_classifier(const float_array_map &weights, size_t batch_size, size_t num_classes)