Turi Create  4.0
cgs.hpp
1 /* Copyright © 2017 Apple Inc. All rights reserved.
2  *
3  * Use of this source code is governed by a BSD-3-clause license that can
4  * be found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause
5  */
6 #ifndef TURI_TEXT_CGS_H_
7 #define TURI_TEXT_CGS_H_
8 
9 #include <vector>
10 #include <core/export.hpp>
11 #include <toolkits/text/topic_model.hpp>
12 #include <toolkits/util/spmat.hpp>
13 
14 namespace turi {
15 
16 namespace text {
17 
18 /**
19  * Returns a random categorical variable in [0, ..., K-1] where
20  * K is the length of the provided vector.
21  * Modifies the provided vector to be normalized probabilities.
22  */
23 size_t random_categorical(std::vector<double>& logprobs);
24 
25 class EXPORT cgs_topic_model : public topic_model {
26 
27  public:
28 
29  static constexpr size_t CGS_TOPIC_MODEL_VERSION = 1;
30 
31  /**
32  * Destructor. Make sure bad things don't happen
33  */
34  ~cgs_topic_model();
35 
36  /**
37  * Clone objects to a topic_model class
38  */
39  topic_model* topic_model_clone() override;
40 
41  /**
42  * Set the model options. Use the option manager to set these options. The
43  * option manager should throw errors if the options do not satisfy the option
44  * manager's conditions.
45  *
46  * \param[in] opts Options to set
47  */
48  void init_options(const std::map<std::string,flexible_type>& _opts) override;
49 
50  inline size_t get_version() const override {
51  return CGS_TOPIC_MODEL_VERSION;
52  }
53 
54  /**
55  * Turi serialization save
56  */
57  void save_impl(turi::oarchive& oarc) const override;
58 
59  /**
60  * Turi serialization save
61  */
62  void load_version(turi::iarchive& iarc, size_t version) override;
63 
64 
65  /**
66  * Train the model using collapsed Gibbs sampling.
67  *
68  * For the seminal work on this, see Griffiths, Steyvers 2004.
69  *
70  * This algorithm is a Gibbs sampler where we sample the latent topic for
71  * each word conditioned on all other latent assignments. This particular
72  * algorithm is "collapsed" in the sense that sample from the conditional
73  * distribution of a model where many of the parameters have been
74  * analytically integrated out. This has been experimentally shown to
75  * yield more (statistically) efficient samplers.
76  *
77  * A few departures from the vanilla version:
78  * - Like several other implementations, we sample a single latent
79  * assignment z_ij per (document, word, count) token, rather than
80  * a latent assignment for every occurrence of every word. This
81  * is done for speed reasons, but it no longer is the proper
82  * distribution. It would be easy to add in a loop over the counts
83  * for each (document, word) pair.
84  * - Initialization is done by "forward sampling", where we sample
85  * from the conditional distribution of each latent assignment using
86  * the assignments sampled previously. This allows us to naturally
87  * handle the case where a user has provided a set of topics for
88  * initialization purposes.
89  *
90  */
91  void train(std::shared_ptr<sarray<flexible_type>> data, bool verbose) override;
92 
93  std::shared_ptr<sarray<std::vector<size_t>>>
94  forward_sample(const v2::ml_data& d,
95  count_vector_type& topic_counts,
96  count_matrix_type& doc_topic_counts);
97 
98  std::map<std::string, size_t>
99  sample_counts(const v2::ml_data& d,
100  count_vector_type& topic_counts,
101  count_matrix_type& doc_topic_counts,
102  std::shared_ptr<sarray<std::vector<size_t>>>& assignments);
103 
104  // TODO: convert interface above to use the extensions methods here
105  BEGIN_CLASS_MEMBER_REGISTRATION("cgs_topic_model")
106  REGISTER_CLASS_MEMBER_FUNCTION(cgs_topic_model::list_fields)
108 
109 }; // kmeans_model class
110 
111 }
112 }
113 #endif
#define BEGIN_CLASS_MEMBER_REGISTRATION(python_facing_classname)
#define REGISTER_CLASS_MEMBER_FUNCTION(function,...)
The serialization input archive object which, provided with a reference to an istream, will read from the istream, providing deserialization capabilities.
Definition: iarchive.hpp:60
#define END_CLASS_MEMBER_REGISTRATION
The serialization output archive object which, provided with a reference to an ostream, will write to the ostream, providing serialization capabilities.
Definition: oarchive.hpp:80