6 #ifndef TURI_SPARSE_SIMILARITY_NEIGHBOR_SEARCH_H 7 #define TURI_SPARSE_SIMILARITY_NEIGHBOR_SEARCH_H 9 #include <core/storage/sframe_data/sarray.hpp> 11 #include <core/parallel/pthread_tools.hpp> 12 #include <core/util/try_finally.hpp> 13 #include <core/util/dense_bitset.hpp> 14 #include <toolkits/sparse_similarity/similarities.hpp> 15 #include <toolkits/sparse_similarity/item_processing.hpp> 17 namespace turi {
namespace sparse_sim {
61 template <
typename SimilarityType,
62 typename ProcessFunction,
63 typename SkipFunction>
64 void brute_force_all_pairs_similarity_with_vector_reference(
65 std::shared_ptr<sarray<std::vector<std::pair<size_t, double> > > > reference_data,
66 const std::vector<item_processing_info<SimilarityType> >& reference_item_info,
67 std::shared_ptr<sarray<std::vector<std::pair<size_t, double> > > > query_data,
68 const std::vector<item_processing_info<SimilarityType> >& query_item_info,
69 const SimilarityType& similarity,
70 ProcessFunction&& process_function,
71 size_t num_dimensions,
72 size_t max_memory_usage,
73 SkipFunction&& skip_pair,
74 const dense_bitset* query_mask =
nullptr) {
77 typedef typename SimilarityType::item_data_type item_data_type;
78 typedef typename SimilarityType::interaction_data_type interaction_data_type;
79 typedef typename SimilarityType::final_item_data_type final_item_data_type;
80 typedef typename SimilarityType::final_interaction_data_type final_interaction_data_type;
83 static constexpr
bool use_final_item_data = (
84 !std::is_same<unused_value_type, final_item_data_type>::value);
86 static constexpr
bool missing_values_are_zero = SimilarityType::missing_values_are_zero();
88 final_item_data_type _unused;
92 size_t num_query_rows = query_data->size();
93 size_t num_reference_rows = reference_data->size();
94 bool using_mask = (query_mask !=
nullptr);
99 DASSERT_NE(num_dimensions, 0);
101 DASSERT_EQ(reference_item_info.size(), reference_data->size());
102 DASSERT_EQ(query_item_info.size(), query_data->size());
105 DASSERT_EQ(query_mask->size(), query_data->size());
112 num_query_rows = query_mask->popcount();
116 if(num_query_rows == 0)
120 size_t max_query_rows_per_block = max_memory_usage / (num_dimensions *
sizeof(double));
121 max_query_rows_per_block = std::max(max_num_threads, max_query_rows_per_block);
122 max_query_rows_per_block = std::min(num_query_rows, max_query_rows_per_block);
127 max_query_rows_per_block = std::min<size_t>(
128 std::numeric_limits<int>::max() / 2, max_query_rows_per_block);
132 size_t num_blocks = (num_query_rows + (max_query_rows_per_block - 1)) / max_query_rows_per_block;
139 max_query_rows_per_block = (num_query_rows + (num_blocks - 1)) / num_blocks;
142 auto query_reader = query_data->get_reader(max_num_threads);
145 auto reference_reader = reference_data->get_reader(max_num_threads);
151 std::vector<double> block_data(max_query_rows_per_block * num_dimensions);
152 auto block_data_index = [&](
size_t row_idx,
size_t element_idx) {
153 DASSERT_LT(row_idx, max_query_rows_per_block);
154 DASSERT_LT(element_idx, num_dimensions);
155 return element_idx * max_query_rows_per_block + row_idx;
160 std::vector<size_t> block_query_row_indices(max_query_rows_per_block);
163 std::vector<item_data_type> block_item_data(max_query_rows_per_block);
165 std::vector<final_item_data_type> block_final_item_data;
166 if(use_final_item_data) {
167 block_final_item_data.resize(max_query_rows_per_block);
172 std::vector<size_t> query_row_counters(max_num_threads,
size_t(-1));
175 for(
size_t block_idx = 0; block_idx < num_blocks; ++block_idx) {
178 atomic<size_t> block_write_idx = 0;
181 std::fill(block_data.begin(), block_data.end(), missing_values_are_zero ? 0 : NAN);
184 in_parallel([&](
size_t thread_idx,
size_t num_threads) GL_GCC_ONLY(GL_HOT_NOINLINE_FLATTEN) {
187 size_t query_row_idx_start = (query_data->size() * thread_idx) / num_threads;
188 size_t query_row_idx_end = (query_data->size() * (thread_idx+1)) / num_threads;
192 size_t& current_query_row_index = query_row_counters[thread_idx];
195 if(current_query_row_index ==
size_t(-1)) {
197 current_query_row_index = query_row_idx_start;
201 std::vector<std::vector<std::pair<size_t, double> > > row_v(1);
205 while(current_query_row_index < query_row_idx_end) {
209 if(using_mask && !query_mask->get(current_query_row_index)) {
210 size_t new_q_idx = current_query_row_index;
211 bool any_more = query_mask->next_bit(new_q_idx);
213 if(UNLIKELY(!any_more || new_q_idx >= query_row_idx_end)) {
215 current_query_row_index = query_row_idx_end;
218 DASSERT_NE(current_query_row_index, new_q_idx);
221 current_query_row_index = new_q_idx;
231 size_t internal_block_idx = (++block_write_idx) - 1;
235 if(internal_block_idx >= max_query_rows_per_block) {
240 DASSERT_LT(current_query_row_index, query_row_idx_end);
244 query_reader->read_rows(current_query_row_index, current_query_row_index + 1, row_v);
245 const auto& row = row_v[0];
248 block_query_row_indices[internal_block_idx] = current_query_row_index;
249 block_item_data[internal_block_idx] = query_item_info[current_query_row_index].item_data;
252 if(use_final_item_data) {
253 block_final_item_data[internal_block_idx]
254 = query_item_info[current_query_row_index].final_item_data;
258 for(
size_t i = 0; i < row.size(); ++i) {
259 size_t idx = block_data_index(internal_block_idx, row[i].first);
260 block_data[idx] = row[i].second;
264 ++current_query_row_index;
269 if(block_idx == num_blocks - 1) {
270 DASSERT_EQ(current_query_row_index, query_row_idx_end);
278 size_t _block_write_idx = block_write_idx;
279 if(block_idx < num_blocks - 1) {
280 DASSERT_GE(_block_write_idx, max_query_rows_per_block);
282 DASSERT_LE(_block_write_idx, max_query_rows_per_block);
286 if(block_write_idx < max_query_rows_per_block) {
288 DASSERT_EQ(block_idx, num_blocks - 1);
296 size_t num_query_rows_in_block = std::min<size_t>(block_write_idx, max_query_rows_per_block);
299 DASSERT_GT(num_query_rows_in_block, 0);
306 for(
size_t i = 0; i < num_query_rows_in_block; ++i) {
307 DASSERT_TRUE(query_mask->get(block_query_row_indices[i]));
315 in_parallel([&](
size_t thread_idx,
size_t num_threads) GL_GCC_ONLY(GL_HOT_NOINLINE_FLATTEN) {
317 size_t reference_row_idx_start = (num_reference_rows * thread_idx) / num_threads;
318 size_t reference_row_idx_end = (num_reference_rows * (thread_idx+1)) / num_threads;
320 const size_t n_reference_rows_per_block = 16;
321 std::vector<std::vector<std::pair<size_t, double> > > reference_rows_v;
323 std::vector<interaction_data_type> edges(num_query_rows_in_block);
326 for(
size_t outer_idx = reference_row_idx_start;
327 outer_idx < reference_row_idx_end;
328 outer_idx += n_reference_rows_per_block) {
332 reference_reader->read_rows(
334 std::min(outer_idx + n_reference_rows_per_block, reference_row_idx_end),
337 if(reference_rows_v.size() != n_reference_rows_per_block) {
338 DASSERT_EQ(outer_idx + reference_rows_v.size(), reference_row_idx_end);
342 for(
size_t inner_idx = 0; inner_idx < reference_rows_v.size(); ++inner_idx) {
345 const size_t ref_idx = outer_idx + inner_idx;
346 const auto& row = reference_rows_v[inner_idx];
349 item_data_type ref_item_data = reference_item_info[ref_idx].item_data;
351 const final_item_data_type& ref_final_item_data
352 = reference_item_info[ref_idx].final_item_data;
355 edges.assign(num_query_rows_in_block, interaction_data_type());
358 for(
const auto& p : row) {
359 size_t dim_index = p.first;
360 double ref_value = p.second;
362 if(missing_values_are_zero) {
367 double* __restrict__ bd_ptr = &(block_data[block_data_index(0, dim_index)]);
368 item_data_type* __restrict__ it_data_ptr = block_item_data.data();
369 interaction_data_type* __restrict__ int_data_ptr = edges.data();
371 for(
int i = 0; i < int(num_query_rows_in_block);
372 ++i, ++bd_ptr, ++it_data_ptr, ++int_data_ptr) {
374 similarity.update_interaction_unsafe(
376 ref_item_data, *it_data_ptr,
380 for(
size_t i = 0; i < num_query_rows_in_block; ++i) {
383 double block_data_entry = block_data[block_data_index(i, dim_index)];
385 if(std::isnan(block_data_entry))
389 similarity.update_interaction_unsafe(
391 ref_item_data, block_item_data[i],
392 ref_value, block_data_entry);
398 for(
size_t i = 0; i < num_query_rows_in_block; ++i) {
399 size_t query_index = block_query_row_indices[i];
401 if(skip_pair(query_index, ref_idx))
405 const auto& q_item_data = block_item_data[i];
406 const auto& q_final_item_data =
407 (use_final_item_data ? block_final_item_data[i] : _unused);
410 final_interaction_data_type e_out = final_interaction_data_type();
412 similarity.finalize_interaction(e_out,
413 ref_final_item_data, q_final_item_data,
415 ref_item_data, q_item_data);
418 process_function(ref_idx, query_index, e_out);
463 typename SimilarityType,
464 typename ProcessFunction,
465 typename SkipFunction>
466 void all_pairs_similarity(
467 std::shared_ptr<sarray<std::vector<std::pair<size_t, T> > > > reference_data,
468 std::shared_ptr<sarray<std::vector<std::pair<size_t, T> > > > query_data,
469 const SimilarityType& similarity,
470 ProcessFunction&& process_function,
471 size_t max_memory_usage,
472 SkipFunction&& skip_pair,
473 const dense_bitset* query_mask =
nullptr) {
476 std::vector<item_processing_info<SimilarityType> > reference_item_info;
477 std::vector<item_processing_info<SimilarityType> > query_item_info;
479 size_t reference_num_users = 0;
480 size_t query_num_users = 0;
482 reference_num_users = calculate_item_processing_rowwise(
483 reference_item_info, similarity, reference_data);
485 if(query_data.get() != reference_data.get()) {
486 query_num_users = calculate_item_processing_rowwise(query_item_info, similarity, query_data);
489 size_t num_dimensions = std::max(reference_num_users, query_num_users);
492 brute_force_all_pairs_similarity_with_vector_reference(
496 (query_data.get() == reference_data.get()) ? reference_item_info : query_item_info,
500 [&](
size_t i,
size_t j,
const typename SimilarityType::final_interaction_data_type& v)
503 process_function(i, j, similarity.export_similarity_score(v));
static size_t cpu_count()
#define GL_HOT_INLINE_FLATTEN
void in_parallel(const std::function< void(size_t thread_id, size_t num_threads)> &fn)
#define DASSERT_TRUE(cond)