loss_fn.classification package
Submodules
loss_fn.classification.base_classification_criteria module
- class loss_fn.classification.base_classification_criteria.BaseClassificationCriteria(opts: Namespace, *args, **kwargs)[source]
Bases:
BaseCriteria
Base class for defining classification loss functions. Sub-classes must implement forward function.
- Parameters:
opts – command line arguments
- __init__(opts: Namespace, *args, **kwargs) None [source]
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- classmethod add_arguments(parser: ArgumentParser) ArgumentParser [source]
Add criterion-specific arguments to the parser.
- forward(input_sample: Any, prediction: Dict[str, Tensor] | Tensor, target: Tensor, *args, **kwargs) Tensor [source]
Computes the cross entropy loss.
- Parameters:
input_sample – Input image tensor to model.
prediction – Output of model. It can be a tensor or mapping of (string: Tensor). In case of a dictionary,
key. (logits is a required) –
target – Target label tensor containing values in the range [0, C), where \(C\) is the number of classes
- Shapes:
input_sample: This loss function does not care about this argument. prediction:
When prediction is a tensor, then shape is [N, C]
When prediction is a dictionary, then the shape of prediction[“logits”] is [N, C]
target: The shape of target tensor is [N]
- Returns:
Scalar loss value is returned.
loss_fn.classification.binary_cross_entropy module
- class loss_fn.classification.binary_cross_entropy.BinaryCrossEntropy(opts: Namespace, *args, **kwargs)[source]
Bases:
BaseClassificationCriteria
Binary cross-entropy loss for classification tasks
- Parameters:
opts – command-line arguments
- __init__(opts: Namespace, *args, **kwargs) None [source]
Initializes internal Module state, shared by both nn.Module and ScriptModule.
loss_fn.classification.cross_entropy module
- class loss_fn.classification.cross_entropy.CrossEntropy(opts: Namespace, *args, **kwargs)[source]
Bases:
BaseClassificationCriteria
Cross entropy loss function for image classification tasks
- Parameters:
opts – command-line arguments
- __init__(opts: Namespace, *args, **kwargs) None [source]
Initializes internal Module state, shared by both nn.Module and ScriptModule.