Turi Create  4.0
factorization_models.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_RECSYS_FACTORIZATION_MODELS_H_
7 #define TURI_RECSYS_FACTORIZATION_MODELS_H_
8 
9 #include <toolkits/recsys/recsys_model_base.hpp>
10 #include <toolkits/recsys/recsys_model_base.hpp>
11 
12 namespace turi {
13 
14 namespace v2 {
15 class ml_data;
16 }
17 
18 namespace factorization {
19 class factorization_model;
20 }
21 
22 namespace recsys {
23 
24 /** Implements all the factorization stuff -- a thin wrapper to the
25  * factorization models.
26  */
28 
29  public:
30  bool include_columns_beyond_user_item() const { return true; }
31 
32  void init_options(const std::map<std::string, flexible_type>& _options);
33 
34  std::map<std::string, flexible_type> train(const v2::ml_data& training_data);
35 
36  sframe predict(const v2::ml_data& test_data) const;
37 
38  sframe get_similar_items(std::shared_ptr<sarray<flexible_type>> indexed_items,
39  size_t k=0) const;
40 
41  void get_item_similarity_scores(
42  size_t item, std::vector<std::pair<size_t, double> >& sim_scores) const;
43 
44  sframe get_similar_users(std::shared_ptr<sarray<flexible_type>> indexed_items,
45  size_t k=0) const;
46 
47  protected:
48  sframe get_similar(size_t column, std::shared_ptr<sarray<flexible_type> > indexed_items, size_t k) const;
49 
50  private:
51  mutable mutex _get_similar_buffers_lock;
52  mutable std::vector<Eigen::Matrix<float, Eigen::Dynamic, 1> > _get_similar_buffers;
53 
54 
55  public:
56  void score_all_items(
57  std::vector<std::pair<size_t, double> >& scores,
58  const std::vector<v2::ml_data_entry>& query_row,
59  size_t top_k,
60  const std::vector<std::pair<size_t, double> >& user_item_list,
61  const std::vector<std::pair<size_t, double> >& new_user_item_data,
62  const std::vector<v2::ml_data_row_reference>& new_observation_data,
63  const std::shared_ptr<v2::ml_data_side_features>& known_side_features) const;
64 
65  static constexpr size_t RECSYS_FACTORIZATION_MODEL_VERSION = 1;
66 
67  inline size_t internal_get_version() const {
68  return RECSYS_FACTORIZATION_MODEL_VERSION;
69  }
70  void internal_save(turi::oarchive& oarc) const;
71 
72  void internal_load(turi::iarchive& iarc, size_t version);
73 
74  protected:
75 
76  /** This term determines whether we work in ranking factorization or
77  * not.
78  */
79  virtual bool include_ranking_options() const = 0;
80 
81  std::map<std::string, flexible_type> train(
82  const v2::ml_data& training_data_by_user,
83  const v2::ml_data& training_data_by_item);
84  private:
85  std::shared_ptr<factorization::factorization_model> model;
86 
87 };
88 
89 ////////////////////////////////////////////////////////////////////////////////
90 // Now the individual model definitions
91 
92 /** Implements factorization model.
93  */
95 
96  public:
97  bool use_target_column(bool target_is_present) const override { return true; }
98 
99  private:
100  bool include_ranking_options() const override { return false; }
101 
102  public:
103  // TODO: convert interface above to use the extensions methods here
104  BEGIN_CLASS_MEMBER_REGISTRATION("factorization_recommender")
107 };
108 
109 /** Implements linear_model.
110  */
112 
113  public:
114  bool use_target_column(bool target_is_present) const override { return target_is_present; }
115 
116  private:
117  bool include_ranking_options() const override { return true; }
118 
119  public:
120 
121  // TODO: convert interface above to use the extensions methods here
122  BEGIN_CLASS_MEMBER_REGISTRATION("ranking_factorization_recommender")
125 };
126 
127 
128 }}
129 
130 #endif
#define BEGIN_CLASS_MEMBER_REGISTRATION(python_facing_classname)
The serialization input archive object which, provided with a reference to an istream, will read from the istream, providing deserialization capabilities.
Definition: iarchive.hpp:60
#define IMPORT_BASE_CLASS_REGISTRATION(base_class)
#define END_CLASS_MEMBER_REGISTRATION
The serialization output archive object which, provided with a reference to an ostream, will write to the ostream, providing serialization capabilities.
Definition: oarchive.hpp:80