7 #ifndef UNITY_TOOLKITS_NEURAL_NET_WEIGHT_INIT_HPP_ 8 #define UNITY_TOOLKITS_NEURAL_NET_WEIGHT_INIT_HPP_ 14 namespace neural_net {
22 using weight_initializer = std::function<void(
float* first_weight,
25 class xavier_weight_initializer {
36 xavier_weight_initializer(
size_t fan_in,
size_t fan_out,
37 std::mt19937* random_engine);
43 void operator()(
float* first_weight,
float* last_weight);
47 std::uniform_real_distribution<float> dist_;
48 std::mt19937& random_engine_;
52 class uniform_weight_initializer {
63 uniform_weight_initializer(
float lower_bound,
float upper_bound,
64 std::mt19937* random_engine);
69 void operator()(
float* first_weight,
float* last_weight);
73 std::uniform_real_distribution<float> dist_;
74 std::mt19937& random_engine_;
77 class normal_weight_initializer {
88 normal_weight_initializer(
float mean,
float std_dev,
89 std::mt19937* random_engine);
95 void operator()(
float* first_weight,
float* last_weight);
98 std::normal_distribution<float> dist_;
99 std::mt19937& random_engine_;
102 struct scalar_weight_initializer {
109 scalar_weight_initializer(
float scalar);
110 void operator()(
float* first_weight,
float* last_weight);
116 struct zero_weight_initializer {
119 void operator()(
float* first_weight,
float* last_weight)
const {}
126 size_t input_size,
size_t state_size, std::mt19937* random_engine);
130 weight_initializer input_gate_weight_fn;
131 weight_initializer forget_gate_weight_fn;
132 weight_initializer block_input_weight_fn;
133 weight_initializer output_gate_weight_fn;
136 weight_initializer input_gate_recursion_fn;
137 weight_initializer forget_gate_recursion_fn;
138 weight_initializer block_input_recursion_fn;
139 weight_initializer output_gate_recursion_fn;
142 weight_initializer input_gate_bias_fn;
143 weight_initializer forget_gate_bias_fn;
144 weight_initializer block_input_bias_fn;
145 weight_initializer output_gate_bias_fn;
151 #endif // UNITY_TOOLKITS_NEURAL_NET_WEIGHT_INIT_HPP_