Turi Create  4.0
weight_init.hpp
1 /* Copyright © 2019 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 
7 #ifndef UNITY_TOOLKITS_NEURAL_NET_WEIGHT_INIT_HPP_
8 #define UNITY_TOOLKITS_NEURAL_NET_WEIGHT_INIT_HPP_
9 
10 #include <functional>
11 #include <random>
12 
13 namespace turi {
14 namespace neural_net {
15 
16 /**
17  * Callback type used to initialize an underlying WeightParams instance.
18  *
19  * The callback should write the desired values into the provided iterator
20  * range, which is initialized to 0.f.
21  */
22 using weight_initializer = std::function<void(float* first_weight,
23  float* last_weight)>;
24 
25 class xavier_weight_initializer {
26 public:
27 
28  /**
29  * Creates a weight initializer that performs Xavier initialization
30  *
31  * \param fan_in The number of inputs that affect each output from the layer
32  * \param fan_out The number of outputs affected by each input to the layer
33  * \param random_engine The random number generator to use, which must remain
34  * valid for the lifetime of this instance.
35  */
36  xavier_weight_initializer(size_t fan_in, size_t fan_out,
37  std::mt19937* random_engine);
38 
39  /**
40  * Initializes each value in uniformly at random in the range [-c,c], where
41  * c = sqrt(3 / (0.5 * fan_in + 0.5 * fan_out).
42  */
43  void operator()(float* first_weight, float* last_weight);
44 
45 private:
46 
47  std::uniform_real_distribution<float> dist_;
48  std::mt19937& random_engine_;
49 };
50 
51 
52 class uniform_weight_initializer {
53  public:
54 
55  /**
56  * Creates a weight initializer that performs Uniform initialization
57  *
58  * \param lower_bound The lower bound of the uniform distribution to be sampled
59  * \param upper_bound The upper bound of the uniform distribution to be sampled
60  * \param random_engine The random number generator to use, which must remain
61  * valid for the lifetime of this instance.
62  */
63  uniform_weight_initializer(float lower_bound, float upper_bound,
64  std::mt19937* random_engine);
65 
66  /**
67  * Initializes each value in uniformly at random in the range [-lower_bound, upper_bound]
68  */
69  void operator()(float* first_weight, float* last_weight);
70 
71 private:
72 
73  std::uniform_real_distribution<float> dist_;
74  std::mt19937& random_engine_;
75 };
76 
77 class normal_weight_initializer {
78  public:
79  /**
80  * Creates a weight initializer that performs normal initialization
81  *
82  * \param mean The mean of the normal distribution to be sampled
83  * \param std_dev The standard deviation of the normal distribution to be
84  * sampled
85  * \param random_engine The random number generator to use, which must
86  * remain valid for the lifetime of this instance.
87  */
88  normal_weight_initializer(float mean, float std_dev,
89  std::mt19937* random_engine);
90 
91  /**
92  * Initializes each value in normally with specified mean and standard
93  * deviation.
94  */
95  void operator()(float* first_weight, float* last_weight);
96 
97  private:
98  std::normal_distribution<float> dist_;
99  std::mt19937& random_engine_;
100 };
101 
102 struct scalar_weight_initializer {
103  /**
104  * Creates a weight initializer that initializes all of the weights to a
105  * constant scalar value.
106  *
107  * \param scalar The scalar value to initialize the weights to.
108  */
109  scalar_weight_initializer(float scalar);
110  void operator()(float* first_weight, float* last_weight);
111 
112  private:
113  float scalar_;
114 };
115 
116 struct zero_weight_initializer {
117 
118  // No work is required, since we assume the buffer is zero-initialized.
119  void operator()(float* first_weight, float* last_weight) const {}
120 };
121 
122 /** Convenience struct to hold all the weight initializers required by LSTM */
124 
125  static lstm_weight_initializers create_with_xavier_method(
126  size_t input_size, size_t state_size, std::mt19937* random_engine);
127  static lstm_weight_initializers create_with_zero();
128 
129  // Initializers for matrices applied to sequence input
130  weight_initializer input_gate_weight_fn;
131  weight_initializer forget_gate_weight_fn;
132  weight_initializer block_input_weight_fn;
133  weight_initializer output_gate_weight_fn;
134 
135  // Initializers for matrices applied to hidden state
136  weight_initializer input_gate_recursion_fn;
137  weight_initializer forget_gate_recursion_fn;
138  weight_initializer block_input_recursion_fn;
139  weight_initializer output_gate_recursion_fn;
140 
141  // Initializers for bias
142  weight_initializer input_gate_bias_fn;
143  weight_initializer forget_gate_bias_fn;
144  weight_initializer block_input_bias_fn;
145  weight_initializer output_gate_bias_fn;
146 };
147 
148 } // neural_net
149 } // turi
150 
151 #endif // UNITY_TOOLKITS_NEURAL_NET_WEIGHT_INIT_HPP_