Turi Create  4.0
mlc_layer_weights.hpp
1 /* Copyright © 2020 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at
5  * https://opensource.org/licenses/BSD-3-Clause
6  */
7 #pragma once
8 
9 #include <string>
10 #include <unordered_map>
11 
12 #include <ml/neural_net/float_array.hpp>
13 
14 #import <MLCompute/MLCompute.h>
15 
16 NS_ASSUME_NONNULL_BEGIN
17 
18 namespace turi {
19 namespace neural_net {
20 
21 class API_AVAILABLE(macos(10.16)) mlc_layer_weights {
22  public:
23  mlc_layer_weights();
24 
25  // Movable but not copyable
26  mlc_layer_weights(const mlc_layer_weights &) = delete;
27  mlc_layer_weights(mlc_layer_weights &&);
28  mlc_layer_weights &operator=(const mlc_layer_weights &) = delete;
29  mlc_layer_weights &operator=(mlc_layer_weights &&);
30 
31  ~mlc_layer_weights();
32 
33  /**
34  * Returns views into the weights managed by this instance, represented as
35  * tensors formatted according to the method used to add each name.
36  */
37  NSDictionary<NSString *, MLCTensor *> *tensor_views() const { return tensors_; }
38 
39  /**
40  * Returns strong references to the memory backing the weights managed by this
41  * instance. The return value uses the same keys as the dictionary returned by
42  * `tensor_views()`.
43  */
44  NSDictionary<NSString *, NSData *> *tensor_weights() const;
45 
46  /**
47  * Returns copies of all the weights managed by this instance. The caller is
48  * responsible for ensuring that MLCompute is not modifying any of these
49  * weights for the duration of this function call (and that MLCompute has
50  * synchronized weights from GPU to this memory, if necessary).
51  */
52  float_array_map export_weights() const;
53 
54  /**
55  * Returns copies of all the weights managed by this instance. It also copies the
56  * optimizer data to the tensors to the weights map. The caller is responsible for
57  * ensuring that MLCompute is not modifying any of these weights for the duration
58  * of this function call (and that MLCompute has synchronized weights from GPU
59  * to this memory, if necessary).
60  */
61  float_array_map export_weights_and_optimizer_data() const;
62 
63  /**
64  * Imports the kernel weights for a convolution layer. The input must have
65  * shape OIHW.
66  */
67  void add_conv_weight(const std::string &name, const float_array &arr);
68 
69  /**
70  * Imports the bias weights for a convolution layer. The input must be
71  * one-dimensional.
72  */
73  void add_conv_bias(const std::string &name, const float_array &arr);
74 
75  /**
76  * Imports one-dimensional weights, such as those for batch-normalization
77  * layers.
78  */
79  void add_flat_array(const std::string &name, const float_array &arr);
80 
81  /**
82  * Imports one weight matrix to be passed to an MLCompute LSTM layer. The
83  * input must be two-dimensional.
84  */
85  void add_lstm_weight(const std::string &name, const float_array &arr);
86 
87  /**
88  * Adds the optimizer data to the saved tensors.
89  */
90  void add_optimizer_data(const std::string &layer_name, const std::string &optimizer_data_1_key,
91  const float_array &optimizer_data_1,
92  const std::string &optimizer_data_2_key,
93  const float_array &optimizer_data_2);
94 
95  private:
96  /**
97  * Copies arr and stores a strong reference to it with the given name. Returns
98  * a weak pointer to the copied memory to pass to MLCompute, which will write
99  * updated (trained) weights back into the same memory. (Yes, MLCompute
100  * ignores the fact that the tensor data is initialized with a const pointer.)
101  * */
102  MLCTensorData *copy_float_array(const std::string &name, const float_array &arr);
103 
104  /**
105  * Strong references to the memory that MLCompute will use to read the initial
106  * weights and to pass back the updated weights. Although this data structure
107  * uses the shared_float_array type, the memory here should not actually be
108  * shared with clients, since in general MLCompute might be modifying it
109  * asynchronously.
110  * */
111  std::unordered_map<std::string, shared_float_array> weights_;
112 
113  /**
114  * Collection of MLCompute tensors that wrap the data owned by weights_ above.
115  */
116  NSMutableDictionary<NSString *, MLCTensor *> *tensors_ = nil;
117 };
118 
119 } // namespace neural_net
120 } // namespace turi
121 
122 NS_ASSUME_NONNULL_END