Source code for metrics.confusion_mat

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

from numbers import Number
from typing import Any, Dict, List, Optional, Union

import torch
from torch import Tensor

from metrics import METRICS_REGISTRY
from metrics.metric_base import BaseMetric
from utils.tensor_utils import reduce_tensor_sum


# TODO: tests
[docs]@METRICS_REGISTRY.register("confusion_matrix") class ConfusionMatrix(BaseMetric): """ Computes the confusion matrix and is based on `FCN <https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/score.py>`_ """
[docs] def reset(self): self.confusion_mat = None self.prediction_key = "logits"
[docs] def update( self, prediction: Union[Tensor, Dict], target: Union[Tensor, Dict], extras: Dict[str, Any] = {}, batch_size: Optional[int] = 1, ): if isinstance(prediction, dict) and self.prediction_key in prediction: prediction = prediction[self.prediction_key] if isinstance(prediction, dict) or isinstance(prediction, dict): raise NotImplementedError( "ConfusionMatrix does not currently support Dict predictions or targets" ) n_classes = prediction.shape[1] if self.confusion_mat is None: self.confusion_mat = torch.zeros( (n_classes, n_classes), dtype=torch.int64, device=target.device ) with torch.no_grad(): prediction = prediction.argmax(1).flatten() target = target.flatten() k = (target >= 0) & (target < n_classes) inds = n_classes * target[k].to(torch.int64) + prediction[k] cnts = torch.bincount(inds, minlength=n_classes**2).reshape( n_classes, n_classes ) if self.is_distributed: cnts = reduce_tensor_sum(cnts) self.confusion_mat += cnts
[docs] def compute(self) -> Union[Number, Dict[str, Union[Number, List[Number]]]]: if self.confusion_mat is None: print("Confusion matrix is None. Check code") return None h = self.confusion_mat.float() metrics: Dict[str, Tensor] = {} metrics["accuracy_global"] = torch.diag(h).sum() / h.sum() diag_h = torch.diag(h) metrics["class_accuracy"] = diag_h / h.sum(1) metrics["mean_class_accuracy"] = metrics["class_accuracy"].mean() metrics["iou"] = diag_h / (h.sum(1) + h.sum(0) - diag_h) metrics["mean_iou"] = metrics["iou"].mean() metrics["confusion"] = self.confusion_mat # Making sure all values are converted to Python values metrics = {k: v.detach().cpu().tolist() for k, v in metrics.items()} return metrics