Source code for utils.tensor_utils

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from torch import Tensor
from torch import distributed as dist

from common import (
    DEFAULT_IMAGE_CHANNELS,
    DEFAULT_IMAGE_HEIGHT,
    DEFAULT_IMAGE_WIDTH,
    DEFAULT_VIDEO_FRAMES,
)
from utils.third_party.ddp_functional_utils import (
    all_gather as all_gather_with_backward,
)


[docs]def image_size_from_opts(opts) -> Tuple[int, int]: try: sampler_name = getattr(opts, "sampler.name", "variable_batch_sampler").lower() if sampler_name.find("var") > -1: im_w = getattr(opts, "sampler.vbs.crop_size_width", DEFAULT_IMAGE_WIDTH) im_h = getattr(opts, "sampler.vbs.crop_size_height", DEFAULT_IMAGE_HEIGHT) elif sampler_name.find("multi") > -1: im_w = getattr(opts, "sampler.msc.crop_size_width", DEFAULT_IMAGE_WIDTH) im_h = getattr(opts, "sampler.msc.crop_size_height", DEFAULT_IMAGE_HEIGHT) else: im_w = getattr(opts, "sampler.bs.crop_size_width", DEFAULT_IMAGE_WIDTH) im_h = getattr(opts, "sampler.bs.crop_size_height", DEFAULT_IMAGE_HEIGHT) except Exception as e: im_h = DEFAULT_IMAGE_HEIGHT im_w = DEFAULT_IMAGE_WIDTH return im_h, im_w
[docs]def video_size_from_opts(opts) -> Tuple[int, int, int]: try: sampler_name = getattr(opts, "sampler.name", "video_batch_sampler").lower() if sampler_name.find("var") > -1: im_w = getattr(opts, "sampler.vbs.crop_size_width", DEFAULT_IMAGE_WIDTH) im_h = getattr(opts, "sampler.vbs.crop_size_height", DEFAULT_IMAGE_HEIGHT) n_frames = getattr( opts, "sampler.vbs.num_frames_per_clip", DEFAULT_IMAGE_HEIGHT ) else: im_w = getattr(opts, "sampler.bs.crop_size_width", DEFAULT_IMAGE_WIDTH) im_h = getattr(opts, "sampler.bs.crop_size_height", DEFAULT_IMAGE_HEIGHT) n_frames = getattr( opts, "sampler.bs.num_frames_per_clip", DEFAULT_IMAGE_HEIGHT ) except Exception as e: im_h = DEFAULT_IMAGE_HEIGHT im_w = DEFAULT_IMAGE_WIDTH n_frames = DEFAULT_VIDEO_FRAMES return im_h, im_w, n_frames
[docs]def create_rand_tensor( opts, device: Optional[str] = "cpu", batch_size: Optional[int] = 1 ) -> Tensor: sampler = getattr(opts, "sampler.name", "batch_sampler") if sampler.lower().find("video") > -1: video_stack = getattr(opts, "video_reader.frame_stack_format", "channel_first") im_h, im_w, n_frames = video_size_from_opts(opts=opts) if video_stack == "channel_first": inp_tensor = torch.randint( low=0, high=255, size=(batch_size, DEFAULT_IMAGE_CHANNELS, n_frames, im_h, im_w), device=device, ) else: inp_tensor = torch.randint( low=0, high=255, size=(batch_size, n_frames, DEFAULT_IMAGE_CHANNELS, im_h, im_w), device=device, ) else: im_h, im_w = image_size_from_opts(opts=opts) inp_tensor = torch.randint( low=0, high=255, size=(batch_size, DEFAULT_IMAGE_CHANNELS, im_h, im_w), device=device, ) inp_tensor = inp_tensor.float().div(255.0) return inp_tensor
[docs]def reduce_tensor(inp_tensor: torch.Tensor) -> torch.Tensor: size = dist.get_world_size() if dist.is_initialized() else 1 inp_tensor_clone = inp_tensor.clone().detach() # dist_barrier() dist.all_reduce(inp_tensor_clone, op=dist.ReduceOp.SUM) inp_tensor_clone /= size return inp_tensor_clone
[docs]def reduce_tensor_sum(inp_tensor: torch.Tensor) -> torch.Tensor: inp_tensor_clone = inp_tensor.clone().detach() # dist_barrier() dist.all_reduce(inp_tensor_clone, op=dist.ReduceOp.SUM) return inp_tensor_clone
[docs]def all_gather_list(data: Union[List, Tensor, Dict[str, Tensor]]): world_size = dist.get_world_size() data_list = [None] * world_size # dist_barrier() dist.all_gather_object(data_list, data) return data_list
[docs]def gather_all_features(features: Tensor, dim=0): return torch.cat(all_gather_with_backward(features), dim=dim)
# world_size = dist.get_world_size() # gathered_data = [torch.zeros_like(features)] * world_size # dist.all_gather(gathered_data, features) # gathered_data = torch.cat(gathered_data, dim=dim) # return gathered_data
[docs]def tensor_to_python_float( inp_tensor: Union[int, float, torch.Tensor], is_distributed: bool, reduce_op: str = "mean", ) -> Union[int, float, np.ndarray]: """ Given a number or a Tensor (potentially in distributed setting) returns the float value. If is_distributed is true, the Tensor must be aggregated first. Args: inp_tensor: the input tensor is_distributed: indicates whether we are in distributed mode reduce_op: reduce operation for aggregation If equals to mean, will reduce using mean, otherwise sum operation """ if is_distributed and isinstance(inp_tensor, torch.Tensor): if reduce_op == "mean": inp_tensor = reduce_tensor(inp_tensor=inp_tensor) else: inp_tensor = reduce_tensor_sum(inp_tensor=inp_tensor) if isinstance(inp_tensor, torch.Tensor) and inp_tensor.numel() > 1: # For IOU, we get a C-dimensional tensor (C - number of classes) # so, we convert here to a numpy array return inp_tensor.cpu().numpy() elif hasattr(inp_tensor, "item"): return inp_tensor.item() elif isinstance(inp_tensor, (int, float)): return inp_tensor * 1.0 else: raise NotImplementedError( "The data type is not supported yet in tensor_to_python_float function" )
[docs]def to_numpy(img_tensor: torch.Tensor) -> np.ndarray: # [0, 1] --> [0, 255] img_tensor = torch.mul(img_tensor, 255.0) # BCHW --> BHWC img_tensor = img_tensor.permute(0, 2, 3, 1) img_np = img_tensor.byte().cpu().numpy() return img_np