Turi Create  4.0
general_union_project_transforms.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_SFRAME_QUERY_OPTIMIZATION_GENERALIZED_UNION_PROJECT_TRANSFORMS_HPP_
7 #define TURI_SFRAME_QUERY_OPTIMIZATION_GENERALIZED_UNION_PROJECT_TRANSFORMS_HPP_
8 
9 #include <core/storage/query_engine/planning/optimizations/optimization_transforms.hpp>
10 #include <core/storage/query_engine/planning/optimization_engine.hpp>
11 #include <core/storage/query_engine/operators/all_operators.hpp>
12 #include <core/storage/query_engine/planning/optimization_node_info.hpp>
13 #include <core/storage/query_engine/operators/operator_properties.hpp>
14 #include <core/data/flexible_type/flexible_type.hpp>
15 
16 #include <array>
17 
18 namespace turi {
19 namespace query_eval {
20 
21 ////////////////////////////////////////////////////////////////////////////////
22 // These things apply at the beginning
23 
24 class opt_expand_generalized_union_project : public opt_transform {
25 
26  std::string description() { return "gen_union_proj(...) -> project(union(...), ...)"; }
27 
28  bool transform_applies(planner_node_type t) {
29  return t == planner_node_type::GENERALIZED_UNION_PROJECT_NODE;
30  }
31 
32  bool apply_transform(optimization_engine *opt_manager, cnode_info_ptr n) {
33 
34  std::vector<size_t> idx_offsets(n->inputs.size());
35  for(size_t i = 0, offset = 0; i < n->inputs.size(); offset += n->inputs[i]->num_columns(), ++i) {
36  idx_offsets[i] = offset;
37  }
38 
39  const flex_dict& input_index_maps = n->p("index_map").get<flex_dict>();
40 
41  std::vector<size_t> project_map(input_index_maps.size());
42 
43  for(size_t i= 0; i < input_index_maps.size(); ++i) {
44  const auto& p = input_index_maps[i];
45 
46  project_map[i] = idx_offsets[size_t(p.first)] + size_t(p.second);
47  }
48 
49  pnode_ptr u_node = op_union::make_planner_node(n->pnode->inputs);
50  pnode_ptr p_node = op_project::make_planner_node(u_node, project_map);
51 
52  opt_manager->replace_node(n, p_node);
53  return true;
54  }
55 };
56 
57 ////////////////////////////////////////////////////////////////////////////////
58 // These things apply to the final stage
59 
60 class opt_union_to_generalized_union_project : public opt_transform {
61 
62  std::string description() { return "union(...) -> gen_union_proj(...)"; }
63 
64  bool transform_applies(planner_node_type t) {
65  return t == planner_node_type::UNION_NODE;
66  }
67 
68  bool apply_transform(optimization_engine *opt_manager, cnode_info_ptr n) {
69 
70  DASSERT_TRUE(n->type == planner_node_type::UNION_NODE);
71 
72  std::vector<std::pair<size_t, size_t> > index_map;
73 
74  for(size_t i = 0; i < n->inputs.size(); ++i) {
75  for(size_t j = 0; j < n->inputs[i]->num_columns(); ++j) {
76  index_map.push_back({i, j});
77  }
78  }
79 
80  pnode_ptr new_pnode = op_union_project::make_planner_node(n->pnode->inputs, index_map);
81  opt_manager->replace_node(n, new_pnode);
82 
83  return true;
84  }
85 };
86 
87 class opt_project_to_generalized_union_project : public opt_transform {
88 
89  std::string description() { return "project(...) -> gen_union_proj(...)"; }
90 
91  bool transform_applies(planner_node_type t) {
92  return t == planner_node_type::PROJECT_NODE;
93  }
94 
95  bool apply_transform(optimization_engine *opt_manager, cnode_info_ptr n) {
96 
97  DASSERT_TRUE(n->type == planner_node_type::PROJECT_NODE);
98 
99  const auto& iv = n->p("indices").get<flex_list>();
100 
101  std::vector<std::pair<size_t, size_t> > index_map(iv.size());
102 
103  for(size_t i = 0; i < iv.size(); ++i) {
104  index_map[i] = {0, iv[i]};
105  }
106 
107  pnode_ptr new_pnode = op_union_project::make_planner_node({n->inputs[0]->pnode}, index_map);
108  opt_manager->replace_node(n, new_pnode);
109 
110  return true;
111  }
112 };
113 
114 class opt_merge_generalized_union_projects : public opt_transform {
115 
116  std::string description() { return "gen_union_proj(..., gen_union_proj(...), ...) -> gen_union_proj(...)"; }
117 
118  bool transform_applies(planner_node_type t) {
119  return t == planner_node_type::GENERALIZED_UNION_PROJECT_NODE;
120  }
121 
122  bool apply_transform(optimization_engine *opt_manager, cnode_info_ptr n) {
123 
124  if(! n->input_type_present(planner_node_type::GENERALIZED_UNION_PROJECT_NODE))
125  return false;
126 
127  std::vector<pnode_ptr> inputs;
128  std::map<const node_info*, size_t> _input_loc_map;
129 
130  ////////////////////////////////////////////////////////////////////////////////
131  // Bookkeeping functions
132 
133  auto get_input = [&](const cnode_info_ptr& nn) -> size_t {
134  auto it = _input_loc_map.lower_bound(nn.get());
135 
136  if(it != _input_loc_map.end() && it->first == nn.get()) {
137  return it->second;
138  } else {
139  size_t idx = _input_loc_map.size();
140 
141  inputs.push_back(nn->pnode);
142  _input_loc_map.insert(it, {nn.get(), idx});
143  return idx;
144  }
145  };
146 
147  ////////////////////////////////////////////////////////////////////////////////
148  // Bookkeeping functions
149 
150  const flex_dict& input_index_maps = n->p("index_map").get<flex_dict>();
151 
152  std::vector<std::pair<size_t, size_t> > index_map(input_index_maps.begin(), input_index_maps.end());
153 
154  for(auto& idx_p : index_map) {
155  size_t old_input_idx = idx_p.first;
156  size_t old_input_col = idx_p.second;
157  size_t new_input_idx;
158  size_t new_input_col;
159 
160  DASSERT_LT(old_input_idx, n->inputs.size());
161  DASSERT_LT(old_input_col, n->inputs[old_input_idx]->num_columns());
162 
163  cnode_info_ptr nn = n->inputs[old_input_idx];
164 
165  // Now, we get to the interesting part
166  if(nn->type == planner_node_type::GENERALIZED_UNION_PROJECT_NODE) {
167  const flex_dict& local_input_index_maps = nn->p("index_map").get<flex_dict>();
168 
169  const auto& sub_idx_p = local_input_index_maps[old_input_col];
170 
171  new_input_idx = get_input(nn->inputs[sub_idx_p.first]);
172  new_input_col = sub_idx_p.second;
173  } else {
174  // Just stays more or less the same.
175  new_input_idx = get_input(nn);
176  new_input_col = old_input_col;
177  }
178 
179  idx_p.first = new_input_idx;
180  idx_p.second = new_input_col;
181  }
182 
183  pnode_ptr new_pnode = op_union_project::make_planner_node(inputs, index_map);
184  opt_manager->replace_node(n, new_pnode);
185 
186  return true;
187  }
188 };
189 
190 class opt_project_add_direct_source_tags : public opt_transform {
191 
192  std::string description() { return "add_source_metadata(gen_union_proj(...))"; }
193 
194  bool transform_applies(planner_node_type t) {
195  return t == planner_node_type::GENERALIZED_UNION_PROJECT_NODE;
196  }
197 
198  bool apply_transform(optimization_engine *opt_manager, cnode_info_ptr n) {
199 
200  bool have_source = false;
201  for(const auto& nn : n->inputs) {
202  if(nn->type == planner_node_type::SFRAME_SOURCE_NODE
203  || nn->type == planner_node_type::SARRAY_SOURCE_NODE) {
204  have_source = true;
205  break;
206  }
207  }
208  if(!have_source)
209  return false;
210 
211  if(n->has_any_p("direct_source_mapping"))
212  return false;
213 
214  std::map<size_t, std::shared_ptr<sarray<flexible_type> > > input_mapping;
215 
216  const flex_dict& index_map = n->p("index_map").get<flex_dict>();
217 
218  for(size_t i = 0; i < index_map.size(); ++i) {
219  size_t idx_1 = index_map[i].first;
220  size_t idx_2 = index_map[i].second;
221 
222  auto nn = n->inputs[idx_1];
223 
224  if(nn->type == planner_node_type::SFRAME_SOURCE_NODE) {
225  auto sa = nn->any_p<sframe>("sframe").select_column(idx_2);
226  if (nn->p("begin_index") == 0 && nn->p("end_index") == sa->size()) {
227  input_mapping[i] = sa;
228  }
229  } else if (nn->type == planner_node_type::SARRAY_SOURCE_NODE) {
230  DASSERT_EQ(idx_2, 0);
231  auto sa = nn->any_p<std::shared_ptr<sarray<flexible_type> > >("sarray");
232  if (nn->p("begin_index") == 0 && nn->p("end_index") == sa->size()) {
233  input_mapping[i] = sa;
234  }
235  }
236  }
237 
238  auto new_pnode = pnode_ptr(new planner_node(*n->pnode));
239  new_pnode->any_operator_parameters["direct_source_mapping"] = input_mapping;
240 
241  opt_manager->replace_node(n, new_pnode);
242 
243  return true;
244  }
245 };
246 
247 }}
248 
249 #endif /* _PROJECTION_TRANSFORMS_H_ */
std::shared_ptr< planner_node > pnode_ptr
A handy typedef.
std::vector< std::pair< flexible_type, flexible_type > > flex_dict
std::vector< flexible_type > flex_list
#define DASSERT_TRUE(cond)
Definition: assertions.hpp:364