Turi Create  4.0
dc_data_iterator.hpp
1 /* Copyright © 2019 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at
5  * https://opensource.org/licenses/BSD-3-Clause
6  */
7 
8 #ifndef TURI_DRAWING_CLASSIFICATION_DC_DATA_ITERATOR_HPP_
9 #define TURI_DRAWING_CLASSIFICATION_DC_DATA_ITERATOR_HPP_
10 
11 #include <random>
12 #include <string>
13 #include <vector>
14 
15 #include <core/data/flexible_type/flexible_type.hpp>
16 #include <core/data/sframe/gl_sframe.hpp>
17 #include <ml/neural_net/float_array.hpp>
18 
19 namespace turi {
20 namespace drawing_classifier {
21 
22 /**
23  * Pure virtual interface for classes that produce batches of data
24  * (pre-augmentation) from a raw SFrame.
25  * \TODO Factor out the shared structure for data iterators
26  * with the other iterators!
27  */
29  public:
30  /**
31  * Defines the inputs to a data_iterator factory function.
32  */
33  struct parameters {
34  /** The SFrame to traverse */
36 
37  /**
38  * The name of the column containing the target variable.
39  *
40  * If empty, then the output will not contain labels.
41  */
42  std::string target_column_name;
43 
44  /** The name of the feature column. */
45  std::string feature_column_name{"feature"};
46 
47  /** The name of the predictions column. */
49 
50  /**
51  * The expected class labels, indexed by identifier.
52  *
53  * If empty, then the labels will be inferred from the data. If non-empty,
54  * an exception will be thrown upon encountering an unexpected label.
55  */
57 
58  /** Whether this is training data or not. */
59  bool is_train = true;
60 
61  /** Whether to traverse the data more than once. */
62  bool repeat = true;
63 
64  /** Whether to shuffle the data on subsequent traversals. */
65  bool shuffle = true;
66 
67  /** Determines results of shuffle operations if enabled. */
68  int random_seed = 0;
69 
70  // normalization factor for input data
71  float scale_factor = 1 / 255.f;
72  };
73 
74  /** Defines the output of a data_iterator. */
75  struct batch {
76  /* Number of examples in batch */
77  size_t num_samples = 256;
78 
79  /**
80  * An array with shape: (requested_batch_size, 28, 28 1)
81  *
82  * Each row is an image.
83  */
84  neural_net::shared_float_array drawings;
85 
86  /**
87  * An array with shape: (requested_batch_size, 1)
88  *
89  * Each row is the target.
90  */
91  neural_net::shared_float_array targets;
92 
93  /**
94  * An array with shape: (requested_batch_size, 1)
95  *
96  * Each row is the weight associated with the target.
97  */
98  neural_net::shared_float_array weights;
99 
100  /**
101  * An array with shape: (requested_batch_size, 1)
102  *
103  * Each row is the prediction.
104  */
105  neural_net::shared_float_array predictions;
106  };
107 
108  virtual ~data_iterator() = default;
109 
110  /**
111  * Returns a vector whose size is equal to `batch_size`.
112  *
113  * If `repeat` was set in the parameters, then the iterator will cycle
114  * indefinitely through the SFrame over and over. Otherwise, the last
115  * non-empty batch may contain fewer than `batch_size` elements, and every
116  * batch after that will be empty.
117  *
118  */
119  virtual batch next_batch(size_t batch_size) = 0;
120 
121  /**
122  * Returns true if and only if the next call to `next_batch` will return a
123  * batch with size greater than 0.
124  */
125  virtual bool has_next_batch() = 0;
126 
127  /** Begins a fresh traversal of the dataset. */
128  virtual void reset() = 0;
129 
130  /**
131  * Returns a sorted list of the unique "label" values found in the
132  * target.
133  */
134  virtual const flex_list& class_labels() const = 0;
135 
136 };
137 
138 /**
139  * Concrete data_iterator implementation that doesn't attempt any
140  * parallelization or background I/O.
141  *
142  * \todo This classs should become an abstract_data_iterator base class with
143  * override points for dispatching work to other threads.
144  */
146  public:
147  simple_data_iterator(const parameters& params);
148 
149  // Not copyable or movable.
151  simple_data_iterator& operator=(const simple_data_iterator&) = delete;
152 
153  batch next_batch(size_t batch_size) override;
154 
155  bool has_next_batch() override;
156 
157  void reset() override;
158 
159  const flex_list& class_labels() const override {
160  return target_properties_.classes;
161  }
162 
163  private:
164  struct target_properties {
165  flex_list classes;
166  };
167 
168  target_properties compute_properties(
169  const gl_sframe& data, const std::string& target_column_name,
170  const flex_list& expected_class_labels);
171 
172  gl_sframe data_;
173  const int64_t target_index_;
174  const int64_t predictions_index_; // -1 if not present
175  const int64_t feature_index_;
176  const bool repeat_;
177  const bool shuffle_;
178  const float scale_factor_ = 1 / 255.0f;
179 
180  const target_properties target_properties_;
181 
182  gl_sframe_range range_iterator_;
183  gl_sframe_range::iterator next_row_;
184  gl_sframe_range::iterator end_of_rows_;
185 
186  std::default_random_engine random_engine_;
187 };
188 
189 } // namespace drawing_classifier
190 } // namespace turi
191 
192 #endif // TURI_DRAWING_CLASSIFICATION_DC_DATA_ITERATOR_HPP_
const flex_list & class_labels() const override
virtual batch next_batch(size_t batch_size)=0
std::vector< flexible_type > flex_list