loss_fn.distillation package

Submodules

loss_fn.distillation.base_distillation module

class loss_fn.distillation.base_distillation.BaseDistillationCriteria(opts: Namespace, *args, **kwargs)[source]

Bases: BaseCriteria

Base class for defining distillation loss functions. Sub-classes must implement _forward_distill function.

Parameters:

opts – command line arguments

__init__(opts: Namespace, *args, **kwargs) None[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

classmethod add_arguments(parser: ArgumentParser) ArgumentParser[source]

Add criterion-specific arguments to the parser.

forward(input_sample: Tensor, prediction: Mapping[str, Tensor] | Tensor, target: Tensor, *args, **kwargs) Mapping[str, Tensor] | Tensor[source]

Computes distillation loss

Parameters:
  • input_sample – Input image tensor.

  • prediction – Output of model. It can be a tensor or mapping of (string: Tensor). In case of a dictionary,

  • key. (logits is a required) –

  • target – Target label tensor containing values in the range [0, C), where \(C\) is the number of classes

Shapes:

input_sample: The shape of input tensor is [N, C, H, W] prediction:

  • When prediction is a tensor, then shape is [N, C]

  • When prediction is a dictionary, then shape of prediction[“logits”] is [N, C]

target: The shape of target tensor is [N]

Returns:

  • Scalar loss value is returned.

loss_fn.distillation.hard_distillation module

class loss_fn.distillation.hard_distillation.HardDistillationLoss(opts: Namespace, *args, **kwargs)[source]

Bases: BaseDistillationCriteria

Hard distillation using cross-entropy for classification tasks. Given an input sample, hard-labels are generated from a teacher and cross-entropy loss is computed between hard-labels and student model’s output.

Parameters:

opts – command-line arguments

__init__(opts: Namespace, *args, **kwargs) None[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

classmethod add_arguments(parser: ArgumentParser) ArgumentParser[source]

Add criterion-specific arguments to the parser.

extra_repr() str[source]

Set the extra representation of the module

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

loss_fn.distillation.soft_kl_distillation module

class loss_fn.distillation.soft_kl_distillation.SoftKLLoss(opts: Namespace, *args, **kwargs)[source]

Bases: BaseDistillationCriteria

Soft KL Loss for classification tasks. Given an input sample, soft-labels (or probabilities) are generated from a teacher and KL loss is computed between soft-labels and student model’s output.

Parameters:

opts – command-line arguments

__init__(opts: Namespace, *args, **kwargs) None[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

classmethod add_arguments(parser: ArgumentParser) ArgumentParser[source]

Add criterion-specific arguments to the parser.

extra_repr() str[source]

Set the extra representation of the module

To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.

Module contents