Turi Create  4.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
style_transfer_data_iterator.hpp
1 /* Copyright © 2019 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at
5  * https://opensource.org/licenses/BSD-3-Clause
6  */
7 
8 #ifndef __TOOLKITS_STYLE_TRANSFER_DATA_ITERATOR_H_
9 #define __TOOLKITS_STYLE_TRANSFER_DATA_ITERATOR_H_
10 
11 #include <random>
12 
13 #include <core/data/sframe/gl_sarray.hpp>
14 
15 namespace turi {
16 namespace style_transfer {
17 
18 struct st_example {
19  image_type content_image;
20  image_type style_image;
21  size_t style_index;
22 };
23 
24 enum st_mode {
25  TRAIN = 0,
26  PREDICT = 1
27 };
28 
29 class data_iterator {
30  public:
31  struct parameters {
32  /** The Style SArray to traverse */
33  gl_sarray style;
34 
35  /** The Content SArray to traverse */
36  gl_sarray content;
37 
38  /**
39  * Whether to traverse the data more than once.
40  */
41  bool repeat = true;
42 
43  /** Whether to shuffle the data on subsequent traversals. */
44  bool shuffle = true;
45 
46  /** Check whether in train or predict mode */
47  enum st_mode mode = st_mode::TRAIN;
48 
49  /** Determines results of shuffle operations if enabled. */
50  int random_seed = 0;
51  };
52 
53  virtual ~data_iterator() = default;
54 
55  /** Returns true when `next_batch` will return a non-empty value. */
56  virtual bool has_next_batch() const = 0;
57 
58  virtual std::vector<st_example> next_batch(size_t batch_size) = 0;
59 
60  virtual void reset() = 0;
61 };
62 
63 class style_transfer_data_iterator : public data_iterator {
64  public:
65  style_transfer_data_iterator(const data_iterator::parameters& params);
66 
67  style_transfer_data_iterator(const style_transfer_data_iterator&) = delete;
68  style_transfer_data_iterator& operator=(const style_transfer_data_iterator&) =
69  delete;
70 
71  bool has_next_batch() const override {
72  // TODO: gl_sframe_range::end() should be a const method.
73  gl_sarray_range range_iterator(m_content_range_iterator);
74  return m_content_next_row != range_iterator.end();
75  }
76 
77  std::vector<st_example> next_batch(size_t batch_size) override;
78 
79  void reset() override;
80 
81  private:
82  gl_sarray m_style_images;
83  gl_sarray m_content_images;
84 
85  const bool m_repeat;
86  const bool m_shuffle;
87  const enum st_mode m_mode;
88 
89  gl_sarray_range m_content_range_iterator;
90  gl_sarray_range::iterator m_content_next_row;
91 
92  std::default_random_engine m_random_engine;
93 };
94 
95 } // namespace style_transfer
96 } // namespace turi
97 
98 #endif // __TOOLKITS_STYLE_TRANSFER_H_
std::vector< sframe > shuffle(sframe sframe_in, size_t n, std::function< size_t(const std::vector< flexible_type > &)> hash_fn, std::function< void(const std::vector< flexible_type > &, size_t)> emit_call_back=std::function< void(const std::vector< flexible_type > &, size_t)>())