#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#
import torch
from torch import Tensor
"""
This file implements different conversion functions to implement `SSD <https://arxiv.org/pdf/1512.02325.pdf>`_ object detector.
Equations are written inside each function for brevity.
"""
[docs]def convert_locations_to_boxes(
pred_locations: Tensor,
anchor_boxes: Tensor,
center_variance: float,
size_variance: float,
) -> Tensor:
"""
This is an inverse of convert_boxes_to_locations function (or Eq.(2) in `SSD paper <https://arxiv.org/pdf/1512.02325.pdf>`_
Args:
pred_locations (Tensor): predicted locations from detector
anchor_boxes (Tensor): prior boxes in center form
center_variance (float): variance value for centers (c_x and c_y)
size_variance (float): variance value for size (height and width)
Returns:
predicted boxes tensor in center form
"""
# priors can have one dimension less.
if anchor_boxes.dim() + 1 == pred_locations.dim():
anchor_boxes = anchor_boxes.unsqueeze(0)
# T_w = log(g_w/d_w) / size_variance ==> g_w = exp(T_w * size_variance) * d_w
# T_h = log(g_h/d_h) / size_variance ==> g_h = exp(T_h * size_variance) * d_h
pred_size = (
torch.exp(pred_locations[..., 2:] * size_variance) * anchor_boxes[..., 2:]
)
# T_cx = ((g_cx - d_cx) / d_w) / center_variance ==> g_cx = ((T_cx * center_variance) * d_w) + d_cx
# T_cy = ((g_cy - d_cy) / d_w) / center_variance ==> g_cy = ((T_cy * center_variance) * d_h) + d_cy
pred_center = (
pred_locations[..., :2] * center_variance * anchor_boxes[..., 2:]
) + anchor_boxes[..., :2]
return torch.cat((pred_center, pred_size), dim=-1)
[docs]def convert_boxes_to_locations(
gt_boxes: Tensor, prior_boxes: Tensor, center_variance: float, size_variance: float
):
"""
This function implements Eq.(2) in the `SSD paper <https://arxiv.org/pdf/1512.02325.pdf>`_
Args:
gt_boxes (Tensor): Ground truth boxes in center form (cx, cy, w, h)
prior_boxes (Tensor): Prior boxes in center form (cx, cy, w, h)
center_variance (float): variance value for centers (c_x and c_y)
size_variance (float): variance value for size (height and width)
Returns:
boxes tensor for training
"""
# T_cx = ((g_cx - d_cx) / d_w) / center_variance; Center vairance is nothing but normalization
# T_cy = ((g_cy - d_cy) / d_h) / center_variance
# T_w = log(g_w/d_w) / size_variance and T_h = log(g_h/d_h) / size_varianc
# priors can have one dimension less
if prior_boxes.dim() + 1 == gt_boxes.dim():
prior_boxes = prior_boxes.unsqueeze(0)
target_centers = (
(gt_boxes[..., :2] - prior_boxes[..., :2]) / prior_boxes[..., 2:]
) / center_variance
target_size = torch.log(gt_boxes[..., 2:] / prior_boxes[..., 2:]) / size_variance
return torch.cat((target_centers, target_size), dim=-1)