Source code for loss_fn.utils.class_weighting

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import torch
from torch import Tensor


[docs]def compute_class_weights( target: Tensor, n_classes: int, norm_val: float = 1.1 ) -> Tensor: """Implementation of a class-weighting scheme, as defined in Section 5.2 of `ENet <https://arxiv.org/pdf/1606.02147.pdf>`_ paper. Args: target: Tensor of shape [Batch_size, *] containing values in the range `[0, C)`. n_classes: Integer specifying the number of classes :math:`C` norm_val: Normalization value. Defaults to 1.1. This value is decided based on the `ESPNetv2 paper <https://arxiv.org/abs/1811.11431>`_. Link: https://github.com/sacmehta/ESPNetv2/blob/b78e323039908f31347d8ca17f49d5502ef1a594/segmentation/loadData.py#L16 Returns: A :math:`C`-dimensional tensor containing class weights """ class_hist = torch.histc(target.float(), bins=n_classes, min=0, max=n_classes - 1) print(class_hist) mask_indices = class_hist == 0 # normalize between 0 and 1 by dividing by the sum norm_hist = torch.div(class_hist, class_hist.sum()) print(norm_hist) norm_hist = torch.add(norm_hist, norm_val) # compute class weights. # samples with more frequency will have less weight and vice-versa class_wts = torch.div(torch.ones_like(class_hist), torch.log(norm_hist)) # mask the classes which do not have samples in the current batch class_wts[mask_indices] = 0.0 return class_wts.to(device=target.device)