6 #ifndef TURI_FACTORIZATION_MODEL_CREATION_FACTORY_H_ 7 #define TURI_FACTORIZATION_MODEL_CREATION_FACTORY_H_ 9 #include <toolkits/factorization/factorization_model_sgd_interface.hpp> 10 #include <toolkits/factorization/factorization_model_impl.hpp> 11 #include <toolkits/sgd/basic_sgd_solver.hpp> 12 #include <toolkits/sgd/sgd_interface.hpp> 13 #include <toolkits/factorization/ranking_sgd_solver_explicit.hpp> 14 #include <toolkits/factorization/ranking_sgd_solver_implicit.hpp> 28 ASSERT_MSG(false, (std::string(#param) + " \'" + param + "\' not recognized.").c_str()); \ 29 ASSERT_UNREACHABLE(); \ 34 #define CREATE_RETURN_SOLVER(...) \ 36 if(solver_class == "sgd::basic_sgd_solver") \ 37 CREATE_RETURN_LOSS_NORMAL(sgd::basic_sgd_solver, __VA_ARGS__); \ 38 else if(solver_class == "factorization::explicit_ranking_sgd_solver") \ 39 CREATE_RETURN_LOSS_NORMAL(factorization::explicit_ranking_sgd_solver, __VA_ARGS__); \ 44 #define SUPPRESS_SOLVERS() \ 45 _SUPPRESS_BY_LOSS(sgd::basic_sgd_solver); \ 46 _SUPPRESS_BY_LOSS(factorization::explicit_ranking_sgd_solver); 54 #define CREATE_RETURN_LOSS_NORMAL(...) \ 56 if(loss_type == "loss_squared_error") \ 57 CREATE_RETURN_REGULARIZER(loss_squared_error, __VA_ARGS__); \ 58 else if(loss_type == "loss_logistic") \ 59 CREATE_RETURN_REGULARIZER(loss_logistic, __VA_ARGS__); \ 64 #define _SUPPRESS_BY_LOSS(...) \ 65 _SUPPRESS_BY_REGULARIZER(loss_squared_error, __VA_ARGS__); \ 66 _SUPPRESS_BY_REGULARIZER(loss_logistic, __VA_ARGS__); 72 #define CREATE_RETURN_REGULARIZER(...) \ 74 if(regularization_type == "L2") { \ 75 CREATE_RETURN_FACTORS(L2, __VA_ARGS__); \ 76 } else if(regularization_type == "ON_THE_FLY") { \ 77 CREATE_RETURN_FACTORS(ON_THE_FLY, __VA_ARGS__); \ 78 } else if(regularization_type == "NONE") { \ 80 CREATE_RETURN_FACTORS(L2, __VA_ARGS__); \ 82 _BAD(regularization_type); \ 86 #define _SUPPRESS_BY_REGULARIZER(...) \ 87 _SUPPRESS_BY_FACTORS(L2, __VA_ARGS__); \ 88 _SUPPRESS_BY_FACTORS(ON_THE_FLY, __VA_ARGS__); 91 #define _INSTANTIATE_LOSS_AND_SOLVER(...) \ 92 _INSTANTIATE_BY_FACTORS(L2, __VA_ARGS__); \ 93 _INSTANTIATE_BY_FACTORS(ON_THE_FLY, __VA_ARGS__); 98 #define CREATE_RETURN_FACTORS(...) \ 100 if(factor_mode == "pure_linear_model" || num_factors == 0) { \ 101 _CREATE_AND_RETURN(pure_linear_model, 0, __VA_ARGS__); \ 102 } else if(factor_mode == "matrix_factorization" \ 103 || ( ((factor_mode == "factorization_machine" \ 104 && (train_data.metadata()->num_columns() == 2))))) { \ 105 if(num_factors == 8) \ 106 _CREATE_AND_RETURN(matrix_factorization, 8, __VA_ARGS__); \ 108 _CREATE_AND_RETURN(matrix_factorization, -1, __VA_ARGS__); \ 109 } else if(factor_mode == "factorization_machine") { \ 110 if(num_factors == 8) \ 111 _CREATE_AND_RETURN(factorization_machine, 8, __VA_ARGS__); \ 113 _CREATE_AND_RETURN(factorization_machine, -1, __VA_ARGS__); \ 119 #define _SUPPRESS_BY_FACTORS(...) \ 120 _SUPPRESS_SOLVER(pure_linear_model, 0, __VA_ARGS__); \ 121 _SUPPRESS_SOLVER(matrix_factorization, 8, __VA_ARGS__); \ 122 _SUPPRESS_SOLVER(matrix_factorization, -1, __VA_ARGS__); \ 123 _SUPPRESS_SOLVER(factorization_machine, 8, __VA_ARGS__); \ 124 _SUPPRESS_SOLVER(factorization_machine, -1, __VA_ARGS__); \ 126 #define _INSTANTIATE_BY_FACTORS(...) \ 127 _INSTANTIATE_SOLVER(pure_linear_model, 0, __VA_ARGS__); \ 128 _INSTANTIATE_SOLVER(matrix_factorization, 8, __VA_ARGS__); \ 129 _INSTANTIATE_SOLVER(matrix_factorization, -1, __VA_ARGS__); \ 130 _INSTANTIATE_SOLVER(factorization_machine, 8, __VA_ARGS__); \ 131 _INSTANTIATE_SOLVER(factorization_machine, -1, __VA_ARGS__); \ 141 #define _SUPPRESS_SOLVER( \ 142 factor_mode, num_factors_if_known, regularization_type, \ 143 loss_type, solver_class) \ 146 using solver_class; \ 147 using namespace factorization; \ 149 extern template class \ 151 factorization_sgd_interface< \ 152 factorization_model_impl<model_factor_mode::factor_mode, \ 153 num_factors_if_known>, \ 155 model_regularization_type::regularization_type> >; \ 161 #define _INSTANTIATE_SOLVER( \ 162 factor_mode, num_factors_if_known, regularization_type, \ 163 loss_type, solver_class) \ 166 using solver_class; \ 167 using namespace factorization; \ 169 template class solver_class< \ 170 factorization_sgd_interface< \ 171 factorization_model_impl<model_factor_mode::factor_mode, \ 172 num_factors_if_known>, \ 174 model_regularization_type::regularization_type> >; \ 180 #define _CREATE_AND_RETURN( \ 181 factor_mode, num_factors_if_known, regularization_type, \ 182 loss_type, solver_class, \ 183 train_data, options) \ 188 typedef factorization_model_impl \ 189 <model_factor_mode::factor_mode, \ 190 num_factors_if_known> model_type; \ 192 std::shared_ptr<model_type> model(new model_type); \ 195 model->setup(loss_type::name(), train_data, options); \ 198 typedef factorization_sgd_interface \ 199 <model_type, loss_type, \ 200 model_regularization_type::regularization_type> interface_type; \ 203 std::shared_ptr<sgd::sgd_interface_base> iface( \ 204 new interface_type(model)); \ 207 std::shared_ptr<sgd::sgd_solver_base> solver( \ 208 new solver_class<interface_type>( \ 209 iface, train_data, options)); \ 211 return {model, solver}; \ 216 namespace turi {
namespace factorization {
221 std::pair<std::shared_ptr<factorization_model>,
222 std::shared_ptr<sgd::sgd_solver_base> >
223 create_model_and_solver(
const v2::ml_data& train_data,
224 std::map<std::string, flexible_type> options,
225 const std::string& loss_type,
226 const std::string& solver_class,
227 const std::string& regularization_type,
228 const std::string& factor_mode,