Turi Create  4.0
transformation.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef _CANVAS_STREAMING_TRANSFORMATION
7 #define _CANVAS_STREAMING_TRANSFORMATION
8 
9 #include <core/data/flexible_type/flexible_type.hpp>
10 #include <core/parallel/lambda_omp.hpp>
11 
12 namespace turi {
13 namespace visualization {
14 
15 class transformation_output {
16  public:
17  virtual ~transformation_output() = default;
18  virtual std::string vega_column_data(bool sframe = false) const = 0;
19 };
20 
21 class sframe_transformation_output : public transformation_output {
22  public:
23  virtual std::string vega_summary_data() const = 0;
24 };
25 
26 class transformation_base {
27  public:
28  virtual ~transformation_base() = default;
29  virtual std::shared_ptr<transformation_output> get() = 0;
30  virtual bool eof() const = 0;
31  double get_percent_complete() const;
32  virtual size_t get_batch_size() const = 0;
33  virtual flex_int get_total_rows() const = 0;
34  virtual flex_int get_rows_processed() const = 0;
35 };
36 
37 class transformation_collection : public std::vector<std::shared_ptr<transformation_base>> {
38  public:
39  // combines all of the transformations in the collection
40  // into a single transformer interface to simplify consumption
41 };
42 
43 template<typename InputIterable,
44  typename Output>
45 class transformation : public transformation_base {
46  protected:
47  size_t m_batch_size;
48  InputIterable m_source;
49  std::shared_ptr<Output> m_transformer;
50  size_t m_currentIdx = 0;
51  bool m_initialized = false;
52 
53  private:
54  void check_init(const char * msg, bool initialized) const {
55  if (initialized != m_initialized) {
56  log_and_throw(msg);
57  }
58  }
59  void require_init() const {
60  check_init("Transformer must be initialized before performing this operation.", true);
61  }
62 
63  protected:
64  /* Subclasses may override: */
65  /* Get the current result (without iterating over any new values) */
66  virtual Output get_current() {
67  return *m_transformer;
68  }
69  /* Create multiple transformers from input */
70  virtual std::vector<Output> split_input(size_t num_threads) {
71  return std::vector<Output>(num_threads);
72  }
73  /* Merge multiple transformers into output */
74  virtual void merge_results(std::vector<Output>& transformers) = 0;
75 
76  public:
77  virtual void init(const InputIterable& source, size_t batch_size) {
78  check_init("Transformer is already initialized.", false);
79  m_batch_size = batch_size;
80  m_source = source;
81  m_transformer = std::make_shared<Output>();
82  m_currentIdx = 0;
83  m_initialized = true;
84  }
85  virtual bool eof() const override {
86  require_init();
87  DASSERT_LE(m_currentIdx, m_source.size());
88  return m_currentIdx == m_source.size();
89  }
90  virtual flex_int get_rows_processed() const override {
91  require_init();
92  DASSERT_LE(m_currentIdx, m_source.size());
93  return m_currentIdx;
94  }
95  virtual flex_int get_total_rows() const override {
96  require_init();
97  return m_source.size();
98  }
99  virtual std::shared_ptr<transformation_output> get() override {
100  require_init();
101  if (this->eof()) {
102  // bail out, done streaming
103  return m_transformer;
104  }
105 
106  const size_t num_threads_reported = thread_pool::get_instance().size();
107  const size_t start = m_currentIdx;
108  const size_t input_size = std::min(m_batch_size, m_source.size() - m_currentIdx);
109  const size_t end = start + input_size;
110  auto transformers = this->split_input(num_threads_reported);
111  const auto& source = this->m_source;
112  in_parallel(
113  [&transformers, &source, input_size, start]
114  (size_t thread_idx, size_t num_threads) {
115 
116  DASSERT_LE(transformers.size(), num_threads);
117  if (thread_idx >= transformers.size()) {
118  // this operation isn't parallel enough to use all threads.
119  // bail out on this thread.
120  return;
121  }
122 
123  auto& transformer = transformers[thread_idx];
124  size_t thread_input_size = input_size / transformers.size();
125  size_t thread_start = start + (thread_idx * thread_input_size);
126  size_t thread_end = thread_idx == transformers.size() - 1 ?
127  start + input_size :
128  thread_start + thread_input_size;
129  DASSERT_LE(thread_end, start + input_size);
130  for (const auto& value : source.range_iterator(thread_start, thread_end)) {
131  transformer.add_element_simple(value);
132  }
133  });
134 
135  this->merge_results(transformers);
136  m_currentIdx = end;
137 
138  return m_transformer;
139  }
140 
141  virtual size_t get_batch_size() const override {
142  return m_batch_size;
143  }
144 };
145 
146 }}
147 
148 #endif
static thread_pool & get_instance()
size_t size() const
void in_parallel(const std::function< void(size_t thread_id, size_t num_threads)> &fn)
Definition: lambda_omp.hpp:35