Turi Create  4.0
serialize_eigen.hpp
1 /**
2  * Copyright (C) 2016 Turi
3  * All rights reserved.
4  *
5  * This software may be modified and distributed under the terms
6  * of the BSD license. See the LICENSE file for details.
7  */
8 
9 /**
10  * Copyright (c) 2009 Carnegie Mellon University.
11  * All rights reserved.
12  *
13  * Licensed under the Apache License, Version 2.0 (the "License");
14  * you may not use this file except in compliance with the License.
15  * You may obtain a copy of the License at
16  *
17  * http://www.apache.org/licenses/LICENSE-2.0
18  *
19  * Unless required by applicable law or agreed to in writing,
20  * software distributed under the License is distributed on an "AS
21  * IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
22  * express or implied. See the License for the specific language
23  * governing permissions and limitations under the License.
24  *
25  * For more about this software visit:
26  *
27  * http://www.turicreate.ml.cmu.edu
28  *
29  */
30 
31 #ifndef EIGEN_SERIALIZATION_HPP
32 #define EIGEN_SERIALIZATION_HPP
33 
34 #include <core/storage/serialization/serialization_includes.hpp>
35 #include <core/logging/assertions.hpp>
36 #include <typeinfo>
37 
38 ////////////////////////////////////////////////////////////////////////////////
39 // Some macros to tame the template gymnastics below
40 
41 #define _EIGEN_TEMPLATE_ARGS(prefix) \
42  typename prefix##Scalar, \
43  int prefix##Rows, int prefix##Cols, \
44  int prefix##Options, int prefix##MaxRows, int prefix##MaxCols
45 
46 #define _EIGEN_TEMPLATE_ARG_DEF(prefix) \
47  prefix##Scalar, \
48  prefix##Rows, prefix##Cols, \
49  prefix##Options, prefix##MaxRows, prefix##MaxCols
50 
51 
52 ////////////////////////////////////////////////////////////////////////////////
53 // Forward declare the eigen classes so they don't have to be linked
54 // in with the serialization libraries.
55 
56 namespace Eigen {
57 
58 template<_EIGEN_TEMPLATE_ARGS(_)> class Matrix;
59 
60 template<_EIGEN_TEMPLATE_ARGS(_)> class Array;
61 
62 template <typename _Scalar, int _Flags, typename _Index> class SparseVector;
63 }
64 
65 namespace turi { namespace archive_detail {
66 
67 ////////////////////////////////////////////////////////////////////////////////
68 // The main serializer function. Used for both array and matrix.
69 
70 template <typename InArcType,
71  template <_EIGEN_TEMPLATE_ARGS(__)> class EigenContainer,
72  _EIGEN_TEMPLATE_ARGS(_)>
73 void eigen_serialize_impl(InArcType& arc, const EigenContainer<_EIGEN_TEMPLATE_ARG_DEF(_)>& X) {
74 
75  // This code does type checking for making sure the correct types
76  // are loaded. However, it's not backwards compatible, so we'll
77  // instead just do the simplest thing possible (that is backwards
78  // compatible). Thus this code is commented out for now.
79 
80  // size_t version = 1;
81  // arc << version;
82 
83  // // An id of the scalar type.
84  // std::string scalar_type_id = typeid(_Scalar).name();
85  // arc << scalar_type_id;
86 
87  // arc << _Rows << _Cols << _Options << _MaxRows << _MaxCols;
88 
89  typedef typename EigenContainer<_EIGEN_TEMPLATE_ARG_DEF(_)>::Index index_type;
90 
91  const index_type rows = X.rows();
92  const index_type cols = X.cols();
93 
94  arc << rows << cols;
95 
96  turi::serialize(arc, X.data(), rows*cols*sizeof(_Scalar));
97 }
98 
99 ////////////////////////////////////////////////////////////////////////////////
100 // The main deserializer function. Used for both array and matrix.
101 
102 template <typename InArcType,
103  template <_EIGEN_TEMPLATE_ARGS(__)> class EigenContainer,
104  _EIGEN_TEMPLATE_ARGS(_)>
105 void eigen_deserialize_impl(InArcType& arc, EigenContainer<_EIGEN_TEMPLATE_ARG_DEF(_)>& X) {
106 
107 
108  // This code does type checking for making sure the correct types
109  // are loaded. However, it's not backwards compatible, so we'll
110  // instead just do the simplest thing possible (that is backwards
111  // compatible). Thus this code is commented out for now.
112 
113  // size_t version;
114  // arc >> version;
115  // ASSERT_EQ(version, 1);
116 
117  // std::string scalar_type_id;
118 
119  // arc >> scalar_type_id;
120 
121  // ASSERT_MSG(scalar_type_id == typeid(_Scalar).name(),
122  // "Attempt to load Eigen matrix from conflicting type.");
123 
124  // int __Rows, __Cols, __Options, __MaxRows, __MaxCols;
125 
126  // arc >> __Rows >> __Cols >> __Options >> __MaxRows >> __MaxCols;
127 
128  // // Right now, only really need to check the options; the rest
129  // // are going to be.
130  // ASSERT_MSG(__Options == _Options,
131  // "Eigen interanl storage options not matched on load.");
132 
133  typedef typename EigenContainer<_EIGEN_TEMPLATE_ARG_DEF(_)>::Index index_type;
134 
135  index_type rows, cols;
136 
137  arc >> rows >> cols;
138 
139  X.resize(rows,cols);
140  turi::deserialize(arc, X.data(), rows*cols*sizeof(_Scalar));
141 
142 }
143 
144 /////////////////////////////////////////////////////////////////////////////////
145 //
146 // The matrix class
147 
148 template <typename InArcType, _EIGEN_TEMPLATE_ARGS(_)>
149 struct deserialize_impl<InArcType, Eigen::Matrix<_EIGEN_TEMPLATE_ARG_DEF(_)>, false> {
150 
151  static void exec(InArcType& arc, Eigen::Matrix<_EIGEN_TEMPLATE_ARG_DEF(_)>& X) {
152  eigen_deserialize_impl(arc, X);
153  }
154 };
155 
156 template <typename OutArcType, _EIGEN_TEMPLATE_ARGS(_)>
157 struct serialize_impl<OutArcType, Eigen::Matrix<_EIGEN_TEMPLATE_ARG_DEF(_)>, false> {
158 
159  static void exec(OutArcType& arc, const Eigen::Matrix<_EIGEN_TEMPLATE_ARG_DEF(_)>& X) {
160  eigen_serialize_impl(arc, X);
161  }
162 };
163 
164 /////////////////////////////////////////////////////////////////////////////////
165 //
166 // The array class
167 
168 template <typename InArcType, _EIGEN_TEMPLATE_ARGS(_)>
169 struct deserialize_impl<InArcType, Eigen::Array<_EIGEN_TEMPLATE_ARG_DEF(_)>, false> {
170 
171  static void exec(InArcType& arc, Eigen::Array<_EIGEN_TEMPLATE_ARG_DEF(_)>& X) {
172  eigen_deserialize_impl(arc, X);
173  }
174 };
175 
176 template <typename OutArcType, _EIGEN_TEMPLATE_ARGS(_)>
177 struct serialize_impl<OutArcType, Eigen::Array<_EIGEN_TEMPLATE_ARG_DEF(_)>, false> {
178 
179  static void exec(OutArcType& arc, const Eigen::Array<_EIGEN_TEMPLATE_ARG_DEF(_)>& X) {
180  eigen_serialize_impl(arc, X);
181  }
182 };
183 
184 /////////////////////////////////////////////////////////////////////////////////
185 //
186 // The SparseVector class
187 
188 template <typename InArcType, typename _Scalar, int _Flags, typename _Index>
189 struct deserialize_impl<InArcType, Eigen::SparseVector<_Scalar, _Flags, _Index>, false> {
190 
191  static void exec(InArcType& arc, Eigen::SparseVector<_Scalar, _Flags, _Index>& vec) {
192  size_t version;
193  arc >> version;
194 
195  ASSERT_EQ(version, 1);
196 
197  size_t _size, _nnz, index;
198  double value;
199 
200  arc >> _size;
201  vec.resize(_size);
202 
203  arc >> _nnz;
204  vec.reserve(_nnz);
205 
206  for(size_t i = 0; i < _nnz; i++) {
207  arc >> index;
208  arc >> value;
209  vec.coeffRef(index) = value;
210  }
211  }
212 };
213 
214 template <typename OutArcType, typename _Scalar, int _Flags, typename _Index>
215 struct serialize_impl<OutArcType, Eigen::SparseVector<_Scalar, _Flags, _Index>, false> {
216 
217  static void exec(OutArcType& arc, const Eigen::SparseVector<_Scalar, _Flags, _Index>& vec) {
218  size_t version = 1;
219 
220  arc << version;
221 
222  arc << (size_t)vec.size() << (size_t)vec.nonZeros();
223 
224  for (typename Eigen::SparseVector<_Scalar, _Flags, _Index>::InnerIterator i(vec); i; ++i) {
225  arc << (size_t)i.index() << (double)i.value();
226  }
227  }
228 };
229 
230 }}
231 
232 #undef _EIGEN_TEMPLATE_ARGS
233 #undef _EIGEN_TEMPLATE_ARG_DEF
234 
235 #endif