Turi Create  4.0
brute_force_neighbors.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_BRUTE_FORCE_NEIGHBORS_H_
7 #define TURI_BRUTE_FORCE_NEIGHBORS_H_
8 
9 // Types
10 #include <core/parallel/atomic.hpp>
11 
12 // ML-Data Utils
13 
14 // Toolkits
15 #include <toolkits/nearest_neighbors/nearest_neighbors.hpp>
16 
17 namespace turi {
18 namespace nearest_neighbors {
19 
20 
21 
22 class EXPORT brute_force_neighbors: public nearest_neighbors_model {
23 
24  protected:
25 
26  // bool is_dense = true; // Indicates if SparseVector is needed
27 
28 
29  public:
30 
31  static constexpr size_t BRUTE_FORCE_NEIGHBORS_VERSION = 2;
32 
33  /**
34  * Destructor. Make sure bad things don't happen
35  */
36  ~brute_force_neighbors();
37 
38  /**
39  * Set the model options. Use the option manager to set these options. The
40  * option manager should throw errors if the options do not satisfy the option
41  * manager's conditions.
42  *
43  * \param[in] opts Options to set
44  */
45  void init_options(const std::map<std::string,flexible_type>& _opts) override;
46 
47  /**
48  * Create a brute force nearest neighbors model.
49  *
50  * \param[in] X sframe input feature data
51  * \param[in] ref_labels row labels for the reference dataset
52  * \param[in] composite_distance
53  * \param[in] opts
54  */
55  void train(const sframe& X, const std::vector<flexible_type>& ref_labels,
56  const std::vector<dist_component_type>& composite_distance_params,
57  const std::map<std::string, flexible_type>& opts) override;
58 
59  /**
60  * Find neighbors of queries in a created brute_force model. Depending on data
61  * attributes, calls either blockwise query or pairwise query.
62  *
63  * \param[in] mld_queries query data
64  * \param[in] query_labels vector query labels
65  * \param[in] k size_t max number of neighbors to return for each query
66  * \param[in] radius flexible_type max distance for returned neighbors to each query
67  * \param[in] include_self_edges if false, don't include results where the
68  * query index and the reference index are the same.
69  *
70  * \returns ret SFrame with four columns: query label, reference
71  * label, distance, and rank.
72  *
73  */
74  sframe query(const v2::ml_data& mld_queries,
75  const std::vector<flexible_type>& query_labels,
76  const size_t k, const double radius,
77  const bool include_self_edges) const override;
78 
79  /**
80  * Search a nearest neighbors reference object for the neighbors of every
81  * point.
82  *
83  * \param[in] k number of neighbors to return for each query
84  * \param[in] radius distance threshold to call a reference point a neighbor
85  * \param[in] include_self_edges if false, don't include results where the
86  * query index and the reference index are the same.
87  *
88  * \returns ret Shared pointer to an SFrame containing query results.
89  */
90  sframe similarity_graph(const size_t k, const double radius,
91  const bool include_self_edges) const override;
92 
93  /**
94  * Find neighbors of queries in a created brute_force model. Break the
95  * reference and query data into blocks small enough to be read into memory,
96  * then use matrix multiplication to compute distances in bulk. Only
97  * appropriate for dense, numeric data with standard distance functions.
98  *
99  * \param mld_queries v2::ml_data query data
100  * \param neighbors std::vector<neighbor_candidates> container for results
101  * \param dist_name std::string name of the distance function.
102  */
103  void blockwise_query(const v2::ml_data& mld_queries,
104  std::vector<neighbor_candidates>& neighbors,
105  const std::string& dist_name) const;
106 
107  /**
108  * Find neighbors of queries in a created brute_force model, by explicitly
109  * computing the distance function for each pair of query and reference
110  * points. This is the default strategy because it works with any distance
111  * function (including composite distances).
112  *
113  * Pseudo code
114  * ++++++++++++++++++++++++++++++++++++++++++++++++++
115  * for query_block in query_data {
116  * load query_block in memory
117  * parallel_for ref_row in ref_data {
118  * for query_row in query_block {
119  * evaluate_point(query_row, ref_row, row_id)
120  * }
121  * }
122  * }
123  *
124  * \param mld_queries v2::ml_data query data
125  * \param neighbors std::vector<neighbor_candidates> container for results
126  */
127  void pairwise_query(const v2::ml_data& mld_queries,
128  std::vector<neighbor_candidates>& neighbors) const;
129 
130 /**
131  * Construct the similarity graph for the reference data, using blockwise matrix
132  * multiplication for distance computations.
133  *
134  * \param neighbors std::vector<neighbor_candidates> container for results
135  * \param dist_name std::string name of the distance function.
136  */
137  void blockwise_similarity_graph(std::vector<neighbor_candidates>& neighbors,
138  const std::string& dist_name) const;
139 
140 
141  inline size_t get_version() const override {
142  return BRUTE_FORCE_NEIGHBORS_VERSION;
143  }
144 
145  /**
146  * Turi serialization save
147  */
148  void save_impl(turi::oarchive& oarc) const override;
149 
150  /**
151  * Turi serialization save
152  */
153  void load_version(turi::iarchive& iarc, size_t version) override;
154 
155  // TODO: convert interface above to use the extensions methods here
156  BEGIN_CLASS_MEMBER_REGISTRATION("nearest_neighbors_brute_force")
157  REGISTER_CLASS_MEMBER_FUNCTION(brute_force_neighbors::list_fields)
159 
160 };
161 
162 
163 } // namespace nearest_neighbors
164 } // namespace turi
165 
166 #endif
#define BEGIN_CLASS_MEMBER_REGISTRATION(python_facing_classname)
#define REGISTER_CLASS_MEMBER_FUNCTION(function,...)
The serialization input archive object which, provided with a reference to an istream, will read from the istream, providing deserialization capabilities.
Definition: iarchive.hpp:60
#define END_CLASS_MEMBER_REGISTRATION
The serialization output archive object which, provided with a reference to an ostream, will write to the ostream, providing serialization capabilities.
Definition: oarchive.hpp:80