Turi Create  4.0
ml_data_iterator.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_DML_DATA_ITERATOR_H_
7 #define TURI_DML_DATA_ITERATOR_H_
8 
9 #include <core/logging/assertions.hpp>
10 #include <ml/ml_data/data_storage/ml_data_row_translation.hpp>
11 #include <ml/ml_data/data_storage/ml_data_block_manager.hpp>
12 #include <ml/ml_data/ml_data.hpp>
13 #include <ml/ml_data/row_reference.hpp>
14 #include <core/util/code_optimization.hpp>
15 
16 // SArray and Flex type
17 #include <core/storage/sframe_data/sarray.hpp>
18 
19 #include <Eigen/SparseCore>
20 #include <Eigen/Core>
21 
22 #include <array>
23 
24 namespace turi {
25 
26 class ml_data;
27 
28 typedef Eigen::Matrix<double, Eigen::Dynamic,1> DenseVector;
29 typedef Eigen::SparseVector<double> SparseVector;
30 
31 /**
32  * \ingroup mldata
33  * Just a simple iterator on the ml_data class. It's just a
34  * convenience structure that keeps track of everything relevant for
35  * the toolkits.
36  */
38  private:
39 
40  // To be initialized only from the get_iterator() method of ml_data.
41  friend class ml_data;
42 
43  // Internal -- called from ml_data;
44  void setup(const ml_data& _data,
46  size_t thread_idx, size_t num_threads);
47 
48  public:
49 
51  ml_data_iterator(const ml_data_iterator&) = delete;
53 
54  ml_data_iterator& operator=(const ml_data_iterator&) = delete;
55  ml_data_iterator& operator=(ml_data_iterator&&) = default;
56 
57  /// Resets the iterator to the start of the sframes in ml_data.
58  void reset();
59 
60  /// Returns true if the iteration is done, false otherwise.
61  inline bool done() const { return current_row_index == iter_row_index_end; }
62 
63  /// Returns the current index of the sframe row, respecting all
64  /// slicing operations on the original ml_data.
65  inline size_t row_index() const { return current_row_index - global_row_start; }
66 
67  ////////////////////////////////////////////////////////////////////////////////
68 
69  /** Return a row reference. The row reference can be used to fill
70  * the observation vectors.
71  */
73  return row;
74  }
75 
76  /** Dereference the iterator.
77  */
79  return &row;
80  }
81 
82 
83  ////////////////////////////////////////////////////////////////////////////////
84 
85  /// Advance the iterator to the next observation.
87  this->advance_row();
88  return *this;
89  }
90 
91  ////////////////////////////////////////////////////////////////////////////////
92 
93  /** Return the data this iterator is working with.
94  */
95  inline const ml_data& ml_data_source() const {
96  return data;
97  }
98 
99  /** Return the raw value of the internal row storage. Used by some
100  * of the internal ml_data processing routines.
101  */
103 
104  if(!rm.data_size_is_constant)
105  ++raw_index;
106 
107  return *(current_data_iter() + raw_index);
108  }
109 
110  ////////////////////////////////////////////////////////////////////////////////
111 
112  /** Seeks to the row given by row_index.
113  *
114  */
115  void seek(size_t row_index) {
116  size_t absolute_row_index = row_index + global_row_start;
117 
118  ASSERT_MSG(absolute_row_index <= global_row_end,
119  "Requested row index out of bounds.");
120 
121  ASSERT_MSG((iter_row_index_start == global_row_start
122  && iter_row_index_end == global_row_end),
123  "Seek not supported with multithreaded iterators.");
124 
125  current_row_index = absolute_row_index;
126 
127  if(!done())
128  setup_block_containing_current_row_index();
129  }
130 
131  private:
132 
133  // Internally, ml_data is just a bunch of shared pointers, so it's
134  // not expensive to store a copy.
135  ml_data data;
136 
138 
139  size_t iter_row_index_start = -1; /**< Starting row index for this iterator. */
140  size_t iter_row_index_end = -1; /**< Ending row index for this iterator. */
141  size_t current_row_index = -1; /**< Current row index for this iterator. */
142  size_t current_block_index = -1; /**< Index of the currently loaded block. */
143 
144  /** A reference to the current row that we're pointing to. Holds
145  * the data_block and current_in_block_index
146  */
148 
149  /** The absolute values of the global row starting locations.
150  */
151  size_t global_row_start, global_row_end;
152 
153  private:
154 
155  /** Return a pointer to the current location in the data.
156  */
157  inline ml_data_internal::entry_value_iterator current_data_iter() const GL_HOT_INLINE_FLATTEN {
158 
159  DASSERT_FALSE(done());
160  DASSERT_LT(row.current_in_block_index, row.data_block->translated_rows.entry_data.size());
161 
162  return &(row.data_block->translated_rows.entry_data[row.current_in_block_index]);
163  }
164 
165  /** Return a pointer to the current location in the data.
166  */
167  inline size_t current_block_row_index() const GL_HOT_INLINE_FLATTEN {
168 
169  size_t index = current_row_index - (current_block_index * data.row_block_size);
170 
171  DASSERT_FALSE(done());
172  DASSERT_LT(index, data.row_block_size);
173 
174  return index;
175  }
176 
177 
178  /** Advance to the next row.
179  */
180  inline void advance_row() GL_HOT_INLINE_FLATTEN {
181 
182  if(row.has_translated_columns || rm.has_target)
183  row.current_in_block_index += get_row_data_size(rm, current_data_iter());
184 
185  ++current_row_index;
186 
187  DASSERT_GE(current_row_index, current_block_index * data.row_block_size);
188 
189  row.current_in_block_row_index = current_row_index - current_block_index * data.row_block_size;
190 
191  if(row.current_in_block_row_index == data.row_block_size && !done())
192  load_next_block();
193  }
194 
195  ////////////////////////////////////////////////////////////////////////////////
196  // Internal reader functions
197 
198  /// Loads the block containing the row index row_index
199  void setup_block_containing_current_row_index() GL_HOT_NOINLINE;
200 
201  /// Loads the next block, resetting all the values so iteration will
202  /// be supported over the next row.
203  void load_next_block() GL_HOT_NOINLINE;
204 
205 };
206 
207 }
208 
209 #endif /* TURI_DML_DATA_ITERATOR_H_ */
void reset()
Resets the iterator to the start of the sframes in ml_data.
const ml_data & ml_data_source() const
void seek(size_t row_index)
const ml_data_iterator & operator++() GL_HOT_INLINE_FLATTEN
Advance the iterator to the next observation.
ml_data_row_reference const * operator->() const GL_HOT_INLINE_FLATTEN
ml_data_internal::entry_value _raw_row_entry(size_t raw_index) const GL_HOT_INLINE_FLATTEN
#define DASSERT_FALSE(cond)
Definition: assertions.hpp:365
#define GL_HOT_INLINE_FLATTEN
bool done() const
Returns true if the iteration is done, false otherwise.
ml_data_row_reference operator*() const GL_HOT_INLINE_FLATTEN