Turi Create  4.0
turi::neural_net::mps_od_backend Class Reference

#include <ml/neural_net/mps_od_backend.hpp>

Public Member Functions

void set_learning_rate (float lr) override
 
float_array_map train (const float_array_map &inputs) override
 
float_array_map predict (const float_array_map &inputs) const override
 
float_array_map export_weights () const override
 

Detailed Description

Model backend for object detection that uses a separate mps_graph_cnnmodule for training and for inference, since mps_graph_cnnmodule doesn't currently support doing both.

Definition at line 22 of file mps_od_backend.hpp.

Member Function Documentation

◆ export_weights()

float_array_map turi::neural_net::mps_od_backend::export_weights ( ) const
overridevirtual

Exports the network weights.

Implements turi::neural_net::model_backend.

◆ predict()

float_array_map turi::neural_net::mps_od_backend::predict ( const float_array_map &  inputs) const
overridevirtual

Performs a forward pass.

Parameters
inputsA map containing all the named inputs required by the model.
Returns
A map containing all the named outputs from the model. The values may be deferred_float_array instances wrapping future (asynchronous) results.

Implements turi::neural_net::model_backend.

◆ set_learning_rate()

void turi::neural_net::mps_od_backend::set_learning_rate ( float  lr)
overridevirtual

Sets the learning rate to be used for future calls to train.

Implements turi::neural_net::model_backend.

◆ train()

float_array_map turi::neural_net::mps_od_backend::train ( const float_array_map &  inputs)
overridevirtual

Performs one forward-backward pass.

Parameters
inputsA map containing all the named inputs and labels required by the model.
Returns
A map containing all the named outputs and loss images from the model. The values may be deferred_float_array instances wrapping future (asynchronous) results.

Implements turi::neural_net::model_backend.


The documentation for this class was generated from the following file: