#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import argparse
from typing import Any, Dict
import torch
from torch import Tensor
from loss_fn import LOSS_REGISTRY
from loss_fn.detection.base_detection_criteria import BaseDetectionCriteria
from utils import logger
[docs]@LOSS_REGISTRY.register(name="mask_rcnn_loss", type="detection")
class MaskRCNNLoss(BaseDetectionCriteria):
"""Mask RCNN loss is computed inside the MaskRCNN model. This class is a wrapper to extract
loss values for different heads (RPN, classification, etc.) and compute the weighted sum.
Args:
opts: command-line arguments
"""
[docs] def __init__(self, opts: argparse.Namespace, *args, **kwargs) -> None:
super().__init__(opts, *args, **kwargs)
self.classifier_weight = getattr(
opts, "loss.detection.mask_rcnn_loss.classifier_weight"
)
self.box_reg_weight = getattr(
opts, "loss.detection.mask_rcnn_loss.box_reg_weight"
)
self.mask_weight = getattr(opts, "loss.detection.mask_rcnn_loss.mask_weight")
self.objectness_weight = getattr(
opts, "loss.detection.mask_rcnn_loss.objectness_weight"
)
self.rpn_box_reg = getattr(opts, "loss.detection.mask_rcnn_loss.rpn_box_reg")
# dev.device is not a part of model arguments. so test fails.
# Setting a default value so test works
self.device = getattr(opts, "dev.device", torch.device("cpu"))
[docs] @classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
if cls != MaskRCNNLoss:
# Don't re-register arguments in subclasses that don't override `add_arguments()`.
return parser
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--loss.detection.mask-rcnn-loss.classifier-weight",
type=float,
default=1,
help=f"Weight for classifier in {cls.__name__}. Defaults to 1.",
)
group.add_argument(
"--loss.detection.mask-rcnn-loss.box-reg-weight",
type=float,
default=1,
help=f"Weight for box reg in {cls.__name__}. Defaults to 1.",
)
group.add_argument(
"--loss.detection.mask-rcnn-loss.mask-weight",
type=float,
default=1,
help=f"Weight for mask in {cls.__name__}. Defaults to 1.",
)
group.add_argument(
"--loss.detection.mask-rcnn-loss.objectness-weight",
type=float,
default=1,
help=f"Weight for objectness in {cls.__name__}. Defaults to 1.",
)
group.add_argument(
"--loss.detection.mask-rcnn-loss.rpn-box-reg",
type=float,
default=1,
help=f"Weight for rpn box reg. in {cls.__name__}. Defaults to 1.",
)
return parser
[docs] def forward(
self,
input_sample: Any,
prediction: Dict[str, Tensor],
*args,
**kwargs,
) -> Dict[str, Tensor]:
"""Compute MaskRCNN loss.
Args:
input_sample: Input image tensor to the model.
prediction: Mapping of the Maskrcnn losses.
Shapes:
input_sample: This loss function does not care about input to the model.
prediction: Dictionary containing scalar Mask RCNN loss values. Expected keys are
loss_classifier, loss_box_reg, loss_mask, loss_objectness, loss_rpn_box_reg.
Returns:
A mapping of (string: scalar) is returned. Output contains following keys: (total_loss,
loss_classifier, loss_box_reg, loss_mask, loss_objectness, loss_rpn_box_reg).
"""
if not self.training:
# MaskRCNN doesn't return the loss during validation. Therefore, we return 0.
return {"total_loss": torch.tensor(0.0, device=self.device)}
if not isinstance(prediction, Dict):
logger.error(
f"{self.__class__.__name__} requires prediction as a dictionary with "
f"loss_classifier, loss_box_reg, loss_mask, loss_objectness, loss_rpn_box_reg as "
f"mandatory keys. Got: {type(prediction)}."
)
if not {
"loss_classifier",
"loss_box_reg",
"loss_mask",
"loss_objectness",
"loss_rpn_box_reg",
}.issubset(prediction.keys()):
logger.error(
f"loss_classifier, loss_box_reg, loss_mask, loss_objectness, loss_rpn_box_reg are "
f"required keys in {self.__class__.__name__}. Got: {prediction.keys()}"
)
total_loss = 0.0
mask_rcnn_losses = {}
for loss_key, loss_wt in zip(
[
"loss_classifier",
"loss_box_reg",
"loss_mask",
"loss_objectness",
"loss_rpn_box_reg",
],
[
self.classifier_weight,
self.box_reg_weight,
self.mask_weight,
self.objectness_weight,
self.rpn_box_reg,
],
):
loss_ = prediction[loss_key] * loss_wt
total_loss += loss_
mask_rcnn_losses[loss_key] = loss_
mask_rcnn_losses.update({"total_loss": total_loss})
return mask_rcnn_losses