Source code for metrics.image_text_retrieval

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

import argparse
from numbers import Number
from typing import Any, Dict, Tuple, Union

import torch
from torch import Tensor

from metrics import METRICS_REGISTRY
from metrics.metric_base import BaseMetric
from metrics.retrieval_cmc import DISTANCE_REGISTRY
from utils import logger
from utils.tensor_utils import all_gather_list


[docs]@METRICS_REGISTRY.register("image_text_retrieval") class ImageTextRetrievalMetric(BaseMetric): """ Computes the image-text retrieval metrics for a list of images and their captions using the distance between their embeddings. Expects predictions to contain two keys: image (Tensor): [batch, hidden_dim] text (Tensor): [batch * num_captions, hidden_dim] Computes the following metrics: image2text recall@1, recall@5, recall@10, mean_rank, median_rank text2image recall@1, recall@5, recall@10, mean_rank, median_rank NOTE: each image MUST have the same number of captions. """
[docs] def __init__( self, image: str = "image", text: str = "text", opts: Dict[str, Any] = None, is_distributed: bool = False, ) -> None: # Ignoring pred_key and target_key as we won't be using them # The issue is, both text and image are in the prediction, so pred_key and # target_key don't make sense here. We can still use pred_key to support nested # dicts, but it didn't seem required. super().__init__(opts, is_distributed) self._image_key = image self._text_key = text distance_metric = getattr( opts, "stats.metrics.img_text_retrieval.distance_metric" ) self.measure = DISTANCE_REGISTRY[distance_metric]
[docs] @classmethod def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Add metric specific arguments""" if cls == ImageTextRetrievalMetric: parser.add_argument( "--stats.metrics.img-text-retrieval.distance-metric", type=str, default="cosine", choices=list(DISTANCE_REGISTRY.keys()), help="Distance to use for nearest-neighbor calculation.", ) return parser
[docs] def reset(self) -> None: self._images = [] self._texts = []
[docs] def update( self, prediction: Union[Tensor, Dict], target: Union[Tensor, Dict], extras: Dict[str, Any], batch_size: int = 1, ) -> None: images = prediction[self._image_key] texts = prediction[self._text_key] if not isinstance(images, Tensor) or not isinstance(texts, Tensor): logger.error( "ImageTextRetrievalMetric only works on Tensor, got {} and {}.".format( type(images), type(texts) ) ) return if self.is_distributed: images = all_gather_list(images) texts = all_gather_list(texts) else: images = [images.detach()] texts = [texts.detach()] self._images.extend(images) self._texts.extend(texts)
[docs] def get_aggregates(self) -> Tuple[Tensor, Tensor]: self._images = [torch.cat(self._images, dim=0)] self._texts = [torch.cat(self._texts, dim=0)] return self._images[0], self._texts[0]
def _text2image( self, images: Tensor, texts: Tensor, num_captions: int ) -> torch.LongTensor: """ Compute the distance between embeddings for text captions and their respective images. Args: image: A tensor of image embeddings. Shape: [batch, hidden_dim] text: A tensor of text embeddings. Shape: [batch * num_captions, hidden_dim] num_captions: The number of captions paired with a single image. Returns: A tensor containing ranks of the corresponding image among all images. """ ranks = torch.zeros(images.shape[0], dtype=torch.long) for i, image in enumerate(images): # [1, hidden_dim] dist [batch * num_captions, hidden_dim] --> [batch * num_captions] # i.e. dists of size: [num_texts] dists = self.measure(image.unsqueeze(0), texts).squeeze(0) # find the rank of the best scoring caption among num_captions inds = torch.argsort(dists) // num_captions ranks[i] = (inds == i).nonzero()[0, 0] return ranks def _image2text( self, images: Tensor, texts: Tensor, num_captions: int ) -> torch.LongTensor: """ Compute the distance between embeddings for images and their respective captions. Args: image: A tensor of image embeddings. Shape: [batch, hidden_dim] text: A tensor of text embeddings. Shape: [batch * num_captions, hidden_dim] num_captions: The number of captions paired with a single image. Returns: A tensor containing ranks of the closest caption to each image among all captions. """ ranks = torch.zeros(texts.shape[0], dtype=torch.long) for i, text in enumerate(texts): # [1, hidden_dim] cos [batch, hidden_dim] --> [batch] # i.e. dists of size: [num_images] dists = self.measure(text.unsqueeze(0), images).squeeze(0) # find the rank of the corresponding image inds = torch.argsort(dists) ranks[i] = (inds == (i // num_captions)).nonzero()[0, 0] return ranks
[docs] def compute(self) -> Union[Number, Dict[str, Number]]: # image: [batch, hidden_dim] # text: [batch, num_captions, hidden_dim] or [batch * num_captions, hidden_dim] images, texts = self.get_aggregates() # make sure text shape is: [batch * num_captions, hidden_dim] if texts.dim() == 3: # [batch, num_captions, hidden_dim] --> [batch * num_captions, hidden_dim] texts = texts.reshape(-1, texts.shape[-1]) num_images = images.shape[0] num_texts = texts.shape[0] assert num_texts % num_images == 0, "Number of captions is not consistent" num_captions = num_texts // num_images with torch.no_grad(): i2t_ranks = self._image2text(images, texts, num_captions) t2i_ranks = self._text2image(images, texts, num_captions) return { "text2image": self._rank_metrics(t2i_ranks), "image2text": self._rank_metrics(i2t_ranks), }
def _rank_metrics(self, ranks: torch.LongTensor) -> Dict[str, Number]: return { "recall@1": 100 * (ranks < 1).float().mean().item(), "recall@5": 100 * (ranks < 5).float().mean().item(), "recall@10": 100 * (ranks < 10).float().mean().item(), "mean_rank": 1 + ranks.float().mean().item(), "median_rank": 1 + ranks.median().item(), }