Turi Create  4.0
transform.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_SFRAME_QUERY_MANAGER_TRANSFORM_HPP
7 #define TURI_SFRAME_QUERY_MANAGER_TRANSFORM_HPP
8 #include <core/data/flexible_type/flexible_type.hpp>
9 #include <core/random/random.hpp>
10 #include <core/parallel/pthread_tools.hpp>
11 #include <core/storage/query_engine/operators/operator.hpp>
12 #include <core/storage/query_engine/execution/query_context.hpp>
13 #include <core/storage/query_engine/operators/operator_properties.hpp>
14 #include <core/util/coro.hpp>
15 namespace turi {
16 namespace query_eval {
17 
18 typedef std::function<flexible_type(const sframe_rows::row&)> transform_type;
19 
20 /**
21  * \ingroup sframe_query_engine
22  * \addtogroup operators Logical Operators
23  * \{
24  */
25 
26 /**
27  * A "transform" operator applys a transform function on a
28  * stream of input.
29  */
30 template<>
31 class operator_impl<planner_node_type::TRANSFORM_NODE> : public query_operator {
32  public:
33  DECL_CORO_STATE(execute);
34  planner_node_type type() const { return planner_node_type::TRANSFORM_NODE; }
35 
36  static std::string name() { return "transform"; }
37 
38  static query_operator_attributes attributes() {
40  ret.attribute_bitfield = query_operator_attributes::LINEAR;
41  ret.num_inputs = 1;
42  return ret;
43  }
44 
45  ////////////////////////////////////////////////////////////////////////////////
46 
47  inline operator_impl(const transform_type& f,
48  flex_type_enum output_type,
49  size_t random_seed=size_t(-1))
50  : m_transform_fn(f), m_output_type(output_type), m_random_seed(random_seed)
51  { }
52 
53  inline std::shared_ptr<query_operator> clone() const {
54  return std::make_shared<operator_impl>(*this);
55  }
56 
57  inline bool coro_running() const {
58  return CORO_RUNNING(execute);
59  }
60  inline void execute(query_context& context) {
61  CORO_BEGIN(execute)
62  if (m_random_seed != size_t(-1)){
63  random::get_source().seed(m_random_seed + thread::thread_id());
64  }
65  while(1) {
66  {
67  auto rows = context.get_next(0);
68  if (rows == nullptr)
69  break;
70  auto output = context.get_output_buffer();
71  output->resize(1, rows->num_rows());
72 
73  auto iter = rows->cbegin();
74  auto output_iter = output->begin();
75  while(iter != rows->cend()) {
76  auto outval = m_transform_fn((*iter));
77  if (m_output_type == flex_type_enum::UNDEFINED ||
78  outval.get_type() == m_output_type ||
79  outval.get_type() == flex_type_enum::UNDEFINED) {
80  (*output_iter)[0] = outval;
81  } else {
82  flexible_type f(m_output_type);
83  f.soft_assign(outval);
84  (*output_iter)[0] = f;
85  }
86  ++output_iter;
87  ++iter;
88  }
89  context.emit(output);
90  }
91  CORO_YIELD();
92  }
93  CORO_END
94  }
95 
96  static std::shared_ptr<planner_node> make_planner_node(
97  std::shared_ptr<planner_node> source,
98  transform_type fn,
99  flex_type_enum output_type,
100  size_t random_seed=size_t(-1)) {
101  return planner_node::make_shared(planner_node_type::TRANSFORM_NODE,
102  {{"output_type", (int)(output_type)},
103  {"random_seed", random_seed}},
104  {{"function", any(fn)}},
105  {source});
106  }
107 
108  static std::shared_ptr<query_operator> from_planner_node(
109  std::shared_ptr<planner_node> pnode) {
110  ASSERT_EQ((int)pnode->operator_type, (int)planner_node_type::TRANSFORM_NODE);
111  ASSERT_EQ(pnode->inputs.size(), 1);
112  ASSERT_TRUE(pnode->operator_parameters.count("output_type"));
113  ASSERT_TRUE(pnode->any_operator_parameters.count("function"));
114  transform_type fn;
115  flex_type_enum output_type =
116  (flex_type_enum)(flex_int)(pnode->operator_parameters["output_type"]);
117  fn = pnode->any_operator_parameters["function"].as<transform_type>();
118  int random_seed = (int)(flex_int)(pnode->operator_parameters["random_seed"]);
119  return std::make_shared<operator_impl>(fn, output_type, random_seed);
120  }
121 
122  static std::vector<flex_type_enum> infer_type(std::shared_ptr<planner_node> pnode) {
123  ASSERT_EQ((int)pnode->operator_type, (int)planner_node_type::TRANSFORM_NODE);
124  ASSERT_TRUE(pnode->operator_parameters.count("output_type"));
125  return {(flex_type_enum)(int)(pnode->operator_parameters["output_type"])};
126  }
127 
128  static int64_t infer_length(std::shared_ptr<planner_node> pnode) {
129  ASSERT_EQ((int)pnode->operator_type, (int)planner_node_type::TRANSFORM_NODE);
130  return infer_planner_node_length(pnode->inputs[0]);
131  }
132 
133  private:
134  transform_type m_transform_fn;
135  flex_type_enum m_output_type;
136  size_t m_random_seed;
137 };
138 
140 
141 /// \}
142 } // query_eval
143 } // turicreate
144 
145 #endif // TURI_SFRAME_QUERY_MANAGER_TRANSFORM_HPP
int64_t infer_planner_node_length(std::shared_ptr< planner_node > pnode)
int num_inputs
Number of inputs expected to the operator.
Definition: operator.hpp:56
void seed()
Seed the generator using the default seed.
Definition: random.hpp:101
flexible_type & soft_assign(const flexible_type &other)
size_t attribute_bitfield
A bitfield of the attribute enum.
Definition: operator.hpp:55
std::shared_ptr< const sframe_rows > get_next(size_t input_number)
void emit(const std::shared_ptr< sframe_rows > &rows)
std::shared_ptr< sframe_rows > get_output_buffer()
#define ASSERT_TRUE(cond)
Definition: assertions.hpp:309
generator & get_source()
static size_t thread_id()
static std::shared_ptr< planner_node > make_shared(planner_node_type operator_type, const std::map< std::string, flexible_type > &operator_parameters=std::map< std::string, flexible_type >(), const std::map< std::string, any > &any_operator_parameters=std::map< std::string, any >(), const std::vector< std::shared_ptr< planner_node >> &inputs=std::vector< std::shared_ptr< planner_node >>())