turicreate.decision_tree_classifier.DecisionTreeClassifier.predict_topk

DecisionTreeClassifier.predict_topk(dataset, output_type='probability', k=3, missing_value_action='auto')

Return top-k predictions for the dataset, using the trained model. Predictions are returned as an SFrame with three columns: id, class, and probability, margin, or rank, depending on the output_type parameter. Input dataset size must be the same as for training of the model.

Parameters:
dataset : SFrame

A dataset that has the same columns that were used during training. If the target column exists in dataset it will be ignored while making predictions.

output_type : {‘probability’, ‘rank’, ‘margin’}, optional

Choose the return type of the prediction:

  • probability: Probability associated with each label in the prediction.
  • rank : Rank associated with each label in the prediction.
  • margin : Margin associated with each label in the prediction.
k : int, optional

Number of classes to return for each input example.

missing_value_action : str, optional

Action to perform when missing values are encountered. Can be one of:

  • ‘auto’: By default the model will treat missing value as is.
  • ‘impute’: Proceed with evaluation by filling in the missing values with the mean of the training data. Missing values are also imputed if an entire column of data is missing during evaluation.
  • ‘error’: Do not proceed with evaluation and terminate with an error message.
Returns:
out : SFrame

An SFrame with model predictions.

Examples

>>> pred = m.predict_topk(validation_data, k=3)
>>> pred
+--------+-------+-------------------+
| id     | class |   probability     |
+--------+-------+-------------------+
|   0    |   4   |   0.995623886585  |
|   0    |   9   |  0.0038311756216  |
|   0    |   7   | 0.000301006948575 |
|   1    |   1   |   0.928708016872  |
|   1    |   3   |  0.0440889261663  |
|   1    |   2   |  0.0176190119237  |
|   2    |   3   |   0.996967732906  |
|   2    |   2   |  0.00151345680933 |
|   2    |   7   | 0.000637513934635 |
|   3    |   1   |   0.998070061207  |
|  ...   |  ...  |        ...        |
+--------+-------+-------------------+
[35688 rows x 3 columns]