Turi Create  4.0
turi::neural_net::model_backend Class Referenceabstract

#include <ml/neural_net/model_backend.hpp>

Public Member Functions

virtual float_array_map export_weights () const =0
 
virtual float_array_map predict (const float_array_map &inputs) const =0
 
virtual void set_learning_rate (float lr)=0
 
virtual float_array_map train (const float_array_map &inputs)=0
 

Detailed Description

A pure virtual interface for neural networks, used to abstract across model architectures and backend implementations.

Definition at line 23 of file model_backend.hpp.

Member Function Documentation

◆ export_weights()

virtual float_array_map turi::neural_net::model_backend::export_weights ( ) const
pure virtual

Exports the network weights.

Implemented in turi::neural_net::mps_od_backend.

◆ predict()

virtual float_array_map turi::neural_net::model_backend::predict ( const float_array_map &  inputs) const
pure virtual

Performs a forward pass.

Parameters
inputsA map containing all the named inputs required by the model.
Returns
A map containing all the named outputs from the model. The values may be deferred_float_array instances wrapping future (asynchronous) results.

Implemented in turi::neural_net::mps_od_backend.

◆ set_learning_rate()

virtual void turi::neural_net::model_backend::set_learning_rate ( float  lr)
pure virtual

Sets the learning rate to be used for future calls to train.

Implemented in turi::neural_net::mps_od_backend.

◆ train()

virtual float_array_map turi::neural_net::model_backend::train ( const float_array_map &  inputs)
pure virtual

Performs one forward-backward pass.

Parameters
inputsA map containing all the named inputs and labels required by the model.
Returns
A map containing all the named outputs and loss images from the model. The values may be deferred_float_array instances wrapping future (asynchronous) results.

Implemented in turi::neural_net::mps_od_backend.


The documentation for this class was generated from the following file: