Source code for data.video_reader.pyav_reader

#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

from typing import Dict, Optional

import av
import numpy
import torch

from data.transforms.base_transforms import BaseTransformation
from data.video_reader import VIDEO_READER_REGISTRY, BaseAVReader
from utils import logger


[docs]@VIDEO_READER_REGISTRY.register(name="pyav") class PyAVReader(BaseAVReader): """ Video Reader using PyAV. """
[docs] def read_video( self, av_file: str, stream_idx: int = 0, audio_sample_rate: int = -1, custom_frame_transforms: Optional[BaseTransformation] = None, video_only: bool = False, *args, **kwargs, ) -> Dict: with av.open(av_file) as container: audio_frames = video_frames = None audio_fps = video_fps = None for stream in container.streams: if self.fast_decoding: # use multi-threading for decoding stream.thread_type = "AUTO" container.seek(0) if stream.type == "audio": # Skip audio stream if audio not required. if video_only: continue # Compute audio frame stats. assert ( audio_fps is None ), f"Multiple audio streams exist in '{av_file}', while only one is expected. (stream_idx={stream_idx})" assert audio_frames is None audio_stream = container.streams.audio[stream_idx] n_audio_channels = len(audio_stream.layout.channels) audio_frames = [] if audio_sample_rate > 0: resampler = av.AudioResampler(rate=audio_sample_rate) for frame in container.decode(audio=stream_idx): if audio_sample_rate > 0: frame = resampler.resample(frame)[0] audio_frames.append( frame.to_ndarray().reshape(n_audio_channels, -1) ) audio_frames = torch.from_numpy( numpy.concatenate(audio_frames, axis=1) ).transpose(1, 0) audio_fps = ( audio_sample_rate if audio_sample_rate > 0 else audio_stream.sample_rate ) elif stream.type == "video": assert video_fps is None assert video_frames is None video_stream = container.streams.video[stream_idx] n_frames = video_stream.frames width = video_stream.width height = video_stream.height video_fps = float(video_stream.base_rate) video_frames = torch.empty( size=(n_frames, 3, height, width), dtype=torch.float ) frame_transforms = ( self.frame_transforms if custom_frame_transforms is None else custom_frame_transforms ) for i, video_frame in enumerate(container.decode(video=stream_idx)): video_frame = video_frame.to_image() video_frame = frame_transforms({"image": video_frame})["image"] video_frames[i] = video_frame return { "audio": audio_frames, "video": video_frames, "metadata": { "audio_fps": audio_fps, "video_fps": video_fps, "filename": av_file, }, }