loss_fn.utils package
Submodules
loss_fn.utils.build_helper module
loss_fn.utils.class_weighting module
- loss_fn.utils.class_weighting.compute_class_weights(target: Tensor, n_classes: int, norm_val: float = 1.1) Tensor [source]
Implementation of a class-weighting scheme, as defined in Section 5.2 of ENet paper.
- Parameters:
target – Tensor of shape [Batch_size, *] containing values in the range [0, C).
n_classes – Integer specifying the number of classes \(C\)
norm_val – Normalization value. Defaults to 1.1. This value is decided based on the
<https (`ESPNetv2 paper) –
//arxiv.org/abs/1811.11431>`_.
- Returns:
A \(C\)-dimensional tensor containing class weights