Turi Create  4.0
sparse_nn.hpp
1 /* Copyright © 2018 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef SPARSE_NN_HPP
7 #define SPARSE_NN_HPP
8 
9 #include <core/export.hpp>
10 #include <model_server/lib/extensions/ml_model.hpp>
11 #include <core/data/sframe/gl_sframe.hpp>
12 #include <model_server/lib/toolkit_class_macros.hpp>
13 
14 namespace turi {
15 namespace prototype {
16 
17 class EXPORT sparse_nn : public ml_model_base {
18  public:
19  static constexpr size_t SPARSE_NN_VERSION = 0;
20 
21  typedef uint128_t hash_type;
22 
23  // Call this function to set up the training data. All columns must be string
24  // or integer columns and are treated as categorical indicators.
25  void train(const gl_sframe& data, const std::string& id_column);
26 
27  // Returns a dictionary of id to distance for the top k values
28  flex_dict query(const flex_dict& x, size_t k) const;
29 
30  // Serialization
31  void save_impl(oarchive& oarc) const override;
32  void load_version(iarchive& iarc, size_t version) override;
33 
35  IMPORT_BASE_CLASS_REGISTRATION(ml_model_base);
36  REGISTER_CLASS_MEMBER_FUNCTION(sparse_nn::train, "data", "id_column");
37  REGISTER_CLASS_MEMBER_FUNCTION(sparse_nn::query, "x", "k");
39 
40  private:
41  // Some metadata.
42  size_t m_num_columns = 0;
43  std::vector<flexible_type> m_ids;
44 
45  // This is optimized for extremely efficient lookup on small, in-memory
46  // datasets.
47  //
48  // Query Algorithm:
49  //
50  // 1. For each feature, do bisection search to determine which index is the
51  // correct one for the given column/feature combo.
52  //
53  // 2. Increment the corresponding values in hit_indices that are denoted by
54  // the bounds given in access_bounds at the index of the found hash in hashes.
55  std::vector<hash_type> m_hashes;
56 
57  std::vector<std::pair<uint32_t, uint32_t> > m_access_bounds;
58  std::vector<uint32_t> m_hit_indices;
59 
60  inline hash_type feature_hash(const std::string& column,
61  const flexible_type& feature) const {
62  return hash128_combine(hash128(column), feature.hash128());
63  }
64 };
65 
66 } // namespace prototype
67 } // namespace turi
68 
69 #endif
#define BEGIN_CLASS_MEMBER_REGISTRATION(python_facing_classname)
#define REGISTER_CLASS_MEMBER_FUNCTION(function,...)
static uint128_t hash128(const char *s, size_t len)
STL namespace.
#define IMPORT_BASE_CLASS_REGISTRATION(base_class)
static uint128_t hash128_combine(uint128_t h1, uint128_t h2)
#define END_CLASS_MEMBER_REGISTRATION
std::vector< std::pair< flexible_type, flexible_type > > flex_dict