Turi Create  4.0
lambda_transform.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_SFRAME_QUERY_MANAGER_LAMBDA_TRANSFORM_HPP
7 #define TURI_SFRAME_QUERY_MANAGER_LAMBDA_TRANSFORM_HPP
8 #include <core/data/flexible_type/flexible_type.hpp>
9 #include <core/storage/query_engine/operators/operator.hpp>
10 #include <core/storage/query_engine/execution/query_context.hpp>
11 #include <core/storage/query_engine/operators/operator_properties.hpp>
12 #include <core/system/lambda/pylambda_function.hpp>
13 #include <core/system/exceptions/error_types.hpp>
14 #include <core/util/coro.hpp>
15 namespace turi {
16 namespace query_eval {
17 
18 /**
19  * \ingroup sframe_query_engine
20  * \addtogroup operators Logical Operators
21  * \{
22  */
23 
24 /**
25  * A "transform" operator that applies a python lambda function to a
26  * single stream of input.
27  */
28 template<>
29 class operator_impl<planner_node_type::LAMBDA_TRANSFORM_NODE> : public query_operator {
30 
31  public:
32  DECL_CORO_STATE(execute);
33  planner_node_type type() const { return planner_node_type::LAMBDA_TRANSFORM_NODE; }
34 
35  static std::string name() { return "lambda_transform"; }
36 
37  static query_operator_attributes attributes() {
39  ret.attribute_bitfield = query_operator_attributes::LINEAR;
40  ret.num_inputs = 1;
41  return ret;
42  }
43 
44  ////////////////////////////////////////////////////////////////////////////////
45  inline operator_impl(std::shared_ptr<lambda::pylambda_function> lambda,
46  flex_type_enum output_type,
47  const std::vector<std::string>& column_names = {})
48  : m_lambda(lambda), m_output_type(output_type),
49  m_column_names(column_names) { }
50 
51  inline std::shared_ptr<query_operator> clone() const {
52  return std::make_shared<operator_impl>(*this);
53  }
54 
55  inline bool coro_running() const {
56  return CORO_RUNNING(execute);
57  }
58  inline void execute(query_context& context) {
59  CORO_BEGIN(execute)
60  while(1) {
61  {
62  auto rows = context.get_next(0);
63  if (rows == nullptr)
64  break;
65 
66  auto output = context.get_output_buffer();
67  output->resize(1, rows->num_rows());
68  std::vector<flexible_type> out;
69 
70  // TODO exception handling
71  if (m_column_names.empty()) {
72  // evalute on sarray
73  m_lambda->eval(*rows, out);
74  } else {
75  // need column names to evalute on sframe
76  m_lambda->eval(m_column_names, *rows, out);
77  }
78 
79  for (size_t i = 0;i < out.size(); ++i) {
80  (*output)[i][0] = convert_value_to_output_type(out[i], m_output_type);
81  }
82  context.emit(output);
83  }
84  CORO_YIELD();
85  }
86  CORO_END
87  }
88 
89  static std::shared_ptr<planner_node> make_planner_node(
90  std::shared_ptr<planner_node> source,
91  const std::string& lambda_str,
92  flex_type_enum output_type,
93  const std::vector<std::string> column_names = {},
94  bool skip_undefined = false,
95  size_t random_seed = size_t(-1)) {
96 
97  flex_list column_names_list(column_names.begin(), column_names.end());
98  auto lambda_function = std::make_shared<lambda::pylambda_function>(lambda_str);
99  lambda_function->set_skip_undefined(skip_undefined);
100  lambda_function->set_random_seed(random_seed);
101  return planner_node::make_shared(planner_node_type::LAMBDA_TRANSFORM_NODE,
102  {{"output_type", (int)(output_type)},
103  {"lambda_str", lambda_str},
104  {"skip_undefined", (int)(skip_undefined)},
105  {"random_seed", (int)(random_seed)},
106  {"column_names", column_names_list}},
107  {{"lambda_fn", any(lambda_function)}},
108  {source});
109  }
110 
111  static std::shared_ptr<query_operator> from_planner_node(
112  std::shared_ptr<planner_node> pnode) {
113  ASSERT_EQ((int)pnode->operator_type, (int)planner_node_type::LAMBDA_TRANSFORM_NODE);
114  ASSERT_EQ(pnode->inputs.size(), 1);
115  ASSERT_TRUE(pnode->operator_parameters.count("output_type"));
116  ASSERT_TRUE(pnode->operator_parameters.count("lambda_str"));
117  ASSERT_TRUE(pnode->operator_parameters.count("column_names"));
118  ASSERT_TRUE(pnode->operator_parameters.count("skip_undefined"));
119  ASSERT_TRUE(pnode->operator_parameters.count("random_seed"));
120  ASSERT_TRUE(pnode->any_operator_parameters.count("lambda_fn"));
121 
122  flex_type_enum output_type =
123  (flex_type_enum)(flex_int)(pnode->operator_parameters["output_type"]);
124  flex_list column_names_list =
125  (pnode->operator_parameters["column_names"]).get<flex_list>();
126  std::vector<std::string> column_names(column_names_list.begin(), column_names_list.end());
127 
128  auto fn = pnode->any_operator_parameters["lambda_fn"]
129  .as<std::shared_ptr<lambda::pylambda_function>>();
130  return std::make_shared<operator_impl>(fn, output_type, column_names);
131  }
132 
133  static std::vector<flex_type_enum> infer_type(std::shared_ptr<planner_node> pnode) {
134  ASSERT_EQ((int)pnode->operator_type, (int)planner_node_type::LAMBDA_TRANSFORM_NODE);
135  ASSERT_TRUE(pnode->operator_parameters.count("output_type"));
136  return {(flex_type_enum)(int)(pnode->operator_parameters["output_type"])};
137  }
138 
139  static int64_t infer_length(std::shared_ptr<planner_node> pnode) {
140  ASSERT_EQ((int)pnode->operator_type, (int)planner_node_type::LAMBDA_TRANSFORM_NODE);
141  return infer_planner_node_length(pnode->inputs[0]);
142  }
143 
144  static std::string repr(std::shared_ptr<planner_node> pnode, pnode_tagger&) {
145  std::ostringstream out;
146  out << "PyLambda";
147 
148  flex_list column_names_list =
149  (pnode->operator_parameters["column_names"]).get<flex_list>();
150 
151  if(!column_names_list.empty()) {
152  out << "(";
153 
154  for(size_t i = 0; i < column_names_list.size() - 1; ++i) {
155  out << column_names_list[i] << ",";
156  }
157 
158  out << column_names_list.back() << ")";
159  }
160 
161  return out.str();
162  }
163 
164  private:
165  std::shared_ptr<lambda::pylambda_function> m_lambda;
166  flex_type_enum m_output_type;
167  std::vector<std::string> m_column_names;
168 
169  /**
170  * Helper function to convert flexible_type value to expected type.
171  */
172  static flexible_type convert_value_to_output_type(const flexible_type& val,
173  flex_type_enum type) {
174  if (val.get_type() == type ||
176  type == flex_type_enum::UNDEFINED) {
177  return val;
178  } else if (flex_type_is_convertible(val.get_type(), type)) {
179  flexible_type res(type);
180  res.soft_assign(val);
181  return res;
182  } else if ( (val.get_type() == flex_type_enum::VECTOR &&
183  type == flex_type_enum::LIST)
184  || (val.get_type() == flex_type_enum::LIST &&
185  type == flex_type_enum::VECTOR)) {
186  // empty lists / vectors cast between each other.
187  flexible_type res(type);
188  res.soft_assign(val);
189  return res;
190  } else {
191  std::string message = "Cannot convert " + std::string(val) +
192  " of type " + flex_type_enum_to_name(val.get_type()) +
193  " to " + flex_type_enum_to_name(type) + " which is the type of the SArray." +
194  " If this is not the correct type try specifying it during creation.";
195  logstream(LOG_ERROR) << message << std::endl;
196  throw(bad_cast(message));
197  }
198  }
199 };
200 
202 
203 /// \}
204 } // query_eval
205 } // turicreate
206 
207 #endif // TURI_SFRAME_QUERY_MANAGER_LAMBDA_TRANSFORM_HPP
#define logstream(lvl)
Definition: logger.hpp:276
int64_t infer_planner_node_length(std::shared_ptr< planner_node > pnode)
int num_inputs
Number of inputs expected to the operator.
Definition: operator.hpp:56
flexible_type & soft_assign(const flexible_type &other)
size_t attribute_bitfield
A bitfield of the attribute enum.
Definition: operator.hpp:55
const char * flex_type_enum_to_name(flex_type_enum en)
std::shared_ptr< const sframe_rows > get_next(size_t input_number)
void emit(const std::shared_ptr< sframe_rows > &rows)
std::shared_ptr< sframe_rows > get_output_buffer()
#define LOG_ERROR
Definition: logger.hpp:97
bool flex_type_is_convertible(flex_type_enum from, flex_type_enum to)
flex_type_enum get_type() const
#define ASSERT_TRUE(cond)
Definition: assertions.hpp:309
static std::shared_ptr< planner_node > make_shared(planner_node_type operator_type, const std::map< std::string, flexible_type > &operator_parameters=std::map< std::string, flexible_type >(), const std::map< std::string, any > &any_operator_parameters=std::map< std::string, any >(), const std::vector< std::shared_ptr< planner_node >> &inputs=std::vector< std::shared_ptr< planner_node >>())
std::vector< flexible_type > flex_list