6 #ifndef TURI_SPARSE_SIMILARITY_UTILITIES_H 7 #define TURI_SPARSE_SIMILARITY_UTILITIES_H 9 #include <core/storage/sframe_data/sarray.hpp> 11 #include <core/parallel/pthread_tools.hpp> 12 #include <core/util/try_finally.hpp> 13 #include <core/logging/table_printer/table_printer.hpp> 14 #include <core/storage/sframe_data/sarray_iterators.hpp> 25 std::pair<size_t, size_t>
27 size_t item_index_lb,
size_t item_index_ub) {
31 auto idx_cmp_f = [](
const std::pair<size_t, T>& p1,
const std::pair<size_t, T>& p2) {
32 return p1.first < p2.first;
36 DASSERT_TRUE(std::is_sorted(v.begin(), v.end(), idx_cmp_f));
43 if(item_index_lb <= v.front().first && v.back().first < item_index_ub) {
49 size_t list_idx_start = 0;
50 size_t list_idx_end = v.size();
54 if(!( item_index_lb <= v.front().first) ) {
55 auto lb_it = std::lower_bound(v.begin(), v.end(),
56 std::make_pair(item_index_lb, T()), idx_cmp_f);
57 list_idx_start = std::distance(v.begin(), lb_it);
62 if(!(v.back().first < item_index_ub)) {
63 auto ub_it = std::lower_bound(v.begin() + list_idx_start,
65 std::make_pair(item_index_ub, T()), idx_cmp_f);
66 list_idx_end = std::distance(v.begin(), ub_it);
69 if(list_idx_start != list_idx_end) {
71 DASSERT_LT(v[list_idx_end-1].first, item_index_ub);
72 DASSERT_GE(v[list_idx_end-1].first, item_index_lb);
73 DASSERT_LT(v[list_idx_start].first, item_index_ub);
74 DASSERT_GE(v[list_idx_start].first, item_index_lb);
77 return {list_idx_start, list_idx_end};
141 template <
typename T,
142 typename SliceInitFunction,
143 typename RowProcessFunction,
144 typename ElementProcessFunction,
145 typename SliceFinalizeFunction>
147 const std::shared_ptr<
sarray<std::vector<std::pair<size_t, T> > > >& data,
148 const std::vector<size_t>& slice_delimiters,
149 SliceInitFunction&& init_slice,
150 RowProcessFunction&& preprocess_row,
151 ElementProcessFunction&& process_element,
152 SliceFinalizeFunction&& finalize_slice) {
154 const size_t n = data->size();
159 volatile bool user_cancelation =
false;
161 auto execute_user_cancelation = [&]() GL_GCC_ONLY(GL_COLD_NOINLINE) {
162 user_cancelation =
true;
163 log_and_throw(
"Cancelled by user.");
166 auto check_user_cancelatation = [&]() GL_GCC_ONLY(
GL_HOT_INLINE) {
167 if(user_cancelation || cppipc::must_cancel()) {
168 execute_user_cancelation();
172 for(
size_t slice_idx = 0; slice_idx < slice_delimiters.size() - 1; ++slice_idx) {
173 check_user_cancelatation();
176 size_t item_idx_start = slice_delimiters[slice_idx];
177 size_t item_idx_end = slice_delimiters[slice_idx + 1];
180 init_slice(slice_idx, item_idx_start, item_idx_end);
183 check_user_cancelatation();
188 in_parallel([&](
size_t thread_idx,
size_t num_threads) GL_GCC_ONLY(GL_HOT) {
191 std::vector<std::vector< std::pair<size_t, T> > > item_buffer_v;
195 check_user_cancelatation();
197 size_t block_row_index_start = 0;
198 if(data_it.read_next(&block_row_index_start, &item_buffer_v) ) {
202 size_t n_rows_read = item_buffer_v.size();
204 for(
size_t inner_idx = 0; inner_idx < n_rows_read && !user_cancelation; ++inner_idx) {
207 check_user_cancelatation();
209 size_t row_idx = block_row_index_start + inner_idx;
210 auto& item_list_nonconst = item_buffer_v[inner_idx];
213 preprocess_row(thread_idx, row_idx, item_idx_start, item_idx_end, item_list_nonconst);
216 if(item_list_nonconst.empty()) {
221 check_user_cancelatation();
223 const auto& item_list = item_list_nonconst;
225 size_t list_idx_start = 0;
226 size_t list_idx_end = item_list.size();
229 std::tie(list_idx_start, list_idx_end)
231 item_list, item_idx_start, item_idx_end);
234 if(UNLIKELY(list_idx_start == list_idx_end))
237 for(
size_t idx_a = list_idx_start; idx_a < list_idx_end; ++idx_a) {
239 DASSERT_GE(item_list[idx_a].first, item_idx_start);
240 DASSERT_LT(item_list[idx_a].first, item_idx_end);
242 process_element(thread_idx, row_idx,
243 item_idx_start, item_idx_end,
244 item_list[idx_a].first, item_list[idx_a].second);
252 check_user_cancelatation();
254 finalize_slice(slice_idx, item_idx_start, item_idx_end);
278 template <
typename T,
typename RowProcessFunction>
280 const std::shared_ptr<
sarray<std::vector<std::pair<size_t, T> > > >& data,
281 RowProcessFunction&& process_row) {
288 auto empty_slice_function = [&](
size_t slice_idx,
size_t item_idx_start,
size_t item_idx_end)
293 auto _process_row = [&](
size_t thread_idx,
size_t row_idx,
294 size_t slice_item_idx_start,
size_t slice_item_idx_end,
295 std::vector<std::pair<size_t, T> >& row)
299 process_row(thread_idx, row_idx, row);
305 auto empty_process_element = [&](
size_t thread_idx,
size_t row_idx,
306 size_t item_idx_start,
size_t item_idx_end,
307 size_t item_idx,
const T& value)
314 {0, std::numeric_limits<size_t>::max()},
315 empty_slice_function,
317 empty_process_element,
318 empty_slice_function);
339 template <
typename T>
341 std::shared_ptr<
sarray<std::vector<std::pair<size_t, T> > > > data,
342 const std::vector<size_t>& item_counts,
343 size_t max_memory_usage) {
345 size_t num_items = item_counts.size();
348 std::vector<size_t> slice_delimiters = {0};
350 size_t mem_usage = 0;
352 for(
size_t i = 0; i < item_counts.size(); ++i) {
353 size_t this_mem_usage =
sizeof(std::pair<size_t, T>) * item_counts[i];
355 if(mem_usage + this_mem_usage > max_memory_usage) {
357 slice_delimiters.push_back(i);
360 mem_usage += this_mem_usage;
363 if(slice_delimiters.size() > 256) {
364 log_and_throw( ( std::string(
"Memory limit of ")
365 + std::to_string(max_memory_usage)
366 +
" too low to efficiently transpose sparse sarray.").c_str() );
370 slice_delimiters.push_back(num_items);
373 const size_t n = data->size();
374 const size_t num_slices = slice_delimiters.size() - 1;
376 std::vector<size_t> row_locations;
377 std::vector<atomic<size_t> > row_sizes;
378 std::vector<std::pair<size_t, T> > slice_t_data;
383 slice_t_data.reserve(max_memory_usage);
386 logprogress_stream <<
"Transposing user-item lists for use in nearest neighbor search. " 389 table_printer table( { {
"Elapsed Time (Data Transposition))", 0}, {
"% Complete", 0} } );
392 atomic<size_t> row_count = 0;
393 size_t total_rows_to_process = n * num_slices;
398 auto init_slice = [&](
size_t slice_idx,
size_t item_idx_start,
size_t item_idx_end) {
399 size_t num_items_in_slice = item_idx_end - item_idx_start;
401 row_locations.resize(num_items_in_slice + 1);
403 size_t item_cumsum = 0;
404 for(
size_t i = 0; i < num_items_in_slice; ++i) {
405 row_locations[i] = item_cumsum;
406 item_cumsum += item_counts[i + item_idx_start];
409 row_sizes.resize(num_items_in_slice);
410 std::fill(row_sizes.begin(), row_sizes.end(), 0);
412 slice_t_data.resize(item_cumsum);
413 row_locations[num_items_in_slice] = item_cumsum;
417 auto empty_preprocess_row = [&](
size_t thread_idx,
size_t row_idx,
418 size_t slice_item_idx_start,
size_t slice_item_idx_end,
419 std::vector<std::pair<size_t, T> >& row)
421 size_t cur_row_count = (++row_count) - 1;
423 if(UNLIKELY(cur_row_count % 100 == 0)) {
424 double percent_complete = double((400 * cur_row_count) / total_rows_to_process) / 4;
426 table.print_timed_progress_row(
progress_time(), percent_complete);
435 auto process_element = [&](
size_t thread_idx,
size_t row_idx,
436 size_t item_idx_start,
size_t item_idx_end,
437 size_t item_idx,
const T& value) {
439 size_t internal_idx = item_idx - item_idx_start;
440 size_t item_count_idx = (++ (row_sizes[internal_idx]) ) - 1;
442 DASSERT_LT(item_count_idx, item_counts[item_idx]);
444 slice_t_data[row_locations[internal_idx] + item_count_idx] = {row_idx, value};
449 auto out_data = std::make_shared<sarray<std::vector<std::pair<size_t, T> > > >();
450 out_data->open_for_write(1);
451 auto it_out = out_data->get_output_iterator(0);
453 auto finalize_slice = [&](
size_t slice_idx,
size_t item_idx_start,
size_t item_idx_end) {
454 size_t num_items_in_slice = item_idx_end - item_idx_start;
456 DASSERT_EQ(row_locations.size(), num_items_in_slice + 1);
458 atomic<size_t> process_idx = 0;
461 size_t write_idx = 0;
464 std::vector<std::pair<size_t, T> > row_out;
466 DASSERT_LT(write_idx, num_items_in_slice);
469 row_out.assign(slice_t_data.begin() + row_locations[write_idx],
470 slice_t_data.begin() + row_locations[write_idx + 1]);
473 ++it_out, ++write_idx;
477 in_parallel([&](
size_t thread_idx,
size_t num_threads) {
481 if(thread_idx == 0) {
486 while(write_idx < num_items_in_slice && idx_is_finished.
get(write_idx)) {
491 size_t idx = (++process_idx) - 1;
493 if(idx >= num_items_in_slice) {
497 std::sort(slice_t_data.begin() + row_locations[idx],
498 slice_t_data.begin() + row_locations[idx + 1],
499 [](
const std::pair<size_t, T>& p1,
500 const std::pair<size_t, T>& p2) {
501 return p1.first < p2.first;
510 while(write_idx < num_items_in_slice) {
521 empty_preprocess_row,
528 table.print_footer();
GL_HOT_INLINE std::pair< size_t, size_t > find_slice_boundary_indices(const std::vector< std::pair< size_t, T > > &v, size_t item_index_lb, size_t item_index_ub)
std::shared_ptr< sframe > sort(std::shared_ptr< planner_node > sframe_planner_node, const std::vector< std::string > column_names, const std::vector< size_t > &sort_column_indices, const std::vector< bool > &sort_orders)
std::shared_ptr< sarray< std::vector< std::pair< size_t, T > > > > transpose_sparse_sarray(std::shared_ptr< sarray< std::vector< std::pair< size_t, T > > > > data, const std::vector< size_t > &item_counts, size_t max_memory_usage)
void iterate_through_sparse_item_array_by_slice(const std::shared_ptr< sarray< std::vector< std::pair< size_t, T > > > > &data, const std::vector< size_t > &slice_delimiters, SliceInitFunction &&init_slice, RowProcessFunction &&preprocess_row, ElementProcessFunction &&process_element, SliceFinalizeFunction &&finalize_slice)
bool get(size_t b) const
Returns the value of the bit b.
#define logprogress_stream
void iterate_through_sparse_item_array(const std::shared_ptr< sarray< std::vector< std::pair< size_t, T > > > > &data, RowProcessFunction &&process_row)
void in_parallel(const std::function< void(size_t thread_id, size_t num_threads)> &fn)
void print_header() const
bool set_bit(size_t b)
Atomically sets the bit at position b to true returning the old value.
sarray_block_iterator< T > make_sarray_block_iterator(const std::shared_ptr< sarray< T > > &data)
#define DASSERT_TRUE(cond)