Turi Create  4.0
groupby.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef __TC_VIS_GROUPBY
7 #define __TC_VIS_GROUPBY
8 
9 #include <core/storage/sframe_data/groupby_aggregate_operators.hpp>
10 #include <core/data/sframe/gl_sframe.hpp>
11 #include <core/util/sys_util.hpp>
12 
13 #include "transformation.hpp"
14 
15 namespace turi {
16 namespace visualization {
17 
18 class summary_stats {
19  private:
20  groupby_operators::average m_average;
21  groupby_operators::count m_count;
22  groupby_operators::max m_max;
23  groupby_operators::min m_min;
24  groupby_operators::sum m_sum;
25  groupby_operators::stdv m_stdv;
26  groupby_operators::variance m_variance;
27 
28  public:
29  void add_element_simple(const flexible_type& value);
30  void combine(const summary_stats& other);
31  void partial_finalize();
32  flexible_type emit() const;
33  void set_input_type(flex_type_enum type);
34 };
35 
36 // Intended for boxes and whiskers or bar chart (bivariate plot, categorical
37 // vs. numeric). For now, just groups by one column (x), doing aggregation per
38 // category on a second column (y). Limited to the first n categories
39 // encountered in the x column.
40 // TODO -- pick the limited set of categories intelligently (n most popular
41 // rather than n first)
42 template<typename Aggregation>
43 class groupby_result {
44  protected:
45  // keeps track of one aggregator per category (unique value on first column)
46  std::unordered_map<flexible_type, Aggregation> m_aggregators;
47 
48  virtual void insert_category(const flexible_type& category) {
49  TURI_ATTRIBUTE_UNUSED_NDEBUG auto inserted =
50  m_aggregators.emplace(category, Aggregation());
51  DASSERT_TRUE(inserted.second); // emplace should succeed
52  auto& agg = m_aggregators.at(category);
54  agg.set_input_type(m_type);
55  }
56 
57  private:
58  constexpr static size_t CATEGORY_LIMIT = 1000;
59  flex_int m_omitted_categories = 0;
61 
62  static void update_or_combine(Aggregation& aggregation, const flexible_type& other) {
63  aggregation.add_element_simple(other);
64  }
65  static void update_or_combine(Aggregation& aggregation, const Aggregation& other) {
66  // TODO this is bad -- we need a non-const Aggregation in order to call
67  // partial_finalize, but this parameter is deeply const.
68  const_cast<Aggregation&>(other).partial_finalize();
69  aggregation.combine(other);
70  }
71 
72  protected:
73  template<typename T>
74  void update_or_combine(const flexible_type& category, const T& value) {
75  auto find_key = m_aggregators.find(category);
76  if (find_key == m_aggregators.end()) {
77  // insert new category if there is room
78  if (m_aggregators.size() < CATEGORY_LIMIT) {
79  this->insert_category(category);
80  groupby_result::update_or_combine(m_aggregators.at(category), value);
81  } else {
82  m_omitted_categories++;
83  }
84  } else {
85  groupby_result::update_or_combine((*find_key).second, value);
86  }
87  }
88  void update(const flexible_type& category, const flexible_type& value) {
89  const flex_type_enum type = value.get_type();
90  if (type == flex_type_enum::UNDEFINED) {
91  return; // ignore undefined values, they don't make sense in groupby
92  }
93  this->set_input_type(type);
94  this->update_or_combine(category, value);
95  }
96 
97  public:
98  void combine(const groupby_result<Aggregation>& other) {
99  this->set_input_type(other.get_input_type());
100  for (const auto& pair : other.m_aggregators) {
101  this->update_or_combine(pair.first, pair.second);
102  }
103  }
104  void update(const std::vector<flexible_type>& values) {
105  // by convention, values[0] is the grouped column,
106  // and values[1] is the aggregated column
107  DASSERT_GE(values.size(), 2);
108  this->update(values[0], values[1]);
109  }
110  std::unordered_map<flexible_type, flexible_type> get_grouped() const {
111  std::unordered_map<flexible_type, flexible_type> ret;
112  for (const auto& pair : m_aggregators) {
113  ret.emplace(pair.first, pair.second.emit());
114  }
115  return ret;
116  }
117  flex_int get_omitted() { return m_omitted_categories; }
118  void set_input_type(flex_type_enum type) {
119  if (m_type == flex_type_enum::UNDEFINED) {
120  m_type = type;
121  } else {
122  DASSERT_TRUE(m_type == type);
123  }
124  }
125  flex_type_enum get_input_type() const {
126  return m_type;
127  }
128  void add_element_simple(const flexible_type& value) {
129  DASSERT_TRUE(value.get_type() == flex_type_enum::LIST);
130  this->update(value.get<flex_list>());
131  }
132 };
133 
134 template<typename Result>
135 class groupby : public transformation<gl_sframe, Result> {
136  protected:
137  virtual void merge_results(std::vector<Result>& transformers) override {
138  for (auto& result : transformers) {
139  this->m_transformer->combine(result);
140  }
141  }
142 };
143 
144 class groupby_summary_result : public groupby_result<summary_stats> {
145 };
146 
147 class groupby_summary : public groupby<groupby_summary_result> {
148 };
149 
150 class groupby_quantile_result : public groupby_result<groupby_operators::quantile> {
151  public:
152  virtual void insert_category(const flexible_type& category) override;
153 };
154 
155 class groupby_quantile : public groupby<groupby_quantile_result> {
156 };
157 
158 }}
159 
160 #endif // __TC_VIS_GROUPBY
std::set< T > values(const std::map< Key, T > &map)
Definition: stl_util.hpp:386
std::vector< flexible_type > flex_list
#define DASSERT_TRUE(cond)
Definition: assertions.hpp:364