Turi Create  4.0
model_factory.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_FACTORIZATION_MODEL_CREATION_FACTORY_H_
7 #define TURI_FACTORIZATION_MODEL_CREATION_FACTORY_H_
8 
9 #include <toolkits/factorization/factorization_model_sgd_interface.hpp>
10 #include <toolkits/factorization/factorization_model_impl.hpp>
11 #include <toolkits/sgd/basic_sgd_solver.hpp>
12 #include <toolkits/sgd/sgd_interface.hpp>
13 #include <toolkits/factorization/ranking_sgd_solver_explicit.hpp>
14 #include <toolkits/factorization/ranking_sgd_solver_implicit.hpp>
15 
16 #include <string>
17 #include <cstdlib>
18 
19 ////////////////////////////////////////////////////////////////////////////////
20 // All of the macros we use to create an instance of the solvers,
21 // instantiate the templates, and suppress the instantiations. We
22 // have a set of instantiation files in factory_instantiations/*,
23 // which separates the different parts of the tree created by these
24 // macros into different compilation units.
25 
26 
27 #define _BAD(param) \
28  ASSERT_MSG(false, (std::string(#param) + " \'" + param + "\' not recognized.").c_str()); \
29  ASSERT_UNREACHABLE(); \
30 
31 ////////////////////////////////////////////////////////////////////////////////
32 // Step 1: Select the solver.
33 
34 #define CREATE_RETURN_SOLVER(...) \
35  do{ \
36  if(solver_class == "sgd::basic_sgd_solver") \
37  CREATE_RETURN_LOSS_NORMAL(sgd::basic_sgd_solver, __VA_ARGS__); \
38  else if(solver_class == "factorization::explicit_ranking_sgd_solver") \
39  CREATE_RETURN_LOSS_NORMAL(factorization::explicit_ranking_sgd_solver, __VA_ARGS__); \
40  else \
41  _BAD(solver_class); \
42  }while(false)
43 
44 #define SUPPRESS_SOLVERS() \
45  _SUPPRESS_BY_LOSS(sgd::basic_sgd_solver); \
46  _SUPPRESS_BY_LOSS(factorization::explicit_ranking_sgd_solver);
47 
48 
49 ////////////////////////////////////////////////////////////////////////////////
50 // Step 2: Select the loss. We have to handle two cases; the first
51 // one is the case with a normal loss, and the second with the
52 // implicit_ranking_sgd_solver.
53 
54 #define CREATE_RETURN_LOSS_NORMAL(...) \
55  do{ \
56  if(loss_type == "loss_squared_error") \
57  CREATE_RETURN_REGULARIZER(loss_squared_error, __VA_ARGS__); \
58  else if(loss_type == "loss_logistic") \
59  CREATE_RETURN_REGULARIZER(loss_logistic, __VA_ARGS__); \
60  else \
61  _BAD(loss_type); \
62  } while(false)
63 
64 #define _SUPPRESS_BY_LOSS(...) \
65  _SUPPRESS_BY_REGULARIZER(loss_squared_error, __VA_ARGS__); \
66  _SUPPRESS_BY_REGULARIZER(loss_logistic, __VA_ARGS__);
67 
68 
69 ////////////////////////////////////////////////////////////////////////////////
70 // Step 3: Select the regularizer.
71 
72 #define CREATE_RETURN_REGULARIZER(...) \
73  do { \
74  if(regularization_type == "L2") { \
75  CREATE_RETURN_FACTORS(L2, __VA_ARGS__); \
76  } else if(regularization_type == "ON_THE_FLY") { \
77  CREATE_RETURN_FACTORS(ON_THE_FLY, __VA_ARGS__); \
78  } else if(regularization_type == "NONE") { \
79  /* does tempering with L2; still need this. */ \
80  CREATE_RETURN_FACTORS(L2, __VA_ARGS__); \
81  } else { \
82  _BAD(regularization_type); \
83  } \
84  } while(false)
85 
86 #define _SUPPRESS_BY_REGULARIZER(...) \
87  _SUPPRESS_BY_FACTORS(L2, __VA_ARGS__); \
88  _SUPPRESS_BY_FACTORS(ON_THE_FLY, __VA_ARGS__);
89 
90 // We are now at the point where
91 #define _INSTANTIATE_LOSS_AND_SOLVER(...) \
92  _INSTANTIATE_BY_FACTORS(L2, __VA_ARGS__); \
93  _INSTANTIATE_BY_FACTORS(ON_THE_FLY, __VA_ARGS__);
94 
95 ////////////////////////////////////////////////////////////////////////////////
96 // Step 4: Select the factor mode and associated number of factors.
97 
98 #define CREATE_RETURN_FACTORS(...) \
99  do { \
100  if(factor_mode == "pure_linear_model" || num_factors == 0) { \
101  _CREATE_AND_RETURN(pure_linear_model, 0, __VA_ARGS__); \
102  } else if(factor_mode == "matrix_factorization" \
103  || ( ((factor_mode == "factorization_machine" \
104  && (train_data.metadata()->num_columns() == 2))))) { \
105  if(num_factors == 8) \
106  _CREATE_AND_RETURN(matrix_factorization, 8, __VA_ARGS__); \
107  else \
108  _CREATE_AND_RETURN(matrix_factorization, -1, __VA_ARGS__); \
109  } else if(factor_mode == "factorization_machine") { \
110  if(num_factors == 8) \
111  _CREATE_AND_RETURN(factorization_machine, 8, __VA_ARGS__); \
112  else \
113  _CREATE_AND_RETURN(factorization_machine, -1, __VA_ARGS__); \
114  } else { \
115  _BAD(factor_mode); \
116  } \
117  } while(false)
118 
119 #define _SUPPRESS_BY_FACTORS(...) \
120  _SUPPRESS_SOLVER(pure_linear_model, 0, __VA_ARGS__); \
121  _SUPPRESS_SOLVER(matrix_factorization, 8, __VA_ARGS__); \
122  _SUPPRESS_SOLVER(matrix_factorization, -1, __VA_ARGS__); \
123  _SUPPRESS_SOLVER(factorization_machine, 8, __VA_ARGS__); \
124  _SUPPRESS_SOLVER(factorization_machine, -1, __VA_ARGS__); \
125 
126 #define _INSTANTIATE_BY_FACTORS(...) \
127  _INSTANTIATE_SOLVER(pure_linear_model, 0, __VA_ARGS__); \
128  _INSTANTIATE_SOLVER(matrix_factorization, 8, __VA_ARGS__); \
129  _INSTANTIATE_SOLVER(matrix_factorization, -1, __VA_ARGS__); \
130  _INSTANTIATE_SOLVER(factorization_machine, 8, __VA_ARGS__); \
131  _INSTANTIATE_SOLVER(factorization_machine, -1, __VA_ARGS__); \
132 
133 
134 ////////////////////////////////////////////////////////////////////////////////
135 // Step 5: Now we know everything we need to create, instantiate, or
136 // suppress the solvers
137 
138 ////////////////////////////////////////
139 // The macro to suppress instantiation of the macro
140 
141 #define _SUPPRESS_SOLVER( \
142  factor_mode, num_factors_if_known, regularization_type, \
143  loss_type, solver_class) \
144  \
145  namespace turi { \
146  using solver_class; \
147  using namespace factorization; \
148  \
149  extern template class \
150  solver_class< \
151  factorization_sgd_interface< \
152  factorization_model_impl<model_factor_mode::factor_mode, \
153  num_factors_if_known>, \
154  loss_type, \
155  model_regularization_type::regularization_type> >; \
156  }
157 
158 ////////////////////////////////////////
159 // The macro to create full instantiation of the solver.
160 
161 #define _INSTANTIATE_SOLVER( \
162  factor_mode, num_factors_if_known, regularization_type, \
163  loss_type, solver_class) \
164  \
165  namespace turi { \
166  using solver_class; \
167  using namespace factorization; \
168  \
169  template class solver_class< \
170  factorization_sgd_interface< \
171  factorization_model_impl<model_factor_mode::factor_mode, \
172  num_factors_if_known>, \
173  loss_type, \
174  model_regularization_type::regularization_type> >; \
175  }
176 
177 ////////////////////////////////////////
178 // The main macro to actually create the model and the solver.
179 
180 #define _CREATE_AND_RETURN( \
181  factor_mode, num_factors_if_known, regularization_type, \
182  loss_type, solver_class, \
183  train_data, options) \
184  \
185  do { \
186  \
187  /* Set up the correct model type. */ \
188  typedef factorization_model_impl \
189  <model_factor_mode::factor_mode, \
190  num_factors_if_known> model_type; \
191  \
192  std::shared_ptr<model_type> model(new model_type); \
193  \
194  /* Set up the model with the correct loss. */ \
195  model->setup(loss_type::name(), train_data, options); \
196  \
197  /* Set up the correct interface type. */ \
198  typedef factorization_sgd_interface \
199  <model_type, loss_type, \
200  model_regularization_type::regularization_type> interface_type; \
201  \
202  /* Set up the interface. */ \
203  std::shared_ptr<sgd::sgd_interface_base> iface( \
204  new interface_type(model)); \
205  \
206  /* Set up the solver. */ \
207  std::shared_ptr<sgd::sgd_solver_base> solver( \
208  new solver_class<interface_type>( \
209  iface, train_data, options)); \
210  \
211  return {model, solver}; \
212  } while(false)
213 
214 
215 
216 namespace turi { namespace factorization {
217 
218 /** The main function to create a version of the model and the solver.
219  *
220  */
221 std::pair<std::shared_ptr<factorization_model>,
222  std::shared_ptr<sgd::sgd_solver_base> >
223 create_model_and_solver(const v2::ml_data& train_data,
224  std::map<std::string, flexible_type> options,
225  const std::string& loss_type,
226  const std::string& solver_class,
227  const std::string& regularization_type,
228  const std::string& factor_mode,
229  flex_int num_factors);
230 
231 }}
232 
233 #endif /* TURI_FACTORIZATION_MODEL_CREATION_FACTORY_H_ */