Turi Create  4.0
binary_transform.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_SFRAME_QUERY_MANAGER_BINARY_TRANSFORM_HPP
7 #define TURI_SFRAME_QUERY_MANAGER_BINARY_TRANSFORM_HPP
8 
9 #include <functional>
10 #include <core/logging/assertions.hpp>
11 #include <core/data/flexible_type/flexible_type.hpp>
12 #include <core/storage/query_engine/operators/operator.hpp>
13 #include <core/storage/query_engine/execution/query_context.hpp>
14 #include <core/storage/query_engine/operators/operator_properties.hpp>
15 #include <core/util/coro.hpp>
16 
17 namespace turi {
18 namespace query_eval {
19 
20 typedef std::function<flexible_type(const sframe_rows::row&,
21  const sframe_rows::row&)> binary_transform_type;
22 
23 /**
24  * \ingroup sframe_query_engine
25  * \addtogroup operators Logical Operators
26  * \{
27  */
28 
29 /**
30  * A "binary transform" operator applys a transform function on two
31  * stream of input.
32  */
33 template<>
34 class operator_impl<planner_node_type::BINARY_TRANSFORM_NODE> : public query_operator {
35  public:
36  DECL_CORO_STATE(execute);
37  std::shared_ptr<const sframe_rows> rows_left, rows_right;
38  sframe_rows::const_iterator left_iter, right_iter;
39  sframe_rows::iterator out_iter;
40  std::shared_ptr<sframe_rows> output_buffer;
41 
42 
43  planner_node_type type() const { return planner_node_type::BINARY_TRANSFORM_NODE; }
44 
45  static std::string name() { return "binary_transform"; }
46 
47  static query_operator_attributes attributes() {
49  ret.attribute_bitfield = query_operator_attributes::LINEAR;
50  ret.num_inputs = 2;
51  return ret;
52  }
53 
54  inline operator_impl(const binary_transform_type& f,
55  flex_type_enum output_type)
56  : m_transform_fn(f)
57  { }
58 
59  inline std::shared_ptr<query_operator> clone() const {
60  return std::make_shared<operator_impl>(*this);
61  }
62  inline bool coro_running() const {
63  return CORO_RUNNING(execute);
64  }
65  inline void execute(query_context& context) {
66  CORO_BEGIN(execute)
67  while(1) {
68  {
69  rows_left = context.get_next(0);
70  rows_right = context.get_next(1);
71  if (rows_left == nullptr && rows_right == nullptr) break;
72  ASSERT_TRUE(rows_left != nullptr && rows_right != nullptr);
73  ASSERT_EQ(rows_left->num_rows(), rows_right->num_rows());
74  ASSERT_EQ(rows_left->num_columns(), 1);
75  ASSERT_EQ(rows_right->num_columns(), 1);
76  output_buffer = context.get_output_buffer();
77  output_buffer->resize(1, rows_left->num_rows());
78 
79  left_iter = rows_left->cbegin();
80  right_iter = rows_right->cbegin();
81  out_iter = output_buffer->begin();
82  while(left_iter != rows_left->cend()) {
83  (*out_iter)[0] = m_transform_fn((*left_iter), (*right_iter));
84  ++left_iter;
85  ++right_iter;
86  ++out_iter;
87  }
88  context.emit(output_buffer);
89  }
90  CORO_YIELD();
91  }
92  CORO_END
93  }
94 
95  static std::shared_ptr<planner_node> make_planner_node(
96  std::shared_ptr<planner_node> left,
97  std::shared_ptr<planner_node> right,
98  binary_transform_type fn,
99  flex_type_enum output_type) {
100 
101  return planner_node::make_shared(planner_node_type::BINARY_TRANSFORM_NODE,
102  {{"output_type", (int)(output_type)}},
103  {{"function", any(fn)}},
104  {left, right});
105  }
106 
107  static std::shared_ptr<query_operator> from_planner_node(
108  std::shared_ptr<planner_node> pnode) {
109 
110  ASSERT_EQ((int)pnode->operator_type, (int)planner_node_type::BINARY_TRANSFORM_NODE);
111  ASSERT_EQ(pnode->inputs.size(), 2);
112  ASSERT_TRUE(pnode->operator_parameters.count("output_type"));
113  ASSERT_TRUE(pnode->any_operator_parameters.count("function"));
114  binary_transform_type fn;
115  flex_type_enum output_type =
116  (flex_type_enum)(flex_int)(pnode->operator_parameters["output_type"]);
117 
118  fn = pnode->any_operator_parameters["function"].as<binary_transform_type>();
119  return std::make_shared<operator_impl>(fn, output_type);
120  }
121 
122  static std::vector<flex_type_enum> infer_type(std::shared_ptr<planner_node> pnode) {
123  ASSERT_EQ((int)pnode->operator_type, (int)planner_node_type::BINARY_TRANSFORM_NODE);
124  ASSERT_TRUE(pnode->operator_parameters.count("output_type"));
125  return {(flex_type_enum)(int)(pnode->operator_parameters["output_type"])};
126  }
127 
128  static int64_t infer_length(std::shared_ptr<planner_node> pnode) {
129  ASSERT_EQ((int)pnode->operator_type, (int)planner_node_type::BINARY_TRANSFORM_NODE);
130  return infer_planner_node_length(pnode->inputs[0]);
131  }
132 
133  private:
134  binary_transform_type m_transform_fn;
135 };
136 
138 
139 
140 /// \}
141 } // query_eval
142 } // turicreate
143 
144 #endif // TURI_SFRAME_QUERY_MANAGER_TRANSFORM_HPP
int64_t infer_planner_node_length(std::shared_ptr< planner_node > pnode)
int num_inputs
Number of inputs expected to the operator.
Definition: operator.hpp:56
size_t attribute_bitfield
A bitfield of the attribute enum.
Definition: operator.hpp:55
std::shared_ptr< const sframe_rows > get_next(size_t input_number)
void emit(const std::shared_ptr< sframe_rows > &rows)
std::shared_ptr< sframe_rows > get_output_buffer()
#define ASSERT_TRUE(cond)
Definition: assertions.hpp:309
static std::shared_ptr< planner_node > make_shared(planner_node_type operator_type, const std::map< std::string, flexible_type > &operator_parameters=std::map< std::string, flexible_type >(), const std::map< std::string, any > &any_operator_parameters=std::map< std::string, any >(), const std::vector< std::shared_ptr< planner_node >> &inputs=std::vector< std::shared_ptr< planner_node >>())