Turi Create  4.0
topic_model.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_TEXT_TOPICMODEL_H_
7 #define TURI_TEXT_TOPICMODEL_H_
8 
9 // SFrame
10 #include <core/storage/sframe_data/sarray.hpp>
11 #include <core/storage/sframe_data/sframe.hpp>
12 
13 // Other
14 #include <core/storage/fileio/temp_files.hpp>
15 #include <iostream>
16 
17 // Types
18 #include <model_server/lib/unity_base_types.hpp>
19 #include <core/util/hash_value.hpp>
20 #include <model_server/lib/flex_dict_view.hpp>
21 #include <toolkits/ml_data_2/ml_data.hpp>
22 #include <toolkits/ml_data_2/metadata.hpp>
23 
24 // Interfaces
25 #include <model_server/lib/extensions/ml_model.hpp>
26 
27 // External
28 #include <Eigen/Core>
29 #include <core/export.hpp>
30 namespace turi {
31 
32 namespace text {
33 
34 /**
35  * Class for learning topic models of text corpora.
36  *
37  * Typical use (as seen in cgs.cpp):
38  *
39  * 1) Create a topic model with a map of options:
40  *
41  * topic_model m = new topic_model(options);
42  *
43  * 2) Create an ml_data object where words have been assigned integers
44  * to faciliate indexing.
45  *
46  * ml_data d = m->create_ml_data_using_metadata(dataset);
47  *
48  * 3) Initialize the model so that we have the internal parameters needed
49  * for each of the words observed in the dataset.
50  *
51  * m->init();
52  *
53  * Note: Two other actions can be useful after initialization:
54  *
55  * set_topics: Loads a set of topics and vocabulary.
56  * set_associations: Loads a set of word-topic assignments.
57  *
58  */
59 class EXPORT topic_model : public ml_model_base{
60 
61  public:
62 
63  typedef Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor> count_matrix_type;
64  typedef Eigen::Matrix<int, 1, Eigen::Dynamic, Eigen::RowMajor> count_vector_type;
65 
66  static constexpr size_t TOPIC_MODEL_VERSION = 1;
67 
68  protected:
69  // Model options
70  size_t num_topics; /* < Number of topics to learn. */
71  size_t vocab_size; /* < Number of words in the vocabulary. */
72  size_t num_words; /* < Number of words in the corpus. */
73  std::map<size_t, size_t> associations; /* < Fixed word-topic associations. */
74 
75  // Hyperparameters
76  double alpha; /* < Controls smoothing over topics. */
77  double beta; /* < Controls smoothing over words. */
78 
79  // Vocabulary lookup
80  std::shared_ptr<v2::ml_metadata> metadata;
81 
82  // Statistics
83  count_matrix_type word_topic_counts; /* < Total count for each word. */
84 
85  // State
86  bool is_initialized; /* < Flag for whether model is ready. */
87  bool option_info_set;
88 
89  // Validation data
90  std::shared_ptr<sarray<flexible_type>> validation_train;
91  std::shared_ptr<sarray<flexible_type>> validation_test;
92 
93  /**
94  * Methods that must be implemented in a new nearest neighbors model
95  * ---------------------------------------------------------------------------
96  */
97  public:
98 
99  /**
100  * Clone objects to a topic_model class
101  *
102  * \returns A new model with the same things in it.
103  *
104  * \ref model_base for details.
105  */
106  virtual topic_model* topic_model_clone() = 0;
107 
108  /**
109  * Set the model options. Use the option manager to set these options. The
110  * option manager should throw errors if the options do not satisfy the option
111  * manager's conditions.
112  *
113  * \param[in] opts Options to set
114  */
115  virtual void init_options(const std::map<std::string,flexible_type>& _opts) override = 0;
116 
117 
118  /**
119  * Gets the model version number
120  */
121  virtual size_t get_version() const override = 0;
122 
123  /**
124  * Serialize the model object.
125  */
126  virtual void save_impl(turi::oarchive& oarc) const override = 0;
127 
128  /**
129  * Load the model object.
130  */
131  virtual void load_version(turi::iarchive& iarc, size_t version) override = 0;
132 
133  /**
134  * Create a topic model.
135  */
136  virtual void train(std::shared_ptr<sarray<flexible_type>> dataset, bool verbose) = 0;
137 
138 
139  /**
140  * Lists all the keys accessible in the "model" map.
141  *
142  * \returns List of keys in the model map.
143  * \ref model_base for details.
144  *
145  */
146  std::vector<std::string> list_fields();
147 
148  /**
149  * Methods with meaningful default implementations.
150  * -------------------------------------------------------------------------
151  */
152  public:
153 
154  /**
155  * Helper function for creating the appropriate ml_data from an sarray of
156  * documents.
157  *
158  * \param dataset An SArray (of dictionary type) containing document
159  * data in bag of words format, where each element has words as keys
160  * and the corresponding counts as values.
161  */
162  v2::ml_data create_ml_data_using_metadata(
163  std::shared_ptr<sarray<flexible_type>> dataset);
164 
165  /**
166  * Methods available to all topic_models.
167  * ----------------------------------------------------------------------
168  */
169 
170  /**
171  * Load a set of associations comprising a (word, topic) pair that should
172  * be considered fixed.
173  *
174  * \param associations An SFrame with two columns named 'word' and 'topic'.
175  */
176  void set_associations(const sframe& associations);
177 
178  /**
179  * Remove current vocabulary and topics and load these instead.
180  *
181  * \param word_topic_prob An SArray of vector type, where each element
182  * has size num_topics. The k'th element represents the probability of
183  * the corresponding word in vocabulary under topic k.
184  * \param vocabulary An SArray of string type containing the unique
185  * words that should be loaded into the model. This must have the same
186  * length as word_topic_prob.
187  * \param weight The weight the model should give these probabilites
188  * when learning. In other words, the provided word-topic probabilities
189  * are multiplied by this weight before used as count
190  * matrices within the model.
191  *
192  */
193  void set_topics(const std::shared_ptr<sarray<flexible_type>> word_topic_prob,
194  const std::shared_ptr<sarray<flexible_type>> vocabulary,
195  size_t weight);
196 
197  /**
198  * Get the most probable words for a given topic.
199  *
200  * \param topic_id The integer id of the topic. Must be in [0, num_topics)
201  * length vocab_size used to construct the topic_model object.
202  * \param num_words The number of words to return for the given topic.
203  * \param cdf_cutoff After ordering words by probability, this will only
204  * return words while the cumulative probability of the words is below
205  * this cutoff value.
206  *
207  * \returns Returns an SFrame with the word and its corresponding score.
208  * The SFrame is sorted by score.
209  */
210  std::pair<std::vector<flexible_type>, std::vector<double>>
211  get_topic(size_t topic_id, size_t num_words=5, double cdf_cutoff=1.0);
212 
213  public:
214 
215  /**
216  * Make predictions on the given data set.
217  *
218  * This method closely resembles the sampler in the collapsed Gibbs
219  * sampler solver found in cgs.hpp. Here, however, the word_topic_counts
220  * matrix is held fixed. For each document, num_burnin iterations are
221  * performed where in each iteration we sample the topic_assignments.
222  * The returned predictions are probabilities, and are computed by
223  * smoothing the doc_topic_counts matrix that arising from sampling.
224  *
225  */
226  std::shared_ptr<sarray<flexible_type>>
227  predict_gibbs(std::shared_ptr<sarray<flexible_type>> data,
228  size_t num_burnin);
229 
230  /**
231  * Make predictions for a given data set. Return the number of assignments
232  * of each topic for each document in the dataset.
233  */
234  count_matrix_type predict_counts(std::shared_ptr<sarray<flexible_type> > dataset, size_t num_burnin);
235 
236 
237  /**
238  * Returns the current topics matrix as an SFrame
239  */
240  std::shared_ptr<sarray<flexible_type>> get_topics_matrix();
241 
242  /**
243  * Returns current vocabulary of words.
244  */
245  std::shared_ptr<sarray<flexible_type>> get_vocabulary();
246 
247  /**
248  * Compute perplexity. For more details see the docstrings for
249  * the version that is not a member of the topic_model class.
250  * This version is for a model's internal usage, i.e. where the two
251  * count matrices are already available. Note that the first thing
252  * this method does is normalize counts to be proper probabilities.
253  * This is done via:
254  * doc_topic_prob[d, k] = p(topic k | document d)
255  * = (doc_topic_count[d, k] + alpha) /
256  * \sum_k' (doc_topic_count[d, k'] + alpha)
257  * word_topic_prob[w, k] = p(word w | topic k)
258  * = (word_topic_count[w, k] + eta) /
259  * \sum_w' (word_topic_count[w', k] + eta)
260  *
261  */
262  double perplexity(std::shared_ptr<sarray<flexible_type>> documents,
263  const count_matrix_type& doc_topic_counts,
264  const count_matrix_type& word_topic_counts);
265 
266  void init_validation(std::shared_ptr<sarray<flexible_type> > validation_train,
267 std::shared_ptr<sarray<flexible_type> > validation_test);
268 
269 
270 };
271 
272 
273 } // text
274 } // turicreate
275 
276 #endif
The serialization input archive object which, provided with a reference to an istream, will read from the istream, providing deserialization capabilities.
Definition: iarchive.hpp:60
The serialization output archive object which, provided with a reference to an ostream, will write to the ostream, providing serialization capabilities.
Definition: oarchive.hpp:80