Source code for cvnets.matcher_det.ssd_matcher

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import argparse
from typing import Optional, Tuple, Union

import numpy as np
import torch
from torch import Tensor

from cvnets.matcher_det import MATCHER_REGISTRY, BaseMatcher
from cvnets.misc.box_utils import (
    center_form_to_corner_form,
    convert_boxes_to_locations,
    convert_locations_to_boxes,
    corner_form_to_center_form,
)
from cvnets.misc.third_party.ssd_utils import assign_priors
from utils import logger


[docs]@MATCHER_REGISTRY.register(name="ssd") class SSDMatcher(BaseMatcher): """ This class assigns labels to anchors via `SSD matching process <https://arxiv.org/abs/1512.02325>`_ Args: opts: command line arguments bg_class_id: Background class index Shape: - Input: - gt_boxes: Ground-truth boxes in corner form (xyxy format). Shape is :math:`(N, 4)` where :math:`N` is the number of boxes - gt_labels: Ground-truth box labels. Shape is :math:`(N)` - anchors: Anchor boxes in center form (c_x, c_y, w, h). Shape is :math:`(M, 4)` where :math:`M` is the number of anchors - Output: - matched_boxes of shape :math:`(M, 4)` - matched_box_labels of shape :math:`(M)` """
[docs] def __init__(self, opts, bg_class_id: Optional[int] = 0, *args, **kwargs) -> None: center_variance = getattr(opts, "matcher.ssd.center_variance", None) check_variable(center_variance, "--matcher.ssd.center-variance") size_variance = getattr(opts, "matcher.ssd.size_variance", None) check_variable(val=size_variance, args_str="--matcher.ssd.size-variance") iou_threshold = getattr(opts, "matcher.ssd.iou_threshold", None) check_variable(val=iou_threshold, args_str="--matcher.ssd.iou-threshold") super().__init__(opts=opts, *args, **kwargs) self.center_variance = center_variance self.size_variance = size_variance self.iou_threshold = iou_threshold self.bg_class_id = bg_class_id
def __repr__(self): return "{}(center_variance={}, size_variance={}, iou_threshold={})".format( self.__class__.__name__, self.center_variance, self.size_variance, self.iou_threshold, )
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """ Add SSD Matcher specific arguments """ group = parser.add_argument_group(title=cls.__name__) group.add_argument( "--matcher.ssd.center-variance", type=float, default=0.1, help="Center variance for matching", ) group.add_argument( "--matcher.ssd.size-variance", type=float, default=0.2, help="Size variance.", ) group.add_argument( "--matcher.ssd.iou-threshold", type=float, default=0.45, help="IOU Threshold.", ) return parser
def __call__( self, gt_boxes: Union[np.ndarray, Tensor], gt_labels: Union[np.ndarray, Tensor], anchors: Tensor, ) -> Tuple[Tensor, Tensor]: if isinstance(gt_boxes, np.ndarray): gt_boxes = torch.from_numpy(gt_boxes) if isinstance(gt_labels, np.ndarray): gt_labels = torch.from_numpy(gt_labels) # convert box priors from center [c_x, c_y] to corner_form [x, y] anchors_xyxy = center_form_to_corner_form(boxes=anchors) matched_boxes_xyxy, matched_labels = assign_priors( gt_boxes, # gt_boxes are in corner form [x, y, w, h] gt_labels, anchors_xyxy, # priors are in corner form [x, y, w, h] self.iou_threshold, background_id=self.bg_class_id, ) # convert the matched boxes to center form [c_x, c_y] matched_boxes_cxcywh = corner_form_to_center_form(matched_boxes_xyxy) # Eq.(2) in paper https://arxiv.org/pdf/1512.02325.pdf boxes_for_regression = convert_boxes_to_locations( gt_boxes=matched_boxes_cxcywh, # center form prior_boxes=anchors, # center form center_variance=self.center_variance, size_variance=self.size_variance, ) return boxes_for_regression, matched_labels
[docs] def convert_to_boxes( self, pred_locations: torch.Tensor, anchors: torch.Tensor ) -> Tensor: """ Decodes boxes from predicted locations and anchors. """ # decode boxes in center form boxes = convert_locations_to_boxes( pred_locations=pred_locations, anchor_boxes=anchors, center_variance=self.center_variance, size_variance=self.size_variance, ) # convert boxes from center form [c_x, c_y] to corner form [x, y] boxes = center_form_to_corner_form(boxes) return boxes
[docs]def check_variable(val, args_str: str): if val is None: logger.error("{} cannot be None".format(args_str)) if not (0.0 < val < 1.0): logger.error( "The value of {} should be between 0 and 1. Got: {}".format(args_str, val) )