Turi Create  4.0
logical_filter.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_SFRAME_QUERY_MANAGER_LOGICAL_FILTER_HPP
7 #define TURI_SFRAME_QUERY_MANAGER_LOGICAL_FILTER_HPP
8 #include <core/data/flexible_type/flexible_type.hpp>
9 #include <core/storage/query_engine/operators/operator.hpp>
10 #include <core/storage/query_engine/execution/query_context.hpp>
11 #include <core/storage/query_engine/operators/operator_properties.hpp>
12 #include <core/util/coro.hpp>
13 
14 namespace turi {
15 namespace query_eval {
16 
17 /**
18  * \ingroup sframe_query_engine
19  * \addtogroup operators Logical Operators
20  * \{
21  */
22 
23 /**
24  * A "logical_filter" operator which takes two inputs of the same size:
25  * "values", and "logical indices", and output the value in "values" for which
26  * the logical index is 1.
27  */
28 template<>
29 class operator_impl<planner_node_type::LOGICAL_FILTER_NODE> : public query_operator {
30  public:
31  DECL_CORO_STATE(execute);
32  std::shared_ptr<const sframe_rows> rows_left, rows_right;
33  sframe_rows::const_iterator left_iter, right_iter, out_iter;
34  std::shared_ptr<sframe_rows> output_buffer;
35  size_t cur_output_index = 0;
36  size_t ncols = 0;
37  size_t nrows = 0;
38  bool has_data = false;
39 
40  planner_node_type type() const { return planner_node_type::LOGICAL_FILTER_NODE; }
41 
42  static std::string name() { return "logical_filter"; }
43 
44  inline operator_impl() { };
45 
46  static query_operator_attributes attributes() {
49  ret.num_inputs = 2;
50  return ret;
51  }
52 
53  inline std::shared_ptr<query_operator> clone() const {
54  return std::make_shared<operator_impl>(*this);
55  }
56 
57  // tests if the first column of col is all zeros
58  bool is_all_zero(const std::shared_ptr<const sframe_rows>& col) {
59  // if it is all zero, we can skip the left data
60  for (auto& row: *col) {
61  if (!(row[0].is_zero())) return false;
62  }
63  return true;
64  }
65 
66  inline bool coro_running() const {
67  return CORO_RUNNING(execute);
68  }
69  inline void execute(query_context& context) {
70  CORO_BEGIN(execute)
71  // read one block
72  rows_left = context.get_next(0);
73  rows_right = context.get_next(1);
74  if (rows_left == nullptr && rows_right == nullptr) return;
75  ASSERT_TRUE(rows_left != nullptr && rows_right != nullptr);
76 
77  // set up the output shape
78  output_buffer = context.get_output_buffer();
79  cur_output_index = 0;
80  ncols = rows_left->num_columns();
81  nrows = context.block_size();
82  output_buffer->resize(ncols, nrows);
83 
84 
85  while(1) {
86  ASSERT_TRUE(rows_left != nullptr && rows_right != nullptr);
87  ASSERT_EQ(rows_left->num_rows(), rows_right->num_rows());
88 
89  left_iter = rows_left->cbegin();
90  right_iter = rows_right->cbegin();
91  while(left_iter != rows_left->cend()) {
92  if (!(*right_iter)[0].is_zero()) {
93  (*output_buffer)[cur_output_index] = (*left_iter);
94  ++cur_output_index;
95  if (cur_output_index == nrows) {
96  context.emit(output_buffer);
97  CORO_YIELD();
98  output_buffer = context.get_output_buffer();
99  output_buffer->resize(ncols, nrows);
100  cur_output_index = 0;
101  }
102  }
103  ++left_iter;
104  ++right_iter;
105  }
106  has_data = false;
107  do {
108  // get the binary column first
109  rows_right = context.get_next(1);
110  // skip left if it is all zeros
111  if (rows_right != nullptr && is_all_zero(rows_right)) {
112  context.skip_next(0);
113  } else {
114  has_data = true;
115  rows_left = context.get_next(0);
116  }
117  // loop until there is really no data.
118  } while(has_data == false);
119  if(rows_left == nullptr && rows_right == nullptr) break;
120  }
121 
122  if (cur_output_index > 0) {
123  output_buffer->resize(ncols, cur_output_index);
124  context.emit(output_buffer);
125  CORO_YIELD();
126  }
127  CORO_END
128  }
129 
130  static std::shared_ptr<planner_node> make_planner_node(
131  std::shared_ptr<planner_node> left,
132  std::shared_ptr<planner_node> right) {
133  return planner_node::make_shared(planner_node_type::LOGICAL_FILTER_NODE,
134  std::map<std::string, flexible_type>(),
135  std::map<std::string, any>(),
136  {left, right});
137  }
138 
139  static std::shared_ptr<query_operator> from_planner_node(
140  std::shared_ptr<planner_node> pnode) {
141  ASSERT_EQ((int)pnode->operator_type,
142  (int)planner_node_type::LOGICAL_FILTER_NODE);
143  ASSERT_EQ(pnode->inputs.size(), 2);
144  return std::make_shared<operator_impl>();
145  }
146 
147  static std::vector<flex_type_enum> infer_type(
148  std::shared_ptr<planner_node> pnode) {
149  ASSERT_EQ((int)pnode->operator_type,
150  (int)planner_node_type::LOGICAL_FILTER_NODE);
151  ASSERT_EQ(pnode->inputs.size(), 2);
152  return infer_planner_node_type(pnode->inputs[0]);
153  }
154 
155  static int64_t infer_length(std::shared_ptr<planner_node> pnode) {
156  return -1;
157  }
158 
159  static std::string repr(std::shared_ptr<planner_node> pnode, pnode_tagger& get_tag) {
160  ASSERT_EQ(pnode->inputs.size(), 2);
161  return std::string("Filter(") + get_tag(pnode->inputs[0]) + "[" + get_tag(pnode->inputs[1]) + "])";
162  }
163 
164 };
165 
167 
168 /// \}
169 } // query_eval
170 } // turicreate
171 
172 #endif // TURI_SFRAME_QUERY_MANAGER_LOGICAL_FILTER_HPP
int num_inputs
Number of inputs expected to the operator.
Definition: operator.hpp:56
size_t attribute_bitfield
A bitfield of the attribute enum.
Definition: operator.hpp:55
std::shared_ptr< const sframe_rows > get_next(size_t input_number)
void emit(const std::shared_ptr< sframe_rows > &rows)
std::shared_ptr< sframe_rows > get_output_buffer()
#define ASSERT_TRUE(cond)
Definition: assertions.hpp:309
static std::shared_ptr< planner_node > make_shared(planner_node_type operator_type, const std::map< std::string, flexible_type > &operator_parameters=std::map< std::string, flexible_type >(), const std::map< std::string, any > &any_operator_parameters=std::map< std::string, any >(), const std::vector< std::shared_ptr< planner_node >> &inputs=std::vector< std::shared_ptr< planner_node >>())
std::vector< flex_type_enum > infer_planner_node_type(std::shared_ptr< planner_node > pnode)
void skip_next(size_t input_number)