Turi Create  4.0
ball_tree_neighbors.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_BALL_TREE_NEIGHBORS_H_
7 #define TURI_BALL_TREE_NEIGHBORS_H_
8 
9 // Toolkits
10 #include <toolkits/nearest_neighbors/nearest_neighbors.hpp>
11 
12 // Miscellaneous
13 #include <Eigen/SparseCore>
14 
15 
16 namespace turi {
17 namespace nearest_neighbors {
18 
19 
20 /**
21  * Ball tree nearest neighbors class
22  * -----------------------------------------------------------------------------
23  *
24  * Implements the ball tree method for k-nearest neighbors search.
25  *
26  * The ball tree works by partitioning the reference data into into successively
27  * smaller balls, and recording the center (i.e. pivot) and radius of each ball.
28  * A ball tree query uses the pivots and radii to exclude many of the balls from
29  * the k-nearest neighbor search, allowing it to run in sub-linear time.
30  *
31  * In addition to the objects contained in the nearest_neighbors_model base
32  * class, the ball tree contains the following:
33  *
34  * - membership:
35  * Each element of this vector indicates which node the corresponding
36  * reference data point belongs to. After the tree is constructed, the
37  * elements in this vector correspond to leaf nodes of the tree only.
38  *
39  * - pivots:
40  * The reference data point at the center of each tree node.
41  *
42  * - node_radii:
43  * The distance from the pivot of each node to the most distant
44  * reference point belonging to the tree node.
45  */
46 class EXPORT ball_tree_neighbors: public nearest_neighbors_model {
47 
48  protected:
49 
50  std::vector<size_t> membership; // leaf node membership
51  std::vector<DenseVector> pivots; // dense pivot obserations
52  std::vector<SparseVector> pivots_sp; // sparse pivot obserations
53  std::vector<double> node_radii; // node radii
54 
55  size_t tree_depth; // number of levels in the tree
56 
57  /**
58  * Decide if a node should be activated for a query. Activating a node means
59  * it will be traversed in the search for a query's nearest neighbors. For
60  * internal nodes, this means the search will in turn check if each child node
61  * should be activated. For leaf nodes, it means the distances between the
62  * query and all members of the node will be computed (and potentially added
63  * to the set of candidate nearest neighbors).
64  *
65  * \param[in] k size_t Max number of neighbors
66  * \param[in] radius double Max distance for a neighbor
67  * \param[in] min_poss_dist double Minimum possible distance from the query
68  * point to the node in question.
69  * \param[in] num_current_neighbors size_t Current number of neighbors
70  * \param[in] max_current_dist double Max distance to the current neighbors
71  * set. Note that if the neighbor candidates set is empty, this will be -1.0.
72  *
73  * \param[out] activate bool If true, the node should be activated.
74  */
75  bool activate_query_node(size_t k, double radius, double min_poss_dist,
76  size_t num_current_neighbors,
77  double max_current_dist) const;
78 
79 
80  public:
81 
82  /**
83  * version 3 (GLC 1.6/sprint 1509): Add the original_row_index member, to
84  * facilitate the 'include_self_edges' flag.
85  */
86  static constexpr size_t BALL_TREE_NEIGHBORS_VERSION = 2;
87 
88  /**
89  * Destructor. Make sure bad things don't happen
90  */
92 
93  /**
94  * Set the model options. Use the option manager to set these options. The
95  * option manager should throw errors if the options do not satisfy the option
96  * manager's conditions.
97  *
98  * \param[in] opts Options to set
99  */
100  void init_options(const std::map<std::string,flexible_type>& _opts) override;
101 
102 
103  /**
104  * Create a ball tree nearest neighbors model.
105  *
106  * \param[in] X sframe input feature data
107  * \param[in] ref_labels row labels for the reference dataset
108  * \param[in] composite_distance_params
109  * \param[in] y sframe input labels
110  */
111  void train(const sframe& X, const std::vector<flexible_type>& ref_labels,
112  const std::vector<dist_component_type>& composite_distance_params,
113  const std::map<std::string, flexible_type>& opts) override;
114 
115  /**
116  * Find neighbors of queries in a created ball tree model.
117  *
118  * For each query, the method keeps track of the current k-nearest neighbors
119  * in the ball tree. At each node, the closest possible point in each child
120  * node to the query is computed, and if this distance is further than the
121  * current k'th nearest neighbor, that child node (and its descendants) is
122  * skpped in the traversal.
123  *
124  * \param[in] mld_queries query data
125  * \param[in] query_labels sframe query labels
126  * \param[in] k size_t max number of neighbors to return for each query
127  * \param[in] radius double max distance for returned neighbors to each query
128  *
129  * \param[out] ret sframe SFrame with four columns: query label, reference
130  * label, distance, and rank.
131  *
132  * \note Assumes that data is already in the right shape.
133  */
134  sframe query(const v2::ml_data& mld_queries,
135  const std::vector<flexible_type>& query_labels,
136  const size_t k, const double radius,
137  const bool include_self_edges) const override;
138 
139  /**
140  * Gets the model version number
141  */
142  inline size_t get_version() const override {
143  return BALL_TREE_NEIGHBORS_VERSION;
144  }
145 
146  /**
147  * Turi serialization save
148  */
149  void save_impl(turi::oarchive& oarc) const override;
150 
151  /**
152  * Turi serialization save
153  */
154  void load_version(turi::iarchive& iarc, size_t version) override;
155 
156  // TODO: convert interface above to use the extensions methods here
157  BEGIN_CLASS_MEMBER_REGISTRATION("nearest_neighbors_ball_tree")
158  REGISTER_CLASS_MEMBER_FUNCTION(ball_tree_neighbors::list_fields)
160 
161 };
162 
163 
164 } // namespace nearest_neighbors
165 } // namespace turi
166 
167 #endif
#define BEGIN_CLASS_MEMBER_REGISTRATION(python_facing_classname)
#define REGISTER_CLASS_MEMBER_FUNCTION(function,...)
The serialization input archive object which, provided with a reference to an istream, will read from the istream, providing deserialization capabilities.
Definition: iarchive.hpp:60
#define END_CLASS_MEMBER_REGISTRATION
The serialization output archive object which, provided with a reference to an ostream, will write to the ostream, providing serialization capabilities.
Definition: oarchive.hpp:80