Turi Create  4.0
reduce.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_SFRAME_QUERY_MANAGER_REDUCE_HPP
7 #define TURI_SFRAME_QUERY_MANAGER_REDUCE_HPP
8 
9 #include <core/data/flexible_type/flexible_type.hpp>
10 #include <core/storage/sframe_data/group_aggregate_value.hpp>
11 #include <core/storage/query_engine/operators/operator.hpp>
12 #include <core/storage/query_engine/execution/query_context.hpp>
13 #include <core/storage/query_engine/operators/operator_properties.hpp>
14 #include <core/util/coro.hpp>
15 #include <iostream>
16 
17 namespace turi {
18 namespace query_eval {
19 
20 /**
21  * \ingroup sframe_query_engine
22  * \addtogroup operators Logical Operators
23  * \{
24  */
25 
26 /**
27  * Takes a stream of input reducing it to a single value.
28  * Uses the \ref group_aggregate_value class to define the reduction operations.
29  */
30 template <>
31 struct operator_impl<planner_node_type::REDUCE_NODE> : public query_operator {
32  public:
33  DECL_CORO_STATE(execute);
34 
35  planner_node_type type() const { return planner_node_type::REDUCE_NODE; }
36 
37  static std::string name() { return "reduce"; }
38 
39  inline operator_impl(std::shared_ptr<group_aggregate_value> aggregator,
40  flex_type_enum output_type)
41  : m_aggregator(aggregator)
42  , m_output_type(output_type) { }
43 
44  static query_operator_attributes attributes() {
47  ret.num_inputs = 1;
48  return ret;
49  }
50 
51  inline std::shared_ptr<query_operator> clone() const {
52  std::shared_ptr<group_aggregate_value> agg(m_aggregator->new_instance());
53  return std::make_shared<operator_impl>(agg, m_output_type);
54  }
55 
56  inline bool coro_running() const {
57  return CORO_RUNNING(execute);
58  }
59  inline void execute(query_context& context) {
60  CORO_BEGIN(execute)
61  {
62  while(1) {
63  auto rows = context.get_next(0);
64  if (rows == nullptr)
65  break;
66  for (const auto& row : *rows) {
67  // TODO make add_element take a sframe_row::row_reference instead
68  if (row.size() == 1) m_aggregator->add_element_simple(row[0]);
69  else m_aggregator->add_element(std::vector<flexible_type>(row));
70  }
71  }
72  auto out = context.get_output_buffer();
73  out->resize(1, 1);
74  (*out)[0][0] = m_aggregator->emit();
75  context.emit(out);
76  }
77  CORO_YIELD();
78  CORO_END
79  }
80 
81  static std::shared_ptr<planner_node> make_planner_node(
82  std::shared_ptr<planner_node> source,
83  group_aggregate_value& aggregator,
84  flex_type_enum output_type) {
85  std::shared_ptr<group_aggregate_value> agg(aggregator.new_instance());
86  return planner_node::make_shared(planner_node_type::REDUCE_NODE,
87  {{"output_type", (int)(output_type)}},
88  {{"aggregator", any(agg)}},
89  {source});
90  }
91 
92  static std::shared_ptr<query_operator> from_planner_node(
93 
94  std::shared_ptr<planner_node> pnode) {
95  ASSERT_EQ((int)pnode->operator_type, (int)planner_node_type::REDUCE_NODE);
96  ASSERT_EQ(pnode->inputs.size(), 1);
97  ASSERT_TRUE(pnode->operator_parameters.count("output_type"));
98  ASSERT_TRUE(pnode->any_operator_parameters.count("aggregator"));
99 
100  std::shared_ptr<group_aggregate_value> aggregator =
101  pnode->any_operator_parameters["aggregator"].as<std::shared_ptr<group_aggregate_value>>();
102  flex_type_enum output_type =
103  (flex_type_enum)(flex_int)(pnode->operator_parameters["output_type"]);
104 
105  std::shared_ptr<group_aggregate_value> agg(aggregator->new_instance());
106  return std::make_shared<operator_impl>(agg, output_type);
107  }
108 
109  static std::vector<flex_type_enum> infer_type(std::shared_ptr<planner_node> pnode) {
110  ASSERT_EQ((int)pnode->operator_type, (int)planner_node_type::REDUCE_NODE);
111  ASSERT_TRUE(pnode->operator_parameters.count("output_type"));
112  return {(flex_type_enum)(int)(pnode->operator_parameters["output_type"])};
113  }
114 
115  static int64_t infer_length(std::shared_ptr<planner_node> pnode) {
116  return -1;
117  }
118 
119  static std::string print(std::shared_ptr<planner_node> pnode) {
120  ASSERT_TRUE(pnode->any_operator_parameters.count("aggregator"));
121 
122  std::shared_ptr<group_aggregate_value> aggregator =
123  pnode->any_operator_parameters["aggregator"].as<std::shared_ptr<group_aggregate_value> >();
124 
125  return std::string("Agg.") + aggregator->name();
126  }
127 
128  private:
129  std::shared_ptr<group_aggregate_value> m_aggregator;
130  flex_type_enum m_output_type;
131 };
132 
134 
135 /// \}
136 } // query_eval
137 } // turicreate
138 
139 #endif // TURI_SFRAME_QUERY_MANAGER_TRANSFORM_HPP
int num_inputs
Number of inputs expected to the operator.
Definition: operator.hpp:56
size_t attribute_bitfield
A bitfield of the attribute enum.
Definition: operator.hpp:55
std::shared_ptr< const sframe_rows > get_next(size_t input_number)
virtual group_aggregate_value * new_instance() const =0
void emit(const std::shared_ptr< sframe_rows > &rows)
std::shared_ptr< sframe_rows > get_output_buffer()
std::shared_ptr< query_operator > clone() const
Definition: reduce.hpp:51
#define ASSERT_TRUE(cond)
Definition: assertions.hpp:309
static std::shared_ptr< planner_node > make_shared(planner_node_type operator_type, const std::map< std::string, flexible_type > &operator_parameters=std::map< std::string, flexible_type >(), const std::map< std::string, any > &any_operator_parameters=std::map< std::string, any >(), const std::vector< std::shared_ptr< planner_node >> &inputs=std::vector< std::shared_ptr< planner_node >>())
virtual std::string name() const =0