Turi Create  4.0
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
generalized_transform.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_SFRAME_QUERY_MANAGER_GENERALIZED_TRANSFORM_HPP
7 #define TURI_SFRAME_QUERY_MANAGER_GENERALIZED_TRANSFORM_HPP
8 #include <core/data/flexible_type/flexible_type.hpp>
9 #include <core/random/random.hpp>
10 #include <core/parallel/pthread_tools.hpp>
11 #include <core/storage/query_engine/operators/operator.hpp>
12 #include <core/storage/query_engine/execution/query_context.hpp>
13 #include <core/storage/query_engine/operators/operator_properties.hpp>
14 #include <core/util/coro.hpp>
15 namespace turi {
16 namespace query_eval {
17 
18 typedef std::function<void (const sframe_rows::row&,
19  sframe_rows::row&)> generalized_transform_type;
20 
21 /**
22  * \ingroup sframe_query_engine
23  * \addtogroup operators Logical Operators
24  * \{
25  */
26 
27 /**
28  * The generalized transform operator is like the transform operator
29  * but produces a vector output.
30  */
31 template<>
32 class operator_impl<planner_node_type::GENERALIZED_TRANSFORM_NODE> : public query_operator {
33  public:
34  DECL_CORO_STATE(execute);
35 
36  planner_node_type type() const { return planner_node_type::GENERALIZED_TRANSFORM_NODE; }
37 
38  static std::string name() { return "generalized_transform"; }
39 
40  static query_operator_attributes attributes() {
42  ret.attribute_bitfield = query_operator_attributes::LINEAR;
43  ret.num_inputs = 1;
44  return ret;
45  }
46 
47  ////////////////////////////////////////////////////////////////////////////////
48 
49  inline operator_impl(const generalized_transform_type& f,
50  const std::vector<flex_type_enum>& output_types,
51  int random_seed=-1)
52  : m_transform_fn(f), m_output_types(output_types), m_random_seed(random_seed)
53  { }
54 
55  inline std::shared_ptr<query_operator> clone() const {
56  return std::make_shared<operator_impl>(*this);
57  }
58 
59  inline bool coro_running() const {
60  return CORO_RUNNING(execute);
61  }
62  inline void execute(query_context& context) {
63  CORO_BEGIN(execute)
64  if (m_random_seed != -1){
65  random::get_source().seed(m_random_seed + thread::thread_id());
66  }
67  while(1) {
68  {
69  auto rows = context.get_next(0);
70  if (rows == nullptr)
71  break;
72  auto output = context.get_output_buffer();
73  output->resize(m_output_types.size(), rows->num_rows());
74 
75  auto iter = rows->cbegin();
76  auto output_iter = output->begin();
77  while(iter != rows->cend()) {
78  m_transform_fn((*iter), (*output_iter));
79  ++output_iter;
80  ++iter;
81  }
82  output->type_check_inplace(m_output_types);
83  context.emit(output);
84  }
85  CORO_YIELD();
86  }
87  CORO_END
88  }
89 
90  static std::shared_ptr<planner_node> make_planner_node(
91  std::shared_ptr<planner_node> source,
92  generalized_transform_type fn,
93  const std::vector<flex_type_enum>& output_types,
94  int random_seed=-1) {
95 
96  flex_list type_list(output_types.size());
97  for (size_t i = 0; i < output_types.size(); ++i) {
98  type_list[i] = flex_int(output_types[i]);
99  }
100 
101  return planner_node::make_shared(planner_node_type::GENERALIZED_TRANSFORM_NODE,
102  {{"output_types", type_list},
103  {"random_seed", random_seed}},
104  {{"function", any(fn)}},
105  {source});
106  }
107 
108  static std::shared_ptr<query_operator> from_planner_node(
109  std::shared_ptr<planner_node> pnode) {
110  ASSERT_EQ((int)pnode->operator_type, (int)planner_node_type::GENERALIZED_TRANSFORM_NODE);
111  ASSERT_EQ(pnode->inputs.size(), 1);
112  ASSERT_TRUE(pnode->operator_parameters.count("output_types"));
113  ASSERT_TRUE(pnode->any_operator_parameters.count("function"));
114  generalized_transform_type fn;
115 
116  flex_list list_output_types = pnode->operator_parameters["output_types"];
117  std::vector<flex_type_enum> output_types;
118  for (auto t: list_output_types) output_types.push_back((flex_type_enum)(flex_int)t);
119 
120  fn = pnode->any_operator_parameters["function"].as<generalized_transform_type>();
121 
122  int random_seed = (int)(flex_int)(pnode->operator_parameters["random_seed"]);
123  return std::make_shared<operator_impl>(fn, output_types, random_seed);
124  }
125 
126  static std::vector<flex_type_enum> infer_type(std::shared_ptr<planner_node> pnode) {
127  ASSERT_EQ((int)pnode->operator_type, (int)planner_node_type::GENERALIZED_TRANSFORM_NODE);
128  ASSERT_TRUE(pnode->operator_parameters.count("output_types"));
129 
130  flex_list outtypes = pnode->operator_parameters["output_types"];
131  std::vector<flex_type_enum> ret;
132  for (auto t: outtypes) ret.push_back((flex_type_enum)(flex_int)t);
133  return ret;
134  }
135 
136  static int64_t infer_length(std::shared_ptr<planner_node> pnode) {
137  ASSERT_EQ((int)pnode->operator_type, (int)planner_node_type::GENERALIZED_TRANSFORM_NODE);
138  return infer_planner_node_length(pnode->inputs[0]);
139  }
140 
141  static std::string repr(std::shared_ptr<planner_node> pnode, pnode_tagger&) {
142  size_t n_outs = infer_length(pnode);
143 
144  switch(n_outs) {
145  case 1: return "Tr->[C0]";
146  case 2: return "Tr->[C0,C1]";
147  case 3: return "Tr->[C0,C1,C2]";
148  default: {
149  std::ostringstream out;
150  out << "Tr->[C0,...,C" << (n_outs - 1) << "]";
151  return out.str();
152  }
153  }
154  }
155 
156  private:
157  generalized_transform_type m_transform_fn;
158  std::vector<flex_type_enum> m_output_types;
159  int m_random_seed;
160 };
161 
163 
164 /// \}
165 } // query_eval
166 } // turicreate
167 
168 #endif // TURI_SFRAME_QUERY_MANAGER_TRANSFORM_HPP
int64_t infer_planner_node_length(std::shared_ptr< planner_node > pnode)
int num_inputs
Number of inputs expected to the operator.
Definition: operator.hpp:56
void seed()
Seed the generator using the default seed.
Definition: random.hpp:101
size_t attribute_bitfield
A bitfield of the attribute enum.
Definition: operator.hpp:55
std::shared_ptr< const sframe_rows > get_next(size_t input_number)
void emit(const std::shared_ptr< sframe_rows > &rows)
std::shared_ptr< sframe_rows > get_output_buffer()
#define ASSERT_TRUE(cond)
Definition: assertions.hpp:309
generator & get_source()
static size_t thread_id()
static std::shared_ptr< planner_node > make_shared(planner_node_type operator_type, const std::map< std::string, flexible_type > &operator_parameters=std::map< std::string, flexible_type >(), const std::map< std::string, any > &any_operator_parameters=std::map< std::string, any >(), const std::vector< std::shared_ptr< planner_node >> &inputs=std::vector< std::shared_ptr< planner_node >>())
std::vector< flexible_type > flex_list