Turi Create  4.0
spmat.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_UTIL_SPMAT_H_
7 #define TURI_UTIL_SPMAT_H_
8 
9 #include <vector>
10 #include <map>
11 #include <core/random/alias.hpp>
12 #include <Eigen/Core>
13 
14 
15 /**
16  * \ingroup toolkit_util
17  * A simple utility class for representing sparse matrices of counts.
18  * It exposes getting particular elements, incrementing elements by a
19  * value, and removing zero elements from the internal data structure.
20  * It is row-based, so only exposes get_row and num_rows.
21  */
22 class spmat {
23  public:
24 
25  /**
26  * Create a sparse matrix with a fixed number of rows.
27  */
28  spmat(size_t num_rows = 0) {
29  m = std::vector<std::map<size_t, size_t>>(num_rows);
30  }
31 
32  /**
33  * Get a vector of nonzero elements in a single row.
34  */
35  const std::map<size_t, size_t>& get_row(size_t i) {
36  return m[i];
37  }
38 
39  /**
40  * Get the count at element(i,j).
41  */
42  size_t get(const size_t i, const size_t j) {
43  if (m[i].count(j) == 0) {
44  return 0;
45  } else {
46  return m[i].at(j);
47  }
48  }
49 
50  /**
51  * Get the number of rows.
52  */
53  size_t num_rows() const { return m.size(); }
54 
55  /**
56  * Delete zeros in a single row
57  */
58  void trim(const size_t i) {
59  auto it = m[i].begin();
60  for( ; it != m[i].end();) {
61  if (it->second == 0) {
62  it = m[i].erase(it);
63  } else {
64  ++it;
65  }
66  }
67  }
68 
69  /**
70  * Increment the element (a, b) by v.
71  */
72  void increment(const size_t& a, const size_t& b, const size_t& v) {
73  auto it = m[a].find(b);
74  if (it == m[a].end()) {
75  m[a][b] = v;
76  } else {
77  it->second += v;
78  }
79  }
80 
81  /**
82  * Convert to Eigen matrix.
83  */
84  Eigen::MatrixXi as_matrix() {
85  size_t nrows = m.size();
86  size_t ncols = 0;
87  for (size_t i = 0; i < nrows; ++i) {
88  auto row = get_row(i);
89  for (auto it = row.begin(); it != row.end(); ++it) {
90  auto col = it->first;
91  if (col >= ncols) ncols = col+1; // zero-based indexing
92  }
93  }
94  auto ret = Eigen::MatrixXi(nrows, ncols);
95  for (size_t i = 0; i < nrows; ++i) {
96  auto row = get_row(i);
97  for (auto it = row.begin(); it != row.end(); ++it) {
98  auto j = it->first;
99  auto v = it->second;
100  ret(i, j) = static_cast<int>(v);
101  }
102  }
103  return ret;
104  }
105 
106  private:
107  std::vector<std::map<size_t, size_t>> m;
108 };
109 
110 #endif
size_t num_rows() const
Definition: spmat.hpp:53
spmat(size_t num_rows=0)
Definition: spmat.hpp:28
void trim(const size_t i)
Definition: spmat.hpp:58
const std::map< size_t, size_t > & get_row(size_t i)
Definition: spmat.hpp:35
Eigen::MatrixXi as_matrix()
Definition: spmat.hpp:84
void increment(const size_t &a, const size_t &b, const size_t &v)
Definition: spmat.hpp:72
Definition: spmat.hpp:22