Turi Create  4.0
recsys_model_base.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_RECSYS_MODEL_BASE_H_
7 #define TURI_RECSYS_MODEL_BASE_H_
8 
9 #include <map>
10 #include <string>
11 #include <set>
12 
13 #include <model_server/lib/extensions/option_manager.hpp>
14 #include <core/data/sframe/gl_sframe.hpp>
15 #include <core/data/sframe/gl_sarray.hpp>
16 #include <model_server/lib/toolkit_function_specification.hpp>
17 #include <model_server/lib/unity_base_types.hpp>
18 #include <toolkits/ml_data_2/ml_data.hpp>
19 #include <toolkits/ml_data_2/ml_data_iterators.hpp>
20 #include <core/util/fast_top_k.hpp>
21 #include <toolkits/coreml_export/mlmodel_wrapper.hpp>
22 
23 
24 // Interfaces
25 #include <model_server/lib/extensions/ml_model.hpp>
26 #include <core/export.hpp>
27 
28 namespace turi {
29 
30 class iarchive;
31 class oarchive;
32 class sframe;
33 class flexible_type;
34 class column_metadata;
35 
36 namespace recsys {
37 
38 class recsys_popularity;
39 
40 
41 /** The base class for recsys model classes. Individual models are
42  * expected to implement all of the pure virtual functions below,
43  * along with (optionally) overriding any of the other virtual
44  * methods.
45  */
46 class EXPORT recsys_model_base : public ml_model_base {
47  public:
48 
49  static constexpr size_t RECSYS_MODEL_BASE_VERSION = 2;
50 
51  /// Default constructor
53 
54  virtual ~recsys_model_base() {}
55 
56  protected:
57 
58  /** Train the algorithm.
59  * Takes a training/validation split. Returns a map of information about the run.
60  */
61  virtual std::map<std::string, flexible_type> train(const v2::ml_data& training_data) = 0;
62 
63  /**
64  * Takes two datasets for training.
65  * \param[in] training_data_by_user ML-Data sorted by user
66  * \param[in] training_data_by_item ML-Data sorted by item
67  */
68  virtual std::map<std::string, flexible_type> train(
69  const v2::ml_data& training_data_by_user,
70  const v2::ml_data& training_data_by_item){
71  log_and_throw("Internal error. ALS not implemented");
72  }
73 
74  public:
75  virtual bool use_target_column(bool target_is_present) const = 0;
76  virtual bool include_columns_beyond_user_item() const { return false; }
77 
78  public:
79  /** Run predictions on each element in the test data set. Returns a
80  * vector corresponding to the response prediction of each
81  * observation in the test_data set.
82  * Also takes a ml_data in the same format containing observations
83  * that are considered "available" during prediction time.
84  */
85  virtual sframe predict(const v2::ml_data& test_data) const = 0;
86 
87  /**
88  * Get the nearest k users for each of the provided users.
89  * If no users are provided, then similar users are retrieved
90  * for all items observed during training.
91  */
92  virtual sframe get_similar_users(
93  std::shared_ptr<sarray<flexible_type> > users, size_t k) const = 0;
94 
95  /**
96  * Get the nearest k items for each of the provided items.
97  * If no items are provided, then similar items are retrieved
98  * for all items observed during training.
99  */
100  virtual sframe get_similar_items(
101  std::shared_ptr<sarray<flexible_type> > items, size_t k) const = 0;
102 
103  /** For each of the items in sim_scores (first part of tuple), sets
104  * a similarity score (second part of tuple) that is higher for
105  * items similar to item.
106  */
108  size_t item, std::vector<std::pair<size_t, double> >& sim_scores) const {}
109 
110  /**
111  * Returns information about all the users in the overlap of the
112  * item pairs listed in two columns in unindexed_item_pairs. All
113  * these items must be present in the training data.
114  *
115  * Returns an sframe with information about this
116  * intersection. Columns are item_1, item_2, num_users_1, num_users_2, item_intersection (dict, user ->
117  */
118  virtual sframe get_item_intersection_info(const sframe& unindexed_item_pairs) const;
119 
120 protected:
121 
122  /** Utility function to aid in the retrieval of similar items.
123  *
124  * GetSimilarFunction is a function called as
125  * f(size_t idx, std::vector<std::pair<size_t, double> >& idx_dist_dest);
126  */
127  template <typename GetSimilarFunction>
128  sframe _create_similar_sframe(size_t column_index, std::shared_ptr<sarray<flexible_type> > items,
129  size_t k, GetSimilarFunction&& similar) const;
130 
131 
132 public:
133 
134  /** For a given base observation, predict the score for all the
135  * items with all non-item columns replaced by the values in the
136  * base observation.
137  *
138  * The base_observation vector is used to generate all the
139  * observations predicted. New observations are generated by
140  * repeatedly copying template_observation, then replacing the
141  * values in item_column_index by each possible item value.
142  */
143  virtual void score_all_items(
144  std::vector<std::pair<size_t, double> >& scores,
145  const std::vector<v2::ml_data_entry>& query_row,
146  size_t top_k,
147  const std::vector<std::pair<size_t, double> >& user_item_list,
148  const std::vector<std::pair<size_t, double> >& new_user_item_data,
149  const std::vector<v2::ml_data_row_reference>& new_observation_data,
150  const std::shared_ptr<v2::ml_data_side_features>& known_side_features) const = 0;
151 
152 
153  // Set additional data for the method
154  virtual void set_extra_data(const std::map<std::string, variant_type>& other_data) {}
155 
156  protected:
157  virtual size_t internal_get_version() const = 0;
158 
159  /** Implement serialization (save). The model subclass should
160  * reimplement this particular function. The syntax follows the
161  * standard turicreate save() method.
162  */
163  virtual void internal_save(turi::oarchive& oarc) const = 0;
164 
165  /** Implement serialization (load). The model subclass should
166  * reimplement this particular function. The syntax follows the
167  * standard turicreate load() method.
168  *
169  * When this method is called, all the model options have been set
170  * up in the base class and are readily accessible. Furthermore,
171  * once this function is called, the model is treated as trained and
172  * ready to be used for prediction and ranking. Thus loading a model
173  * can effectively replace the training stage.
174  */
175  virtual void internal_load(turi::iarchive& iarc, size_t version) = 0;
176 
177  ////////////////////////////////////////////////////////////////////////////////
178  //
179  // Interacting with the data set by the train part of the model
180  //
181  ////////////////////////////////////////////////////////////////////////////////
182 
183  public:
184 
185  /// The metadata needed for translating the data back and forth
186  static constexpr size_t USER_COLUMN_INDEX = 0;
187  static constexpr size_t ITEM_COLUMN_INDEX = 1;
188 
189  std::shared_ptr<v2::ml_metadata> metadata;
190  std::shared_ptr<sarray<std::vector<std::pair<size_t, double> > > > trained_user_items;
191 
192  /** Creates an ml_data object according to the given schema. No
193  * target column.
194  */
195  v2::ml_data create_ml_data(const sframe& data,
196  const sframe& new_user_side_data=sframe(),
197  const sframe& new_item_side_data=sframe()) const;
198 
199  private:
200  /// Returns an sframe with the columns renamed such that they will
201  /// not conflict with anything.
202  sframe _sanitize_side_column_names(size_t main_index, sframe side_table) const;
203 
204  public:
205 
206  /** Returns the flexible data type of the user column;
207  * The model must be trained at this point.
208  */
210  return metadata->column_type(USER_COLUMN_INDEX);
211  }
212 
213  /** Returns the flexible data type of the item column;
214  * The model must be trained at this point.
215  */
217  return metadata->column_type(ITEM_COLUMN_INDEX);
218  }
219 
220  ////////////////////////////////////////////////////////////////////////////////
221  //
222  // The methods for train, test, etc.
223  //
224  ////////////////////////////////////////////////////////////////////////////////
225 
226  public:
227 
228  /** Train the model using an sframe as the primary observations.
229  * This method constructs the internal ml_data objects from the
230  * current options.
231  *
232  * \param observation_data An SFrame containing at least a column containing
233  * user ids and a column containing item ids.
234  * \param user_side_data An SFrame containing side information about users,
235  * where one column matches with the user column of observation data.
236  * \param item_side_data An SFrame containing side information about items,
237  * where one column matches with the item column of observation data.
238  * \param other_data When provided, each model can implement a method set_extra_data
239  * in order to use this argument during training.
240  * \returns Statistics about the training.
241  */
242  void setup_and_train(const sframe& observation_data,
243  const sframe& user_side_data=sframe(),
244  const sframe& item_side_data=sframe(),
245  const std::map<std::string, variant_type>& other_data=(std::map<std::string, variant_type>()));
246 
247  /** Some of the models, such as popularity, can be built entirely
248  * from data already contained in the model. This method allows us
249  * to create a new model while bypassing the typical
250  * setup_and_train method. This simply imports all the relevant
251  * variables over; the final training is left up to the model.
252  */
253  void import_all_from_other_model(const recsys_model_base* other);
254 
255  recsys_model_base& operator=(const recsys_model_base&) = default;
256 
257  gl_sframe api_get_similar_items(gl_sarray items, size_t k, size_t verbose, int get_all_items) const;
258 
259  gl_sframe api_get_similar_users(gl_sarray users, size_t k, int get_all_users) const;
260 
261 
262  gl_sframe api_predict(gl_sframe data_to_predict, gl_sframe new_user_data, gl_sframe new_item_data) const;
263  variant_map_type api_set_current_options(std::map<std::string, flexible_type> options);
264 
265  void api_train(gl_sframe _dataset, gl_sframe _user_data, gl_sframe _item_data,
266  const std::map<std::string, flexible_type>& opts,
267  const variant_map_type& extra_data);
268 
269  variant_map_type api_get_current_options();
270 
271  gl_sframe api_recommend(gl_sframe _query, gl_sframe _exclude, gl_sframe _restrictions, gl_sframe _new_data, gl_sframe _new_user_data,
272  gl_sframe new_item_data, bool exclude_training_interactions, size_t top_k, double diversity, size_t random_seed);
273 
274  gl_sframe api_get_item_intersection_info(gl_sframe item_pairs);
275 
276  gl_sframe api_precision_recall_by_user(gl_sframe validation_data, gl_sframe recommend_output, const std::vector<size_t>& cutoffs);
277 
278  variant_map_type api_get_train_stats();
279 
280  EXPORT variant_map_type api_get_data_schema();
281 
282 
283  /** Creates and returns a popularity baseline
284  *
285  */
286  std::shared_ptr<recsys_model_base> get_popularity_baseline() const;
287 
288  flex_dict get_data_schema() const;
289 
290 private:
291 
292  /** Choose some things diversely.
293  */
294  struct diversity_choice_buffer {
295  std::vector<size_t> current_candidates;
296  std::vector<size_t> chosen_items;
297  std::vector<size_t> current_diversity_score;
298  std::vector<std::pair<size_t, double> > sim_scores;
299  };
300 
301  void choose_diversely(size_t top_k,
302  std::vector<std::pair<size_t, double> >& candidates,
303  size_t random_seed,
304  diversity_choice_buffer& dv_buffer) const;
305 
306 public:
307 
308  /** Return the top_k ranks for this model based on sorted
309  * predictions.
310  *
311  * Here, for each user in users, the top_k ranks are returned in the
312  * same format as the previous function.
313  *
314  * If exclude_observations is given, these observations are excluded
315  * from the returned values.
316  *
317  * \overload
318  */
319  sframe recommend(const sframe& reference_data,
320  size_t top_k,
321  const sframe& restriction_data = sframe(),
322  const sframe& exclusion_data = sframe(),
323  const sframe& new_observation_data = sframe(),
324  const sframe& new_user_data = sframe(),
325  const sframe& new_item_data = sframe(),
326  bool exclude_training_interactions = true,
327  double diversity_factor = 0,
328  size_t random_seed = 0) const;
329 
330  std::shared_ptr<unity_sframe_base> recommend_extension_wrapper(
331  std::shared_ptr<unity_sframe_base> reference_data,
332  std::shared_ptr<unity_sframe_base> new_observation_data,
333  flex_int top_k) const;
334 
335  std::shared_ptr<unity_sframe_base> get_num_users_per_item_extension_wrapper() const;
336 
337  std::shared_ptr<unity_sframe_base> get_num_items_per_user_extension_wrapper() const;
338 
339  virtual std::shared_ptr<coreml::MLModelWrapper> export_to_coreml(
340  const std::string& filename,
341  const std::map<std::string, flexible_type>& additional_user_defined);
342 
343  variant_map_type summary();
344 
345  /**
346  * Compute the precision and recall for a (potentially held out) set of
347  * observations.
348  *
349  * \param validation_data A ml_data giving the validation set the
350  * precision and recall should be calculated on.
351  *
352  * \param recommend_output The output of the recommend method. Note
353  * that recommend should be called with top_k larger than the max
354  * value in cutoffs.
355  *
356  * \param cutoffs A vector of cutoffs for computing e.g. the top
357  * [5,10,50] rankings.
358  *
359  * \return An sframe with 5 columns -- user, cutoff, precision,
360  * recall, and item counts.
361  */
362  sframe precision_recall_stats(const sframe& indexed_validation_data,
363  const sframe& recommend_output,
364  const std::vector<size_t>& cutoffs) const;
365 
366 
367  /**
368  * Return an SFrame containing each user id and the number of
369  * observations with that user in the training set.
370  */
371  sframe get_num_items_per_user() const;
372 
373  /**
374  * Return an SFrame containing each item and the number of
375  * observations with that item in the training set.
376  */
377  sframe get_num_users_per_item() const;
378 
379 
380  inline size_t get_version() const override {
381  return RECSYS_MODEL_BASE_VERSION;
382  }
383 
384  /// Serialization -- save
385  virtual void save_impl(turi::oarchive& oarc) const override;
386 
387  /// Serialization -- load
388  void load_version(turi::iarchive& iarc, size_t version) override;
389 
390  /// Get stats about algorithm runtime
391  std::map<std::string, flexible_type> get_train_stats();
392 
393 
396 
397  REGISTER_NAMED_CLASS_MEMBER_FUNCTION("get_similar_items",
398  recsys_model_base::api_get_similar_items,
399  "items", "k", "verbose",
400  "get_all_items");
401 
404  "options");
405 
406  REGISTER_NAMED_CLASS_MEMBER_FUNCTION("get_similar_users",
407  recsys_model_base::api_get_similar_users,
408  "users", "k", "get_all_users");
409 
411  recsys_model_base::api_predict,
412  "data_to_predict", "new_user_data",
413  "new_item_data");
414 
415  REGISTER_NAMED_CLASS_MEMBER_FUNCTION("train", recsys_model_base::api_train,
416  "dataset", "user_data", "item_data",
417  "opts", "extra_data");
418 
420  "recommend", recsys_model_base::api_recommend, "query", "exclude",
421  "restrictions", "new_data", "new_user_data", "new_item_data",
422  "exclude_training_interactions", "top_k", "diversity", "random_seed");
423 
424  register_defaults("recommend",
425  {{"exclude", gl_sframe()},
426  {"restrictions", gl_sframe()},
427  {"new_data", gl_sframe()},
428  {"new_user_data", gl_sframe()},
429  {"new_item_data", gl_sframe()},
430  {"exclude_training_interactions", true},
431  {"diversity", 0},
432  {"random_seed", 1}});
433 
435  "get_current_options", recsys_model_base::api_get_current_options);
436 
437  REGISTER_NAMED_CLASS_MEMBER_FUNCTION("get_num_users_per_item", recsys_model_base::get_num_users_per_item_extension_wrapper);
438 
439  REGISTER_NAMED_CLASS_MEMBER_FUNCTION("get_num_items_per_user", recsys_model_base::get_num_items_per_user_extension_wrapper);
440 
441  REGISTER_NAMED_CLASS_MEMBER_FUNCTION("summary", recsys_model_base::summary);
442 
444  "get_popularity_baseline", recsys_model_base::get_popularity_baseline);
445 
447  "get_item_intersection_info",
448  recsys_model_base::api_get_item_intersection_info, "item_pairs");
449 
450  REGISTER_NAMED_CLASS_MEMBER_FUNCTION("export_to_coreml",
451  recsys_model_base::export_to_coreml,
452  "filename", "additional_user_defined");
453  register_defaults("export_to_coreml",
454  {{"additional_user_defined", to_variant(std::map<std::string, flexible_type>())}});
455 
457  "precision_recall_by_user", recsys_model_base::api_precision_recall_by_user,
458  "indexed_validation_data", "recommend_output", "cutoffs");
459 
460  REGISTER_NAMED_CLASS_MEMBER_FUNCTION("get_data_schema",
461  recsys_model_base::api_get_data_schema);
462 
464 
465  REGISTER_CLASS_MEMBER_FUNCTION(recsys_model_base::recommend_extension_wrapper,
466  "reference_data", "new_observation_data",
467  "top_k")
468 
470 };
471 
472 ////////////////////////////////////////////////////////////////////////////////
473 // Implementation of the get_similar utility functions
474 
475 template <typename GetSimilarFunction>
477  size_t column_index, std::shared_ptr<sarray<flexible_type> > query,
478  size_t k, GetSimilarFunction&& similar) const {
479 
480  sframe res;
481  size_t num_segments = thread::cpu_count();
482 
483  const bool use_all_values = (query == nullptr);
484 
485  size_t n = use_all_values ? metadata->index_size(column_index) : query->size();
486 
487  decltype(query->get_reader()) reader;
488 
489  if(!use_all_values) {
490  reader = query->get_reader();
491  }
492 
493  auto indexer = metadata->indexer(column_index);
494 
495  flex_type_enum t = metadata->column_type(column_index);
496 
497  res.open_for_write(
498  {metadata->column_name(column_index), "similar", "score", "rank"},
500  "", num_segments);
501 
502  in_parallel([&](size_t thread_idx, size_t num_threads) {
503 
504  std::vector<flexible_type> data;
505 
506  Eigen::Matrix<float, Eigen::Dynamic, 1> similarities;
507  typedef std::pair<size_t, double> item_score_pair;
508  std::vector<item_score_pair> score_list(metadata->index_size(column_index));
509 
510  auto it_out = res.get_output_iterator(thread_idx);
511 
512  size_t start_idx = (n * thread_idx) / num_threads;
513  size_t end_idx = (n * (thread_idx+1)) / num_threads;
514  size_t n_in_block = 1000;
515 
516  for(size_t block_start = start_idx; block_start < end_idx; block_start += 1000) {
517 
518  if(!use_all_values) {
519  reader->read_rows(block_start, std::min(end_idx, block_start + 1000), data);
520  n_in_block = data.size();
521  } else {
522  n_in_block = std::min(end_idx, block_start + 1000) - block_start;
523  }
524 
525  for(size_t i = 0; i < n_in_block; ++i) {
526 
527  size_t query_idx = use_all_values ? block_start + i : indexer->immutable_map_value_to_index(data[i]);
528 
529  if(query_idx == static_cast<size_t>(-1))
530  continue;
531 
532  similar(query_idx, score_list);
533 
534  // Assume that higher scores are better.
535  auto score_sorter = [](const item_score_pair& vi1, const item_score_pair& vi2) {
536  return vi1.second < vi2.second;
537  };
538 
539  // get an extra item in case the query_idx is in there.
540  // Then ignore it if it is.
541  extract_and_sort_top_k(score_list, k + 1, score_sorter);
542 
543  // Write out the top-k items in indexed format to an output
544  // sframe with the original values. Skip the current query
545  // index if it's in there.
546  const flexible_type& query_item = use_all_values ? indexer->map_index_to_value(query_idx) : data[i];
547 
548  for(size_t j = 0, rank = 1; j < score_list.size(); ++j, ++it_out) {
549  if(score_list[j].first == query_idx)
550  continue;
551 
552 
553  flexible_type ref_datum = indexer->map_index_to_value(score_list[j].first);
554  *it_out = {query_item, ref_datum, score_list[j].second, rank};
555  ++rank;
556 
557  if(rank > k)
558  break;
559  }
560  }
561  }
562  });
563 
564  res.close();
565  return res;
566 }
567 
568 variant_map_type train_test_split(gl_sframe _dataset,
569  const std::string& user_column,
570  const std::string& item_column,
571  flexible_type max_num_users,
572  double item_test_proportion,
573  size_t random_seed);
574 
575 variant_map_type init(variant_map_type& params);
576 
577 variant_map_type get_train_stats(variant_map_type& params);
578 
579 std::vector<toolkit_function_specification> get_toolkit_function_registration();
580 
581 }}
582 
583 #endif /* TURI_RECSYS_ALGORITHM_TEMPLATE_H_ */
std::shared_ptr< recsys_model_base > get_popularity_baseline() const
#define REGISTER_CLASS_MEMBER_FUNCTION(function,...)
The serialization input archive object which, provided with a reference to an istream, will read from the istream, providing deserialization capabilities.
Definition: iarchive.hpp:60
void extract_and_sort_top_k(std::vector< T > &v, size_t top_k, LessThan less_than)
Definition: fast_top_k.hpp:87
virtual void get_item_similarity_scores(size_t item, std::vector< std::pair< size_t, double > > &sim_scores) const
#define BEGIN_BASE_CLASS_MEMBER_REGISTRATION()
recsys_model_base()
Default constructor.
virtual void init_options(const std::map< std::string, flexible_type > &_options)
Definition: ml_model.hpp:80
#define IMPORT_BASE_CLASS_REGISTRATION(base_class)
iterator get_output_iterator(size_t segmentid)
sframe _create_similar_sframe(size_t column_index, std::shared_ptr< sarray< flexible_type > > items, size_t k, GetSimilarFunction &&similar) const
#define END_CLASS_MEMBER_REGISTRATION
static size_t cpu_count()
virtual std::map< std::string, flexible_type > train(const v2::ml_data &training_data_by_user, const v2::ml_data &training_data_by_item)
#define REGISTER_NAMED_CLASS_MEMBER_FUNCTION(name, function,...)
void open_for_write(const std::vector< std::string > &column_names, const std::vector< flex_type_enum > &column_types, const std::string &frame_sidx_file="", size_t nsegments=SFRAME_DEFAULT_NUM_SEGMENTS, bool fail_on_column_names=true)
Definition: sframe.hpp:265
variant_type to_variant(const T &f)
Definition: variant.hpp:308
size_t get_version() const override
void in_parallel(const std::function< void(size_t thread_id, size_t num_threads)> &fn)
Definition: lambda_omp.hpp:35
The serialization output archive object which, provided with a reference to an ostream, will write to the ostream, providing serialization capabilities.
Definition: oarchive.hpp:80
std::vector< std::pair< flexible_type, flexible_type > > flex_dict
std::map< std::string, flexible_type > get_train_stats()
Get stats about algorithm runtime.