Turi Create  4.0
row_reference.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_DML_DATA_ROW_REFERENCE_H_
7 #define TURI_DML_DATA_ROW_REFERENCE_H_
8 
9 #include <core/logging/assertions.hpp>
10 #include <ml/ml_data/data_storage/ml_data_row_format.hpp>
11 #include <ml/ml_data/data_storage/ml_data_block_manager.hpp>
12 #include <ml/ml_data/ml_data.hpp>
13 #include <core/util/code_optimization.hpp>
14 
15 #include <Eigen/SparseCore>
16 #include <Eigen/Core>
17 
18 #include <array>
19 #include <type_traits>
20 
21 namespace turi {
22 
23 typedef Eigen::Matrix<double, Eigen::Dynamic,1> DenseVector;
24 typedef Eigen::SparseVector<double> SparseVector;
25 
26 /**
27  * \ingroup mldata
28  * A class containing a reference to the row of an ml_data instance,
29  * providing access to the underlying data.
30  *
31  * In other words, you can do
32  *
33  * it->fill(x);
34  *
35  * or
36  *
37  * auto row_ref = *it;
38  *
39  * // do stuff ...
40  * row_ref.fill(x);
41  *
42  * The data block pointed to by this reference is kept alive as long
43  * as this reference class exists.
44  *
45  *
46  * Another example of how it is used is below:
47  *
48  * sframe X = make_integer_testing_sframe( {"C1", "C2"}, { {0, 0}, {1, 1}, {2, 2}, {3, 3}, {4, 4} } );
49  *
50  * ml_data data;
51  *
52  * data.fill(X);
53  *
54  * // Get row references
55  *
56  * std::vector<ml_data_row_reference> rows(data.num_rows());
57  *
58  * for(auto it = data.get_iterator(); !it.done(); ++it) {
59  * rows[it.row_index()] = *it;
60  * }
61  *
62  * // Now go through and make sure that each of these hold the
63  * // correct answers.
64  *
65  * std::vector<ml_data_entry> x;
66  *
67  * for(size_t i = 0; i < rows.size(); ++i) {
68  *
69  * // The metadata for the row is the same as that in the data.
70  * ASSERT_TRUE(rows[i].metadata().get() == data.metadata().get());
71  *
72  * rows[i].fill(x);
73  *
74  * ASSERT_EQ(x.size(), 2);
75  *
76  * ASSERT_EQ(x[0].column_index, 0);
77  * ASSERT_EQ(x[0].index, 0);
78  * ASSERT_EQ(x[0].value, i);
79  *
80  * ASSERT_EQ(x[1].column_index, 1);
81  * ASSERT_EQ(x[1].index, 0);
82  * ASSERT_EQ(x[1].value, i);
83  * }
84  * }
85  *
86  */
88  public:
89 
90  /** Create an ml_data_row_reference from a single sframe row reference.
91  *
92  * Row must be in the format {column_name, value} and columns
93  * correspond to the columns in metadata. Missing columns are
94  * treated as missing values.
95  *
96  * Returns a single row reference.
97  */
98  static GL_HOT ml_data_row_reference from_row(
99  const std::shared_ptr<ml_metadata>& metadata, const flex_dict& row,
100  ml_missing_value_action none_action = ml_missing_value_action::USE_NAN);
101 
102  /**
103  * Fill an observation vector, represented as an ml_data_entry
104  * struct. (column_index, index, value) pairs, from this row
105  * reference. For each column:
106  *
107  * Categotical: Returns (col_id, v, 1)
108  * Numeric : Returns (col_id, 0, v)
109  * Vector : Returns (col_id, i, v) for each (i,v) in vector.
110  *
111  * Example use is given by the following code:
112  *
113  * std::vector<ml_data_entry> x;
114  *
115  * row_ref.fill(x);
116  * double y = row_ref.target_value();
117  * ...
118  */
119  template <typename Entry>
121  inline void fill(std::vector<Entry>& x) const;
122 
123  /**
124  * Fill a row of an Eigen expression in the current location in the
125  * iteration.
126  *
127  * Example:
128  *
129  * Eigen::MatrixXd X;
130  *
131  * ...
132  *
133  * it.fill(X.row(row_idx));
134  *
135  * ---------------------------------------------
136  *
137  * \param[in,out] x An eigen row expression.
138  *
139  */
140  template <typename EigenXpr>
142  inline void fill(
143  EigenXpr&& x,
144  typename std::enable_if<std::is_convertible<EigenXpr, DenseVector>::value >::type* = 0) const {
145 
146  fill_eigen(x);
147  }
148 
149  /**
150  * Fill an observation vector with the untranslated columns, if any
151  * have been specified at setup time. These columns are simply
152  * mapped back to their sarray counterparts.
153  */
154  void fill_untranslated_values(std::vector<flexible_type>& x) const;
155 
156  /** The explicit version to fill an eigen expression.
157  */
158  template <typename EigenXpr>
160  inline void fill_eigen(EigenXpr&& x) const;
161 
162  /** A generic function to unpack the values into a particular
163  * format. This allows, e.g. custom distance functions and stuff
164  * to work out well.
165  *
166  * // Called for every element.
167  * // mode: What type of column it is.
168  * // feature_index: index within the column
169  * // index_size: number of features in this column.
170  * // index_offset: The global index would be index_offset + feature_index
171  * // value: the value of the feature.
172  *
173  * auto unpack_function =
174  * [&](ml_column_mode mode, size_t column_index,
175  * size_t feature_index, double value,
176  * size_t index_size, size_t index_offset) {
177  * ...
178  * };
179  *
180  * // Called after every column is done unpacking.
181  * auto column_end_function =
182  * [&](ml_column_mode mode, size_t column_index, size_t index_size) {
183  * ...
184  * };
185  */
186  template <typename ElementWriteFunction,
187  typename ColumnEndFunction>
189  void unpack(ElementWriteFunction&& ewf, ColumnEndFunction&& cef) const {
190 
191  if(UNLIKELY(!has_translated_columns))
192  return;
193 
194  DASSERT_TRUE(data_block != nullptr);
195  DASSERT_TRUE(data_block->metadata != nullptr);
196  DASSERT_TRUE(data_block->translated_rows.entry_data.size() != 0);
197 
198  const ml_data_internal::row_metadata& rm = data_block->rm;
199  ml_data_internal::entry_value_iterator row_block_ptr = current_data_iter();
200 
201  read_ml_data_row(rm, row_block_ptr, ewf, cef);
202  }
203 
204  /** Returns the current target value, if present, or 1 if not
205  * present. If the target column is supposed to be a categorical
206  * value, then use target_index().
207  */
209  return get_target_value(data_block->rm, current_data_iter());
210  }
211 
212  /** Returns the current categorical target index, if present, or 0
213  * if not present.
214  */
216  return get_target_index(data_block->rm, current_data_iter());
217  }
218 
219  /** Returns a pointer to the metadata class that describes the data
220  * that this row reference refers to.
221  */
222  const std::shared_ptr<ml_metadata>& metadata() const {
223  return data_block->metadata;
224  }
225 
226  private:
227  friend class turi::ml_data_iterator;
228 
229  std::shared_ptr<ml_data_internal::ml_data_block> data_block;
230  size_t current_in_block_index = size_t(-1);
231  size_t current_in_block_row_index = size_t(-1);
232  bool has_translated_columns = false;
233  bool has_untranslated_columns = false;
234 
235  /** Return a pointer to the current location in the data.
236  */
237  inline ml_data_internal::entry_value_iterator current_data_iter() const GL_HOT_INLINE_FLATTEN {
238 
239 #ifndef NDEBUG
240  if(data_block->translated_rows.entry_data.empty()) {
241  ASSERT_EQ(current_in_block_index, 0);
242  } else {
243  ASSERT_LT(current_in_block_index, data_block->translated_rows.entry_data.size());
244  }
245 #endif
246 
247  // Note, this may be nullptr in the case of only untranslated columns and no targets.
248  return data_block->translated_rows.entry_data.data() + current_in_block_index;
249  }
250 
251 
252 };
253 
254 ////////////////////////////////////////////////////////////////////////////////
255 // Implementations of the above
256 
257 template <typename Entry>
259 void ml_data_row_reference::fill(std::vector<Entry>& x) const {
260 
261  x.clear();
262 
263  unpack(
264 
265  [&](ml_column_mode mode, size_t column_index,
266  size_t feature_index, double value,
267  size_t index_size, size_t index_offset) GL_GCC_ONLY(GL_HOT_INLINE_FLATTEN) {
268 
269  size_t global_index = (LIKELY(feature_index < index_size)
270  ? index_offset + feature_index
271  : size_t(-1));
272 
273  Entry e;
274  e = ml_data_full_entry{column_index, feature_index, global_index, value};
275  x.push_back(e);
276  },
277 
278  // Nothing that we need to do at the end of each column.
279  [&](ml_column_mode, size_t, size_t) {});
280 }
281 
282 ////////////////////////////////////////////////////////////////////////////////
283 // fill eigen stuff
284 
285 template <typename EigenXpr>
287 inline void ml_data_row_reference::fill_eigen(EigenXpr&& x) const {
288 
289  x.setZero();
290 
291  unpack(
292 
293  /** The function to write out the data to x.
294  */
295  [&](ml_column_mode mode, size_t column_index,
296  size_t feature_index, double value,
297  size_t index_size, size_t index_offset) {
298 
299  if(UNLIKELY(feature_index >= index_size))
300  return;
301 
302  size_t idx = index_offset + feature_index;
303 
304  DASSERT_GE(idx, 0);
305  DASSERT_LT(idx, size_t(x.size()));
306  x.coeffRef(idx) = value;
307  },
308 
309  /** The function to advance the offset, called after each column
310  * is finished.
311  */
312  [&](ml_column_mode mode, size_t column_index, size_t index_size) {});
313 }
314 
315 }
316 
317 #endif /* TURI_ML_DATA_ROW_REFERENCE_H_ */
GL_HOT_INLINE_FLATTEN void fill_eigen(EigenXpr &&x) const
double target_value() const GL_HOT_INLINE_FLATTEN
GL_HOT_INLINE_FLATTEN void unpack(ElementWriteFunction &&ewf, ColumnEndFunction &&cef) const
const std::shared_ptr< ml_metadata > & metadata() const
size_t target_index() const GL_HOT_INLINE_FLATTEN
void fill_untranslated_values(std::vector< flexible_type > &x) const
GL_HOT_INLINE void fill(std::vector< Entry > &x) const
#define GL_HOT_INLINE
#define GL_HOT_INLINE_FLATTEN
static GL_HOT ml_data_row_reference from_row(const std::shared_ptr< ml_metadata > &metadata, const flex_dict &row, ml_missing_value_action none_action=ml_missing_value_action::USE_NAN)
std::vector< std::pair< flexible_type, flexible_type > > flex_dict
GL_HOT_INLINE_FLATTEN void fill(EigenXpr &&x, typename std::enable_if< std::is_convertible< EigenXpr, DenseVector >::value >::type *=0) const
#define DASSERT_TRUE(cond)
Definition: assertions.hpp:364