Source code for cvnets.anchor_generator.base_anchor_generator

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import argparse
from typing import Optional, Tuple, Union

import torch
from torch import Tensor


[docs]class BaseAnchorGenerator(torch.nn.Module): """ Base class for anchor generators for the task of object detection. """
[docs] def __init__(self, *args, **kwargs) -> None: super().__init__() self.anchors_dict = dict()
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """ Add anchor generator-specific arguments to the parser """ return parser
[docs] def num_anchors_per_os(self): """Returns anchors per output stride. Child classes must implement this function.""" raise NotImplementedError
@torch.no_grad() def _generate_anchors( self, height: int, width: int, output_stride: int, device: Optional[str] = "cpu", *args, **kwargs ) -> Union[Tensor, Tuple[Tensor, ...]]: raise NotImplementedError @torch.no_grad() def _get_anchors( self, fm_height: int, fm_width: int, fm_output_stride: int, device: Optional[str] = "cpu", *args, **kwargs ) -> Union[Tensor, Tuple[Tensor, ...]]: key = "h_{}_w_{}_os_{}".format(fm_height, fm_width, fm_output_stride) if key not in self.anchors_dict: default_anchors_ctr = self._generate_anchors( height=fm_height, width=fm_width, output_stride=fm_output_stride, device=device, *args, **kwargs ) self.anchors_dict[key] = default_anchors_ctr return default_anchors_ctr else: return self.anchors_dict[key]
[docs] @torch.no_grad() def forward( self, fm_height: int, fm_width: int, fm_output_stride: int, device: Optional[str] = "cpu", *args, **kwargs ) -> Union[Tensor, Tuple[Tensor, ...]]: """ Returns anchors for the feature map Args: fm_height (int): Height of the feature map fm_width (int): Width of the feature map fm_output_stride (int): Output stride of the feature map device (Optional, str): Device (cpu or cuda). Defaults to cpu Returns: Tensor or Tuple of Tensors """ return self._get_anchors( fm_height=fm_height, fm_width=fm_width, fm_output_stride=fm_output_stride, device=device, *args, **kwargs )