Turi Create  4.0
source_transforms.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_SFRAME_QUERY_OPTIMIZATION_SOURCE_TRANSFORMS_H_
7 #define TURI_SFRAME_QUERY_OPTIMIZATION_SOURCE_TRANSFORMS_H_
8 
9 #include <core/storage/query_engine/planning/optimizations/optimization_transforms.hpp>
10 #include <core/storage/query_engine/planning/optimization_engine.hpp>
11 #include <core/storage/query_engine/operators/all_operators.hpp>
12 #include <core/storage/query_engine/planning/optimization_node_info.hpp>
13 #include <core/storage/query_engine/operators/operator_properties.hpp>
14 #include <core/data/flexible_type/flexible_type.hpp>
15 
16 #include <array>
17 
18 namespace turi {
19 namespace query_eval {
20 
21 /** This optimization scans the entire graph for duplicate sarrays,
22  * including inside of sframes, and then merges them, using projects
23  * to keep things consistent.
24  *
25  * THis works by:
26  *
27  * 1. If it's the tip node of the tree to be optimized, it goes
28  * through and makes a list of all source nodes. sframe sources are
29  * assumed to be a collection of sarray sources.
30  *
31  * 2. Identical sarrays are merged. If they are part of sframes,
32  * the two sframes are merged together, with a project transform
33  * following them for each of the outputs.
34  *
35  * 3. Other identical source nodes (e.g. range nodes) are merged.
36  *
37  */
39 
40  std::string description() { return "source_a, ..., source_a -> source_a"; }
41 
42  // Only apply this to the node at the head of the graph
43  bool transform_applies(planner_node_type t) {
44  return (t == planner_node_type::IDENTITY_NODE);
45  }
46 
47  void fill_source_sets(const cnode_info_ptr& n,
48  std::set<const node_info*>& seen, std::vector<cnode_info_ptr>& source_nodes) {
49 
50  auto it = seen.lower_bound(n.get());
51 
52  if(it != seen.end() && *it == n.get())
53  return;
54  else
55  seen.insert(it, n.get());
56 
57  if(n->is_source_node())
58  source_nodes.push_back(n);
59 
60  for(const auto& nn : n->inputs) {
61  fill_source_sets(nn, seen, source_nodes);
62  }
63  }
64 
65  // Only apply this to the node at the head of the graph
66  bool apply_transform(optimization_engine *opt_manager, cnode_info_ptr n) {
67 
68  // First, go through and get all the source_nodes present.
69  std::vector<cnode_info_ptr> source_nodes;
70  std::set<const node_info*> seen_nodes;
71  fill_source_sets(n, seen_nodes, source_nodes);
72 
73  // What identifies a unique source.
74  struct source_id {
75  ptrdiff_t ptr_key;
76  size_t begin_index;
77  size_t end_index;
78 
79  bool operator<(const source_id& s) const {
80  return (std::make_pair(ptr_key, std::make_pair(begin_index, end_index))
81  < std::make_pair(s.ptr_key, std::make_pair(s.begin_index, s.end_index)));
82  }
83  };
84 
85  struct source_out {
86  cnode_info_ptr src_node;
87  size_t column_index;
88  };
89 
90  std::map<source_id, std::vector<source_out> > all_sources;
91 
92  for(const cnode_info_ptr& sn : source_nodes) {
93 
94  switch(sn->type) {
95  case planner_node_type::SFRAME_SOURCE_NODE: {
96  size_t begin_index = sn->p("begin_index");
97  size_t end_index = sn->p("end_index");
98  sframe sf = sn->any_p<sframe>("sframe");
99 
100  for(size_t i = 0; i < sf.num_columns(); ++i) {
101  source_id id;
102  id.ptr_key = ptrdiff_t(sf.select_column(i).get());
103  id.begin_index = begin_index;
104  id.end_index = end_index;
105 
106  source_out out;
107  out.src_node = sn;
108  out.column_index = i;
109 
110  all_sources[id].push_back(out);
111  }
112 
113  break;
114  }
115 
116  case planner_node_type::SARRAY_SOURCE_NODE: {
117  auto sa = sn->any_p<std::shared_ptr<sarray<flexible_type> > >("sarray");
118 
119  source_id id;
120  id.ptr_key = ptrdiff_t(sa.get());
121  id.begin_index = sn->p("begin_index");
122  id.end_index = sn->p("end_index");
123 
124  source_out out;
125  out.src_node = sn;
126  out.column_index = 0;
127 
128  all_sources[id].push_back(out);
129 
130  break;
131  }
132 
133  case planner_node_type::RANGE_NODE: {
134  source_id id;
135  // Use this as the key since any range nodes with the same
136  // begin and end indices can be merged.
137  id.ptr_key = int(planner_node_type::RANGE_NODE);
138  id.begin_index = sn->p("start") + sn->p("begin_index");
139  id.end_index = sn->p("start") + sn->p("end_index");
140 
141  source_out out;
142  out.src_node = sn;
143  out.column_index = 0;
144 
145  all_sources[id].push_back(out);
146 
147  break;
148  }
149 
150  default: break;
151  }
152  }
153 
154  ////////////////////////////////////////////////////////////////////////////////
155  // Merge all the similar nodes.
156  bool change_occured = false;
157  for(auto p : all_sources) {
158 
159  const source_id& id = p.first;
160  const std::vector<source_out>& out_v = p.second;
161 
162  if(out_v.size() == 1) {
163  continue;
164  }
165 
166  // Based on what sources are present, we should go through and
167  // see what's the best way to merge things. This can be tricky
168  // depending on the different cases.
169 
170  size_t sarray_count = 0;
171  size_t sframe_count = 0;
172  size_t range_node_count = 0;
173 
174  for(const source_out& out : out_v) {
175  switch(out.src_node->type) {
176 
177  // Count sframe sources with one column as sarray sources.
178  case planner_node_type::SFRAME_SOURCE_NODE: {
179  if(out.src_node->num_columns() == 1)
180  ++sarray_count;
181  else
182  ++sframe_count;
183  break;
184  }
185 
186  case planner_node_type::SARRAY_SOURCE_NODE:
187  ++sarray_count;
188  break;
189 
190  case planner_node_type::RANGE_NODE:
191  ++range_node_count;
192  break;
193 
194  default: DASSERT_TRUE(false); break;
195  }
196  }
197 
198  if(range_node_count > 0) {
199  // All range nodes.
200  DASSERT_EQ(sarray_count, 0);
201  DASSERT_EQ(sframe_count, 0);
202 
203  for(size_t i = 1; i < out_v.size(); ++i) {
204  opt_manager->replace_node(out_v[i].src_node, out_v[0].src_node->pnode);
205  return true;
206  }
207 
208  } else if(sframe_count == 0) {
209  // All sarrays.
210  DASSERT_EQ(sarray_count, out_v.size());
211 
212  for(size_t i = 1; i < out_v.size(); ++i) {
213  opt_manager->replace_node(out_v[i].src_node, out_v[0].src_node->pnode);
214  }
215  change_occured = true;
216 
217  } else if(sframe_count == 1) {
218  // One sframe, the rest sarrays. Another easy case.
219 
220  size_t sframe_index = 0;
221 
222  for(size_t i = 1; i < out_v.size(); ++i) {
223  if(out_v[i].src_node->type == planner_node_type::SFRAME_SOURCE_NODE
224  && out_v[i].src_node->num_columns() > 1) {
225  sframe_index = i;
226  break;
227  }
228  }
229 
230  pnode_ptr out_project = op_project::make_planner_node(
231  out_v[sframe_index].src_node->pnode, {out_v[sframe_index].column_index});
232 
233  for(size_t i = 0; i < out_v.size(); ++i) {
234  if(i != sframe_index)
235  opt_manager->replace_node(out_v[i].src_node, out_project);
236  }
237 
238  change_occured = true;
239 
240  } else {
241  // The hardest. Merge all of these into one sframe, slap
242  // projections on it afterwords, and then exit as we've
243  // possibly invalidated the rest of the lookup tables since
244  // they will possibly refer to the other components of the
245  // replaced sframe.
246 
247  std::map<void*, size_t> idx_map;
248 
249  std::vector<std::vector<size_t> > projections(out_v.size());
250  std::vector<std::shared_ptr<sarray<flexible_type> > > new_columns;
251 
252  auto get_index = [&](const std::shared_ptr<sarray<flexible_type> >& s) -> size_t {
253  auto it = idx_map.lower_bound(s.get());
254  if(it != idx_map.end() && it->first == s.get()) {
255  return it->second;
256  } else {
257  size_t idx = idx_map.size();
258  idx_map.insert(it, {s.get(), idx});
259  new_columns.push_back(s);
260  return idx;
261  }
262  };
263 
264  for(size_t i = 0; i < out_v.size(); ++i) {
265  if(out_v[i].src_node->type == planner_node_type::SFRAME_SOURCE_NODE) {
266 
267  sframe sf = out_v[i].src_node->any_p<sframe>("sframe");
268  projections[i].resize(sf.num_columns());
269 
270  for(size_t j = 0; j < sf.num_columns(); ++j) {
271  projections[i][j] = get_index(sf.select_column(j));
272  }
273 
274  } else if(out_v[i].src_node->type == planner_node_type::SARRAY_SOURCE_NODE) {
275  auto sa = out_v[i].src_node->any_p<std::shared_ptr<sarray<flexible_type> > >("sarray");
276 
277  projections[i] = {get_index(sa)};
278 
279  } else {
280  DASSERT_TRUE(false);
281  }
282  }
283 
284  // Now, make a new sframe
285  pnode_ptr sf_src = op_sframe_source::make_planner_node(sframe(new_columns), id.begin_index, id.end_index);
286 
287  for(size_t i = 0; i < out_v.size(); ++i) {
288  pnode_ptr rep_node = op_project::make_planner_node(sf_src, projections[i]);
289  opt_manager->replace_node(out_v[i].src_node, rep_node);
290  }
291 
292  // This is needed as the rest of the cache needs to be rebuilt.
293  return true;
294  }
295  }
296 
297  return change_occured;
298  }
299 
300 
301 };
302 
303 }}
304 
305 #endif
std::shared_ptr< sarray< flexible_type > > select_column(size_t column_id) const
std::shared_ptr< planner_node > pnode_ptr
A handy typedef.
size_t num_columns() const
Returns the number of columns in the SFrame. Does not throw.
Definition: sframe.hpp:341
#define DASSERT_TRUE(cond)
Definition: assertions.hpp:364