8 #ifndef MPS_OD_BACKEND_HPP_ 9 #define MPS_OD_BACKEND_HPP_ 11 #include <ml/neural_net/mps_graph_cnnmodule.h> 12 #include <ml/neural_net/model_backend.hpp> 15 namespace neural_net {
25 std::shared_ptr<mps_command_queue> command_queue;
33 float_array_map config;
34 float_array_map weights;
41 float_array_map
train(
const float_array_map& inputs)
override;
44 float_array_map
predict(
const float_array_map& inputs)
const override;
49 void ensure_training_module();
50 void ensure_prediction_module()
const;
54 std::unique_ptr<mps_graph_cnn_module> training_module_;
57 mutable std::unique_ptr<mps_graph_cnn_module> prediction_module_;
63 #endif // MPS_OD_BACKEND_HPP_ float_array_map export_weights() const override
float_array_map predict(const float_array_map &inputs) const override
float_array_map train(const float_array_map &inputs) override
void set_learning_rate(float lr) override