Turi Create  4.0
ternary_operator.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_SFRAME_QUERY_MANAGER_TERNARY_OPERATOR_HPP
7 #define TURI_SFRAME_QUERY_MANAGER_TERNARY_OPERATOR_HPP
8 
9 #include <functional>
10 #include <core/logging/assertions.hpp>
11 #include <core/data/flexible_type/flexible_type.hpp>
12 #include <core/storage/query_engine/operators/operator.hpp>
13 #include <core/storage/query_engine/execution/query_context.hpp>
14 #include <core/storage/query_engine/operators/operator_properties.hpp>
15 #include <core/util/coro.hpp>
16 
17 namespace turi {
18 namespace query_eval {
19 
20 /**
21  * \ingroup sframe_query_engine
22  * \addtogroup operators Logical Operators
23  * \{
24  */
25 
26 /**
27  * An element-wise "ternary operator".
28  * Takes 3 columns: condition, istrue, isfalse.
29  * For each row,
30  * if condition == True, the corresponding row is selected from istrue
31  * if condition == False, the corresponding row is selected from isfalse
32  */
33 template<>
34 class operator_impl<planner_node_type::TERNARY_OPERATOR> : public query_operator {
35  public:
36  DECL_CORO_STATE(execute);
37 
38  planner_node_type type() const { return planner_node_type::TERNARY_OPERATOR; }
39 
40  static std::string name() { return "ternary"; }
41 
42  static query_operator_attributes attributes() {
44  ret.attribute_bitfield = query_operator_attributes::LINEAR;
45  ret.num_inputs = 3;
46  return ret;
47  }
48 
49  inline operator_impl() { }
50 
51  inline std::shared_ptr<query_operator> clone() const {
52  return std::make_shared<operator_impl>(*this);
53  }
54 
55  inline bool coro_running() const {
56  return CORO_RUNNING(execute);
57  }
58  inline void execute(query_context& context) {
59  constexpr size_t CONDITION_INPUT = 0;
60  constexpr size_t ISTRUE_INPUT = 1;
61  constexpr size_t ISFALSE_INPUT = 2;
62  CORO_BEGIN(execute)
63  while(1) {
64  {
65  auto condition = context.get_next(CONDITION_INPUT);
66 
67  if (condition == nullptr) break;
68  ASSERT_EQ(condition->num_columns(), 1);
69  const auto& condition_column = *(condition->cget_columns()[0]);
70  // sum the number of non-zeros
71  size_t num_non_zero = std::accumulate(condition_column.begin(),
72  condition_column.end(),
73  (size_t)0,
74  [](size_t acc, const flexible_type& f) {
75  return acc + (! f.is_zero());
76  });
77 
78  auto output_buffer = context.get_output_buffer();
79 
80  // fast path. All true, or all false
81  if (num_non_zero == 0 || num_non_zero == condition_column.size()) {
82  size_t input_number;
83  size_t skip_number;
84  if (num_non_zero == 0) {
85  skip_number = ISTRUE_INPUT;
86  input_number = ISFALSE_INPUT;
87  } else {
88  skip_number = ISFALSE_INPUT;
89  input_number = ISTRUE_INPUT;
90  }
91 
92  // all is false
93  context.skip_next(skip_number);
94 
95  auto& out_columns = output_buffer->get_columns();
96 
97  auto input = context.get_next(input_number);
98  ASSERT_EQ(input->num_rows(), condition_column.size());
99  ASSERT_EQ(input->num_columns(), 1);
100 
101  out_columns.clear();
102  out_columns.push_back(input->cget_columns()[0]);
103  } else {
104 
105  auto isfalse = context.get_next(ISFALSE_INPUT);
106  auto istrue = context.get_next(ISTRUE_INPUT);
107  ASSERT_TRUE(istrue != nullptr);
108  ASSERT_TRUE(isfalse != nullptr);
109  ASSERT_EQ(isfalse->num_rows(), condition_column.size());
110  ASSERT_EQ(istrue->num_rows(), condition_column.size());
111  ASSERT_EQ(isfalse->num_columns(), 1);
112  ASSERT_EQ(istrue->num_columns(), 1);
113 
114  output_buffer->resize(1, condition_column.size());
115 
116  auto istrue_iter = istrue->cbegin();
117  auto isfalse_iter = isfalse->cbegin();
118  auto out_iter = output_buffer->begin();
119  for (auto& cval : condition_column) {
120  if (cval.is_zero()) {
121  (*out_iter)[0] = (*isfalse_iter)[0];
122  } else {
123  (*out_iter)[0] = (*istrue_iter)[0];
124  }
125  ++istrue_iter;
126  ++isfalse_iter;
127  ++out_iter;
128  }
129  }
130  context.emit(output_buffer);
131  }
132  CORO_YIELD();
133  }
134  CORO_END
135  }
136 
137  static std::shared_ptr<planner_node> make_planner_node(
138  std::shared_ptr<planner_node> condition,
139  std::shared_ptr<planner_node> istrue,
140  std::shared_ptr<planner_node> isfalse) {
141 
142  return planner_node::make_shared(planner_node_type::TERNARY_OPERATOR,
143  std::map<std::string, flexible_type>(),
144  std::map<std::string, any>(),
145  {condition, istrue, isfalse});
146  }
147 
148  static std::shared_ptr<query_operator> from_planner_node(
149  std::shared_ptr<planner_node> pnode) {
150 
151  ASSERT_EQ((int)pnode->operator_type, (int)planner_node_type::TERNARY_OPERATOR);
152  ASSERT_EQ(pnode->inputs.size(), 3);
153 
154  return std::make_shared<operator_impl>();
155  }
156 
157  static std::vector<flex_type_enum> infer_type(std::shared_ptr<planner_node> pnode) {
158  ASSERT_EQ((int)pnode->operator_type, (int)planner_node_type::TERNARY_OPERATOR);
159  return infer_planner_node_type(pnode->inputs[1]);
160  }
161 
162  static int64_t infer_length(std::shared_ptr<planner_node> pnode) {
163  ASSERT_EQ((int)pnode->operator_type, (int)planner_node_type::TERNARY_OPERATOR);
164  return infer_planner_node_length(pnode->inputs[0]);
165  }
166 
167 };
168 
170 
171 
172 /// \}
173 } // query_eval
174 } // turicreate
175 
176 #endif // TURI_SFRAME_QUERY_MANAGER_TERNARY_OPERATOR_HPP
int64_t infer_planner_node_length(std::shared_ptr< planner_node > pnode)
int num_inputs
Number of inputs expected to the operator.
Definition: operator.hpp:56
size_t attribute_bitfield
A bitfield of the attribute enum.
Definition: operator.hpp:55
std::shared_ptr< const sframe_rows > get_next(size_t input_number)
void emit(const std::shared_ptr< sframe_rows > &rows)
std::shared_ptr< sframe_rows > get_output_buffer()
#define ASSERT_TRUE(cond)
Definition: assertions.hpp:309
static std::shared_ptr< planner_node > make_shared(planner_node_type operator_type, const std::map< std::string, flexible_type > &operator_parameters=std::map< std::string, flexible_type >(), const std::map< std::string, any > &any_operator_parameters=std::map< std::string, any >(), const std::vector< std::shared_ptr< planner_node >> &inputs=std::vector< std::shared_ptr< planner_node >>())
std::vector< flex_type_enum > infer_planner_node_type(std::shared_ptr< planner_node > pnode)
void skip_next(size_t input_number)
bool is_zero() const