turicreate.topic_model.perplexity

turicreate.topic_model.perplexity(test_data, predictions, topics, vocabulary)

Compute the perplexity of a set of test documents given a set of predicted topics.

Let theta be the matrix of document-topic probabilities, where theta_ik = p(topic k | document i). Let Phi be the matrix of term-topic probabilities, where phi_jk = p(word j | topic k).

Then for each word in each document, we compute for a given word w and document d

\[p(word | heta[doc_id,:], \phi[word_id,:]) = \sum_k heta[doc_id, k] * \phi[word_id, k]\]

We compute loglikelihood to be:

\[l(D) = \sum_{i \in D} \sum_{j in D_i} count_{i,j} * log Pr(word_{i,j} | heta, \phi)\]

and perplexity to be

\[\exp \{ - l(D) / \sum_i \sum_j count_{i,j} \}\]
Parameters:
test_data : SArray of type dict or SFrame with a single column of type dict

Documents in bag-of-words format.

predictions : SArray

An SArray of vector type, where each vector contains estimates of the probability that this document belongs to each of the topics. This must have the same size as test_data; otherwise an exception occurs. This can be the output of predict(), for example.

topics : SFrame

An SFrame containing two columns: ‘vocabulary’ and ‘topic_probabilities’. The value returned by m[‘topics’] is a valid input for this argument, where m is a trained TopicModel.

vocabulary : SArray

An SArray of words to use. All words in test_data that are not in this vocabulary will be ignored.

Notes

For more details, see equations 13-16 of [PattersonTeh2013].

References

[PERP]Wikipedia - perplexity
[PattersonTeh2013]Patterson, Teh. “Stochastic Gradient Riemannian Langevin Dynamics on the Probability Simplex” NIPS, 2013.

Examples

>>> from turicreate import topic_model
>>> train_data, test_data = turicreate.text_analytics.random_split(docs)
>>> m = topic_model.create(train_data)
>>> pred = m.predict(train_data)
>>> topics = m['topics']
>>> p = topic_model.perplexity(test_data, pred,
                               topics['topic_probabilities'],
                               topics['vocabulary'])
>>> p
1720.7  # lower values are better