Turi Create  4.0
fast_top_k.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_FAST_TOP_K_H_
7 #define TURI_FAST_TOP_K_H_
8 
9 #include <vector>
10 #include <array>
11 #include <algorithm>
12 
13 namespace turi {
14 
15 template <typename T, typename LessThan>
16 GL_HOT_NOINLINE_FLATTEN
17 void __run_top_k_small_k(std::vector<T>& v, LessThan less_than, size_t k) {
18 
19  std::sort(v.begin(), v.begin() + k, less_than);
20 
21  for(size_t i = k; i < v.size(); ++i) {
22  if(less_than(v[0], v[i])) {
23 
24 #ifndef NDEBUG
25  // Preserve all the elements so the debug routines below can check things.
26  std::swap(v[0], v[i]);
27 #else
28  // Just do an assignment.
29  v[0] = v[i];
30 #endif
31 
32  for(size_t j = 1; j < k; ++j) {
33  if(!less_than(v[j-1], v[j])) {
34  std::swap(v[j], v[j-1]);
35  } else {
36  break;
37  }
38  }
39  }
40  }
41 
42 #ifndef NDEBUG
43 
44  // Run checking code here to make sure this is equivalent to
45  // nth_element + sort.
46  std::vector<T> va;
47  va.assign(v.begin(), v.begin() + k);
48 
49  auto gt_sorter = [&](const T& t1, const T& t2) {
50  return less_than(t2, t1);
51  };
52 
53  std::nth_element(v.begin(), v.begin() + k, v.end(), gt_sorter);
54 
55  std::sort(v.begin(), v.begin() + k, gt_sorter);
56  for(size_t j = 0; j < k; ++j) {
57  // test for equality using the less_than operator
58  ASSERT_TRUE(!less_than(v[j], va[k - 1 - j]) && !less_than(va[k - 1 - j], v[j]));
59  }
60 
61  for(size_t i = k; i < v.size(); ++i) {
62  for(size_t j = 0; j < k; ++j) {
63  ASSERT_TRUE(bool(!less_than(v[j], v[i])));
64  }
65  }
66 
67  // Copy them back in sorted decreasing order.
68  for(size_t i = 0; i < k; ++i) {
69  v[k - 1 - i] = va[i];
70  }
71 #else
72  std::reverse(v.begin(), v.begin() + k);
73 #endif
74 
75  DASSERT_TRUE(bool(std::is_sorted(v.begin(), v.begin() + k, gt_sorter)));
76 
77  v.resize(k);
78 }
79 
80 /**
81  * \ingroup util
82  * Goes through and extracts the top k out of all the elements in v,
83  * then resizes v to be of size top_k. After running this, the
84  * elements of v are in sorted descending order.
85  */
86 template <typename T, typename LessThan>
88  std::vector<T>& v, size_t top_k, LessThan less_than) {
89 
90  auto gt_sorter = [&](const T& t1, const T& t2) {
91  return less_than(t2, t1);
92  };
93 
94  if(v.size() <= top_k) {
95  std::sort(v.begin(), v.end(), gt_sorter);
96  return;
97  }
98 
99  if(top_k <= 10) {
100  __run_top_k_small_k(v, less_than, top_k);
101  return;
102  }
103 
104  std::nth_element(v.begin(), v.begin() + top_k, v.end(), gt_sorter);
105  v.resize(top_k);
106  std::sort(v.begin(), v.end(), gt_sorter);
107  return;
108 }
109 
110 /**
111  * \ingroup util
112  * Goes through and extracts the top k out of all the elements in v,
113  * then resizes v to be of size top_k. After running this, the
114  * elements of v are in sorted descending order.
115  */
116 template <typename T>
118  std::vector<T>& v, size_t top_k) {
119  extract_and_sort_top_k(v, top_k, std::less<T>());
120 }
121 
122 
123 }
124 
125 #endif /* _FAST_TOP_K_H_ */
void extract_and_sort_top_k(std::vector< T > &v, size_t top_k, LessThan less_than)
Definition: fast_top_k.hpp:87
std::shared_ptr< sframe > sort(std::shared_ptr< planner_node > sframe_planner_node, const std::vector< std::string > column_names, const std::vector< size_t > &sort_column_indices, const std::vector< bool > &sort_orders)
#define ASSERT_TRUE(cond)
Definition: assertions.hpp:309
#define DASSERT_TRUE(cond)
Definition: assertions.hpp:364