loss_fn.distillation package
Submodules
loss_fn.distillation.base_distillation module
- class loss_fn.distillation.base_distillation.BaseDistillationCriteria(opts: Namespace, *args, **kwargs)[source]
Bases:
BaseCriteria
Base class for defining distillation loss functions. Sub-classes must implement _forward_distill function.
- Parameters:
opts – command line arguments
- __init__(opts: Namespace, *args, **kwargs) None [source]
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- classmethod add_arguments(parser: ArgumentParser) ArgumentParser [source]
Add criterion-specific arguments to the parser.
- forward(input_sample: Tensor, prediction: Mapping[str, Tensor] | Tensor, target: Tensor, *args, **kwargs) Mapping[str, Tensor] | Tensor [source]
Computes distillation loss
- Parameters:
input_sample – Input image tensor.
prediction – Output of model. It can be a tensor or mapping of (string: Tensor). In case of a dictionary,
key. (logits is a required) –
target – Target label tensor containing values in the range [0, C), where
is the number of classes
- Shapes:
input_sample: The shape of input tensor is [N, C, H, W] prediction:
When prediction is a tensor, then shape is [N, C]
When prediction is a dictionary, then shape of prediction[“logits”] is [N, C]
target: The shape of target tensor is [N]
- Returns:
Scalar loss value is returned.
loss_fn.distillation.hard_distillation module
- class loss_fn.distillation.hard_distillation.HardDistillationLoss(opts: Namespace, *args, **kwargs)[source]
Bases:
BaseDistillationCriteria
Hard distillation using cross-entropy for classification tasks. Given an input sample, hard-labels are generated from a teacher and cross-entropy loss is computed between hard-labels and student model’s output.
- Parameters:
opts – command-line arguments
- __init__(opts: Namespace, *args, **kwargs) None [source]
Initializes internal Module state, shared by both nn.Module and ScriptModule.
loss_fn.distillation.soft_kl_distillation module
- class loss_fn.distillation.soft_kl_distillation.SoftKLLoss(opts: Namespace, *args, **kwargs)[source]
Bases:
BaseDistillationCriteria
Soft KL Loss for classification tasks. Given an input sample, soft-labels (or probabilities) are generated from a teacher and KL loss is computed between soft-labels and student model’s output.
- Parameters:
opts – command-line arguments
- __init__(opts: Namespace, *args, **kwargs) None [source]
Initializes internal Module state, shared by both nn.Module and ScriptModule.