Turi Create  4.0
popularity.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_RECSYS_MODEL_POPULARITY_H_
7 #define TURI_RECSYS_MODEL_POPULARITY_H_
8 
9 #include <vector>
10 #include <string>
11 #include <toolkits/recsys/recsys_model_base.hpp>
12 #include <toolkits/nearest_neighbors/nearest_neighbors.hpp>
13 #include <toolkits/nearest_neighbors/ball_tree_neighbors.hpp>
14 
15 
16 namespace turi {
17 
18 namespace v2 {
19 class ml_data;
20 }
21 
22 class sframe;
23 class flexible_type;
24 
25 namespace recsys {
26 
27 class EXPORT recsys_popularity : public recsys_model_base {
28  public:
29 
30  // Implement the bare minimum of the pure virtual methods
31  void init_options(const std::map<std::string, flexible_type>&_options) override;
32 
33  bool use_target_column(bool target_is_present) const override {
34  return target_is_present;
35  }
36 
37  /** Creates and trains the model. Training can be done either
38  * through the ml_data version, or the sarray of item-target pairs.
39  *
40  * At the end of training, the state variable "item_predictions"
41  * holds the predicted value of each of the items.
42  */
43  std::map<std::string, flexible_type> train(const v2::ml_data& data) override;
44 
45  #ifdef __clang__
46  #pragma clang diagnostic push
47  #pragma clang diagnostic ignored "-Woverloaded-virtual" // TODO: fix this issue
48  #endif
49  std::map<std::string, flexible_type> train(
50  std::shared_ptr<sarray<std::vector<std::pair<size_t, double> > > > trained_user_item);
51  #ifdef __clang__
52  #pragma clang diagnostic pop
53  #endif
54 
55  sframe predict(const v2::ml_data& test_data) const override;
56 
57  sframe get_similar_items(std::shared_ptr<sarray<flexible_type>> items,
58  size_t k=0) const override;
59 
60  sframe get_similar_users(std::shared_ptr<sarray<flexible_type>> users,
61  size_t k=0) const override;
62 
63  void score_all_items(
64  std::vector<std::pair<size_t, double> >& scores,
65  const std::vector<v2::ml_data_entry>& query_row,
66  size_t top_k,
67  const std::vector<std::pair<size_t, double> >& user_item_list,
68  const std::vector<std::pair<size_t, double> >& new_user_item_data,
69  const std::vector<v2::ml_data_row_reference>& new_observation_data,
70  const std::shared_ptr<v2::ml_data_side_features>& known_side_features) const override;
71 
72  /////////////////////////////////////////////////////////////////////////////////
73  // Save and load stuff
74  static constexpr size_t POPULARITY_RECOMMENDER_VERSION = 0;
75 
76  inline size_t internal_get_version() const override {
77  return POPULARITY_RECOMMENDER_VERSION;
78  }
79  void internal_save(turi::oarchive& oarc) const override;
80  void internal_load(turi::iarchive& iarc, size_t version) override;
81 
82  private:
83  std::vector<double> item_predictions;
84  double unseen_item_prediction;
85  std::shared_ptr<nearest_neighbors::ball_tree_neighbors> nearest_items_model;
86 
87  public:
89  IMPORT_BASE_CLASS_REGISTRATION(recsys_model_base)
91 
92 };
93 
94 }}
95 
96 #endif /* TURI_RECSYS_MODEL_POP_COUNT_H_ */
#define BEGIN_CLASS_MEMBER_REGISTRATION(python_facing_classname)
The serialization input archive object which, provided with a reference to an istream, will read from the istream, providing deserialization capabilities.
Definition: iarchive.hpp:60
#define IMPORT_BASE_CLASS_REGISTRATION(base_class)
#define END_CLASS_MEMBER_REGISTRATION
The serialization output archive object which, provided with a reference to an ostream, will write to the ostream, providing serialization capabilities.
Definition: oarchive.hpp:80