Turi Create  4.0
statistics_tracker.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_TOPK_STATISTICS_TRACKER_H_
7 #define TURI_TOPK_STATISTICS_TRACKER_H_
8 
9 #include <core/data/flexible_type/flexible_type.hpp>
10 #include <core/util/hash_value.hpp>
11 #include <core/logging/assertions.hpp>
12 #include <core/util/bitops.hpp>
13 #include <core/storage/serialization/serialization_includes.hpp>
14 #include <core/generics/hopscotch_map.hpp>
15 #include <core/parallel/pthread_tools.hpp>
16 #include <core/export.hpp>
17 
18 namespace turi {
19 
20 /**
21  *
22  * Parallel statistics(mean) tracker
23  *
24  * Note: This implementation is intended to be general and will be moved to some
25  * place more general later.
26  *
27  * Construction
28  * -------------
29  *
30  * // Construct the tracker with the arguments.
31  * auto tracker = statistics_tracker(10, 1, "column_name_for_error_messages");
32  * tracker.initialize();
33  *
34  * // Insert flexible types into the tracker
35  * for (const flexible_type& v: sa.range_iterator() {
36  * tracker.insert_or_update(v);
37  * }
38  *
39  * // Finalize mapping
40  * tracker.finalize();
41  *
42  * Lookups
43  * --------
44  * size_t index = tracker.lookup(v); // Returns (size_t) -1 if not present.
45  *
46  * size_t counts = tracker.lookup_counts(v); // Returns 0 if not present.
47  *
48  * flexible_type v = tracker.inverse_lookup(1) // Fails if index doesn't exist.
49  *
50  * Parallel construction
51  * -----------------------
52  *
53  * // Initialize
54  * tracker.initialize();
55  *
56  * // Perform the indexing.
57  * in_parallel([&](size_t thread_idx, size_t num_threads) {
58  *
59  * size_t start_idx = src_size * thread_idx / num_threads;
60  * size_t end_idx = src_size * (thread_idx + 1) / num_threads;
61  *
62  * // Inserts value of 1 for each key k
63  * for (const flexible_type& k: sa.range_iterator(start_idx, end_idx) {
64  * tracker.insert_or_update(k,1,thread_id);
65  * }
66  *
67  * // Finalize
68  * tracker.finalize();
69  *
70  *
71  */
72 class EXPORT statistics_tracker {
73 
74  public:
75 
76  /**
77  * Default constructor
78  *
79  * \param[in] column_name Column name for display.
80  *
81  */
82  statistics_tracker( const std::string _column_name = "") :
83  column_name(_column_name) {
84  }
85 
86 
87  /**
88  * Copy constructor: Don't want to risk making copies of this.
89  */
90  statistics_tracker(const statistics_tracker&) = delete;
91 
92 
93  /**
94  * Initialize the index mapping and setup. Should be called before
95  * starting the map.
96  */
97  void initialize();
98 
99  /**
100  * Insert
101  *
102  * \param[in] key Flexible type.
103  * \param[in] value Flexible type.
104  * \param[in] thread_idx Thread id (For parallel insertion).
105  *
106  */
107  void insert_or_update(const flexible_type& key, flexible_type value,size_t thread_idx = 0) GL_HOT;
108 
109  /**
110  * Returns the index associated with the value.
111  *
112  * \param[in] value Search for the value.
113  * \returns The index. (Returns size_t(-1) if not present).
114  */
115  size_t lookup(const flexible_type& value) const;
116 
117  /**
118  * Returns the counts associated with the value.
119  *
120  * \param[in] value Search for the value.
121  * \returns Counts (Returns 0 if not present).
122  */
123  size_t lookup_counts(const flexible_type& value) const;
124 
125  /**
126  * Returns the counts associated with the value.
127  *
128  * \param[in] value Search for the value.
129  * \returns Counts (Returns 0 if not present).
130  */
131  flex_float lookup_means(const flexible_type& value) const;
132 
133  /**
134  * Finalize by dropping indices that dont meet
135  * - Count requirement i.e count >= threshold.
136  * - Topk requirement.
137  */
138  void finalize(size_t num_examples);
139 
140  /**
141  * Returns the "value" associated with the index.
142  *
143  * \param[\in] idx Index associated with the feature value.
144  * \return The "value" in the original data associated with the given id.
145  */
146  flexible_type inverse_lookup(size_t idx) const;
147 
148 
149 
150  /** Returns the number of categorical variables.
151  *
152  * \return Column size.
153  */
154  inline size_t size() const {
155  return index_lookup.size();
156  }
157 
158  /** Returns the number of categorical variables.
159  *
160  * \return Column size.
161  */
162  inline std::vector<flexible_type> get_keys() const {
163  return keys;
164  }
165 
166  /**
167  * Returns the current version used for the serialization.
168  */
169  size_t get_version() const;
170 
171  /**
172  * Serialize the object (save).
173  */
174  void save_impl(turi::oarchive& oarc) const;
175 
176  /**
177  * Load the object.
178  */
179  void load_version(turi::iarchive& iarc, size_t version);
180 
181  private:
182 
183  // Private members.
184  std::string column_name = "";
185 
186  // List of Map(hash : (value, count)) per thread.
187  struct threadlocal_accumulator{
188  std::vector<hopscotch_map<hash_value, size_t>> count;
189  std::vector<hopscotch_map<hash_value, flex_float>> mean;
190  std::vector<hopscotch_map<hash_value, size_t>> missing;
191  std::vector<hopscotch_map<hash_value, flexible_type>> key_index;
192  } threadlocal_accumulator;
193  // Index -> value/cound
194  std::vector<size_t> counts;
195  std::vector<flex_float> means;
196  std::vector<size_t> missing;
197  std::vector<flexible_type> keys;
198 
199  // Map(value : index)
201 
202 
203  // Private helper functions.
204  // ------------------------------------------------------------------------
205 
206  /**
207  * Validate feature types.
208  */
209  void valdidate_types(const flexible_type& value) const;
210 };
211 } // turicreate
212 
213 
214 // Implement serialization for std::shared_ptr
215 BEGIN_OUT_OF_PLACE_SAVE(arc, std::shared_ptr<statistics_tracker>, m) {
216  if(m == nullptr) {
217  arc << false;
218  } else {
219  arc << true;
220 
221  // Save the version number
222  size_t version = m->get_version();
223  arc << version;
224  // Save the object.
225  m->save_impl(arc);
226  }
227 } END_OUT_OF_PLACE_SAVE()
228 
229 
230 // Implement deserialization
231 BEGIN_OUT_OF_PLACE_LOAD(arc, std::shared_ptr<statistics_tracker>, m) {
232  bool is_not_nullptr;
233  arc >> is_not_nullptr;
234  if(is_not_nullptr) {
235 
236  // Load version
237  size_t version;
238  arc >> version;
239 
240  // Load object.
241  m.reset(new statistics_tracker(""));
242  m->load_version(arc, version);
243 
244  } else {
245  m = std::shared_ptr<statistics_tracker>(nullptr);
246  }
247 } END_OUT_OF_PLACE_LOAD()
248 #endif
#define BEGIN_OUT_OF_PLACE_LOAD(arc, tname, tval)
Macro to make it easy to define out-of-place loads.
Definition: iarchive.hpp:314
The serialization input archive object which, provided with a reference to an istream, will read from the istream, providing deserialization capabilities.
Definition: iarchive.hpp:60
std::set< Key > keys(const std::map< Key, T > &map)
Definition: stl_util.hpp:358
statistics_tracker(const std::string _column_name="")
The serialization output archive object which, provided with a reference to an ostream, will write to the ostream, providing serialization capabilities.
Definition: oarchive.hpp:80
#define BEGIN_OUT_OF_PLACE_SAVE(arc, tname, tval)
Macro to make it easy to define out-of-place saves.
Definition: oarchive.hpp:346
std::vector< flexible_type > get_keys() const