Turi Create  4.0
column_statistics.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_ML2_COLUMN_STATISTICS_H_
7 #define TURI_ML2_COLUMN_STATISTICS_H_
8 
9 #include <core/data/flexible_type/flexible_type.hpp>
10 #include <core/logging/assertions.hpp>
11 #include <core/storage/serialization/serialization_includes.hpp>
12 #include <toolkits/ml_data_2/ml_data_column_modes.hpp>
13 #include <model_server/lib/variant.hpp>
14 
15 namespace turi { namespace v2 { namespace ml_data_internal {
16 
17 /** Uses the factory model for saving and loading.
18  */
20 
21 public:
22 
23  column_statistics() = default;
24 
25  virtual ~column_statistics() = default;
26 
27  /**
28  * Equality testing in subclass -- slow! Use for
29  * debugging/testing. Upcast this to superclass to do full testing.
30  */
31  virtual bool is_equal(const column_statistics* other) const = 0;
32 
33  /**
34  * Equality testing -- slow! Use for debugging/testing
35  */
36  bool operator==(const column_statistics& other) const;
37 
38  /**
39  * Inequality testing -- slow! Use for debugging/testing
40  */
41  bool operator!=(const column_statistics& other) const;
42 
43  ////////////////////////////////////////////////////////////
44  // Functions to access the statistics
45 
46  /** Returns the number of seen by the methods collecting the
47  * statistics.
48  */
49  virtual size_t num_observations() const { return size_t(-1); }
50 
51  /* The count; index here is the index obtained by one of the
52  * map_value_to_index functions previously.
53  */
54  virtual size_t count(size_t index) const { return size_t(-1); }
55 
56  /* The mean; index here is the index obtained by one of the
57  * map_value_to_index functions previously.
58  */
59  virtual double mean(size_t index) const { return NAN; }
60 
61  /* The variance; index here is the index obtained by one of the
62  * map_value_to_index functions previously.
63  */
64  virtual double stdev(size_t index) const { return NAN; }
65 
66  /* The variance; index here is the index obtained by one of the
67  * map_value_to_index functions previously.
68  */
69  virtual size_t n_positive(size_t index) const { return size_t(-1); }
70 
71 
72  ////////////////////////////////////////////////////////////
73  // Routines for updating the statistics. This is done online, while
74  // new categories are being added, etc., so we have to be
75 
76  /// Initialize the statistics -- counting, mean, and stdev
77  virtual void initialize() = 0;
78 
79  /// Update categorical statistics for a batch of categorical indices.
80  virtual void update_categorical_statistics(
81  size_t thread_idx, const std::vector<size_t>& cat_index_vect) = 0;
82 
83  /// Update categorical statistics for a batch of real values.
84  virtual void update_numeric_statistics(
85  size_t thread_idx, const std::vector<double>& value_vect) = 0;
86 
87  /// Update statistics after observing a dictionary.
88  virtual void update_dict_statistics(
89  size_t thread_idx, const std::vector<std::pair<size_t, double> >& dict) = 0;
90 
91  /** Perform final computations on the different statistics. Called
92  * after all the data is filled.
93  */
94  virtual void finalize() = 0;
95 
96  ////////////////////////////////////////////////////////////////////////////////
97  // Methods for creation and serialization
98 
99  /** Returns the current version used for the serialization.
100  */
101  virtual size_t get_version() const = 0;
102 
103  /** Serialize the object (save).
104  */
105  virtual void save_impl(turi::oarchive& oarc) const = 0;
106 
107  /** Load the object.
108  */
109  virtual void load_version(turi::iarchive& iarc, size_t version) = 0;
110 
111  /** The factory method for loading and instantiating the proper class
112  */
113  static std::shared_ptr<column_statistics> factory_create(
114  const std::map<std::string, variant_type>& creation_options);
115 
116  const std::map<std::string, variant_type>& get_serialization_parameters() const {
117  return creation_options;
118  }
119 
120  /** One way to set the statistics. Used by the serialization converters.
121  */
122  virtual void set_data(const std::map<std::string, variant_type>& params) {}
123 
124  /** Create a copy with the index cleared.
125  */
126  virtual std::shared_ptr<column_statistics> create_cleared_copy() const = 0;
127 
128  private:
129 
130  /** A snapshot of the options needed for creating the class.
131  */
132  std::map<std::string, variant_type> creation_options;
133 
134  protected:
135 
136  // Store the basic column data. This allows us to do error checking
137  // and error reporting intelligently.
138 
139  std::string column_name;
140  ml_column_mode mode;
141  flex_type_enum original_column_type;
142  std::map<std::string, flexible_type> options;
143 };
144 
145 }}}
146 
147 ////////////////////////////////////////////////////////////////////////////////
148 // Implement serialization for vector<std::shared_ptr<column_statistics>
149 // > and std::shared_ptr<column_statistics>
150 
151 BEGIN_OUT_OF_PLACE_SAVE(arc, std::shared_ptr<v2::ml_data_internal::column_statistics>, m) {
152  if(m == nullptr) {
153  arc << false;
154  } else {
155  arc << true;
156 
157  // Save the version number
158  size_t version = m->get_version();
159  arc << version;
160 
161  // Save the model parameters as a map
162  std::map<std::string, variant_type> serialization_parameters =
163  m->get_serialization_parameters();
164 
165  // Save the version along with the creation options.
166  serialization_parameters["version"] = to_variant(m->get_version());
167 
168  variant_deep_save(serialization_parameters, arc);
169 
170  m->save_impl(arc);
171  }
172 
173 } END_OUT_OF_PLACE_SAVE()
174 
175 
176 BEGIN_OUT_OF_PLACE_LOAD(arc, std::shared_ptr<v2::ml_data_internal::column_statistics>, m) {
177  bool is_not_nullptr;
178  arc >> is_not_nullptr;
179  if(is_not_nullptr) {
180 
181  size_t version;
182  arc >> version;
183 
184  std::map<std::string, variant_type> creation_options;
185  variant_deep_load(creation_options, arc);
186 
188 
189  m->load_version(arc, version);
190 
191  } else {
192  m = std::shared_ptr<v2::ml_data_internal::column_statistics>(nullptr);
193  }
194 } END_OUT_OF_PLACE_LOAD()
195 
196 
197 #endif
#define BEGIN_OUT_OF_PLACE_LOAD(arc, tname, tval)
Macro to make it easy to define out-of-place loads.
Definition: iarchive.hpp:314
The serialization input archive object which, provided with a reference to an istream, will read from the istream, providing deserialization capabilities.
Definition: iarchive.hpp:60
virtual void update_dict_statistics(size_t thread_idx, const std::vector< std::pair< size_t, double > > &dict)=0
Update statistics after observing a dictionary.
void variant_deep_load(variant_type &v, iarchive &iarc)
virtual void save_impl(turi::oarchive &oarc) const =0
virtual void initialize()=0
Initialize the statistics – counting, mean, and stdev.
void variant_deep_save(const variant_type &v, oarchive &oarc)
virtual void load_version(turi::iarchive &iarc, size_t version)=0
static std::shared_ptr< column_statistics > factory_create(const std::map< std::string, variant_type > &creation_options)
virtual std::shared_ptr< column_statistics > create_cleared_copy() const =0
bool operator!=(const column_statistics &other) const
virtual void update_categorical_statistics(size_t thread_idx, const std::vector< size_t > &cat_index_vect)=0
Update categorical statistics for a batch of categorical indices.
virtual void update_numeric_statistics(size_t thread_idx, const std::vector< double > &value_vect)=0
Update categorical statistics for a batch of real values.
virtual void set_data(const std::map< std::string, variant_type > &params)
virtual bool is_equal(const column_statistics *other) const =0
variant_type to_variant(const T &f)
Definition: variant.hpp:308
The serialization output archive object which, provided with a reference to an ostream, will write to the ostream, providing serialization capabilities.
Definition: oarchive.hpp:80
bool operator==(const column_statistics &other) const
#define BEGIN_OUT_OF_PLACE_SAVE(arc, tname, tval)
Macro to make it easy to define out-of-place saves.
Definition: oarchive.hpp:346