6 #ifndef TURI_RECSYS_MODEL_BASE_H_ 7 #define TURI_RECSYS_MODEL_BASE_H_ 13 #include <model_server/lib/extensions/option_manager.hpp> 14 #include <core/data/sframe/gl_sframe.hpp> 15 #include <core/data/sframe/gl_sarray.hpp> 16 #include <model_server/lib/toolkit_function_specification.hpp> 17 #include <model_server/lib/unity_base_types.hpp> 18 #include <toolkits/ml_data_2/ml_data.hpp> 19 #include <toolkits/ml_data_2/ml_data_iterators.hpp> 20 #include <core/util/fast_top_k.hpp> 21 #include <toolkits/coreml_export/mlmodel_wrapper.hpp> 25 #include <model_server/lib/extensions/ml_model.hpp> 26 #include <core/export.hpp> 34 class column_metadata;
38 class recsys_popularity;
49 static constexpr
size_t RECSYS_MODEL_BASE_VERSION = 2;
61 virtual std::map<std::string, flexible_type> train(
const v2::ml_data& training_data) = 0;
68 virtual std::map<std::string, flexible_type>
train(
69 const v2::ml_data& training_data_by_user,
70 const v2::ml_data& training_data_by_item){
71 log_and_throw(
"Internal error. ALS not implemented");
75 virtual bool use_target_column(
bool target_is_present)
const = 0;
76 virtual bool include_columns_beyond_user_item()
const {
return false; }
85 virtual sframe predict(
const v2::ml_data& test_data)
const = 0;
92 virtual sframe get_similar_users(
100 virtual sframe get_similar_items(
108 size_t item, std::vector<std::pair<size_t, double> >& sim_scores)
const {}
118 virtual sframe get_item_intersection_info(
const sframe& unindexed_item_pairs)
const;
127 template <
typename GetSimilarFunction>
129 size_t k, GetSimilarFunction&& similar)
const;
143 virtual void score_all_items(
144 std::vector<std::pair<size_t, double> >& scores,
145 const std::vector<v2::ml_data_entry>& query_row,
147 const std::vector<std::pair<size_t, double> >& user_item_list,
148 const std::vector<std::pair<size_t, double> >& new_user_item_data,
149 const std::vector<v2::ml_data_row_reference>& new_observation_data,
150 const std::shared_ptr<v2::ml_data_side_features>& known_side_features)
const = 0;
154 virtual void set_extra_data(
const std::map<std::string, variant_type>& other_data) {}
157 virtual size_t internal_get_version()
const = 0;
175 virtual void internal_load(
turi::iarchive& iarc,
size_t version) = 0;
186 static constexpr
size_t USER_COLUMN_INDEX = 0;
187 static constexpr
size_t ITEM_COLUMN_INDEX = 1;
189 std::shared_ptr<v2::ml_metadata> metadata;
190 std::shared_ptr<sarray<std::vector<std::pair<size_t, double> > > > trained_user_items;
195 v2::ml_data create_ml_data(
const sframe& data,
202 sframe _sanitize_side_column_names(
size_t main_index,
sframe side_table)
const;
210 return metadata->column_type(USER_COLUMN_INDEX);
217 return metadata->column_type(ITEM_COLUMN_INDEX);
242 void setup_and_train(
const sframe& observation_data,
245 const std::map<std::string, variant_type>& other_data=(std::map<std::string, variant_type>()));
257 gl_sframe api_get_similar_items(
gl_sarray items,
size_t k,
size_t verbose,
int get_all_items)
const;
263 variant_map_type api_set_current_options(std::map<std::string, flexible_type> options);
266 const std::map<std::string, flexible_type>& opts,
267 const variant_map_type& extra_data);
269 variant_map_type api_get_current_options();
272 gl_sframe new_item_data,
bool exclude_training_interactions,
size_t top_k,
double diversity,
size_t random_seed);
278 variant_map_type api_get_train_stats();
280 EXPORT variant_map_type api_get_data_schema();
286 std::shared_ptr<recsys_model_base> get_popularity_baseline()
const;
294 struct diversity_choice_buffer {
295 std::vector<size_t> current_candidates;
296 std::vector<size_t> chosen_items;
297 std::vector<size_t> current_diversity_score;
298 std::vector<std::pair<size_t, double> > sim_scores;
301 void choose_diversely(
size_t top_k,
302 std::vector<std::pair<size_t, double> >& candidates,
304 diversity_choice_buffer& dv_buffer)
const;
326 bool exclude_training_interactions =
true,
327 double diversity_factor = 0,
328 size_t random_seed = 0)
const;
330 std::shared_ptr<unity_sframe_base> recommend_extension_wrapper(
331 std::shared_ptr<unity_sframe_base> reference_data,
332 std::shared_ptr<unity_sframe_base> new_observation_data,
335 std::shared_ptr<unity_sframe_base> get_num_users_per_item_extension_wrapper()
const;
337 std::shared_ptr<unity_sframe_base> get_num_items_per_user_extension_wrapper()
const;
339 virtual std::shared_ptr<coreml::MLModelWrapper> export_to_coreml(
340 const std::string& filename,
341 const std::map<std::string, flexible_type>& additional_user_defined);
343 variant_map_type summary();
362 sframe precision_recall_stats(
const sframe& indexed_validation_data,
363 const sframe& recommend_output,
364 const std::vector<size_t>& cutoffs)
const;
371 sframe get_num_items_per_user()
const;
377 sframe get_num_users_per_item()
const;
381 return RECSYS_MODEL_BASE_VERSION;
391 std::map<std::string, flexible_type> get_train_stats();
398 recsys_model_base::api_get_similar_items,
399 "items",
"k",
"verbose",
407 recsys_model_base::api_get_similar_users,
408 "users",
"k",
"get_all_users");
411 recsys_model_base::api_predict,
412 "data_to_predict",
"new_user_data",
416 "dataset",
"user_data",
"item_data",
417 "opts",
"extra_data");
420 "recommend", recsys_model_base::api_recommend,
"query",
"exclude",
421 "restrictions",
"new_data",
"new_user_data",
"new_item_data",
422 "exclude_training_interactions",
"top_k",
"diversity",
"random_seed");
424 register_defaults(
"recommend",
430 {
"exclude_training_interactions",
true},
432 {
"random_seed", 1}});
435 "get_current_options", recsys_model_base::api_get_current_options);
447 "get_item_intersection_info",
448 recsys_model_base::api_get_item_intersection_info,
"item_pairs");
451 recsys_model_base::export_to_coreml,
452 "filename",
"additional_user_defined");
453 register_defaults(
"export_to_coreml",
454 {{
"additional_user_defined",
to_variant(std::map<std::string, flexible_type>())}});
457 "precision_recall_by_user", recsys_model_base::api_precision_recall_by_user,
458 "indexed_validation_data",
"recommend_output",
"cutoffs");
461 recsys_model_base::api_get_data_schema);
466 "reference_data",
"new_observation_data",
475 template <
typename GetSimilarFunction>
478 size_t k, GetSimilarFunction&& similar)
const {
483 const bool use_all_values = (query ==
nullptr);
485 size_t n = use_all_values ? metadata->index_size(column_index) : query->size();
487 decltype(query->get_reader()) reader;
489 if(!use_all_values) {
490 reader = query->get_reader();
493 auto indexer = metadata->indexer(column_index);
498 {metadata->column_name(column_index),
"similar",
"score",
"rank"},
502 in_parallel([&](
size_t thread_idx,
size_t num_threads) {
504 std::vector<flexible_type> data;
506 Eigen::Matrix<float, Eigen::Dynamic, 1> similarities;
507 typedef std::pair<size_t, double> item_score_pair;
508 std::vector<item_score_pair> score_list(metadata->index_size(column_index));
512 size_t start_idx = (n * thread_idx) / num_threads;
513 size_t end_idx = (n * (thread_idx+1)) / num_threads;
514 size_t n_in_block = 1000;
516 for(
size_t block_start = start_idx; block_start < end_idx; block_start += 1000) {
518 if(!use_all_values) {
519 reader->read_rows(block_start, std::min(end_idx, block_start + 1000), data);
520 n_in_block = data.size();
522 n_in_block = std::min(end_idx, block_start + 1000) - block_start;
525 for(
size_t i = 0; i < n_in_block; ++i) {
527 size_t query_idx = use_all_values ? block_start + i : indexer->immutable_map_value_to_index(data[i]);
529 if(query_idx == static_cast<size_t>(-1))
532 similar(query_idx, score_list);
535 auto score_sorter = [](
const item_score_pair& vi1,
const item_score_pair& vi2) {
536 return vi1.second < vi2.second;
546 const flexible_type& query_item = use_all_values ? indexer->map_index_to_value(query_idx) : data[i];
548 for(
size_t j = 0, rank = 1; j < score_list.size(); ++j, ++it_out) {
549 if(score_list[j].first == query_idx)
553 flexible_type ref_datum = indexer->map_index_to_value(score_list[j].first);
554 *it_out = {query_item, ref_datum, score_list[j].second, rank};
568 variant_map_type train_test_split(
gl_sframe _dataset,
569 const std::string& user_column,
570 const std::string& item_column,
572 double item_test_proportion,
575 variant_map_type init(variant_map_type& params);
577 variant_map_type get_train_stats(variant_map_type& params);
579 std::vector<toolkit_function_specification> get_toolkit_function_registration();
std::shared_ptr< recsys_model_base > get_popularity_baseline() const
#define REGISTER_CLASS_MEMBER_FUNCTION(function,...)
The serialization input archive object which, provided with a reference to an istream, will read from the istream, providing deserialization capabilities.
void extract_and_sort_top_k(std::vector< T > &v, size_t top_k, LessThan less_than)
virtual void get_item_similarity_scores(size_t item, std::vector< std::pair< size_t, double > > &sim_scores) const
#define BEGIN_BASE_CLASS_MEMBER_REGISTRATION()
recsys_model_base()
Default constructor.
virtual void init_options(const std::map< std::string, flexible_type > &_options)
#define IMPORT_BASE_CLASS_REGISTRATION(base_class)
iterator get_output_iterator(size_t segmentid)
sframe _create_similar_sframe(size_t column_index, std::shared_ptr< sarray< flexible_type > > items, size_t k, GetSimilarFunction &&similar) const
#define END_CLASS_MEMBER_REGISTRATION
static size_t cpu_count()
virtual std::map< std::string, flexible_type > train(const v2::ml_data &training_data_by_user, const v2::ml_data &training_data_by_item)
#define REGISTER_NAMED_CLASS_MEMBER_FUNCTION(name, function,...)
void open_for_write(const std::vector< std::string > &column_names, const std::vector< flex_type_enum > &column_types, const std::string &frame_sidx_file="", size_t nsegments=SFRAME_DEFAULT_NUM_SEGMENTS, bool fail_on_column_names=true)
variant_type to_variant(const T &f)
size_t get_version() const override
void in_parallel(const std::function< void(size_t thread_id, size_t num_threads)> &fn)
The serialization output archive object which, provided with a reference to an ostream, will write to the ostream, providing serialization capabilities.
flex_type_enum item_type() const
std::vector< std::pair< flexible_type, flexible_type > > flex_dict
std::map< std::string, flexible_type > get_train_stats()
Get stats about algorithm runtime.
flex_type_enum user_type() const