Source code for loss_fn.distillation.base_distillation

#  For licensing see accompanying LICENSE file.
#  Copyright (C) 2023 Apple Inc. All Rights Reserved.

import argparse
from typing import Mapping, Union

import torch
from torch import Tensor

from loss_fn import LOSS_REGISTRY, BaseCriteria
from loss_fn.utils.build_helper import build_cls_teacher_from_opts
from utils import logger


[docs]@LOSS_REGISTRY.register(name="__base__", type="distillation") class BaseDistillationCriteria(BaseCriteria): """Base class for defining distillation loss functions. Sub-classes must implement `_forward_distill` function. Args: opts: command line arguments """
[docs] def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None: super().__init__(opts, *args, **kwargs) self.teacher = build_cls_teacher_from_opts(opts=opts)
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: if cls != BaseDistillationCriteria: # Don't re-register arguments in subclasses that don't override `add_arguments()`. return parser group = parser.add_argument_group(cls.__name__) group.add_argument( "--loss.distillation.name", type=str, default=None, help="Name of the loss function. Defaults to None.", ) return parser
@torch.no_grad() def _logits_from_teacher(self, input_sample: Tensor) -> Tensor: """Compute logits from teacher given input image tensor. Args: input_sample: Input image tensor Shape: input_sample: Shape is [Batch size, 3, height, width] teacher_output or teacher_output["logits"]: Shape is [Batch size, number of classes] Returns: Teacher output tensor (without softmax) ...note: The output of teacher can be Tensor or Dict[str, Tensor]. In case of dictionary, logits is a mandatory key. """ self.teacher.eval() teacher_output: Union[Tensor, Mapping[str, Tensor]] = self.teacher(input_sample) if isinstance(teacher_output, Mapping): if "logits" not in teacher_output: logger.error( "The output type of teacher is dictionary and must contain logits as a key." f"Got: {teacher_output.keys()}" ) return teacher_output["logits"] return teacher_output def _forward_distill( self, input_sample: Tensor, prediction: Tensor, *args, **kwargs ) -> Tensor: """Computes distillation loss. Args: input_sample: Input image tensor prediction: Student model's output. Shapes: input_sample: Shape is [Batch size, 3, height, width] prediction: Shape is [Batch size, number of classes] Returns: A scalar loss value. """ raise NotImplementedError
[docs] def forward( self, input_sample: Tensor, prediction: Union[Mapping[str, Tensor], Tensor], target: Tensor, *args, **kwargs, ) -> Union[Mapping[str, Tensor], Tensor]: """Computes distillation loss Args: input_sample: Input image tensor. prediction: Output of model. It can be a tensor or mapping of (string: Tensor). In case of a dictionary, `logits` is a required key. target: Target label tensor containing values in the range `[0, C)`, where :math:`C` is the number of classes Shapes: input_sample: The shape of input tensor is [N, C, H, W] prediction: * When prediction is a tensor, then shape is [N, C] * When prediction is a dictionary, then shape of prediction["logits"] is [N, C] target: The shape of target tensor is [N] Returns: * Scalar loss value is returned. """ if isinstance(prediction, Tensor): return self._forward_distill( input_sample=input_sample, prediction=prediction, *args, **kwargs ) elif isinstance(prediction, Mapping): if "logits" not in prediction: logger.error( f"logits is a required key in {self.__class__.__name__} when prediction type" f"is dictionary. Got keys: {prediction.keys()}" ) predicted_logits = prediction["logits"] # compute distillation loss distill_loss = self._forward_distill( input_sample=input_sample, prediction=predicted_logits, *args, **kwargs ) return distill_loss else: logger.error( f"Prediction should be either a Tensor or Dictionary[str, Tensor]. Got: {type(prediction)}" )