Turi Create  4.0
alias.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_TEXT_ALIAS_H_
7 #define TURI_TEXT_ALIAS_H_
8 
9 #include <vector>
10 #include <core/random/alias.hpp>
11 #include <toolkits/util/spmat.hpp>
12 #include <toolkits/text/topic_model.hpp>
13 #include <core/export.hpp>
14 #include <toolkits/text/topic_model.hpp>
15 
16 /**
17 TODO:
18 - Replace spmat with flex_dict for each document
19 - Parallelize over documents. Trim zeros at the end of each sample.
20 - Try map, hopscotch_map for spmat.
21 - change predict_counts and sample_counts API to be able to handle
22  both training set and validation set? That way we aren't starting from
23  scratch on the validation set.
24 - combine all the word alias computation into one method
25 - use Eigen Vector instead of matrices with one row.
26 - Make sure to use const auto& where appropriate.
27 - Choose whether to use w, s, t, d, psdw, etc.
28 - Track MH acceptance ratio
29 - Consider using row-order Eigen matrices and checking for speedup
30  (at least for CGS word_topic_counts?)
31  */
32 
33 namespace turi {
34 
35 namespace text {
36 
37 
38 /**
39  *
40  * The basic pseudocode for the AliasDLA method is as follows:
41  *
42 initialize n_{t,w}
43 for w in vocab:
44  compute q_w(t) for all t
45  compute Q_w = sum_t q_w(t)
46  A = GenerateAlias(q_w, K)
47  for k = 1:K
48  S_w.push(SampleAlias(A, K))
49  store q_w(t), Q_w, S_w
50 
51 for d in docs:
52  for i in len(d):
53  w = i'th word in d
54  s = current topic for w in doc d
55  decrement n_{s,d} and n_{s,w} by 1
56  for z where n_{z,d} != 0
57  compute p_dw(z)
58  compute P_dw
59  t = sample from q(t) by popping from S_w
60  if S_w empty:
61  Recompute A and populate S_w
62  Recompute q_w(t), Q_w
63  compute pi
64  if not rand(1) < min(1, pi)
65  t = s
66  increment n_{t,d} and n_{t,w} by 1
67 */
68 
69 
70 class EXPORT alias_topic_model : public topic_model {
71 
72  public:
73 
74  static constexpr size_t ALIAS_TOPIC_MODEL_VERSION = 1;
75 
76  /**
77  * Destructor. Make sure bad things don't happen
78  */
80 
81  /**
82  * Clone objects to a topic_model class
83  */
84  topic_model* topic_model_clone() override;
85 
86  /**
87  * Set the model options. Use the option manager to set these options. The
88  * option manager should throw errors if the options do not satisfy the option
89  * manager's conditions.
90  *
91  * \param[in] opts Options to set
92  */
93  void init_options(const std::map<std::string,flexible_type>& _opts) override;
94 
95  inline size_t get_version() const override {
96  return ALIAS_TOPIC_MODEL_VERSION;
97  }
98 
99  /**
100  * Turi serialization save
101  */
102  void save_impl(turi::oarchive& oarc) const override;
103 
104  /**
105  * Turi serialization save
106  */
107  void load_version(turi::iarchive& iarc, size_t version) override;
108 
109 
110  /**
111  * Train the model using the method described in (Li, 2014).
112  *
113  */
114  void train(std::shared_ptr<sarray<flexible_type> > data, bool verbose) override;
115 
116  /**
117  * Use the dataset to create an initial set of topic assignments.
118  * Each element is a vector whose length is the total number of
119  * words in the respective document. If the first word occurs
120  * M times, then the first M elements of this vector are the
121  * latent assignments for that word.
122  * While sampling new assignments, topic_counts and
123  * doc_topic_counts are incremented.
124  */
125  std::shared_ptr<sarray<std::vector<size_t>>>
126  forward_sample(v2::ml_data d);
127 
128  /**
129  * For the given word do the following:
130  * - Compute q_w(t) and Q_w for word w. Stores this in members q and Q.
131  * - Compute the alias datastructures for each word w.
132  * - Fill the cache of topic samples, S_w.
133  */
134  void cache_word_pmf_and_samples(size_t w);
135 
136  /**
137  * Simultaneously iterate through an v2::ml_data object and the sarray of
138  * latent topic assignments. For each instance of a word, resample its topic.
139  */
140  std::map<std::string, size_t> sample_counts(v2::ml_data d, size_t num_blocks);
141 
142  /**
143  * Perform sampling given a block of data d (typically a slice
144  * of an SArray represnted via an ml_data object).
145  */
146  void sample_block(const v2::ml_data& d,
147  std::vector<std::vector<size_t>>& doc_assignments);
148 
149  /**
150  * Sample a new topic for word w in document d.
151  * \param document d
152  * \param word w
153  * \param initial topic s
154  * \param vector of topic probabilities that gets used for sampling
155  *
156  */
157  size_t sample_topic(size_t d, size_t w, size_t s,
158  std::vector<double>& pd);
159 
160  private:
161  std::shared_ptr<sarray<std::vector<size_t>>> assignments;
162 
163  spmat doc_topic_counts;
164  count_vector_type topic_counts;
165 
166  // Initialize counter for tracking sampling statistics.
167  atomic<size_t> token_count;
168 
169  // Word pdf datastructures
170  Eigen::MatrixXd q; // pmf for each word
171  Eigen::MatrixXd Q; // normalizing const for each word
172  std::vector<random::alias_sampler> word_samplers;
173  std::vector<std::vector<size_t>> word_samples;
174 
175  // Constants
176  size_t TARGET_BLOCK_NUM_ELEMENTS = 1000000000/16; // approx 1gb in memory per block
177 
178  public:
179 
180  // TODO: convert interface above to use the extensions methods here
181  BEGIN_CLASS_MEMBER_REGISTRATION("alias_topic_model")
184 
185 };
186 }
187 }
188 #endif
size_t get_version() const override
Definition: alias.hpp:95
#define BEGIN_CLASS_MEMBER_REGISTRATION(python_facing_classname)
#define REGISTER_CLASS_MEMBER_FUNCTION(function,...)
The serialization input archive object which, provided with a reference to an istream, will read from the istream, providing deserialization capabilities.
Definition: iarchive.hpp:60
std::vector< std::string > list_fields()
#define END_CLASS_MEMBER_REGISTRATION
Definition: spmat.hpp:22
The serialization output archive object which, provided with a reference to an ostream, will write to the ostream, providing serialization capabilities.
Definition: oarchive.hpp:80