loss_fn.classification package

Submodules

loss_fn.classification.base_classification_criteria module

class loss_fn.classification.base_classification_criteria.BaseClassificationCriteria(opts: Namespace, *args, **kwargs)[source]

Bases: BaseCriteria

Base class for defining classification loss functions. Sub-classes must implement forward function.

Parameters:

opts – command line arguments

__init__(opts: Namespace, *args, **kwargs) None[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

classmethod add_arguments(parser: ArgumentParser) ArgumentParser[source]

Add criterion-specific arguments to the parser.

forward(input_sample: Any, prediction: Dict[str, Tensor] | Tensor, target: Tensor, *args, **kwargs) Tensor[source]

Computes the cross entropy loss.

Parameters:
  • input_sample – Input image tensor to model.

  • prediction – Output of model. It can be a tensor or mapping of (string: Tensor). In case of a dictionary,

  • key. (logits is a required) –

  • target – Target label tensor containing values in the range [0, C), where \(C\) is the number of classes

Shapes:

input_sample: This loss function does not care about this argument. prediction:

  • When prediction is a tensor, then shape is [N, C]

  • When prediction is a dictionary, then the shape of prediction[“logits”] is [N, C]

target: The shape of target tensor is [N]

Returns:

Scalar loss value is returned.

loss_fn.classification.binary_cross_entropy module

class loss_fn.classification.binary_cross_entropy.BinaryCrossEntropy(opts: Namespace, *args, **kwargs)[source]

Bases: BaseClassificationCriteria

Binary cross-entropy loss for classification tasks

Parameters:

opts – command-line arguments

__init__(opts: Namespace, *args, **kwargs) None[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

classmethod add_arguments(parser: ArgumentParser) ArgumentParser[source]

Add criterion-specific arguments to the parser.

extra_repr() str[source]

Set the extra representation of the module

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

loss_fn.classification.cross_entropy module

class loss_fn.classification.cross_entropy.CrossEntropy(opts: Namespace, *args, **kwargs)[source]

Bases: BaseClassificationCriteria

Cross entropy loss function for image classification tasks

Parameters:

opts – command-line arguments

__init__(opts: Namespace, *args, **kwargs) None[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

classmethod add_arguments(parser: ArgumentParser) ArgumentParser[source]

Add cross-entropy criterion-specific arguments to the parser.

extra_repr() str[source]

Set the extra representation of the module

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

Module contents