Turi Create
4.0
|
#include <ml/neural_net/mps_od_backend.hpp>
Public Member Functions | |
void | set_learning_rate (float lr) override |
float_array_map | train (const float_array_map &inputs) override |
float_array_map | predict (const float_array_map &inputs) const override |
float_array_map | export_weights () const override |
Model backend for object detection that uses a separate mps_graph_cnnmodule for training and for inference, since mps_graph_cnnmodule doesn't currently support doing both.
Definition at line 22 of file mps_od_backend.hpp.
|
overridevirtual |
Exports the network weights.
Implements turi::neural_net::model_backend.
|
overridevirtual |
Performs a forward pass.
inputs | A map containing all the named inputs required by the model. |
Implements turi::neural_net::model_backend.
|
overridevirtual |
Sets the learning rate to be used for future calls to train.
Implements turi::neural_net::model_backend.
|
overridevirtual |
Performs one forward-backward pass.
inputs | A map containing all the named inputs and labels required by the model. |
Implements turi::neural_net::model_backend.