turicreate.evaluation.roc_curve

turicreate.evaluation.roc_curve(targets, predictions, average=None, index_map=None)

Compute an ROC curve for the given targets and predictions. Currently, only binary classification is supported.

Parameters:
targets : SArray

An SArray containing the observed values. For binary classification, the alpha-numerically first category is considered the reference category.

predictions : SArray

The prediction that corresponds to each target value. This vector must have the same length as targets. Target scores, can either be probability estimates of the positive class, confidence values, or binary decisions.

average : string, [None (default)]

Metric averaging strategies for multiclass classification. Averaging strategies can be one of the following:

  • None: No averaging is performed and a single metric is returned for each class.
index_map : dict[int], [None (default)]

For binary classification, a dictionary mapping the two target labels to either 0 (negative) or 1 (positive). For multi-class classification, a dictionary mapping potential target labels to the associated index into the vectors in predictions.

Returns:
out : SFrame

Each row represents the predictive performance when using a given cutoff threshold, where all predictions above that cutoff are considered “positive”. Four columns are used to describe the performance:

  • tpr : True positive rate, the number of true positives divided by the number of positives.
  • fpr : False positive rate, the number of false positives divided by the number of negatives.
  • p : Total number of positive values.
  • n : Total number of negative values.
  • class : Reference class for this ROC curve.

See also

confusion_matrix, auc

Notes

  • For binary classification, when the target label is of type “string”, then the labels are sorted alphanumerically and the largest label is chosen as the “positive” label. For example, if the classifier labels are {“cat”, “dog”}, then “dog” is chosen as the positive label for the binary classification case. This behavior can be overridden by providing an explicit index_map.
  • For multi-class classification, when the target label is of type “string”, then the probability vector is assumed to be a vector of probabilities of classes as sorted alphanumerically. Hence, for the probability vector [0.1, 0.2, 0.7] for a dataset with classes “cat”, “dog”, and “rat”; the 0.1 corresponds to “cat”, the 0.2 to “dog” and the 0.7 to “rat”. This behavior can be overridden by providing an explicit index_map.
  • The ROC curve is computed using a binning approximation with 1M bins and is hence accurate only to the 5th decimal.

References

An introduction to ROC analysis. Tom Fawcett.

Examples

>>> targets = turicreate.SArray([0, 1, 1, 0])
>>> predictions = turicreate.SArray([0.1, 0.35, 0.7, 0.99])

# Calculate the roc-curve.
>>> roc_curve =  turicreate.evaluation.roc_curve(targets, predictions)
+-------------------+-----+-----+---+---+
|     threshold     | fpr | tpr | p | n |
+-------------------+-----+-----+---+---+
|        0.0        | 1.0 | 1.0 | 2 | 2 |
| 9.99999974738e-06 | 1.0 | 1.0 | 2 | 2 |
| 1.99999994948e-05 | 1.0 | 1.0 | 2 | 2 |
| 2.99999992421e-05 | 1.0 | 1.0 | 2 | 2 |
| 3.99999989895e-05 | 1.0 | 1.0 | 2 | 2 |
| 4.99999987369e-05 | 1.0 | 1.0 | 2 | 2 |
| 5.99999984843e-05 | 1.0 | 1.0 | 2 | 2 |
| 7.00000018696e-05 | 1.0 | 1.0 | 2 | 2 |
|  7.9999997979e-05 | 1.0 | 1.0 | 2 | 2 |
| 9.00000013644e-05 | 1.0 | 1.0 | 2 | 2 |
+-------------------+-----+-----+---+---+
[100001 rows x 5 columns]

For the multi-class setting, an ROC curve is returned for each class.

# Targets and Predictions
>>> targets = turicreate.SArray([0, 1, 2, 3, 0, 1, 2, 3])
>>> predictions = turicreate.SArray([1, 0, 2, 1, 3, 1, 2, 1])

# Micro average of the recall scores for each class.
>>> turicreate.evaluation.recall(targets, predictions,
...                            average = 'micro')
0.375

# Macro average of the recall scores for each class.
>>> turicreate.evaluation.recall(targets, predictions,
...                            average = 'macro')
0.375

# Recall score for each class.
>>> turicreate.evaluation.recall(targets, predictions,
...                            average = None)
{0: 0.0, 1: 0.5, 2: 1.0, 3: 0.0}

This metric also works in the multi-class setting.

# Targets and Predictions
>>> targets     = turicreate.SArray([ 1, 0, 2, 1])
>>> predictions = turicreate.SArray([[.1, .8, 0.1],
...                                [.9, .1, 0.0],
...                                [.8, .1, 0.1],
...                                [.3, .6, 0.1]])

# Compute the ROC curve.
>>> roc_curve = turicreate.evaluation.roc_curve(targets, predictions)
+-----------+-----+-----+---+---+-------+
| threshold | fpr | tpr | p | n | class |
+-----------+-----+-----+---+---+-------+
|    0.0    | 1.0 | 1.0 | 1 | 3 |   0   |
|   1e-05   | 1.0 | 1.0 | 1 | 3 |   0   |
|   2e-05   | 1.0 | 1.0 | 1 | 3 |   0   |
|   3e-05   | 1.0 | 1.0 | 1 | 3 |   0   |
|   4e-05   | 1.0 | 1.0 | 1 | 3 |   0   |
|   5e-05   | 1.0 | 1.0 | 1 | 3 |   0   |
|   6e-05   | 1.0 | 1.0 | 1 | 3 |   0   |
|   7e-05   | 1.0 | 1.0 | 1 | 3 |   0   |
|   8e-05   | 1.0 | 1.0 | 1 | 3 |   0   |
|   9e-05   | 1.0 | 1.0 | 1 | 3 |   0   |
+-----------+-----+-----+---+---+-------+
[300003 rows x 6 columns]

This metric also works for string classes.

# Targets and Predictions
>>> targets     = turicreate.SArray(["cat", "dog", "foosa", "dog"])
>>> predictions = turicreate.SArray([[.1, .8, 0.1],
...                                [.9, .1, 0.0],
...                                [.8, .1, 0.1],
...                                [.3, .6, 0.1]])

# Compute the ROC curve.
>>> roc_curve = turicreate.evaluation.roc_curve(targets, predictions)
+-----------+-----+-----+---+---+-------+
| threshold | fpr | tpr | p | n | class |
+-----------+-----+-----+---+---+-------+
|    0.0    | 1.0 | 1.0 | 1 | 3 |  cat  |
|   1e-05   | 1.0 | 1.0 | 1 | 3 |  cat  |
|   2e-05   | 1.0 | 1.0 | 1 | 3 |  cat  |
|   3e-05   | 1.0 | 1.0 | 1 | 3 |  cat  |
|   4e-05   | 1.0 | 1.0 | 1 | 3 |  cat  |
|   5e-05   | 1.0 | 1.0 | 1 | 3 |  cat  |
|   6e-05   | 1.0 | 1.0 | 1 | 3 |  cat  |
|   7e-05   | 1.0 | 1.0 | 1 | 3 |  cat  |
|   8e-05   | 1.0 | 1.0 | 1 | 3 |  cat  |
|   9e-05   | 1.0 | 1.0 | 1 | 3 |  cat  |
+-----------+-----+-----+---+---+-------+
[300003 rows x 6 columns]