Source code for cvnets.misc.box_utils

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import torch
from torch import Tensor

"""
    This file implements different conversion functions to implement `SSD <https://arxiv.org/pdf/1512.02325.pdf>`_ object detector.
    Equations are written inside each function for brevity.
"""


[docs]def convert_locations_to_boxes( pred_locations: Tensor, anchor_boxes: Tensor, center_variance: float, size_variance: float, ) -> Tensor: """ This is an inverse of convert_boxes_to_locations function (or Eq.(2) in `SSD paper <https://arxiv.org/pdf/1512.02325.pdf>`_ Args: pred_locations (Tensor): predicted locations from detector anchor_boxes (Tensor): prior boxes in center form center_variance (float): variance value for centers (c_x and c_y) size_variance (float): variance value for size (height and width) Returns: predicted boxes tensor in center form """ # priors can have one dimension less. if anchor_boxes.dim() + 1 == pred_locations.dim(): anchor_boxes = anchor_boxes.unsqueeze(0) # T_w = log(g_w/d_w) / size_variance ==> g_w = exp(T_w * size_variance) * d_w # T_h = log(g_h/d_h) / size_variance ==> g_h = exp(T_h * size_variance) * d_h pred_size = ( torch.exp(pred_locations[..., 2:] * size_variance) * anchor_boxes[..., 2:] ) # T_cx = ((g_cx - d_cx) / d_w) / center_variance ==> g_cx = ((T_cx * center_variance) * d_w) + d_cx # T_cy = ((g_cy - d_cy) / d_w) / center_variance ==> g_cy = ((T_cy * center_variance) * d_h) + d_cy pred_center = ( pred_locations[..., :2] * center_variance * anchor_boxes[..., 2:] ) + anchor_boxes[..., :2] return torch.cat((pred_center, pred_size), dim=-1)
[docs]def convert_boxes_to_locations( gt_boxes: Tensor, prior_boxes: Tensor, center_variance: float, size_variance: float ): """ This function implements Eq.(2) in the `SSD paper <https://arxiv.org/pdf/1512.02325.pdf>`_ Args: gt_boxes (Tensor): Ground truth boxes in center form (cx, cy, w, h) prior_boxes (Tensor): Prior boxes in center form (cx, cy, w, h) center_variance (float): variance value for centers (c_x and c_y) size_variance (float): variance value for size (height and width) Returns: boxes tensor for training """ # T_cx = ((g_cx - d_cx) / d_w) / center_variance; Center vairance is nothing but normalization # T_cy = ((g_cy - d_cy) / d_h) / center_variance # T_w = log(g_w/d_w) / size_variance and T_h = log(g_h/d_h) / size_varianc # priors can have one dimension less if prior_boxes.dim() + 1 == gt_boxes.dim(): prior_boxes = prior_boxes.unsqueeze(0) target_centers = ( (gt_boxes[..., :2] - prior_boxes[..., :2]) / prior_boxes[..., 2:] ) / center_variance target_size = torch.log(gt_boxes[..., 2:] / prior_boxes[..., 2:]) / size_variance return torch.cat((target_centers, target_size), dim=-1)
[docs]def center_form_to_corner_form(boxes: Tensor) -> Tensor: """ This function convert boxes from center to corner form Args: boxes (Tensor): Boxes in center form (cx,cy,w,h) Returns: Boxes tensor in corner form (x,y,w,h) """ # x = c_x - (delta_w * 0.5), y = c_y - (delta_h * 0.5) # w = c_x + (delta_w * 0.5), h = c_y + (delta_h * 0.5) return torch.cat( ( boxes[..., :2] - (boxes[..., 2:] * 0.5), boxes[..., :2] + (boxes[..., 2:] * 0.5), ), dim=-1, )
[docs]def corner_form_to_center_form(boxes: torch.Tensor) -> torch.Tensor: """ This function converts boxes from corner to center form Args: boxes (Tensor): boxes in corner form (x, y, w, h) Returns: Boxes tensor in center form (c_x, c_y, w, h) """ # c_x = ( x + w ) * 0.5, c_y = (y + h) * 0.5 # delta_w = w - x, delta_h = h - y return torch.cat( ((boxes[..., :2] + boxes[..., 2:]) * 0.5, boxes[..., 2:] - boxes[..., :2]), dim=-1, )