Turi Create  4.0
rolling_aggregate.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_SFRAME_ROLLING_AGGREGATE_HPP
7 #define TURI_SFRAME_ROLLING_AGGREGATE_HPP
8 
9 #include <core/data/flexible_type/flexible_type.hpp>
10 #include <core/storage/sframe_data/sarray.hpp>
11 #include <core/storage/sframe_data/groupby_aggregate_operators.hpp>
12 
13 namespace turi {
14 
15 /**
16  * \ingroup sframe_physical
17  * \addtogroup groupby_aggregate Groupby Aggregation
18  * \{
19  */
20 
21 /**
22  * Rolling window aggregators
23  */
24 namespace rolling_aggregate {
25 
26 typedef boost::circular_buffer<flexible_type>::iterator circ_buffer_iterator_t;
27 typedef std::function<flexible_type(circ_buffer_iterator_t,circ_buffer_iterator_t)> full_window_fn_type_t;
28 
29 /**
30  * Apply an aggregate function over a moving window.
31  *
32  * \param input The input SArray (expects to be materialized)
33  * \param agg_op The aggregator. These classes are the same as used by groupby.
34  * \param window_start The start of the moving window relative to the current
35  * value being calculated, inclusive. For example, 2 values behind the current
36  * would be -2, and 0 indicates that the start of the window is the current
37  * value.
38  * \param window_end The end of the moving window relative to the current value
39  * being calculated, inclusive. Must be greater than `window_start`. For
40  * example, 0 would indicate that the current value is the end of the window,
41  * and 2 would indicate that the window ends at 2 data values after the
42  * current.
43  * \param min_observations The minimum allowed number of non-NULL values in the
44  * moving window for the emitted value to be non-NULL. size_t(-1) indicates
45  * that all values must be non-NULL.
46  *
47  * Returns an SArray of the same length as the input, with a type that matches
48  * the type output by the aggregation function.
49  *
50  * Throws an exception if:
51  * - window_end < window_start
52  * - The window size is excessively large (currently hardcoded to UINT_MAX).
53  * - The given function name corresponds to a function that will not operate
54  * on the data type of the input SArray.
55  * - The aggregation function returns more than one non-NULL types.
56  */
57 std::shared_ptr<sarray<flexible_type>> rolling_apply(
58  const sarray<flexible_type> &input,
59  std::shared_ptr<group_aggregate_value> agg_op,
60  ssize_t window_start,
61  ssize_t window_end,
62  size_t min_observations);
63 
64 
65 /// Aggregate functions
66 template<typename Iterator>
67 flexible_type full_window_aggregate(std::shared_ptr<group_aggregate_value> agg_op,
68  Iterator first, Iterator last) {
69  auto agg = agg_op->new_instance();
70  for(; first != last; ++first) {
71  agg->add_element_simple(*first);
72  }
73 
74  return agg->emit();
75 }
76 
77 /**
78  * Scans the current window to check for the number of non-NULL values.
79  *
80  * Returns true if the number of non-NULL values is >= min_observations, false
81  * otherwise.
82  */
83 template<typename Iterator>
84 bool has_min_observations(size_t min_observations,
85  Iterator first,
86  Iterator last) {
87  size_t observations = 0;
88  size_t count = 0;
89  bool need_all = (min_observations == size_t(-1));
90  for(; first != last; ++first, ++count) {
91  if(first->get_type() != flex_type_enum::UNDEFINED) {
92  ++observations;
93  if(!need_all && (observations >= min_observations)) {
94  return true;
95  }
96  }
97  }
98 
99  if(need_all)
100  return (observations == count);
101 
102  return false;
103 }
104 
105 } // namespace rolling_aggregate
106 } // namespace turi
107 #endif // TURI_SFRAME_ROLLING_AGGREGATE_HPP
bool has_min_observations(size_t min_observations, Iterator first, Iterator last)
std::shared_ptr< sarray< flexible_type > > rolling_apply(const sarray< flexible_type > &input, std::shared_ptr< group_aggregate_value > agg_op, ssize_t window_start, ssize_t window_end, size_t min_observations)
flexible_type full_window_aggregate(std::shared_ptr< group_aggregate_value > agg_op, Iterator first, Iterator last)
Aggregate functions.