Turi Create  4.0
lsh_neighbors.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_LSH_NEIGHBORS_H_
7 #define TURI_LSH_NEIGHBORS_H_
8 
9 
10 // Toolkits
11 #include <toolkits/nearest_neighbors/nearest_neighbors.hpp>
12 #include <toolkits/nearest_neighbors/lsh_family.hpp>
13 
14 namespace turi {
15 namespace nearest_neighbors {
16 
17 /**
18  * LSH nearest neighbor class.
19  *
20  * The intuition behind LSH-based indexes is to hash data points into
21  * buckets, such that similar points are more likely to be hashed to the same
22  * bucket than dissimilar ones. We could then find the approximate nearest
23  * neighbors of any point, simply by finding the bucket that it is hashed to.
24  *
25  * It works as follows:
26  *
27  * 1. Choose k hash functions h_1, h_2, ..., h_k from a uniform of some
28  * family of LSH functions. For any data point v, place v in the bucket with
29  * key g(v) = (h_1(v), h_2(v), ..., h_k(v)).
30  * 2. Independently perform step 1 l times to construct l separate hash
31  * tables, with hash functions g_1, g_2, ..., g_l
32  *
33  * You can set k and l by setting num_projections_per_table and
34  * num_tables respectively.
35  *
36  *
37  */
38 class EXPORT lsh_neighbors: public nearest_neighbors_model {
39 
40  public:
41 
42  static constexpr size_t LSH_NEIGHBORS_VERSION = 1;
43 
44  /**
45  * Destructor. Make sure bad things don't happen
46  */
47  ~lsh_neighbors();
48 
49  /**
50  * Set the model options. Use the option manager to set these options. The
51  * option manager should throw errors if the options do not satisfy the option
52  * manager's conditions.
53  *
54  * \param[in] opts Options to set
55  */
56  void init_options(const std::map<std::string,flexible_type>& _opts) override;
57 
58  /**
59  * Create a LSH nearest neighbors model.
60  *
61  * \param[in] X sframe input feature data
62  * \param[in] ref_labels row labels for the reference dataset
63  * \param[in] composite_distance_params
64  * \param[in] y sframe input labels
65  */
66  void train(const sframe& X, const std::vector<flexible_type>& ref_labels,
67  const std::vector<dist_component_type>& composite_distance_params,
68  const std::map<std::string, flexible_type>& opts) override;
69 
70  /**
71  * Find neighbors of queries in a created LSH model.
72  *
73  * For each query, the method keeps track of the current k-nearest neighbors
74  * in the LSH. At each node, the closest possible point in each child
75  * node to the query is computed, and if this distance is further than the
76  * current k'th nearest neighbor, that child node (and its descendants) is
77  * skpped in the traversal.
78  *
79  * \param[in] mld_queries query data
80  * \param[in] query_labels sframe query labels
81  * \param[in] k size_t max number of neighbors to return for each query
82  * \param[in] radius double max distance for returned neighbors to each query
83  *
84  * \param[out] ret sframe SFrame with four columns: query label, reference
85  * label, distance, and rank.
86  *
87  * \note Assumes that data is already in the right shape.
88  */
89  sframe query(const v2::ml_data& mld_queries,
90  const std::vector<flexible_type>& query_labels,
91  const size_t k, const double radius,
92  const bool include_self_edges) const override;
93 
94 
95  /**
96  * Gets the model version number
97  */
98  inline size_t get_version() const override {
99  return LSH_NEIGHBORS_VERSION;
100  }
101 
102  /**
103  * Turi serialization save
104  */
105  void save_impl(turi::oarchive& oarc) const override;
106 
107  /**
108  * Turi serialization save
109  */
110  void load_version(turi::iarchive& iarc, size_t version) override;
111 
112  // TODO: convert interface above to use the extensions methods here
113  BEGIN_CLASS_MEMBER_REGISTRATION("nearest_neighbors_lsh")
114  REGISTER_CLASS_MEMBER_FUNCTION(lsh_neighbors::list_fields)
116 
117  private:
118  std::shared_ptr<lsh_family> lsh_model;
119 };
120 
121 
122 } // namespace nearest_neighbors
123 } // namespace turi
124 
125 #endif
#define BEGIN_CLASS_MEMBER_REGISTRATION(python_facing_classname)
#define REGISTER_CLASS_MEMBER_FUNCTION(function,...)
The serialization input archive object which, provided with a reference to an istream, will read from the istream, providing deserialization capabilities.
Definition: iarchive.hpp:60
#define END_CLASS_MEMBER_REGISTRATION
The serialization output archive object which, provided with a reference to an ostream, will write to the ostream, providing serialization capabilities.
Definition: oarchive.hpp:80