Turi Create  4.0
neighbor_search.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_SPARSE_SIMILARITY_NEIGHBOR_SEARCH_H
7 #define TURI_SPARSE_SIMILARITY_NEIGHBOR_SEARCH_H
8 
9 #include <core/storage/sframe_data/sarray.hpp>
10 #include <vector>
11 #include <core/parallel/pthread_tools.hpp>
12 #include <core/util/try_finally.hpp>
13 #include <core/util/dense_bitset.hpp>
14 #include <toolkits/sparse_similarity/similarities.hpp>
15 #include <toolkits/sparse_similarity/item_processing.hpp>
16 
17 namespace turi { namespace sparse_sim {
18 
19 /** Efficiently perform an all-pairs brute force processing as
20  * possible over sarrays of sparse, sorted vectors.
21  *
22  * reference_data and query_data are two sarrays of sparse vectors.
23  * A similarity score is calculated between each entry of
24  * reference_data and each entry of query_data, with process_function
25  * called for each (minus exceptions below). *
26  *
27  * Reference_item_info and query_item_info are obtained from calling
28  * one of the methods in item_processing.hpp.
29  *
30  * The SimilarityType similarity function is defined as one of the
31  * classes in similarities.hpp, or a class that also conforms to a
32  * similar interface.
33  *
34  * process_function should have the signature
35  *
36  * process_function(size_t reference_idx, size_t query_idx, double similarity)
37  *
38  * It is called in parallel for each reference and query entry.
39  *
40  * num_dimensions is the maximum dimension of each sparse vector. An
41  * error is raised if any index is >= num_dimensions.
42  *
43  * max_memory_usage is used to determine the block size for doing the
44  * query; a larger value of this means fewer passes through the
45  * reference set.
46  *
47  * skip_pair has the signature
48  *
49  * skip_pair(size_t reference_idx, size_t query_idx) -> bool
50  *
51  * If true, then the similarity score for that item is not calculated
52  * for that reference_idx and query_idx pair. Normally, this can be
53  * set to return false, in which case nothing is skipped. (This is
54  * used, for example, if the reference_data and the query_data are the
55  * same, and only one direction is calculated.
56  *
57  * If provided, query_mask is a dense_bitset of the same length as
58  * query_data. If a particular entry is false, then that row is
59  * skipped in the similarity comparisons.
60  */
61 template <typename SimilarityType,
62  typename ProcessFunction,
63  typename SkipFunction>
64 void brute_force_all_pairs_similarity_with_vector_reference(
65  std::shared_ptr<sarray<std::vector<std::pair<size_t, double> > > > reference_data,
66  const std::vector<item_processing_info<SimilarityType> >& reference_item_info,
67  std::shared_ptr<sarray<std::vector<std::pair<size_t, double> > > > query_data,
68  const std::vector<item_processing_info<SimilarityType> >& query_item_info,
69  const SimilarityType& similarity,
70  ProcessFunction&& process_function,
71  size_t num_dimensions,
72  size_t max_memory_usage,
73  SkipFunction&& skip_pair,
74  const dense_bitset* query_mask = nullptr) {
75 
76  // The vertex type is used as reference later on.
77  typedef typename SimilarityType::item_data_type item_data_type;
78  typedef typename SimilarityType::interaction_data_type interaction_data_type;
79  typedef typename SimilarityType::final_item_data_type final_item_data_type;
80  typedef typename SimilarityType::final_interaction_data_type final_interaction_data_type;
81 
82  // Set constants used later.
83  static constexpr bool use_final_item_data = (
84  !std::is_same<unused_value_type, final_item_data_type>::value);
85 
86  static constexpr bool missing_values_are_zero = SimilarityType::missing_values_are_zero();
87 
88  final_item_data_type _unused;
89 
90  // constants used later
91  size_t max_num_threads = thread::cpu_count();
92  size_t num_query_rows = query_data->size();
93  size_t num_reference_rows = reference_data->size();
94  bool using_mask = (query_mask != nullptr);
95 
96  ////////////////////////////////////////////////////////////////////////////////
97  // Check input
98 
99  DASSERT_NE(num_dimensions, 0);
100 
101  DASSERT_EQ(reference_item_info.size(), reference_data->size());
102  DASSERT_EQ(query_item_info.size(), query_data->size());
103 
104  if(using_mask) {
105  DASSERT_EQ(query_mask->size(), query_data->size());
106  }
107 
108  ////////////////////////////////////////////////////////////////////////////////
109  // If we are using a mask, then the number of query rows is
110  // calculated from that.
111  if(using_mask) {
112  num_query_rows = query_mask->popcount();
113  }
114 
115  // Nothing to do here.
116  if(num_query_rows == 0)
117  return;
118 
119  // Set the block size as a function of the maximum memory usage.
120  size_t max_query_rows_per_block = max_memory_usage / (num_dimensions * sizeof(double));
121  max_query_rows_per_block = std::max(max_num_threads, max_query_rows_per_block);
122  max_query_rows_per_block = std::min(num_query_rows, max_query_rows_per_block);
123 
124  // This is needed, as we use an int counter in the inner loop to
125  // enable easier autovectorization (see
126  // http://www.slideshare.net/linaroorg/using-gcc-autovectorizer).
127  max_query_rows_per_block = std::min<size_t>(
128  std::numeric_limits<int>::max() / 2, max_query_rows_per_block);
129 
130  // Set up the number of blocks. Number of blocks is the ceiling of
131  // all this.
132  size_t num_blocks = (num_query_rows + (max_query_rows_per_block - 1)) / max_query_rows_per_block;
133 
134  // Now that we have the number of blocks, further minimize memory
135  // use by making the number of query rows per block as even as
136  // possible. That way we won't end up with a single block that has
137  // like 1 row and the rest that have many more.
138 
139  max_query_rows_per_block = (num_query_rows + (num_blocks - 1)) / num_blocks;
140 
141  // Get the reader for the query data.
142  auto query_reader = query_data->get_reader(max_num_threads);
143 
144  // Get the reader for the reference data
145  auto reference_reader = reference_data->get_reader(max_num_threads);
146 
147  // Set up the query data so that all dimensions are contiguous in
148  // memory. That way, on a query, we can do everything for this
149  // element together, thereby optimizing memory access and the
150  // increasing the likelihood that the compiler can vectorize it.
151  std::vector<double> block_data(max_query_rows_per_block * num_dimensions);
152  auto block_data_index = [&](size_t row_idx, size_t element_idx) {
153  DASSERT_LT(row_idx, max_query_rows_per_block);
154  DASSERT_LT(element_idx, num_dimensions);
155  return element_idx * max_query_rows_per_block + row_idx;
156  };
157 
158  // For all the rows in the current block, this is the actual row
159  // index within that block.
160  std::vector<size_t> block_query_row_indices(max_query_rows_per_block);
161 
162  // The vertex info for each of these rows.
163  std::vector<item_data_type> block_item_data(max_query_rows_per_block);
164 
165  std::vector<final_item_data_type> block_final_item_data;
166  if(use_final_item_data) {
167  block_final_item_data.resize(max_query_rows_per_block);
168  }
169 
170  // Counters indicating where we are within each segment, as each
171  // thread reads a new segment.
172  std::vector<size_t> query_row_counters(max_num_threads, size_t(-1));
173 
174  // Loop over the blocks.
175  for(size_t block_idx = 0; block_idx < num_blocks; ++block_idx) {
176 
177  // This is the location of the current open slot for dumping one of the rows
178  atomic<size_t> block_write_idx = 0;
179 
180  // Clear out the data in this block.
181  std::fill(block_data.begin(), block_data.end(), missing_values_are_zero ? 0 : NAN);
182 
183  // Fill the block with appropriate rows.
184  in_parallel([&](size_t thread_idx, size_t num_threads) GL_GCC_ONLY(GL_HOT_NOINLINE_FLATTEN) {
185 
186  // This is the segment we are responsible for in this thread.
187  size_t query_row_idx_start = (query_data->size() * thread_idx) / num_threads;
188  size_t query_row_idx_end = (query_data->size() * (thread_idx+1)) / num_threads;
189 
190  // Get the overall current_query_row_index where we are at within the segment
191  // this thread is assigned to.
192  size_t& current_query_row_index = query_row_counters[thread_idx];
193 
194  // Check for initializing it at the appropriate location.
195  if(current_query_row_index == size_t(-1)) {
196  // It has not been initialized yet; do it at the start of our segment.
197  current_query_row_index = query_row_idx_start;
198  }
199 
200  // Row buffer.
201  std::vector<std::vector<std::pair<size_t, double> > > row_v(1);
202 
203  // Now, read in rows until we are out of space in this block,
204  // or until we are out of rows in this reading segment.
205  while(current_query_row_index < query_row_idx_end) {
206 
207  // If we are using the query mask, then check if we are in a
208  // valid spot. If not, then advance forward until we are.
209  if(using_mask && !query_mask->get(current_query_row_index)) {
210  size_t new_q_idx = current_query_row_index;
211  bool any_more = query_mask->next_bit(new_q_idx);
212 
213  if(UNLIKELY(!any_more || new_q_idx >= query_row_idx_end)) {
214  // Done.
215  current_query_row_index = query_row_idx_end;
216  break;
217  } else {
218  DASSERT_NE(current_query_row_index, new_q_idx);
219 
220  // Next row.
221  current_query_row_index = new_q_idx;
222  }
223  }
224 
225  if(using_mask) {
226  // Just make sure we've got a live one.
227  DASSERT_TRUE(query_mask->get(current_query_row_index));
228  }
229 
230  // Get the next index.
231  size_t internal_block_idx = (++block_write_idx) - 1;
232 
233  // Do we have a place in the to put this? If not, break and
234  // leave this position for the next block.
235  if(internal_block_idx >= max_query_rows_per_block) {
236  break;
237  }
238 
239  // Assert that we do indeed have a row left.
240  DASSERT_LT(current_query_row_index, query_row_idx_end);
241 
242  // Now that we know we have a spot in the block, write it
243  // out to the block data.
244  query_reader->read_rows(current_query_row_index, current_query_row_index + 1, row_v);
245  const auto& row = row_v[0];
246 
247  // Write
248  block_query_row_indices[internal_block_idx] = current_query_row_index;
249  block_item_data[internal_block_idx] = query_item_info[current_query_row_index].item_data;
250 
251  // Write out the final vertex data.
252  if(use_final_item_data) {
253  block_final_item_data[internal_block_idx]
254  = query_item_info[current_query_row_index].final_item_data;
255  }
256 
257  // Write the row out to the block data.
258  for(size_t i = 0; i < row.size(); ++i) {
259  size_t idx = block_data_index(internal_block_idx, row[i].first);
260  block_data[idx] = row[i].second;
261  }
262 
263  // Finally, advance the counter to continue.
264  ++current_query_row_index;
265  }
266 
267  // If we are on the last pass, make sure that we have
268  // covered all the query data.
269  if(block_idx == num_blocks - 1) {
270  DASSERT_EQ(current_query_row_index, query_row_idx_end);
271  }
272  });
273 
274  // Check to make sure our math is correct regarding the number of query
275  // rows and the number of blocks.
276 #ifndef NDEBUG
277  {
278  size_t _block_write_idx = block_write_idx; // cause of atomic
279  if(block_idx < num_blocks - 1) {
280  DASSERT_GE(_block_write_idx, max_query_rows_per_block);
281  } else {
282  DASSERT_LE(_block_write_idx, max_query_rows_per_block);
283  }
284 
285  // Readjust the size of the block (num_query_rows_in_block) if needed.
286  if(block_write_idx < max_query_rows_per_block) {
287  // Everything is done, so it must be in the last block
288  DASSERT_EQ(block_idx, num_blocks - 1);
289  }
290  }
291 #endif
292 
293  // Set the number of query rows in this block. The
294  // block_write_idx may have been incremented multiple times by
295  // different threads.
296  size_t num_query_rows_in_block = std::min<size_t>(block_write_idx, max_query_rows_per_block);
297 
298  // If all the math is correct, this block will never be empty.
299  DASSERT_GT(num_query_rows_in_block, 0);
300 
301  // Now, if we're using a mask, make sure all the indices are
302  // masked properly.
303 #ifndef NDEBUG
304  {
305  if(using_mask) {
306  for(size_t i = 0; i < num_query_rows_in_block; ++i) {
307  DASSERT_TRUE(query_mask->get(block_query_row_indices[i]));
308  }
309  }
310  }
311 #endif
312 
313  // Okay, now that we have a specific block of query data, go
314  // through and perform the nearest neighbors query on it.
315  in_parallel([&](size_t thread_idx, size_t num_threads) GL_GCC_ONLY(GL_HOT_NOINLINE_FLATTEN) {
316  // This is the segment we are responsible for in this thread.
317  size_t reference_row_idx_start = (num_reference_rows * thread_idx) / num_threads;
318  size_t reference_row_idx_end = (num_reference_rows * (thread_idx+1)) / num_threads;
319 
320  const size_t n_reference_rows_per_block = 16;
321  std::vector<std::vector<std::pair<size_t, double> > > reference_rows_v;
322 
323  std::vector<interaction_data_type> edges(num_query_rows_in_block);
324 
325  // Read it in in blocks of n_reference_rows_per_block rows for efficiency.
326  for(size_t outer_idx = reference_row_idx_start;
327  outer_idx < reference_row_idx_end;
328  outer_idx += n_reference_rows_per_block) {
329 
330  ////////////////////////////////////////////////////////////////////////////////
331 
332  reference_reader->read_rows(
333  outer_idx,
334  std::min(outer_idx + n_reference_rows_per_block, reference_row_idx_end),
335  reference_rows_v);
336 
337  if(reference_rows_v.size() != n_reference_rows_per_block) {
338  DASSERT_EQ(outer_idx + reference_rows_v.size(), reference_row_idx_end);
339  }
340 
341  // Now over rows in the buffer.
342  for(size_t inner_idx = 0; inner_idx < reference_rows_v.size(); ++inner_idx) {
343 
344  // Now, for each row, go through and calculate the full intersection.
345  const size_t ref_idx = outer_idx + inner_idx;
346  const auto& row = reference_rows_v[inner_idx];
347 
348  // Get the information for this particular vertex.
349  item_data_type ref_item_data = reference_item_info[ref_idx].item_data;
350 
351  const final_item_data_type& ref_final_item_data
352  = reference_item_info[ref_idx].final_item_data;
353 
354  // Zero the edges.
355  edges.assign(num_query_rows_in_block, interaction_data_type());
356 
357  // Get the vertex for this one here.
358  for(const auto& p : row) {
359  size_t dim_index = p.first;
360  double ref_value = p.second;
361 
362  if(missing_values_are_zero) {
363  // This is in the inner loop, so a lot of time is spent
364  // in this computation. Try to make it as friendly as
365  // possible to the vectorizer as possible.
366 
367  double* __restrict__ bd_ptr = &(block_data[block_data_index(0, dim_index)]);
368  item_data_type* __restrict__ it_data_ptr = block_item_data.data();
369  interaction_data_type* __restrict__ int_data_ptr = edges.data();
370 
371  for(int i = 0; i < int(num_query_rows_in_block);
372  ++i, ++bd_ptr, ++it_data_ptr, ++int_data_ptr) {
373 
374  similarity.update_interaction_unsafe(
375  *int_data_ptr,
376  ref_item_data, *it_data_ptr,
377  ref_value, *bd_ptr);
378  }
379  } else {
380  for(size_t i = 0; i < num_query_rows_in_block; ++i) {
381  // branching on individual entries, so can't do
382  // vectorization here anyway.
383  double block_data_entry = block_data[block_data_index(i, dim_index)];
384 
385  if(std::isnan(block_data_entry))
386  continue;
387 
388  // Aggregate it along this edge.
389  similarity.update_interaction_unsafe(
390  edges[i],
391  ref_item_data, block_item_data[i],
392  ref_value, block_data_entry);
393  }
394  }
395  }
396 
397  // Now, go through, finalize the answers, and record them.
398  for(size_t i = 0; i < num_query_rows_in_block; ++i) {
399  size_t query_index = block_query_row_indices[i];
400 
401  if(skip_pair(query_index, ref_idx))
402  continue;
403 
404  // Get the vertex and value info for this query row.
405  const auto& q_item_data = block_item_data[i];
406  const auto& q_final_item_data =
407  (use_final_item_data ? block_final_item_data[i] : _unused);
408 
409  // Set up the output value.
410  final_interaction_data_type e_out = final_interaction_data_type();
411 
412  similarity.finalize_interaction(e_out,
413  ref_final_item_data, q_final_item_data,
414  edges[i],
415  ref_item_data, q_item_data);
416 
417  // Now do the meat of the operation -- record the result.
418  process_function(ref_idx, query_index, e_out);
419  }
420  }
421  }
422  });
423 
424  // Now, we're done, so go to the next block.
425  }
426 }
427 
428 /** An easier-to-use wrapper for the above nearest neighbors search.
429  * reference_data and query_data are two sarrays of sparse vectors.
430  * A similarity score is calculated between each entry of
431  * reference_data and each entry of query_data, with process_function
432  * called for each (minus exceptions below).
433 
434  * The SimilarityType similarity function is defined as one of the
435  * classes in similarities.hpp, or a class that also conforms to a
436  * similar interface.
437  *
438  * process_function should have the signature
439  *
440  * process_function(size_t reference_idx, size_t query_idx, double similarity)
441  *
442  * It is called in parallel for each reference and query entry.
443  *
444  * max_memory_usage is used to determine the block size for doing the
445  * query; a larger value of this means fewer passes through the
446  * reference set.
447  *
448  * skip_pair has the signature
449  *
450  * skip_pair(size_t reference_idx, size_t query_idx) -> bool
451  *
452  * If true, then the similarity score for that item is not calculated
453  * for that reference_idx and query_idx pair. Normally, this can be
454  * set to return false, in which case nothing is skipped. (This is
455  * used, for example, if the reference_data and the query_data are the
456  * same, and only one direction is calculated.
457  *
458  * If provided, query_mask is a dense_bitset of the same length as
459  * query_data. If a particular entry is false, then that row is
460  * skipped in the similarity comparisons.
461  */
462 template<typename T,
463  typename SimilarityType,
464  typename ProcessFunction,
465  typename SkipFunction>
466 void all_pairs_similarity(
467  std::shared_ptr<sarray<std::vector<std::pair<size_t, T> > > > reference_data,
468  std::shared_ptr<sarray<std::vector<std::pair<size_t, T> > > > query_data,
469  const SimilarityType& similarity,
470  ProcessFunction&& process_function,
471  size_t max_memory_usage,
472  SkipFunction&& skip_pair,
473  const dense_bitset* query_mask = nullptr) {
474 
475 
476  std::vector<item_processing_info<SimilarityType> > reference_item_info;
477  std::vector<item_processing_info<SimilarityType> > query_item_info;
478 
479  size_t reference_num_users = 0;
480  size_t query_num_users = 0;
481 
482  reference_num_users = calculate_item_processing_rowwise(
483  reference_item_info, similarity, reference_data);
484 
485  if(query_data.get() != reference_data.get()) {
486  query_num_users = calculate_item_processing_rowwise(query_item_info, similarity, query_data);
487  }
488 
489  size_t num_dimensions = std::max(reference_num_users, query_num_users);
490 
491  // Now, we have everything we need to use the above function.
492  brute_force_all_pairs_similarity_with_vector_reference(
493  reference_data,
494  reference_item_info,
495  query_data,
496  (query_data.get() == reference_data.get()) ? reference_item_info : query_item_info,
497  similarity,
498 
499  // The process function similarity has has to be translated, so do that here.
500  [&](size_t i, size_t j, const typename SimilarityType::final_interaction_data_type& v)
501  GL_GCC_ONLY(GL_HOT_INLINE_FLATTEN)
502  {
503  process_function(i, j, similarity.export_similarity_score(v));
504  },
505  num_dimensions,
506  max_memory_usage,
507  skip_pair,
508  query_mask);
509 }
510 
511 }}
512 
513 
514 #endif
static size_t cpu_count()
#define GL_HOT_INLINE_FLATTEN
void in_parallel(const std::function< void(size_t thread_id, size_t num_threads)> &fn)
Definition: lambda_omp.hpp:35
#define DASSERT_TRUE(cond)
Definition: assertions.hpp:364