turicreate.topic_model.TopicModel.predict

TopicModel.predict(dataset, output_type='assignment', num_burnin=None)

Use the model to predict topics for each document. The provided dataset should be an SArray object where each element is a dict representing a single document in bag-of-words format, where keys are words and values are their corresponding counts. If dataset is an SFrame, then it must contain a single column of dict type.

The current implementation will make inferences about each document given its estimates of the topics learned when creating the model. This is done via Gibbs sampling.

Parameters:
dataset : SArray, SFrame of type dict

A set of documents to use for making predictions.

output_type : str, optional

The type of output desired. This can either be

  • assignment: the returned values are integers in [0, num_topics)
  • probability: each returned prediction is a vector with length num_topics, where element k represents the probability that document belongs to topic k.
num_burnin : int, optional

The number of iterations of Gibbs sampling to perform when inferring the topics for documents at prediction time. If provided this will override the burnin value set during training.

Returns:
out : SArray

See also

evaluate

Notes

For each unique word w in a document d, we sample an assignment to topic k with probability proportional to

\[p(z_{dw} = k) \propto (n_{d,k} + \alpha) * \Phi_{w,k}\]

where

  • \(W\) is the size of the vocabulary,
  • \(n_{d,k}\) is the number of other times we have assigned a word in document to d to topic \(k\),
  • \(\Phi_{w,k}\) is the probability under the model of choosing word \(w\) given the word is of topic \(k\). This is the matrix returned by calling m[‘topics’].

This represents a collapsed Gibbs sampler for the document assignments while we keep the topics learned during training fixed. This process is done in parallel across all documents, five times per document.

Examples

Make predictions about which topic each document belongs to.

>>> docs = turicreate.SArray('https://static.turi.com/datasets/nips-text')
>>> m = turicreate.topic_model.create(docs)
>>> pred = m.predict(docs)

If one is interested in the probability of each topic

>>> pred = m.predict(docs, output_type='probability')