Turi Create  4.0
model_backend.hpp
1 /* Copyright © 2018 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 
7 #ifndef UNITY_TOOLKITS_NEURAL_NET_MODEL_BACKEND_HPP_
8 #define UNITY_TOOLKITS_NEURAL_NET_MODEL_BACKEND_HPP_
9 
10 #include <memory>
11 #include <string>
12 #include <vector>
13 
14 #include <ml/neural_net/float_array.hpp>
15 
16 namespace turi {
17 namespace neural_net {
18 
19 /**
20  * A pure virtual interface for neural networks, used to abstract across model
21  * architectures and backend implementations.
22  */
24 public:
25 
26  virtual ~model_backend() = default;
27 
28  /**
29  * Exports the network weights.
30  *
31  * \todo Someday, once no Python frontend depends on this method, this could
32  * just take a mutable model_spec (updating the one used to initialize
33  * the model_backend).
34  */
35  virtual float_array_map export_weights() const = 0;
36 
37  // TODO: Accessors describing name inputs and expected shapes.
38 
39  /**
40  * Performs a forward pass.
41  *
42  * \param inputs A map containing all the named inputs required by the model.
43  * \return A map containing all the named outputs from the model. The values
44  * may be deferred_float_array instances wrapping future
45  * (asynchronous) results.
46  */
47  virtual float_array_map predict(const float_array_map& inputs) const = 0;
48 
49  /**
50  * Sets the learning rate to be used for future calls to train.
51  */
52  virtual void set_learning_rate(float lr) = 0;
53 
54  /**
55  * Performs one forward-backward pass.
56  *
57  * \param inputs A map containing all the named inputs and labels required by
58  * the model.
59  * \return A map containing all the named outputs and loss images from the
60  * model. The values may be deferred_float_array instances wrapping
61  * future (asynchronous) results.
62  *
63  * \todo Introduce a separate mutable subclass, so that prediction-only models
64  * don't need to have a `train` method at all.
65  */
66  virtual float_array_map train(const float_array_map& inputs) = 0;
67 };
68 
69 } // namespace neural_net
70 } // namespace turi
71 
72 #endif // UNITY_TOOLKITS_NEURAL_NET_MODEL_BACKEND_HPP_
virtual void set_learning_rate(float lr)=0
virtual float_array_map export_weights() const =0
virtual float_array_map train(const float_array_map &inputs)=0
virtual float_array_map predict(const float_array_map &inputs) const =0